Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions sqlx-core/src/net/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,37 @@ impl std::fmt::Display for CertificateInput {
pub struct TlsConfig<'a> {
pub accept_invalid_certs: bool,
pub accept_invalid_hostnames: bool,
pub hostname: &'a str,
pub root_cert_path: Option<&'a CertificateInput>,
pub client_cert_path: Option<&'a CertificateInput>,
pub client_key_path: Option<&'a CertificateInput>,
}

#[cfg(feature = "_tls-native-tls")]
pub use self::tls_native_tls::NativeTlsConnector as TlsConnector;
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
pub use self::tls_rustls::RustlsConnector as TlsConnector;
#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
#[derive(Debug, Clone)]
pub struct TlsConnector(std::convert::Infallible);

pub async fn connector(config: TlsConfig<'_>) -> crate::Result<TlsConnector> {
#[cfg(feature = "_tls-native-tls")]
return tls_native_tls::connector(config).await;

#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
return tls_rustls::connector(config).await;

#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
{
_ = config;
panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled")
}
}

pub async fn handshake<S, Ws>(
socket: S,
config: TlsConfig<'_>,
hostname: &str,
connector: &TlsConnector,
with_socket: Ws,
) -> crate::Result<Ws::Output>
where
Expand All @@ -77,18 +99,18 @@ where
{
#[cfg(feature = "_tls-native-tls")]
return Ok(with_socket
.with_socket(tls_native_tls::handshake(socket, config).await?)
.with_socket(tls_native_tls::handshake(socket, hostname, connector).await?)
.await);

#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
return Ok(with_socket
.with_socket(tls_rustls::handshake(socket, config).await?)
.with_socket(tls_rustls::handshake(socket, hostname, connector).await?)
.await);

#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
{
drop((socket, config, with_socket));
panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled")
drop((socket, hostname, with_socket));
match connector.0 {}
}
}

Expand Down
22 changes: 17 additions & 5 deletions sqlx-core/src/net/tls/tls_native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ impl<S: Socket> Socket for NativeTlsSocket<S> {
}
}

pub async fn handshake<S: Socket>(
socket: S,
config: TlsConfig<'_>,
) -> crate::Result<NativeTlsSocket<S>> {
#[derive(Debug, Clone)]
pub struct NativeTlsConnector {
connector: native_tls::TlsConnector,
}

pub async fn connector(config: TlsConfig<'_>) -> crate::Result<NativeTlsConnector> {
let mut builder = native_tls::TlsConnector::builder();

builder
Expand All @@ -67,8 +69,18 @@ pub async fn handshake<S: Socket>(
let connector = rt::spawn_blocking(move || builder.build())
.await
.map_err(Error::tls)?;
Ok(NativeTlsConnector { connector })
}

let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) {
pub async fn handshake<S: Socket>(
socket: S,
hostname: &str,
connector: &NativeTlsConnector,
) -> crate::Result<NativeTlsSocket<S>> {
let mut mid_handshake = match connector
.connector
.connect(hostname, StdSocket::new(socket))
{
Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }),
Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)),
Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake,
Expand Down
27 changes: 21 additions & 6 deletions sqlx-core/src/net/tls/tls_rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ impl<S: Socket> Socket for RustlsSocket<S> {
}
}

pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
where
S: Socket,
{
#[derive(Debug, Clone)]
pub struct RustlsConnector {
config: Arc<ClientConfig>,
}

pub async fn connector(tls_config: TlsConfig<'_>) -> Result<RustlsConnector, Error> {
#[cfg(all(
feature = "_tls-rustls-aws-lc-rs",
not(feature = "_tls-rustls-ring-webpki"),
Expand Down Expand Up @@ -180,11 +182,24 @@ where
}
};

let host = ServerName::try_from(tls_config.hostname.to_owned()).map_err(Error::tls)?;
Ok(RustlsConnector {
config: Arc::new(config),
})
}

