Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,21 @@ final class InitFlow {
* @param password the password of the {@code user}.
* @param compressionAlgorithms the list of compression algorithms.
* @param zstdCompressionLevel the zstd compression level.
* @param serverRSAPublicKeyFile the local file path of the MySQL server's public key
* @return a {@link Mono} that indicates the initialization is done, or an error if the initialization failed.
*/
static Mono<Void> initHandshake(Client client, SslMode sslMode, String database, String user,
@Nullable CharSequence password, Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel) {
@Nullable CharSequence password, Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@Nullable String serverRSAPublicKeyFile) {
return client.exchange(new HandshakeExchangeable(
client,
sslMode,
database,
user,
password,
compressionAlgorithms,
zstdCompressionLevel
zstdCompressionLevel,
serverRSAPublicKeyFile
)).then();
}

Expand Down Expand Up @@ -511,9 +514,12 @@ final class HandshakeExchangeable extends FluxExchangeable<Void> {

private boolean sslCompleted;

@Nullable
private String serverRSAPublicKeyFile;

HandshakeExchangeable(Client client, SslMode sslMode, String database, String user,
@Nullable CharSequence password, Set<CompressionAlgorithm> compressions,
int zstdCompressionLevel) {
int zstdCompressionLevel, @Nullable String serverRSAPublicKeyFile) {
this.client = client;
this.sslMode = sslMode;
this.database = database;
Expand All @@ -522,6 +528,7 @@ final class HandshakeExchangeable extends FluxExchangeable<Void> {
this.compressions = compressions;
this.zstdCompressionLevel = zstdCompressionLevel;
this.sslCompleted = sslMode == SslMode.TUNNEL;
this.serverRSAPublicKeyFile = serverRSAPublicKeyFile;
}

@Override
Expand Down Expand Up @@ -605,6 +612,11 @@ private AuthResponse createAuthResponse(String phase) {
MySqlAuthProvider authProvider = getAndNextProvider();

if (authProvider.isSslNecessary() && !sslCompleted) {
if (serverRSAPublicKeyFile != null && sslMode.equals(SslMode.DISABLED)) {
return new AuthResponse(MySqlAuthProvider.rsaEncryption(authProvider.authentication(
password, salt, client.getContext().getClientCollation()), serverRSAPublicKeyFile,
client.getContext().getServerVersion(), salt));
}
throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), phase), CLI_SPECIFIC);
}

Expand Down Expand Up @@ -709,12 +721,17 @@ private MySqlAuthProvider getAndNextProvider() {
private HandshakeResponse createHandshakeResponse(Capability capability) {
MySqlAuthProvider authProvider = getAndNextProvider();

if (authProvider.isSslNecessary() && !sslCompleted) {
throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"),
CLI_SPECIFIC);
if (authProvider.isSslNecessary() && !sslCompleted && serverRSAPublicKeyFile == null) {
throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"), CLI_SPECIFIC);
}

byte[] authorization = authProvider.authentication(password, salt, client.getContext().getClientCollation());

if (authProvider.isSslNecessary() && !sslCompleted && sslMode.equals(SslMode.DISABLED)) {
authorization = MySqlAuthProvider.rsaEncryption(authorization, serverRSAPublicKeyFile,
client.getContext().getServerVersion(), salt);
}

String authType = authProvider.getType();

if (MySqlAuthProvider.NO_AUTH_PROVIDER.equals(authType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ public final class MySqlConnectionConfiguration {

private final boolean tinyInt1isBit;

@Nullable
private final String serverRSAPublicKeyFile;

private MySqlConnectionConfiguration(
boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
Expand All @@ -153,7 +156,7 @@ private MySqlConnectionConfiguration(
Extensions extensions, @Nullable Publisher<String> passwordPublisher,
@Nullable AddressResolverGroup<?> resolver,
boolean metrics,
boolean tinyInt1isBit) {
boolean tinyInt1isBit, @Nullable String serverRSAPublicKeyFile) {
this.isHost = isHost;
this.domain = domain;
this.port = port;
Expand Down Expand Up @@ -185,6 +188,7 @@ private MySqlConnectionConfiguration(
this.resolver = resolver;
this.metrics = metrics;
this.tinyInt1isBit = tinyInt1isBit;
this.serverRSAPublicKeyFile = serverRSAPublicKeyFile;
}

/**
Expand Down Expand Up @@ -328,6 +332,11 @@ boolean isTinyInt1isBit() {
return tinyInt1isBit;
}

@Nullable
String getServerRSAPublicKeyFile() {
return serverRSAPublicKeyFile;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -367,7 +376,8 @@ public boolean equals(Object o) {
Objects.equals(passwordPublisher, that.passwordPublisher) &&
Objects.equals(resolver, that.resolver) &&
metrics == that.metrics &&
tinyInt1isBit == that.tinyInt1isBit;
tinyInt1isBit == that.tinyInt1isBit &&
Objects.equals(serverRSAPublicKeyFile, that.serverRSAPublicKeyFile);
}

@Override
Expand All @@ -382,7 +392,8 @@ public int hashCode() {
loadLocalInfilePath, localInfileBufferSize,
queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel,
loopResources, extensions, passwordPublisher, resolver, metrics, tinyInt1isBit);
loopResources, extensions, passwordPublisher, resolver, metrics, tinyInt1isBit,
serverRSAPublicKeyFile);
}

@Override
Expand Down Expand Up @@ -418,7 +429,8 @@ private String buildCommonToStringPart() {
", passwordPublisher=" + passwordPublisher +
", resolver=" + resolver +
", metrics=" + metrics +
", tinyInt1isBit=" + tinyInt1isBit;
", tinyInt1isBit=" + tinyInt1isBit +
", serverRSAPublicKeyFile=" + serverRSAPublicKeyFile;
}

/**
Expand Down Expand Up @@ -522,6 +534,9 @@ public static final class Builder {

private boolean tinyInt1isBit = true;

@Nullable
private String serverRSAPublicKeyFile;

/**
* Builds an immutable {@link MySqlConnectionConfiguration} with current options.
*
Expand Down Expand Up @@ -556,7 +571,8 @@ public MySqlConnectionConfiguration build() {
loadLocalInfilePath,
localInfileBufferSize, queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel, loopResources,
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver, metrics, tinyInt1isBit);
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver, metrics, tinyInt1isBit,
serverRSAPublicKeyFile);
}

/**
Expand Down Expand Up @@ -1234,6 +1250,21 @@ public Builder tinyInt1isBit(boolean tinyInt1isBit) {
return this;
}

/**
* Option to configure the database server's RSA Public Key file path on the local system if RSA encryption
* is desired such as when using caching_sha2_password authentication type while SSLMode is DISABLED. If
* serverRSAPublicKeyFile not null and SSLMode is not DISABLED, SSL encryption takes precedence.
*
* @param serverRSAPublicKeyFile the local file path of the database server's RSA Public Key file or
* {@code null} when RSA encryption not desired
* @return this {@link Builder}
* @since 1.4.2
*/
public Builder serverRSAPublicKeyFile(@Nullable String serverRSAPublicKeyFile) {
this.serverRSAPublicKeyFile = serverRSAPublicKeyFile;
return this;
}

private SslMode requireSslMode() {
SslMode sslMode = this.sslMode;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ private static Mono<MySqlConnection> getMySqlConnection(
user,
password,
configuration.getCompressionAlgorithms(),
configuration.getZstdCompressionLevel()
configuration.getZstdCompressionLevel(),
configuration.getServerRSAPublicKeyFile()
).then(InitFlow.initSession(
client,
sessionDb,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,17 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr
*/
public static final Option<Boolean> TINY_INT_1_IS_BIT = Option.valueOf("tinyInt1isBit");

/**
* Option to configure the database server's RSA Public Key file path on the local system if RSA encryption
* is desired such as when using caching_sha2_password authentication type while
* {@link io.asyncer.r2dbc.mysql.constant.SslMode} is DISABLED. If serverRSAPublicKeyFile not
* {@code null} and {@link io.asyncer.r2dbc.mysql.constant.SslMode} is not DISABLED, SSL encryption
* takes precedence.
*
* @since 1.4.2
*/
public static final Option<String> SERVER_RSA_PUBLIC_KEY_FILE = Option.valueOf("serverRSAPublicKeyFile");

@Override
public ConnectionFactory create(ConnectionFactoryOptions options) {
requireNonNull(options, "connectionFactoryOptions must not be null");
Expand Down Expand Up @@ -438,6 +449,8 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) {
.to(builder::metrics);
mapper.optional(TINY_INT_1_IS_BIT).asBoolean()
.to(builder::tinyInt1isBit);
mapper.optional(SERVER_RSA_PUBLIC_KEY_FILE).asString()
.to(builder::serverRSAPublicKeyFile);

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,17 @@ private static byte[] allBytesXor(byte[] left, byte[] right) {
return left;
}

static byte[] rotatingXor(byte[] password, byte[] seedBytes) {
int seedLength = seedBytes.length;
int passwordLength = password.length;
byte[] buffer = new byte[passwordLength];

for (int i = 0; i < passwordLength; i++) {
buffer[i] = (byte) (password[i] ^ seedBytes[i % seedLength]);
}

return buffer;
}

private AuthUtils() { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,30 @@

package io.asyncer.r2dbc.mysql.authentication;

import io.asyncer.r2dbc.mysql.ServerVersion;
import io.asyncer.r2dbc.mysql.collation.CharCollation;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import org.jetbrains.annotations.Nullable;

import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;

/**
* An abstraction of the MySQL authorization plugin provider for connection phase. More information for MySQL
* authentication type:
Expand Down Expand Up @@ -124,4 +142,57 @@ static MySqlAuthProvider build(String type) {
* @return the next provider
*/
MySqlAuthProvider next();

/**
* Encrypts data with the RSA Public Key of MySQL server
* @param bytesToEncrypt the data to encrypt
* @param serverRSAPublicKeyFile the file path on the local system of the database server's RSA Public Key
* @param serverVersion the version of the MySQL server
* @param seed the seed bytes for rotating XOR obfuscation
* @return the encrypted bytes
*/
static byte[] rsaEncryption(byte[] bytesToEncrypt, String serverRsaPublicKeyFile, ServerVersion serverVersion,
byte[] seed) {
try {
bytesToEncrypt = AuthUtils.rotatingXor(bytesToEncrypt, seed);

String key = new String(Files.readAllBytes(Paths.get(serverRsaPublicKeyFile)), Charset.defaultCharset());

int startIndex = key.indexOf("-----BEGIN PUBLIC KEY-----") + 26;
int endIndex = key.indexOf("-----END PUBLIC KEY-----");
key = key.substring(startIndex, endIndex);
String publicKeyPEM = key.replaceAll(System.lineSeparator(), "");

byte[] encoded = Base64.getDecoder().decode(publicKeyPEM);

KeyFactory keyFactory = KeyFactory.getInstance("RSA");
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(encoded);
RSAPublicKey pk = (RSAPublicKey) keyFactory.generatePublic(keySpec);

Cipher cipher;
if (serverVersion.isGreaterThanOrEqualTo(ServerVersion.create(8, 0, 5))) {
cipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-1AndMGF1Padding");
} else {
cipher = Cipher.getInstance("RSA/ECB/PKCS1Padding");
}
cipher.init(Cipher.ENCRYPT_MODE, pk);
return cipher.doFinal(bytesToEncrypt);
} catch (IOException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
} catch (InvalidKeySpecException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
} catch (NoSuchPaddingException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
} catch (InvalidKeyException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
} catch (IllegalBlockSizeException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
} catch (BadPaddingException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
} catch (IndexOutOfBoundsException e) {
throw new IllegalArgumentException(e.getLocalizedMessage(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ private static MySqlConnectionConfiguration filledUp() {
.lockWaitTimeout(Duration.ofSeconds(5))
.statementTimeout(Duration.ofSeconds(10))
.autodetectExtensions(false)
.serverRSAPublicKeyFile("/path/to/mysql/serverRSAPublicKey.pem")
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.METRICS;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.PASSWORD_PUBLISHER;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.RESOLVER;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.SERVER_RSA_PUBLIC_KEY_FILE;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.USE_SERVER_PREPARE_STATEMENT;
import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT;
import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE;
Expand Down Expand Up @@ -530,6 +531,18 @@ void sessionVariables(String input, List<String> expected) {
assertThat(MySqlConnectionFactoryProvider.setup(options).getSessionVariables()).isEqualTo(expected);
}

@Test
void validServerRSAPublicKeyFile() {
ConnectionFactoryOptions options = ConnectionFactoryOptions.builder()
.option(DRIVER, "mysql")
.option(HOST, "127.0.0.1")
.option(USER, "root")
.option(SERVER_RSA_PUBLIC_KEY_FILE, "/path/to/mysql/serverRSAPublicKey.pem")
.build();

assertThat(MySqlConnectionFactoryProvider.setup(options).getServerRSAPublicKeyFile()).isEqualTo("/path/to/mysql/serverRSAPublicKey.pem");
}

static Stream<Arguments> sessionVariables() {
return Stream.of(
Arguments.of("", Collections.emptyList()),
Expand Down
Loading