/*
 * Decompiled with CFR 0.152.
 */
package com.swirlds.virtualmap.internal.hash;

import com.swirlds.logging.legacy.LogMarker;
import com.swirlds.virtualmap.VirtualKey;
import com.swirlds.virtualmap.VirtualValue;
import com.swirlds.virtualmap.config.VirtualMapConfig;
import com.swirlds.virtualmap.datasource.VirtualLeafRecord;
import com.swirlds.virtualmap.internal.Path;
import com.swirlds.virtualmap.internal.hash.VirtualHashListener;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.LongFunction;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.hiero.base.concurrent.AbstractTask;
import org.hiero.base.crypto.Cryptography;
import org.hiero.base.crypto.CryptographyProvider;
import org.hiero.base.crypto.Hash;
import org.hiero.base.crypto.HashBuilder;

public final class VirtualHasher<K extends VirtualKey, V extends VirtualValue> {
    private static final Logger logger = LogManager.getLogger(VirtualHasher.class);
    private static final ThreadLocal<HashBuilder> HASH_BUILDER_THREAD_LOCAL = ThreadLocal.withInitial(() -> new HashBuilder(Cryptography.DEFAULT_DIGEST_TYPE));
    private LongFunction<Hash> hashReader;
    private VirtualHashListener<K, V> listener;
    private static final Cryptography CRYPTOGRAPHY = CryptographyProvider.getInstance();
    private final AtomicBoolean shutdown = new AtomicBoolean(false);

    public void shutdown() {
        this.shutdown.set(true);
    }

    private static int getChunkHeightForInputRank(int rank, int firstLeafRank, int lastLeafRank, int defaultChunkHeight) {
        if (rank == lastLeafRank && firstLeafRank != lastLeafRank) {
            return 1;
        }
        if (rank == firstLeafRank) {
            return (rank - 1) % defaultChunkHeight + 1;
        }
        assert (rank % defaultChunkHeight == 0);
        return defaultChunkHeight;
    }

    private static int getChunkHeightForOutputRank(int rank, int firstLeafRank, int lastLeafRank, int defaultChunkHeight) {
        if (rank == firstLeafRank && firstLeafRank != lastLeafRank) {
            return 1;
        }
        assert (rank % defaultChunkHeight == 0);
        assert (rank < firstLeafRank);
        return Math.min(defaultChunkHeight, firstLeafRank - rank);
    }

