【Mybatis】單獨使用mybatis的SQL模板解析
前言
由於公司的專案歷史設計問題坑多不見底,新專案沒時間改,舊專案改不動。生產存在非常多的需要且只能通過資料庫指令碼改資料的違規操作。
每次開發到一半,一個工單就丟過來讓去生產改資料,天天寫指令碼,這怎麼受得了。
幾個季度下來,忍無可忍,我一拍桌子,決定chi.....從頭開發一個指令碼執行工具,管理寫過的指令碼,指令碼間可關聯執行,可跨資料庫。以解決臨時寫指令碼,及指令碼共享的問題。
設計上,指令碼執行工具的指令碼最早僅是佔位符,如 "{引數名}",然後replaceAll,後面想了下,為啥不能像mybatis一樣,或者說,為啥不直接借mybatis的xml模板解析功能來解析指令碼呢。
————————————————————————————————
正文
通過走查程式碼,可以發現mybatis在XMLLanguageDriver
類裡的createSqlSource
裡實現SQL解析,並且支援使用<script></script>
執行字串模板。
OK,既然如此,就想辦法按這裡的程式碼執行即可。
直接修改XPathParser的話太過複雜,還得理內部邏輯,這裡走模擬法
構建配置Configuration
首先需要一個mybatis的xml最小配置,這裡寫為字串:
String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n" // + "<!DOCTYPE configuration\r\n" // + " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n" // + " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"// + "<configuration>\r\n" // + "</configuration>";
然後根據mybatis的實現構建Configuration
點選檢視程式碼
InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
Configuration configuration = xmlConfigBuilder.parse();
有了Configuration就可以建立XPathParser來解析SQL了
點選檢視程式碼
String script = "<script>\nSELECT COUNT(0) FROM TABLE_NAME where 1=1 <if test=\"param!=null\"> AND num = #{param}</if> \n</script>";
XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), null);
Map<String, String> params = new HashMap<>();
param.put("param", "1");
BoundSql boundSql = source.getBoundSql(params);
String sql = boundSql.getSql();
結果
通過上述程式碼解析出來的SQL,若帶了引數number,則為
SELECT COUNT(0) FROM TABLE_NAME where 1=1 AND num = ?
是帶佔位符的安全度高的預編譯SQL,使用時需要構建PrepareStatement
,然後通過如jdbc傳入prepareStatement及手動set引數。
自動化prepareStatement引數設定
參照Mybatis的DefaultParameterHandler
類的setParameters
建立PrepareStatement
可以通過如springJdbc進行構建:
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
然後呼叫DefaultParameterHandler
的setParameters即可
獲取完整SQL
參照mybatis-plus的PerformanceInterceptor
類,該類可通過Statement
獲取SQL
需要的程式碼如下:
點選檢視程式碼
/**
* COPY FROM {@link PerformanceInterceptor}
*/
private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
private Method oracleGetOriginalSqlMethod;
private Method druidGetSQLMethod;
/**
* 獲取原始SQL, COPY FROM {@link PerformanceInterceptor}
*/
private String getOriginSql(PreparedStatement statement) {
String originalSql = null;
String stmtClassName = statement.getClass().getName();
if (DruidPooledPreparedStatement.equals(stmtClassName)) {
try {
if (druidGetSQLMethod == null) {
Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
druidGetSQLMethod = clazz.getMethod("getSql");
}
Object stmtSql = druidGetSQLMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} catch (Exception e) {
e.printStackTrace();
}
} else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
try {
if (oracleGetOriginalSqlMethod != null) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} else {
Class<?> clazz = Class.forName(stmtClassName);
oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
if (oracleGetOriginalSqlMethod != null) {
// OraclePreparedStatementWrapper is not a public class, need set this.
oracleGetOriginalSqlMethod.setAccessible(true);
if (null != oracleGetOriginalSqlMethod) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
}
}
}
} catch (Exception e) {
// ignore
}
}
if (originalSql == null) {
originalSql = statement.toString();
}
return originalSql;
}
獲取出來的SQL在mybatis中為com.mxxxx : SELECT....,大概是這樣,需要擷取:
點選檢視程式碼
/**
* 獲取sql語句開頭部分, COPY FROM {@link PerformanceInterceptor}
*/
private int indexOfSqlStart(String sql) {
String upperCaseSql = sql.toUpperCase();
Set<Integer> set = new HashSet<>();
set.add(upperCaseSql.indexOf("SELECT "));
set.add(upperCaseSql.indexOf("UPDATE "));
set.add(upperCaseSql.indexOf("INSERT "));
set.add(upperCaseSql.indexOf("DELETE "));
set.remove(-1);
if (CollectionUtils.isEmpty(set)) {
return -1;
}
List<Integer> list = new ArrayList<>(set);
list.sort(Comparator.naturalOrder());
return list.get(0);
}
這樣就獲取出來完整的SQL了,可喜可賀。但是!
Oracle
資料庫使用該方法無效,打印出來還是預編譯SQL
完整SQL第二方案
在第一個方案執行失敗的情況下(可遍歷字串看有沒有?
),增加該方案:
參照:https://www.cnblogs.com/aipan/p/7237854.html
增加LoggableStatement,改造DefaultParameterHandler
的setParameters(Copy出來作為新方法):
點選檢視程式碼
@SuppressWarnings({"unchecked", "rawtypes"})
private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
// 改造點
LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
Object parameterObject = boundSql.getParameterObject();
if (parameterMappings != null) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
if (boundSql.hasAdditionalParameter(propertyName)) {
value = boundSql.getAdditionalParameter(propertyName);
} else if (parameterObject == null) {
value = null;
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
value = metaObject.getValue(propertyName);
}
TypeHandler typeHandler = parameterMapping.getTypeHandler();
JdbcType jdbcType = parameterMapping.getJdbcType();
if (value == null && jdbcType == null) {
jdbcType = configuration.getJdbcTypeForNull();
}
try {
typeHandler.setParameter(ls, i + 1, value, jdbcType);
} catch (TypeException | SQLException e) {
throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
}
}
}
}
return ls;
}
在獲取完整SQL失敗後,即可通過LoggableStatement來獲取SQL:
if (!isCorrectGetSql(boundSql, originalSql)) {
originalSql = statement.getQueryString();
}
————————————————————————————
上面提到專案的該類完整程式碼:
點選檢視程式碼
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import org.apache.ibatis.builder.xml.XMLConfigBuilder;
import org.apache.ibatis.builder.xml.XMLMapperEntityResolver;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.parsing.XNode;
import org.apache.ibatis.parsing.XPathParser;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.scripting.xmltags.XMLScriptBuilder;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeException;
import org.apache.ibatis.type.TypeHandler;
import org.springframework.jdbc.core.JdbcTemplate;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
/**
* 使用Mybatis的XML解析來進行SQL構建的執行器
*/
public class MybatisTemplateSqlExcutor {
/**
* 指令碼模板
*/
private static final String SCRIPT_TEMPLATE = "<script>\n%s\n</script>";
/**
* MP環境配置
*/
private Configuration configuration;
/**
* COPY FROM {@link PerformanceInterceptor}
*/
private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
private Method oracleGetOriginalSqlMethod;
private Method druidGetSQLMethod;
/**
* 構造器,初始化MP環境配置
*/
public MybatisTemplateSqlExcutor() {
InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
configuration = xmlConfigBuilder.parse();
}
public String parseSql(JdbcTemplate jdbcTemplate, String sqlTemplate, Map<String, Object> params) throws SQLException {
String script = String.format(SCRIPT_TEMPLATE, sqlTemplate);
XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), Map.class);
BoundSql boundSql = source.getBoundSql(params);
LoggableStatement statement = buildPreparedStatement(jdbcTemplate, boundSql);
String originalSql = getOriginSql(statement.getPreparedStatement());
int index = indexOfSqlStart(originalSql);
if (index > 0) {
originalSql = originalSql.substring(index);
}
if (!isCorrectGetSql(boundSql, originalSql)) {
originalSql = statement.getQueryString();
}
return originalSql;
}
/**
* 從MP複製過來的指令碼解析方法
*/
private SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
XMLScriptBuilder builder = new XMLScriptBuilder(configuration, script, parameterType);
return builder.parseScriptNode();
}
/**
* 根據BoundSql組裝PreparedStatement,用於獲取實際SQL
*/
@SuppressWarnings({"unchecked", "rawtypes"})
private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
Object parameterObject = boundSql.getParameterObject();
if (parameterMappings != null) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
if (boundSql.hasAdditionalParameter(propertyName)) {
value = boundSql.getAdditionalParameter(propertyName);
} else if (parameterObject == null) {
value = null;
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
value = metaObject.getValue(propertyName);
}
TypeHandler typeHandler = parameterMapping.getTypeHandler();
JdbcType jdbcType = parameterMapping.getJdbcType();
if (value == null && jdbcType == null) {
jdbcType = configuration.getJdbcTypeForNull();
}
try {
typeHandler.setParameter(ls, i + 1, value, jdbcType);
} catch (TypeException | SQLException e) {
throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
}
}
}
}
return ls;
}
/**
* 獲取sql語句開頭部分
*/
private int indexOfSqlStart(String sql) {
String upperCaseSql = sql.toUpperCase();
Set<Integer> set = new HashSet<>();
set.add(upperCaseSql.indexOf("SELECT "));
set.add(upperCaseSql.indexOf("UPDATE "));
set.add(upperCaseSql.indexOf("INSERT "));
set.add(upperCaseSql.indexOf("DELETE "));
set.remove(-1);
if (CollectionUtils.isEmpty(set)) {
return -1;
}
List<Integer> list = new ArrayList<>(set);
list.sort(Comparator.naturalOrder());
return list.get(0);
}
/**
* 獲取原始SQL, COPY FROM {@link PerformanceInterceptor}
*/
private String getOriginSql(PreparedStatement statement) {
String originalSql = null;
String stmtClassName = statement.getClass().getName();
if (DruidPooledPreparedStatement.equals(stmtClassName)) {
try {
if (druidGetSQLMethod == null) {
Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
druidGetSQLMethod = clazz.getMethod("getSql");
}
Object stmtSql = druidGetSQLMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} catch (Exception e) {
e.printStackTrace();
}
} else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
try {
if (oracleGetOriginalSqlMethod != null) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} else {
Class<?> clazz = Class.forName(stmtClassName);
oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
if (oracleGetOriginalSqlMethod != null) {
// OraclePreparedStatementWrapper is not a public class, need set this.
oracleGetOriginalSqlMethod.setAccessible(true);
if (null != oracleGetOriginalSqlMethod) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
}
}
}
} catch (Exception e) {
// ignore
}
}
if (originalSql == null) {
originalSql = statement.toString();
}
return originalSql;
}
/**
* 獲取此方法名的具體 Method
*
* @param clazz class 物件
* @param methodName 方法名
* @return 方法
*/
public Method getMethodRegular(Class<?> clazz, String methodName) {
if (Object.class.equals(clazz)) {
return null;
}
for (Method method : clazz.getDeclaredMethods()) {
if (method.getName().equals(methodName)) {
return method;
}
}
return getMethodRegular(clazz.getSuperclass(), methodName);
}
/**
* 判斷是否正確的獲取了SQL
*/
private boolean isCorrectGetSql(BoundSql boundSql, String originSql) {
return countQuestionMark(boundSql.getSql()) > countQuestionMark(originSql);
}
/**
* 統計佔位符
*/
private int countQuestionMark(String sql) {
int result = 0;
for (char c : sql.toCharArray())
if (c == '?')
result++;
return result;
}
/**
* 空MP配置模板,用於構建MP環境配置,(放這裡是由於部落格的編輯器識別問題,會導致高亮錯誤)
*/
private static final String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n"
+ "<!DOCTYPE configuration\r\n"
+ " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n"
+ " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"
+ "<configuration>\r\n"
+ "</configuration>";
}
稱之為Excutor是因為這其實是個子類,只提供parseSql
其他
若有更好的方法,請務必告訴我(๑•̀ㅂ•́)و✧