/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.common.network.client;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.TransportContext;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientBootstrap;
import org.apache.celeborn.common.network.server.TransportChannelHandler;
import org.apache.celeborn.common.network.util.IOMode;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.network.util.TransportFrameDecoder;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransportClientFactory
implements Closeable {
    private static final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
    private final TransportContext context;
    private final List<TransportClientBootstrap> clientBootstraps;
    private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
    private final Random rand;
    private final int numConnectionsPerPeer;
    private final int connectTimeoutMs;
    private final int connectionTimeoutMs;
    private final int sslHandshakeTimeoutMs;
    private final int receiveBuf;
    private final int sendBuf;
    private final Class<? extends Channel> socketChannelClass;
    private EventLoopGroup workerGroup;
    protected ByteBufAllocator allocator;
    private final int maxClientConnectRetries;
    private final int maxClientConnectRetryWaitTimeMs;

    public TransportClientFactory(TransportContext context, List<TransportClientBootstrap> clientBootstraps) {
        this.context = (TransportContext)Preconditions.checkNotNull((Object)context);
        TransportConf conf = context.getConf();
        this.clientBootstraps = Lists.newArrayList((Iterable)((Iterable)Preconditions.checkNotNull(clientBootstraps)));
        this.connectionPool = JavaUtils.newConcurrentHashMap();
        this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
        this.connectTimeoutMs = conf.connectTimeoutMs();
        this.connectionTimeoutMs = conf.connectionTimeoutMs();
        this.sslHandshakeTimeoutMs = conf.sslHandshakeTimeoutMs();
        this.receiveBuf = conf.receiveBuf();
        this.sendBuf = conf.sendBuf();
        this.rand = new Random();
        IOMode ioMode = IOMode.valueOf(conf.ioMode());
        this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
        logger.info("Module {} mode {} threads {}", new Object[]{conf.getModuleName(), ioMode, conf.clientThreads()});
        this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), conf.conflictAvoidChooserEnable(), conf.getModuleName() + "-client");
        this.allocator = NettyUtils.getByteBufAllocator(conf, context.getSource(), false, conf.clientThreads());
        this.maxClientConnectRetries = conf.maxIORetries();
        this.maxClientConnectRetryWaitTimeMs = conf.ioRetryWaitTimeMs();
    }

    public TransportClient createClient(String remoteHost, int remotePort, int partitionId) throws IOException, InterruptedException {
        return this.retryCreateClient(remoteHost, remotePort, partitionId, TransportFrameDecoder::new);
    }

    public TransportClient retryCreateClient(String remoteHost, int remotePort, int partitionId, Supplier<ChannelInboundHandlerAdapter> supplier) throws IOException, InterruptedException {
        int numTries = 0;
        while (numTries < this.maxClientConnectRetries) {
            try {
                return this.createClient(remoteHost, remotePort, partitionId, supplier.get());
            }
            catch (Exception e) {
                if (e instanceof InterruptedException) {
                    Thread.currentThread().interrupt();
                    throw e;
                }
                logger.warn("Retry create client, times {}/{} with error: {}", new Object[]{++numTries, this.maxClientConnectRetries, e.getMessage(), e});
                if (numTries == this.maxClientConnectRetries) {
                    throw e;
                }
                Thread.sleep(this.maxClientConnectRetryWaitTimeMs);
            }
        }
        return null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public TransportClient createClient(String remoteHost, int remotePort, int partitionId, ChannelInboundHandlerAdapter decoder) throws IOException, InterruptedException {
        String resolveMsg;
        InetSocketAddress unresolvedAddress = InetSocketAddress.createUnresolved(remoteHost, remotePort);
        ClientPool clientPool = this.connectionPool.computeIfAbsent(unresolvedAddress, key -> new ClientPool(this.numConnectionsPerPeer));
        int clientIndex = partitionId < 0 ? this.rand.nextInt(this.numConnectionsPerPeer) : partitionId % this.numConnectionsPerPeer;
        TransportClient cachedClient = clientPool.clients[clientIndex];
        if (cachedClient != null && cachedClient.isActive()) {
            TransportChannelHandler handler = (TransportChannelHandler)cachedClient.getChannel().pipeline().get(TransportChannelHandler.class);
            if (handler != null) {
                TransportChannelHandler transportChannelHandler = handler;
                synchronized (transportChannelHandler) {
                    handler.getResponseHandler().updateTimeOfLastRequest();
                }
            }
            if (cachedClient.isActive()) {
                logger.debug("Returning cached connection from {} to {}: {}", new Object[]{cachedClient.getChannel().localAddress(), cachedClient.getSocketAddress(), cachedClient});
                return cachedClient;
            }
        }
        long preResolveHost = System.nanoTime();
        InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
        long hostResolveTimeMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - preResolveHost);
        String string = resolveMsg = resolvedAddress.isUnresolved() ? "failed" : "succeed";
        if (hostResolveTimeMs > 2000L) {
            logger.warn("DNS resolution {} for {} took {} ms", new Object[]{resolveMsg, resolvedAddress, hostResolveTimeMs});
        } else {
            logger.trace("DNS resolution {} for {} took {} ms", new Object[]{resolveMsg, resolvedAddress, hostResolveTimeMs});
        }
        Object object = clientPool.locks[clientIndex];
        synchronized (object) {
            cachedClient = clientPool.clients[clientIndex];
            if (cachedClient != null) {
                if (cachedClient.isActive()) {
                    logger.debug("Returning cached connection from {} to {}: {}", new Object[]{cachedClient.getChannel().localAddress(), resolvedAddress, cachedClient});
                    return cachedClient;
                }
                logger.info("Found inactive connection to {}, creating a new one.", (Object)resolvedAddress);
            }
            clientPool.clients[clientIndex] = this.internalCreateClient(resolvedAddress, decoder);
            return clientPool.clients[clientIndex];
        }
    }

    public TransportClient createClient(String remoteHost, int remotePort) throws IOException, InterruptedException {
        return this.createClient(remoteHost, remotePort, -1);
    }

    private TransportClient internalCreateClient(final InetSocketAddress address, final ChannelInboundHandlerAdapter decoder) throws IOException, InterruptedException {
        Bootstrap bootstrap = new Bootstrap();
        ((Bootstrap)((Bootstrap)((Bootstrap)((Bootstrap)((Bootstrap)bootstrap.group(this.workerGroup)).channel(this.socketChannelClass)).option(ChannelOption.TCP_NODELAY, (Object)true)).option(ChannelOption.SO_KEEPALIVE, (Object)true)).option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (Object)this.connectTimeoutMs)).option(ChannelOption.ALLOCATOR, (Object)this.allocator);
        if (this.receiveBuf > 0) {
            bootstrap.option(ChannelOption.SO_RCVBUF, (Object)this.receiveBuf);
        }
        if (this.sendBuf > 0) {
            bootstrap.option(ChannelOption.SO_SNDBUF, (Object)this.sendBuf);
        }
        final AtomicReference clientRef = new AtomicReference();
        final AtomicReference channelRef = new AtomicReference();
        bootstrap.handler((ChannelHandler)new ChannelInitializer<SocketChannel>(){

            public void initChannel(SocketChannel ch) {
                TransportChannelHandler clientHandler = TransportClientFactory.this.context.initializePipeline(ch, decoder, true);
                clientRef.set(clientHandler.getClient());
                channelRef.set(ch);
            }
        });
        long preConnect = System.nanoTime();
        final ChannelFuture cf = bootstrap.connect((SocketAddress)address);
        if (this.connectTimeoutMs <= 0) {
            cf.await();
            assert (cf.isDone());
            if (cf.isCancelled()) {
                throw new IOException(String.format("Connecting to %s cancelled", address));
            }
            if (!cf.isSuccess()) {
                throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
            }
        } else {
            if (!cf.await((long)this.connectTimeoutMs)) {
                throw new CelebornIOException(String.format("Connecting to %s timed out (%s ms)", address, this.connectTimeoutMs));
            }
            if (cf.cause() != null) {
                throw new CelebornIOException(String.format("Failed to connect to %s", address), cf.cause());
            }
        }
        if (this.context.sslEncryptionEnabled()) {
            SslHandler sslHandler = (SslHandler)cf.channel().pipeline().get(SslHandler.class);
            sslHandler.setHandshakeTimeoutMillis((long)this.sslHandshakeTimeoutMs);
            Future future = sslHandler.handshakeFuture().addListener((GenericFutureListener)new GenericFutureListener<Future<Channel>>(){

                public void operationComplete(Future<Channel> handshakeFuture) {
                    if (handshakeFuture.isSuccess()) {
                        logger.debug("successfully completed TLS handshake to {}", (Object)address);
                    } else {
                        logger.info("failed to complete TLS handshake to {}", (Object)address, (Object)handshakeFuture.cause());
                        cf.channel().close();
                    }
                }
            });
            if (!future.await((long)this.connectionTimeoutMs)) {
                cf.channel().close();
                throw new IOException(String.format("Failed to connect to %s within connection timeout", address));
            }
        }
        TransportClient client = (TransportClient)clientRef.get();
        assert (client != null) : "Channel future completed successfully with null client";
        long preBootstrap = System.nanoTime();
        logger.debug("Running bootstraps for {} ...", (Object)address);
        try {
            for (TransportClientBootstrap clientBootstrap : this.clientBootstraps) {
                clientBootstrap.doBootstrap(client);
            }
        }
        catch (Exception e) {
            long bootstrapTime = System.nanoTime() - preBootstrap;
            logger.error("Exception while bootstrapping client after {}", (Object)Utils.nanoDurationToString(bootstrapTime), (Object)e);
            client.close();
            throw Throwables.propagate((Throwable)e);
        }
        long postBootstrap = System.nanoTime();
        logger.debug("Successfully created connection to {} after {} ({} spent in bootstraps)", new Object[]{address, Utils.nanoDurationToString(postBootstrap - preConnect), Utils.nanoDurationToString(postBootstrap - preBootstrap)});
        return client;
    }

    @Override
    public void close() {
        for (ClientPool clientPool : this.connectionPool.values()) {
            for (int i = 0; i < clientPool.clients.length; ++i) {
                TransportClient client = clientPool.clients[i];
                if (client == null) continue;
                clientPool.clients[i] = null;
                JavaUtils.closeQuietly(client);
            }
        }
        this.connectionPool.clear();
        if (this.workerGroup != null && !this.workerGroup.isShuttingDown()) {
            this.workerGroup.shutdownGracefully();
        }
    }

    public TransportContext getContext() {
        return this.context;
    }

    private static class ClientPool {
        TransportClient[] clients;
        Object[] locks;

        ClientPool(int size) {
            this.clients = new TransportClient[size];
            this.locks = new Object[size];
            for (int i = 0; i < size; ++i) {
                this.locks[i] = new Object();
            }
        }
    }
}

