/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.oracle;

import org.apache.seatunnel.api.table.catalog.Column;
import org.apache.seatunnel.api.table.catalog.TablePath;
import org.apache.seatunnel.api.table.converter.BasicTypeDefine;
import org.apache.seatunnel.api.table.converter.TypeConverter;
import org.apache.seatunnel.api.table.schema.event.AlterTableAddColumnEvent;
import org.apache.seatunnel.api.table.schema.event.AlterTableChangeColumnEvent;
import org.apache.seatunnel.api.table.schema.event.AlterTableColumnEvent;
import org.apache.seatunnel.api.table.schema.event.AlterTableModifyColumnEvent;
import org.apache.seatunnel.connectors.seatunnel.jdbc.config.JdbcOptions;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.converter.JdbcRowConverter;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.DatabaseIdentifier;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.JdbcDialect;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.JdbcDialectTypeMapper;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.SQLUtils;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.dialectenum.FieldIdeEnum;
import org.apache.seatunnel.connectors.seatunnel.jdbc.source.JdbcSourceTable;

import org.apache.commons.lang3.StringUtils;

import lombok.extern.slf4j.Slf4j;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

@Slf4j
public class OracleDialect implements JdbcDialect {

    private static final int DEFAULT_ORACLE_FETCH_SIZE = 128;
    public String fieldIde = FieldIdeEnum.ORIGINAL.getValue();
    private final boolean handleBlobAsString;

    public OracleDialect(String fieldIde) {
        this(fieldIde, JdbcOptions.HANDLE_BLOB_AS_STRING.defaultValue());
    }

    public OracleDialect() {
        this(FieldIdeEnum.ORIGINAL.getValue(), JdbcOptions.HANDLE_BLOB_AS_STRING.defaultValue());
    }

    public OracleDialect(String fieldIde, boolean handleBlobAsString) {
        this.fieldIde = fieldIde;
        this.handleBlobAsString = handleBlobAsString;
    }

    @Override
    public String dialectName() {
        return DatabaseIdentifier.ORACLE;
    }

    @Override
    public JdbcRowConverter getRowConverter() {
        return new OracleJdbcRowConverter();
    }

    @Override
    public TypeConverter<BasicTypeDefine> getTypeConverter() {
        return new OracleTypeConverter(true, handleBlobAsString);
    }

    @Override
    public String hashModForField(String fieldName, int mod) {
        return "MOD(ORA_HASH(" + quoteIdentifier(fieldName) + ")," + mod + ")";
    }

    @Override
    public JdbcDialectTypeMapper getJdbcDialectTypeMapper() {
        return new OracleTypeMapper(true, handleBlobAsString);
    }

