/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.storage.scan;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import lombok.Generated;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.NumberUtil;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder;
import org.opensearch.search.aggregations.bucket.missing.MissingOrder;
import org.opensearch.search.sort.ScoreSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;
import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder;
import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser;
import org.opensearch.sql.opensearch.storage.OpenSearchIndex;

public abstract class AbstractCalciteIndexScan
extends TableScan {
    private static final Logger LOG = LogManager.getLogger(AbstractCalciteIndexScan.class);
    public final OpenSearchIndex osIndex;
    protected final RelDataType schema;
    protected final PushDownContext pushDownContext;

    protected AbstractCalciteIndexScan(RelOptCluster cluster, RelTraitSet traitSet, List<RelHint> hints, RelOptTable table, OpenSearchIndex osIndex, RelDataType schema, PushDownContext pushDownContext) {
        super(cluster, traitSet, hints, table);
        this.osIndex = Objects.requireNonNull(osIndex, "OpenSearch index");
        this.schema = schema;
        this.pushDownContext = pushDownContext;
    }

    public RelDataType deriveRowType() {
        return this.schema;
    }

    public RelWriter explainTerms(RelWriter pw) {
        OpenSearchRequestBuilder requestBuilder = this.osIndex.createRequestBuilder();
        this.pushDownContext.forEach(action -> action.apply(requestBuilder));
        String explainString = String.valueOf(this.pushDownContext) + ", " + String.valueOf(requestBuilder);
        return super.explainTerms(pw).itemIf("PushDownContext", (Object)explainString, !this.pushDownContext.isEmpty());
    }

    protected Integer getQuerySizeLimit() {
        return (Integer)this.osIndex.getSettings().getSettingValue(Settings.Key.QUERY_SIZE_LIMIT);
    }

    public double estimateRowCount(RelMetadataQuery mq) {
        double estimateRowCountFactor = (Double)this.osIndex.getSettings().getSettingValue(Settings.Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR);
        return this.pushDownContext.stream().reduce(this.osIndex.getMaxResultWindow().doubleValue(), (rowCount, action) -> (switch (action.type.ordinal()) {
            default -> throw new MatchException(null, null);
            case 2 -> mq.getRowCount((RelNode)action.digest);
            case 1, 3 -> rowCount;
            case 0 -> NumberUtil.multiply((Double)rowCount, (Double)RelMdUtil.guessSelectivity((RexNode)((RexNode)action.digest)));
            case 5 -> NumberUtil.multiply((Double)rowCount, (Double)RelMdUtil.guessSelectivity((RexNode)((RexNode)action.digest))) * 1.1;
            case 4 -> Math.min(rowCount, (double)((Integer)action.digest).intValue());
        }) * estimateRowCountFactor, (a, b) -> null);
    }

    protected abstract AbstractCalciteIndexScan buildScan(RelOptCluster var1, RelTraitSet var2, List<RelHint> var3, RelOptTable var4, OpenSearchIndex var5, RelDataType var6, PushDownContext var7);

    private List<String> getCollationNames(List<RelFieldCollation> collations) {
        return collations.stream().map(collation -> (String)this.getRowType().getFieldNames().get(collation.getFieldIndex())).toList();
    }

    private boolean hasAggregatorInSortBy(List<String> collations) {
        Stream<LogicalAggregate> aggregates = this.pushDownContext.stream().filter(action -> action.type() == PushDownType.AGGREGATION).map(action -> (LogicalAggregate)action.digest());
        return aggregates.map(aggregate -> AbstractCalciteIndexScan.isAnyCollationNameInAggregateOutput(aggregate, collations)).reduce(false, Boolean::logicalOr);
    }

    private static boolean isAnyCollationNameInAggregateOutput(LogicalAggregate aggregate, List<String> collations) {
        List fieldNames = aggregate.getRowType().getFieldNames();
        int groupOffset = aggregate.getGroupSet().cardinality();
        List fieldsWithoutGrouping = fieldNames.subList(groupOffset, fieldNames.size());
        return collations.stream().map(fieldsWithoutGrouping::contains).reduce(false, Boolean::logicalOr);
    }

    protected PushDownContext cloneWithoutSort(PushDownContext pushDownContext) {
        PushDownContext newContext = new PushDownContext();
        for (PushDownAction action : pushDownContext) {
            if (action.type() == PushDownType.SORT) continue;
            newContext.add(action);
        }
        return newContext;
    }

    public AbstractCalciteIndexScan pushDownSort(List<RelFieldCollation> collations) {
        try {
            Object digest;
            AbstractAction action;
            List<String> collationNames = this.getCollationNames(collations);
            if (this.getPushDownContext().isAggregatePushed() && this.hasAggregatorInSortBy(collationNames)) {
                return null;
            }
            RelTraitSet traitsWithCollations = this.getTraitSet().plus((RelTrait)RelCollations.of(collations));
            AbstractCalciteIndexScan newScan = this.buildScan(this.getCluster(), traitsWithCollations, (List<RelHint>)this.hints, this.table, this.osIndex, this.getRowType(), this.cloneWithoutSort(this.pushDownContext));
            if (this.pushDownContext.isAggregatePushed) {
                ((AggPushDownAction)Objects.requireNonNull((PushDownAction)this.pushDownContext.peekLast()).action).pushDownSortIntoAggBucket(collations);
                action = requestBuilder -> {};
                digest = collations;
            } else {
                ArrayList<SortBuilder> builders = new ArrayList<SortBuilder>();
                for (RelFieldCollation collation : collations) {
                    ScoreSortBuilder sortBuilder;
                    SortOrder order;
                    int index = collation.getFieldIndex();
                    String fieldName = (String)this.getRowType().getFieldNames().get(index);
                    RelFieldCollation.Direction direction = collation.getDirection();
                    RelFieldCollation.NullDirection nullDirection = collation.nullDirection;
                    SortOrder sortOrder = order = RelFieldCollation.Direction.DESCENDING.equals((Object)direction) ? SortOrder.DESC : SortOrder.ASC;
                    if ("_score".equals(fieldName)) {
                        sortBuilder = SortBuilders.scoreSort();
                    } else {
                        String missing = switch (nullDirection) {
                            case RelFieldCollation.NullDirection.FIRST -> "_first";
                            case RelFieldCollation.NullDirection.LAST -> "_last";
                            default -> null;
                        };
                        ExprType fieldType = this.osIndex.getFieldTypes().get(fieldName);
                        String field = OpenSearchTextType.toKeywordSubField(fieldName, fieldType);
                        sortBuilder = SortBuilders.fieldSort((String)field).missing((Object)missing);
                    }
                    builders.add(sortBuilder.order(order));
                }
                action = requestBuilder -> requestBuilder.pushDownSort(builders);
                digest = ((Object)builders).toString();
            }
            newScan.pushDownContext.add(PushDownAction.of(PushDownType.SORT, digest, action));
            return newScan;
        }
        catch (Exception e) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Cannot pushdown the sort {}", this.getCollationNames(collations), (Object)e);
            }
            return null;
        }
    }

    @Generated
    public OpenSearchIndex getOsIndex() {
        return this.osIndex;
    }

    @Generated
    public RelDataType getSchema() {
        return this.schema;
    }

    @Generated
    public PushDownContext getPushDownContext() {
        return this.pushDownContext;
    }

    public static class PushDownContext
    extends ArrayDeque<PushDownAction> {
        private boolean isAggregatePushed = false;
        private boolean isLimitPushed = false;

        @Override
        public PushDownContext clone() {
            return (PushDownContext)super.clone();
        }

        @Override
        public boolean add(PushDownAction pushDownAction) {
            if (pushDownAction.type == PushDownType.AGGREGATION) {
                this.isAggregatePushed = true;
            }
            if (pushDownAction.type == PushDownType.LIMIT) {
                this.isLimitPushed = true;
            }
            return super.add(pushDownAction);
        }

        public boolean isAggregatePushed() {
            if (this.isAggregatePushed) {
                return true;
            }
            this.isAggregatePushed = !this.isEmpty() && ((PushDownAction)super.peekLast()).type == PushDownType.AGGREGATION;
            return this.isAggregatePushed;
        }

        @Generated
        public boolean isLimitPushed() {
            return this.isLimitPushed;
        }
    }

    public record PushDownAction(PushDownType type, Object digest, AbstractAction action) {
        static PushDownAction of(PushDownType type, Object digest, AbstractAction action) {
            return new PushDownAction(type, digest, action);
        }

        @Override
        public String toString() {
            return String.valueOf((Object)this.type) + "->" + String.valueOf(this.digest);
        }

        public void apply(OpenSearchRequestBuilder requestBuilder) {
            this.action.apply(requestBuilder);
        }
    }

    protected static enum PushDownType {
        FILTER,
        PROJECT,
        AGGREGATION,
        SORT,
        LIMIT,
        SCRIPT;

    }

    public static interface AbstractAction {
        public void apply(OpenSearchRequestBuilder var1);
    }

    public static class AggPushDownAction
    implements AbstractAction {
        private Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser> aggregationBuilder;
        private final Map<String, OpenSearchDataType> extendedTypeMapping;

        public AggPushDownAction(Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser> aggregationBuilder, Map<String, OpenSearchDataType> extendedTypeMapping) {
            this.aggregationBuilder = aggregationBuilder;
            this.extendedTypeMapping = extendedTypeMapping;
        }

        @Override
        public void apply(OpenSearchRequestBuilder requestBuilder) {
            requestBuilder.pushDownAggregation(this.aggregationBuilder);
            requestBuilder.pushTypeMapping(this.extendedTypeMapping);
        }

        public void pushDownSortIntoAggBucket(List<RelFieldCollation> collations) {
            CompositeAggregationBuilder compositeAggregationBuilder = (CompositeAggregationBuilder)((List)this.aggregationBuilder.getLeft()).getFirst();
            List buckets = ((CompositeAggregationBuilder)((List)this.aggregationBuilder.getLeft()).getFirst()).sources();
            ArrayList newBuckets = new ArrayList(buckets.size());
            ArrayList selected = new ArrayList(collations.size());
            collations.forEach(collation -> {
                CompositeValuesSourceBuilder bucket = (CompositeValuesSourceBuilder)buckets.get(collation.getFieldIndex());
                RelFieldCollation.Direction direction = collation.getDirection();
                RelFieldCollation.NullDirection nullDirection = collation.nullDirection;
                SortOrder order = RelFieldCollation.Direction.DESCENDING.equals((Object)direction) ? SortOrder.DESC : SortOrder.ASC;
                MissingOrder missingOrder = switch (nullDirection) {
                    case RelFieldCollation.NullDirection.FIRST -> MissingOrder.FIRST;
                    case RelFieldCollation.NullDirection.LAST -> MissingOrder.LAST;
                    default -> MissingOrder.DEFAULT;
                };
                newBuckets.add(bucket.order(order).missingOrder(missingOrder));
                selected.add(collation.getFieldIndex());
            });
            IntStream.range(0, buckets.size()).filter(i -> !selected.contains(i)).forEach(i -> newBuckets.add((CompositeValuesSourceBuilder)buckets.get(i)));
            AggregatorFactories.Builder newAggBuilder = new AggregatorFactories.Builder();
            compositeAggregationBuilder.getSubAggregations().forEach(arg_0 -> ((AggregatorFactories.Builder)newAggBuilder).addAggregator(arg_0));
            this.aggregationBuilder = Pair.of(Collections.singletonList(((CompositeAggregationBuilder)AggregationBuilders.composite((String)"composite_buckets", newBuckets).subAggregations(newAggBuilder)).size(1000)), (Object)((OpenSearchAggregationResponseParser)this.aggregationBuilder.getRight()));
        }
    }
}

