diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java index a7c13c596..2f3acf22c 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java @@ -126,10 +126,12 @@ final class InitFlow { * @param password the password of the {@code user}. * @param compressionAlgorithms the list of compression algorithms. * @param zstdCompressionLevel the zstd compression level. + * @param serverRSAPublicKeyFile the local file path of the MySQL server's public key * @return a {@link Mono} that indicates the initialization is done, or an error if the initialization failed. */ static Mono initHandshake(Client client, SslMode sslMode, String database, String user, - @Nullable CharSequence password, Set compressionAlgorithms, int zstdCompressionLevel) { + @Nullable CharSequence password, Set compressionAlgorithms, int zstdCompressionLevel, + @Nullable String serverRSAPublicKeyFile) { return client.exchange(new HandshakeExchangeable( client, sslMode, @@ -137,7 +139,8 @@ static Mono initHandshake(Client client, SslMode sslMode, String database, user, password, compressionAlgorithms, - zstdCompressionLevel + zstdCompressionLevel, + serverRSAPublicKeyFile )).then(); } @@ -511,9 +514,12 @@ final class HandshakeExchangeable extends FluxExchangeable { private boolean sslCompleted; + @Nullable + private String serverRSAPublicKeyFile; + HandshakeExchangeable(Client client, SslMode sslMode, String database, String user, @Nullable CharSequence password, Set compressions, - int zstdCompressionLevel) { + int zstdCompressionLevel, @Nullable String serverRSAPublicKeyFile) { this.client = client; this.sslMode = sslMode; this.database = database; @@ -522,6 +528,7 @@ final class HandshakeExchangeable extends FluxExchangeable { this.compressions = compressions; this.zstdCompressionLevel = zstdCompressionLevel; this.sslCompleted = sslMode == SslMode.TUNNEL; + this.serverRSAPublicKeyFile = serverRSAPublicKeyFile; } @Override @@ -605,6 +612,11 @@ private AuthResponse createAuthResponse(String phase) { MySqlAuthProvider authProvider = getAndNextProvider(); if (authProvider.isSslNecessary() && !sslCompleted) { + if (serverRSAPublicKeyFile != null && sslMode.equals(SslMode.DISABLED)) { + return new AuthResponse(MySqlAuthProvider.rsaEncryption(authProvider.authentication( + password, salt, client.getContext().getClientCollation()), serverRSAPublicKeyFile, + client.getContext().getServerVersion(), salt)); + } throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), phase), CLI_SPECIFIC); } @@ -709,12 +721,17 @@ private MySqlAuthProvider getAndNextProvider() { private HandshakeResponse createHandshakeResponse(Capability capability) { MySqlAuthProvider authProvider = getAndNextProvider(); - if (authProvider.isSslNecessary() && !sslCompleted) { - throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"), - CLI_SPECIFIC); + if (authProvider.isSslNecessary() && !sslCompleted && serverRSAPublicKeyFile == null) { + throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"), CLI_SPECIFIC); } byte[] authorization = authProvider.authentication(password, salt, client.getContext().getClientCollation()); + + if (authProvider.isSslNecessary() && !sslCompleted && sslMode.equals(SslMode.DISABLED)) { + authorization = MySqlAuthProvider.rsaEncryption(authorization, serverRSAPublicKeyFile, + client.getContext().getServerVersion(), salt); + } + String authType = authProvider.getType(); if (MySqlAuthProvider.NO_AUTH_PROVIDER.equals(authType)) { diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java index 39fb91eb6..af17ae80c 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java @@ -136,6 +136,9 @@ public final class MySqlConnectionConfiguration { private final boolean tinyInt1isBit; + @Nullable + private final String serverRSAPublicKeyFile; + private MySqlConnectionConfiguration( boolean isHost, String domain, int port, MySqlSslConfiguration ssl, boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout, @@ -153,7 +156,7 @@ private MySqlConnectionConfiguration( Extensions extensions, @Nullable Publisher passwordPublisher, @Nullable AddressResolverGroup resolver, boolean metrics, - boolean tinyInt1isBit) { + boolean tinyInt1isBit, @Nullable String serverRSAPublicKeyFile) { this.isHost = isHost; this.domain = domain; this.port = port; @@ -185,6 +188,7 @@ private MySqlConnectionConfiguration( this.resolver = resolver; this.metrics = metrics; this.tinyInt1isBit = tinyInt1isBit; + this.serverRSAPublicKeyFile = serverRSAPublicKeyFile; } /** @@ -328,6 +332,11 @@ boolean isTinyInt1isBit() { return tinyInt1isBit; } + @Nullable + String getServerRSAPublicKeyFile() { + return serverRSAPublicKeyFile; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -367,7 +376,8 @@ public boolean equals(Object o) { Objects.equals(passwordPublisher, that.passwordPublisher) && Objects.equals(resolver, that.resolver) && metrics == that.metrics && - tinyInt1isBit == that.tinyInt1isBit; + tinyInt1isBit == that.tinyInt1isBit && + Objects.equals(serverRSAPublicKeyFile, that.serverRSAPublicKeyFile); } @Override @@ -382,7 +392,8 @@ public int hashCode() { loadLocalInfilePath, localInfileBufferSize, queryCacheSize, prepareCacheSize, compressionAlgorithms, zstdCompressionLevel, - loopResources, extensions, passwordPublisher, resolver, metrics, tinyInt1isBit); + loopResources, extensions, passwordPublisher, resolver, metrics, tinyInt1isBit, + serverRSAPublicKeyFile); } @Override @@ -418,7 +429,8 @@ private String buildCommonToStringPart() { ", passwordPublisher=" + passwordPublisher + ", resolver=" + resolver + ", metrics=" + metrics + - ", tinyInt1isBit=" + tinyInt1isBit; + ", tinyInt1isBit=" + tinyInt1isBit + + ", serverRSAPublicKeyFile=" + serverRSAPublicKeyFile; } /** @@ -522,6 +534,9 @@ public static final class Builder { private boolean tinyInt1isBit = true; + @Nullable + private String serverRSAPublicKeyFile; + /** * Builds an immutable {@link MySqlConnectionConfiguration} with current options. * @@ -556,7 +571,8 @@ public MySqlConnectionConfiguration build() { loadLocalInfilePath, localInfileBufferSize, queryCacheSize, prepareCacheSize, compressionAlgorithms, zstdCompressionLevel, loopResources, - Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver, metrics, tinyInt1isBit); + Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver, metrics, tinyInt1isBit, + serverRSAPublicKeyFile); } /** @@ -1234,6 +1250,21 @@ public Builder tinyInt1isBit(boolean tinyInt1isBit) { return this; } + /** + * Option to configure the database server's RSA Public Key file path on the local system if RSA encryption + * is desired such as when using caching_sha2_password authentication type while SSLMode is DISABLED. If + * serverRSAPublicKeyFile not null and SSLMode is not DISABLED, SSL encryption takes precedence. + * + * @param serverRSAPublicKeyFile the local file path of the database server's RSA Public Key file or + * {@code null} when RSA encryption not desired + * @return this {@link Builder} + * @since 1.4.2 + */ + public Builder serverRSAPublicKeyFile(@Nullable String serverRSAPublicKeyFile) { + this.serverRSAPublicKeyFile = serverRSAPublicKeyFile; + return this; + } + private SslMode requireSslMode() { SslMode sslMode = this.sslMode; diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java index 094674f2a..faae5f485 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java @@ -162,7 +162,8 @@ private static Mono getMySqlConnection( user, password, configuration.getCompressionAlgorithms(), - configuration.getZstdCompressionLevel() + configuration.getZstdCompressionLevel(), + configuration.getServerRSAPublicKeyFile() ).then(InitFlow.initSession( client, sessionDb, diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java index 5905c56ca..fe6d2a60d 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java @@ -341,6 +341,17 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr */ public static final Option TINY_INT_1_IS_BIT = Option.valueOf("tinyInt1isBit"); + /** + * Option to configure the database server's RSA Public Key file path on the local system if RSA encryption + * is desired such as when using caching_sha2_password authentication type while + * {@link io.asyncer.r2dbc.mysql.constant.SslMode} is DISABLED. If serverRSAPublicKeyFile not + * {@code null} and {@link io.asyncer.r2dbc.mysql.constant.SslMode} is not DISABLED, SSL encryption + * takes precedence. + * + * @since 1.4.2 + */ + public static final Option SERVER_RSA_PUBLIC_KEY_FILE = Option.valueOf("serverRSAPublicKeyFile"); + @Override public ConnectionFactory create(ConnectionFactoryOptions options) { requireNonNull(options, "connectionFactoryOptions must not be null"); @@ -438,6 +449,8 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) { .to(builder::metrics); mapper.optional(TINY_INT_1_IS_BIT).asBoolean() .to(builder::tinyInt1isBit); + mapper.optional(SERVER_RSA_PUBLIC_KEY_FILE).asString() + .to(builder::serverRSAPublicKeyFile); return builder.build(); } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java index 2a9cd84e8..0dc004952 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java @@ -113,5 +113,17 @@ private static byte[] allBytesXor(byte[] left, byte[] right) { return left; } + static byte[] rotatingXor(byte[] password, byte[] seedBytes) { + int seedLength = seedBytes.length; + int passwordLength = password.length; + byte[] buffer = new byte[passwordLength]; + + for (int i = 0; i < passwordLength; i++) { + buffer[i] = (byte) (password[i] ^ seedBytes[i % seedLength]); + } + + return buffer; + } + private AuthUtils() { } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlAuthProvider.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlAuthProvider.java index 2ae157271..63a033274 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlAuthProvider.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlAuthProvider.java @@ -16,12 +16,30 @@ package io.asyncer.r2dbc.mysql.authentication; +import io.asyncer.r2dbc.mysql.ServerVersion; import io.asyncer.r2dbc.mysql.collation.CharCollation; import io.r2dbc.spi.R2dbcPermissionDeniedException; import org.jetbrains.annotations.Nullable; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.InvalidKeyException; +import java.security.KeyFactory; +import java.security.NoSuchAlgorithmException; +import java.security.interfaces.RSAPublicKey; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.X509EncodedKeySpec; +import java.util.Base64; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; + /** * An abstraction of the MySQL authorization plugin provider for connection phase. More information for MySQL * authentication type: @@ -124,4 +142,57 @@ static MySqlAuthProvider build(String type) { * @return the next provider */ MySqlAuthProvider next(); + + /** + * Encrypts data with the RSA Public Key of MySQL server + * @param bytesToEncrypt the data to encrypt + * @param serverRSAPublicKeyFile the file path on the local system of the database server's RSA Public Key + * @param serverVersion the version of the MySQL server + * @param seed the seed bytes for rotating XOR obfuscation + * @return the encrypted bytes + */ + static byte[] rsaEncryption(byte[] bytesToEncrypt, String serverRsaPublicKeyFile, ServerVersion serverVersion, + byte[] seed) { + try { + bytesToEncrypt = AuthUtils.rotatingXor(bytesToEncrypt, seed); + + String key = new String(Files.readAllBytes(Paths.get(serverRsaPublicKeyFile)), Charset.defaultCharset()); + + int startIndex = key.indexOf("-----BEGIN PUBLIC KEY-----") + 26; + int endIndex = key.indexOf("-----END PUBLIC KEY-----"); + key = key.substring(startIndex, endIndex); + String publicKeyPEM = key.replaceAll(System.lineSeparator(), ""); + + byte[] encoded = Base64.getDecoder().decode(publicKeyPEM); + + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + X509EncodedKeySpec keySpec = new X509EncodedKeySpec(encoded); + RSAPublicKey pk = (RSAPublicKey) keyFactory.generatePublic(keySpec); + + Cipher cipher; + if (serverVersion.isGreaterThanOrEqualTo(ServerVersion.create(8, 0, 5))) { + cipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-1AndMGF1Padding"); + } else { + cipher = Cipher.getInstance("RSA/ECB/PKCS1Padding"); + } + cipher.init(Cipher.ENCRYPT_MODE, pk); + return cipher.doFinal(bytesToEncrypt); + } catch (IOException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } catch (NoSuchAlgorithmException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } catch (InvalidKeySpecException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } catch (NoSuchPaddingException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } catch (InvalidKeyException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } catch (IllegalBlockSizeException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } catch (BadPaddingException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } catch (IndexOutOfBoundsException e) { + throw new IllegalArgumentException(e.getLocalizedMessage(), e); + } + } } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java index e62fea190..c653aeb18 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java @@ -282,6 +282,7 @@ private static MySqlConnectionConfiguration filledUp() { .lockWaitTimeout(Duration.ofSeconds(5)) .statementTimeout(Duration.ofSeconds(10)) .autodetectExtensions(false) + .serverRSAPublicKeyFile("/path/to/mysql/serverRSAPublicKey.pem") .build(); } } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java index be48a2255..beb33e59a 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java @@ -54,6 +54,7 @@ import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.METRICS; import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.PASSWORD_PUBLISHER; import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.RESOLVER; +import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.SERVER_RSA_PUBLIC_KEY_FILE; import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.USE_SERVER_PREPARE_STATEMENT; import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT; import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE; @@ -530,6 +531,18 @@ void sessionVariables(String input, List expected) { assertThat(MySqlConnectionFactoryProvider.setup(options).getSessionVariables()).isEqualTo(expected); } + @Test + void validServerRSAPublicKeyFile() { + ConnectionFactoryOptions options = ConnectionFactoryOptions.builder() + .option(DRIVER, "mysql") + .option(HOST, "127.0.0.1") + .option(USER, "root") + .option(SERVER_RSA_PUBLIC_KEY_FILE, "/path/to/mysql/serverRSAPublicKey.pem") + .build(); + + assertThat(MySqlConnectionFactoryProvider.setup(options).getServerRSAPublicKeyFile()).isEqualTo("/path/to/mysql/serverRSAPublicKey.pem"); + } + static Stream sessionVariables() { return Stream.of( Arguments.of("", Collections.emptyList()), diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/authentication/AuthUtilsTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/authentication/AuthUtilsTest.java new file mode 100644 index 000000000..195b3192b --- /dev/null +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/authentication/AuthUtilsTest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.authentication; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; + +/** + * Unit tests for {@link AuthUtils} + */ +public class AuthUtilsTest { + + static byte[] seedBytes; + + static { + try { + seedBytes = generateSalt(20); + } catch (NoSuchAlgorithmException e) { + seedBytes = "random".getBytes(); + } + } + + @Test + void rotatingXor() { + byte[] password = "abc123".getBytes(); + + assertDoesNotThrow(() -> AuthUtils.rotatingXor(password, seedBytes)); + } + + static byte[] generateSalt(int length) throws NoSuchAlgorithmException { + SecureRandom sr = SecureRandom.getInstance("SHA1PRNG"); + byte[] salt = new byte[length]; + sr.nextBytes(salt); + return salt; + } +}