/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.protocol.client.ainode;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService;
import org.apache.iotdb.ainode.rpc.thrift.TConfigs;
import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TForecastReq;
import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
import org.apache.iotdb.common.rpc.thrift.TAINodeLocation;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.client.ClientManager;
import org.apache.iotdb.commons.client.IClientManager;
import org.apache.iotdb.commons.client.ThriftClient;
import org.apache.iotdb.commons.client.factory.ThriftClientFactory;
import org.apache.iotdb.commons.client.property.ThriftClientProperty;
import org.apache.iotdb.commons.conf.CommonConfig;
import org.apache.iotdb.commons.conf.CommonDescriptor;
import org.apache.iotdb.commons.consensus.ConfigRegionId;
import org.apache.iotdb.commons.exception.ainode.LoadModelException;
import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp;
import org.apache.iotdb.db.protocol.client.ConfigNodeClient;
import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager;
import org.apache.iotdb.db.protocol.client.ConfigNodeInfo;
import org.apache.iotdb.rpc.TConfigurationConst;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.thrift.TException;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.layered.TFramedTransport;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
import org.apache.tsfile.read.common.block.column.TsBlockSerde;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AINodeClient
implements AutoCloseable,
ThriftClient {
    private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class);
    private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig();
    private TEndPoint endPoint;
    private TTransport transport;
    private final ThriftClientProperty property;
    private IAINodeRPCService.Client client;
    public static final String MSG_CONNECTION_FAIL = "Fail to connect to AINode. Please check status of AINode";
    private static final int MAX_RETRY = 3;
    private final TsBlockSerde tsBlockSerde = new TsBlockSerde();
    ClientManager<TEndPoint, AINodeClient> clientManager;
    private static final IClientManager<ConfigRegionId, ConfigNodeClient> CONFIG_NODE_CLIENT_MANAGER = ConfigNodeClientManager.getInstance();
    private static final AtomicReference<TAINodeLocation> CURRENT_LOCATION = new AtomicReference();

    public static TEndPoint getCurrentEndpoint() {
        TAINodeLocation loc = CURRENT_LOCATION.get();
        if (loc == null) {
            loc = AINodeClient.refreshFromConfigNode();
        }
        return loc == null ? null : AINodeClient.pickEndpointFrom(loc);
    }

    public static void updateGlobalAINodeLocation(TAINodeLocation loc) {
        if (loc != null) {
            CURRENT_LOCATION.set(loc);
        }
    }

    private <R> R executeRemoteCallWithRetry(RemoteCall<R> call) throws TException {
        TException last = null;
        for (int attempt = 1; attempt <= 3; ++attempt) {
            try {
                if (this.transport == null || !this.transport.isOpen()) {
                    TEndPoint ep = AINodeClient.getCurrentEndpoint();
                    if (ep == null) {
                        throw new TException("AINode endpoint unavailable");
                    }
                    this.endPoint = ep;
                    this.init();
                }
                return call.apply(this.client);
            }
            catch (TException e) {
                last = e;
                this.invalidate();
                TAINodeLocation loc = AINodeClient.refreshFromConfigNode();
                if (loc != null) {
                    this.endPoint = AINodeClient.pickEndpointFrom(loc);
                }
                try {
                    Thread.sleep(1000L * (long)attempt);
                }
                catch (InterruptedException ie) {
                    Thread.currentThread().interrupt();
                }
                continue;
            }
        }
        throw last != null ? last : new TException(MSG_CONNECTION_FAIL);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static TAINodeLocation refreshFromConfigNode() {
        try (ConfigNodeClient cn = (ConfigNodeClient)CONFIG_NODE_CLIENT_MANAGER.borrowClient((Object)ConfigNodeInfo.CONFIG_REGION_ID);){
            TGetAINodeLocationResp resp = cn.getAINodeLocation();
            if (resp == null) return null;
            if (!resp.isSetAiNodeLocation()) return null;
            TAINodeLocation loc = resp.getAiNodeLocation();
            CURRENT_LOCATION.set(loc);
            TAINodeLocation tAINodeLocation = loc;
            return tAINodeLocation;
        }
        catch (Exception e) {
            LoggerFactory.getLogger(AINodeClient.class).debug("[AINodeClient] refreshFromConfigNode failed: {}", (Object)e.toString());
        }
        return null;
    }

    private static TEndPoint pickEndpointFrom(TAINodeLocation loc) {
        if (loc == null) {
            return null;
        }
        if (loc.isSetInternalEndPoint() && loc.getInternalEndPoint() != null) {
            return loc.getInternalEndPoint();
        }
        return null;
    }

    public AINodeClient(ThriftClientProperty property, TEndPoint endPoint, ClientManager<TEndPoint, AINodeClient> clientManager) throws TException {
        this.property = property;
        this.clientManager = clientManager;
        this.endPoint = endPoint;
        this.init();
    }

    private void init() throws TException {
        try {
            this.transport = new TFramedTransport.Factory().getTransport((TTransport)new TSocket(TConfigurationConst.defaultTConfiguration, this.endPoint.getIp(), this.endPoint.getPort(), this.property.getConnectionTimeoutMs()));
            if (!this.transport.isOpen()) {
                this.transport.open();
            }
        }
        catch (TTransportException e) {
            throw new TException(MSG_CONNECTION_FAIL);
        }
        this.client = new IAINodeRPCService.Client(this.property.getProtocolFactory().getProtocol(this.transport));
    }

    public TTransport getTransport() {
        return this.transport;
    }

    public TSStatus stopAINode() throws TException {
        try {
            TSStatus status = this.client.stopAINode();
            if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
                throw new TException(status.message);
            }
            return status;
        }
        catch (TException e) {
            logger.warn("Failed to connect to AINode from ConfigNode when executing {}: {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName(), (Object)e.getMessage());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    public ModelInformation registerModel(String modelName, String uri) throws LoadModelException {
        try {
            TRegisterModelReq req = new TRegisterModelReq(uri, modelName);
            TRegisterModelResp resp = this.client.registerModel(req);
            if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
                throw new LoadModelException(resp.status.message, resp.status.getCode());
            }
            return this.parseModelInformation(modelName, resp.getAttributes(), resp.getConfigs());
        }
        catch (TException e) {
            throw new LoadModelException(e.getMessage(), TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode());
        }
    }

    private ModelInformation parseModelInformation(String modelName, String attributes, TConfigs configs) {
        int i;
        int[] inputShape = configs.getInput_shape().stream().mapToInt(Integer::intValue).toArray();
        int[] outputShape = configs.getOutput_shape().stream().mapToInt(Integer::intValue).toArray();
        TSDataType[] inputType = new TSDataType[inputShape[1]];
        TSDataType[] outputType = new TSDataType[outputShape[1]];
        for (i = 0; i < inputShape[1]; ++i) {
            inputType[i] = TSDataType.values()[(Byte)configs.getInput_type().get(i)];
        }
        for (i = 0; i < outputShape[1]; ++i) {
            outputType[i] = TSDataType.values()[(Byte)configs.getOutput_type().get(i)];
        }
        return new ModelInformation(modelName, inputShape, outputShape, inputType, outputType, attributes);
    }

    public TSStatus deleteModel(TDeleteModelReq req) throws TException {
        return this.executeRemoteCallWithRetry(c -> c.deleteModel(req));
    }

    public TSStatus loadModel(TLoadModelReq req) throws TException {
        return this.executeRemoteCallWithRetry(c -> c.loadModel(req));
    }

    public TSStatus unloadModel(TUnloadModelReq req) throws TException {
        return this.executeRemoteCallWithRetry(c -> c.unloadModel(req));
    }

    public TShowModelsResp showModels(TShowModelsReq req) throws TException {
        return this.executeRemoteCallWithRetry(c -> c.showModels(req));
    }

    public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException {
        return this.executeRemoteCallWithRetry(c -> c.showLoadedModels(req));
    }

    public TShowAIDevicesResp showAIDevices() throws TException {
        return this.executeRemoteCallWithRetry(IAINodeRPCService.Client::showAIDevices);
    }

    public TInferenceResp inference(String modelId, TsBlock inputTsBlock, Map<String, String> inferenceAttributes, TWindowParams windowParams) throws TException {
        try {
            TInferenceReq inferenceReq = new TInferenceReq(modelId, this.tsBlockSerde.serialize(inputTsBlock));
            if (windowParams != null) {
                inferenceReq.setWindowParams(windowParams);
            }
            if (inferenceAttributes != null) {
                inferenceReq.setInferenceAttributes(inferenceAttributes);
            }
            return this.executeRemoteCallWithRetry(c -> c.inference(inferenceReq));
        }
        catch (IOException e) {
            throw new TException("An exception occurred while serializing input data", (Throwable)e);
        }
        catch (TException e) {
            logger.warn("Error happens in AINode when executing {}: {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName(), (Object)e.getMessage());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    public TForecastResp forecast(String modelId, TsBlock inputTsBlock, int outputLength, Map<String, String> options) {
        try {
            TForecastReq forecastReq = new TForecastReq(modelId, this.tsBlockSerde.serialize(inputTsBlock), outputLength);
            forecastReq.setOptions(options);
            return this.executeRemoteCallWithRetry(c -> c.forecast(forecastReq));
        }
        catch (IOException e) {
            TSStatus tsStatus = new TSStatus(TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
            tsStatus.setMessage(String.format("Failed to serialize input tsblock %s", e.getMessage()));
            return new TForecastResp(tsStatus);
        }
        catch (TException e) {
            TSStatus tsStatus = new TSStatus(TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode());
            tsStatus.setMessage(String.format("Failed to connect to AINode when executing %s: %s", Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage()));
            return new TForecastResp(tsStatus);
        }
    }

    public TSStatus createTrainingTask(TTrainingReq req) throws TException {
        return this.executeRemoteCallWithRetry(c -> c.createTrainingTask(req));
    }

    @Override
    public void close() throws Exception {
        this.clientManager.returnClient((Object)this.endPoint, (Object)this);
    }

    public void invalidate() {
        Optional.ofNullable(this.transport).ifPresent(TTransport::close);
    }

    public void invalidateAll() {
        this.clientManager.clear((Object)this.endPoint);
    }

    public boolean printLogWhenEncounterException() {
        return this.property.isPrintLogWhenEncounterException();
    }

    @FunctionalInterface
    private static interface RemoteCall<R> {
        public R apply(IAINodeRPCService.Client var1) throws TException;
    }

    public static class Factory
    extends ThriftClientFactory<TEndPoint, AINodeClient> {
        public Factory(ClientManager<TEndPoint, AINodeClient> clientClientManager, ThriftClientProperty thriftClientProperty) {
            super(clientClientManager, thriftClientProperty);
        }

        public void destroyObject(TEndPoint tEndPoint, PooledObject<AINodeClient> pooledObject) throws Exception {
            ((AINodeClient)pooledObject.getObject()).invalidate();
        }

        public PooledObject<AINodeClient> makeObject(TEndPoint endPoint) throws Exception {
            return new DefaultPooledObject((Object)new AINodeClient(this.thriftClientProperty, endPoint, (ClientManager<TEndPoint, AINodeClient>)this.clientManager));
        }

        public boolean validateObject(TEndPoint tEndPoint, PooledObject<AINodeClient> pooledObject) {
            return Optional.ofNullable(((AINodeClient)pooledObject.getObject()).getTransport()).map(TTransport::isOpen).orElse(false);
        }
    }
}

