/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.memorycontainer;

import java.time.Instant;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceAlreadyExistsException;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLIndex;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.memorycontainer.MLMemoryContainer;
import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerInput;
import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerRequest;
import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportCreateMemoryContainerAction
extends HandledTransportAction<MLCreateMemoryContainerRequest, MLCreateMemoryContainerResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportCreateMemoryContainerAction.class);
    private final MLIndicesHandler mlIndicesHandler;
    private final Client client;
    private final SdkClient sdkClient;
    private final ClusterService clusterService;
    private final ConnectorAccessControlHelper connectorAccessControlHelper;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final MLModelManager mlModelManager;

    @Inject
    public TransportCreateMemoryContainerAction(TransportService transportService, ActionFilters actionFilters, Client client, SdkClient sdkClient, ClusterService clusterService, MLIndicesHandler mlIndicesHandler, ConnectorAccessControlHelper connectorAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, MLModelManager mlModelManager) {
        super("cluster:admin/opensearch/ml/memory_containers/create", transportService, actionFilters, MLCreateMemoryContainerRequest::new);
        this.client = client;
        this.sdkClient = sdkClient;
        this.clusterService = clusterService;
        this.mlIndicesHandler = mlIndicesHandler;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.mlModelManager = mlModelManager;
    }

    protected void doExecute(Task task, MLCreateMemoryContainerRequest request, ActionListener<MLCreateMemoryContainerResponse> listener) {
        if (!this.mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
            listener.onFailure((Exception)new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN, new Object[0]));
            return;
        }
        MLCreateMemoryContainerInput input = request.getMlCreateMemoryContainerInput();
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, input.getTenantId(), listener)) {
            return;
        }
        User user = RestActionUtils.getUserContext(this.client);
        String tenantId = input.getTenantId();
        this.validateModels(input.getMemoryStorageConfig(), (ActionListener<Boolean>)ActionListener.wrap(isValid -> {
            ActionListener indexCheckListener = ActionListener.wrap(created -> {
                try {
                    MLMemoryContainer memoryContainer = this.buildMemoryContainer(input, user, tenantId);
                    this.indexMemoryContainer(memoryContainer, (ActionListener<String>)ActionListener.wrap(memoryContainerId -> this.createMemoryDataIndices((String)memoryContainerId, memoryContainer, user, (ActionListener<String>)ActionListener.wrap(actualIndexName -> {
                        MemoryStorageConfig config = memoryContainer.getMemoryStorageConfig();
                        if (config == null) {
                            config = MemoryStorageConfig.builder().memoryIndexName(actualIndexName).build();
                        } else {
                            config.setMemoryIndexName(actualIndexName);
                        }
                        memoryContainer.setMemoryStorageConfig(config);
                        this.updateMemoryContainer((String)memoryContainerId, memoryContainer, (ActionListener<Boolean>)ActionListener.wrap(updated -> listener.onResponse((Object)new MLCreateMemoryContainerResponse(memoryContainerId, "created")), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
                    }, arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
                }
                catch (Exception e) {
                    log.error("Failed to create memory container", (Throwable)e);
                    listener.onFailure(e);
                }
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0));
            this.initMemoryContainerIndexIfAbsent((ActionListener<Boolean>)indexCheckListener);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private void initMemoryContainerIndexIfAbsent(ActionListener<Boolean> listener) {
        try {
            this.mlIndicesHandler.initMLIndexIfAbsent(MLIndex.MEMORY_CONTAINER, listener);
        }
        catch (Exception e) {
            log.error("Failed to init memory container index", (Throwable)e);
            listener.onFailure(e);
        }
    }

    private MLMemoryContainer buildMemoryContainer(MLCreateMemoryContainerInput input, User user, String tenantId) {
        Instant now = Instant.now();
        return MLMemoryContainer.builder().name(input.getName()).description(input.getDescription()).owner(user).tenantId(tenantId).createdTime(now).lastUpdatedTime(now).memoryStorageConfig(input.getMemoryStorageConfig()).build();
    }

    private void createMemoryDataIndices(String memoryContainerId, MLMemoryContainer container, User user, ActionListener<String> listener) {
        Object baseIndexName;
        String userId = user != null ? user.getName() : "default";
        MemoryStorageConfig memoryStorageConfig = container.getMemoryStorageConfig();
        Object object = baseIndexName = memoryStorageConfig != null ? memoryStorageConfig.getMemoryIndexName() : null;
        if (baseIndexName == null) {
            if (memoryStorageConfig == null || !memoryStorageConfig.isSemanticStorageEnabled()) {
                baseIndexName = "ml-static-memory-" + memoryContainerId + "-" + userId;
            } else if (memoryStorageConfig.getEmbeddingModelType() == FunctionName.TEXT_EMBEDDING) {
                baseIndexName = "ml-knn-memory-" + memoryContainerId + "-" + userId;
            } else if (memoryStorageConfig.getEmbeddingModelType() == FunctionName.SPARSE_ENCODING) {
                baseIndexName = "ml-sparse-memory-" + memoryContainerId + "-" + userId;
            }
        }
        String finalIndexName = ((String)baseIndexName).toLowerCase(Locale.ROOT);
        this.createMemoryDataIndex(finalIndexName, container.getMemoryStorageConfig(), (ActionListener<Boolean>)ActionListener.wrap(success -> listener.onResponse((Object)finalIndexName), arg_0 -> listener.onFailure(arg_0)));
    }

    private void createMemoryDataIndex(String indexName, MemoryStorageConfig memoryStorageConfig, ActionListener<Boolean> listener) {
        try {
            HashMap<String, Comparable<Boolean>> indexSettings = new HashMap<String, Comparable<Boolean>>();
            HashMap indexMappings = new HashMap();
            HashMap<String, Map<String, String>> properties = new HashMap<String, Map<String, String>>();
            properties.put("user_id", Map.of("type", "keyword"));
            properties.put("agent_id", Map.of("type", "keyword"));
            properties.put("session_id", Map.of("type", "keyword"));
            properties.put("memory", Map.of("type", "text"));
            properties.put("tags", Map.of("type", "flat_object"));
            properties.put("memory_type", Map.of("type", "keyword"));
            properties.put("role", Map.of("type", "text"));
            properties.put("created_time", Map.of("type", "date", "format", "strict_date_time||epoch_millis"));
            properties.put("last_updated_time", Map.of("type", "date", "format", "strict_date_time||epoch_millis"));
            if (memoryStorageConfig != null && memoryStorageConfig.isSemanticStorageEnabled()) {
                if (memoryStorageConfig.getEmbeddingModelType() == FunctionName.TEXT_EMBEDDING) {
                    indexSettings.put("index.knn", Boolean.valueOf(true));
                    indexSettings.put("index.knn.algo_param.ef_search", Integer.valueOf(100));
                    int dimension = memoryStorageConfig.getDimension();
                    HashMap<String, Object> knnVector = new HashMap<String, Object>();
                    knnVector.put("type", "knn_vector");
                    knnVector.put("dimension", dimension);
                    HashMap<String, Object> method = new HashMap<String, Object>();
                    method.put("name", "hnsw");
                    method.put("space_type", "cosinesimil");
                    method.put("engine", "lucene");
                    method.put("parameters", Map.of("ef_construction", 100, "m", 16));
                    knnVector.put("method", method);
                    properties.put("memory_embedding", knnVector);
                } else if (memoryStorageConfig.getEmbeddingModelType() == FunctionName.SPARSE_ENCODING) {
                    properties.put("memory_embedding", Map.of("type", "rank_features"));
                }
            }
            indexMappings.put("properties", properties);
            this.client.admin().indices().create(new CreateIndexRequest(indexName).settings(indexSettings).mapping(indexMappings), ActionListener.wrap(response -> {
                if (response.isAcknowledged()) {
                    log.info("Successfully created memory data index: {}", (Object)indexName);
                    listener.onResponse((Object)true);
                } else {
                    listener.onFailure((Exception)new RuntimeException("Failed to create memory data index: " + indexName));
                }
            }, e -> {
                if (e instanceof ResourceAlreadyExistsException) {
                    log.info("Memory data index already exists: {}", (Object)indexName);
                    listener.onResponse((Object)true);
                } else {
                    log.error("Error creating memory data index: {}", (Object)indexName, e);
                    listener.onFailure(e);
                }
            }));
        }
        catch (Exception e2) {
            log.error("Failed to create memory data index", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private void updateMemoryContainer(String memoryContainerId, MLMemoryContainer container, ActionListener<Boolean> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.sdkClient.putDataObjectAsync(((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)PutDataObjectRequest.builder().tenantId(container.getTenantId())).index(".plugins-ml-memory-container")).id(memoryContainerId)).dataObject((ToXContentObject)container).build()).whenComplete((r, throwable) -> {
                context.restore();
                if (throwable != null) {
                    Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                    log.error("Failed to update memory container", (Throwable)cause);
                    listener.onFailure(cause);
                } else {
                    try {
                        IndexResponse indexResponse = r.indexResponse();
                        log.info("Successfully updated memory container with ID: {}", (Object)memoryContainerId);
                        listener.onResponse((Object)true);
                    }
                    catch (Exception e) {
                        listener.onFailure(e);
                    }
                }
            });
        }
        catch (Exception e) {
            log.error("Failed to update memory container", (Throwable)e);
            listener.onFailure(e);
        }
    }

    private void indexMemoryContainer(MLMemoryContainer container, ActionListener<String> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.sdkClient.putDataObjectAsync(((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)PutDataObjectRequest.builder().tenantId(container.getTenantId())).index(".plugins-ml-memory-container")).dataObject((ToXContentObject)container).build()).whenComplete((r, throwable) -> {
                context.restore();
                if (throwable != null) {
                    Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                    log.error("Failed to index memory container", (Throwable)cause);
                    listener.onFailure(cause);
                } else {
                    try {
                        IndexResponse indexResponse = r.indexResponse();
                        assert (indexResponse != null);
                        if (indexResponse.getResult() == DocWriteResponse.Result.CREATED) {
                            String generatedId = indexResponse.getId();
                            log.info("Successfully created memory container with ID: {}", (Object)generatedId);
                            listener.onResponse((Object)generatedId);
                        } else {
                            listener.onFailure((Exception)new RuntimeException("Failed to create memory container"));
                        }
                    }
                    catch (Exception e) {
                        listener.onFailure(e);
                    }
                }
            });
        }
        catch (Exception e) {
            log.error("Failed to save memory container", (Throwable)e);
            listener.onFailure(e);
        }
    }

    private void validateModels(MemoryStorageConfig config, ActionListener<Boolean> listener) {
        if (config == null || !config.isSemanticStorageEnabled()) {
            listener.onResponse((Object)true);
            return;
        }
        if (config.getLlmModelId() != null) {
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                ActionListener wrappedListener = ActionListener.runBefore((ActionListener)ActionListener.wrap(llmModel -> {
                    if (llmModel.getAlgorithm() != FunctionName.REMOTE) {
                        listener.onFailure((Exception)new IllegalArgumentException(String.format("LLM model must be a REMOTE model, found: %s", llmModel.getAlgorithm())));
                        return;
                    }
                    this.validateEmbeddingModel(config, listener);
                }, e -> {
                    log.error("Failed to get LLM model: {}", (Object)config.getLlmModelId(), e);
                    listener.onFailure((Exception)new IllegalArgumentException(String.format("LLM model with ID %s not found", config.getLlmModelId())));
                }), () -> ((ThreadContext.StoredContext)context).restore());
                this.mlModelManager.getModel(config.getLlmModelId(), (ActionListener<MLModel>)wrappedListener);
            }
        } else {
            this.validateEmbeddingModel(config, listener);
        }
    }

    private void validateEmbeddingModel(MemoryStorageConfig config, ActionListener<Boolean> listener) {
        if (config.getEmbeddingModelId() != null) {
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                ActionListener wrappedListener = ActionListener.runBefore((ActionListener)ActionListener.wrap(embeddingModel -> {
                    FunctionName expectedType;
                    FunctionName modelAlgorithm = embeddingModel.getAlgorithm();
                    if (modelAlgorithm != (expectedType = config.getEmbeddingModelType()) && modelAlgorithm != FunctionName.REMOTE) {
                        listener.onFailure((Exception)new IllegalArgumentException(String.format("Embedding model must be of type %s or REMOTE, found: %s", expectedType, modelAlgorithm)));
                        return;
                    }
                    listener.onResponse((Object)true);
                }, e -> {
                    log.error("Failed to get embedding model: {}", (Object)config.getEmbeddingModelId(), e);
                    listener.onFailure((Exception)new IllegalArgumentException(String.format("Embedding model with ID %s not found", config.getEmbeddingModelId())));
                }), () -> ((ThreadContext.StoredContext)context).restore());
                this.mlModelManager.getModel(config.getEmbeddingModelId(), (ActionListener<MLModel>)wrappedListener);
            }
        } else {
            listener.onResponse((Object)true);
        }
    }
}

