/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.ingest;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.engine.annotation.Ingester;
import org.opensearch.ml.engine.ingest.AbstractIngestion;
import org.opensearch.ml.engine.utils.S3Utils;
import org.opensearch.transport.client.Client;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Exception;

@Ingester(value="s3")
public class S3DataIngestion
extends AbstractIngestion {
    @Generated
    private static final Logger log = LogManager.getLogger(S3DataIngestion.class);
    public static final String SOURCE = "source";

    public S3DataIngestion(Client client) {
        super(client);
    }

    @Override
    public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
        String accessKey = (String)mlBatchIngestionInput.getCredential().get("access_key");
        String secretKey = (String)mlBatchIngestionInput.getCredential().get("secret_key");
        String sessionToken = (String)mlBatchIngestionInput.getCredential().get("session_token");
        String region = (String)mlBatchIngestionInput.getCredential().get("region");
        S3Client s3 = S3Utils.initS3Client(accessKey, secretKey, region, sessionToken);
        List s3Uris = (List)mlBatchIngestionInput.getDataSources().get(SOURCE);
        if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) {
            return 100.0;
        }
        boolean isSoleSource = s3Uris.size() == 1;
        List<Double> successRates = Collections.synchronizedList(new ArrayList());
        for (int sourceIndex = 0; sourceIndex < s3Uris.size(); ++sourceIndex) {
            successRates.add(this.ingestSingleSource(s3, (String)s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
        }
        return this.calculateSuccessRate(successRates);
    }

    public double ingestSingleSource(S3Client s3, String s3Uri, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, boolean isSoleSource, int bulkSize) {
        String bucketName = S3Utils.getS3BucketName(s3Uri);
        String keyName = S3Utils.getS3KeyName(s3Uri);
        GetObjectRequest getObjectRequest = (GetObjectRequest)GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
        double successRate = 0.0;
        try (ResponseInputStream s3is = AccessController.doPrivileged(() -> s3.getObject(getObjectRequest));
             BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)s3is, StandardCharsets.UTF_8));){
            CompletableFuture<Void> future;
            String line;
            ArrayList<String> linesBuffer = new ArrayList<String>();
            int lineCount = 0;
            AtomicInteger successfulBatches = new AtomicInteger(0);
            AtomicInteger failedBatches = new AtomicInteger(0);
            ArrayList<CompletableFuture<Void>> futures = new ArrayList<CompletableFuture<Void>>();
            while ((line = reader.readLine()) != null) {
                linesBuffer.add(line);
                if (++lineCount % bulkSize != 0) continue;
                future = new CompletableFuture();
                this.batchIngest(linesBuffer, mlBatchIngestionInput, this.getBulkResponseListener(successfulBatches, failedBatches, future), sourceIndex, isSoleSource);
                futures.add(future);
                linesBuffer.clear();
            }
            if (!linesBuffer.isEmpty()) {
                future = new CompletableFuture<Void>();
                this.batchIngest(linesBuffer, mlBatchIngestionInput, this.getBulkResponseListener(successfulBatches, failedBatches, future), sourceIndex, isSoleSource);
                futures.add(future);
            }
            reader.close();
            CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
            allFutures.join();
            int totalBatches = successfulBatches.get() + failedBatches.get();
            successRate = totalBatches == 0 ? 100.0 : (double)successfulBatches.get() / (double)totalBatches * 100.0;
        }
        catch (S3Exception e) {
            log.error("Error reading from S3: " + e.awsErrorDetails().errorMessage());
            throw e;
        }
        catch (PrivilegedActionException e) {
            throw new RuntimeException("Failed to get S3 Object: ", e);
        }
        catch (Exception e) {
            log.error(e.getMessage());
            throw new OpenSearchStatusException("Failed to batch ingest: " + e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        finally {
            s3.close();
        }
        return successRate;
    }
}

