/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.expression.aggregation;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import org.opensearch.sql.data.model.ExprCollectionValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.aggregation.AggregationState;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.utils.ExpressionUtils;

public class TakeAggregator
extends Aggregator<TakeState> {
    public TakeAggregator(List<Expression> arguments, ExprCoreType returnType) {
        super(BuiltinFunctionName.TAKE.getName(), arguments, returnType);
    }

    @Override
    public TakeState create() {
        return new TakeState(this.getArguments().get(1).valueOf().integerValue());
    }

    @Override
    protected TakeState iterate(ExprValue value, TakeState state) {
        state.take(value);
        return state;
    }

    public String toString() {
        return String.format(Locale.ROOT, "take(%s)", ExpressionUtils.format(this.getArguments()));
    }

    protected static class TakeState
    implements AggregationState {
        protected int index;
        protected int size;
        protected List<ExprValue> hits;

        TakeState(int size) {
            if (size <= 0) {
                throw new IllegalArgumentException("size must be greater than 0");
            }
            this.index = 0;
            this.size = size;
            this.hits = new ArrayList<ExprValue>();
        }

        public void take(ExprValue value) {
            if (this.index < this.size) {
                this.hits.add(value);
            }
            ++this.index;
        }

        @Override
        public ExprValue result() {
            return new ExprCollectionValue(this.hits);
        }
    }
}

