/*
 * Decompiled with CFR 0.152.
 */
package org.openjdk.jmc.flightrecorder.stacktrace.graph;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.openjdk.jmc.flightrecorder.stacktrace.graph.AggregatableFrame;
import org.openjdk.jmc.flightrecorder.stacktrace.graph.Edge;
import org.openjdk.jmc.flightrecorder.stacktrace.graph.Node;
import org.openjdk.jmc.flightrecorder.stacktrace.graph.StacktraceGraphModel;

public class Pruning {
    public static StacktraceGraphModel prune(StacktraceGraphModel model, int maxNodeCount, boolean trimLowFrequency) {
        double nodeFraction;
        long nodeCutoff;
        long totalValue = model.getNodes().stream().mapToLong(node -> node.count).sum();
        if (trimLowFrequency && (nodeCutoff = Math.round((double)totalValue * (nodeFraction = 0.005))) > 0L) {
            model = Pruning.discardLowFrequencyNodes(model, nodeCutoff);
        }
        HashMap<Integer, Long> nodeScores = new HashMap<Integer, Long>();
        for (Node node2 : model.getNodes()) {
            long score = Pruning.entropyScore(node2);
            nodeScores.put(node2.getNodeId(), score);
        }
        ArrayList<Node> sortedNodes = new ArrayList<Node>(model.getNodes());
        sortedNodes.sort((n1, n2) -> {
            long score1 = (Long)nodeScores.get(n1.getNodeId());
            long score2 = (Long)nodeScores.get(n2.getNodeId());
            return -Long.compare(score1, score2);
        });
        if (trimLowFrequency) {
            double edgeFraction = 0.001;
            long edgeCutoff = Math.round((double)totalValue * edgeFraction);
            Pruning.trimLowFrequencyEdges(sortedNodes, edgeCutoff);
        }
        return Pruning.selectTopNode(model, sortedNodes, maxNodeCount);
    }

    private static StacktraceGraphModel discardLowFrequencyNodes(StacktraceGraphModel model, long nodeCutoff) {
        HashSet<AggregatableFrame> cutNodes = new HashSet<AggregatableFrame>(model.getNodes().size());
        for (Node node : model.getNodes()) {
            if (node.cumulativeWeight < (double)nodeCutoff) continue;
            cutNodes.add(node.getFrame());
        }
        return new StacktraceGraphModel(model, cutNodes);
    }

    private static StacktraceGraphModel selectTopNode(StacktraceGraphModel model, Collection<Node> sortedNodes, int maxCount) {
        HashSet<AggregatableFrame> cutNodes = new HashSet<AggregatableFrame>(model.getNodes().size());
        int count = 0;
        for (Node node : sortedNodes) {
            cutNodes.add(node.getFrame());
            if (++count < maxCount) continue;
            break;
        }
        return new StacktraceGraphModel(model, cutNodes);
    }

    private static int trimLowFrequencyEdges(Collection<Node> sortedNode, long edgeCutoff) {
        int droppedEdges = 0;
        for (Node node : sortedNode) {
            for (Map.Entry<Node.NodeWrapper, Edge> entry : new HashSet<Map.Entry<Node.NodeWrapper, Edge>>(node.getIn().entrySet())) {
                if (!(entry.getValue().value < (double)edgeCutoff)) continue;
                node.getIn().remove(entry.getKey());
                entry.getKey().node.getOut().remove(new Node.NodeWrapper(node.getNodeId(), node));
                ++droppedEdges;
            }
        }
        return droppedEdges;
    }

    private static long entropyScore(Node node) {
        double score = 0.0;
        score = node.getIn().isEmpty() ? (score += 1.0) : (score += Pruning.edgeEntropyScore(node, node.getIn().values(), 0.0));
        score = node.getOut().isEmpty() ? (score += 1.0) : (score += Pruning.edgeEntropyScore(node, node.getOut().values(), node.weight));
        return Math.round(score * node.cumulativeWeight + node.weight);
    }

    private static double edgeEntropyScore(Node node, Collection<Edge> edges, double self) {
        double score = 0.0;
        double total = self;
        for (Edge edge : edges) {
            if (!(edge.getValue() > 0.0)) continue;
            total += Math.abs(edge.getValue());
        }
        if (total > 0.0) {
            for (Edge edge : edges) {
                double frac = Math.abs(edge.getValue()) / total;
                score += -frac * Math.log(frac);
            }
            if (self > 0.0) {
                double frac = Math.abs(self) / total;
                score += -frac * Math.log(frac);
            }
        }
        return score;
    }
}

