Bläddra i källkod

set backtick by DbType (#3984)

Co-authored-by: agapple <jianghang.loujh@alibaba-inc.com>
He Wang 3 år sedan
förälder
incheckning
c775478a56

+ 10 - 3
client-adapter/rdb/src/main/java/com/alibaba/otter/canal/client/adapter/rdb/RdbAdapter.java

@@ -14,6 +14,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import com.alibaba.druid.pool.DruidDataSource;
+import com.alibaba.druid.util.JdbcUtils;
 import com.alibaba.otter.canal.client.adapter.OuterAdapter;
 import com.alibaba.otter.canal.client.adapter.rdb.config.ConfigLoader;
 import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
@@ -73,6 +74,11 @@ public class RdbAdapter implements OuterAdapter {
     @Override
     public void init(OuterAdapterConfig configuration, Properties envProperties) {
         this.envProperties = envProperties;
+
+        // 从jdbc url获取db类型
+        Map<String, String> properties = configuration.getProperties();
+        String dbType = JdbcUtils.getDbType(properties.get("jdbc.url"), null);
+
         Map<String, MappingConfig> rdbMappingTmp = ConfigLoader.load(envProperties);
         // 过滤不匹配的key的配置
         rdbMappingTmp.forEach((key, mappingConfig) -> {
@@ -112,7 +118,6 @@ public class RdbAdapter implements OuterAdapter {
         }
 
         // 初始化连接池
-        Map<String, String> properties = configuration.getProperties();
         dataSource = new DruidDataSource();
         dataSource.setDriverClassName(properties.get("jdbc.driverClassName"));
         dataSource.setUrl(properties.get("jdbc.url"));
@@ -125,6 +130,8 @@ public class RdbAdapter implements OuterAdapter {
         dataSource.setTimeBetweenEvictionRunsMillis(60000);
         dataSource.setMinEvictableIdleTimeMillis(300000);
         dataSource.setUseUnfairLock(true);
+        dataSource.setDbType(dbType);
+
         // List<String> array = new ArrayList<>();
         // array.add("set names utf8mb4;");
         // dataSource.setConnectionInitSqls(array);
@@ -226,7 +233,7 @@ public class RdbAdapter implements OuterAdapter {
     public Map<String, Object> count(String task) {
         MappingConfig config = rdbMapping.get(task);
         MappingConfig.DbMapping dbMapping = config.getDbMapping();
-        String sql = "SELECT COUNT(1) AS cnt FROM " + SyncUtil.getDbTableName(dbMapping);
+        String sql = "SELECT COUNT(1) AS cnt FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType());
         Connection conn = null;
         Map<String, Object> res = new LinkedHashMap<>();
         try {
@@ -252,7 +259,7 @@ public class RdbAdapter implements OuterAdapter {
                 }
             }
         }
-        res.put("targetTable", SyncUtil.getDbTableName(dbMapping));
+        res.put("targetTable", SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()));
 
         return res;
     }

+ 9 - 3
client-adapter/rdb/src/main/java/com/alibaba/otter/canal/client/adapter/rdb/service/RdbEtlService.java

@@ -13,6 +13,7 @@ import java.util.concurrent.atomic.AtomicLong;
 
 import javax.sql.DataSource;
 
+import com.alibaba.druid.pool.DruidDataSource;
 import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
 import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig.DbMapping;
 import com.alibaba.otter.canal.client.adapter.rdb.support.SyncUtil;
@@ -56,8 +57,11 @@ public class RdbEtlService extends AbstractEtlService {
             DbMapping dbMapping = (DbMapping) mapping;
             Map<String, String> columnsMap = new LinkedHashMap<>();
             Map<String, Integer> columnType = new LinkedHashMap<>();
+            DruidDataSource dataSource = (DruidDataSource) srcDS;
 
-            Util.sqlRS(targetDS, "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping) + " LIMIT 1 ", rs -> {
+            Util.sqlRS(targetDS,
+                "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()) + " LIMIT 1 ",
+                rs -> {
                 try {
 
                     ResultSetMetaData rsd = rs.getMetaData();
@@ -83,7 +87,9 @@ public class RdbEtlService extends AbstractEtlService {
                     boolean completed = false;
 
                     StringBuilder insertSql = new StringBuilder();
-                    insertSql.append("INSERT INTO ").append(SyncUtil.getDbTableName(dbMapping)).append(" (");
+                    insertSql.append("INSERT INTO ")
+                        .append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()))
+                        .append(" (");
                     columnsMap
                         .forEach((targetColumnName, srcColumnName) -> insertSql.append(targetColumnName).append(","));
 
@@ -107,7 +113,7 @@ public class RdbEtlService extends AbstractEtlService {
                             // 删除数据
                             Map<String, Object> pkVal = new LinkedHashMap<>();
                             StringBuilder deleteSql = new StringBuilder(
-                                "DELETE FROM " + SyncUtil.getDbTableName(dbMapping) + " WHERE ");
+                                "DELETE FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()) + " WHERE ");
                             appendCondition(dbMapping, deleteSql, pkVal, rs);
                             try (PreparedStatement pstmt2 = connTarget.prepareStatement(deleteSql.toString())) {
                                 int k = 1;

+ 12 - 5
client-adapter/rdb/src/main/java/com/alibaba/otter/canal/client/adapter/rdb/service/RdbMirrorDbSyncService.java

@@ -7,16 +7,17 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 
-import javax.sql.DataSource;
-
 import org.apache.commons.lang.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.alibaba.druid.pool.DruidDataSource;
 import com.alibaba.fastjson2.JSON;
 import com.alibaba.fastjson2.JSONWriter.Feature;
 import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
 import com.alibaba.otter.canal.client.adapter.rdb.config.MirrorDbConfig;
+import com.alibaba.otter.canal.client.adapter.rdb.support.SingleDml;
+import com.alibaba.otter.canal.client.adapter.rdb.support.SyncUtil;
 import com.alibaba.otter.canal.client.adapter.support.Dml;
 
 /**
@@ -30,10 +31,10 @@ public class RdbMirrorDbSyncService {
     private static final Logger         logger = LoggerFactory.getLogger(RdbMirrorDbSyncService.class);
 
     private Map<String, MirrorDbConfig> mirrorDbConfigCache;                                           // 镜像库配置
-    private DataSource                  dataSource;
+    private DruidDataSource             dataSource;
     private RdbSyncService              rdbSyncService;                                                // rdbSyncService代理
 
-    public RdbMirrorDbSyncService(Map<String, MirrorDbConfig> mirrorDbConfigCache, DataSource dataSource,
+    public RdbMirrorDbSyncService(Map<String, MirrorDbConfig> mirrorDbConfigCache, DruidDataSource dataSource,
                                   Integer threads, Map<String, Map<String, Integer>> columnsTypeCache,
                                   boolean skipDupException){
         this.mirrorDbConfigCache = mirrorDbConfigCache;
@@ -150,7 +151,13 @@ public class RdbMirrorDbSyncService {
      */
     private void executeDdl(MirrorDbConfig mirrorDbConfig, Dml ddl) {
         try (Connection conn = dataSource.getConnection(); Statement statement = conn.createStatement()) {
-            statement.execute(ddl.getSql());
+            // 替换反引号
+            String sql = ddl.getSql();
+            String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());
+            if (!"`".equals(backtick)) {
+                sql = sql.replaceAll("`", backtick);
+            }
+            statement.execute(sql);
             // 移除对应配置
             mirrorDbConfig.getTableConfig().remove(ddl.getTable());
             if (logger.isTraceEnabled()) {

+ 18 - 16
client-adapter/rdb/src/main/java/com/alibaba/otter/canal/client/adapter/rdb/service/RdbSyncService.java

@@ -15,12 +15,11 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.function.Function;
 
-import javax.sql.DataSource;
-
 import org.apache.commons.lang.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.alibaba.druid.pool.DruidDataSource;
 import com.alibaba.fastjson2.JSON;
 import com.alibaba.fastjson2.JSONWriter.Feature;
 import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
@@ -41,6 +40,7 @@ public class RdbSyncService {
 
     private static final Logger               logger  = LoggerFactory.getLogger(RdbSyncService.class);
 
+    private DruidDataSource                   dataSource;
     // 源库表字段类型缓存: instance.schema.table -> <columnName, jdbcType>
     private Map<String, Map<String, Integer>> columnsTypeCache;
 
@@ -59,13 +59,14 @@ public class RdbSyncService {
         return columnsTypeCache;
     }
 
-    public RdbSyncService(DataSource dataSource, Integer threads, boolean skipDupException){
+    public RdbSyncService(DruidDataSource dataSource, Integer threads, boolean skipDupException){
         this(dataSource, threads, new ConcurrentHashMap<>(), skipDupException);
     }
 
     @SuppressWarnings("unchecked")
-    public RdbSyncService(DataSource dataSource, Integer threads, Map<String, Map<String, Integer>> columnsTypeCache,
+    public RdbSyncService(DruidDataSource dataSource, Integer threads, Map<String, Map<String, Integer>> columnsTypeCache,
                           boolean skipDupException){
+        this.dataSource = dataSource;
         this.columnsTypeCache = columnsTypeCache;
         this.skipDupException = skipDupException;
         try {
@@ -251,15 +252,15 @@ public class RdbSyncService {
         }
 
         DbMapping dbMapping = config.getDbMapping();
-
+        String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());
         Map<String, String> columnsMap = SyncUtil.getColumnsMap(dbMapping, data);
 
         StringBuilder insertSql = new StringBuilder();
-        insertSql.append("INSERT INTO ").append(SyncUtil.getDbTableName(dbMapping)).append(" (");
+        insertSql.append("INSERT INTO ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType())).append(" (");
 
-        columnsMap.forEach((targetColumnName, srcColumnName) -> insertSql.append("`")
+        columnsMap.forEach((targetColumnName, srcColumnName) -> insertSql.append(backtick)
             .append(targetColumnName)
-            .append("`")
+            .append(backtick)
             .append(","));
         int len = insertSql.length();
         insertSql.delete(len - 1, len).append(") VALUES (");
@@ -323,13 +324,13 @@ public class RdbSyncService {
         }
 
         DbMapping dbMapping = config.getDbMapping();
-
+        String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());
         Map<String, String> columnsMap = SyncUtil.getColumnsMap(dbMapping, data);
 
         Map<String, Integer> ctype = getTargetColumnType(batchExecutor.getConn(), config);
 
         StringBuilder updateSql = new StringBuilder();
-        updateSql.append("UPDATE ").append(SyncUtil.getDbTableName(dbMapping)).append(" SET ");
+        updateSql.append("UPDATE ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType())).append(" SET ");
         List<Map<String, ?>> values = new ArrayList<>();
         boolean hasMatched = false;
         for (String srcColumnName : old.keySet()) {
@@ -342,7 +343,7 @@ public class RdbSyncService {
             if (!targetColumnNames.isEmpty()) {
                 hasMatched = true;
                 for (String targetColumnName : targetColumnNames) {
-                    updateSql.append("`").append(targetColumnName).append("`").append("=?, ");
+                    updateSql.append(backtick).append(targetColumnName).append(backtick).append("=?, ");
                     Integer type = ctype.get(Util.cleanColumn(targetColumnName).toLowerCase());
                     if (type == null) {
                         throw new RuntimeException("Target column: " + targetColumnName + " not matched");
@@ -379,11 +380,10 @@ public class RdbSyncService {
         }
 
         DbMapping dbMapping = config.getDbMapping();
-
         Map<String, Integer> ctype = getTargetColumnType(batchExecutor.getConn(), config);
 
         StringBuilder sql = new StringBuilder();
-        sql.append("DELETE FROM ").append(SyncUtil.getDbTableName(dbMapping)).append(" WHERE ");
+        sql.append("DELETE FROM ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType())).append(" WHERE ");
 
         List<Map<String, ?>> values = new ArrayList<>();
         // 拼接主键
@@ -402,7 +402,7 @@ public class RdbSyncService {
     private void truncate(BatchExecutor batchExecutor, MappingConfig config) throws SQLException {
         DbMapping dbMapping = config.getDbMapping();
         StringBuilder sql = new StringBuilder();
-        sql.append("TRUNCATE TABLE ").append(SyncUtil.getDbTableName(dbMapping));
+        sql.append("TRUNCATE TABLE ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()));
         batchExecutor.execute(sql.toString(), new ArrayList<>());
         if (logger.isTraceEnabled()) {
             logger.trace("Truncate target table, sql: {}", sql);
@@ -426,7 +426,7 @@ public class RdbSyncService {
                 if (columnType == null) {
                     columnType = new LinkedHashMap<>();
                     final Map<String, Integer> columnTypeTmp = columnType;
-                    String sql = "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping) + " WHERE 1=2";
+                    String sql = "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()) + " WHERE 1=2";
                     Util.sqlRS(conn, sql, rs -> {
                         try {
                             ResultSetMetaData rsd = rs.getMetaData();
@@ -455,6 +455,8 @@ public class RdbSyncService {
 
     private void appendCondition(MappingConfig.DbMapping dbMapping, StringBuilder sql, Map<String, Integer> ctype,
                                  List<Map<String, ?>> values, Map<String, Object> d, Map<String, Object> o) {
+        String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());
+
         // 拼接主键
         for (Map.Entry<String, String> entry : dbMapping.getTargetPk().entrySet()) {
             String targetColumnName = entry.getKey();
@@ -462,7 +464,7 @@ public class RdbSyncService {
             if (srcColumnName == null) {
                 srcColumnName = Util.cleanColumn(targetColumnName);
             }
-            sql.append("`").append(targetColumnName).append("`").append("=? AND ");
+            sql.append(backtick).append(targetColumnName).append(backtick).append("=? AND ");
             Integer type = ctype.get(Util.cleanColumn(targetColumnName).toLowerCase());
             if (type == null) {
                 throw new RuntimeException("Target column: " + targetColumnName + " not matched");

+ 28 - 3
client-adapter/rdb/src/main/java/com/alibaba/otter/canal/client/adapter/rdb/support/SyncUtil.java

@@ -1,5 +1,6 @@
 package com.alibaba.otter.canal.client.adapter.rdb.support;
 
+import com.alibaba.druid.DbType;
 import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
 import com.alibaba.otter.canal.client.adapter.support.Util;
 import org.apache.commons.lang.StringUtils;
@@ -255,12 +256,36 @@ public class SyncUtil {
         }
     }
 
-    public static String getDbTableName(MappingConfig.DbMapping dbMapping) {
+    public static String getDbTableName(MappingConfig.DbMapping dbMapping, String dbType) {
         String result = "";
+        String backtick = getBacktickByDbType(dbType);
         if (StringUtils.isNotEmpty(dbMapping.getTargetDb())) {
-            result += ("`" + dbMapping.getTargetDb() + "`.");
+            result += (backtick + dbMapping.getTargetDb() + backtick + ".");
         }
-        result += ("`" + dbMapping.getTargetTable() + "`");
+        result += (backtick + dbMapping.getTargetTable() + backtick);
         return result;
     }
+
+    /**
+     * 根据DbType返回反引号或空字符串
+     *
+     * @param dbTypeName DbType名称
+     * @return 反引号或空字符串
+     */
+    public static String getBacktickByDbType(String dbTypeName) {
+        DbType dbType = DbType.of(dbTypeName);
+        if (dbType == null) {
+            dbType = DbType.other;
+        }
+
+        // 只有当dbType为MySQL/MariaDB或OceanBase时返回反引号
+        switch (dbType) {
+            case mysql:
+            case mariadb:
+            case oceanbase:
+                return "`";
+            default:
+                return "";
+        }
+    }
 }