Browse Source

增加对mysql的caching_sha2_password认证插件fullauth流程支持 (#4767)

高原 2 years ago
parent
commit
b8c2a757a0

+ 75 - 29
driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/MysqlConnector.java

@@ -26,7 +26,7 @@ import com.alibaba.otter.canal.parse.driver.mysql.utils.PacketManager;
 
 
 /**
 /**
  * 基于mysql socket协议的链接实现
  * 基于mysql socket协议的链接实现
- * 
+ *
  * @author jianghang 2013-2-18 下午09:22:30
  * @author jianghang 2013-2-18 下午09:22:30
  * @version 1.0.1
  * @version 1.0.1
  */
  */
@@ -220,55 +220,45 @@ public class MysqlConnector {
                 packet.fromBytes(body);
                 packet.fromBytes(body);
                 authData = packet.authData;
                 authData = packet.authData;
                 pluginName = packet.authName;
                 pluginName = packet.authName;
+                logger.info("auth switch pluginName is {}.", pluginName);
             }
             }
 
 
-            boolean isSha2Password = false;
             byte[] encryptedPassword = null;
             byte[] encryptedPassword = null;
             if ("mysql_clear_password".equals(pluginName)) {
             if ("mysql_clear_password".equals(pluginName)) {
                 encryptedPassword = getPassword().getBytes();
                 encryptedPassword = getPassword().getBytes();
+                header = authSwitchAfterAuth(encryptedPassword, header);
+                body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
             } else if ("mysql_native_password".equals(pluginName)) {
             } else if ("mysql_native_password".equals(pluginName)) {
                 try {
                 try {
                     encryptedPassword = MySQLPasswordEncrypter.scramble411(getPassword().getBytes(), authData);
                     encryptedPassword = MySQLPasswordEncrypter.scramble411(getPassword().getBytes(), authData);
                 } catch (NoSuchAlgorithmException e) {
                 } catch (NoSuchAlgorithmException e) {
                     throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
                     throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
                 }
                 }
+                header = authSwitchAfterAuth(encryptedPassword, header);
+                body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
             } else if ("caching_sha2_password".equals(pluginName)) {
             } else if ("caching_sha2_password".equals(pluginName)) {
-                isSha2Password = true;
+                byte[] scramble = authData;
                 try {
                 try {
-                    encryptedPassword = MySQLPasswordEncrypter.scrambleCachingSha2(getPassword().getBytes(), authData);
+                    encryptedPassword = MySQLPasswordEncrypter.scrambleCachingSha2(getPassword().getBytes(), scramble);
                 } catch (DigestException e) {
                 } catch (DigestException e) {
                     throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
                     throw new RuntimeException("can't encrypt password that will be sent to MySQL server.", e);
                 }
                 }
-            }
-            assert encryptedPassword != null;
-            AuthSwitchResponsePacket responsePacket = new AuthSwitchResponsePacket();
-            responsePacket.authData = encryptedPassword;
-            byte[] auth = responsePacket.toBytes();
-
-            h = new HeaderPacket();
-            h.setPacketBodyLength(auth.length);
-            h.setPacketSequenceNumber((byte) (header.getPacketSequenceNumber() + 1));
-            PacketManager.writePkg(channel, h.toBytes(), auth);
-            logger.info("auth switch response packet is sent out.");
-
-            header = null;
-            header = PacketManager.readHeader(channel, 4);
-            body = null;
-            body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
-            assert body != null;
-            if (isSha2Password) {
+                header = authSwitchAfterAuth(encryptedPassword, header);
+                body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
+                assert body != null;
                 if (body[0] == 0x01 && body[1] == 0x04) {
                 if (body[0] == 0x01 && body[1] == 0x04) {
-                    // password auth failed
-                    throw new IOException("caching_sha2_password Auth failed");
+                    header = cachingSha2PasswordFullAuth(channel, header, getPassword().getBytes(), scramble);
+                    body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
+                } else {
+                    header = PacketManager.readHeader(channel, 4);
+                    body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
                 }
                 }
-
-                header = null;
-                header = PacketManager.readHeader(channel, 4);
-                body = null;
+            } else {
+                header = authSwitchAfterAuth(encryptedPassword, header);
                 body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
                 body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
             }
             }
         }
         }
