/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.index.sai.disk.v1.vector;

import io.github.jbellis.jvector.graph.GraphIndexBuilder;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.IntStream;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.index.sai.disk.format.IndexDescriptor;
import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.vector.BitsUtil;
import org.apache.cassandra.index.sai.disk.v1.vector.CompactionVectorValues;
import org.apache.cassandra.index.sai.disk.v1.vector.ConcurrentVectorValues;
import org.apache.cassandra.index.sai.disk.v1.vector.RamAwareVectorValues;
import org.apache.cassandra.index.sai.disk.v1.vector.RamEstimation;
import org.apache.cassandra.index.sai.disk.v1.vector.VectorPostings;
import org.apache.cassandra.index.sai.utils.IndexIdentifier;
import org.apache.cassandra.io.util.SequentialWriter;
import org.apache.cassandra.tracing.Tracing;
import org.cliffc.high_scale_lib.NonBlockingHashMapLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OnHeapGraph<T> {
    private static final Logger logger = LoggerFactory.getLogger(OnHeapGraph.class);
    private final RamAwareVectorValues vectorValues;
    private final GraphIndexBuilder<float[]> builder;
    private final VectorType<?> vectorType;
    private final VectorSimilarityFunction similarityFunction;
    private final ConcurrentMap<float[], VectorPostings<T>> postingsMap;
    private final NonBlockingHashMapLong<VectorPostings<T>> postingsByOrdinal;
    private final AtomicInteger nextOrdinal = new AtomicInteger();
    private volatile boolean hasDeletions;
    public static final float MAX_FLOAT32_COMPONENT = 1.0E17f;

    public OnHeapGraph(AbstractType<?> termComparator, IndexWriterConfig indexWriterConfig) {
        this(termComparator, indexWriterConfig, true);
    }

    public OnHeapGraph(AbstractType<?> termComparator, IndexWriterConfig indexWriterConfig, boolean concurrent) {
        this.vectorType = (VectorType)termComparator;
        this.vectorValues = concurrent ? new ConcurrentVectorValues(((VectorType)termComparator).dimension) : new CompactionVectorValues((VectorType)termComparator);
        this.similarityFunction = indexWriterConfig.getSimilarityFunction();
        this.postingsMap = new ConcurrentSkipListMap<float[], VectorPostings<T>>(Arrays::compare);
        this.postingsByOrdinal = new NonBlockingHashMapLong();
        this.builder = new GraphIndexBuilder<float[]>(this.vectorValues, VectorEncoding.FLOAT32, this.similarityFunction, indexWriterConfig.getMaximumNodeConnections(), indexWriterConfig.getConstructionBeamWidth(), 1.2f, 1.4f);
    }

    public int size() {
        return this.vectorValues.size();
    }

    public boolean isEmpty() {
        return this.postingsMap.values().stream().allMatch(VectorPostings::isEmpty);
    }

    public long add(ByteBuffer term, T key, InvalidVectorBehavior behavior) {
        assert (term != null && term.remaining() != 0);
        float[] vector = this.vectorType.composeAsFloat(term);
        if (behavior == InvalidVectorBehavior.IGNORE) {
            try {
                OnHeapGraph.validateIndexable(vector, this.similarityFunction);
            }
            catch (InvalidRequestException e) {
                logger.trace("Ignoring invalid vector during index build against existing data: {}", (Object)vector, (Object)e);
                return 0L;
            }
        } else {
            assert (behavior == InvalidVectorBehavior.FAIL);
            OnHeapGraph.validateIndexable(vector, this.similarityFunction);
        }
        long bytesUsed = 0L;
        VectorPostings postings = (VectorPostings)this.postingsMap.get(vector);
        if (postings == null) {
            postings = new VectorPostings(key);
            if (this.postingsMap.putIfAbsent(vector, postings) == null) {
                int ordinal = this.nextOrdinal.getAndIncrement();
                postings.setOrdinal(ordinal);
                bytesUsed += RamEstimation.concurrentHashMapRamUsed(1);
                bytesUsed += this.vectorValues instanceof ConcurrentVectorValues ? ((ConcurrentVectorValues)this.vectorValues).add(ordinal, vector) : ((CompactionVectorValues)this.vectorValues).add(ordinal, term);
                bytesUsed += VectorPostings.emptyBytesUsed() + VectorPostings.bytesPerPosting();
                this.postingsByOrdinal.put(ordinal, (VectorPostings<T>)postings);
                return bytesUsed += this.builder.addGraphNode(ordinal, this.vectorValues);
            }
            postings = (VectorPostings)this.postingsMap.get(vector);
        }
        if (postings.add(key)) {
            bytesUsed += VectorPostings.bytesPerPosting();
        }
        return bytesUsed;
    }

    public static void checkInBounds(float[] v) {
        for (int i = 0; i < v.length; ++i) {
            if (!Float.isFinite(v[i])) {
                throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
            }
            if (!(Math.abs(v[i]) > 1.0E17f)) continue;
            throw new IllegalArgumentException("Out-of-bounds value at vector[" + i + "]=" + v[i]);
        }
    }

    public static void validateIndexable(float[] vector, VectorSimilarityFunction similarityFunction) {
        try {
            OnHeapGraph.checkInBounds(vector);
        }
        catch (IllegalArgumentException e) {
            throw new InvalidRequestException(e.getMessage());
        }
        if (similarityFunction == VectorSimilarityFunction.COSINE) {
            for (int i = 0; i < vector.length; ++i) {
                if (vector[i] == 0.0f) continue;
                return;
            }
            throw new InvalidRequestException("Zero vectors cannot be indexed or queried with cosine similarity");
        }
    }

    public Collection<T> keysFromOrdinal(int node) {
        return this.postingsByOrdinal.get(node).getPostings();
    }

    public long remove(ByteBuffer term, T key) {
        assert (term != null && term.remaining() != 0);
        float[] vector = this.vectorType.composeAsFloat(term);
        VectorPostings postings = (VectorPostings)this.postingsMap.get(vector);
        if (postings == null) {
            return 0L;
        }
        this.hasDeletions = true;
        return postings.remove(key);
    }

    public PriorityQueue<T> search(float[] queryVector, int limit, Bits toAccept) {
        OnHeapGraph.validateIndexable(queryVector, this.similarityFunction);
        if (this.vectorValues.size() == 0) {
            return new PriorityQueue();
        }
        Bits bits = this.hasDeletions ? BitsUtil.bitsIgnoringDeleted(toAccept, this.postingsByOrdinal) : toAccept;
        OnHeapGraphIndex<float[]> graph = this.builder.getGraph();
        GraphSearcher searcher = new GraphSearcher.Builder(graph.getView()).withConcurrentUpdates().build();
        NeighborSimilarity.ExactScoreFunction scoreFunction = node2 -> this.vectorCompareFunction(queryVector, node2);
        SearchResult result = searcher.search(scoreFunction, null, limit, bits);
        Tracing.trace("ANN search visited {} in-memory nodes to return {} results", (Object)result.getVisitedCount(), (Object)result.getNodes().length);
        SearchResult.NodeScore[] a = result.getNodes();
        PriorityQueue<T> keyQueue = new PriorityQueue<T>();
        for (int i = 0; i < a.length; ++i) {
            keyQueue.addAll(this.keysFromOrdinal(a[i].node));
        }
        return keyQueue;
    }

    /*
     * Exception decompiling
     */
    public SegmentMetadata.ComponentMetadataMap writeData(IndexDescriptor indexDescriptor, IndexIdentifier indexIdentifier, Function<T, Integer> postingTransformer) throws IOException {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private float vectorCompareFunction(float[] queryVector, int node) {
        return this.similarityFunction.compare(queryVector, (float[])this.vectorValues.vectorValue(node));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private long writePQ(SequentialWriter writer) throws IOException {
        int M = this.vectorValues.dimension() / 2;
        writer.writeBoolean(this.vectorValues.size() >= 1024);
        if (this.vectorValues.size() < 1024) {
            logger.debug("Skipping PQ for only {} vectors", (Object)this.vectorValues.size());
            return writer.position();
        }
        logger.debug("Computing PQ for {} vectors", (Object)this.vectorValues.size());
        Class<OnHeapGraph> clazz = OnHeapGraph.class;
        synchronized (OnHeapGraph.class) {
            ProductQuantization pq = ProductQuantization.compute(this.vectorValues, M, false);
            assert (!this.vectorValues.isValueShared());
            byte[][] encoded = (byte[][])IntStream.range(0, this.vectorValues.size()).parallel().mapToObj(i -> pq.encode(this.vectorValues.vectorValue(i))).toArray(x$0 -> new byte[x$0][]);
            // ** MonitorExit[var5_3] (shouldn't be in output)
            CompressedVectors cv = new CompressedVectors(pq, encoded);
            cv.write(writer);
            return writer.position();
        }
    }

    private static /* synthetic */ void lambda$writeData$1(Set deletedOrdinals, VectorPostings vectorPostings) {
        deletedOrdinals.add(vectorPostings.getOrdinal());
    }

    public static enum InvalidVectorBehavior {
        IGNORE,
        FAIL;

    }
}

