1. 程式人生 > 其它 >【Mybatis】單獨使用mybatis的SQL模板解析

【Mybatis】單獨使用mybatis的SQL模板解析

前言

由於公司的專案歷史設計問題坑多不見底,新專案沒時間改,舊專案改不動。生產存在非常多的需要且只能通過資料庫指令碼改資料的違規操作。
每次開發到一半,一個工單就丟過來讓去生產改資料,天天寫指令碼,這怎麼受得了。
幾個季度下來,忍無可忍,我一拍桌子,決定chi.....從頭開發一個指令碼執行工具,管理寫過的指令碼,指令碼間可關聯執行,可跨資料庫。以解決臨時寫指令碼,及指令碼共享的問題。

設計上,指令碼執行工具的指令碼最早僅是佔位符,如 "{引數名}",然後replaceAll,後面想了下,為啥不能像mybatis一樣,或者說,為啥不直接借mybatis的xml模板解析功能來解析指令碼呢。
————————————————————————————————

正文

通過走查程式碼,可以發現mybatis在XMLLanguageDriver類裡的createSqlSource裡實現SQL解析,並且支援使用<script></script>執行字串模板。
OK,既然如此,就想辦法按這裡的程式碼執行即可。

直接修改XPathParser的話太過複雜,還得理內部邏輯,這裡走模擬法

構建配置Configuration

首先需要一個mybatis的xml最小配置,這裡寫為字串:

  String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n" //
            + "<!DOCTYPE configuration\r\n" //
            + " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n" //
            + " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"//
            + "<configuration>\r\n" //
            + "</configuration>";

然後根據mybatis的實現構建Configuration

點選檢視程式碼
InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
Configuration configuration = xmlConfigBuilder.parse();

有了Configuration就可以建立XPathParser來解析SQL了

點選檢視程式碼
String script = "<script>\nSELECT COUNT(0) FROM TABLE_NAME where 1=1 <if test=\"param!=null\"> AND num = #{param}</if> \n</script>";
XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), null);

Map<String, String> params = new HashMap<>();
param.put("param", "1");
BoundSql boundSql = source.getBoundSql(params);
String sql = boundSql.getSql();

結果

通過上述程式碼解析出來的SQL,若帶了引數number,則為
SELECT COUNT(0) FROM TABLE_NAME where 1=1 AND num = ?

是帶佔位符的安全度高的預編譯SQL,使用時需要構建PrepareStatement,然後通過如jdbc傳入prepareStatement及手動set引數。

自動化prepareStatement引數設定

參照Mybatis的DefaultParameterHandler類的setParameters

建立PrepareStatement可以通過如springJdbc進行構建:
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());

然後呼叫DefaultParameterHandler的setParameters即可

獲取完整SQL

參照mybatis-plus的PerformanceInterceptor類,該類可通過Statement獲取SQL
需要的程式碼如下:

