浏览代码

fixed issue #1660 ,support mysql8.0 caching_sha2_password auth

agapple 6 年之前
父节点
当前提交
b95c0d03cf

+ 1 - 1
deployer/src/main/resources/example/instance.properties

@@ -48,6 +48,6 @@ canal.mq.topic=example
 #canal.mq.dynamicTopic=mytest1.user,mytest2\\..*,.*\\..*
 canal.mq.partition=0
 # hash partition config
-#canal.mq.partitionsNum=4
+#canal.mq.partitionsNum=3
 #canal.mq.partitionHash=test.table:id^name,.*\\..*
 #################################################

+ 32 - 11
driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/MysqlConnector.java

@@ -2,6 +2,7 @@ package com.alibaba.otter.canal.parse.driver.mysql;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.security.DigestException;
 import java.security.NoSuchAlgorithmException;
 import java.util.concurrent.atomic.AtomicBoolean;
 
@@ -218,28 +219,48 @@ public class MysqlConnector {
                 pluginName = packet.authName;
             }
 
+            boolean isSha2Password = false;
+            byte[] encryptedPassword = null;
             if (pluginName != null && "mysql_native_password".equals(pluginName)) {
-                byte[] encryptedPassword = null;
                 try {
                     encryptedPassword = MySQLPasswordEncrypter.scramble411(getPassword().getBytes(), authData);
                 } catch (NoSuchAlgorithmException e) {
                     throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
                 }
-                AuthSwitchResponsePacket responsePacket = new AuthSwitchResponsePacket();
-                responsePacket.authData = encryptedPassword;
-                byte[] auth = responsePacket.toBytes();
-
-                h = new HeaderPacket();
-                h.setPacketBodyLength(auth.length);
-                h.setPacketSequenceNumber((byte) (header.getPacketSequenceNumber() + 1));
-                PacketManager.writePkg(channel, h.toBytes(), auth);
-                logger.info("auth switch response packet is sent out.");
+            } else if (pluginName != null && "caching_sha2_password".equals(pluginName)) {
+                isSha2Password = true;
+                try {
+                    encryptedPassword = MySQLPasswordEncrypter.scrambleCachingSha2(getPassword().getBytes(), authData);
+                } catch (DigestException e) {
+                    throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
+                }
+            }
+            assert encryptedPassword != null;
+            AuthSwitchResponsePacket responsePacket = new AuthSwitchResponsePacket();
+            responsePacket.authData = encryptedPassword;
+            byte[] auth = responsePacket.toBytes();
+
+            h = new HeaderPacket();
+            h.setPacketBodyLength(auth.length);
+            h.setPacketSequenceNumber((byte) (header.getPacketSequenceNumber() + 1));
+            PacketManager.writePkg(channel, h.toBytes(), auth);
+            logger.info("auth switch response packet is sent out.");
+
+            header = null;
+            header = PacketManager.readHeader(channel, 4);
+            body = null;
+            body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
+            assert body != null;
+            if (isSha2Password) {
+                if (body[0] == 0x01 && body[1] == 0x04) {
+                    // password auth failed
+                    throw new IOException("caching_sha2_password Auth failed");
+                }
 
                 header = null;
                 header = PacketManager.readHeader(channel, 4);
                 body = null;
                 body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
-                assert body != null;
             }
         }
 

+ 46 - 0
driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/utils/MySQLPasswordEncrypter.java

@@ -1,10 +1,47 @@
 package com.alibaba.otter.canal.parse.driver.mysql.utils;
 
+import java.security.DigestException;
 import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 
 public class MySQLPasswordEncrypter {
 
+    private static final int CACHING_SHA2_DIGEST_LENGTH = 32;
+
+    public static byte[] scrambleCachingSha2(byte[] password, byte[] seed) throws DigestException {
+        MessageDigest md;
+        try {
+            md = MessageDigest.getInstance("SHA-256");
+        } catch (NoSuchAlgorithmException ex) {
+            throw new DigestException(ex);
+        }
+
+        byte[] dig1 = new byte[CACHING_SHA2_DIGEST_LENGTH];
+        byte[] dig2 = new byte[CACHING_SHA2_DIGEST_LENGTH];
+        byte[] scramble1 = new byte[CACHING_SHA2_DIGEST_LENGTH];
+
+        // SHA2(src) => digest_stage1
+        md.update(password, 0, password.length);
+        md.digest(dig1, 0, CACHING_SHA2_DIGEST_LENGTH);
+        md.reset();
+
+        // SHA2(digest_stage1) => digest_stage2
+        md.update(dig1, 0, dig1.length);
+        md.digest(dig2, 0, CACHING_SHA2_DIGEST_LENGTH);
+        md.reset();
+
+        // SHA2(digest_stage2, m_rnd) => scramble_stage1
+        md.update(dig2, 0, dig1.length);
+        md.update(seed, 0, seed.length);
+        md.digest(scramble1, 0, CACHING_SHA2_DIGEST_LENGTH);
+
+        // XOR(digest_stage1, scramble_stage1) => scramble
+        byte[] mysqlScrambleBuff = new byte[CACHING_SHA2_DIGEST_LENGTH];
+        xorString(dig1, mysqlScrambleBuff, scramble1, CACHING_SHA2_DIGEST_LENGTH);
+
+        return mysqlScrambleBuff;
+    }
+
     public static final byte[] scramble411(byte[] pass, byte[] seed) throws NoSuchAlgorithmException {
         MessageDigest md = MessageDigest.getInstance("SHA-1");
         byte[] pass1 = md.digest(pass);
@@ -71,4 +108,13 @@ public class MySQLPasswordEncrypter {
         return result;
     }
 
+    private static void xorString(byte[] from, byte[] to, byte[] scramble, int length) {
+        int pos = 0;
+        int scrambleLength = scramble.length;
+        while (pos < length) {
+            to[pos] = (byte) (from[pos] ^ scramble[pos % scrambleLength]);
+            pos++;
+        }
+    }
+
 }