    @Override
    public String quoteIdentifier(String identifier) {
        if (identifier.contains(".")) {
            String[] parts = identifier.split("\\.");
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < parts.length - 1; i++) {
                sb.append("\"").append(parts[i]).append("\"").append(".");
            }
            return sb.append("\"")
                    .append(getFieldIde(parts[parts.length - 1], fieldIde))
                    .append("\"")
                    .toString();
        }
        return "\"" + getFieldIde(identifier, fieldIde) + "\"";
    }

    @Override
    public String tableIdentifier(String database, String tableName) {
        return quoteIdentifier(tableName);
    }

    @Override
    public Optional<String> getUpsertStatement(
            String database, String tableName, String[] fieldNames, String[] uniqueKeyFields) {
        List<String> nonUniqueKeyFields =
                Arrays.stream(fieldNames)
                        .filter(fieldName -> !Arrays.asList(uniqueKeyFields).contains(fieldName))
                        .collect(Collectors.toList());
        String valuesBinding =
                Arrays.stream(fieldNames)
                        .map(fieldName -> ":" + fieldName + " " + quoteIdentifier(fieldName))
                        .collect(Collectors.joining(", "));

        String usingClause = String.format("SELECT %s FROM DUAL", valuesBinding);
        String onConditions =
                Arrays.stream(uniqueKeyFields)
                        .map(
                                fieldName ->
                                        String.format(
                                                "TARGET.%s=SOURCE.%s",
                                                quoteIdentifier(fieldName),
                                                quoteIdentifier(fieldName)))
                        .collect(Collectors.joining(" AND "));
        String updateSetClause =
                nonUniqueKeyFields.stream()
                        .map(
                                fieldName ->
                                        String.format(
                                                "TARGET.%s=SOURCE.%s",
                                                quoteIdentifier(fieldName),
                                                quoteIdentifier(fieldName)))
                        .collect(Collectors.joining(", "));
        String insertFields =
                Arrays.stream(fieldNames)
                        .map(this::quoteIdentifier)
                        .collect(Collectors.joining(", "));
        String insertValues =
                Arrays.stream(fieldNames)
                        .map(fieldName -> "SOURCE." + quoteIdentifier(fieldName))
                        .collect(Collectors.joining(", "));

        String upsertSQL =
                String.format(
                        " MERGE INTO %s TARGET"
                                + " USING (%s) SOURCE"
                                + " ON (%s) "
                                + " WHEN MATCHED THEN"
                                + " UPDATE SET %s"
                                + " WHEN NOT MATCHED THEN"
                                + " INSERT (%s) VALUES (%s)",
                        tableIdentifier(database, tableName),
                        usingClause,
                        onConditions,
                        updateSetClause,
                        insertFields,
                        insertValues);

        return Optional.of(upsertSQL);
    }

    @Override
    public PreparedStatement creatPreparedStatement(
            Connection connection, String queryTemplate, int fetchSize) throws SQLException {
        PreparedStatement statement =
                connection.prepareStatement(
                        queryTemplate, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
        if (fetchSize > 0) {
            statement.setFetchSize(fetchSize);
        } else {
            statement.setFetchSize(DEFAULT_ORACLE_FETCH_SIZE);
        }
        return statement;
    }

    @Override
    public TablePath parse(String tablePath) {
        return TablePath.of(tablePath, true);
    }

    @Override
    public String tableIdentifier(TablePath tablePath) {
        return quoteIdentifier(tablePath.getSchemaAndTableName());
    }

    @Override
    public Long approximateRowCntStatement(Connection connection, JdbcSourceTable table)
            throws SQLException {

        // 1. Use select count
        // 2. If no query is configured, use TABLE STATUS.
        // 3. If a query is configured but does not contain a WHERE clause and tablePath is
        // configured, use TABLE STATUS.
        // 4. If a query is configured with a WHERE clause, or a query statement is configured but
        // tablePath is TablePath.DEFAULT, use COUNT(*).

        String query = table.getQuery();

        boolean useTableStats =
                StringUtils.isBlank(query)
                        || (!query.toLowerCase().contains("where")
                                && table.getTablePath() != null
                                && !TablePath.DEFAULT
                                        .getFullName()
                                        .equals(table.getTablePath().getFullName()));

        if (table.getUseSelectCount()) {
            useTableStats = false;
            if (StringUtils.isBlank(query)) {
                query = "SELECT * FROM " + tableIdentifier(table.getTablePath());
            }
        }

        if (useTableStats) {
            TablePath tablePath = table.getTablePath();
            String rowCountQuery =
                    String.format(
                            "select NUM_ROWS from all_tables where OWNER = '%s' AND TABLE_NAME = '%s' ",
                            tablePath.getSchemaName(), tablePath.getTableName());
            try (Statement stmt = connection.createStatement()) {
                String analyzeTable =
                        String.format(
                                "analyze table %s compute statistics for table",
                                tableIdentifier(tablePath));
                if (!table.getSkipAnalyze()) {
                    log.info("Split Chunk, approximateRowCntStatement: {}", analyzeTable);
                    stmt.execute(analyzeTable);
                } else {
                    log.warn("Skip analyze, approximateRowCntStatement: {}", analyzeTable);
                }
                log.info("Split Chunk, approximateRowCntStatement: {}", rowCountQuery);
                try (ResultSet rs = stmt.executeQuery(rowCountQuery)) {
                    if (!rs.next()) {
                        throw new SQLException(
                                String.format(
                                        "No result returned after running query [%s]",
                                        rowCountQuery));
                    }
                    return rs.getLong(1);
                }
            }
        }
        return SQLUtils.countForSubquery(connection, query);
    }

    @Override
    public Object queryNextChunkMax(
            Connection connection,
            JdbcSourceTable table,
            String columnName,
            int chunkSize,
            Object includedLowerBound)
            throws SQLException {
        String quotedColumn = quoteIdentifier(columnName);
        String sqlQuery;
        if (StringUtils.isNotBlank(table.getQuery())) {
            sqlQuery =
                    String.format(
                            "SELECT MAX(%s) FROM ("
                                    + "SELECT %s FROM (%s) WHERE %s >= ? ORDER BY %s ASC "
                                    + ") WHERE ROWNUM <= %s",
                            quotedColumn,
                            quotedColumn,
                            table.getQuery(),
                            quotedColumn,
                            quotedColumn,
                            chunkSize);
        } else {
            sqlQuery =
                    String.format(
                            "SELECT MAX(%s) FROM ("
                                    + "SELECT %s FROM %s WHERE %s >= ? ORDER BY %s ASC "
                                    + ") WHERE ROWNUM <= %s",
                            quotedColumn,
                            quotedColumn,
                            tableIdentifier(table.getTablePath()),
                            quotedColumn,
                            quotedColumn,
                            chunkSize);
        }

        try (PreparedStatement ps = connection.prepareStatement(sqlQuery)) {
            ps.setObject(1, includedLowerBound);
            try (ResultSet rs = ps.executeQuery()) {
                if (!rs.next()) {
                    // this should never happen
                    throw new SQLException(
                            String.format("No result returned after running query [%s]", sqlQuery));
                }
                return rs.getObject(1);
            }
        }
    }

    @Override
    public Object[] sampleDataFromColumn(
            Connection connection,
            JdbcSourceTable table,
            String columnName,
            int samplingRate,
            int fetchSize)
            throws Exception {
        String sampleQuery;
        if (StringUtils.isNotBlank(table.getQuery())) {
            sampleQuery =
                    String.format(
                            "SELECT %s FROM (%s) T", quoteIdentifier(columnName), table.getQuery());
        } else {
            sampleQuery =
                    String.format(
                            "SELECT %s FROM %s",
                            quoteIdentifier(columnName), tableIdentifier(table.getTablePath()));
        }

        try (PreparedStatement stmt = creatPreparedStatement(connection, sampleQuery, fetchSize)) {
            try (ResultSet rs = stmt.executeQuery()) {
                int count = 0;
                List<Object> results = new ArrayList<>();

                while (rs.next()) {
                    count++;
                    if (count % samplingRate == 0) {
                        results.add(rs.getObject(1));
                    }
                    if (Thread.currentThread().isInterrupted()) {
                        throw new InterruptedException("Thread interrupted");
                    }
                }
                Object[] resultsArray = results.toArray();
                Arrays.sort(resultsArray);
                return resultsArray;
            }
        }
    }

    @Override
    public void applySchemaChange(
            Connection connection, TablePath tablePath, AlterTableAddColumnEvent event)
            throws SQLException {
        List<String> ddlSQL = new ArrayList<>();
        ddlSQL.add(buildUpdateColumnSQL(connection, tablePath, event));

        if (event.getColumn().getComment() != null) {
            ddlSQL.add(buildUpdateColumnCommentSQL(tablePath, event.getColumn()));
        }

        try (Statement statement = connection.createStatement()) {
            for (String sql : ddlSQL) {
                log.info("Executing add column SQL: {}", sql);
                statement.execute(sql);
            }
        }
    }

    @Override
    public void applySchemaChange(
            Connection connection, TablePath tablePath, AlterTableChangeColumnEvent event)
            throws SQLException {
        List<String> ddlSQL = new ArrayList<>();
        if (event.getOldColumn() != null
                && !(event.getColumn().getName().equals(event.getOldColumn()))) {
            StringBuilder sqlBuilder =
                    new StringBuilder()
                            .append("ALTER TABLE ")
                            .append(tableIdentifier(tablePath))
                            .append(" RENAME COLUMN ")
                            .append(quoteIdentifier(event.getOldColumn()))
                            .append(" TO ")
                            .append(quoteIdentifier(event.getColumn().getName()));
            ddlSQL.add(sqlBuilder.toString());
        }

        try (Statement statement = connection.createStatement()) {
            for (String sql : ddlSQL) {
                log.info("Executing change column SQL: {}", sql);
                statement.execute(sql);
            }
        }

        if (event.getColumn().getDataType() != null) {
            applySchemaChange(
                    connection,
                    tablePath,
                    AlterTableModifyColumnEvent.modify(event.tableIdentifier(), event.getColumn()));
        }
    }

    @Override
    public void applySchemaChange(
            Connection connection, TablePath tablePath, AlterTableModifyColumnEvent event)
            throws SQLException {
        List<String> ddlSQL = new ArrayList<>();
        ddlSQL.add(buildUpdateColumnSQL(connection, tablePath, event));

        if (event.getColumn().getComment() != null) {
            ddlSQL.add(buildUpdateColumnCommentSQL(tablePath, event.getColumn()));
        }

        try (Statement statement = connection.createStatement()) {
            for (String sql : ddlSQL) {
                log.info("Executing modify column SQL: {}", sql);
                statement.execute(sql);
            }
        }
    }

    private String buildUpdateColumnSQL(
            Connection connection, TablePath tablePath, AlterTableColumnEvent event)
            throws SQLException {
        String actionType;
        Column column;
        if (event instanceof AlterTableModifyColumnEvent) {
            actionType = "MODIFY";
            column = ((AlterTableModifyColumnEvent) event).getColumn();
        } else if (event instanceof AlterTableAddColumnEvent) {
            actionType = "ADD";
            column = ((AlterTableAddColumnEvent) event).getColumn();
        } else {
            throw new IllegalArgumentException("Unsupported AlterTableColumnEvent: " + event);
        }
        String sourceDialectName = event.getSourceDialectName();
        boolean sameCatalog = StringUtils.equals(dialectName(), sourceDialectName);
        BasicTypeDefine typeDefine = getTypeConverter().reconvert(column);
        String columnType = sameCatalog ? column.getSourceType() : typeDefine.getColumnType();
        StringBuilder sqlBuilder =
                new StringBuilder()
                        .append("ALTER TABLE  ")
                        .append(tableIdentifier(tablePath))
                        .append(" ")
                        .append(actionType)
                        .append(" ")
                        .append(quoteIdentifier(column.getName()))
                        .append(" ")
                        .append(columnType);
        // Only decorate with default value when source dialect is same as sink dialect
        // Todo Support for cross-database default values for ddl statements
        if (column.getDefaultValue() != null && sameCatalog) {
            sqlBuilder.append(" ").append(sqlClauseWithDefaultValue(typeDefine, sourceDialectName));
        }
        if (event instanceof AlterTableModifyColumnEvent) {
            boolean targetColumnNullable =
                    columnIsNullable(connection, tablePath, column.getName());
            if (column.isNullable() != targetColumnNullable) {
                sqlBuilder.append(" ").append(column.isNullable() ? "NULL" : "NOT NULL");
            }
        } else {
            sqlBuilder.append(" ").append(column.isNullable() ? "NULL" : "NOT NULL");
        }
        return sqlBuilder.toString();
    }

    private String buildUpdateColumnCommentSQL(TablePath tablePath, Column column) {
        return String.format(
                "COMMENT ON COLUMN %s.%s IS '%s'",
                tableIdentifier(tablePath), quoteIdentifier(column.getName()), column.getComment());
    }

    private boolean columnIsNullable(Connection connection, TablePath tablePath, String column)
            throws SQLException {
        String selectColumnSQL =
                "SELECT"
                        + "        NULLABLE FROM"
                        + "        ALL_TAB_COLUMNS c"
                        + "        WHERE c.owner = '"
                        + tablePath.getSchemaName()
                        + "'"
                        + "        AND c.table_name = '"
                        + tablePath.getTableName()
                        + "'"
                        + "        AND c.column_name = '"
                        + column
                        + "'";
        try (Statement statement = connection.createStatement()) {
            ResultSet rs = statement.executeQuery(selectColumnSQL);
            rs.next();
            return rs.getString("NULLABLE").equals("Y");
        }
    }

    @Override
    public String dualTable() {
        return " FROM dual ";
    }

    @Override
    public String getCollateSql(String collate) {
        if (StringUtils.isNotBlank(collate)) {
            StringBuilder sql = new StringBuilder();
            sql.append("NLSSORT(")
                    .append("char_val")
                    .append(", 'NLS_SORT=")
                    .append(collate)
                    .append("')");
            return sql.toString();
        } else {
            return "char_val";
        }
    }
}