點選檢視程式碼
    /**
     * COPY FROM {@link PerformanceInterceptor}
     */
    private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
    private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
    private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
    private Method oracleGetOriginalSqlMethod;
    private Method druidGetSQLMethod;

    /**
     * 獲取原始SQL, COPY FROM {@link PerformanceInterceptor}
     */
    private String getOriginSql(PreparedStatement statement) {
        String originalSql = null;
        String stmtClassName = statement.getClass().getName();
        if (DruidPooledPreparedStatement.equals(stmtClassName)) {
            try {
                if (druidGetSQLMethod == null) {
                    Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
                    druidGetSQLMethod = clazz.getMethod("getSql");
                }
                Object stmtSql = druidGetSQLMethod.invoke(statement);
                if (stmtSql instanceof String) {
                    originalSql = (String) stmtSql;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        } else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
            try {
                if (oracleGetOriginalSqlMethod != null) {
                    Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                    if (stmtSql instanceof String) {
                        originalSql = (String) stmtSql;
                    }
                } else {
                    Class<?> clazz = Class.forName(stmtClassName);
                    oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
                    if (oracleGetOriginalSqlMethod != null) {
                        // OraclePreparedStatementWrapper is not a public class, need set this.
                        oracleGetOriginalSqlMethod.setAccessible(true);
                        if (null != oracleGetOriginalSqlMethod) {
                            Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                            if (stmtSql instanceof String) {
                                originalSql = (String) stmtSql;
                            }
                        }
                    }
                }
            } catch (Exception e) {
                // ignore
            }
        }
        if (originalSql == null) {
            originalSql = statement.toString();
        }

        return originalSql;
    }

獲取出來的SQL在mybatis中為com.mxxxx : SELECT....,大概是這樣,需要擷取:

點選檢視程式碼
    /**
     * 獲取sql語句開頭部分, COPY FROM {@link PerformanceInterceptor}
     */
    private int indexOfSqlStart(String sql) {
        String upperCaseSql = sql.toUpperCase();
        Set<Integer> set = new HashSet<>();
        set.add(upperCaseSql.indexOf("SELECT "));
        set.add(upperCaseSql.indexOf("UPDATE "));
        set.add(upperCaseSql.indexOf("INSERT "));
        set.add(upperCaseSql.indexOf("DELETE "));
        set.remove(-1);
        if (CollectionUtils.isEmpty(set)) {
            return -1;
        }
        List<Integer> list = new ArrayList<>(set);
        list.sort(Comparator.naturalOrder());
        return list.get(0);
    }

這樣就獲取出來完整的SQL了,可喜可賀。但是!

Oracle資料庫使用該方法無效,打印出來還是預編譯SQL

完整SQL第二方案

在第一個方案執行失敗的情況下(可遍歷字串看有沒有?),增加該方案:
參照:https://www.cnblogs.com/aipan/p/7237854.html
增加LoggableStatement,改造DefaultParameterHandler的setParameters(Copy出來作為新方法):

點選檢視程式碼
    @SuppressWarnings({"unchecked", "rawtypes"})
    private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
        PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
        // 改造點
        LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        Object parameterObject = boundSql.getParameterObject();
        if (parameterMappings != null) {
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    if (boundSql.hasAdditionalParameter(propertyName)) {
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (parameterObject == null) {
                        value = null;
                    } else {
                        MetaObject metaObject = configuration.newMetaObject(parameterObject);
                        value = metaObject.getValue(propertyName);
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    try {
                        typeHandler.setParameter(ls, i + 1, value, jdbcType);
                    } catch (TypeException | SQLException e) {
                        throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
                    }
                }
            }
        }

        return ls;
    }

在獲取完整SQL失敗後,即可通過LoggableStatement來獲取SQL:

if (!isCorrectGetSql(boundSql, originalSql)) {
   originalSql = statement.getQueryString();
}

————————————————————————————

上面提到專案的該類完整程式碼:

點選檢視程式碼
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import org.apache.ibatis.builder.xml.XMLConfigBuilder;
import org.apache.ibatis.builder.xml.XMLMapperEntityResolver;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.parsing.XNode;
import org.apache.ibatis.parsing.XPathParser;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.scripting.xmltags.XMLScriptBuilder;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeException;
import org.apache.ibatis.type.TypeHandler;
import org.springframework.jdbc.core.JdbcTemplate;

import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;

/**
 * 使用Mybatis的XML解析來進行SQL構建的執行器
 */
public class MybatisTemplateSqlExcutor {

    /**
     * 指令碼模板
     */
    private static final String SCRIPT_TEMPLATE = "<script>\n%s\n</script>";

    /**
     * MP環境配置
     */
    private Configuration configuration;

    /**
     * COPY FROM {@link PerformanceInterceptor}
     */
    private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
    private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
    private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
    private Method oracleGetOriginalSqlMethod;
    private Method druidGetSQLMethod;

    /**
     * 構造器,初始化MP環境配置
     */
    public MybatisTemplateSqlExcutor() {
        InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
        XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
        configuration = xmlConfigBuilder.parse();
    }

