/*
 *   o_
 * in|tarsys GmbH (c)
 *
 * all rights reserved
 *
 */
package de.intarsys.cloudsuite.gears.demo.auth;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Map;

import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import javax.security.auth.Subject;

import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;

import com.fasterxml.jackson.jakarta.rs.json.JacksonJsonProvider;

import de.intarsys.tools.collection.ByteArrayTools;
import de.intarsys.tools.functor.IArgs;
import de.intarsys.tools.jaxrs.JaxrsTools;
import de.intarsys.tools.string.StringTools;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.client.ClientBuilder;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.client.Invocation;
import jakarta.ws.rs.client.WebTarget;
import jakarta.ws.rs.core.Form;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status.Family;

public class OpenIdConnectContext extends CommonAuthContext {

	private static final int RANDOM_LENGTH = 32;

	private static final String URI_AUTH_CALLBACK = "/api/v1/authentication/callback";

	private static final String SOCKET_TLS = "TLS";

	private static final String CLAIM_ISS = "iss";

	private static final int HTTP_OK = 200;

	private static final String PARAM_ID_TOKEN = "id_token";

	private static final String PARAM_ACCESS_TOKEN = "access_token";

	private static final String TOKEN_TYPE_BEARER = "Bearer";

	private static final String PARAM_TOKEN_TYPE = "token_type";

	private static final String GRANT_TYPE_AUTHORIZATION_CODE = "authorization_code";

	private static final String PARAM_GRANT_TYPE = "grant_type";

	private static final String PARAM_SCOPE = "scope";

	private static final String PARAM_RESPONSE_TYPE = "response_type";

	private static final String PARAM_RESPONSE_MODE = "response_mode";

	private static final String PARAM_DISPLAY = "display";

	private static final String PARAM_NONCE = "nonce";

	private static final String PARAM_REDIRECT_URI = "redirect_uri";

	private static final String PARAM_CLIENT_ID = "client_id";

	private static final String PARAM_CODE = "code";

	private static final String PARAM_STATE = "state";

	private static final String PARAM_ERROR = "error";

	private static final String PARAM_ERROR_DESCRIPTION = "error_description";

	private static final String PARAM_PROMPT = "prompt";

	private static final String ERROR_ACCESS_DENIED = "access_denied";

	/*
	 * bind request to callback, mitigate CSRF
	 */
	private String state;

	private Client client;

	private String accessToken;

	private String idToken;

	private String redirectUri;

	public OpenIdConnectContext(CommonAuthModule authModule, Subject subject, IArgs args) {
		super(authModule, subject, args);
	}

	@Override
	public String authCallback(HttpServletRequest servletRequest) throws AuthenticationException {
		String tmpState = servletRequest.getParameter(PARAM_STATE);
		if (StringTools.isEmpty(tmpState)) {
			throw new AuthenticationFailed("authentication failed");
		}
		if (!tmpState.equals(getState())) {
			throw new AuthenticationFailed("authentication failed");
		}
		authCheckError(servletRequest);
		String tmpCode = servletRequest.getParameter(PARAM_CODE);
		if (StringTools.isEmpty(tmpCode)) {
			throw new AuthenticationFailed("authentication failed");
		}
		authRequestToken(tmpCode);
		JwtToken jwtToken = authValidateToken();
		OpenIdConnectPrincipal principal = new OpenIdConnectPrincipal(jwtToken, getIdToken());
		getSubject().getPrincipals().add(principal);
		return principal.getToken();
	}

	protected void authCheckError(HttpServletRequest servletRequest) throws AuthenticationException {
		String tmpError = servletRequest.getParameter(PARAM_ERROR);
		if (StringTools.isEmpty(tmpError)) {
			return;
		}
		String tmpErrorDescription = servletRequest.getParameter(PARAM_ERROR_DESCRIPTION);
		tmpErrorDescription = tmpErrorDescription == null ? "authentication failed" : tmpErrorDescription;
		if (ERROR_ACCESS_DENIED.equals(tmpError)) {
			throw new AuthenticationCanceled(tmpErrorDescription);
		}
		throw new AuthenticationFailed(tmpError + "-" + tmpErrorDescription);
	}

	protected String authRequestCode() throws AuthenticationException {
		WebTarget targetMethod = client.target(getAuthModule().getEndpointAuthorization());
		targetMethod = targetMethod.queryParam(PARAM_CLIENT_ID, getAuthModule().getClientId());
		targetMethod = targetMethod.queryParam(PARAM_REDIRECT_URI, getRedirectUri());
		targetMethod = targetMethod.queryParam(PARAM_STATE, getState());
		targetMethod = targetMethod.queryParam(PARAM_NONCE, createRandomString());
		targetMethod = targetMethod.queryParam(PARAM_SCOPE, getAuthModule().getScope());
		targetMethod = targetMethod.queryParam(PARAM_RESPONSE_TYPE, getAuthModule().getResponseType());
		if (!StringTools.isEmpty(getAuthModule().getResponseMode())) {
			targetMethod = targetMethod.queryParam(PARAM_RESPONSE_MODE, getAuthModule().getResponseMode());
		}
		if (!StringTools.isEmpty(getAuthModule().getDisplay())) {
			targetMethod = targetMethod.queryParam(PARAM_DISPLAY, getAuthModule().getDisplay());
		}
		if (!StringTools.isEmpty(getAuthModule().getPrompt())) {
			targetMethod = targetMethod.queryParam(PARAM_PROMPT, getAuthModule().getPrompt());
		}
		return targetMethod.getUri().toString();
	}

