/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.mapred;

import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.Utils;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.io.serializer.Serializer;
import org.apache.hadoop.mapred.Counters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CelebornSortBasedPusher<K, V>
extends OutputStream {
    private final Logger logger = LoggerFactory.getLogger(CelebornSortBasedPusher.class);
    private final int mapId;
    private final int attempt;
    private final int numMappers;
    private final int numReducers;
    private final ShuffleClient shuffleClient;
    private final int maxIOBufferSize;
    private final int spillIOBufferSize;
    private final Serializer<K> kSer;
    private final Serializer<V> vSer;
    private final RawComparator<K> comparator;
    private final AtomicReference<Exception> exception = new AtomicReference();
    private final Counters.Counter mapOutputByteCounter;
    private final Counters.Counter mapOutputRecordCounter;
    private final Map<Integer, List<SerializedKV>> partitionedKVs;
    private int writePos;
    private byte[] serializedKV;
    private final int maxPushDataSize;

    public CelebornSortBasedPusher(int numMappers, int numReducers, int mapId, int attemptId, Serializer<K> kSer, Serializer<V> vSer, int maxIOBufferSize, int spillIOBufferSize, RawComparator<K> comparator, Counters.Counter mapOutputByteCounter, Counters.Counter mapOutputRecordCounter, ShuffleClient shuffleClient, CelebornConf celebornConf) {
        this.numMappers = numMappers;
        this.numReducers = numReducers;
        this.mapId = mapId;
        this.attempt = attemptId;
        this.kSer = kSer;
        this.vSer = vSer;
        this.maxIOBufferSize = maxIOBufferSize;
        this.spillIOBufferSize = spillIOBufferSize;
        this.mapOutputByteCounter = mapOutputByteCounter;
        this.mapOutputRecordCounter = mapOutputRecordCounter;
        this.comparator = comparator;
        this.shuffleClient = shuffleClient;
        this.partitionedKVs = new HashMap<Integer, List<SerializedKV>>();
        this.serializedKV = new byte[maxIOBufferSize];
        this.maxPushDataSize = (int)celebornConf.clientMrMaxPushData();
        this.logger.info("Sort based push initialized with numMappers:{} numReducers:{} mapId:{} attemptId:{} maxIOBufferSize:{} spillIOBufferSize:{}", new Object[]{numMappers, numReducers, mapId, attemptId, maxIOBufferSize, spillIOBufferSize});
        try {
            kSer.open((OutputStream)this);
            vSer.open((OutputStream)this);
        }
        catch (IOException e) {
            this.exception.compareAndSet(null, e);
        }
    }

    public void insert(K key, V value, int partition) {
        try {
            if (this.writePos >= this.spillIOBufferSize) {
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("Data is large enough {}/{}/{}, trigger sort and flush", new Object[]{Utils.bytesToString(this.writePos), Utils.bytesToString(this.spillIOBufferSize), Utils.bytesToString(this.maxIOBufferSize)});
                }
                this.sortKVs();
                this.sendKVAndUpdateWritePos();
            }
            int dataLen = this.insertRecordInternal(key, value, partition);
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Sort based pusher insert into partition:{} with {} bytes", (Object)partition, (Object)dataLen);
            }
            this.mapOutputRecordCounter.increment(1L);
            this.mapOutputByteCounter.increment((long)dataLen);
        }
        catch (IOException e) {
            this.exception.compareAndSet(null, e);
        }
    }

    private void sendKVAndUpdateWritePos() throws IOException {
        Iterator<Map.Entry<Integer, List<SerializedKV>>> entryIter = this.partitionedKVs.entrySet().iterator();
        while (entryIter.hasNext()) {
            Map.Entry<Integer, List<SerializedKV>> entry = entryIter.next();
            entryIter.remove();
            int partition = entry.getKey();
            List<SerializedKV> kvs = entry.getValue();
            ArrayList<SerializedKV> localKVs = new ArrayList<SerializedKV>();
            int partitionKVTotalLen = 0;
            for (SerializedKV kv : kvs) {
                localKVs.add(kv);
                if ((partitionKVTotalLen += kv.kLen + kv.vLen) <= this.maxPushDataSize) continue;
                this.sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);
                localKVs.clear();
                partitionKVTotalLen = 0;
            }
            if (!localKVs.isEmpty()) {
                this.sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);
            }
            kvs.clear();
        }
        this.partitionedKVs.clear();
        this.writePos = 0;
    }

    private void sendSortedBuffersPartition(int partition, List<SerializedKV> localKVs, int partitionKVTotalLen) throws IOException {
        int extraSize = 0;
        for (SerializedKV localKV : localKVs) {
            extraSize += WritableUtils.getVIntSize((long)localKV.kLen);
            extraSize += WritableUtils.getVIntSize((long)localKV.vLen);
        }
        extraSize += WritableUtils.getVIntSize((long)-1L);
        byte[] pkvs = new byte[4 + (extraSize += WritableUtils.getVIntSize((long)-1L)) + partitionKVTotalLen];
        int pkvsPos = 4;
        Platform.putInt(pkvs, Platform.BYTE_ARRAY_OFFSET, partitionKVTotalLen + extraSize);
        for (SerializedKV kv : localKVs) {
            int recordLen = kv.kLen + kv.vLen;
            pkvsPos = this.writeVLong(pkvs, pkvsPos, kv.kLen);
            pkvsPos = this.writeVLong(pkvs, pkvsPos, kv.vLen);
            System.arraycopy(this.serializedKV, kv.offset, pkvs, pkvsPos, recordLen);
            pkvsPos += recordLen;
        }
        pkvsPos = this.writeVLong(pkvs, pkvsPos, -1L);
        this.writeVLong(pkvs, pkvsPos, -1L);
        int compressedSize = this.shuffleClient.pushData(0, this.mapId, this.attempt, partition, pkvs, 0, 4 + extraSize + partitionKVTotalLen, this.numMappers, this.numReducers);
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("Send sorted buffer mapId:{} attemptId:{} to partition:{} uncompressed size:{} compressed size:{}", new Object[]{this.mapId, this.attempt, partition, Utils.bytesToString(4 + extraSize + partitionKVTotalLen), Utils.bytesToString(compressedSize)});
        }
    }

    private int writeVLong(byte[] data, int offset, long dataInt) {
        if (dataInt >= -112L && dataInt <= 127L) {
            data[offset++] = (byte)dataInt;
            return offset;
        }
        int len = -112;
        if (dataInt < 0L) {
            dataInt ^= 0xFFFFFFFFFFFFFFFFL;
            len = -120;
        }
        long tmp = dataInt;
        while (tmp != 0L) {
            tmp >>= 8;
            --len;
        }
        data[offset++] = (byte)len;
        for (int idx = len = len < -120 ? -(len + 120) : -(len + 112); idx != 0; --idx) {
            int shiftBits = (idx - 1) * 8;
            long mask = 255L << shiftBits;
            data[offset++] = (byte)((dataInt & mask) >> shiftBits);
        }
        return offset;
    }

    private void sortKVs() {
        for (Map.Entry<Integer, List<SerializedKV>> partitionKVEntry : this.partitionedKVs.entrySet()) {
            partitionKVEntry.getValue().sort((o1, o2) -> this.comparator.compare(this.serializedKV, o1.offset, o1.kLen, this.serializedKV, o2.offset, o2.kLen));
        }
    }

    private int insertRecordInternal(K key, V value, int partition) throws IOException {
        int offset = this.writePos;
        this.kSer.serialize(key);
        int keyLen = this.writePos - offset;
        this.vSer.serialize(value);
        int valLen = this.writePos - keyLen - offset;
        List serializedKVs = this.partitionedKVs.computeIfAbsent(partition, v -> new ArrayList());
        serializedKVs.add(new SerializedKV(offset, keyLen, valLen));
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("Pusher insert into buffer partition:{} offset:{} keyLen:{} valueLen:{} size:{}", new Object[]{partition, offset, keyLen, valLen, this.partitionedKVs.size()});
        }
        return keyLen + valLen;
    }

    public void checkException() throws IOException {
        if (this.exception.get() != null) {
            throw new IOException("Write data to celeborn failed", this.exception.get());
        }
    }

    @Override
    public void write(int b) throws IOException {
        if (this.writePos < this.maxIOBufferSize) {
            this.serializedKV[this.writePos] = (byte)b;
            ++this.writePos;
        } else {
            this.logger.warn("Sort push memory high, write pos {} max size {}", (Object)this.writePos, (Object)this.maxIOBufferSize);
            throw new IOException("Sort pusher memory exhausted.");
        }
    }

    @Override
    public void flush() {
        this.logger.info("Sort based pusher called flush");
        try {
            this.sortKVs();
            this.sendKVAndUpdateWritePos();
        }
        catch (IOException e) {
            this.exception.compareAndSet(null, e);
        }
    }

    @Override
    public void close() {
        this.flush();
        try {
            this.logger.info("Call mapper end shuffleId:{} mapId:{} attemptId:{} numMappers:{}", new Object[]{0, this.mapId, this.attempt, this.numMappers});
            this.shuffleClient.mapperEnd(0, this.mapId, this.attempt, this.numMappers);
        }
        catch (IOException e) {
            this.exception.compareAndSet(null, e);
        }
        this.partitionedKVs.clear();
        this.serializedKV = null;
    }

    static class SerializedKV {
        final int offset;
        final int kLen;
        final int vLen;

        public SerializedKV(int offset, int kLen, int vLen) {
            this.offset = offset;
            this.kLen = kLen;
            this.vLen = vLen;
        }
    }
}

