/*
 * Decompiled with CFR 0.152.
 */
package io.seata.rm.datasource.exec;

import io.seata.common.exception.NotSupportYetException;
import io.seata.common.exception.ShouldNeverHappenException;
import io.seata.common.util.CollectionUtils;
import io.seata.common.util.StringUtils;
import io.seata.rm.datasource.PreparedStatementProxy;
import io.seata.rm.datasource.StatementProxy;
import io.seata.rm.datasource.exec.AbstractDMLBaseExecutor;
import io.seata.rm.datasource.exec.StatementCallback;
import io.seata.rm.datasource.sql.struct.ColumnMeta;
import io.seata.rm.datasource.sql.struct.TableRecords;
import io.seata.sqlparser.SQLInsertRecognizer;
import io.seata.sqlparser.SQLRecognizer;
import io.seata.sqlparser.struct.Null;
import io.seata.sqlparser.struct.SqlMethodExpr;
import io.seata.sqlparser.struct.SqlSequenceExpr;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InsertExecutor<T, S extends Statement>
extends AbstractDMLBaseExecutor<T, S> {
    private static final Logger LOGGER = LoggerFactory.getLogger(InsertExecutor.class);
    protected static final String ERR_SQL_STATE = "S1009";
    private static final String PLACEHOLDER = "?";

    public InsertExecutor(StatementProxy<S> statementProxy, StatementCallback<T, S> statementCallback, SQLRecognizer sqlRecognizer) {
        super(statementProxy, statementCallback, sqlRecognizer);
    }

    @Override
    protected TableRecords beforeImage() throws SQLException {
        return TableRecords.empty(this.getTableMeta());
    }

    @Override
    protected TableRecords afterImage(TableRecords beforeImage) throws SQLException {
        List<Object> pkValues = this.containsPK() ? this.getPkValuesByColumn() : (this.containsColumns() ? this.getPkValuesByAuto() : this.getPkValuesByColumn());
        TableRecords afterImage = this.buildTableRecords(pkValues);
        if (afterImage == null) {
            throw new SQLException("Failed to build after-image for insert");
        }
        return afterImage;
    }

    protected boolean containsPK() {
        SQLInsertRecognizer recognizer = (SQLInsertRecognizer)this.sqlRecognizer;
        List<String> insertColumns = recognizer.getInsertColumns();
        if (CollectionUtils.isEmpty(insertColumns)) {
            return false;
        }
        return this.containsPK(insertColumns);
    }

    protected boolean containsColumns() {
        SQLInsertRecognizer recognizer = (SQLInsertRecognizer)this.sqlRecognizer;
        List<String> insertColumns = recognizer.getInsertColumns();
        return insertColumns != null && !insertColumns.isEmpty();
    }

    protected List<Object> getPkValuesByColumn() throws SQLException {
        SQLInsertRecognizer recognizer = (SQLInsertRecognizer)this.sqlRecognizer;
        int pkIndex = this.getPkIndex();
        if (pkIndex == -1) {
            throw new ShouldNeverHappenException(String.format("pkIndex is %d", pkIndex));
        }
        List<Object> pkValues = null;
        if (this.statementProxy instanceof PreparedStatementProxy) {
            PreparedStatementProxy preparedStatementProxy = (PreparedStatementProxy)this.statementProxy;
            List<List<Object>> insertRows = recognizer.getInsertRows();
            if (insertRows != null && !insertRows.isEmpty()) {
                ArrayList<Object>[] parameters = preparedStatementProxy.getParameters();
                int rowSize = insertRows.size();
                if (rowSize == 1) {
                    Object pkValue = insertRows.get(0).get(pkIndex);
                    pkValues = PLACEHOLDER.equals(pkValue) ? parameters[pkIndex] : insertRows.stream().map(insertRow -> insertRow.get(pkIndex)).collect(Collectors.toList());
                } else {
                    int totalPlaceholderNum = -1;
                    pkValues = new ArrayList(rowSize);
                    for (int i = 0; i < rowSize; ++i) {
                        List<Object> row = insertRows.get(i);
                        if (row.isEmpty()) continue;
                        Object pkValue = row.get(pkIndex);
                        int currentRowPlaceholderNum = -1;
                        for (Object r : row) {
                            if (!PLACEHOLDER.equals(r)) continue;
                            ++totalPlaceholderNum;
                            ++currentRowPlaceholderNum;
                        }
                        if (PLACEHOLDER.equals(pkValue)) {
                            int idx = pkIndex;
                            if (i != 0) {
                                idx = totalPlaceholderNum - currentRowPlaceholderNum + pkIndex;
                            }
                            ArrayList<Object> parameter = parameters[idx];
                            pkValues.addAll(parameter);
                            continue;
                        }
                        pkValues.add(pkValue);
                    }
                }
            }
        } else {
            List<List<Object>> insertRows = recognizer.getInsertRows();
            pkValues = new ArrayList<Object>(insertRows.size());
            for (List<Object> row : insertRows) {
                pkValues.add(row.get(pkIndex));
            }
        }
        if (pkValues == null) {
            throw new ShouldNeverHappenException();
        }
        boolean b = this.checkPkValues(pkValues);
        if (!b) {
            throw new NotSupportYetException(String.format("not support sql [%s]", this.sqlRecognizer.getOriginalSQL()));
        }
        if (!pkValues.isEmpty() && pkValues.get(0) instanceof SqlSequenceExpr) {
            pkValues = this.getPkValuesBySequence(pkValues.get(0));
        } else if (pkValues.size() == 1 && pkValues.get(0) instanceof SqlMethodExpr) {
            pkValues = this.getPkValuesByAuto();
        } else if (!pkValues.isEmpty() && pkValues.get(0) instanceof Null) {
            pkValues = this.getPkValuesByAuto();
        }
        return pkValues;
    }

    protected List<Object> getPkValuesBySequence(Object expr) throws SQLException {
        try {
            return this.oracleByAuto();
        }
        catch (NotSupportYetException | SQLException exception) {
            if (!(expr instanceof SqlSequenceExpr)) {
                throw new NotSupportYetException(String.format("not support expr [%s]", expr.getClass().getName()));
            }
            SqlSequenceExpr sequenceExpr = (SqlSequenceExpr)expr;
            String sql = "SELECT " + sequenceExpr.getSequence() + ".currval FROM DUAL";
            LOGGER.warn("Fail to get auto-generated keys, use '{}' instead. Be cautious, statement could be polluted. Recommend you set the statement to return generated keys.", (Object)sql);
            ResultSet genKeys = this.statementProxy.getConnection().createStatement().executeQuery(sql);
            ArrayList<Object> pkValues = new ArrayList<Object>();
            while (genKeys.next()) {
                Object v = genKeys.getObject(1);
                pkValues.add(v);
            }
            return pkValues;
        }
    }

    protected List<Object> getPkValuesByAuto() throws SQLException {
        boolean oracle = StringUtils.equalsIgnoreCase("oracle", this.getDbType());
        if (oracle) {
            return this.oracleByAuto();
        }
        return this.defaultByAuto();
    }

    protected int getPkIndex() {
        SQLInsertRecognizer recognizer = (SQLInsertRecognizer)this.sqlRecognizer;
        List<String> insertColumns = recognizer.getInsertColumns();
        if (CollectionUtils.isNotEmpty(insertColumns)) {
            int insertColumnsSize = insertColumns.size();
            int pkIndex = -1;
            for (int paramIdx = 0; paramIdx < insertColumnsSize; ++paramIdx) {
                if (!this.equalsPK(insertColumns.get(paramIdx))) continue;
                pkIndex = paramIdx;
                break;
            }
            return pkIndex;
        }
        int pkIndex = -1;
        Map<String, ColumnMeta> allColumns = this.getTableMeta().getAllColumns();
        for (Map.Entry<String, ColumnMeta> entry : allColumns.entrySet()) {
            ++pkIndex;
            if (!this.equalsPK(entry.getValue().getColumnName())) continue;
            break;
        }
        return pkIndex;
    }

    private boolean checkPkValues(List<Object> pkValues) {
        boolean pkParameterHasNull = false;
        boolean pkParameterHasNotNull = false;
        boolean pkParameterHasExpr = false;
        if (pkValues.size() == 1) {
            return true;
        }
        for (Object pkValue : pkValues) {
            if (pkValue instanceof Null) {
                pkParameterHasNull = true;
                continue;
            }
            pkParameterHasNotNull = true;
            if (!(pkValue instanceof SqlMethodExpr)) continue;
            pkParameterHasExpr = true;
        }
        if (pkParameterHasExpr) {
            return false;
        }
        return !pkParameterHasNull || !pkParameterHasNotNull;
    }

    private List<Object> defaultByAuto() throws SQLException {
        ResultSet genKeys;
        Map<String, ColumnMeta> pkMetaMap = this.getTableMeta().getPrimaryKeyMap();
        if (pkMetaMap.size() != 1) {
            throw new NotSupportYetException();
        }
        ColumnMeta pkMeta = pkMetaMap.values().iterator().next();
        if (!pkMeta.isAutoincrement()) {
            throw new ShouldNeverHappenException();
        }
        try {
            genKeys = this.statementProxy.getTargetStatement().getGeneratedKeys();
        }
        catch (SQLException e) {
            if (ERR_SQL_STATE.equalsIgnoreCase(e.getSQLState())) {
                LOGGER.warn("Fail to get auto-generated keys, use 'SELECT LAST_INSERT_ID()' instead. Be cautious, statement could be polluted. Recommend you set the statement to return generated keys.");
                genKeys = this.statementProxy.getTargetStatement().executeQuery("SELECT LAST_INSERT_ID()");
            }
            throw e;
        }
        ArrayList<Object> pkValues = new ArrayList<Object>();
        while (genKeys.next()) {
            Object v = genKeys.getObject(1);
            pkValues.add(v);
        }
        try {
            genKeys.beforeFirst();
        }
        catch (SQLException e) {
            LOGGER.warn("Fail to reset ResultSet cursor. can not get primary key value");
        }
        return pkValues;
    }

    private List<Object> oracleByAuto() throws SQLException {
        Map<String, ColumnMeta> pkMetaMap = this.getTableMeta().getPrimaryKeyMap();
        if (pkMetaMap.size() != 1) {
            throw new NotSupportYetException();
        }
        ResultSet genKeys = this.statementProxy.getTargetStatement().getGeneratedKeys();
        ArrayList<Object> pkValues = new ArrayList<Object>();
        while (genKeys.next()) {
            Object v = genKeys.getObject(1);
            pkValues.add(v);
        }
        if (pkValues.isEmpty()) {
            throw new NotSupportYetException(String.format("not support sql [%s]", this.sqlRecognizer.getOriginalSQL()));
        }
        return pkValues;
    }
}

