package com.chinamcloud.spider.system.config.dbinterceptors;

import com.alibaba.druid.util.StringUtils;
import com.chinamcloud.spider.system.config.dbinterceptors.dm.DynamicBoundSql;
import com.chinamcloud.spider.system.config.dbinterceptors.dm.DynamicSqlSource;
import lombok.SneakyThrows;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.keygen.NoKeyGenerator;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.JdbcType;

import java.lang.reflect.Field;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})})
public class DMKeywordInterceptor implements Interceptor {

    @Override
    public Object intercept(final Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Class<? extends MappedStatement> mappedStatementClass = mappedStatement.getClass();
        Field sqlSourceField = mappedStatementClass.getDeclaredField("sqlSource");
        sqlSourceField.setAccessible(true);
        final SqlSource sqlSource = (SqlSource) sqlSourceField.get(mappedStatement);
        sqlSourceField.set(mappedStatement, DynamicSqlSource.getDynamicSqlSource(sqlSource,invocation));
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        // 可以在这里设置拦截器的属性
    }

    public enum ProcessTokenEnum {
        /**
         * 对于向preparedStatement中传入null值参数其jdbcType必须要正确不能为空
         */

        SET_NULL_JDBC_TYPE {
            boolean processFlag = false;

            private final Field PARAMETERMAPPING_JDBC_FIELD;
            {
                try {
                    PARAMETERMAPPING_JDBC_FIELD = ParameterMapping.class.getDeclaredField("jdbcType");
                    PARAMETERMAPPING_JDBC_FIELD.setAccessible(true);
                } catch (NoSuchFieldException e) {
                    throw new RuntimeException(e);
                }
            }

            @Override
            public boolean shouldProcess(DynamicBoundSql originSql, Invocation invocation) {
                return true;
            }

            /**
             *
             * 这个值在mybatis-config.xml被设置为了java.sql.JDBCType#NULL,这里设置回默认的java.sql.JDBCType#OTHER
             */
            @Override
            @SneakyThrows
            public DynamicBoundSql process(DynamicBoundSql oldSql, Invocation invocation) {
                if(!processFlag){
                    MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
                    Configuration configuration = mappedStatement.getConfiguration();
                    configuration.setJdbcTypeForNull(JdbcType.OTHER);
                    processFlag = true;
                }
                //遍历获取jdbcType为空的parameterMapping
                List<ParameterMapping> parameterMappings = oldSql.getParameterMappings();
                ResultMap baseResultMap = ProcessTokenEnum.getBaseResultMap(invocation);
                if (baseResultMap == null) {
                    return oldSql;
                }
                List<ResultMapping> resultMappings = baseResultMap.getResultMappings();
                HashMap<String, JdbcType> propertyJdbcTypeMap = new HashMap<>();
                for (ResultMapping resultMapping : resultMappings) {
                    String property = resultMapping.getProperty();
                    JdbcType jdbcType_act = resultMapping.getJdbcType();
                    propertyJdbcTypeMap.put(property,jdbcType_act);
                }
                for (ParameterMapping parameterMapping : parameterMappings) {
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (jdbcType == null) {
                        String property = parameterMapping.getProperty();
                        JdbcType type = propertyJdbcTypeMap.get(property);
                        if ( type != null) {
                            PARAMETERMAPPING_JDBC_FIELD.set(parameterMapping, type);
                        }
                    }
                }
                return oldSql;
            }
        },

        /**
         * 达梦的关键词之前在mysql作为表字段了,查询的时候加双引号
         */
        KEYWORD {
            //需要加引号处理的token
            final String[] targetToken = {"COMMENT","SYNONYM","SYSTEM","STAT","RULE","PAGE","MAP","JOB","DAILY","CATALOG"};

            @Override
            public boolean shouldProcess(DynamicBoundSql originSql, Invocation invocation) {
                String upperCaseSql = originSql.getSql().toUpperCase();
                for (String token : targetToken) {
                    if (upperCaseSql.contains(token)) {
                        return true;
                    }
                }
                return false;
            }

            @Override
            @SneakyThrows
            public DynamicBoundSql process(final DynamicBoundSql oldSql, Invocation invocation) {
                String processSql = oldSql.getSql();
                for (String s : targetToken) {
                    processSql = processSql.replaceAll("(?i)(?<=[\\s,.])+" + s + "(?=[\\s,])+", "\"" + s + "\"");
                }
                oldSql.setSql(processSql);
                return oldSql;
            }
        },

