/*
 * Decompiled with CFR 0.152.
 */
package io.questdb.network;

import io.questdb.ClientTlsConfiguration;
import io.questdb.cutlass.line.LineSenderException;
import io.questdb.cutlass.line.tcp.DelegatingTlsChannel;
import io.questdb.log.Log;
import io.questdb.network.NetworkFacade;
import io.questdb.network.PlainSocket;
import io.questdb.network.Socket;
import io.questdb.network.TlsSessionInitFailedException;
import io.questdb.std.Chars;
import io.questdb.std.Unsafe;
import io.questdb.std.Vect;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;

public final class JavaTlsClientSocket
implements Socket {
    private static final long ADDRESS_FIELD_OFFSET;
    private static final TrustManager[] BLIND_TRUST_MANAGERS;
    private static final long CAPACITY_FIELD_OFFSET;
    private static final int INITIAL_BUFFER_CAPACITY_BYTES = 262144;
    private static final long LIMIT_FIELD_OFFSET;
    private static final int STATE_CLOSING = 3;
    private static final int STATE_EMPTY = 0;
    private static final int STATE_PLAINTEXT = 1;
    private static final int STATE_TLS = 2;
    private final Socket delegate;
    private final Log log;
    private final ClientTlsConfiguration tlsConfig;
    private final ByteBuffer unwrapInputBuffer;
    private final ByteBuffer unwrapOutputBuffer;
    private final ByteBuffer wrapInputBuffer;
    private final ByteBuffer wrapOutputBuffer;
    private SSLEngine sslEngine;
    private int state = 0;
    private long unwrapInputBufferPtr;
    private long wrapOutputBufferPtr;

    JavaTlsClientSocket(NetworkFacade nf, Log log, ClientTlsConfiguration tlsConfig) {
        this.delegate = new PlainSocket(nf, log);
        this.log = log;
        this.tlsConfig = tlsConfig;
        this.wrapInputBuffer = ByteBuffer.allocateDirect(0);
        this.unwrapOutputBuffer = ByteBuffer.allocateDirect(0);
        this.wrapOutputBuffer = ByteBuffer.allocateDirect(0);
        this.unwrapInputBuffer = ByteBuffer.allocateDirect(0);
    }

    @Override
    public void close() {
        this.log.debug().$("closing TLS socket [fd=").$(this.delegate.getFd()).$(']').$();
        switch (this.state) {
            case 0: 
            case 3: {
                return;
            }
            case 2: {
                assert (this.sslEngine != null);
                this.state = 3;
                this.sslEngine.closeOutbound();
                try {
                    this.sslEngine.wrap(this.wrapInputBuffer, this.wrapOutputBuffer);
                    while (this.wantsTlsWrite()) {
                        int n = this.tlsIO(1);
                        if (n >= 0) continue;
                        this.log.debug().$("could not send TLS close_notify").$();
                        break;
                    }
                }
                catch (SSLException e) {
                    this.log.debug().$("could not send TLS close_notify").$(e).$();
                }
                this.sslEngine = null;
            }
            case 1: {
                this.state = 3;
                this.freeInternalBuffers();
                this.delegate.close();
                this.state = 0;
            }
        }
    }

    @Override
    public long getFd() {
        return this.delegate.getFd();
    }

    @Override
    public boolean isClosed() {
        return this.delegate.isClosed();
    }

    @Override
    public boolean isMorePlaintextBuffered() {
        return false;
    }

    @Override
    public boolean isTlsSessionStarted() {
        return this.sslEngine != null;
    }

    @Override
    public void of(long fd) {
        assert (this.state == 0);
        this.delegate.of(fd);
        this.state = 1;
    }

    @Override
    public int recv(long bufferPtr, int bufferLen) {
        assert (this.sslEngine != null);
        JavaTlsClientSocket.resetBufferToPointer(this.unwrapOutputBuffer, bufferPtr, bufferLen);
        this.unwrapOutputBuffer.position(0);
        try {
            int plainBytesReceived = 0;
            while (true) {
                int n = this.readFromSocket();
                assert (this.unwrapInputBuffer.position() == 0) : "unwrapInputBuffer is not compacted";
                int bytesAvailable = this.unwrapInputBuffer.limit();
                if (n < 0 && bytesAvailable == 0) {
                    if (plainBytesReceived == 0) {
                        return n;
                    }
                    return plainBytesReceived;
                }
                if (bytesAvailable == 0) {
                    return plainBytesReceived;
                }
                SSLEngineResult result = this.sslEngine.unwrap(this.unwrapInputBuffer, this.unwrapOutputBuffer);
                plainBytesReceived += result.bytesProduced();
                int bytesConsumed = result.bytesConsumed();
                int bytesRemaining = bytesAvailable - bytesConsumed;
                Vect.memcpy(this.unwrapInputBufferPtr, this.unwrapInputBufferPtr + (long)bytesConsumed, bytesRemaining);
                this.unwrapInputBuffer.position(0);
                this.unwrapInputBuffer.limit(bytesRemaining);
                switch (result.getStatus()) {
                    case BUFFER_UNDERFLOW: {
                        return plainBytesReceived;
                    }
                    case BUFFER_OVERFLOW: {
                        if (this.unwrapOutputBuffer.position() == 0) {
                            throw new AssertionError((Object)"Output buffer too small to fit a single TLS record. This should not happen, please report as a bug.");
                        }
                        return plainBytesReceived;
                    }
                    case OK: {
                        break;
                    }
                    case CLOSED: {
                        this.log.debug().$("SSL engine closed").$();
                        return plainBytesReceived == 0 ? -1 : plainBytesReceived;
                    }
                }
            }
        }
        catch (SSLException e) {
            this.log.error().$("could not unwrap SSL packet").$(e).$();
            return -1;
        }
    }

    @Override
    public int send(long bufferPtr, int bufferLen) {
        try {
            JavaTlsClientSocket.resetBufferToPointer(this.wrapInputBuffer, bufferPtr, bufferLen);
            this.wrapInputBuffer.position(0);
            int plainBytesConsumed = 0;
            while (true) {
                int bytesToSend;
                if ((bytesToSend = this.wrapOutputBuffer.position()) > 0) {
                    int sent = this.writeToSocket(bytesToSend);
                    if (sent < 0) {
                        return sent;
                    }
                    if (sent < bytesToSend) {
                        return plainBytesConsumed;
                    }
                }
                if (this.wrapInputBuffer.remaining() == 0) {
                    return plainBytesConsumed;
                }
                SSLEngineResult result = this.sslEngine.wrap(this.wrapInputBuffer, this.wrapOutputBuffer);
                plainBytesConsumed += result.bytesConsumed();
                switch (result.getStatus()) {
                    case BUFFER_UNDERFLOW: {
                        throw new AssertionError((Object)"Underflow while reading a plain text. This should not happen, please report as a bug");
                    }
                    case BUFFER_OVERFLOW: {
                        if (this.wrapOutputBuffer.position() != 0) break;
                        this.growWrapOutputBuffer();
                        break;
                    }
                    case OK: {
                        break;
                    }
                    case CLOSED: {
                        this.log.error().$("Attempt to send to a closed SSLEngine").$();
                        return -1;
                    }
                }
            }
        }
        catch (SSLException e) {
            this.log.error().$("could not wrap SSL packet").$(e).$();
            return -1;
        }
    }

    @Override
    public int shutdown(int how) {
        return this.delegate.shutdown(how);
    }

    @Override
    public void startTlsSession(CharSequence peerName) throws TlsSessionInitFailedException {
        assert (this.state == 1);
        this.prepareInternalBuffers();
        try {
            this.sslEngine = this.createSslEngine(peerName);
            this.sslEngine.beginHandshake();
            SSLEngineResult.HandshakeStatus handshakeStatus = this.sslEngine.getHandshakeStatus();
            while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
                block1 : switch (handshakeStatus) {
                    case NEED_TASK: {
                        Runnable task;
                        while ((task = this.sslEngine.getDelegatedTask()) != null) {
                            task.run();
                        }
                        handshakeStatus = this.sslEngine.getHandshakeStatus();
                        break;
                    }
                    case NEED_WRAP: {
                        SSLEngineResult result = this.sslEngine.wrap(this.wrapInputBuffer, this.wrapOutputBuffer);
                        handshakeStatus = result.getHandshakeStatus();
                        switch (result.getStatus()) {
                            case BUFFER_UNDERFLOW: {
                                throw new AssertionError((Object)"Buffer underflow during TLS handshake. This should not happen. please report as a bug");
                            }
                            case BUFFER_OVERFLOW: {
                                throw new AssertionError((Object)"Buffer overflow during TLS handshake. This should not happen, please report as a bug");
                            }
                            case OK: {
                                int n;
                                int bufferLimit = this.wrapOutputBuffer.position();
                                for (int written = 0; written < bufferLimit; written += n) {
                                    n = this.delegate.send(this.wrapOutputBufferPtr + (long)written, bufferLimit - written);
                                    if (n >= 0) continue;
                                    throw TlsSessionInitFailedException.instance("socket write error");
                                }
                                this.wrapOutputBuffer.clear();
                                break block1;
                            }
                            case CLOSED: {
                                throw TlsSessionInitFailedException.instance("server closed connection unexpectedly");
                            }
                        }
                        break;
                    }
                    case NEED_UNWRAP: {
                        int n = this.readFromSocket();
                        if (n < 0) {
                            throw TlsSessionInitFailedException.instance("socket read error");
                        }
                        SSLEngineResult result = this.sslEngine.unwrap(this.unwrapInputBuffer, this.unwrapOutputBuffer);
                        handshakeStatus = result.getHandshakeStatus();
                        switch (result.getStatus()) {
                            case BUFFER_UNDERFLOW: {
                                break block1;
                            }
                            case BUFFER_OVERFLOW: {
                                throw new AssertionError((Object)"Buffer overflow during TLS handshake. This should not happen, please report as a bug");
                            }
                            case OK: {
                                break block1;
                            }
                            case CLOSED: {
                                throw TlsSessionInitFailedException.instance("server closed connection unexpectedly");
                            }
                        }
                    }
                }
            }
            this.unwrapInputBuffer.position(0);
            this.unwrapInputBuffer.limit(0);
            this.unwrapOutputBuffer.clear();
            this.wrapOutputBuffer.clear();
            this.state = 2;
        }
        catch (IOException | KeyManagementException | KeyStoreException | NoSuchAlgorithmException | CertificateException e) {
            throw TlsSessionInitFailedException.instance("TLS session creation failed [error=").put(e.getMessage()).put(']');
        }
    }

    @Override
    public boolean supportsTls() {
        return true;
    }

    @Override
    public int tlsIO(int readinessFlags) {
        int bytesToSend;
        if ((readinessFlags & 1) != 0 && (bytesToSend = this.wrapOutputBuffer.position()) > 0) {
            int n = this.writeToSocket(bytesToSend);
            return Math.min(n, 0);
        }
        return 0;
    }

    @Override
    public boolean wantsTlsRead() {
        return false;
    }

    @Override
    public boolean wantsTlsWrite() {
        return this.wrapOutputBuffer.position() > 0;
    }

    private static long allocateMemoryAndResetBuffer(ByteBuffer buffer, int capacity) {
        long newAddress = Unsafe.malloc(capacity, 58);
        JavaTlsClientSocket.resetBufferToPointer(buffer, newAddress, capacity);
        return newAddress;
    }

    private static long expandBuffer(ByteBuffer buffer, long oldAddress) {
        int oldCapacity = buffer.capacity();
        int newCapacity = oldCapacity * 2;
        long newAddress = Unsafe.realloc(oldAddress, oldCapacity, newCapacity, 58);
        JavaTlsClientSocket.resetBufferToPointer(buffer, newAddress, newCapacity);
        return newAddress;
    }

    private static InputStream openTrustStoreStream(String trustStorePath) throws FileNotFoundException {
        if (trustStorePath.startsWith("classpath:")) {
            String adjustedPath = trustStorePath.substring("classpath:".length());
            InputStream trustStoreStream = DelegatingTlsChannel.class.getResourceAsStream(adjustedPath);
            if (trustStoreStream == null) {
                throw new LineSenderException((CharSequence)"configured trust store is unavailable ").put("[path=").put(trustStorePath).put("]");
            }
            return trustStoreStream;
        }
        return new FileInputStream(trustStorePath);
    }

    private static void resetBufferToPointer(ByteBuffer buffer, long ptr, int len) {
        assert (buffer.isDirect());
        Unsafe.getUnsafe().putLong(buffer, ADDRESS_FIELD_OFFSET, ptr);
        Unsafe.getUnsafe().putLong(buffer, LIMIT_FIELD_OFFSET, len);
        Unsafe.getUnsafe().putLong(buffer, CAPACITY_FIELD_OFFSET, len);
        buffer.position(0);
    }

    private SSLEngine createSslEngine(CharSequence serverName) throws KeyManagementException, NoSuchAlgorithmException, KeyStoreException, IOException, CertificateException {
        SSLContext sslContext;
        String trustStorePath = this.tlsConfig.trustStorePath();
        int tlsValidationMode = this.tlsConfig.tlsValidationMode();
        if (trustStorePath != null) {
            sslContext = SSLContext.getInstance("TLS");
            TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
            KeyStore jks = KeyStore.getInstance("JKS");
            try (InputStream trustStoreStream = JavaTlsClientSocket.openTrustStoreStream(trustStorePath);){
                jks.load(trustStoreStream, this.tlsConfig.trustStorePassword());
            }
            tmf.init(jks);
            TrustManager[] trustManagers = tmf.getTrustManagers();
            sslContext.init(null, trustManagers, new SecureRandom());
        } else if (tlsValidationMode == 1) {
            sslContext = SSLContext.getInstance("TLS");
            sslContext.init(null, BLIND_TRUST_MANAGERS, new SecureRandom());
        } else {
            sslContext = SSLContext.getDefault();
        }
        SSLEngine sslEngine = sslContext.createSSLEngine(Chars.toString(serverName), -1);
        if (tlsValidationMode != 1) {
            SSLParameters sslParameters = sslEngine.getSSLParameters();
            sslParameters.setEndpointIdentificationAlgorithm("https");
            sslEngine.setSSLParameters(sslParameters);
        }
        sslEngine.setUseClientMode(true);
        return sslEngine;
    }

    private void freeInternalBuffers() {
        long ptrToFree = this.wrapOutputBufferPtr;
        if (ptrToFree != 0L) {
            int capacity = this.wrapOutputBuffer.capacity();
            assert (capacity != 0);
            JavaTlsClientSocket.resetBufferToPointer(this.wrapOutputBuffer, 0L, 0);
            this.wrapOutputBufferPtr = 0L;
            Unsafe.free(ptrToFree, capacity, 58);
            assert (this.unwrapInputBufferPtr != 0L);
            capacity = this.unwrapInputBuffer.capacity();
            assert (capacity != 0);
            JavaTlsClientSocket.resetBufferToPointer(this.unwrapInputBuffer, 0L, 0);
            ptrToFree = this.unwrapInputBufferPtr;
            this.unwrapInputBufferPtr = 0L;
            Unsafe.free(ptrToFree, capacity, 58);
        }
    }

    private void growWrapOutputBuffer() {
        this.wrapOutputBufferPtr = JavaTlsClientSocket.expandBuffer(this.wrapOutputBuffer, this.wrapOutputBufferPtr);
    }

    private void prepareInternalBuffers() {
        int initialCapacity = Integer.getInteger("questdb.experimental.tls.buffersize", 262144);
        this.wrapOutputBufferPtr = JavaTlsClientSocket.allocateMemoryAndResetBuffer(this.wrapOutputBuffer, initialCapacity);
        this.unwrapInputBufferPtr = JavaTlsClientSocket.allocateMemoryAndResetBuffer(this.unwrapInputBuffer, initialCapacity);
        this.unwrapInputBuffer.flip();
    }

    private int readFromSocket() {
        int writerPos = this.unwrapInputBuffer.limit();
        int freeSpace = this.unwrapInputBuffer.capacity() - writerPos;
        if (freeSpace == 0) {
            return 0;
        }
        assert (Unsafe.getUnsafe().getLong(this.unwrapInputBuffer, ADDRESS_FIELD_OFFSET) == this.unwrapInputBufferPtr);
        long adjustedPtr = this.unwrapInputBufferPtr + (long)writerPos;
        int n = this.delegate.recv(adjustedPtr, freeSpace);
        if (n < 0) {
            return n;
        }
        this.unwrapInputBuffer.limit(writerPos + n);
        return n;
    }

    private int writeToSocket(int bytesToSend) {
        int n = this.delegate.send(this.wrapOutputBufferPtr, bytesToSend);
        if (n < 0) {
            return n;
        }
        int bytesRemaining = bytesToSend - n;
        Vect.memmove(this.wrapOutputBufferPtr, this.wrapOutputBufferPtr + (long)n, bytesRemaining);
        this.wrapOutputBuffer.position(bytesRemaining);
        return n;
    }

    static {
        Field capacityField;
        Field limitField;
        Field addressField;
        BLIND_TRUST_MANAGERS = new TrustManager[]{new X509TrustManager(){

            @Override
            public void checkClientTrusted(X509Certificate[] certs, String t) {
            }

            @Override
            public void checkServerTrusted(X509Certificate[] certs, String t) {
            }

            @Override
            public X509Certificate[] getAcceptedIssuers() {
                return null;
            }
        }};
        try {
            addressField = Buffer.class.getDeclaredField("address");
            limitField = Buffer.class.getDeclaredField("limit");
            capacityField = Buffer.class.getDeclaredField("capacity");
        }
        catch (NoSuchFieldException e) {
            throw new ExceptionInInitializerError(e);
        }
        ADDRESS_FIELD_OFFSET = Unsafe.getUnsafe().objectFieldOffset(addressField);
        LIMIT_FIELD_OFFSET = Unsafe.getUnsafe().objectFieldOffset(limitField);
        CAPACITY_FIELD_OFFSET = Unsafe.getUnsafe().objectFieldOffset(capacityField);
    }
}

