/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.server.coordinator.balancer;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import org.apache.commons.math3.util.FastMath;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.server.coordinator.SegmentCountsPerInterval;
import org.apache.druid.server.coordinator.ServerHolder;
import org.apache.druid.server.coordinator.balancer.BalancerStrategy;
import org.apache.druid.server.coordinator.loading.SegmentAction;
import org.apache.druid.server.coordinator.stats.CoordinatorRunStats;
import org.apache.druid.server.coordinator.stats.Dimension;
import org.apache.druid.server.coordinator.stats.RowKey;
import org.apache.druid.server.coordinator.stats.Stats;
import org.apache.druid.timeline.DataSegment;
import org.joda.time.Interval;
import org.joda.time.ReadableInterval;

public class CostBalancerStrategy
implements BalancerStrategy {
    private static final EmittingLogger log = new EmittingLogger(CostBalancerStrategy.class);
    private static final double HALF_LIFE = 24.0;
    static final double LAMBDA = Math.log(2.0) / 24.0;
    static final double INV_LAMBDA_SQUARE = 1.0 / (LAMBDA * LAMBDA);
    private static final double MILLIS_IN_HOUR = 3600000.0;
    private static final double MILLIS_FACTOR = 3600000.0 / LAMBDA;
    private static final Comparator<Pair<Double, ServerHolder>> CHEAPEST_SERVERS_FIRST = Comparator.comparing(pair -> (Double)pair.lhs).thenComparing(pair -> (ServerHolder)pair.rhs);
    private final CoordinatorRunStats stats = new CoordinatorRunStats();
    private final AtomicLong computeTimeNanos = new AtomicLong(0L);
    private final ListeningExecutorService exec;

    public static double computeJointSegmentsCost(DataSegment segment, Iterable<DataSegment> segmentSet) {
        Interval costComputeInterval = CostBalancerStrategy.getCostComputeInterval(segment);
        double totalCost = 0.0;
        for (DataSegment s : segmentSet) {
            if (!costComputeInterval.overlaps((ReadableInterval)s.getInterval())) continue;
            totalCost += CostBalancerStrategy.computeJointSegmentsCost(segment, s);
        }
        return totalCost;
    }

    public static double computeJointSegmentsCost(DataSegment segmentA, DataSegment segmentB) {
        Interval intervalA = segmentA.getInterval();
        Interval intervalB = segmentB.getInterval();
        double multiplier = segmentA.getDataSource().equals(segmentB.getDataSource()) ? 2.0 : 1.0;
        return CostBalancerStrategy.intervalCost(intervalA, intervalB) * multiplier;
    }

    public static double intervalCost(Interval intervalA, Interval intervalB) {
        double t0 = intervalA.getStartMillis();
        double t1 = ((double)intervalA.getEndMillis() - t0) / MILLIS_FACTOR;
        double start = ((double)intervalB.getStartMillis() - t0) / MILLIS_FACTOR;
        double end = ((double)intervalB.getEndMillis() - t0) / MILLIS_FACTOR;
        return INV_LAMBDA_SQUARE * CostBalancerStrategy.intervalCost(t1, start, end);
    }

    public static double intervalCost(double x1, double y0, double y1) {
        if (x1 == 0.0 || y1 == y0) {
            return 0.0;
        }
        if (y0 < 0.0) {
            double tmp = x1;
            x1 = y1 - y0;
            y1 = tmp - y0;
            y0 = -y0;
        }
        if (y0 < x1) {
            double gamma;
            double beta;
            if (y1 <= x1) {
                beta = y1 - y0;
                gamma = x1 - y0;
            } else {
                beta = x1 - y0;
                gamma = y1 - y0;
            }
            return CostBalancerStrategy.intervalCost(y0, y0, y1) + CostBalancerStrategy.intervalCost(beta, beta, gamma) + 2.0 * (beta + FastMath.exp((double)(-beta)) - 1.0);
        }
        double exy0 = FastMath.exp((double)(x1 - y0));
        double exy1 = FastMath.exp((double)(x1 - y1));
        double ey0 = FastMath.exp((double)(0.0 - y0));
        double ey1 = FastMath.exp((double)(0.0 - y1));
        return ey1 - ey0 - (exy1 - exy0);
    }

    public CostBalancerStrategy(ListeningExecutorService exec) {
        this.exec = exec;
    }

    @Override
    public Iterator<ServerHolder> findServersToLoadSegment(DataSegment segmentToLoad, List<ServerHolder> serverHolders) {
        return this.orderServersByPlacementCost(segmentToLoad, serverHolders, SegmentAction.LOAD).stream().filter(server -> server.canLoadSegment(segmentToLoad)).iterator();
    }

    @Override
    public ServerHolder findDestinationServerToMoveSegment(DataSegment segmentToMove, ServerHolder sourceServer, List<ServerHolder> serverHolders) {
        List<ServerHolder> servers = this.orderServersByPlacementCost(segmentToMove, serverHolders, SegmentAction.MOVE_TO);
        if (servers.isEmpty()) {
            return null;
        }
        ServerHolder candidateServer = servers.get(0);
        return candidateServer.equals(sourceServer) ? null : candidateServer;
    }

    @Override
    public Iterator<ServerHolder> findServersToDropSegment(DataSegment segmentToDrop, List<ServerHolder> serverHolders) {
        List<ServerHolder> serversByCost = this.orderServersByPlacementCost(segmentToDrop, serverHolders, SegmentAction.DROP);
        return Lists.reverse(serversByCost).iterator();
    }

    @Override
    public CoordinatorRunStats getStats() {
        this.stats.add(Stats.Balancer.COMPUTATION_TIME, TimeUnit.NANOSECONDS.toMillis(this.computeTimeNanos.getAndSet(0L)));
        return this.stats;
    }

    protected double computePlacementCost(DataSegment proposalSegment, ServerHolder server) {
        Interval costComputeInterval = CostBalancerStrategy.getCostComputeInterval(proposalSegment);
        Object2IntOpenHashMap intervalToSegmentCount = new Object2IntOpenHashMap();
        SegmentCountsPerInterval projectedSegments = server.getProjectedSegmentCounts();
        projectedSegments.getIntervalToTotalSegmentCount().object2IntEntrySet().forEach(entry -> {
            Interval interval = (Interval)entry.getKey();
            if (costComputeInterval.overlaps((ReadableInterval)interval)) {
                intervalToSegmentCount.addTo((Object)interval, entry.getIntValue());
            }
        });
        String datasource = proposalSegment.getDataSource();
        projectedSegments.getIntervalToSegmentCount(datasource).object2IntEntrySet().forEach(entry -> {
            Interval interval = (Interval)entry.getKey();
            if (costComputeInterval.overlaps((ReadableInterval)interval)) {
                intervalToSegmentCount.addTo((Object)interval, entry.getIntValue());
            }
        });
        double cost = 0.0;
        Interval segmentInterval = proposalSegment.getInterval();
        cost += intervalToSegmentCount.object2IntEntrySet().stream().mapToDouble(entry -> CostBalancerStrategy.intervalCost(segmentInterval, (Interval)entry.getKey()) * (double)entry.getIntValue()).sum();
        if (server.isProjectedSegment(proposalSegment)) {
            cost -= CostBalancerStrategy.intervalCost(segmentInterval, segmentInterval) * 2.0;
        }
        return cost;
    }

    private List<ServerHolder> orderServersByPlacementCost(DataSegment segment, List<ServerHolder> serverHolders, SegmentAction action) {
        Stopwatch computeTime = Stopwatch.createStarted();
        ArrayList<ListenableFuture> futures = new ArrayList<ListenableFuture>();
        for (ServerHolder server : serverHolders) {
            futures.add(this.exec.submit(() -> Pair.of((Object)this.computePlacementCost(segment, server), (Object)server)));
        }
        String tier = serverHolders.isEmpty() ? null : serverHolders.get(0).getServer().getTier();
        RowKey metricKey = RowKey.with(Dimension.TIER, tier).with(Dimension.DATASOURCE, segment.getDataSource()).and(Dimension.DESCRIPTION, action.name());
        PriorityQueue<Pair<Double, ServerHolder>> costPrioritizedServers = new PriorityQueue<Pair<Double, ServerHolder>>(CHEAPEST_SERVERS_FIRST);
        try {
            costPrioritizedServers.addAll((Collection)Futures.allAsList(futures).get(1L, TimeUnit.MINUTES));
        }
        catch (Exception e) {
            this.stats.add(Stats.Balancer.COMPUTATION_ERRORS, metricKey, 1L);
            this.handleFailure(e, segment, action);
        }
        computeTime.stop();
        this.stats.add(Stats.Balancer.COMPUTATION_COUNT, 1L);
        this.computeTimeNanos.addAndGet(computeTime.elapsed(TimeUnit.NANOSECONDS));
        return costPrioritizedServers.stream().map(pair -> (ServerHolder)pair.rhs).collect(Collectors.toList());
    }

    private void handleFailure(Exception e, DataSegment segment, SegmentAction action) {
        String reason;
        String suggestion = "";
        if (this.exec.isShutdown()) {
            reason = "Executor shutdown";
        } else if (e instanceof TimeoutException) {
            reason = "Timed out";
            suggestion = " Try setting a higher value for 'balancerComputeThreads'.";
        } else {
            reason = e.getMessage();
        }
        String msgFormat = "Cost strategy computations failed for action[%s] on segment[%s] due to reason[%s].[%s]";
        log.noStackTrace().warn((Throwable)e, msgFormat, new Object[]{action, segment.getId(), reason, suggestion});
    }

    private static Interval getCostComputeInterval(DataSegment segment) {
        Interval segmentInterval = segment.getInterval();
        if (Intervals.isEternity((Interval)segmentInterval)) {
            return segmentInterval;
        }
        long maxGap = TimeUnit.DAYS.toMillis(45L);
        return Intervals.utc((long)(segmentInterval.getStartMillis() - maxGap), (long)(segmentInterval.getEndMillis() + maxGap));
    }
}

