Skip to content

Commit

Permalink
fix #171 sharding-jdbc-core is adjusted for the feature of id-generator
Browse files Browse the repository at this point in the history
  • Loading branch information
hanahmily authored and gaohongtao committed Nov 10, 2016
1 parent 905b94a commit 6fc065e
Show file tree
Hide file tree
Showing 32 changed files with 671 additions and 95 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ logs/
# system ignore
.DS_Store
Thumbs.db

*.class
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@
import com.dangdang.ddframe.rdb.sharding.api.strategy.database.NoneDatabaseShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.table.NoneTableShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.table.TableShardingStrategy;
import com.dangdang.ddframe.rdb.sharding.exception.ShardingJdbcException;
import com.dangdang.ddframe.rdb.sharding.id.generator.IdGenerator;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.TreeSet;

/**
* 分库分表规则配置对象.
Expand Down Expand Up @@ -84,12 +87,12 @@ public static ShardingRuleBuilder builder() {
}

/**
* 根据逻辑表名称查找分片规则.
* 试着根据逻辑表名称查找分片规则.
*
* @param logicTableName 逻辑表名称
* @return 该逻辑表的分片规则
*/
public Optional<TableRule> findTableRule(final String logicTableName) {
public Optional<TableRule> tryFindTableRule(final String logicTableName) {
for (TableRule each : tableRules) {
if (each.getLogicTable().equals(logicTableName)) {
return Optional.of(each);
Expand All @@ -98,6 +101,20 @@ public Optional<TableRule> findTableRule(final String logicTableName) {
return Optional.absent();
}

/**
* 根据逻辑表名找到指定分片规则.
*
* @param logicTableName 逻辑表名称
* @return 该逻辑表的分片规则
*/
public TableRule findTableRule(final String logicTableName) {
Optional<TableRule> tableRuleOptional = tryFindTableRule(logicTableName);
if (tableRuleOptional.isPresent()) {
return tableRuleOptional.get();
}
throw new ShardingJdbcException(String.format("%s does not exist in ShardingRule", logicTableName));
}

/**
* 获取数据库分片策略.
*
Expand Down Expand Up @@ -194,14 +211,17 @@ public Optional<BindingTableRule> findBindingTableRule(final String logicTable)
/**
* 获取所有的分片列名.
*
* @param tableName 表名
* @return 分片列名集合
*/
// TODO 目前使用分片列名称, 为了进一步提升解析性能,应考虑使用表名 + 列名
public Collection<String> getAllShardingColumns() {
Set<String> result = new HashSet<>();
public Collection<String> getAllShardingColumns(final String tableName) {
Set<String> result = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
result.addAll(databaseShardingStrategy.getShardingColumns());
result.addAll(tableShardingStrategy.getShardingColumns());
for (TableRule each : tableRules) {
if (!each.getLogicTable().equalsIgnoreCase(tableName)) {
continue;
}
if (null != each.getDatabaseShardingStrategy()) {
result.addAll(each.getDatabaseShardingStrategy().getShardingColumns());
}
Expand All @@ -212,6 +232,22 @@ public Collection<String> getAllShardingColumns() {
return result;
}

/**
* 获取所有需要自增的列名.
*
* @param tableName 表名
* @return 自增列
*/
public Collection<String> getAutoIncrementColumns(final String tableName) {
for (TableRule each : tableRules) {
if (!each.getLogicTable().equalsIgnoreCase(tableName)) {
continue;
}
return Sets.newLinkedHashSet(each.getAutoIncrementColumnMap().keySet());
}
return Collections.emptySet();
}

/**
* 分片规则配置对象构建器.
*/
Expand All @@ -228,6 +264,8 @@ public static class ShardingRuleBuilder {

private TableShardingStrategy tableShardingStrategy;

private Class<? extends IdGenerator> idGeneratorClass;

/**
* 构建数据源配置规则.
*
Expand Down Expand Up @@ -282,14 +320,32 @@ public ShardingRuleBuilder tableShardingStrategy(final TableShardingStrategy tab
this.tableShardingStrategy = tableShardingStrategy;
return this;
}

/**
* 构建默认id生成器.
*
* @param idGeneratorClass 默认的Id生成器
* @return 分片规则配置对象构建器
*/
public ShardingRuleBuilder idGenerator(final Class<? extends IdGenerator> idGeneratorClass) {
this.idGeneratorClass = idGeneratorClass;
return this;
}

/**
* 构建分片规则配置对象.
*
* @return 分片规则配置对象
*/
public ShardingRule build() {
return new ShardingRule(dataSourceRule, tableRules, bindingTableRules, databaseShardingStrategy, tableShardingStrategy);
ShardingRule result = new ShardingRule(dataSourceRule, tableRules, bindingTableRules, databaseShardingStrategy, tableShardingStrategy);
if (null == idGeneratorClass) {
return result;
}
for (TableRule each : tableRules) {
each.fillIdGenerator(idGeneratorClass);
}
return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@

import com.dangdang.ddframe.rdb.sharding.api.strategy.database.DatabaseShardingStrategy;
import com.dangdang.ddframe.rdb.sharding.api.strategy.table.TableShardingStrategy;
import com.dangdang.ddframe.rdb.sharding.id.generator.IdGenerator;
import com.google.common.base.Preconditions;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;

/**
* 表规则配置对象.
Expand All @@ -49,6 +53,9 @@ public final class TableRule {

private final TableShardingStrategy tableShardingStrategy;

@Getter(AccessLevel.PACKAGE)
private final Map<String, IdGenerator> autoIncrementColumnMap = new LinkedHashMap<>();

/**
* 全属性构造器.
*
Expand Down Expand Up @@ -193,6 +200,27 @@ int findActualTableIndex(final String dataSourceName, final String actualTableNa
return -1;
}

void fillIdGenerator(final Class<? extends IdGenerator> idGeneratorClass) {
for (Map.Entry<String, IdGenerator> each : autoIncrementColumnMap.entrySet()) {
if (null == each.getValue()) {
each.setValue(TableRuleBuilder.instanceIdGenerator(idGeneratorClass));
}
}
}

/**
* 生成Id.
*
* @param columnName 列名称
* @return 生成的id
*/
public Object generateId(final String columnName) {
Object result = autoIncrementColumnMap.get(columnName).generateId();
Preconditions.checkNotNull(result);
Preconditions.checkState(result instanceof Number || result instanceof String, "id %s(%s) should be Number or String", result.toString(), result.getClass().getName());
return result;
}

/**
* 表规则配置对象构建器.
*/
Expand All @@ -212,7 +240,21 @@ public static class TableRuleBuilder {
private DatabaseShardingStrategy databaseShardingStrategy;

private TableShardingStrategy tableShardingStrategy;

private final Map<String, IdGenerator> autoIncrementColumnMap = new LinkedHashMap<>();

private Class<? extends IdGenerator> tableIdGeneratorClass;


static IdGenerator instanceIdGenerator(final Class<? extends IdGenerator> idGeneratorClass) {
Preconditions.checkNotNull(idGeneratorClass);
try {
return idGeneratorClass.newInstance();
} catch (final InstantiationException | IllegalAccessException e) {
throw new IllegalArgumentException(String.format("Class %s should have public privilege and no argument constructor", idGeneratorClass.getName()));
}
}

/**
* 构建是否为动态表.
*
Expand Down Expand Up @@ -278,14 +320,55 @@ public TableRuleBuilder tableShardingStrategy(final TableShardingStrategy tableS
this.tableShardingStrategy = tableShardingStrategy;
return this;
}

/**
* 自增列.
*
* @param autoIncrementColumn 自增列名称
* @return 规则配置对象构建器
*/
public TableRuleBuilder autoIncrementColumns(final String autoIncrementColumn) {
this.autoIncrementColumnMap.put(autoIncrementColumn, null);
return this;
}

/**
* 自增列.
*
* @param autoIncrementColumn 自增列名称
* @param columnIdGeneratorClass 列Id生成器的类
* @return 规则配置对象构建器
*/
public TableRuleBuilder autoIncrementColumns(final String autoIncrementColumn, final Class<? extends IdGenerator> columnIdGeneratorClass) {
this.autoIncrementColumnMap.put(autoIncrementColumn, instanceIdGenerator(columnIdGeneratorClass));
return this;
}

/**
* 整个表的Id生成器.
*
* @param tableIdGeneratorClass Id生成器
* @return 规则配置对象构建器
*/
public TableRuleBuilder tableIdGenerator(final Class<? extends IdGenerator> tableIdGeneratorClass) {
this.tableIdGeneratorClass = tableIdGeneratorClass;
return this;
}

/**
* 构建表规则配置对象.
*
* @return 表规则配置对象
*/
public TableRule build() {
return new TableRule(logicTable, dynamic, actualTables, dataSourceRule, dataSourceNames, databaseShardingStrategy, tableShardingStrategy);
TableRule result = new TableRule(logicTable, dynamic, actualTables, dataSourceRule, dataSourceNames, databaseShardingStrategy, tableShardingStrategy);
result.autoIncrementColumnMap.putAll(autoIncrementColumnMap);
if (null == tableIdGeneratorClass) {
return result;
}
result.fillIdGenerator(tableIdGeneratorClass);
return result;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ public final void replayMethodsInvocation(final Object target) {
}
}

@Override
public boolean add(final Object o) {
int index = jdbcMethodInvocations.size() + 1;
recordMethodInvocation(index, "setObject", new Class[]{int.class, Object.class}, new Object[]{index, o});
return true;
}

/**
* 根据索引设置列表中的值.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule;
import com.dangdang.ddframe.rdb.sharding.exception.SQLParserException;
import com.dangdang.ddframe.rdb.sharding.parser.result.SQLParsedResult;
import com.dangdang.ddframe.rdb.sharding.parser.result.router.SQLStatementType;
Expand All @@ -32,7 +33,6 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.Collection;
import java.util.List;

/**
Expand All @@ -51,7 +51,7 @@ public final class SQLParseEngine {

private final SQLASTOutputVisitor visitor;

private final Collection<String> shardingColumns;
private final ShardingRule shardingRule;

/**
*  解析SQL.
Expand All @@ -62,7 +62,7 @@ public SQLParsedResult parse() {
Preconditions.checkArgument(visitor instanceof SQLVisitor);
SQLVisitor sqlVisitor = (SQLVisitor) visitor;
visitor.setParameters(parameters);
sqlVisitor.getParseContext().setShardingColumns(shardingColumns);
sqlVisitor.getParseContext().setShardingRule(shardingRule);
sqlStatement.accept(visitor);
SQLParsedResult result = sqlVisitor.getParseContext().getParsedResult();
if (sqlVisitor.getParseContext().isHasOrCondition()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
import com.alibaba.druid.sql.dialect.sqlserver.parser.SQLServerStatementParser;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule;
import com.dangdang.ddframe.rdb.sharding.constants.DatabaseType;
import com.dangdang.ddframe.rdb.sharding.exception.SQLParserException;
import com.dangdang.ddframe.rdb.sharding.parser.visitor.VisitorLogProxy;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.Collection;
import java.util.List;

/**
Expand All @@ -54,15 +54,15 @@ public final class SQLParserFactory {
* @param databaseType 数据库类型
* @param sql SQL语句
* @param parameters SQL中参数的值
* @param shardingColumns 分片列名称集合
* @param shardingRule 分片规则
* @return 解析器引擎对象
* @throws SQLParserException SQL解析异常
*/
public static SQLParseEngine create(final DatabaseType databaseType, final String sql, final List<Object> parameters, final Collection<String> shardingColumns) throws SQLParserException {
public static SQLParseEngine create(final DatabaseType databaseType, final String sql, final List<Object> parameters, final ShardingRule shardingRule) throws SQLParserException {
log.debug("Logic SQL: {}, {}", sql, parameters);
SQLStatement sqlStatement = getSQLStatementParser(databaseType, sql).parseStatement();
log.trace("Get {} SQL Statement", sqlStatement.getClass().getName());
return new SQLParseEngine(sqlStatement, parameters, getSQLVisitor(databaseType, sqlStatement), shardingColumns);
return new SQLParseEngine(sqlStatement, parameters, getSQLVisitor(databaseType, sqlStatement), shardingRule);
}

private static SQLStatementParser getSQLStatementParser(final DatabaseType databaseType, final String sql) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package com.dangdang.ddframe.rdb.sharding.parser.result.router;

import java.util.Collection;
import java.util.LinkedHashSet;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;

import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.LinkedList;

/**
* SQL路由上下文.
*
Expand All @@ -41,4 +42,6 @@ public final class RouteContext {
private SQLStatementType sqlStatementType;

private SQLBuilder sqlBuilder;

private Collection<String> autoIncrementColumns = new LinkedList<>();
}
Loading

0 comments on commit 6fc065e

Please sign in to comment.