    public String parseSql(JdbcTemplate jdbcTemplate, String sqlTemplate, Map<String, Object> params) throws SQLException {
        String script = String.format(SCRIPT_TEMPLATE, sqlTemplate);
        XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
        SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), Map.class);
        BoundSql boundSql = source.getBoundSql(params);
        LoggableStatement statement = buildPreparedStatement(jdbcTemplate, boundSql);
        String originalSql = getOriginSql(statement.getPreparedStatement());
        int index = indexOfSqlStart(originalSql);
        if (index > 0) {
            originalSql = originalSql.substring(index);
        }
        if (!isCorrectGetSql(boundSql, originalSql)) {
            originalSql = statement.getQueryString();
        }
        return originalSql;
    }

    /**
     * 從MP複製過來的指令碼解析方法
     */
    private SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
        XMLScriptBuilder builder = new XMLScriptBuilder(configuration, script, parameterType);
        return builder.parseScriptNode();
    }

    /**
     * 根據BoundSql組裝PreparedStatement,用於獲取實際SQL
     */
    @SuppressWarnings({"unchecked", "rawtypes"})
    private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
        PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
        LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        Object parameterObject = boundSql.getParameterObject();
        if (parameterMappings != null) {
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    if (boundSql.hasAdditionalParameter(propertyName)) {
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (parameterObject == null) {
                        value = null;
                    } else {
                        MetaObject metaObject = configuration.newMetaObject(parameterObject);
                        value = metaObject.getValue(propertyName);
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    try {
                        typeHandler.setParameter(ls, i + 1, value, jdbcType);
                    } catch (TypeException | SQLException e) {
                        throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
                    }
                }
            }
        }

        return ls;
    }

    /**
     * 獲取sql語句開頭部分
     */
    private int indexOfSqlStart(String sql) {
        String upperCaseSql = sql.toUpperCase();
        Set<Integer> set = new HashSet<>();
        set.add(upperCaseSql.indexOf("SELECT "));
        set.add(upperCaseSql.indexOf("UPDATE "));
        set.add(upperCaseSql.indexOf("INSERT "));
        set.add(upperCaseSql.indexOf("DELETE "));
        set.remove(-1);
        if (CollectionUtils.isEmpty(set)) {
            return -1;
        }
        List<Integer> list = new ArrayList<>(set);
        list.sort(Comparator.naturalOrder());
        return list.get(0);
    }

    /**
     * 獲取原始SQL, COPY FROM {@link PerformanceInterceptor}
     */
    private String getOriginSql(PreparedStatement statement) {
        String originalSql = null;
        String stmtClassName = statement.getClass().getName();
        if (DruidPooledPreparedStatement.equals(stmtClassName)) {
            try {
                if (druidGetSQLMethod == null) {
                    Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
                    druidGetSQLMethod = clazz.getMethod("getSql");
                }
                Object stmtSql = druidGetSQLMethod.invoke(statement);
                if (stmtSql instanceof String) {
                    originalSql = (String) stmtSql;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        } else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
            try {
                if (oracleGetOriginalSqlMethod != null) {
                    Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                    if (stmtSql instanceof String) {
                        originalSql = (String) stmtSql;
                    }
                } else {
                    Class<?> clazz = Class.forName(stmtClassName);
                    oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
                    if (oracleGetOriginalSqlMethod != null) {
                        // OraclePreparedStatementWrapper is not a public class, need set this.
                        oracleGetOriginalSqlMethod.setAccessible(true);
                        if (null != oracleGetOriginalSqlMethod) {
                            Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                            if (stmtSql instanceof String) {
                                originalSql = (String) stmtSql;
                            }
                        }
                    }
                }
            } catch (Exception e) {
                // ignore
            }
        }
        if (originalSql == null) {
            originalSql = statement.toString();
        }

        return originalSql;
    }

    /**
     * 獲取此方法名的具體 Method
     *
     * @param clazz class 物件
     * @param methodName 方法名
     * @return 方法
     */
    public Method getMethodRegular(Class<?> clazz, String methodName) {
        if (Object.class.equals(clazz)) {
            return null;
        }
        for (Method method : clazz.getDeclaredMethods()) {
            if (method.getName().equals(methodName)) {
                return method;
            }
        }
        return getMethodRegular(clazz.getSuperclass(), methodName);
    }

    /**
     * 判斷是否正確的獲取了SQL
     */
    private boolean isCorrectGetSql(BoundSql boundSql, String originSql) {
        return countQuestionMark(boundSql.getSql()) > countQuestionMark(originSql);
    }

    /**
     * 統計佔位符
     */
    private int countQuestionMark(String sql) {
        int result = 0;
        for (char c : sql.toCharArray())
            if (c == '?')
                result++;
        return result;
    }

    /**
     * 空MP配置模板,用於構建MP環境配置,(放這裡是由於部落格的編輯器識別問題,會導致高亮錯誤)
     */
    private static final String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n"
            + "<!DOCTYPE configuration\r\n"
            + " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n"
            + " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"
            + "<configuration>\r\n"
            + "</configuration>";
}

稱之為Excutor是因為這其實是個子類,只提供parseSql

其他

若有更好的方法,請務必告訴我(๑•̀ㅂ•́)و✧