/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchrelevance.metrics.calculator;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.searchrelevance.exception.SearchRelevanceException;

public class PairComparison {
    public static final String JACCARD_SIMILARITY_FIELD_NAME = "jaccard";
    public static final String RBO_50_SIMILARITY_FIELD_NAME = "rbo50";
    public static final String RBO_90_SIMILARITY_FIELD_NAME = "rbo90";
    public static final String FREQUENCY_WEIGHTED_SIMILARITY_FIELD_NAME = "frequencyWeighted";

    public static double calculateJaccardSimilarity(List<String> listA, List<String> listB) {
        HashSet<String> setA = new HashSet<String>(listA);
        HashSet<String> setB = new HashSet<String>(listB);
        HashSet<String> intersection = new HashSet<String>(setA);
        intersection.retainAll(setB);
        HashSet<String> union = new HashSet<String>(setA);
        union.addAll(setB);
        if (union.isEmpty()) {
            return 0.0;
        }
        double jaccardSimilarity = (double)intersection.size() / (double)union.size();
        return (double)Math.round(jaccardSimilarity * 100.0) / 100.0;
    }

    public static double calculateRBOSimilarity(List<String> listA, List<String> listB, double p) {
        if (p <= 0.0 || p >= 1.0) {
            throw new SearchRelevanceException("p must be between 0 and 1", RestStatus.INTERNAL_SERVER_ERROR);
        }
        int maxDepth = Math.max(listA.size(), listB.size());
        double sum = 0.0;
        double weight = 1.0;
        double sumWeight = 0.0;
        for (int d = 0; d < maxDepth; ++d) {
            HashSet<String> setA = new HashSet<String>(listA.subList(0, Math.min(d + 1, listA.size())));
            HashSet<String> setB = new HashSet<String>(listB.subList(0, Math.min(d + 1, listB.size())));
            HashSet<String> intersection = new HashSet<String>(setA);
            intersection.retainAll(setB);
            double overlap = (double)intersection.size() / (double)Math.max(setA.size(), setB.size());
            sum += weight * overlap;
            sumWeight += weight;
            weight *= p;
        }
        double rboSimilarity = sum * (1.0 - p) / (1.0 - Math.pow(p, maxDepth));
        return (double)Math.round(rboSimilarity * 100.0) / 100.0;
    }

    public static double calculateFrequencyWeightedSimilarity(List<String> listA, List<String> listB) {
        Map<String, Double> weights = PairComparison.calculateCombinedWeights(listA, listB);
        double intersectionWeight = 0.0;
        for (String item : new HashSet<String>(listA)) {
            if (!listB.contains(item)) continue;
            intersectionWeight += weights.get(item).doubleValue();
        }
        double unionWeight = weights.values().stream().mapToDouble(Double::doubleValue).sum();
        double frequencyWeightedSimilarity = unionWeight == 0.0 ? 0.0 : intersectionWeight / unionWeight;
        return (double)Math.round(frequencyWeightedSimilarity * 100.0) / 100.0;
    }

    private static Map<String, Double> calculateCombinedWeights(List<String> listA, List<String> listB) {
        FrequencyStats statsA = PairComparison.calculateFrequencyWeights(listA);
        FrequencyStats statsB = PairComparison.calculateFrequencyWeights(listB);
        HashMap<String, Double> combinedWeights = new HashMap<String, Double>();
        HashSet<String> allItems = new HashSet<String>();
        allItems.addAll(statsA.weights.keySet());
        allItems.addAll(statsB.weights.keySet());
        for (String item : allItems) {
            double weightA = statsA.weights.getOrDefault(item, 0.0);
            double weightB = statsB.weights.getOrDefault(item, 0.0);
            combinedWeights.put(item, (weightA + weightB) / 2.0);
        }
        return combinedWeights;
    }

    private static FrequencyStats calculateFrequencyWeights(List<String> list) {
        HashMap<String, Integer> frequencies = new HashMap<String, Integer>();
        for (String item : list) {
            frequencies.put(item, frequencies.getOrDefault(item, 0) + 1);
        }
        double totalFrequency = frequencies.values().stream().mapToInt(Integer::intValue).sum();
        HashMap<String, Double> weights = new HashMap<String, Double>();
        for (Map.Entry entry : frequencies.entrySet()) {
            weights.put((String)entry.getKey(), (double)((Integer)entry.getValue()).intValue() / totalFrequency);
        }
        return new FrequencyStats(weights, frequencies, totalFrequency);
    }

    private static class FrequencyStats {
        public final Map<String, Double> weights;
        public final Map<String, Integer> frequencies;
        public final double totalFrequency;

        public FrequencyStats(Map<String, Double> weights, Map<String, Integer> frequencies, double totalFrequency) {
            this.weights = weights;
            this.frequencies = frequencies;
            this.totalFrequency = totalFrequency;
        }
    }
}

