白小兔的小小站

既然选择了远方,便只顾风雨兼程

0%

通过Mybatis的Interceptor动态修改SQL实现限制时间查询范围

随着业务表的数据量越来越大,DBA催着进行了分表,然后需要业务查询层限制查询时间范围。考虑到直接修改Mybatis的Mapper文件(xml)比较繁琐,遂采用其拦截器来实现这一功能。

要解决的问题有以下几点:

  1. 确保拦截且只拦截查询操作
  2. 获取并正确修改SQL
  3. 与其他组件(如PageHelper)正确协作

在找现成的轮子的时候,找到了Druid拦截sql语句,实现在添加一个查询条件这篇博文,所以可以发现,我的代码基本上都是来自于这里。

拦截器的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGSelectQueryBlock;
import com.alibaba.druid.sql.dialect.postgresql.parser.PGSQLStatementParser;
import com.alibaba.druid.sql.dialect.postgresql.visitor.PGOutputVisitor;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;

import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;

/**
* @author littleRabbit on 2020-04-13
* detail:查询时间范围拦截器
*/
@Intercepts({@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
public class QueryTimeLimitInterceptor implements Interceptor {

private static final String PG_STRING = "postgresql";

@Override
public Object intercept(Invocation invocation) throws Throwable {
if (QueryTimeLimitUtil.isQueryTimeLimitEnabled()) {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameter = args[1];
BoundSql boundSql;
if (args.length == 4) {
boundSql = ms.getBoundSql(parameter);
} else {
boundSql = (BoundSql) args[5];
}
if (boundSql != null && StringUtils.hasText(boundSql.getSql())) {
String oldSql = boundSql.getSql();
String newSql = oldSql;
PGSQLStatementParser parser = new PGSQLStatementParser(oldSql);
SQLStatement stmt = parser.parseStatement();
if (stmt instanceof SQLSelectStatement) {
SQLSelect sqlSelect = ((SQLSelectStatement) stmt).getSelect();
if (sqlSelect.getQuery() instanceof SQLUnionQuery) {
SQLUnionQuery unionQuery = (SQLUnionQuery) sqlSelect.getQuery();
newSql = doUnionSelect(unionQuery);
} else {
newSql = doSelectSql(newSql, (PGSelectQueryBlock) sqlSelect.getQueryBlock());
}
}
resetSql2Invocation(invocation, newSql);
}
}
return invocation.proceed();
}

private String doUnionSelect(SQLUnionQuery unionQuery) {
SQLSelectQuery left = unionQuery.getLeft();
SQLSelectQuery right = unionQuery.getRight();
if (left instanceof SQLUnionQuery) {
doUnionSelect((SQLUnionQuery) left);
} else {
doSelectSql(String.valueOf(left), (PGSelectQueryBlock) left);
}
if (right instanceof SQLUnionQuery) {
doUnionSelect((SQLUnionQuery) right);
} else {
doSelectSql(String.valueOf(right), (PGSelectQueryBlock) right);
}
return String.valueOf(unionQuery);
}

@Override
public Object plugin(Object o) {
return Plugin.wrap(o, this);
}

@Override
public void setProperties(Properties properties) {

}

private String doSelectSql(String sql, PGSelectQueryBlock select) {
// 获取where对象
SQLExpr where = select.getWhere();
List<SQLSelectItem> selectList = select.getSelectList();
// 遍历查询的字段,如果查询字段中有子查询 则加上查询条件
selectList.forEach(e -> {
if (e.getExpr() instanceof SQLQueryExpr) {
SQLQueryExpr expr = (SQLQueryExpr) e.getExpr();
String newFieldSql = doSelectSql(String.valueOf(expr), (PGSelectQueryBlock) expr.getSubQuery().getQueryBlock());
SQLExpr subSelect = SQLUtils.toSQLExpr(newFieldSql);
e.setExpr(subSelect);
}
});
// 获取所查询的表
SQLTableSource from = select.getFrom();
// 如果from语句是子查询
if (from instanceof SQLSubqueryTableSource) {
String fromString = String.valueOf(from);
SQLSubqueryTableSource subqueryTableSource = (SQLSubqueryTableSource) from;
String subQuery = doSelectSql(fromString, (PGSelectQueryBlock) subqueryTableSource.getSelect().getQueryBlock());
SQLSelect sqlSelectBySql = getSqlSelectBySql(subQuery);
((SQLSubqueryTableSource) from).setSelect(sqlSelectBySql);
select.setWhere(getNewWhereCondition(select, where, sql, from));
}
// 如果from语句是关联查询
if (from instanceof SQLJoinTableSource) {
SQLJoinTableSource joinFrom = (SQLJoinTableSource) from;
SQLTableSource left = joinFrom.getLeft();
SQLTableSource right = joinFrom.getRight();
setTableSourceNewSql(left);
setTableSourceNewSql(right);
}
select.setWhere(getNewWhereCondition(select, where, sql, from));
StringBuffer newSql = new StringBuffer();
select.accept0(new PGOutputVisitor(newSql));
return newSql.toString();
}

/**
* from语句是子查询的 处理子查询 并更新from语句
*
* @param tableSource from语句
*/
private void setTableSourceNewSql(SQLTableSource tableSource) {
if (!(tableSource instanceof SQLSubqueryTableSource)) {
return;
}
SQLSubqueryTableSource subqueryTableSource = (SQLSubqueryTableSource) tableSource;
String leftSubQueryString = String.valueOf(subqueryTableSource.getSelect());
String newLeftSubQueryString = doSelectSql(leftSubQueryString, (PGSelectQueryBlock) subqueryTableSource.getSelect().getQueryBlock());
SQLSelect sqlselect = getSqlSelectBySql(newLeftSubQueryString);
subqueryTableSource.setSelect(sqlselect);
}

/**
* 添加where条件
*
* @param where where语句
* @return 修改后的where条件
*/
private SQLExpr getNewWhereCondition(PGSelectQueryBlock select, SQLExpr where, String sql,
SQLTableSource tableSource) {
// 如果where中包含子查询
if (where instanceof SQLInSubQueryExpr) {
SQLSelect subSelect = ((SQLInSubQueryExpr) where).subQuery;
// 获取子查询语句
String subQuery = String.valueOf(subSelect);
// 处理子查询语句
String newSubQuery = doSelectSql(subQuery, (PGSelectQueryBlock) subSelect.getQueryBlock());
SQLSelect sqlSelectBySql = getSqlSelectBySql(newSubQuery);
((SQLInSubQueryExpr) where).setSubQuery(sqlSelectBySql);
}
SQLBinaryOpExpr binaryOpExprWhere = new SQLBinaryOpExpr(PG_STRING);
List<SourceFromInfo> tableNameList = new ArrayList<>();
getTableNames(select, tableSource, tableNameList);
if (CollectionUtils.isEmpty(tableNameList)) {
return where;
}
// 根据多个表名获取拼接条件
SQLBinaryOpExpr conditionByTableName = getWhereConditionByTableList(tableNameList);
// 没有需要添加的条件,直接返回
if (ObjectUtils.isEmpty(conditionByTableName)) {
return where;
}
// 没有where条件时 则返回需要添加的条件
if (where == null) {
return conditionByTableName;
}
binaryOpExprWhere.setLeft(conditionByTableName);
binaryOpExprWhere.setOperator(SQLBinaryOperator.BooleanAnd);
binaryOpExprWhere.setRight(where.clone());
return binaryOpExprWhere;
}

/**
* 查询所有的表信息
*
* @param select from语句对应的select语句
* @param tableSource from语句
* @param tableNameList sql中from语句中所有表信息
*/
private void getTableNames(PGSelectQueryBlock select, SQLTableSource tableSource,
List<SourceFromInfo> tableNameList) {
// 子查询
if (tableSource instanceof SQLSubqueryTableSource) {
SourceFromInfo fromInfo = new SourceFromInfo();
fromInfo.setSubQuery(true);
SQLSubqueryTableSource subqueryTableSource = (SQLSubqueryTableSource) tableSource;
// 设置别名
fromInfo.setAlias(subqueryTableSource.getAlias());
tableNameList.add(fromInfo);
}
// 连接查询
if (tableSource instanceof SQLJoinTableSource) {
SQLJoinTableSource joinSource = (SQLJoinTableSource) tableSource;
SQLTableSource left = joinSource.getLeft();
SQLTableSource right = joinSource.getRight();
// 子查询则递归获取
if (left instanceof SQLSubqueryTableSource) {
getTableNames((PGSelectQueryBlock) ((SQLSubqueryTableSource) left).getSelect().getQuery(), left,
tableNameList);
}
// 子查询则递归获取
if (right instanceof SQLSubqueryTableSource) {
getTableNames((PGSelectQueryBlock) ((SQLSubqueryTableSource) right).getSelect().getQuery(), right,
tableNameList);
}
// 连接查询 左边是单表
if (left instanceof SQLExprTableSource) {
addOnlyTable(left, tableNameList);
}
// 连接查询 右边是单表
if (right instanceof SQLExprTableSource) {
addOnlyTable(right, tableNameList);
}
// 连接查询 左边还是连接查询 则递归继续获取表名
if (left instanceof SQLJoinTableSource) {
getTableNames(null, left, tableNameList);
}
// 连接查询 右边还是连接查询 则递归继续获取表名
if (right instanceof SQLJoinTableSource) {
getTableNames(null, right, tableNameList);
}
}
// 普通表查询
if (tableSource instanceof SQLExprTableSource) {
addOnlyTable(tableSource, tableNameList);
}
}

/**
* 如果当前from语句只有单表,则添加到list中
*
* @param tableSource from语句
* @param tableNameList 表信息list
*/
private void addOnlyTable(SQLTableSource tableSource, List<SourceFromInfo> tableNameList) {
SourceFromInfo fromInfo = new SourceFromInfo();
// 普通表查询
String tableName = String.valueOf(tableSource);
fromInfo.setTableName(tableName);
fromInfo.setAlias(tableSource.getAlias());
fromInfo.setNeedAddCondition(true);
tableNameList.add(fromInfo);
}

/**
* 根据from语句得到的表名拼接条件
*
* @param tableNameList 表名列表
* @return 拼接后的条件
*/
private SQLBinaryOpExpr getWhereConditionByTableList(List<SourceFromInfo> tableNameList) {
// 先过滤掉不需要添加条件的
tableNameList = tableNameList.stream().filter(SourceFromInfo::isNeedAddCondition).collect(Collectors.toList());
if (CollectionUtils.isEmpty(tableNameList)) {
return null;
}
SQLBinaryOpExpr allCondition = new SQLBinaryOpExpr(PG_STRING);
for (int i = 0; i < tableNameList.size(); i++) {
SourceFromInfo tableNameInfo = tableNameList.get(i);
SQLBinaryOpExpr timeRangeLimit = getTimeRangeCondition(tableNameInfo);
// 如果是最后一个且不是第一个则将当期table条件设置为右侧条件
if (i > 0 && i == tableNameList.size() - 1) {
allCondition.setOperator(SQLBinaryOperator.BooleanAnd);
allCondition.setRight(timeRangeLimit);
break;
}
// 如果是只有一个table 则直接设置最终条件为当期table条件
if (tableNameList.size() == 1) {
allCondition = timeRangeLimit;
break;
}
if (allCondition.getLeft() == null) {
allCondition.setLeft(timeRangeLimit);
} else {
SQLBinaryOpExpr condition = getAndCondition((SQLBinaryOpExpr) allCondition.getLeft(), timeRangeLimit);
allCondition.setLeft(condition);
}
}
return allCondition;
}

/**
* 根据表信息拼接条件
*
* @param tableNameInfo 表信息
* @return 拼接后的条件
*/
private SQLBinaryOpExpr getTimeRangeCondition(SourceFromInfo tableNameInfo) {
SQLBinaryOpExpr timeRangeLimit = new SQLBinaryOpExpr(PG_STRING);
String timeIdentifier = QueryTimeLimitUtil.getTimeIdentifier(tableNameInfo.getTableName());
if (StringUtils.isEmpty(tableNameInfo.getAlias())) {
timeRangeLimit.setLeft(new SQLIdentifierExpr(timeIdentifier));
} else {
timeRangeLimit.setLeft(new SQLPropertyExpr(tableNameInfo.getAlias(), timeIdentifier));
}
timeRangeLimit.setOperator(SQLBinaryOperator.GreaterThanOrEqual);
Timestamp beforeThreeMonth = new Timestamp(LocalDateTime.now().minus(3, ChronoUnit.MONTHS).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli());
timeRangeLimit.setRight(new SQLTimestampExpr(beforeThreeMonth.toString()));
return timeRangeLimit;
}

/**
* 拼接and条件
*
* @param left 左侧条件
* @param right 右侧条件
* @return 拼接后的条件
*/
private SQLBinaryOpExpr getAndCondition(SQLBinaryOpExpr left, SQLBinaryOpExpr right) {
SQLBinaryOpExpr condition = new SQLBinaryOpExpr(PG_STRING);
condition.setLeft(left);
condition.setOperator(SQLBinaryOperator.BooleanAnd);
condition.setRight(right);
return condition;
}

/**
* 将String类型select sql语句转化为SQLSelect对象
*
* @param sql 查询SQL语句
* @return 转化后的对象实体
*/
private SQLSelect getSqlSelectBySql(String sql) {
SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, PG_STRING);
List<SQLStatement> parseStatementList = parser.parseStatementList();
if (CollectionUtils.isEmpty(parseStatementList)) {
return null;
}
SQLSelectStatement sstmt = (SQLSelectStatement) parseStatementList.get(0);
return sstmt.getSelect();
}

private void resetSql2Invocation(Invocation invocation, String sql) throws SQLException {
final Object[] args = invocation.getArgs();
MappedStatement statement = (MappedStatement) args[0];
Object parameterObject = args[1];
BoundSql boundSql = statement.getBoundSql(parameterObject);
MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
msObject.setValue("sqlSource.boundSql.sql", sql);
args[0] = newStatement;
}

private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder =
new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}

class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;

public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}

