/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.sampler;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.sampler.AbstractStreamSampler;
import com.amazon.randomcutforest.sampler.AcceptPointState;
import com.amazon.randomcutforest.sampler.ISampled;
import com.amazon.randomcutforest.sampler.Weighted;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class CompactSampler
extends AbstractStreamSampler<Integer> {
    public static final long SEQUENCE_INDEX_NA = -1L;
    protected final float[] weight;
    protected final int[] pointIndex;
    protected final long[] sequenceIndex;
    protected int size;
    private final boolean storeSequenceIndexesEnabled;

    public static Builder<?> builder() {
        return new Builder();
    }

    public static CompactSampler uniformSampler(int sampleSize, long randomSeed, boolean storeSequences) {
        return ((Builder)((Builder)((Builder)((Builder)new Builder().capacity(sampleSize)).timeDecay(0.0)).randomSeed(randomSeed)).storeSequenceIndexesEnabled(storeSequences)).build();
    }

    protected CompactSampler(Builder<?> builder) {
        super(builder);
        CommonUtils.checkArgument(builder.initialAcceptFraction > 0.0, " the admittance fraction cannot be <= 0");
        CommonUtils.checkArgument(builder.capacity > 0, " sampler capacity cannot be <=0 ");
        this.storeSequenceIndexesEnabled = ((Builder)builder).storeSequenceIndexesEnabled;
        this.timeDecay = builder.timeDecay;
        this.maxSequenceIndex = builder.maxSequenceIndex;
        this.mostRecentTimeDecayUpdate = builder.sequenceIndexOfMostRecentTimeDecayUpdate;
        if (((Builder)builder).weight != null || ((Builder)builder).pointIndex != null || ((Builder)builder).sequenceIndex != null || ((Builder)builder).validateHeap) {
            CommonUtils.checkArgument(((Builder)builder).weight != null && ((Builder)builder).weight.length == builder.capacity, " incorrect state");
            CommonUtils.checkArgument(((Builder)builder).pointIndex != null && ((Builder)builder).pointIndex.length == builder.capacity, " incorrect state");
            CommonUtils.checkArgument(!((Builder)builder).storeSequenceIndexesEnabled || ((Builder)builder).sequenceIndex != null && ((Builder)builder).sequenceIndex.length == builder.capacity, " incorrect state");
            this.weight = ((Builder)builder).weight;
            this.pointIndex = ((Builder)builder).pointIndex;
            this.sequenceIndex = ((Builder)builder).sequenceIndex;
            this.size = ((Builder)builder).size;
            this.reheap(((Builder)builder).validateHeap);
        } else {
            CommonUtils.checkArgument(((Builder)builder).size == 0, "incorrect state");
            this.size = 0;
            this.weight = new float[builder.capacity];
            this.pointIndex = new int[builder.capacity];
            this.sequenceIndex = (long[])(this.storeSequenceIndexesEnabled ? new long[builder.capacity] : null);
        }
    }

    @Override
    public boolean acceptPoint(long sequenceIndex) {
        boolean initial;
        CommonUtils.checkState(sequenceIndex >= this.mostRecentTimeDecayUpdate, "incorrect sequences submitted to sampler");
        this.evictedPoint = null;
        float weight = this.computeWeight(sequenceIndex);
        boolean bl = initial = this.size < this.capacity && this.random.nextDouble() < this.initialAcceptProbability(this.size);
        if (initial || weight < this.weight[0]) {
            this.acceptPointState = new AcceptPointState(sequenceIndex, weight);
            if (!initial) {
                this.evictMax();
            }
            return true;
        }
        return false;
    }

    public void evictMax() {
        long evictedIndex = this.storeSequenceIndexesEnabled ? this.sequenceIndex[0] : 0L;
        this.evictedPoint = new Weighted<Integer>(this.pointIndex[0], this.weight[0], evictedIndex);
        --this.size;
        this.weight[0] = this.weight[this.size];
        this.pointIndex[0] = this.pointIndex[this.size];
        if (this.storeSequenceIndexesEnabled) {
            this.sequenceIndex[0] = this.sequenceIndex[this.size];
        }
        this.swapDown(0);
    }

    private void swapDown(int startIndex, boolean validate) {
        int current = startIndex;
        while (2 * current + 1 < this.size) {
            int maxIndex = 2 * current + 1;
            if (2 * current + 2 < this.size && this.weight[2 * current + 2] > this.weight[maxIndex]) {
                maxIndex = 2 * current + 2;
            }
            if (!(this.weight[maxIndex] > this.weight[current])) break;
            if (validate) {
                throw new IllegalStateException("the heap property is not satisfied at index " + current);
            }
            this.swapWeights(current, maxIndex);
            current = maxIndex;
        }
    }

    private void swapDown(int startIndex) {
        this.swapDown(startIndex, false);
    }

    public void reheap(boolean validate) {
        for (int i = (this.size + 1) / 2; i >= 0; --i) {
            this.swapDown(i, validate);
        }
    }

    public void addPoint(Integer pointIndex, float weight, long sequenceIndex) {
        CommonUtils.checkArgument(this.acceptPointState == null && this.size < this.capacity && pointIndex != null, " operation not permitted");
        this.acceptPointState = new AcceptPointState(sequenceIndex, weight);
        this.addPoint(pointIndex);
    }

    @Override
    public void addPoint(Integer pointIndex) {
        if (pointIndex != null) {
            int tmp;
            CommonUtils.checkState(this.size < this.capacity, "sampler full");
            CommonUtils.checkState(this.acceptPointState != null, "this method should only be called after a successful call to acceptSample(long)");
            this.weight[this.size] = this.acceptPointState.getWeight();
            this.pointIndex[this.size] = pointIndex;
            if (this.storeSequenceIndexesEnabled) {
                this.sequenceIndex[this.size] = this.acceptPointState.getSequenceIndex();
            }
            int current = this.size++;
            while (current > 0 && this.weight[tmp = (current - 1) / 2] < this.weight[current]) {
                this.swapWeights(current, tmp);
                current = tmp;
            }
            this.acceptPointState = null;
        }
    }

    @Override
    public List<ISampled<Integer>> getSample() {
        return this.streamSample().collect(Collectors.toList());
    }

    public List<Weighted<Integer>> getWeightedSample() {
        return this.streamSample().collect(Collectors.toList());
    }

    private Stream<Weighted<Integer>> streamSample() {
        this.reset_weights();
        return IntStream.range(0, this.size).mapToObj(i -> {
            long index = this.sequenceIndex != null ? this.sequenceIndex[i] : -1L;
            return new Weighted<Integer>(this.pointIndex[i], this.weight[i], index);
        });
    }

    private void reset_weights() {
        if (this.accumuluatedTimeDecay == 0.0) {
            return;
        }
        int i = 0;
        while (i < this.size) {
            int n = i++;
            this.weight[n] = (float)((double)this.weight[n] + this.accumuluatedTimeDecay);
        }
        this.accumuluatedTimeDecay = 0.0;
    }

    @Override
    public Optional<ISampled<Integer>> getEvictedPoint() {
        return Optional.ofNullable(this.evictedPoint);
    }

    @Override
    public int size() {
        return this.size;
    }

    public float[] getWeightArray() {
        return this.weight;
    }

    public int[] getPointIndexArray() {
        return this.pointIndex;
    }

    public long[] getSequenceIndexArray() {
        return this.sequenceIndex;
    }

    public boolean isStoreSequenceIndexesEnabled() {
        return this.storeSequenceIndexesEnabled;
    }

    private void swapWeights(int a, int b) {
        int tmp = this.pointIndex[a];
        this.pointIndex[a] = this.pointIndex[b];
        this.pointIndex[b] = tmp;
        float tmpDouble = this.weight[a];
        this.weight[a] = this.weight[b];
        this.weight[b] = tmpDouble;
        if (this.storeSequenceIndexesEnabled) {
            long tmpLong = this.sequenceIndex[a];
            this.sequenceIndex[a] = this.sequenceIndex[b];
            this.sequenceIndex[b] = tmpLong;
        }
    }

    public static class Builder<T extends Builder<T>>
    extends AbstractStreamSampler.Builder<T> {
        private int size = 0;
        private float[] weight = null;
        private int[] pointIndex = null;
        private long[] sequenceIndex = null;
        private boolean validateHeap = false;
        private boolean storeSequenceIndexesEnabled = false;

        public T size(int size) {
            this.size = size;
            return (T)this;
        }

        public T weight(float[] weight) {
            this.weight = weight;
            return (T)this;
        }

        public T pointIndex(int[] pointIndex) {
            this.pointIndex = pointIndex;
            return (T)this;
        }

        public T sequenceIndex(long[] sequenceIndex) {
            this.sequenceIndex = sequenceIndex;
            return (T)this;
        }

        public T storeSequenceIndexesEnabled(boolean storeSequenceIndexesEnabled) {
            this.storeSequenceIndexesEnabled = storeSequenceIndexesEnabled;
            return (T)this;
        }

        public T validateHeap(boolean validateHeap) {
            this.validateHeap = validateHeap;
            return (T)this;
        }

        public CompactSampler build() {
            return new CompactSampler(this);
        }
    }
}

