/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchrelevance.ubi;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchScrollRequest;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.Scroll;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.searchrelevance.ubi.QuerySampler;
import org.opensearch.transport.client.Client;

public class ProbabilityProportionalToSizeQuerySampler
extends QuerySampler {
    public static final String NAME = "pptss";
    private static final Logger LOGGER = LogManager.getLogger(ProbabilityProportionalToSizeQuerySampler.class);
    private static final double EPSILON = 1.0E-5;

    public ProbabilityProportionalToSizeQuerySampler(int size, Client client) {
        super(size, client);
    }

    @Override
    public CompletableFuture<Map<String, Integer>> sample() {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query((QueryBuilder)QueryBuilders.matchAllQuery()).size(10000);
        final CompletableFuture<Map<String, Integer>> future = new CompletableFuture<Map<String, Integer>>();
        this.getUserQueries(searchSourceBuilder, new ActionListener<Collection<String>>(){
            final /* synthetic */ ProbabilityProportionalToSizeQuerySampler this$0;
            {
                this.this$0 = this$0;
            }

            public void onResponse(Collection<String> userQueries) {
                try {
                    if (userQueries.isEmpty()) {
                        LOGGER.warn("No queries found in {}", (Object)"ubi_queries");
                        future.complete(new HashMap());
                        return;
                    }
                    Map<String, Integer> result = this.this$0.getQuerySet(userQueries);
                    future.complete(result);
                }
                catch (Exception e) {
                    LOGGER.error("Error processing user queries", (Throwable)e);
                    future.complete(new HashMap());
                }
            }

            public void onFailure(Exception e) {
                LOGGER.error("Failed to retrieve queries from {}: {}", (Object)"ubi_queries", (Object)e.getMessage());
                future.complete(new HashMap());
            }
        });
        return future;
    }

    private Map<String, Integer> getQuerySet(Collection<String> userQueries) {
        HashMap<String, Long> weights = new HashMap<String, Long>();
        HashMap<String, Double> normalizedWeights = new HashMap<String, Double>();
        HashMap cumulativeWeights = new HashMap();
        HashMap<String, Integer> querySet = new HashMap<String, Integer>();
        userQueries.forEach(query -> weights.merge((String)query, 1L, Long::sum));
        long countOfQueries = userQueries.size();
        weights.forEach((query, weight) -> normalizedWeights.put((String)query, weight.doubleValue() / (double)countOfQueries));
        double sumOfNormalizedWeights = normalizedWeights.values().stream().mapToDouble(Double::doubleValue).sum();
        if (!this.compareDouble(1.0, sumOfNormalizedWeights)) {
            throw new IllegalStateException("Summed normalized weights do not equal 1.0: " + sumOfNormalizedWeights);
        }
        double[] lastWeight = new double[]{0.0};
        normalizedWeights.forEach((query, weight) -> {
            lastWeight[0] = lastWeight[0] + weight;
            cumulativeWeights.put(query, lastWeight[0]);
        });
        if (!this.compareDouble(lastWeight[0], 1.0)) {
            throw new IllegalStateException("The sum of cumulative weights does not equal 1.0: " + lastWeight[0]);
        }
        UniformRealDistribution uniform = new UniformRealDistribution(0.0, 1.0);
        block0: for (int i = 1; i <= this.getSize(); ++i) {
            double r = uniform.sample();
            for (String userQuery : cumulativeWeights.keySet()) {
                double cumulativeWeight = (Double)cumulativeWeights.get(userQuery);
                if (!(cumulativeWeight >= r)) continue;
                querySet.put(userQuery, Math.toIntExact((Long)weights.get(userQuery)));
                continue block0;
            }
        }
        return querySet;
    }

    private void getUserQueries(SearchSourceBuilder searchSourceBuilder, ActionListener<Collection<String>> listener) {
        ArrayList<String> userQueries = new ArrayList<String>();
        this.scrollUserQueries(searchSourceBuilder, new Scroll(TimeValue.timeValueMinutes((long)10L)), userQueries, null, listener);
    }

    private void scrollUserQueries(SearchSourceBuilder searchSourceBuilder, Scroll scroll, Collection<String> accumulator, String scrollId, ActionListener<Collection<String>> listener) {
        try {
            if (scrollId == null) {
                SearchRequest searchRequest = new SearchRequest(new String[]{"ubi_queries"}).scroll(scroll).source(searchSourceBuilder);
                this.getClient().search(searchRequest, ActionListener.wrap(searchResponse -> this.processSearchResponse((SearchResponse)searchResponse, scroll, accumulator, listener), arg_0 -> listener.onFailure(arg_0)));
            } else {
                SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId).scroll(scroll);
                this.getClient().searchScroll(scrollRequest, ActionListener.wrap(searchResponse -> this.processSearchResponse((SearchResponse)searchResponse, scroll, accumulator, listener), arg_0 -> listener.onFailure(arg_0)));
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private void processSearchResponse(SearchResponse searchResponse, Scroll scroll, Collection<String> accumulator, ActionListener<Collection<String>> listener) {
        try {
            SearchHit[] hits = searchResponse.getHits().getHits();
            if (hits == null || hits.length == 0) {
                listener.onResponse(accumulator);
                return;
            }
            for (SearchHit hit : hits) {
                Map fields = hit.getSourceAsMap();
                String userQuery = fields.get("user_query").toString();
                accumulator.add(userQuery);
                LOGGER.debug("User queries count: {} user query: {}", (Object)accumulator.size(), (Object)userQuery);
            }
            this.scrollUserQueries(null, scroll, accumulator, searchResponse.getScrollId(), listener);
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private boolean compareDouble(double a, double b) {
        return Math.abs(a - b) < 1.0E-5;
    }
}