    public Hash hash(LongFunction<Hash> hashReader, Iterator<VirtualLeafRecord<K, V>> sortedDirtyLeaves, long firstLeafPath, long lastLeafPath, VirtualHashListener<K, V> listener, @NonNull VirtualMapConfig virtualMapConfig) {
        ForkJoinPool forkJoinPool;
        Thread thread;
        Objects.requireNonNull(virtualMapConfig);
        if (listener == null) {
            listener = new VirtualHashListener<K, V>(this){};
        }
        if ((thread = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
            ForkJoinWorkerThread thread2 = (ForkJoinWorkerThread)thread;
            forkJoinPool = thread2.getPool();
        } else {
            forkJoinPool = ForkJoinPool.commonPool();
        }
        ForkJoinPool hashingPool = forkJoinPool;
        listener.onHashingStarted(firstLeafPath, lastLeafPath);
        if (!sortedDirtyLeaves.hasNext()) {
            listener.onHashingCompleted();
            return null;
        }
        if (firstLeafPath < 1L || lastLeafPath < 1L) {
            throw new IllegalArgumentException("Dirty leaves stream is not empty, but leaf path range is empty");
        }
        this.hashReader = hashReader;
        this.listener = listener;
        int chunkHeight = virtualMapConfig.virtualHasherChunkHeight();
        int firstLeafRank = Path.getRank(firstLeafPath);
        int lastLeafRank = Path.getRank(lastLeafPath);
        HashMap<Long, HashProducingTask> allTasks = new HashMap<Long, HashProducingTask>();
        int rootTaskHeight = Math.min(firstLeafRank, chunkHeight);
        ChunkHashTask rootTask = new ChunkHashTask(hashingPool, 0L, rootTaskHeight);
        rootTask.setOut(null);
        allTasks.put(0L, rootTask);
        boolean firstLeaf = true;
        long[] stack = new long[lastLeafRank + 1];
        Arrays.fill(stack, -1L);
        while (sortedDirtyLeaves.hasNext()) {
            VirtualLeafRecord<K, V> leaf = sortedDirtyLeaves.next();
            long curPath = leaf.getPath();
            LeafHashTask leafTask = (LeafHashTask)((Object)allTasks.remove(curPath));
            if (leafTask == null) {
                leafTask = new LeafHashTask(hashingPool, curPath);
            }
            leafTask.setLeaf(leaf);
            HashProducingTask curTask = leafTask;
            while (true) {
                int curRank = Path.getRank(curPath);
                int parentChunkHeight = VirtualHasher.getChunkHeightForInputRank(curRank, firstLeafRank, lastLeafRank, chunkHeight);
                int chunkWidth = 1 << parentChunkHeight;
                long curStackPath = stack[curRank];
                if (curStackPath != -1L) {
                    stack[curRank] = -1L;
                    long firstPathInRank = Path.getPathForRankAndIndex(curRank, 0L);
                    long curStackChunkNoInRank = (curStackPath - firstPathInRank) / (long)chunkWidth;
                    long firstPathInNextChunk = firstPathInRank + (curStackChunkNoInRank + 1L) * (long)chunkWidth;
                    while (curStackPath < firstPathInNextChunk) {
                        if (curStackPath == curPath) {
                            if (curPath + 1L >= firstPathInNextChunk) break;
                            stack[curRank] = curPath + 1L;
                            break;
                        }
                        HashProducingTask t = (HashProducingTask)((Object)allTasks.remove(curStackPath));
                        assert (t != null);
                        t.complete();
                        ++curStackPath;
                    }
                }
                if (curTask.hasOut() || curTask == rootTask) break;
                long parentPath = Path.getGrandParentPath(curPath, parentChunkHeight);
                ChunkHashTask parentTask = (ChunkHashTask)((Object)allTasks.remove(parentPath));
                if (parentTask == null) {
                    parentTask = new ChunkHashTask(hashingPool, parentPath, parentChunkHeight);
                }
                curTask.setOut(parentTask);
                long firstPathInRank = Path.getPathForRankAndIndex(curRank, 0L);
                long chunkNoInRank = (curPath - firstPathInRank) / (long)chunkWidth;
                long firstSiblingPath = firstPathInRank + chunkNoInRank * (long)chunkWidth;
                long lastSiblingPath = firstSiblingPath + (long)chunkWidth - 1L;
                for (long siblingPath = firstSiblingPath; siblingPath <= lastSiblingPath; ++siblingPath) {
                    HashProducingTask siblingTask;
                    if (siblingPath == curPath) continue;
                    if (siblingPath > lastLeafPath) {
                        assert (siblingPath == 2L);
                        parentTask.setHash(siblingPath, Cryptography.NULL_HASH);
                        continue;
                    }
                    if (siblingPath < curPath && !firstLeaf) {
                        assert (!allTasks.containsKey(siblingPath));
                        parentTask.send();
                        continue;
                    }
                    if (siblingPath >= firstLeafPath) {
                        assert (!allTasks.containsKey(siblingPath));
                        siblingTask = allTasks.computeIfAbsent(siblingPath, p -> new LeafHashTask(hashingPool, (long)p));
                    } else {
                        int taskChunkHeight = VirtualHasher.getChunkHeightForOutputRank(curRank, firstLeafRank, lastLeafRank, chunkHeight);
                        siblingTask = allTasks.computeIfAbsent(siblingPath, path -> new ChunkHashTask(hashingPool, (long)path, taskChunkHeight));
                    }
                    siblingTask.setOut(parentTask);
                }
                if (curPath != lastSiblingPath && !firstLeaf) {
                    stack[curRank] = curPath + 1L;
                }
                curPath = parentPath;
                curTask = parentTask;
            }
            firstLeaf = false;
        }
        allTasks.forEach((path, task) -> task.complete());
        allTasks.clear();
        try {
            rootTask.join();
        }
        catch (Exception e) {
            if (this.shutdown.get()) {
                return null;
            }
            logger.error(LogMarker.EXCEPTION.getMarker(), "Failed to wait for all hashing tasks", (Throwable)e);
            throw e;
        }
        listener.onHashingCompleted();
        this.hashReader = null;
        this.listener = null;
        return rootTask.getResult();
    }

    public Hash emptyRootHash() {
        return ChunkHashTask.hash(0L, Cryptography.NULL_HASH, Cryptography.NULL_HASH);
    }

    class ChunkHashTask
    extends HashProducingTask {
        private final long path;
        private final int height;
        private final Hash[] ins;

        ChunkHashTask(ForkJoinPool pool, long path, int height) {
            super(pool, 1 + (1 << height));
            this.path = path;
            this.height = height;
            this.ins = new Hash[1 << height];
        }

        @Override
        public void complete() {
            assert (Arrays.stream(this.ins).allMatch(Objects::isNull));
            super.complete();
        }

        void setHash(long path, Hash hash) {
            assert (Path.getRank(this.path) + this.height == Path.getRank(path));
            long firstPathInPathRank = Path.getLeftGrandChildPath(this.path, this.height);
            int index = Math.toIntExact(path - firstPathInPathRank);
            assert (index >= 0 && index < this.ins.length);
            this.ins[index] = hash;
            this.send();
        }

        Hash getResult() {
            assert (this.isDone());
            return this.ins[0];
        }

        @Override
        protected boolean onExecute() {
            long rankPath = Path.getLeftGrandChildPath(this.path, this.height);
            for (int len = 1 << this.height; len > 1; len >>= 1) {
                for (int i = 0; i < len / 2; ++i) {
                    long hashedPath = Path.getParentPath(rankPath + (long)(i * 2));
                    Hash left = this.ins[i * 2];
                    Hash right = this.ins[i * 2 + 1];
                    if (left == null && right == null) {
                        this.ins[i] = null;
                        continue;
                    }
                    if (left == null) {
                        left = VirtualHasher.this.hashReader.apply(rankPath + (long)(i * 2));
                    }
                    if (right == null) {
                        right = VirtualHasher.this.hashReader.apply(rankPath + (long)(i * 2) + 1L);
                    }
                    this.ins[i] = ChunkHashTask.hash(hashedPath, left, right);
                    VirtualHasher.this.listener.onNodeHashed(hashedPath, this.ins[i]);
                }
                rankPath = Path.getParentPath(rankPath);
            }
            if (this.out != null) {
                this.out.setHash(this.path, this.ins[0]);
            }
            return true;
        }

        static Hash hash(long path, Hash left, Hash right) {
            long classId = path == 0L ? 5367589755328273141L : -5826388714229745985L;
            int serId = path == 0L ? 3 : 1;
            HashBuilder builder = HASH_BUILDER_THREAD_LOCAL.get();
            builder.reset();
            builder.update(classId);
            builder.update(serId);
            builder.update(left);
            builder.update(right);
            return builder.build();
        }
    }

    class LeafHashTask
    extends HashProducingTask {
        private final long path;
        private VirtualLeafRecord<K, V> leaf;

        LeafHashTask(ForkJoinPool pool, long path) {
            super(pool, 2);
            this.path = path;
        }

        @Override
        public void complete() {
            assert (this.leaf == null);
            super.complete();
        }

        void setLeaf(VirtualLeafRecord<K, V> leaf) {
            assert (leaf != null);
            assert (this.path == leaf.getPath());
            this.leaf = leaf;
            this.send();
        }

        @Override
        protected boolean onExecute() {
            Hash hash = null;
            if (this.leaf != null) {
                hash = CRYPTOGRAPHY.digestSync(this.leaf);
                VirtualHasher.this.listener.onLeafHashed(this.leaf);
                VirtualHasher.this.listener.onNodeHashed(this.path, hash);
            }
            this.out.setHash(this.path, hash);
            return true;
        }
    }

    class HashProducingTask
    extends AbstractTask {
        protected ChunkHashTask out;

        HashProducingTask(ForkJoinPool pool, int dependencyCount) {
            super(pool, dependencyCount);
        }

        boolean hasOut() {
            return this.out != null;
        }

        void setOut(ChunkHashTask out) {
            this.out = out;
            this.send();
        }

        void complete() {
            this.out.send();
        }

        protected boolean onExecute() {
            return true;
        }

        protected void onException(Throwable t) {
            if (this.out != null) {
                this.out.completeExceptionally(t);
            }
        }
    }
}

