/*
 * Decompiled with CFR 0.152.
 */
package io.questdb.griffin.engine.functions.array;

import io.questdb.cairo.CairoConfiguration;
import io.questdb.cairo.CairoException;
import io.questdb.cairo.ColumnType;
import io.questdb.cairo.arr.ArrayView;
import io.questdb.cairo.arr.DerivedArrayView;
import io.questdb.cairo.arr.DirectArray;
import io.questdb.cairo.sql.ArrayFunction;
import io.questdb.cairo.sql.Function;
import io.questdb.cairo.sql.Record;
import io.questdb.cairo.vm.api.MemoryA;
import io.questdb.griffin.FunctionFactory;
import io.questdb.griffin.PlanSink;
import io.questdb.griffin.SqlException;
import io.questdb.griffin.SqlExecutionContext;
import io.questdb.griffin.engine.functions.BinaryFunction;
import io.questdb.std.IntList;
import io.questdb.std.Misc;
import io.questdb.std.ObjList;

public class DoubleMatrixMultiplyFunctionFactory
implements FunctionFactory {
    @Override
    public String getSignature() {
        return "matmul(D[]D[])";
    }

    @Override
    public Function newInstance(int position, ObjList<Function> args, IntList argPositions, CairoConfiguration configuration, SqlExecutionContext sqlExecutionContext) throws SqlException {
        return new Func(configuration, args.getQuick(0), args.getQuick(1), argPositions.getQuick(0), argPositions.getQuick(1));
    }

    private static class Func
    extends ArrayFunction
    implements BinaryFunction {
        private final DirectArray arrayOut;
        private final Function leftArg;
        private final int leftArgPos;
        private final DerivedArrayView leftDerived;
        private final Function rightArg;
        private final DerivedArrayView rightDerived;

        public Func(CairoConfiguration configuration, Function leftArg, Function rightArg, int leftArgPos, int rightArgPos) throws SqlException {
            try {
                this.leftArg = leftArg;
                this.rightArg = rightArg;
                this.arrayOut = new DirectArray(configuration);
                this.leftArgPos = leftArgPos;
                int nDimsLeft = ColumnType.decodeArrayDimensionality(leftArg.getType());
                int nDimsRight = ColumnType.decodeArrayDimensionality(rightArg.getType());
                DerivedArrayView leftDerived = null;
                if (nDimsLeft == 1) {
                    leftDerived = new DerivedArrayView();
                } else if (nDimsLeft != 2) {
                    throw SqlException.position(rightArgPos).put("left array is not one or two-dimensional");
                }
                DerivedArrayView rightDerived = null;
                if (nDimsRight == 1) {
                    rightDerived = new DerivedArrayView();
                } else if (nDimsRight != 2) {
                    throw SqlException.position(rightArgPos).put("right array is not one or two-dimensional");
                }
                this.rightDerived = rightDerived;
                this.leftDerived = leftDerived;
                this.type = ColumnType.encodeArrayType((short)10, 2);
            }
            catch (Throwable th) {
                this.close();
                throw th;
            }
        }

        @Override
        public void close() {
            BinaryFunction.super.close();
            Misc.free(this.arrayOut);
        }

        @Override
        public ArrayView getArray(Record rec) {
            ArrayView left = this.leftArg.getArray(rec);
            ArrayView right = this.rightArg.getArray(rec);
            if (left.isNull() || right.isNull()) {
                this.arrayOut.ofNull();
                return this.arrayOut;
            }
            if (this.leftDerived != null) {
                this.leftDerived.of(left);
                this.leftDerived.prependDimensions(1);
                left = this.leftDerived;
            }
            if (this.rightDerived != null) {
                this.rightDerived.of(right);
                this.rightDerived.appendDimensions(1);
                right = this.rightDerived;
            }
            int commonDimLen = left.getDimLen(1);
            if (right.getDimLen(0) != commonDimLen) {
                throw CairoException.nonCritical().position(this.leftArgPos).put("left array row length doesn't match right array column length ").put("[leftRowLen=").put(commonDimLen).put(", rightColLen=").put(right.getDimLen(0)).put(']');
            }
            int outRowCount = left.getDimLen(0);
            int outColCount = right.getDimLen(1);
            int leftStride0 = left.getStride(0);
            int leftStride1 = left.getStride(1);
            int rightStride0 = right.getStride(0);
            int rightStride1 = right.getStride(1);
            this.arrayOut.setType(this.type);
            this.arrayOut.setDimLen(0, outRowCount);
            this.arrayOut.setDimLen(1, outColCount);
            this.arrayOut.applyShape(this.leftArgPos);
            MemoryA memOut = this.arrayOut.startMemoryA();
            for (int rowOut = 0; rowOut < outRowCount; ++rowOut) {
                for (int colOut = 0; colOut < outColCount; ++colOut) {
                    double sum = 0.0;
                    for (int commonDim = 0; commonDim < commonDimLen; ++commonDim) {
                        int leftFlatIndex = leftStride0 * rowOut + leftStride1 * commonDim;
                        int rightFlatIndex = rightStride0 * commonDim + rightStride1 * colOut;
                        sum += left.getDouble(leftFlatIndex) * right.getDouble(rightFlatIndex);
                    }
                    memOut.putDouble(sum);
                }
            }
            return this.arrayOut;
        }

        @Override
        public Function getLeft() {
            return this.leftArg;
        }

        @Override
        public String getName() {
            return "matmul";
        }

        @Override
        public Function getRight() {
            return this.rightArg;
        }

        @Override
        public boolean isThreadSafe() {
            return false;
        }

        @Override
        public void toPlan(PlanSink sink) {
            sink.val(this.getName()).val('(').val(this.leftArg).val(", ").val(this.rightArg).val(')');
        }
    }
}

