/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup.dictionary;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionarySlice;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockFactory;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;

public class IdentityDictionary
extends AIdentityDictionary {
    private static final long serialVersionUID = 2535887782153955098L;

    private IdentityDictionary(int nRowCol) {
        super(nRowCol);
    }

    public static IDictionary create(int nRowCol) {
        return IdentityDictionary.create(nRowCol, false);
    }

    private IdentityDictionary(int nRowCol, boolean withEmpty) {
        super(nRowCol, withEmpty);
    }

    public static IDictionary create(int nRowCol, boolean withEmpty) {
        if (nRowCol == 1) {
            if (withEmpty) {
                return new Dictionary(new double[]{1.0, 0.0});
            }
            return new Dictionary(new double[]{1.0});
        }
        return new IdentityDictionary(nRowCol, withEmpty);
    }

    @Override
    public double[] getValues() {
        if (this.nRowCol < 3) {
            double[] ret = new double[this.nRowCol * this.nRowCol + (this.withEmpty ? this.nRowCol : 0)];
            for (int i = 0; i < this.nRowCol; ++i) {
                ret[i * this.nRowCol + i] = 1.0;
            }
            return ret;
        }
        throw new DMLCompressionException("Invalid to materialize identity Matrix Please Implement alternative");
    }

    @Override
    public double getValue(int i) {
        int nCol = this.nRowCol;
        int row = i / nCol;
        if (row > this.nRowCol) {
            return 0.0;
        }
        int col = i % nCol;
        return row == col ? 1.0 : 0.0;
    }

    @Override
    public double getValue(int r, int c, int nCol) {
        return r == c ? 1.0 : 0.0;
    }

    @Override
    public long getInMemorySize() {
        return IdentityDictionary.getInMemorySize(-1);
    }

    public static long getInMemorySize(int numberColumns) {
        return AIdentityDictionary.getInMemorySize(numberColumns);
    }

    @Override
    public double aggregate(double init, Builtin fn) {
        if (fn.getBuiltinCode() == Builtin.BuiltinCode.MAX) {
            return fn.execute(init, 1.0);
        }
        if (fn.getBuiltinCode() == Builtin.BuiltinCode.MIN) {
            return fn.execute(init, 0.0);
        }
        throw new NotImplementedException();
    }

    @Override
    public double[] aggregateRows(Builtin fn, int nCol) {
        double[] ret = new double[this.nRowCol];
        Arrays.fill(ret, fn.execute(1L, 0L));
        return ret;
    }

    @Override
    public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) {
        for (int i = 0; i < this.nRowCol; ++i) {
            int idx = colIndexes.get(i);
            c[idx] = fn.execute(c[idx], 0.0);
            c[idx] = fn.execute(c[idx], 1.0);
        }
    }

    @Override
    public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) {
        int i;
        boolean same = false;
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            same = true;
            for (i = 0; i < colIndexes.size() && same; ++i) {
                same = v[colIndexes.get(i)] == 0.0;
            }
        }
        if (op.fn instanceof Divide) {
            same = true;
            for (i = 0; i < colIndexes.size() && same; ++i) {
                same = v[colIndexes.get(i)] == 1.0;
            }
        }
        if (same) {
            return this;
        }
        MatrixBlockDictionary mb = this.getMBDict();
        return mb.binOpRight(op, v, colIndexes);
    }

    @Override
    public IDictionary clone() {
        return new IdentityDictionary(this.nRowCol, this.withEmpty);
    }

    @Override
    public IDictionary.DictType getDictType() {
        return IDictionary.DictType.Identity;
    }

    @Override
    public int getNumberOfValues(int ncol) {
        if (ncol != this.nRowCol) {
            throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns");
        }
        return this.nRowCol + (this.withEmpty ? 1 : 0);
    }

    @Override
    public int getNumberOfColumns(int nrow) {
        if (nrow != this.nRowCol + (this.withEmpty ? 1 : 0)) {
            throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns");
        }
        return this.nRowCol;
    }

    @Override
    public double[] sumAllRowsToDouble(int nrColumns) {
        if (this.withEmpty) {
            double[] ret = new double[this.nRowCol + 1];
            Arrays.fill(ret, 1.0);
            ret[ret.length - 1] = 0.0;
            return ret;
        }
        double[] ret = new double[this.nRowCol];
        Arrays.fill(ret, 1.0);
        return ret;
    }

    @Override
    public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) {
        int i;
        double[] ret = new double[this.getNumberOfValues(defaultTuple.length) + 1];
        for (i = 0; i < this.nRowCol; ++i) {
            ret[i] = 1.0;
        }
        for (i = 0; i < defaultTuple.length; ++i) {
            int n = ret.length - 1;
            ret[n] = ret[n] + defaultTuple[i];
        }
        return ret;
    }

    @Override
    public double[] sumAllRowsToDoubleWithReference(double[] reference) {
        int i;
        double[] ret = new double[this.getNumberOfValues(reference.length)];
        double refSum = 0.0;
        for (i = 0; i < reference.length; ++i) {
            refSum += reference[i];
        }
        Arrays.fill(ret, 1.0);
        i = 0;
        while (i < ret.length) {
            int n = i++;
            ret[n] = ret[n] + refSum;
        }
        if (this.withEmpty) {
            int n = ret.length - 1;
            ret[n] = ret[n] + -1.0;
        }
        return ret;
    }

    @Override
    public double[] sumAllRowsToDoubleSq(int nrColumns) {
        double[] ret = new double[this.nRowCol];
        Arrays.fill(ret, 1.0);
        return ret;
    }

    @Override
    public void colSum(double[] c, int[] counts, IColIndex colIndexes) {
        for (int i = 0; i < colIndexes.size(); ++i) {
            int n = colIndexes.get(i);
            c[n] = c[n] + (double)counts[i];
        }
    }

    @Override
    public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) {
        this.colSum(c, counts, colIndexes);
    }

    @Override
    public void colProduct(double[] res, int[] counts, IColIndex colIndexes) {
        for (int i = 0; i < colIndexes.size(); ++i) {
            res[colIndexes.get((int)i)] = 0.0;
        }
    }

    @Override
    public double sum(int[] counts, int ncol) {
        double s = 0.0;
        for (int v : counts) {
            s += (double)v;
        }
        if (this.withEmpty) {
            s -= (double)counts[counts.length - 1];
        }
        return s;
    }

    @Override
    public double sumSq(int[] counts, int ncol) {
        return this.sum(counts, ncol);
    }

    @Override
    public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) {
        if (idxStart == 0 && idxEnd == this.nRowCol) {
            return new IdentityDictionary(this.nRowCol, this.withEmpty);
        }
        return IdentityDictionarySlice.create(this.nRowCol, this.withEmpty, idxStart, idxEnd);
    }

    @Override
    public long getNumberNonZeros(int[] counts, int nCol) {
        return (long)this.sum(counts, nCol);
    }

    @Override
    public int[] countNNZZeroColumns(int[] counts) {
        if (this.withEmpty) {
            return Arrays.copyOf(counts, this.nRowCol);
        }
        return counts;
    }

    @Override
    public final void addToEntry(double[] v, int fr, int to, int nCol) {
        this.addToEntry(v, fr, to, nCol, 1);
    }

    @Override
    public void addToEntry(double[] v, int fr, int to, int nCol, int rep) {
        if (!this.withEmpty) {
            int n = to * nCol + fr;
            v[n] = v[n] + (double)rep;
        } else if (fr < this.nRowCol) {
            int n = to * nCol + fr;
            v[n] = v[n] + (double)rep;
        }
    }

    @Override
    public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) {
        if (this.withEmpty) {
            this.addToEntryVectorizedWithEmpty(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol);
        } else {
            this.addToEntryVectorizedNorm(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol);
        }
    }

    private void addToEntryVectorizedWithEmpty(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) {
        if (f1 < this.nRowCol) {
            int n = t1 * nCol + f1;
            v[n] = v[n] + 1.0;
        }
        if (f2 < this.nRowCol) {
            int n = t2 * nCol + f2;
            v[n] = v[n] + 1.0;
        }
        if (f3 < this.nRowCol) {
            int n = t3 * nCol + f3;
            v[n] = v[n] + 1.0;
        }
        if (f4 < this.nRowCol) {
            int n = t4 * nCol + f4;
            v[n] = v[n] + 1.0;
        }
        if (f5 < this.nRowCol) {
            int n = t5 * nCol + f5;
            v[n] = v[n] + 1.0;
        }
        if (f6 < this.nRowCol) {
            int n = t6 * nCol + f6;
            v[n] = v[n] + 1.0;
        }
        if (f7 < this.nRowCol) {
            int n = t7 * nCol + f7;
            v[n] = v[n] + 1.0;
        }
        if (f8 < this.nRowCol) {
            int n = t8 * nCol + f8;
            v[n] = v[n] + 1.0;
        }
    }

    private void addToEntryVectorizedNorm(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) {
        int n = t1 * nCol + f1;
        v[n] = v[n] + 1.0;
        int n2 = t2 * nCol + f2;
        v[n2] = v[n2] + 1.0;
        int n3 = t3 * nCol + f3;
        v[n3] = v[n3] + 1.0;
        int n4 = t4 * nCol + f4;
        v[n4] = v[n4] + 1.0;
        int n5 = t5 * nCol + f5;
        v[n5] = v[n5] + 1.0;
        int n6 = t6 * nCol + f6;
        v[n6] = v[n6] + 1.0;
        int n7 = t7 * nCol + f7;
        v[n7] = v[n7] + 1.0;
        int n8 = t8 * nCol + f8;
        v[n8] = v[n8] + 1.0;
    }

    @Override
    public MatrixBlockDictionary getMBDict() {
        return this.getMBDict(this.nRowCol);
    }

    @Override
    public MatrixBlockDictionary createMBDict(int nCol) {
        if (this.withEmpty) {
            SparseBlock sb = SparseBlockFactory.createIdentityMatrixWithEmptyRow(this.nRowCol);
            MatrixBlock identity = new MatrixBlock(this.nRowCol + 1, this.nRowCol, this.nRowCol, sb);
            return new MatrixBlockDictionary(identity);
        }
        SparseBlock sb = SparseBlockFactory.createIdentityMatrix(this.nRowCol);
        MatrixBlock identity = new MatrixBlock(this.nRowCol, this.nRowCol, this.nRowCol, sb);
        return new MatrixBlockDictionary(identity);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeByte(DictionaryFactory.Type.IDENTITY.ordinal());
        out.writeInt(this.nRowCol);
    }

    public static IdentityDictionary read(DataInput in) throws IOException {
        return new IdentityDictionary(in.readInt());
    }

    @Override
    public long getExactSizeOnDisk() {
        return 5L;
    }

    @Override
    public IDictionary preaggValuesFromDense(int numVals, IColIndex colIndexes, IColIndex aggregateColumns, double[] b, int cut) {
        int cs = colIndexes.size();
        int s = aggregateColumns.size();
        double[] ret = new double[s * numVals];
        int off = 0;
        for (int i = 0; i < cs; ++i) {
            int offB = colIndexes.get(i) * cut;
            for (int j = 0; j < s; ++j) {
                ret[off++] = b[offB + aggregateColumns.get(j)];
            }
        }
        MatrixBlock db = new MatrixBlock(numVals, s, ret);
        return new MatrixBlockDictionary(db);
    }

    @Override
    public double getSparsity() {
        if (this.withEmpty) {
            return 1.0 / (double)(this.nRowCol + 1);
        }
        return 1.0 / (double)this.nRowCol;
    }

    @Override
    public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) {
        if (!this.withEmpty || dictIdx < this.nRowCol) {
            int n = off + cols.get(dictIdx);
            ret[n] = ret[n] + v;
        }
    }

    @Override
    public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) {
        int leftSide = rowsLeft.size();
        int colsOut = result.getNumColumns();
        int commonDim = Math.min(left.length / leftSide, this.nRowCol);
        double[] resV = result.getDenseBlockValues();
        for (int i = 0; i < leftSide; ++i) {
            int offOut = rowsLeft.get(i) * colsOut;
            int leftOff = i;
            for (int j = 0; j < commonDim; ++j) {
                int n = offOut + colsRight.get(j);
                resV[n] = resV[n] + left[leftOff + j * leftSide];
            }
        }
    }

    @Override
    public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, int[] scaling) {
        int leftSide = rowsLeft.size();
        int resCols = result.getNumColumns();
        double[] resV = result.getDenseBlockValues();
        for (int i = 0; i < leftSide; ++i) {
            int offOut = rowsLeft.get(i) * resCols;
            for (int j = 0; j < this.nRowCol; ++j) {
                int n = offOut + colsRight.get(j);
                resV[n] = resV[n] + left[i + j * leftSide] * (double)scaling[j];
            }
        }
    }

    @Override
    public boolean equals(IDictionary o) {
        if (o instanceof IdentityDictionary && ((IdentityDictionary)o).nRowCol == this.nRowCol && ((IdentityDictionary)o).withEmpty == this.withEmpty) {
            return true;
        }
        return this.getMBDict().equals(o);
    }

    @Override
    protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, int nColRight) {
        int thisColsSize = thisCols.size();
        SparseBlockMCSR ret = new SparseBlockMCSR(numVals);
        for (int h = 0; h < thisColsSize; ++h) {
            int colIdx = thisCols.get(h);
            if (b.isEmpty(colIdx)) continue;
            double[] sValues = b.values(colIdx);
            int[] sIndexes = b.indexes(colIdx);
            int sPos = b.pos(colIdx);
            int sEnd = b.size(colIdx) + sPos;
            for (int i = sPos; i < sEnd; ++i) {
                ret.add(h, sIndexes[i], sValues[i]);
            }
        }
        MatrixBlock retB = new MatrixBlock(numVals, nColRight, -1L, ret);
        retB.recomputeNonZeros();
        return MatrixBlockDictionary.create(retB, false);
    }

    @Override
    protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, IColIndex aggregateColumns) {
        int thisColsSize = thisCols.size();
        int aggColSize = aggregateColumns.size();
        SparseBlockMCSR ret = new SparseBlockMCSR(numVals);
        block0: for (int h = 0; h < thisColsSize; ++h) {
            int colIdx = thisCols.get(h);
            if (b.isEmpty(colIdx)) continue;
            double[] sValues = b.values(colIdx);
            int[] sIndexes = b.indexes(colIdx);
            int sPos = b.pos(colIdx);
            int sEnd = b.size(colIdx) + sPos;
            int retIdx = 0;
            for (int i = sPos; i < sEnd; ++i) {
                while (retIdx < aggColSize && aggregateColumns.get(retIdx) < sIndexes[i]) {
                    ++retIdx;
                }
                if (retIdx == aggColSize) continue block0;
                ret.add(h, retIdx, sValues[i]);
            }
        }
        MatrixBlock retB = new MatrixBlock(numVals, aggregateColumns.size(), -1L, ret);
        retB.recomputeNonZeros();
        return MatrixBlockDictionary.create(retB, false);
    }

    @Override
    public String getString(int colIndexes) {
        return "IdentityMatrix of size: " + this.nRowCol + " with empty: " + this.withEmpty;
    }

    public String toString() {
        return "IdentityMatrix of size: " + this.nRowCol + " with empty: " + this.withEmpty;
    }
}

