/*
 * Decompiled with CFR 0.152.
 */
package com.hedera.node.app.blocks.impl;

import com.hedera.node.app.blocks.StreamingTreeHasher;
import com.hedera.node.app.blocks.impl.BlockImplUtils;
import com.hedera.node.app.hapi.utils.CommonUtils;
import com.hedera.pbj.runtime.io.buffer.Bytes;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;

public class ConcurrentStreamingTreeHasher
implements StreamingTreeHasher {
    private static final int DEFAULT_HASH_COMBINE_BATCH_SIZE = 8;
    private final HashCombiner combiner = new HashCombiner(0);
    private final ExecutorService executorService;
    private final int hashCombineBatchSize;
    private int numLeaves;
    private int rootHeight;
    private boolean rootHashRequested = false;

    public ConcurrentStreamingTreeHasher(@NonNull ExecutorService executorService) {
        this(executorService, 8);
    }

    public ConcurrentStreamingTreeHasher(@NonNull ExecutorService executorService, int hashCombineBatchSize) {
        this.executorService = Objects.requireNonNull(executorService);
        if (hashCombineBatchSize % 2 == 1) {
            throw new IllegalArgumentException("Hash combine batch size must be an even number");
        }
        this.hashCombineBatchSize = hashCombineBatchSize;
    }

    @Override
    public void addLeaf(@NonNull ByteBuffer hash) {
        Objects.requireNonNull(hash);
        if (this.rootHashRequested) {
            throw new IllegalStateException("Cannot add leaves after requesting the root hash");
        }
        if (hash.remaining() < HASH_LENGTH) {
            throw new IllegalArgumentException("Buffer has less than " + HASH_LENGTH + " bytes remaining");
        }
        ++this.numLeaves;
        byte[] bytes = new byte[HASH_LENGTH];
        hash.get(bytes);
        this.combiner.combine(bytes);
    }

    @Override
    public CompletableFuture<Bytes> rootHash() {
        this.rootHashRequested = true;
        this.rootHeight = ConcurrentStreamingTreeHasher.rootHeightFor(this.numLeaves);
        return this.combiner.finalCombination();
    }

    @Override
    public StreamingTreeHasher.Status status() {
        if (this.numLeaves == 0) {
            return StreamingTreeHasher.Status.EMPTY;
        }
        ArrayList<Bytes> rightmostHashes = new ArrayList<Bytes>();
        this.combiner.flushAvailable(rightmostHashes, ConcurrentStreamingTreeHasher.rootHeightFor(this.numLeaves + 1));
        return new StreamingTreeHasher.Status(this.numLeaves, rightmostHashes);
    }

    public static Bytes rootHashFrom(@NonNull StreamingTreeHasher.Status penultimateStatus, @NonNull Bytes lastLeafHash) {
        Objects.requireNonNull(lastLeafHash);
        Bytes hash = lastLeafHash;
        int rootHeight = ConcurrentStreamingTreeHasher.rootHeightFor(penultimateStatus.numLeaves() + 1);
        for (int i = 0; i < rootHeight; ++i) {
            Bytes rightmostHash = penultimateStatus.rightmostHashes().get(i);
            hash = rightmostHash.length() == 0L ? BlockImplUtils.hashInternalNode(hash, HashCombiner.EMPTY_HASHES[i]) : BlockImplUtils.hashInternalNode(rightmostHash, hash);
        }
        return hash;
    }

    private static int rootHeightFor(int numLeaves) {
        int numPerfectLeaves = ConcurrentStreamingTreeHasher.containingPowerOfTwo(numLeaves);
        return numPerfectLeaves == 0 ? 0 : Integer.numberOfTrailingZeros(numPerfectLeaves);
    }

    private static int containingPowerOfTwo(int n) {
        if ((n & n - 1) == 0) {
            return n;
        }
        return Integer.highestOneBit(n) << 1;
    }

    private class HashCombiner {
        private static final ThreadLocal<MessageDigest> DIGESTS = ThreadLocal.withInitial(CommonUtils::sha384DigestOrThrow);
        private static final int MAX_DEPTH = 24;
        private static final int MIN_TO_SCHEDULE = 16;
        private static final byte[][] EMPTY_HASHES = new byte[24][];
        private final int height;
        private HashCombiner delegate;
        private List<byte[]> pendingHashes = new ArrayList<byte[]>();
        private CompletableFuture<Void> combination = CompletableFuture.completedFuture(null);

        private HashCombiner(int height) {
            if (height >= 24) {
                throw new IllegalArgumentException("Cannot combine hashes at height " + height);
            }
            this.height = height;
        }

        public void combine(@NonNull byte[] hash) {
            this.pendingHashes.add(hash);
            if (this.pendingHashes.size() == ConcurrentStreamingTreeHasher.this.hashCombineBatchSize) {
                this.schedulePendingWork();
            }
        }

        public CompletableFuture<Bytes> finalCombination() {
            if (this.height == ConcurrentStreamingTreeHasher.this.rootHeight) {
                byte[] rootHash = this.pendingHashes.isEmpty() ? EMPTY_HASHES[0] : this.pendingHashes.getFirst();
                return CompletableFuture.completedFuture(Bytes.wrap((byte[])rootHash));
            }
            if (!this.pendingHashes.isEmpty()) {
                this.schedulePendingWork();
            }
            return this.combination.thenCompose(ignore -> this.delegate.finalCombination());
        }

        public void flushAvailable(@NonNull List<Bytes> rightmostHashes, int stopHeight) {
            if (this.height < stopHeight) {
                byte[] newPendingHash = this.pendingHashes.size() % 2 == 0 ? null : this.pendingHashes.removeLast();
                this.schedulePendingWork();
                this.combination.join();
                if (newPendingHash != null) {
                    this.pendingHashes.add(newPendingHash);
                    rightmostHashes.add(Bytes.wrap((byte[])newPendingHash));
                } else {
                    rightmostHashes.add(Bytes.EMPTY);
                }
                this.delegate.flushAvailable(rightmostHashes, stopHeight);
            }
        }

        private void schedulePendingWork() {
            CompletableFuture<List<byte[]>> pendingCombination;
            if (this.delegate == null) {
                this.delegate = new HashCombiner(this.height + 1);
            }
            if (this.pendingHashes.size() < 16) {
                pendingCombination = CompletableFuture.completedFuture(this.combine(this.pendingHashes));
            } else {
                List<byte[]> hashes = this.pendingHashes;
                pendingCombination = CompletableFuture.supplyAsync(() -> this.combine(hashes), ConcurrentStreamingTreeHasher.this.executorService);
            }
            this.combination = this.combination.thenCombine(pendingCombination, (ignore, combined) -> {
                combined.forEach(this.delegate::combine);
                return null;
            });
            this.pendingHashes = new ArrayList<byte[]>();
        }

        private List<byte[]> combine(@NonNull List<byte[]> hashes) {
            ArrayList<byte[]> result = new ArrayList<byte[]>();
            MessageDigest digest = DIGESTS.get();
            int m = hashes.size();
            for (int i = 0; i < m; i += 2) {
                byte[] left = hashes.get(i);
                byte[] right = i + 1 < m ? hashes.get(i + 1) : EMPTY_HASHES[this.height];
                result.add(BlockImplUtils.hashInternalNode(digest, left, right));
            }
            return result;
        }

        static {
            HashCombiner.EMPTY_HASHES[0] = BlockImplUtils.hashLeaf(new byte[StreamingTreeHasher.HASH_LENGTH]);
            for (int i = 1; i < 24; ++i) {
                HashCombiner.EMPTY_HASHES[i] = BlockImplUtils.hashInternalNode(EMPTY_HASHES[i - 1], EMPTY_HASHES[i - 1]);
            }
        }
    }
}

