/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.StatsProvider;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.Constraint;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.FixedWidthType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.TableScanNode;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class TableScanStatsRule
extends SimpleStatsRule<TableScanNode> {
    private static final Pattern<TableScanNode> PATTERN = Patterns.tableScan();
    private final Metadata metadata;

    public TableScanStatsRule(Metadata metadata, StatsNormalizer normalizer) {
        super(normalizer);
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<TableScanNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(TableScanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
        if (SystemSessionProperties.isStatisticsPrecalculationForPushdownEnabled(session) && node.getStatistics().isPresent()) {
            return node.getStatistics();
        }
        Constraint constraint = new Constraint(TupleDomain.all());
        TableStatistics tableStatistics = this.metadata.getTableStatistics(session, node.getTable(), constraint);
        HashMap<Symbol, SymbolStatsEstimate> outputSymbolStats = new HashMap<Symbol, SymbolStatsEstimate>();
        for (Map.Entry<Symbol, ColumnHandle> entry : node.getAssignments().entrySet()) {
            Symbol symbol = entry.getKey();
            Optional<ColumnStatistics> columnStatistics = Optional.ofNullable((ColumnStatistics)tableStatistics.getColumnStatistics().get(entry.getValue()));
            SymbolStatsEstimate symbolStatistics = columnStatistics.map(statistics -> TableScanStatsRule.toSymbolStatistics(tableStatistics, statistics, types.get(symbol))).orElse(SymbolStatsEstimate.unknown());
            outputSymbolStats.put(symbol, symbolStatistics);
        }
        return Optional.of(PlanNodeStatsEstimate.builder().setOutputRowCount(tableStatistics.getRowCount().getValue()).addSymbolStatistics(outputSymbolStats).build());
    }

    private static SymbolStatsEstimate toSymbolStatistics(TableStatistics tableStatistics, ColumnStatistics columnStatistics, Type type) {
        Objects.requireNonNull(tableStatistics, "tableStatistics is null");
        Objects.requireNonNull(columnStatistics, "columnStatistics is null");
        Objects.requireNonNull(type, "type is null");
        double nullsFraction = columnStatistics.getNullsFraction().getValue();
        double nonNullRowsCount = tableStatistics.getRowCount().getValue() * (1.0 - nullsFraction);
        double averageRowSize = nonNullRowsCount == 0.0 ? 0.0 : (type instanceof FixedWidthType ? Double.NaN : columnStatistics.getDataSize().getValue() / nonNullRowsCount);
        SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder();
        result.setNullsFraction(nullsFraction);
        result.setDistinctValuesCount(columnStatistics.getDistinctValuesCount().getValue());
        result.setAverageRowSize(averageRowSize);
        columnStatistics.getRange().ifPresent(range -> {
            result.setLowValue(range.getMin());
            result.setHighValue(range.getMax());
        });
        return result.build();
    }
}

