/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableConstantVectorSearchCallToCorrelateRuleConfig;
import org.immutables.value.Value;

public class ConstantVectorSearchCallToCorrelateRule
extends RelRule<ConstantVectorSearchCallToCorrelateRuleConfig> {
    public static final ConstantVectorSearchCallToCorrelateRule INSTANCE = ConstantVectorSearchCallToCorrelateRuleConfig.DEFAULT.toRule();

    private ConstantVectorSearchCallToCorrelateRule(ConstantVectorSearchCallToCorrelateRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        LogicalTableFunctionScan scan = (LogicalTableFunctionScan)call.rel(0);
        RexNode rexNode = scan.getCall();
        if (!(rexNode instanceof RexCall)) {
            return false;
        }
        RexCall rexCall = (RexCall)rexNode;
        return rexCall.getOperator() instanceof SqlVectorSearchTableFunction && RexUtil.isConstant(rexCall.getOperands().get(2));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        LogicalTableFunctionScan scan = (LogicalTableFunctionScan)call.rel(0);
        RexCall functionCall = (RexCall)scan.getCall();
        RexNode constantCall = functionCall.getOperands().get(2);
        RelOptCluster cluster = scan.getCluster();
        RelBuilder builder = call.builder();
        LogicalValues values = LogicalValues.createOneRow(cluster);
        builder.push(values);
        builder.project(constantCall);
        CorrelationId correlId = cluster.createCorrel();
        RexNode correlRex = cluster.getRexBuilder().makeCorrel(builder.peek().getRowType(), correlId);
        RexNode correlatedConstant = cluster.getRexBuilder().makeFieldAccess(correlRex, 0);
        builder.push(scan.getInput(0));
        ArrayList<RexNode> operands = new ArrayList<RexNode>(functionCall.operands);
        operands.set(2, correlatedConstant);
        builder.functionScan(functionCall.getOperator(), 1, operands);
        builder.join(JoinRelType.INNER, (RexNode)cluster.getRexBuilder().makeLiteral(true), Collections.singleton(correlId));
        builder.projectExcept(builder.field(0));
        call.transformTo(builder.build());
    }

    @Value.Immutable
    public static interface ConstantVectorSearchCallToCorrelateRuleConfig
    extends RelRule.Config {
        public static final ConstantVectorSearchCallToCorrelateRuleConfig DEFAULT = ImmutableConstantVectorSearchCallToCorrelateRuleConfig.builder().build().withOperandSupplier(b0 -> b0.operand(LogicalTableFunctionScan.class).anyInputs()).withDescription("ConstantVectorSearchCallToCorrelateRule");

        @Override
        default public ConstantVectorSearchCallToCorrelateRule toRule() {
            return new ConstantVectorSearchCallToCorrelateRule(this);
        }
    }
}