        /**
         * mysql函数subdate达梦不支持
         * 可以替换为add_time - **INTERVAL** **'60'** **SECOND**
         * 但是暂时不做兼容直接替换为 *mapper.xml改为写法 now() > add_time - interval '60' SECOND
         */
        @Deprecated
        SUBDATE,

        /**
         * mysql中非聚合字段不存在于group by中是非标准sql写法,
         * 达梦不直接支持但是提供了兼容方式:在select后加入 \/*+ GROUP_OPT_FLAG(1)*\/
         */
        GROUP_BY {
            //兼容命令
            private final String COMPATIBILITY_IDENTIFICATION = "/*+ GROUP_OPT_FLAG(1)*/";

            @Override
            public boolean shouldProcess(DynamicBoundSql dynamicSqlBound, Invocation invocation) {
                String originSql2 = dynamicSqlBound.getSql().replace(" ", "");
                //转小写
                originSql2 = originSql2.toUpperCase();
                return originSql2.contains("GROUPBY") && !originSql2.contains(COMPATIBILITY_IDENTIFICATION.toUpperCase());
            }

            @Override
            public DynamicBoundSql process(DynamicBoundSql dynamicSqlBound, Invocation invocation) {
                //正则获取sql中第一个select或者SELECT 位置,在其后加入COMPATIBILITY_IDENTIFICATION
                String sqlNew = dynamicSqlBound.getSql().replaceFirst("(?i)(?<=SELECT)[^aA]*?(?=FROM)", " " + COMPATIBILITY_IDENTIFICATION + " " + "$0");
                dynamicSqlBound.setSql(sqlNew);
                return dynamicSqlBound;
            }
        },