	protected void authRequestToken(String code) throws AuthenticationException {
		WebTarget targetMethod = client.target(getAuthModule().getEndpointToken());
		Form form = new Form();
		form.param(PARAM_CODE, code);
		form.param(PARAM_REDIRECT_URI, getRedirectUri());
		form.param(PARAM_GRANT_TYPE, GRANT_TYPE_AUTHORIZATION_CODE);
		/*
		 * should move to BasicAuth
		 */
		String authHeader = createAuthHeader(getAuthModule().getClientId(), getAuthModule().getClientSecret());
		Invocation.Builder builder = targetMethod.request();
		builder.header("Authorization", authHeader);
		Response response = builder.post(Entity.entity(form, MediaType.APPLICATION_FORM_URLENCODED_TYPE));
		if (response.getStatusInfo().getFamily() == Family.SUCCESSFUL) {
			Map<String, String> result = response.readEntity(Map.class);
			String tokenType = result.get(PARAM_TOKEN_TYPE);
			if (!TOKEN_TYPE_BEARER.equalsIgnoreCase(tokenType)) {
				throw new AuthenticationFailed("authentication failed");
			}
			setAccessToken(result.get(PARAM_ACCESS_TOKEN));
			setIdToken(result.get(PARAM_ID_TOKEN));
		} else {
			throw new AuthenticationFailed("authentication failed");
		}
	}

	/*
	 * This is not a complete validation, do not use in production!!
	 *
	 * should be done locally anyway.
	 *
	 */
	protected JwtToken authValidateToken() throws AuthenticationException {
		if (!StringTools.isEmpty(getAuthModule().getEndpointTokeninfo())) {
			WebTarget targetMethod = client.target(getAuthModule().getEndpointTokeninfo());
			targetMethod = targetMethod.queryParam(PARAM_ID_TOKEN, getIdToken());
			Invocation.Builder builder = targetMethod.request();
			Response response = builder.get();
			int status = response.getStatus();
			if (status == HTTP_OK) {
				/* String plainToken = */ response.readEntity(String.class);
			} else {
				throw new AuthenticationFailed("authentication failed");
			}
		}
		JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(getIdToken());
		JwtToken jwtToken = jwtConsumer.getJwtToken();
		String claimIss = (String) jwtToken.getClaim(CLAIM_ISS);
		if (getAuthModule().getIssuer() != null && !getAuthModule().getIssuer().equals(claimIss)) {
			throw new AuthenticationFailed("authentication failed");
		}
		return jwtToken;
	}

	@Override
	protected Response basicAuthenticate(HttpServletRequest servletRequest) throws AuthenticationException {
		state = createRandomString();

		SSLContext sslContext;
		try {
			sslContext = SSLContext.getInstance(SOCKET_TLS);
			sslContext.init(null, new TrustManager[] { getTrustManagerTrustAll() }, null);
		} catch (GeneralSecurityException e) {
			throw new AuthenticationFailed(e.getMessage(), e);
		}
		client = ClientBuilder.newBuilder() //
				.sslContext(sslContext) //
				.hostnameVerifier((hostname, session) -> true) // NOSONAR
				.build() //
				.register(JacksonJsonProvider.class);
		setAccessToken(null);
		setIdToken(null);
		//
		String tempUri = JaxrsTools.getUriBuilderContext(servletRequest).path(URI_AUTH_CALLBACK).toString();
		setRedirectUri(tempUri);
		String authRequestUri = authRequestCode();
		return Response.temporaryRedirect(URI.create(authRequestUri)).build();
	}

	@Override
	protected void basicDeauthenticate() {
		getSubject().getPrincipals().removeIf((principal) -> principal instanceof OpenIdConnectPrincipal);
	}

	protected String createAuthHeader(String clientId, String clientSecret) {
		String token = clientId + ":" + clientSecret;
		return "Basic " + Base64.getEncoder().encodeToString(token.getBytes(StandardCharsets.UTF_8));
	}

	protected String createRandomString() {
		return Base64.getEncoder().encodeToString(ByteArrayTools.createRandomBytes(RANDOM_LENGTH));
	}

	public String getAccessToken() {
		return accessToken;
	}

	@Override
	public OpenIdConnectModule getAuthModule() {
		return (OpenIdConnectModule) super.getAuthModule();
	}

	public String getIdToken() {
		return idToken;
	}

	public String getRedirectUri() {
		return redirectUri;
	}

	protected String getState() {
		return state;
	}

	protected X509TrustManager getTrustManagerTrustAll() {
		return new X509TrustManager() {

			@Override
			public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { // NOSONAR
				// unused
			}

			@Override
			public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { // NOSONAR
				// unused
			}

			@Override
			public X509Certificate[] getAcceptedIssuers() {
				return new X509Certificate[0];
			}
		};
	}

	protected void setAccessToken(String oAuth2Token) {
		this.accessToken = oAuth2Token;
	}

	protected void setIdToken(String oAuth2IdToken) {
		this.idToken = oAuth2IdToken;
	}

	public void setRedirectUri(String redirectUri) {
		this.redirectUri = redirectUri;
	}

}
