/*
 * Decompiled with CFR 0.152.
 */
package io.helidon.webserver;

import io.helidon.http.DirectHandler;
import io.helidon.http.RequestException;
import io.helidon.webserver.ProxyProtocolData;
import java.io.IOException;
import java.io.InputStream;
import java.io.PushbackInputStream;
import java.io.UncheckedIOException;
import java.net.Inet6Address;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.function.Supplier;

class ProxyProtocolHandler
implements Supplier<ProxyProtocolData> {
    private static final System.Logger LOGGER = System.getLogger(ProxyProtocolHandler.class.getName());
    private static final int MAX_V1_FIELD_LENGTH = 40;
    private static final int MAX_TLV_BYTES_TO_SKIP = 512;
    static final byte[] V1_PREFIX = new byte[]{80, 82, 79, 88, 89};
    static final byte[] V2_PREFIX_1 = new byte[]{13, 10, 13, 10, 0};
    static final byte[] V2_PREFIX_2 = new byte[]{13, 10, 81, 85, 73, 84, 10};
    static final RequestException BAD_PROTOCOL_EXCEPTION = RequestException.builder().type(DirectHandler.EventType.OTHER).message("Unable to parse proxy protocol header").build();
    private final Socket socket;
    private final String channelId;

    ProxyProtocolHandler(Socket socket, String channelId) {
        this.socket = socket;
        this.channelId = channelId;
    }

    @Override
    public ProxyProtocolData get() {
        LOGGER.log(System.Logger.Level.DEBUG, "Reading proxy protocol data for channel %s", this.channelId);
        try {
            byte[] prefix = new byte[V1_PREFIX.length];
            PushbackInputStream inputStream = new PushbackInputStream(this.socket.getInputStream(), 1);
            int n = inputStream.read(prefix);
            if (n < V1_PREFIX.length) {
                throw BAD_PROTOCOL_EXCEPTION;
            }
            if (ProxyProtocolHandler.arrayEquals(prefix, V1_PREFIX, V1_PREFIX.length)) {
                return ProxyProtocolHandler.handleV1Protocol(inputStream);
            }
            if (ProxyProtocolHandler.arrayEquals(prefix, V2_PREFIX_1, V2_PREFIX_1.length)) {
                return ProxyProtocolHandler.handleV2Protocol(inputStream);
            }
            throw BAD_PROTOCOL_EXCEPTION;
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    static ProxyProtocolData handleV1Protocol(PushbackInputStream inputStream) throws IOException {
        try {
            byte[] buffer = new byte[40];
            ProxyProtocolHandler.match(inputStream, (byte)32);
            int n = ProxyProtocolHandler.readUntil(inputStream, buffer, 32, 13);
            String familyProtocol = new String(buffer, 0, n, StandardCharsets.US_ASCII);
            ProxyProtocolData.Family family = ProxyProtocolData.Family.fromString(familyProtocol);
            ProxyProtocolData.Protocol protocol = ProxyProtocolData.Protocol.fromString(familyProtocol);
            byte b = ProxyProtocolHandler.readNext(inputStream);
            if (b == 13 && family == ProxyProtocolData.Family.UNKNOWN) {
                return new ProxyProtocolDataImpl(ProxyProtocolData.Family.UNKNOWN, ProxyProtocolData.Protocol.UNKNOWN, "", "", -1, -1);
            }
            ProxyProtocolHandler.match(b, (byte)32);
            n = ProxyProtocolHandler.readUntil(inputStream, buffer, 32);
            String sourceAddress = new String(buffer, 0, n, StandardCharsets.US_ASCII);
            ProxyProtocolHandler.match(inputStream, (byte)32);
            n = ProxyProtocolHandler.readUntil(inputStream, buffer, 32);
            String destAddress = new String(buffer, 0, n, StandardCharsets.US_ASCII);
            ProxyProtocolHandler.match(inputStream, (byte)32);
            n = ProxyProtocolHandler.readUntil(inputStream, buffer, 32);
            int sourcePort = Integer.parseInt(new String(buffer, 0, n, StandardCharsets.US_ASCII));
            ProxyProtocolHandler.match(inputStream, (byte)32);
            n = ProxyProtocolHandler.readUntil(inputStream, buffer, 13);
            int destPort = Integer.parseInt(new String(buffer, 0, n, StandardCharsets.US_ASCII));
            ProxyProtocolHandler.match(inputStream, (byte)13);
            ProxyProtocolHandler.match(inputStream, (byte)10);
            return new ProxyProtocolDataImpl(family, protocol, sourceAddress, destAddress, sourcePort, destPort);
        }
        catch (IllegalArgumentException e) {
            throw BAD_PROTOCOL_EXCEPTION;
        }
    }

    static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throws IOException {
        ProxyProtocolHandler.match(inputStream, V2_PREFIX_2);
        byte b = ProxyProtocolHandler.readNext(inputStream);
        if (b >>> 4 != 2) {
            throw BAD_PROTOCOL_EXCEPTION;
        }
        b = ProxyProtocolHandler.readNext(inputStream);
        ProxyProtocolData.Family family = switch (b >>> 4) {
            case 1 -> ProxyProtocolData.Family.IPv4;
            case 2 -> ProxyProtocolData.Family.IPv6;
            case 3 -> ProxyProtocolData.Family.UNIX;
            default -> ProxyProtocolData.Family.UNKNOWN;
        };
        ProxyProtocolData.Protocol protocol = switch (b & 0xF) {
            case 1 -> ProxyProtocolData.Protocol.TCP;
            case 2 -> ProxyProtocolData.Protocol.UDP;
            default -> ProxyProtocolData.Protocol.UNKNOWN;
        };
        b = ProxyProtocolHandler.readNext(inputStream);
        int headerLength = b << 8 & 0xFF00 | ProxyProtocolHandler.readNext(inputStream) & 0xFF;
        Object sourceAddress = "";
        Object destAddress = "";
        int sourcePort = -1;
        int destPort = -1;
        switch (family) {
            case IPv4: {
                byte[] buffer = new byte[12];
                int n = inputStream.read(buffer, 0, buffer.length);
                if (n < buffer.length) {
                    throw BAD_PROTOCOL_EXCEPTION;
                }
                sourceAddress = (buffer[0] & 0xFF) + "." + (buffer[1] & 0xFF) + "." + (buffer[2] & 0xFF) + "." + (buffer[3] & 0xFF);
                destAddress = (buffer[4] & 0xFF) + "." + (buffer[5] & 0xFF) + "." + (buffer[6] & 0xFF) + "." + (buffer[7] & 0xFF);
                sourcePort = buffer[9] & 0xFF | buffer[8] << 8 & 0xFF00;
                destPort = buffer[11] & 0xFF | buffer[10] << 8 & 0xFF00;
                headerLength -= buffer.length;
                break;
            }
            case IPv6: {
                byte[] buffer = new byte[16];
                int n = inputStream.read(buffer, 0, buffer.length);
                if (n < buffer.length) {
                    throw BAD_PROTOCOL_EXCEPTION;
                }
                sourceAddress = Inet6Address.getByAddress(buffer).getHostAddress();
                n = inputStream.read(buffer, 0, buffer.length);
                if (n < buffer.length) {
                    throw BAD_PROTOCOL_EXCEPTION;
                }
                destAddress = Inet6Address.getByAddress(buffer).getHostAddress();
                n = inputStream.read(buffer, 0, 4);
                if (n < 4) {
                    throw BAD_PROTOCOL_EXCEPTION;
                }
                sourcePort = buffer[1] & 0xFF | buffer[0] << 8 & 0xFF00;
                destPort = buffer[3] & 0xFF | buffer[2] << 8 & 0xFF00;
                headerLength -= 2 * buffer.length + 4;
                break;
            }
            case UNIX: {
                byte[] buffer = new byte[216];
                int n = inputStream.read(buffer, 0, buffer.length);
                if (n < buffer.length) {
                    throw BAD_PROTOCOL_EXCEPTION;
                }
                sourceAddress = new String(buffer, 0, 108, StandardCharsets.US_ASCII);
                destAddress = new String(buffer, 108, buffer.length, StandardCharsets.US_ASCII);
                headerLength -= buffer.length;
                break;
            }
        }
        if (headerLength > 512) {
            throw BAD_PROTOCOL_EXCEPTION;
        }
        while (headerLength > 0) {
            headerLength -= (int)inputStream.skip(headerLength);
        }
        return new ProxyProtocolDataImpl(family, protocol, (String)sourceAddress, (String)destAddress, sourcePort, destPort);
    }

    private static byte readNext(InputStream inputStream) throws IOException {
        int b = inputStream.read();
        if (b < 0) {
            throw BAD_PROTOCOL_EXCEPTION;
        }
        return (byte)b;
    }

    private static void match(byte a, byte b) {
        if (a != b) {
            throw BAD_PROTOCOL_EXCEPTION;
        }
    }

    private static void match(PushbackInputStream inputStream, byte b) throws IOException {
        if (inputStream.read() != b) {
            throw BAD_PROTOCOL_EXCEPTION;
        }
    }

    private static void match(PushbackInputStream inputStream, byte ... bs) throws IOException {
        for (byte b : bs) {
            int c = inputStream.read();
            if ((byte)c == b) continue;
            throw BAD_PROTOCOL_EXCEPTION;
        }
    }

    private static int readUntil(PushbackInputStream inputStream, byte[] buffer, byte ... delims) throws IOException {
        int n = 0;
        do {
            byte b;
            if (ProxyProtocolHandler.arrayContains(delims, b = ProxyProtocolHandler.readNext(inputStream))) {
                inputStream.unread(b);
                return n;
            }
            buffer[n++] = b;
        } while (n < buffer.length);
        throw BAD_PROTOCOL_EXCEPTION;
    }

    private static boolean arrayEquals(byte[] array1, byte[] array2, int prefix) {
        return Arrays.equals(array1, 0, prefix, array2, 0, prefix);
    }

    private static boolean arrayContains(byte[] array, byte b) {
        for (byte a : array) {
            if (a != b) continue;
            return true;
        }
        return false;
    }

    record ProxyProtocolDataImpl(ProxyProtocolData.Family family, ProxyProtocolData.Protocol protocol, String sourceAddress, String destAddress, int sourcePort, int destPort) implements ProxyProtocolData
    {
    }
}