-
+        assert body != null;
         if (body[0] < 0) {
         if (body[0] < 0) {
             if (body[0] == -1) {
             if (body[0] == -1) {
                 ErrorPacket err = new ErrorPacket();
                 ErrorPacket err = new ErrorPacket();
@@ -280,6 +270,62 @@ public class MysqlConnector {
         }
         }
     }
     }
 
 
+    private HeaderPacket cachingSha2PasswordFullAuth(SocketChannel channel, HeaderPacket header, byte[] pass,
+                                                     byte[] seed) throws IOException {
+        AuthSwitchResponsePacket responsePacket = new AuthSwitchResponsePacket();
+        responsePacket.authData = new byte[] { 2 };
+        byte[] auth = responsePacket.toBytes();
+        HeaderPacket h = new HeaderPacket();
+        h.setPacketBodyLength(auth.length);
+        h.setPacketSequenceNumber((byte) (header.getPacketSequenceNumber() + 1));
+        PacketManager.writePkg(channel, h.toBytes(), auth);
+        logger.info("caching sha2 password fullAuth request public key packet is sent out.");
+
+        header = PacketManager.readHeader(channel, 4);
+        byte[] body = PacketManager.readBytes(channel, header.getPacketBodyLength(), timeout);
+        AuthSwitchRequestMoreData packet = new AuthSwitchRequestMoreData();
+        packet.fromBytes(body);
+        if (packet.status != 0x01) {
+            throw new IOException("caching_sha2_password get public key failed");
+        }
+        logger.info("caching sha2 password fullAuth get server public key succeed.");
+        byte[] publicKeyBytes = packet.authData;
+
+        byte[] encryptedPassword = null;
+        try {
+            encryptedPassword = MySQLPasswordEncrypter.scrambleRsa(publicKeyBytes, pass, seed);
+        } catch (Exception e) {
+            logger.error("rsa encrypt failed {}", publicKeyBytes);
+            throw new IOException("caching_sha2_password auth failed", e);
+        }
+
+        // send auth
+        responsePacket = new AuthSwitchResponsePacket();
+        responsePacket.authData = encryptedPassword;
+        auth = responsePacket.toBytes();
+        h = new HeaderPacket();
+        h.setPacketBodyLength(auth.length);
+        h.setPacketSequenceNumber((byte) (header.getPacketSequenceNumber() + 1));
+        PacketManager.writePkg(channel, h.toBytes(), auth);
+        logger.info("caching sha2 password fullAuth response auth data packet is sent out.");
+        return PacketManager.readHeader(channel, 4);
+    }
+
+    private HeaderPacket authSwitchAfterAuth(byte[] encryptedPassword, HeaderPacket header) throws IOException {
+        assert encryptedPassword != null;
+        AuthSwitchResponsePacket responsePacket = new AuthSwitchResponsePacket();
+        responsePacket.authData = encryptedPassword;
+        byte[] auth = responsePacket.toBytes();
+
+        HeaderPacket h = new HeaderPacket();
+        h.setPacketBodyLength(auth.length);
+        h.setPacketSequenceNumber((byte) (header.getPacketSequenceNumber() + 1));
+        PacketManager.writePkg(channel, h.toBytes(), auth);
+        logger.info("auth switch response packet is sent out.");
+        header = PacketManager.readHeader(channel, 4);
+        return header;
+    }
+
     private void auth323(SocketChannel channel, byte packetSequenceNumber, byte[] seed) throws IOException {
     private void auth323(SocketChannel channel, byte packetSequenceNumber, byte[] seed) throws IOException {
         // auth 323
         // auth 323
         Reply323Packet r323 = new Reply323Packet();
         Reply323Packet r323 = new Reply323Packet();

+ 29 - 0
driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/utils/MySQLPasswordEncrypter.java

@@ -1,8 +1,17 @@
 package com.alibaba.otter.canal.parse.driver.mysql.utils;
 package com.alibaba.otter.canal.parse.driver.mysql.utils;
 
 
 import java.security.DigestException;
 import java.security.DigestException;
+import java.security.InvalidKeyException;
+import java.security.KeyFactory;
 import java.security.MessageDigest;
 import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 import java.security.NoSuchAlgorithmException;
+import java.security.PublicKey;
+import java.security.spec.InvalidKeySpecException;
+import java.security.spec.X509EncodedKeySpec;
+import javax.crypto.BadPaddingException;
+import javax.crypto.Cipher;
+import javax.crypto.IllegalBlockSizeException;
+import javax.crypto.NoSuchPaddingException;
 
 
 public class MySQLPasswordEncrypter {
 public class MySQLPasswordEncrypter {
 
 
@@ -85,6 +94,26 @@ public class MySQLPasswordEncrypter {
         return new String(chars);
         return new String(chars);
     }
     }
 
 
+    public static final byte[] scrambleRsa(byte[] publicKeyBytes, byte[] pass,
+                                           byte[] seed) throws NoSuchAlgorithmException, InvalidKeySpecException,
+                                                        NoSuchPaddingException, InvalidKeyException,
+                                                        IllegalBlockSizeException, BadPaddingException {
+        byte[] input = new byte[pass.length + 1];
+        System.arraycopy(pass, 0, input, 0, pass.length);
+        byte[] encryptedPassword = new byte[input.length];
+        xorString(input, encryptedPassword, seed, input.length);
+        String publicKeyPem = new String(publicKeyBytes).replace("\n", "")
+            .replace("-----BEGIN PUBLIC KEY-----", "")
+            .replace("-----END PUBLIC KEY-----", "");
+        byte[] certificateData = java.util.Base64.getDecoder().decode(publicKeyPem.getBytes());
+        X509EncodedKeySpec keySpec = new X509EncodedKeySpec(certificateData);
+        KeyFactory keyFactory = KeyFactory.getInstance("RSA");
+        PublicKey publicKey = keyFactory.generatePublic(keySpec);
+        Cipher cipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-1AndMGF1Padding");
+        cipher.init(Cipher.ENCRYPT_MODE, publicKey);
+        return cipher.doFinal(encryptedPassword);
+    }
+
     private static long[] hash(String src) {
     private static long[] hash(String src) {
         long nr = 1345345333L;
         long nr = 1345345333L;
         long add = 7;
         long add = 7;