/*
 * Decompiled with CFR 0.152.
 */
package org.apache.gravitino.server.authentication;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
import java.net.URL;
import java.security.Principal;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.apache.gravitino.Config;
import org.apache.gravitino.UserPrincipal;
import org.apache.gravitino.exceptions.UnauthorizedException;
import org.apache.gravitino.server.authentication.OAuthConfig;
import org.apache.gravitino.server.authentication.OAuthTokenValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JwksTokenValidator
implements OAuthTokenValidator {
    private static final Logger LOG = LoggerFactory.getLogger(JwksTokenValidator.class);
    private String jwksUri;
    private String expectedIssuer;
    private List<String> principalFields;
    private long allowSkewSeconds;

    @Override
    public void initialize(Config config) {
        this.jwksUri = (String)config.get(OAuthConfig.JWKS_URI);
        this.expectedIssuer = (String)config.get(OAuthConfig.AUTHORITY);
        this.principalFields = (List)config.get(OAuthConfig.PRINCIPAL_FIELDS);
        this.allowSkewSeconds = (Long)config.get(OAuthConfig.ALLOW_SKEW_SECONDS);
        LOG.info("Initializing JWKS token validator");
        if (StringUtils.isBlank((CharSequence)this.jwksUri)) {
            throw new IllegalArgumentException("JWKS URI must be configured when using JWKS-based OAuth providers");
        }
        try {
            new URL(this.jwksUri);
        }
        catch (Exception e) {
            LOG.error("Invalid JWKS URI format: {}", (Object)this.jwksUri);
            throw new IllegalArgumentException("Invalid JWKS URI format: " + this.jwksUri, e);
        }
    }

    @Override
    public Principal validateToken(String token, String serviceAudience) {
        if (token == null || token.trim().isEmpty()) {
            LOG.error("Token is null or empty");
            throw new UnauthorizedException("Token cannot be null or empty", new Object[0]);
        }
        if (serviceAudience == null || serviceAudience.trim().isEmpty()) {
            LOG.error("Service audience is null or empty");
            throw new UnauthorizedException("Service audience cannot be null or empty", new Object[0]);
        }
        try {
            SignedJWT signedJWT = SignedJWT.parse((String)token);
            JWKSource<SecurityContext> jwkSource = this.createJwkSource();
            JWSAlgorithm algorithm = JWSAlgorithm.parse((String)signedJWT.getHeader().getAlgorithm().getName());
            JWSVerificationKeySelector keySelector = new JWSVerificationKeySelector(algorithm, jwkSource);
            DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
            jwtProcessor.setJWSKeySelector((JWSKeySelector)keySelector);
            JWTClaimsSet.Builder expectedClaimsBuilder = new JWTClaimsSet.Builder();
            if (StringUtils.isNotBlank((CharSequence)this.expectedIssuer)) {
                expectedClaimsBuilder.issuer(this.expectedIssuer);
            }
            if (StringUtils.isNotBlank((CharSequence)serviceAudience)) {
                expectedClaimsBuilder.audience(serviceAudience);
            }
            DefaultJWTClaimsVerifier claimsVerifier = new DefaultJWTClaimsVerifier(expectedClaimsBuilder.build(), null);
            claimsVerifier.setMaxClockSkew((int)this.allowSkewSeconds);
            jwtProcessor.setJWTClaimsSetVerifier((JWTClaimsSetVerifier)claimsVerifier);
            JWTClaimsSet validatedClaims = jwtProcessor.process(signedJWT, null);
            String principal = this.extractPrincipal(validatedClaims);
            if (principal == null) {
                LOG.error("No valid principal found in token");
                throw new UnauthorizedException("No valid principal found in token", new Object[0]);
            }
            return new UserPrincipal(principal);
        }
        catch (Exception e) {
            LOG.error("JWKS JWT validation error: {}", (Object)e.getMessage());
            throw new UnauthorizedException((Throwable)e, "JWKS JWT validation error", new Object[0]);
        }
    }

    private JWKSource<SecurityContext> createJwkSource() throws Exception {
        try {
            return JWKSourceBuilder.create((URL)new URL(this.jwksUri)).build();
        }
        catch (Exception e) {
            LOG.error("Failed to create JWKS source from URI: {}", (Object)this.jwksUri, (Object)e);
            throw new Exception("Failed to create JWKS source: " + e.getMessage(), e);
        }
    }

    private String extractPrincipal(JWTClaimsSet validatedClaims) {
        if (this.principalFields != null && !this.principalFields.isEmpty()) {
            for (String field : this.principalFields) {
                String principal;
                if (!StringUtils.isNotBlank((CharSequence)field) || (principal = (String)validatedClaims.getClaim(field)) == null) continue;
                return principal;
            }
        }
        return null;
    }
}