        /**
         * 达梦数据库不直接支持在插入数据的时候,设置自增长字段的值
         * 处理方式是插入数据前set IDENTITY_INSERT TABLE_NAME ON;
         * 插入数据后set IDENTITY_INSERT TABLE_NAME OFF;
         */
        INSERT_PK {
            /**
             * 匹配 插入语句中包含*id*
             * 这里默认自增长的字段都带有id(本系统仅主键: id/*id)
             */
            @Override
            public boolean shouldProcess(DynamicBoundSql oldSql, Invocation invocation) {
                String upperSql = oldSql.getSql().toUpperCase();
                Object arg0 = invocation.getArgs()[0];
                if (!(arg0 instanceof MappedStatement)) {
                    return false;
                }
                MappedStatement statement = (MappedStatement) arg0;
                return statement.getSqlCommandType() == SqlCommandType.INSERT
                        && (statement.getKeyGenerator()==null || statement.getKeyGenerator() instanceof NoKeyGenerator)
                        && upperSql.contains("ID")
                        && !upperSql.contains("IDENTITY_INSERT");
            }

            @Override
            @SuppressWarnings("unchecked,rawtypes")
            @SneakyThrows
            public DynamicBoundSql process(DynamicBoundSql oldSqlBound, Invocation invocation) {
                //正则获取sql中的表名
                // 正则表达式匹配 INSERT INTO 之后的第一个标识符
                String sqlOld = oldSqlBound.getSql();
                Pattern pattern = Pattern.compile("(?i)INSERT\\s+INTO\\s+(\\S+)\\s+");
                Matcher matcher = pattern.matcher(sqlOld);
                String tableName = "";
                if (matcher.find()) {
                    // 返回匹配到的表名
                    tableName = matcher.group(1);
                }
                if (StringUtils.isEmpty(tableName)) {
                    System.err.println("没有匹配到表名:" + sqlOld);
                    return oldSqlBound;
                }
                String insertBefore = "set IDENTITY_INSERT " + tableName + " ON;\n";
                String insertAfter = "set IDENTITY_INSERT " + tableName + " OFF\n";
                if (!sqlOld.endsWith(";")) {
                    sqlOld = sqlOld + ";";
                }
                //兼容mybatis的inserts的批量插入,在达梦数据库是错误写法,自增长键不能指定为null(mysql中即使指定为null也能正常生成)
                //INSERT INTO CRMS_CATALOG_PRIVILEGE (id, catalog_id, code_attribute, role_id, add_time) VALUES (null, 9999, 'test', 'test', '2024-08-08');
                Object arg0 = invocation.getArgs()[0];
                if (arg0 instanceof MappedStatement) {
                    MappedStatement statement = (MappedStatement) arg0;
                    String id = statement.getId();
                    if (id.endsWith("inserts")) {
                        //批量插入
                        //获取idResultMapping
                        ResultMap baseResultMap = ProcessTokenEnum.getBaseResultMap(invocation);
                        ResultMapping idResultMapping = baseResultMap ==null ? null: baseResultMap.getIdResultMappings().get(0);
                        if (idResultMapping != null) {
                            Object arg1 = invocation.getArgs()[1];
                            if (arg1 instanceof Map) {
                                Map<String, Object> arg1Map = (Map<String, Object>) arg1;
                                List entities = (List) arg1Map.get("list");
                                if (entities == null) {
                                    entities = (List) arg1Map.get("collection");
                                }
                                if (entities != null && !entities.isEmpty()) {
                                    int size = entities.size();
                                    //简单处理有任何实体的主键没有手动设置则所有实体都不插手动插入主键
                                    boolean isAnyNullPK = false;
                                    Field idField = baseResultMap.getType().getDeclaredField(idResultMapping.getProperty());
                                    idField.setAccessible(true);
                                    for (int i = 1; i <= size; i++) {
                                        Object column_id = idField.get(entities.get(i - 1));
                                        boolean nullPk = column_id == null;
                                        isAnyNullPK |= nullPk;
                                        if (isAnyNullPK) {
                                            idField.set(entities.get(i - 1), null);
                                        }
                                    }
                                    idField.setAccessible(false);
                                    //处理sql映射参数
                                    //表名后的字段合集
                                    //正则获取后面每组值
                                    //(?,?,?,?)
                                    Pattern patternValues = Pattern.compile("(?i)\\s*\\([^)]+\\)(?!\\s*VALUES)");
                                    Matcher matcherValue = patternValues.matcher(sqlOld);
                                    ArrayList<String> values = new ArrayList<>();
                                    while (matcherValue.find()) {
                                        values.add(matcherValue.group());
                                    }
                                    if (isAnyNullPK) {
                                        //去掉主键字段
                                        sqlOld = delOneColumn(sqlOld, idResultMapping, values);
                                        //去掉映射参数
                                        List<ParameterMapping> parameterMappings = oldSqlBound.getParameterMappings();
                                        LinkedList<ParameterMapping> parameterMappingsNew = new LinkedList<>();
                                        for (ParameterMapping parameterMapping : parameterMappings) {
                                            String property = parameterMapping.getProperty();
                                            String actProperty = property.substring(property.lastIndexOf(".") + 1);
                                            if (!idResultMapping.getProperty().equalsIgnoreCase(actProperty)) {
                                                parameterMappingsNew.add(parameterMapping);
                                            }
                                            //需要去掉这个参数
                                        }
                                        oldSqlBound.setParameterMappings(parameterMappingsNew);
                                    }
                                }
                            }
                        }
                    }
                }
                oldSqlBound.setSql(insertBefore + (sqlOld.endsWith(";") ? sqlOld : sqlOld + ";") + insertAfter);
                return oldSqlBound;
            }

            private String delOneColumn(String sqlOld, final ResultMapping idResultMap, final ArrayList<String> values) {
                sqlOld = sqlOld.replaceFirst("(?i)" + idResultMap.getColumn() + "\\s*,", "");
                //去掉每组value的一个 ?,
                String[] split = sqlOld.split("(?i)VALUES");
                String sql1 = split[0];
                StringBuilder builder = new StringBuilder();
                builder.append(sql1);
                builder.append(" VALUES ");
                for (int i = 0; i < values.size(); i++) {
                    String item = values.get(i).replaceFirst("\\?\\s*,", "");
                    if (i != values.size() - 1) {
                        item = item + ",";
                    }
                    builder.append(item);
                }
                sqlOld = builder.toString();
                return sqlOld;
            }
        },