pub async fn handshake<S>(
socket: S,
hostname: &str,
connector: &RustlsConnector,
) -> Result<RustlsSocket<S>, Error>
where
S: Socket,
{
let host = ServerName::try_from(hostname.to_owned()).map_err(Error::tls)?;

let mut socket = RustlsSocket {
inner: StdSocket::new(socket),
state: ClientConnection::new(Arc::new(config), host).map_err(Error::tls)?,
state: ClientConnection::new(connector.config.clone(), host).map_err(Error::tls)?,
close_notify_sent: false,
};

Expand Down
2 changes: 1 addition & 1 deletion sqlx-mysql/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<'a> DoHandshake<'a> {
fn new(options: &'a MySqlConnectOptions) -> Result<Self, Error> {
if options.enable_cleartext_plugin
&& matches!(
options.ssl_mode,
options.ssl_options.ssl_mode,
MySqlSslMode::Disabled | MySqlSslMode::Preferred
)
{
Expand Down
38 changes: 25 additions & 13 deletions sqlx-mysql/src/connection/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ pub(super) async fn maybe_upgrade<S: Socket>(
) -> Result<MySqlStream, Error> {
let server_supports_tls = stream.capabilities.contains(Capabilities::SSL);

if matches!(options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() {
if matches!(options.ssl_options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() {
// remove the SSL capability if SSL has been explicitly disabled
stream.capabilities.remove(Capabilities::SSL);
}

// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
match options.ssl_mode {
match options.ssl_options.ssl_mode {
MySqlSslMode::Disabled => return Ok(stream.boxed_socket()),

MySqlSslMode::Preferred => {
Expand All @@ -53,16 +53,27 @@ pub(super) async fn maybe_upgrade<S: Socket>(
}
}

let tls_config = TlsConfig {
accept_invalid_certs: !matches!(
options.ssl_mode,
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
),
accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity),
hostname: &options.host,
root_cert_path: options.ssl_ca.as_ref(),
client_cert_path: options.ssl_client_cert.as_ref(),
client_key_path: options.ssl_client_key.as_ref(),
let connector = if let Some(c) = options.ssl_options.cached_connector.get() {
c
} else {
let tls_config = TlsConfig {
accept_invalid_certs: !matches!(
options.ssl_options.ssl_mode,
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
),
accept_invalid_hostnames: !matches!(
options.ssl_options.ssl_mode,
MySqlSslMode::VerifyIdentity
),
root_cert_path: options.ssl_options.ssl_ca.as_ref(),
client_cert_path: options.ssl_options.ssl_client_cert.as_ref(),
client_key_path: options.ssl_options.ssl_client_key.as_ref(),
};
let connector = tls::connector(tls_config).await?;
options
.ssl_options
.cached_connector
.get_or_init(|| connector)
};

// Request TLS upgrade
Expand All @@ -75,7 +86,8 @@ pub(super) async fn maybe_upgrade<S: Socket>(

tls::handshake(
stream.socket.into_inner(),
tls_config,
&options.host,
connector,
MapStream {
server_version: stream.server_version,
capabilities: stream.capabilities,
Expand Down
56 changes: 39 additions & 17 deletions sqlx-mysql/src/options/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use std::path::{Path, PathBuf};
use std::{
path::{Path, PathBuf},
sync::{Arc, OnceLock},
};

mod connect;
mod parse;
mod ssl_mode;

use crate::{connection::LogSettings, net::tls::CertificateInput};
use sqlx_core::net::tls::TlsConnector;
pub use ssl_mode::MySqlSslMode;

/// Options and flags which can be used to configure a MySQL connection.
Expand Down Expand Up @@ -67,10 +71,7 @@ pub struct MySqlConnectOptions {
pub(crate) username: String,
pub(crate) password: Option<String>,
pub(crate) database: Option<String>,
pub(crate) ssl_mode: MySqlSslMode,
pub(crate) ssl_ca: Option<CertificateInput>,
pub(crate) ssl_client_cert: Option<CertificateInput>,
pub(crate) ssl_client_key: Option<CertificateInput>,
pub(crate) ssl_options: SslOptions,
pub(crate) statement_cache_capacity: usize,
pub(crate) charset: String,
pub(crate) collation: Option<String>,
Expand All @@ -88,6 +89,15 @@ impl Default for MySqlConnectOptions {
}
}

#[derive(Debug, Clone)]
pub(crate) struct SslOptions {
pub(crate) ssl_mode: MySqlSslMode,
pub(crate) ssl_ca: Option<CertificateInput>,
pub(crate) ssl_client_cert: Option<CertificateInput>,
pub(crate) ssl_client_key: Option<CertificateInput>,
pub(crate) cached_connector: Arc<OnceLock<TlsConnector>>,
}

impl MySqlConnectOptions {
/// Creates a new, default set of options ready for configuration
pub fn new() -> Self {
Expand All @@ -100,10 +110,13 @@ impl MySqlConnectOptions {
database: None,
charset: String::from("utf8mb4"),
collation: None,
ssl_mode: MySqlSslMode::Preferred,
ssl_ca: None,
ssl_client_cert: None,
ssl_client_key: None,
ssl_options: SslOptions {
ssl_mode: MySqlSslMode::Preferred,
ssl_ca: None,
ssl_client_cert: None,
ssl_client_key: None,
cached_connector: Arc::new(OnceLock::new()),
},
statement_cache_capacity: 100,
log_settings: Default::default(),
pipes_as_concat: true,
Expand Down Expand Up @@ -158,6 +171,11 @@ impl MySqlConnectOptions {
self
}

fn ssl_options_mut(&mut self) -> &mut SslOptions {
Arc::make_mut(&mut self.ssl_options.cached_connector).take();
&mut self.ssl_options
}

/// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated
/// with the server.
///
Expand All @@ -172,7 +190,7 @@ impl MySqlConnectOptions {
/// .ssl_mode(MySqlSslMode::Required);
/// ```
pub fn ssl_mode(mut self, mode: MySqlSslMode) -> Self {
self.ssl_mode = mode;
self.ssl_options_mut().ssl_mode = mode;
self
}

Expand All @@ -187,7 +205,7 @@ impl MySqlConnectOptions {
/// .ssl_ca("path/to/ca.crt");
/// ```
pub fn ssl_ca(mut self, file_name: impl AsRef<Path>) -> Self {
self.ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned()));
self.ssl_options_mut().ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned()));
self
}

Expand All @@ -202,7 +220,7 @@ impl MySqlConnectOptions {
/// .ssl_ca_from_pem(vec![]);
/// ```
pub fn ssl_ca_from_pem(mut self, pem_certificate: Vec<u8>) -> Self {
self.ssl_ca = Some(CertificateInput::Inline(pem_certificate));
self.ssl_options_mut().ssl_ca = Some(CertificateInput::Inline(pem_certificate));
self
}

Expand All @@ -217,7 +235,8 @@ impl MySqlConnectOptions {
/// .ssl_client_cert("path/to/client.crt");
/// ```
pub fn ssl_client_cert(mut self, cert: impl AsRef<Path>) -> Self {
self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf()));
self.ssl_options_mut().ssl_client_cert =
Some(CertificateInput::File(cert.as_ref().to_path_buf()));
self
}

Expand All @@ -242,7 +261,8 @@ impl MySqlConnectOptions {
/// .ssl_client_cert_from_pem(CERT);
/// ```
pub fn ssl_client_cert_from_pem(mut self, cert: impl AsRef<[u8]>) -> Self {
self.ssl_client_cert = Some(CertificateInput::Inline(cert.as_ref().to_vec()));
self.ssl_options_mut().ssl_client_cert =
Some(CertificateInput::Inline(cert.as_ref().to_vec()));
self
}

Expand All @@ -257,7 +277,8 @@ impl MySqlConnectOptions {
/// .ssl_client_key("path/to/client.key");
/// ```
pub fn ssl_client_key(mut self, key: impl AsRef<Path>) -> Self {
self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf()));
self.ssl_options_mut().ssl_client_key =
Some(CertificateInput::File(key.as_ref().to_path_buf()));
self
}

Expand All @@ -282,7 +303,8 @@ impl MySqlConnectOptions {
/// .ssl_client_key_from_pem(KEY);
/// ```
pub fn ssl_client_key_from_pem(mut self, key: impl AsRef<[u8]>) -> Self {
self.ssl_client_key = Some(CertificateInput::Inline(key.as_ref().to_vec()));
self.ssl_options_mut().ssl_client_key =
Some(CertificateInput::Inline(key.as_ref().to_vec()));
self
}

Expand Down Expand Up @@ -497,7 +519,7 @@ impl MySqlConnectOptions {
/// assert!(matches!(options.get_ssl_mode(), MySqlSslMode::Preferred));
/// ```
pub fn get_ssl_mode(&self) -> MySqlSslMode {
self.ssl_mode
self.ssl_options.ssl_mode
}

/// Get the server charset.
Expand Down
8 changes: 4 additions & 4 deletions sqlx-mysql/src/options/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl MySqlConnectOptions {
url.set_path(database);
}

let ssl_mode = match self.ssl_mode {
let ssl_mode = match self.ssl_options.ssl_mode {
MySqlSslMode::Disabled => "DISABLED",
MySqlSslMode::Preferred => "PREFERRED",
MySqlSslMode::Required => "REQUIRED",
Expand All @@ -112,7 +112,7 @@ impl MySqlConnectOptions {
};
url.query_pairs_mut().append_pair("ssl-mode", ssl_mode);

if let Some(ssl_ca) = &self.ssl_ca {
if let Some(ssl_ca) = &self.ssl_options.ssl_ca {
url.query_pairs_mut()
.append_pair("ssl-ca", &ssl_ca.to_string());
}
Expand All @@ -123,12 +123,12 @@ impl MySqlConnectOptions {
url.query_pairs_mut().append_pair("charset", collation);
}

if let Some(ssl_client_cert) = &self.ssl_client_cert {
if let Some(ssl_client_cert) = &self.ssl_options.ssl_client_cert {
url.query_pairs_mut()
.append_pair("ssl-cert", &ssl_client_cert.to_string());
}

if let Some(ssl_client_key) = &self.ssl_client_key {
if let Some(ssl_client_key) = &self.ssl_options.ssl_client_key {
url.query_pairs_mut()
.append_pair("ssl-key", &ssl_client_key.to_string());
}
Expand Down
Loading
Loading