package org.springframework.security.web.server.firewall;

import java.net.InetSocketAddress;
import java.net.URI;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.SslInfo;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebExchangeDecorator;
import reactor.core.publisher.Mono;

/* loaded from: input_file:WEB-INF/lib/spring-security-web-6.5.1.jar:org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall.class */
public class StrictServerWebExchangeFirewall implements ServerWebExchangeFirewall {
    private static final String ENCODED_PERCENT = "%25";
    private static final String PERCENT = "%";
    private Set<String> encodedUrlBlocklist = new HashSet();
    private Set<String> decodedUrlBlocklist = new HashSet();
    private Set<HttpMethod> allowedHttpMethods = createDefaultAllowedHttpMethods();
    private Predicate<String> allowedHostnames = str -> {
        return true;
    };
    private Predicate<String> allowedHeaderNames = ALLOWED_HEADER_NAMES;
    private Predicate<String> allowedHeaderValues = ALLOWED_HEADER_VALUES;
    private Predicate<String> allowedParameterNames = ALLOWED_PARAMETER_NAMES;
    private Predicate<String> allowedParameterValues = ALLOWED_PARAMETER_VALUES;
    private static final Set<HttpMethod> ALLOW_ANY_HTTP_METHOD = Collections.emptySet();
    private static final List<String> FORBIDDEN_ENCODED_PERIOD = Collections.unmodifiableList(Arrays.asList("%2e", "%2E"));
    private static final List<String> FORBIDDEN_SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B"));
    private static final List<String> FORBIDDEN_FORWARDSLASH = Collections.unmodifiableList(Arrays.asList("%2f", "%2F"));
    private static final List<String> FORBIDDEN_DOUBLE_FORWARDSLASH = Collections.unmodifiableList(Arrays.asList("//", "%2f%2f", "%2f%2F", "%2F%2f", "%2F%2F"));
    private static final List<String> FORBIDDEN_BACKSLASH = Collections.unmodifiableList(Arrays.asList("\\", "%5c", "%5C"));
    private static final List<String> FORBIDDEN_NULL = Collections.unmodifiableList(Arrays.asList("��", "%00"));
    private static final List<String> FORBIDDEN_LF = Collections.unmodifiableList(Arrays.asList("\n", "%0a", "%0A"));
    private static final List<String> FORBIDDEN_CR = Collections.unmodifiableList(Arrays.asList(StringUtils.CR, "%0d", "%0D"));
    private static final List<String> FORBIDDEN_LINE_SEPARATOR = Collections.unmodifiableList(Arrays.asList("\u2028"));
    private static final List<String> FORBIDDEN_PARAGRAPH_SEPARATOR = Collections.unmodifiableList(Arrays.asList("\u2029"));
    private static final Pattern ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN = Pattern.compile("[\\p{IsAssigned}&&[^\\p{IsControl}]]*");
    private static final Predicate<String> ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE = str -> {
        return ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN.matcher(str).matches();
    };
    private static final Pattern HEADER_VALUE_PATTERN = Pattern.compile("[\\p{IsAssigned}&&[[^\\p{IsControl}]||\\t]]*");
    private static final Predicate<String> HEADER_VALUE_PREDICATE = str -> {
        return str == null || HEADER_VALUE_PATTERN.matcher(str).matches();
    };
    public static final Predicate<String> ALLOWED_HEADER_NAMES = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
    public static final Predicate<String> ALLOWED_HEADER_VALUES = HEADER_VALUE_PREDICATE;
    public static final Predicate<String> ALLOWED_PARAMETER_NAMES = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
    public static final Predicate<String> ALLOWED_PARAMETER_VALUES = str -> {
        return true;
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/spring-security-web-6.5.1.jar:org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall$StrictFirewallServerWebExchange.class */
    public final class StrictFirewallServerWebExchange extends ServerWebExchangeDecorator {

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:WEB-INF/lib/spring-security-web-6.5.1.jar:org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall$StrictFirewallServerWebExchange$StrictFirewallHttpRequest.class */
        public final class StrictFirewallHttpRequest extends ServerHttpRequestDecorator {

            /* loaded from: input_file:WEB-INF/lib/spring-security-web-6.5.1.jar:org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall$StrictFirewallServerWebExchange$StrictFirewallHttpRequest$StrictFirewallBuilder.class */
            private final class StrictFirewallBuilder implements ServerHttpRequest.Builder {
                private final ServerHttpRequest.Builder delegate;

                private StrictFirewallBuilder(ServerHttpRequest.Builder builder) {
                    this.delegate = builder;
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder method(HttpMethod httpMethod) {
                    return new StrictFirewallBuilder(this.delegate.method(httpMethod));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder uri(URI uri) {
                    return new StrictFirewallBuilder(this.delegate.uri(uri));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder path(String str) {
                    return new StrictFirewallBuilder(this.delegate.path(str));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder contextPath(String str) {
                    return new StrictFirewallBuilder(this.delegate.contextPath(str));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder header(String str, String... strArr) {
                    return new StrictFirewallBuilder(this.delegate.header(str, strArr));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder headers(Consumer<HttpHeaders> consumer) {
                    return new StrictFirewallBuilder(this.delegate.headers(consumer));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder sslInfo(SslInfo sslInfo) {
                    return new StrictFirewallBuilder(this.delegate.sslInfo(sslInfo));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest.Builder remoteAddress(InetSocketAddress inetSocketAddress) {
                    return new StrictFirewallBuilder(this.delegate.remoteAddress(inetSocketAddress));
                }

                @Override // org.springframework.http.server.reactive.ServerHttpRequest.Builder
                public ServerHttpRequest build() {
                    return new StrictFirewallHttpRequest(this.delegate.build());
                }
            }

            /* loaded from: input_file:WEB-INF/lib/spring-security-web-6.5.1.jar:org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall$StrictFirewallServerWebExchange$StrictFirewallHttpRequest$StrictFirewallHttpHeaders.class */
            private final class StrictFirewallHttpHeaders extends HttpHeaders {
                private StrictFirewallHttpHeaders(HttpHeaders httpHeaders) {
                    super(httpHeaders);
                }

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // org.springframework.http.HttpHeaders, org.springframework.util.MultiValueMap
                public String getFirst(String str) {
                    StrictServerWebExchangeFirewall.this.validateAllowedHeaderName(str);
                    String first = super.getFirst(str);
                    StrictServerWebExchangeFirewall.this.validateAllowedHeaderValue(str, first);
                    return first;
                }

                @Override // org.springframework.http.HttpHeaders, java.util.Map
                public List<String> get(Object obj) {
                    if (obj instanceof String) {
                        StrictServerWebExchangeFirewall.this.validateAllowedHeaderName((String) obj);
                    }
                    List<String> list = super.get(obj);
                    if (list == null) {
                        return list;
                    }
                    Iterator<String> it = list.iterator();
                    while (it.hasNext()) {
                        StrictServerWebExchangeFirewall.this.validateAllowedHeaderValue(obj, it.next());
                    }
                    return list;
                }

                @Override // org.springframework.http.HttpHeaders, java.util.Map
                public Set<String> keySet() {
                    Set<String> keySet = super.keySet();
                    Iterator<String> it = keySet.iterator();
                    while (it.hasNext()) {
                        StrictServerWebExchangeFirewall.this.validateAllowedHeaderName(it.next());
                    }
                    return keySet;
                }
            }

            private StrictFirewallHttpRequest(ServerHttpRequest serverHttpRequest) {
                super(serverHttpRequest);
            }

            @Override // org.springframework.http.server.reactive.ServerHttpRequestDecorator, org.springframework.http.HttpMessage
            public HttpHeaders getHeaders() {
                return new StrictFirewallHttpHeaders(super.getHeaders());
            }

            @Override // org.springframework.http.server.reactive.ServerHttpRequestDecorator, org.springframework.http.server.reactive.ServerHttpRequest
            public MultiValueMap<String, String> getQueryParams() {
                MultiValueMap<String, String> queryParams = super.getQueryParams();
                for (Map.Entry<String, String> entry : queryParams.entrySet()) {
                    String key = entry.getKey();
                    StrictServerWebExchangeFirewall.this.validateAllowedParameterName(key);
                    Iterator it = ((List) entry.getValue()).iterator();
                    while (it.hasNext()) {
                        StrictServerWebExchangeFirewall.this.validateAllowedParameterValue(key, (String) it.next());
                    }
                }
                return queryParams;
            }

            @Override // org.springframework.http.server.reactive.ServerHttpRequest
            public ServerHttpRequest.Builder mutate() {
                return new StrictFirewallBuilder(super.mutate());
            }
        }

        private StrictFirewallServerWebExchange(ServerWebExchange serverWebExchange) {
            super(serverWebExchange);
        }

        @Override // org.springframework.web.server.ServerWebExchangeDecorator, org.springframework.web.server.ServerWebExchange
        public ServerHttpRequest getRequest() {
            return new StrictFirewallHttpRequest(super.getRequest());
        }
    }

    public StrictServerWebExchangeFirewall() {
        urlBlocklistsAddAll(FORBIDDEN_SEMICOLON);
        urlBlocklistsAddAll(FORBIDDEN_FORWARDSLASH);
        urlBlocklistsAddAll(FORBIDDEN_DOUBLE_FORWARDSLASH);
        urlBlocklistsAddAll(FORBIDDEN_BACKSLASH);
        urlBlocklistsAddAll(FORBIDDEN_NULL);
        urlBlocklistsAddAll(FORBIDDEN_LF);
        urlBlocklistsAddAll(FORBIDDEN_CR);
        this.encodedUrlBlocklist.add(ENCODED_PERCENT);
        this.encodedUrlBlocklist.addAll(FORBIDDEN_ENCODED_PERIOD);
        this.decodedUrlBlocklist.add("%");
        this.decodedUrlBlocklist.addAll(FORBIDDEN_LINE_SEPARATOR);
        this.decodedUrlBlocklist.addAll(FORBIDDEN_PARAGRAPH_SEPARATOR);
    }

    public Set<String> getEncodedUrlBlocklist() {
        return this.encodedUrlBlocklist;
    }

    public Set<String> getDecodedUrlBlocklist() {
        return this.decodedUrlBlocklist;
    }

    @Override // org.springframework.security.web.server.firewall.ServerWebExchangeFirewall
    public Mono<ServerWebExchange> getFirewalledExchange(ServerWebExchange serverWebExchange) {
        return Mono.fromCallable(() -> {
            ServerHttpRequest request = serverWebExchange.getRequest();
            rejectForbiddenHttpMethod(request);
            rejectedBlocklistedUrls(request);
            rejectedUntrustedHosts(request);
            if (!isNormalized(request)) {
                throw new ServerExchangeRejectedException("The request was rejected because the URL was not normalized");
            }
            serverWebExchange.getResponse().beforeCommit(() -> {
                return Mono.fromRunnable(() -> {
                    for (Map.Entry<String, List<String>> entry : serverWebExchange.getResponse().getHeaders().entrySet()) {
                        String key = entry.getKey();
                        Iterator<String> it = entry.getValue().iterator();
                        while (it.hasNext()) {
                            validateCrlf(key, it.next());
                        }
                    }
                });
            });
            return new StrictFirewallServerWebExchange(serverWebExchange);
        });
    }

    private static void validateCrlf(String str, String str2) {
        Assert.isTrue((hasCrlf(str) || hasCrlf(str2)) ? false : true, (Supplier<String>) () -> {
            return "Invalid characters (CR/LF) in header " + str;
        });
    }

    private static boolean hasCrlf(String str) {
        return (str == null || (str.indexOf(10) == -1 && str.indexOf(13) == -1)) ? false : true;
    }

    public void setUnsafeAllowAnyHttpMethod(boolean z) {
        this.allowedHttpMethods = z ? ALLOW_ANY_HTTP_METHOD : createDefaultAllowedHttpMethods();
    }

    public void setAllowedHttpMethods(Collection<HttpMethod> collection) {
        Assert.notNull(collection, "allowedHttpMethods cannot be null");
        this.allowedHttpMethods = collection != ALLOW_ANY_HTTP_METHOD ? new HashSet<>(collection) : ALLOW_ANY_HTTP_METHOD;
    }

    public void setAllowSemicolon(boolean z) {
        if (z) {
            urlBlocklistsRemoveAll(FORBIDDEN_SEMICOLON);
        } else {
            urlBlocklistsAddAll(FORBIDDEN_SEMICOLON);
        }
    }

    public void setAllowUrlEncodedSlash(boolean z) {
        if (z) {
            urlBlocklistsRemoveAll(FORBIDDEN_FORWARDSLASH);
        } else {
            urlBlocklistsAddAll(FORBIDDEN_FORWARDSLASH);
        }
    }

    public void setAllowUrlEncodedDoubleSlash(boolean z) {
        if (z) {
            urlBlocklistsRemoveAll(FORBIDDEN_DOUBLE_FORWARDSLASH);
        } else {
            urlBlocklistsAddAll(FORBIDDEN_DOUBLE_FORWARDSLASH);
        }
    }

    public void setAllowUrlEncodedPeriod(boolean z) {
        if (z) {
            this.encodedUrlBlocklist.removeAll(FORBIDDEN_ENCODED_PERIOD);
        } else {
            this.encodedUrlBlocklist.addAll(FORBIDDEN_ENCODED_PERIOD);
        }
    }

    public void setAllowBackSlash(boolean z) {
        if (z) {
            urlBlocklistsRemoveAll(FORBIDDEN_BACKSLASH);
        } else {
            urlBlocklistsAddAll(FORBIDDEN_BACKSLASH);
        }
    }

    public void setAllowNull(boolean z) {
        if (z) {
            urlBlocklistsRemoveAll(FORBIDDEN_NULL);
        } else {
            urlBlocklistsAddAll(FORBIDDEN_NULL);
        }
    }

    public void setAllowUrlEncodedPercent(boolean z) {
        if (z) {
            this.encodedUrlBlocklist.remove(ENCODED_PERCENT);
            this.decodedUrlBlocklist.remove("%");
        } else {
            this.encodedUrlBlocklist.add(ENCODED_PERCENT);
            this.decodedUrlBlocklist.add("%");
        }
    }

    public void setAllowUrlEncodedCarriageReturn(boolean z) {
        if (z) {
            urlBlocklistsRemoveAll(FORBIDDEN_CR);
        } else {
            urlBlocklistsAddAll(FORBIDDEN_CR);
        }
    }

    public void setAllowUrlEncodedLineFeed(boolean z) {
        if (z) {
            urlBlocklistsRemoveAll(FORBIDDEN_LF);
        } else {
            urlBlocklistsAddAll(FORBIDDEN_LF);
        }
    }

    public void setAllowUrlEncodedParagraphSeparator(boolean z) {
        if (z) {
            this.decodedUrlBlocklist.removeAll(FORBIDDEN_PARAGRAPH_SEPARATOR);
        } else {
            this.decodedUrlBlocklist.addAll(FORBIDDEN_PARAGRAPH_SEPARATOR);
        }
    }

    public void setAllowUrlEncodedLineSeparator(boolean z) {
        if (z) {
            this.decodedUrlBlocklist.removeAll(FORBIDDEN_LINE_SEPARATOR);
        } else {
            this.decodedUrlBlocklist.addAll(FORBIDDEN_LINE_SEPARATOR);
        }
    }

    public void setAllowedHeaderNames(Predicate<String> predicate) {
        Assert.notNull(predicate, "allowedHeaderNames cannot be null");
        this.allowedHeaderNames = predicate;
    }

    public void setAllowedHeaderValues(Predicate<String> predicate) {
        Assert.notNull(predicate, "allowedHeaderValues cannot be null");
        this.allowedHeaderValues = predicate;
    }

    public void setAllowedParameterNames(Predicate<String> predicate) {
        Assert.notNull(predicate, "allowedParameterNames cannot be null");
        this.allowedParameterNames = predicate;
    }

    public void setAllowedParameterValues(Predicate<String> predicate) {
        Assert.notNull(predicate, "allowedParameterValues cannot be null");
        this.allowedParameterValues = predicate;
    }

    public void setAllowedHostnames(Predicate<String> predicate) {
        Assert.notNull(predicate, "allowedHostnames cannot be null");
        this.allowedHostnames = predicate;
    }

    private void urlBlocklistsAddAll(Collection<String> collection) {
        this.encodedUrlBlocklist.addAll(collection);
        this.decodedUrlBlocklist.addAll(collection);
    }

    private void urlBlocklistsRemoveAll(Collection<String> collection) {
        this.encodedUrlBlocklist.removeAll(collection);
        this.decodedUrlBlocklist.removeAll(collection);
    }

    private void rejectNonPrintableAsciiCharactersInFieldName(String str, String str2) {
        if (!containsOnlyPrintableAsciiCharacters(str)) {
            throw new ServerExchangeRejectedException(String.format("The %s was rejected because it can only contain printable ASCII characters.", str2));
        }
    }

    private void rejectForbiddenHttpMethod(ServerHttpRequest serverHttpRequest) {
        if (this.allowedHttpMethods != ALLOW_ANY_HTTP_METHOD && !this.allowedHttpMethods.contains(serverHttpRequest.getMethod())) {
            throw new ServerExchangeRejectedException("The request was rejected because the HTTP method \"" + String.valueOf(serverHttpRequest.getMethod()) + "\" was not included within the list of allowed HTTP methods " + String.valueOf(this.allowedHttpMethods));
        }
    }

    private void rejectedBlocklistedUrls(ServerHttpRequest serverHttpRequest) {
        for (String str : this.encodedUrlBlocklist) {
            if (encodedUrlContains(serverHttpRequest, str)) {
                throw new ServerExchangeRejectedException("The request was rejected because the URL contained a potentially malicious String \"" + str + "\"");
            }
        }
        for (String str2 : this.decodedUrlBlocklist) {
            if (decodedUrlContains(serverHttpRequest, str2)) {
                throw new ServerExchangeRejectedException("The request was rejected because the URL contained a potentially malicious String \"" + str2 + "\"");
            }
        }
    }

    private void rejectedUntrustedHosts(ServerHttpRequest serverHttpRequest) {
        String host = serverHttpRequest.getURI().getHost();
        if (host != null && !this.allowedHostnames.test(host)) {
            throw new ServerExchangeRejectedException("The request was rejected because the domain " + host + " is untrusted.");
        }
    }

    private static Set<HttpMethod> createDefaultAllowedHttpMethods() {
        HashSet hashSet = new HashSet();
        hashSet.add(HttpMethod.DELETE);
        hashSet.add(HttpMethod.GET);
        hashSet.add(HttpMethod.HEAD);
        hashSet.add(HttpMethod.OPTIONS);
        hashSet.add(HttpMethod.PATCH);
        hashSet.add(HttpMethod.POST);
        hashSet.add(HttpMethod.PUT);
        return hashSet;
    }

    private boolean isNormalized(ServerHttpRequest serverHttpRequest) {
        return isNormalized(serverHttpRequest.getPath().value()) && isNormalized(serverHttpRequest.getURI().getRawPath()) && isNormalized(serverHttpRequest.getURI().getPath());
    }

    private void validateAllowedHeaderName(String str) {
        if (!this.allowedHeaderNames.test(str)) {
            throw new ServerExchangeRejectedException("The request was rejected because the header name \"" + str + "\" is not allowed.");
        }
    }

    private void validateAllowedHeaderValue(Object obj, String str) {
        if (!this.allowedHeaderValues.test(str)) {
            throw new ServerExchangeRejectedException("The request was rejected because the header: \"" + String.valueOf(obj) + " \" has a value \"" + str + "\" that is not allowed.");
        }
    }

    private void validateAllowedParameterName(String str) {
        if (!this.allowedParameterNames.test(str)) {
            throw new ServerExchangeRejectedException("The request was rejected because the parameter name \"" + str + "\" is not allowed.");
        }
    }

    private void validateAllowedParameterValue(String str, String str2) {
        if (!this.allowedParameterValues.test(str2)) {
            throw new ServerExchangeRejectedException("The request was rejected because the parameter: \"" + str + " \" has a value \"" + str2 + "\" that is not allowed.");
        }
    }

    private static boolean encodedUrlContains(ServerHttpRequest serverHttpRequest, String str) {
        if (valueContains(serverHttpRequest.getPath().value(), str)) {
            return true;
        }
        return valueContains(serverHttpRequest.getURI().getRawPath(), str);
    }

    private static boolean decodedUrlContains(ServerHttpRequest serverHttpRequest, String str) {
        return valueContains(serverHttpRequest.getURI().getPath(), str);
    }

    private static boolean containsOnlyPrintableAsciiCharacters(String str) {
        if (str == null) {
            return true;
        }
        int length = str.length();
        for (int i = 0; i < length; i++) {
            char charAt = str.charAt(i);
            if (charAt < ' ' || charAt > '~') {
                return false;
            }
        }
        return true;
    }

    private static boolean valueContains(String str, String str2) {
        return str != null && str.contains(str2);
    }

    private static boolean isNormalized(String str) {
        if (str == null) {
            return true;
        }
        int length = str.length();
        while (true) {
            int i = length;
            if (i <= 0) {
                return true;
            }
            int lastIndexOf = str.lastIndexOf(47, i - 1);
            int i2 = i - lastIndexOf;
            if (i2 == 2 && str.charAt(lastIndexOf + 1) == '.') {
                return false;
            }
            if (i2 == 3 && str.charAt(lastIndexOf + 1) == '.' && str.charAt(lastIndexOf + 2) == '.') {
                return false;
            }
            length = lastIndexOf;
        }
    }
}