        /**
         * mysql/group_concat--->dm/wm_concat
         */
        GROUP_CONCAT {
            @Override
            public boolean shouldProcess(DynamicBoundSql originSql, Invocation invocation) {
                String lowerSql = originSql.getSql().toUpperCase();
                //正则如果包含函数group_concat(.)则替换为wm_concat
                return lowerSql.contains("GROUP_CONCAT(");
            }

            @Override
            public DynamicBoundSql process(DynamicBoundSql oldSql, Invocation invocation) {
                String sql = oldSql.getSql();
                String funcOld = "(?i)" + this.name() + "\\(";
                String funcNew = "WM_CONCAT" + "\\(";
                sql = sql.replaceAll(funcOld, funcNew);
                oldSql.setSql(sql);
                return oldSql;
            }
        },

        /**
         * 达梦数据库返回插入数据的id,采用方式1
         * 方式一：insert语句末尾加入RETURNING id
         * 方式二：插入后使用 select LAST_INSERT_ID as last_id查询生成的id
         */
        RETURNING_ID {
            @Override
            public boolean shouldProcess(DynamicBoundSql originSql, Invocation invocation) {
                MappedStatement statement = (MappedStatement) invocation.getArgs()[0];
                String upperCase = originSql.getSql().toUpperCase();
                return statement.getSqlCommandType() == SqlCommandType.INSERT
                        && (statement.getKeyGenerator()!=null && !(statement.getKeyGenerator() instanceof NoKeyGenerator))
                        && !upperCase.contains(" RETURNING ");
            }

            @Override
            public DynamicBoundSql process(DynamicBoundSql oldSql, Invocation invocation) {
                ResultMap baseResultMap = ProcessTokenEnum.getBaseResultMap(invocation);
                final ResultMapping idResultMapping = baseResultMap ==null ? null: baseResultMap.getIdResultMappings().get(0);
                String sql = oldSql.getSql().toUpperCase();
                if (idResultMapping != null) {
                    String rgx = "(?i)INSERT\\s+INTO[\\s\\S]+?VALUES([\\s\\S]+?)(?=;|$)";
                    sql = sql.replaceFirst(rgx, "$0" + " RETURNING " + idResultMapping.getColumn() + " INTO :" + idResultMapping.getColumn());
                    oldSql.setSql(sql);
                }
                return oldSql;
            }
        };

        private volatile static Map<String, ResultMap> resultMaps;

        /**
         * 是否处理这个sql
         */
        public boolean shouldProcess(DynamicBoundSql originSql, Invocation invocation) {
            return false;
        }

        /**
         * 处理sql
         *
         * @param oldSql 老sql
         * @return 新sql
         */
        public DynamicBoundSql process(DynamicBoundSql oldSql, Invocation invocation) {
            return oldSql;
        }

        /**
         * 获取resultMap
         * @return 获取resultMap
         */
        @SneakyThrows
        @SuppressWarnings("unchecked")
        private static Map<String, ResultMap> getResultMaps(Invocation invocation) {
            if (resultMaps == null) {
                synchronized (DMKeywordInterceptor.class) {
                    if (resultMaps == null) {
                        MappedStatement arg0 = (MappedStatement) invocation.getArgs()[0];
                        Configuration configuration = arg0.getConfiguration();
                        Field resultMapField = Configuration.class.getDeclaredField("resultMaps");
                        resultMapField.setAccessible(true);
                        resultMaps = (Map<String, ResultMap>) resultMapField.get(configuration);
                    }
                }
            }
            return resultMaps;
        }

        private static ResultMap getBaseResultMap(Invocation invocation) {
            Object arg0 = invocation.getArgs()[0];
            if (arg0 instanceof MappedStatement) {
                MappedStatement statement = (MappedStatement) arg0;
                String id = statement.getId();
                //获取命名空间
                String mapperNameSpace = id.substring(0, id.lastIndexOf("."));
                //获取BaseResultMap
                Map<String, ResultMap> resultMaps = getResultMaps(invocation);
                ResultMap baseResultMap = null;
                try {
                    baseResultMap = resultMaps.get(mapperNameSpace + "." + "BaseResultMap");
                } catch (Exception e) {
                    //不存在baseResultMap的xml
                }
                return baseResultMap;
            }
            return null;
        }
    }
}