@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}

class SourceFromInfo {
private String tableName;

private String alias;

private boolean needAddCondition;

private boolean subQuery;

public String getTableName() {
return tableName;
}

public void setTableName(String tableName) {
this.tableName = tableName;
}

public String getAlias() {
return alias;
}

public void setAlias(String alias) {
this.alias = alias;
}

public boolean isNeedAddCondition() {
return needAddCondition;
}

public void setNeedAddCondition(boolean needAddCondition) {
this.needAddCondition = needAddCondition;
}

public boolean isSubQuery() {
return subQuery;
}

public void setSubQuery(boolean subQuery) {
this.subQuery = subQuery;
}
}

}

解决开篇提出的问题

  1. 在mybatis的配置文件中配置上我们自定义的拦截器

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    <?xml version="1.0" encoding="UTF-8" ?>
    <!DOCTYPE configuration PUBLIC "-//mybatis.org//DTD Config 3.0//EN"
    "http://mybatis.org/dtd/mybatis-3-config.dtd">
    <configuration>
    <settings>
    <setting name="cacheEnabled" value="true"/>
    <setting name="mapUnderscoreToCamelCase" value="true"/>
    </settings>
    <plugins>
    <plugin interceptor="com.github.pagehelper.PageInterceptor">
    <property name="helperDialect" value="postgresql"/>
    <property name="pageSizeZero" value="true"/>
    </plugin>
    <plugin interceptor="com.leuncle.interceptor.QueryTimeRangeInterceptor"/>
    </plugins>
    </configuration>

这里将自定义的拦截器放在PageHelper的后面,它们就可以和谐共处了。不能放前面的原因是,Mybatis的拦截器按从后向前的顺序执行,由于PageHelper的intercept中没有执行invocation.proceed(),所以导致后续的拦截器不再生效,所谓的职责链模式,这篇文章讲了原因和解决方案。

  1. 正确修改SQL,主要是加进去的查询条件放的位置,不能错了,再一点就是字段名,因为拦截了所有的查询,所以必须小心,否则弄个字段不存在,就完蛋了。这里就需要我们自己进行业务相关的实现了,我因为需求只是对创建时间进行限制,PG库的所有表的这个字段都有且命名统一,所以没这个烦恼。
  2. 可能存在不需要处理的表,那么直接在addOnlyTable方法中判断下表名,如果是不需要处理的表,将needAddCondition设置为false即可。