/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.sql2rel;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.calcite.linq4j.Nullness;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlDynamicParam;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelectKeyword;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.util.SqlVisitor;
import org.apache.calcite.sql.validate.AggregatingSelectScope;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.sql2rel.AuxiliaryConverter;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Util;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableMap;
import org.apache.flink.calcite.shaded.org.checkerframework.checker.nullness.qual.Nullable;
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;

class AggConverter
implements SqlVisitor<Void> {
    private final SqlToRelConverter.Blackboard bb;
    private final Map<String, String> nameMap;
    final SqlNodeList groupExprs = new SqlNodeList(SqlParserPos.ZERO);
    private final Map<SqlNode, Ord<AuxiliaryConverter>> auxiliaryGroupExprs = new HashMap<SqlNode, Ord<AuxiliaryConverter>>();
    private final SqlNodeList measureExprs = new SqlNodeList(SqlParserPos.ZERO);
    final PairList<RexNode, @Nullable String> convertedInputExprs = PairList.of();
    final List<AggregateCall> aggCalls = new ArrayList<AggregateCall>();
    private final Map<SqlNode, RexNode> aggMapping = new HashMap<SqlNode, RexNode>();
    private final Map<AggregateCall, RexNode> aggCallMapping = new HashMap<AggregateCall, RexNode>();
    private final SqlValidator validator;
    private final AggregatingSelectScope scope;
    boolean inOver = false;

    private AggConverter(SqlToRelConverter.Blackboard bb, ImmutableMap<String, String> nameMap) {
        this(bb, nameMap, null, null);
    }

    private AggConverter(SqlToRelConverter.Blackboard bb, ImmutableMap<String, String> nameMap, SqlValidator validator, AggregatingSelectScope scope) {
        this.bb = bb;
        this.nameMap = nameMap;
        this.validator = validator;
        this.scope = scope;
    }

    static AggConverter create(SqlToRelConverter.Blackboard bb) {
        return new AggConverter(bb, ImmutableMap.of());
    }

    static AggConverter create(SqlToRelConverter.Blackboard bb, AggregatingSelectScope scope, SqlValidator validator) {
        HashMap nameMap = new HashMap();
        Ord.forEach(scope.getNode().getSelectList(), (selectItem, i) -> {
            String name;
            if (SqlUtil.isCallTo(selectItem, SqlStdOperatorTable.AS)) {
                SqlCall call = (SqlCall)selectItem;
                selectItem = call.operand(0);
                name = ((SqlNode)call.operand(1)).toString();
            } else {
                name = SqlValidatorUtil.alias(selectItem, i);
            }
            nameMap.put(selectItem.toString(), name);
        });
        final AggregatingSelectScope.Resolved resolved = (AggregatingSelectScope.Resolved)scope.resolved.get();
        return new AggConverter(bb, ImmutableMap.copyOf(nameMap), validator, scope){

            @Override
            AggregatingSelectScope.Resolved getResolved() {
                return resolved;
            }
        };
    }

    int addGroupExpr(SqlNode expr) {
        int ref = this.lookupGroupExpr(expr);
        if (ref >= 0) {
            return ref;
        }
        int index = this.groupExprs.size();
        this.groupExprs.add(expr);
        String name = this.nameMap.get(expr.toString());
        RexNode convExpr = this.bb.convertExpression(expr);
        this.addExpr(convExpr, name);
        if (expr instanceof SqlCall) {
            SqlCall call = (SqlCall)expr;
            SqlStdOperatorTable.convertGroupToAuxiliaryCalls(call, (node, converter) -> this.addAuxiliaryGroupExpr((SqlNode)node, index, (AuxiliaryConverter)converter));
        }
        return index;
    }

    void addAuxiliaryGroupExpr(SqlNode node, int index, AuxiliaryConverter converter) {
        for (SqlNode node2 : this.auxiliaryGroupExprs.keySet()) {
            if (!node2.equalsDeep(node, Litmus.IGNORE)) continue;
            return;
        }
        this.auxiliaryGroupExprs.put(node, Ord.of(index, converter));
    }

    boolean addMeasureExpr(SqlNode expr) {
        if (this.isMeasureExpr(expr)) {
            return false;
        }
        this.measureExprs.add(expr);
        String name = this.nameMap.get(expr.toString());
        RexNode convExpr = this.bb.convertExpression(expr);
        this.addExpr(convExpr, name);
        return true;
    }

    private void addExpr(RexNode expr, @Nullable String name) {
        if (name == null && expr instanceof RexInputRef) {
            int i = ((RexInputRef)expr).getIndex();
            name = this.bb.root().getRowType().getFieldList().get(i).getName();
        }
        if (this.convertedInputExprs.rightList().contains(name)) {
            name = null;
        }
        this.convertedInputExprs.add(expr, name);
    }

    @Override
    public Void visit(SqlIdentifier id) {
        return null;
    }

    @Override
    public Void visit(SqlNodeList nodeList) {
        nodeList.forEach((Consumer<? super SqlNode>)((Consumer<SqlNode>)this::visitNode));
        return null;
    }

    @Override
    public Void visit(SqlLiteral lit) {
        return null;
    }

    @Override
    public Void visit(SqlDataTypeSpec type) {
        return null;
    }

    @Override
    public Void visit(SqlDynamicParam param) {
        return null;
    }

    @Override
    public Void visit(SqlIntervalQualifier intervalQualifier) {
        return null;
    }

    @Override
    public Void visit(SqlCall call) {
        switch (call.getKind()) {
            case FILTER: 
            case IGNORE_NULLS: 
            case RESPECT_NULLS: 
            case WITHIN_DISTINCT: 
            case WITHIN_GROUP: {
                this.translateAgg(call);
                return null;
            }
            case SELECT: {
                return null;
            }
        }
        boolean prevInOver = this.inOver;
        if (call.getOperator().getKind() == SqlKind.OVER) {
            List<SqlNode> operandList = call.getOperandList();
            assert (operandList.size() == 2);
            this.inOver = true;
            operandList.get(0).accept(this);
            this.inOver = false;
            operandList.get(1).accept(this);
            return null;
        }
        if (call.getOperator().isAggregator()) {
            if (this.inOver) {
                this.inOver = false;
            } else {
                this.translateAgg(call);
                return null;
            }
        }
        for (SqlNode operand : call.getOperandList()) {
            if (operand == null) continue;
            operand.accept(this);
        }
        this.inOver = prevInOver;
        return null;
    }

    private void translateAgg(SqlCall call) {
        this.translateAgg(call, null, null, null, false, call);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void translateAgg(SqlCall call, @Nullable SqlNode filter, @Nullable SqlNodeList distinctList, @Nullable SqlNodeList orderList, boolean ignoreNulls, SqlCall outerCall) {
        RelCollation collation;
        ImmutableBitSet distinctKeys;
        assert (this.bb.agg == this);
        RexBuilder rexBuilder = this.bb.getRexBuilder();
        List<SqlNode> operands = call.getOperandList();
        SqlParserPos pos = call.getParserPosition();
        switch (call.getKind()) {
            case FILTER: {
                assert (filter == null);
                this.translateAgg((SqlCall)call.operand(0), (SqlNode)call.operand(1), distinctList, orderList, ignoreNulls, outerCall);
                return;
            }
            case WITHIN_DISTINCT: {
                assert (orderList == null);
                this.translateAgg((SqlCall)call.operand(0), filter, (SqlNodeList)call.operand(1), orderList, ignoreNulls, outerCall);
                return;
            }
            case WITHIN_GROUP: {
                assert (orderList == null);
                this.translateAgg((SqlCall)call.operand(0), filter, distinctList, (SqlNodeList)call.operand(1), ignoreNulls, outerCall);
                return;
            }
            case IGNORE_NULLS: {
                ignoreNulls = true;
            }
            case RESPECT_NULLS: {
                this.translateAgg((SqlCall)call.operand(0), filter, distinctList, orderList, ignoreNulls, outerCall);
                return;
            }
            case COUNTIF: {
                SqlCall call2 = SqlStdOperatorTable.COUNT.createCall(pos, SqlIdentifier.star(pos));
                SqlNode filter2 = SqlUtil.andExpressions(filter, call.operand(0));
                this.translateAgg(call2, filter2, distinctList, orderList, ignoreNulls, outerCall);
                return;
            }
            case STRING_AGG: {
                List<SqlNode> operands2;
                if (!operands.isEmpty() && Util.last(operands) instanceof SqlNodeList) {
                    orderList = (SqlNodeList)Util.last(operands);
                    operands2 = Util.skipLast(operands);
                } else {
                    operands2 = operands;
                }
                SqlCall call2 = SqlStdOperatorTable.LISTAGG.createCall(call.getFunctionQuantifier(), pos, operands2);
                this.translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
                return;
            }
            case GROUP_CONCAT: {
                Object separator;
                ArrayList<SqlNode> operands2 = new ArrayList<SqlNode>(operands);
                if (!operands2.isEmpty() && Util.last(operands2).getKind() == SqlKind.SEPARATOR) {
                    SqlCall sepCall = (SqlCall)operands2.remove(operands.size() - 1);
                    separator = sepCall.operand(0);
                } else {
                    separator = null;
                }
                if (!operands2.isEmpty() && Util.last(operands2) instanceof SqlNodeList) {
                    orderList = (SqlNodeList)operands2.remove(operands2.size() - 1);
                }
                if (separator != null) {
                    operands2.add((SqlNode)separator);
                }
                SqlCall call2 = SqlStdOperatorTable.LISTAGG.createCall(call.getFunctionQuantifier(), pos, operands2);
                this.translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
                return;
            }
            case ARRAY_AGG: 
            case ARRAY_CONCAT_AGG: {
                if (operands.isEmpty() || !(Util.last(operands) instanceof SqlNodeList)) break;
                orderList = (SqlNodeList)Util.last(operands);
                SqlCall call2 = call.getOperator().createCall(call.getFunctionQuantifier(), pos, Util.skipLast(operands));
                this.translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
                return;
            }
        }
        ArrayList<Integer> args = new ArrayList<Integer>();
        int filterArg = -1;
        try {
            this.bb.agg = null;
            FlinkSqlCallBinding binding = new FlinkSqlCallBinding(this.validator, this.scope, call);
            List<SqlNode> sqlNodes = binding.operands();
            for (int i2 = 0; i2 < sqlNodes.size(); ++i2) {
                SqlIdentifier id;
                SqlNode operand = sqlNodes.get(i2);
                if (operand instanceof SqlIdentifier && (id = (SqlIdentifier)operand).isStar()) {
                    assert (call.operandCount() == 1);
                    assert (args.isEmpty());
                    break;
                }
                RexNode convertedExpr = this.bb.convertExpression(operand);
                args.add(this.lookupOrCreateGroupExpr(convertedExpr));
            }
            if (filter != null) {
                RexNode convertedExpr = this.bb.convertExpression(filter);
                if (convertedExpr.getType().isNullable()) {
                    convertedExpr = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_TRUE, convertedExpr);
                }
                filterArg = this.lookupOrCreateGroupExpr(convertedExpr);
            }
            if (distinctList == null) {
                distinctKeys = null;
            } else {
                ImmutableBitSet.Builder distinctBuilder = ImmutableBitSet.builder();
                for (SqlNode distinct : distinctList) {
                    RexNode e = this.bb.convertExpression(distinct);
                    distinctBuilder.set(this.lookupOrCreateGroupExpr(e));
                }
                distinctKeys = distinctBuilder.build();
            }
        }
        finally {
            this.bb.agg = this;
        }
        SqlAggFunction aggFunction = (SqlAggFunction)call.getOperator();
        RelDataType type = this.bb.getValidator().deriveType(this.bb.scope, call);
        boolean distinct = false;
        SqlLiteral quantifier = call.getFunctionQuantifier();
        if (null != quantifier && quantifier.getValue() == SqlSelectKeyword.DISTINCT) {
            distinct = true;
        }
        boolean approximate = false;
        if (aggFunction == SqlStdOperatorTable.APPROX_COUNT_DISTINCT) {
            aggFunction = SqlStdOperatorTable.COUNT;
            distinct = true;
            approximate = true;
        }
        if (orderList == null || orderList.size() == 0) {
            collation = RelCollations.EMPTY;
        } else {
            try {
                this.bb.agg = null;
                collation = RelCollations.of(orderList.stream().map(order -> this.bb.convertSortExpression((SqlNode)order, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.UNSPECIFIED, this::sortToFieldCollation)).collect(Collectors.toList()));
            }
            finally {
                this.bb.agg = this;
            }
        }
        AggregateCall aggCall = AggregateCall.create(aggFunction, distinct, approximate, ignoreNulls, ImmutableList.of(), args, filterArg, distinctKeys, collation, type, this.nameMap.get(outerCall.toString()));
        RexNode rex = rexBuilder.addAggCall(aggCall, this.groupExprs.size(), this.aggCalls, this.aggCallMapping, i -> this.convertedInputExprs.leftList().get(i).getType().isNullable());
        this.aggMapping.put(outerCall, rex);
    }

    private RelFieldCollation sortToFieldCollation(SqlNode expr, RelFieldCollation.Direction direction, RelFieldCollation.NullDirection nullDirection) {
        RexNode node = this.bb.convertExpression(expr);
        int fieldIndex = this.lookupOrCreateGroupExpr(node);
        if (nullDirection == RelFieldCollation.NullDirection.UNSPECIFIED) {
            nullDirection = direction.defaultNullDirection();
        }
        return new RelFieldCollation(fieldIndex, direction, nullDirection);
    }

    private int lookupOrCreateGroupExpr(RexNode expr) {
        int index = 0;
        for (RexNode convertedInputExpr : this.convertedInputExprs.leftList()) {
            if (expr.equals(convertedInputExpr)) {
                return index;
            }
            ++index;
        }
        this.addExpr(expr, null);
        return index;
    }

    int lookupGroupExpr(SqlNode expr) {
        return SqlUtil.indexOfDeep(this.groupExprs, expr, Litmus.IGNORE);
    }

    boolean isMeasureExpr(SqlNode expr) {
        return SqlUtil.indexOfDeep(this.measureExprs, expr, Litmus.IGNORE) >= 0;
    }

    @Nullable RexNode lookupMeasure(SqlNode expr) {
        return this.aggMapping.get(expr);
    }

    @Nullable RexNode lookupAggregates(SqlCall call) {
        assert (this.bb.agg == this);
        for (Map.Entry<SqlNode, Ord<AuxiliaryConverter>> e : this.auxiliaryGroupExprs.entrySet()) {
            if (!call.equalsDeep(e.getKey(), Litmus.IGNORE)) continue;
            AuxiliaryConverter converter = (AuxiliaryConverter)e.getValue().e;
            RexBuilder rexBuilder = this.bb.getRexBuilder();
            int groupOrdinal = e.getValue().i;
            return converter.convert(rexBuilder, this.convertedInputExprs.leftList().get(groupOrdinal), rexBuilder.makeInputRef(Nullness.castNonNull(this.bb.root), groupOrdinal));
        }
        return this.aggMapping.get(call);
    }

    AggregatingSelectScope.Resolved getResolved() {
        throw new UnsupportedOperationException();
    }
}

