diff --git a/sqlx-core/src/net/tls/mod.rs b/sqlx-core/src/net/tls/mod.rs index 7bb1744189..5957e70d40 100644 --- a/sqlx-core/src/net/tls/mod.rs +++ b/sqlx-core/src/net/tls/mod.rs @@ -60,15 +60,37 @@ impl std::fmt::Display for CertificateInput { pub struct TlsConfig<'a> { pub accept_invalid_certs: bool, pub accept_invalid_hostnames: bool, - pub hostname: &'a str, pub root_cert_path: Option<&'a CertificateInput>, pub client_cert_path: Option<&'a CertificateInput>, pub client_key_path: Option<&'a CertificateInput>, } +#[cfg(feature = "_tls-native-tls")] +pub use self::tls_native_tls::NativeTlsConnector as TlsConnector; +#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))] +pub use self::tls_rustls::RustlsConnector as TlsConnector; +#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))] +#[derive(Debug, Clone)] +pub struct TlsConnector(std::convert::Infallible); + +pub async fn connector(config: TlsConfig<'_>) -> crate::Result { + #[cfg(feature = "_tls-native-tls")] + return tls_native_tls::connector(config).await; + + #[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))] + return tls_rustls::connector(config).await; + + #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))] + { + _ = config; + panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled") + } +} + pub async fn handshake( socket: S, - config: TlsConfig<'_>, + hostname: &str, + connector: &TlsConnector, with_socket: Ws, ) -> crate::Result where @@ -77,18 +99,18 @@ where { #[cfg(feature = "_tls-native-tls")] return Ok(with_socket - .with_socket(tls_native_tls::handshake(socket, config).await?) + .with_socket(tls_native_tls::handshake(socket, hostname, connector).await?) .await); #[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))] return Ok(with_socket - .with_socket(tls_rustls::handshake(socket, config).await?) + .with_socket(tls_rustls::handshake(socket, hostname, connector).await?) .await); #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))] { - drop((socket, config, with_socket)); - panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled") + drop((socket, hostname, with_socket)); + match connector.0 {} } } diff --git a/sqlx-core/src/net/tls/tls_native_tls.rs b/sqlx-core/src/net/tls/tls_native_tls.rs index 3423e48f8c..936cca093d 100644 --- a/sqlx-core/src/net/tls/tls_native_tls.rs +++ b/sqlx-core/src/net/tls/tls_native_tls.rs @@ -39,10 +39,12 @@ impl Socket for NativeTlsSocket { } } -pub async fn handshake( - socket: S, - config: TlsConfig<'_>, -) -> crate::Result> { +#[derive(Debug, Clone)] +pub struct NativeTlsConnector { + connector: native_tls::TlsConnector, +} + +pub async fn connector(config: TlsConfig<'_>) -> crate::Result { let mut builder = native_tls::TlsConnector::builder(); builder @@ -67,8 +69,18 @@ pub async fn handshake( let connector = rt::spawn_blocking(move || builder.build()) .await .map_err(Error::tls)?; + Ok(NativeTlsConnector { connector }) +} - let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) { +pub async fn handshake( + socket: S, + hostname: &str, + connector: &NativeTlsConnector, +) -> crate::Result> { + let mut mid_handshake = match connector + .connector + .connect(hostname, StdSocket::new(socket)) + { Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }), Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)), Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake, diff --git a/sqlx-core/src/net/tls/tls_rustls.rs b/sqlx-core/src/net/tls/tls_rustls.rs index 1ecbbad519..abc195a4f2 100644 --- a/sqlx-core/src/net/tls/tls_rustls.rs +++ b/sqlx-core/src/net/tls/tls_rustls.rs @@ -87,10 +87,12 @@ impl Socket for RustlsSocket { } } -pub async fn handshake(socket: S, tls_config: TlsConfig<'_>) -> Result, Error> -where - S: Socket, -{ +#[derive(Debug, Clone)] +pub struct RustlsConnector { + config: Arc, +} + +pub async fn connector(tls_config: TlsConfig<'_>) -> Result { #[cfg(all( feature = "_tls-rustls-aws-lc-rs", not(feature = "_tls-rustls-ring-webpki"), @@ -180,11 +182,24 @@ where } }; - let host = ServerName::try_from(tls_config.hostname.to_owned()).map_err(Error::tls)?; + Ok(RustlsConnector { + config: Arc::new(config), + }) +} + +pub async fn handshake( + socket: S, + hostname: &str, + connector: &RustlsConnector, +) -> Result, Error> +where + S: Socket, +{ + let host = ServerName::try_from(hostname.to_owned()).map_err(Error::tls)?; let mut socket = RustlsSocket { inner: StdSocket::new(socket), - state: ClientConnection::new(Arc::new(config), host).map_err(Error::tls)?, + state: ClientConnection::new(connector.config.clone(), host).map_err(Error::tls)?, close_notify_sent: false, }; diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..9e670cfc42 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -42,7 +42,7 @@ impl<'a> DoHandshake<'a> { fn new(options: &'a MySqlConnectOptions) -> Result { if options.enable_cleartext_plugin && matches!( - options.ssl_mode, + options.ssl_options.ssl_mode, MySqlSslMode::Disabled | MySqlSslMode::Preferred ) { diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index 9034fbd63a..dddb4c6aa5 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -20,13 +20,13 @@ pub(super) async fn maybe_upgrade( ) -> Result { let server_supports_tls = stream.capabilities.contains(Capabilities::SSL); - if matches!(options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() { + if matches!(options.ssl_options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() { // remove the SSL capability if SSL has been explicitly disabled stream.capabilities.remove(Capabilities::SSL); } // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS - match options.ssl_mode { + match options.ssl_options.ssl_mode { MySqlSslMode::Disabled => return Ok(stream.boxed_socket()), MySqlSslMode::Preferred => { @@ -53,16 +53,27 @@ pub(super) async fn maybe_upgrade( } } - let tls_config = TlsConfig { - accept_invalid_certs: !matches!( - options.ssl_mode, - MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity - ), - accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity), - hostname: &options.host, - root_cert_path: options.ssl_ca.as_ref(), - client_cert_path: options.ssl_client_cert.as_ref(), - client_key_path: options.ssl_client_key.as_ref(), + let connector = if let Some(c) = options.ssl_options.cached_connector.get() { + c + } else { + let tls_config = TlsConfig { + accept_invalid_certs: !matches!( + options.ssl_options.ssl_mode, + MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity + ), + accept_invalid_hostnames: !matches!( + options.ssl_options.ssl_mode, + MySqlSslMode::VerifyIdentity + ), + root_cert_path: options.ssl_options.ssl_ca.as_ref(), + client_cert_path: options.ssl_options.ssl_client_cert.as_ref(), + client_key_path: options.ssl_options.ssl_client_key.as_ref(), + }; + let connector = tls::connector(tls_config).await?; + options + .ssl_options + .cached_connector + .get_or_init(|| connector) }; // Request TLS upgrade @@ -75,7 +86,8 @@ pub(super) async fn maybe_upgrade( tls::handshake( stream.socket.into_inner(), - tls_config, + &options.host, + connector, MapStream { server_version: stream.server_version, capabilities: stream.capabilities, diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index 421bfb700e..db4bed0e88 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -1,10 +1,14 @@ -use std::path::{Path, PathBuf}; +use std::{ + path::{Path, PathBuf}, + sync::{Arc, OnceLock}, +}; mod connect; mod parse; mod ssl_mode; use crate::{connection::LogSettings, net::tls::CertificateInput}; +use sqlx_core::net::tls::TlsConnector; pub use ssl_mode::MySqlSslMode; /// Options and flags which can be used to configure a MySQL connection. @@ -67,10 +71,7 @@ pub struct MySqlConnectOptions { pub(crate) username: String, pub(crate) password: Option, pub(crate) database: Option, - pub(crate) ssl_mode: MySqlSslMode, - pub(crate) ssl_ca: Option, - pub(crate) ssl_client_cert: Option, - pub(crate) ssl_client_key: Option, + pub(crate) ssl_options: SslOptions, pub(crate) statement_cache_capacity: usize, pub(crate) charset: String, pub(crate) collation: Option, @@ -88,6 +89,15 @@ impl Default for MySqlConnectOptions { } } +#[derive(Debug, Clone)] +pub(crate) struct SslOptions { + pub(crate) ssl_mode: MySqlSslMode, + pub(crate) ssl_ca: Option, + pub(crate) ssl_client_cert: Option, + pub(crate) ssl_client_key: Option, + pub(crate) cached_connector: Arc>, +} + impl MySqlConnectOptions { /// Creates a new, default set of options ready for configuration pub fn new() -> Self { @@ -100,10 +110,13 @@ impl MySqlConnectOptions { database: None, charset: String::from("utf8mb4"), collation: None, - ssl_mode: MySqlSslMode::Preferred, - ssl_ca: None, - ssl_client_cert: None, - ssl_client_key: None, + ssl_options: SslOptions { + ssl_mode: MySqlSslMode::Preferred, + ssl_ca: None, + ssl_client_cert: None, + ssl_client_key: None, + cached_connector: Arc::new(OnceLock::new()), + }, statement_cache_capacity: 100, log_settings: Default::default(), pipes_as_concat: true, @@ -158,6 +171,11 @@ impl MySqlConnectOptions { self } + fn ssl_options_mut(&mut self) -> &mut SslOptions { + Arc::make_mut(&mut self.ssl_options.cached_connector).take(); + &mut self.ssl_options + } + /// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated /// with the server. /// @@ -172,7 +190,7 @@ impl MySqlConnectOptions { /// .ssl_mode(MySqlSslMode::Required); /// ``` pub fn ssl_mode(mut self, mode: MySqlSslMode) -> Self { - self.ssl_mode = mode; + self.ssl_options_mut().ssl_mode = mode; self } @@ -187,7 +205,7 @@ impl MySqlConnectOptions { /// .ssl_ca("path/to/ca.crt"); /// ``` pub fn ssl_ca(mut self, file_name: impl AsRef) -> Self { - self.ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned())); + self.ssl_options_mut().ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned())); self } @@ -202,7 +220,7 @@ impl MySqlConnectOptions { /// .ssl_ca_from_pem(vec![]); /// ``` pub fn ssl_ca_from_pem(mut self, pem_certificate: Vec) -> Self { - self.ssl_ca = Some(CertificateInput::Inline(pem_certificate)); + self.ssl_options_mut().ssl_ca = Some(CertificateInput::Inline(pem_certificate)); self } @@ -217,7 +235,8 @@ impl MySqlConnectOptions { /// .ssl_client_cert("path/to/client.crt"); /// ``` pub fn ssl_client_cert(mut self, cert: impl AsRef) -> Self { - self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); + self.ssl_options_mut().ssl_client_cert = + Some(CertificateInput::File(cert.as_ref().to_path_buf())); self } @@ -242,7 +261,8 @@ impl MySqlConnectOptions { /// .ssl_client_cert_from_pem(CERT); /// ``` pub fn ssl_client_cert_from_pem(mut self, cert: impl AsRef<[u8]>) -> Self { - self.ssl_client_cert = Some(CertificateInput::Inline(cert.as_ref().to_vec())); + self.ssl_options_mut().ssl_client_cert = + Some(CertificateInput::Inline(cert.as_ref().to_vec())); self } @@ -257,7 +277,8 @@ impl MySqlConnectOptions { /// .ssl_client_key("path/to/client.key"); /// ``` pub fn ssl_client_key(mut self, key: impl AsRef) -> Self { - self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf())); + self.ssl_options_mut().ssl_client_key = + Some(CertificateInput::File(key.as_ref().to_path_buf())); self } @@ -282,7 +303,8 @@ impl MySqlConnectOptions { /// .ssl_client_key_from_pem(KEY); /// ``` pub fn ssl_client_key_from_pem(mut self, key: impl AsRef<[u8]>) -> Self { - self.ssl_client_key = Some(CertificateInput::Inline(key.as_ref().to_vec())); + self.ssl_options_mut().ssl_client_key = + Some(CertificateInput::Inline(key.as_ref().to_vec())); self } @@ -497,7 +519,7 @@ impl MySqlConnectOptions { /// assert!(matches!(options.get_ssl_mode(), MySqlSslMode::Preferred)); /// ``` pub fn get_ssl_mode(&self) -> MySqlSslMode { - self.ssl_mode + self.ssl_options.ssl_mode } /// Get the server charset. diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index e31ddc46d4..37db00ef6c 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -103,7 +103,7 @@ impl MySqlConnectOptions { url.set_path(database); } - let ssl_mode = match self.ssl_mode { + let ssl_mode = match self.ssl_options.ssl_mode { MySqlSslMode::Disabled => "DISABLED", MySqlSslMode::Preferred => "PREFERRED", MySqlSslMode::Required => "REQUIRED", @@ -112,7 +112,7 @@ impl MySqlConnectOptions { }; url.query_pairs_mut().append_pair("ssl-mode", ssl_mode); - if let Some(ssl_ca) = &self.ssl_ca { + if let Some(ssl_ca) = &self.ssl_options.ssl_ca { url.query_pairs_mut() .append_pair("ssl-ca", &ssl_ca.to_string()); } @@ -123,12 +123,12 @@ impl MySqlConnectOptions { url.query_pairs_mut().append_pair("charset", collation); } - if let Some(ssl_client_cert) = &self.ssl_client_cert { + if let Some(ssl_client_cert) = &self.ssl_options.ssl_client_cert { url.query_pairs_mut() .append_pair("ssl-cert", &ssl_client_cert.to_string()); } - if let Some(ssl_client_key) = &self.ssl_client_key { + if let Some(ssl_client_key) = &self.ssl_options.ssl_client_key { url.query_pairs_mut() .append_pair("ssl-key", &ssl_client_key.to_string()); } diff --git a/sqlx-postgres/src/connection/tls.rs b/sqlx-postgres/src/connection/tls.rs index a49c9caa8c..7b8bea3341 100644 --- a/sqlx-postgres/src/connection/tls.rs +++ b/sqlx-postgres/src/connection/tls.rs @@ -20,7 +20,7 @@ async fn maybe_upgrade( options: &PgConnectOptions, ) -> Result, Error> { // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS - match options.ssl_mode { + match options.ssl_options.ssl_mode { // FIXME: Implement ALLOW PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)), @@ -45,22 +45,31 @@ async fn maybe_upgrade( } } - let accept_invalid_certs = !matches!( - options.ssl_mode, - PgSslMode::VerifyCa | PgSslMode::VerifyFull - ); - let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull); - - let config = TlsConfig { - accept_invalid_certs, - accept_invalid_hostnames, - hostname: &options.host, - root_cert_path: options.ssl_root_cert.as_ref(), - client_cert_path: options.ssl_client_cert.as_ref(), - client_key_path: options.ssl_client_key.as_ref(), + let connector = if let Some(c) = options.ssl_options.cached_connector.get() { + c + } else { + let accept_invalid_certs = !matches!( + options.ssl_options.ssl_mode, + PgSslMode::VerifyCa | PgSslMode::VerifyFull + ); + let accept_invalid_hostnames = + !matches!(options.ssl_options.ssl_mode, PgSslMode::VerifyFull); + + let config = TlsConfig { + accept_invalid_certs, + accept_invalid_hostnames, + root_cert_path: options.ssl_options.ssl_root_cert.as_ref(), + client_cert_path: options.ssl_options.ssl_client_cert.as_ref(), + client_key_path: options.ssl_options.ssl_client_key.as_ref(), + }; + let connector = tls::connector(config).await?; + options + .ssl_options + .cached_connector + .get_or_init(|| connector) }; - tls::handshake(socket, config, SocketIntoBox).await + tls::handshake(socket, &options.host, connector, SocketIntoBox).await } async fn request_upgrade( diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index 21e6628cae..1432673720 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -2,7 +2,9 @@ use std::borrow::Cow; use std::env::var; use std::fmt::{self, Display, Write}; use std::path::{Path, PathBuf}; +use std::sync::{Arc, OnceLock}; +use sqlx_core::net::tls::TlsConnector; pub use ssl_mode::PgSslMode; use crate::{connection::LogSettings, net::tls::CertificateInput}; @@ -21,10 +23,7 @@ pub struct PgConnectOptions { pub(crate) username: String, pub(crate) password: Option, pub(crate) database: Option, - pub(crate) ssl_mode: PgSslMode, - pub(crate) ssl_root_cert: Option, - pub(crate) ssl_client_cert: Option, - pub(crate) ssl_client_key: Option, + pub(crate) ssl_options: SslOptions, pub(crate) statement_cache_capacity: usize, pub(crate) application_name: Option, pub(crate) log_settings: LogSettings, @@ -38,6 +37,15 @@ impl Default for PgConnectOptions { } } +#[derive(Debug, Clone)] +pub(crate) struct SslOptions { + pub(crate) ssl_mode: PgSslMode, + pub(crate) ssl_root_cert: Option, + pub(crate) ssl_client_cert: Option, + pub(crate) ssl_client_key: Option, + pub(crate) cached_connector: Arc>, +} + impl PgConnectOptions { /// Create a default set of connection options populated from the current environment. /// @@ -82,16 +90,19 @@ impl PgConnectOptions { username, password: var("PGPASSWORD").ok(), database, - ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from), - ssl_client_cert: var("PGSSLCERT").ok().map(CertificateInput::from), - // As of writing, the implementation of `From` only looks for - // `-----BEGIN CERTIFICATE-----` and so will not attempt to parse - // a PEM-encoded private key. - ssl_client_key: var("PGSSLKEY").ok().map(CertificateInput::from), - ssl_mode: var("PGSSLMODE") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or_default(), + ssl_options: SslOptions { + ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from), + ssl_client_cert: var("PGSSLCERT").ok().map(CertificateInput::from), + // As of writing, the implementation of `From` only looks for + // `-----BEGIN CERTIFICATE-----` and so will not attempt to parse + // a PEM-encoded private key. + ssl_client_key: var("PGSSLKEY").ok().map(CertificateInput::from), + ssl_mode: var("PGSSLMODE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or_default(), + cached_connector: Arc::new(OnceLock::new()), + }, statement_cache_capacity: 100, application_name: var("PGAPPNAME").ok(), extra_float_digits: Some("2".into()), @@ -205,6 +216,11 @@ impl PgConnectOptions { self } + fn ssl_options_mut(&mut self) -> &mut SslOptions { + Arc::make_mut(&mut self.ssl_options.cached_connector).take(); + &mut self.ssl_options + } + /// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated /// with the server. /// @@ -221,7 +237,7 @@ impl PgConnectOptions { /// .ssl_mode(PgSslMode::Require); /// ``` pub fn ssl_mode(mut self, mode: PgSslMode) -> Self { - self.ssl_mode = mode; + self.ssl_options_mut().ssl_mode = mode; self } @@ -239,7 +255,8 @@ impl PgConnectOptions { /// .ssl_root_cert("./ca-certificate.crt"); /// ``` pub fn ssl_root_cert(mut self, cert: impl AsRef) -> Self { - self.ssl_root_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); + self.ssl_options_mut().ssl_root_cert = + Some(CertificateInput::File(cert.as_ref().to_path_buf())); self } @@ -255,7 +272,8 @@ impl PgConnectOptions { /// .ssl_client_cert("./client.crt"); /// ``` pub fn ssl_client_cert(mut self, cert: impl AsRef) -> Self { - self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); + self.ssl_options_mut().ssl_client_cert = + Some(CertificateInput::File(cert.as_ref().to_path_buf())); self } @@ -274,14 +292,15 @@ impl PgConnectOptions { /// -----BEGIN CERTIFICATE----- /// /// -----END CERTIFICATE-----"; - /// + /// /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_client_cert_from_pem(CERT); /// ``` pub fn ssl_client_cert_from_pem(mut self, cert: impl AsRef<[u8]>) -> Self { - self.ssl_client_cert = Some(CertificateInput::Inline(cert.as_ref().to_vec())); + self.ssl_options_mut().ssl_client_cert = + Some(CertificateInput::Inline(cert.as_ref().to_vec())); self } @@ -297,7 +316,8 @@ impl PgConnectOptions { /// .ssl_client_key("./client.key"); /// ``` pub fn ssl_client_key(mut self, key: impl AsRef) -> Self { - self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf())); + self.ssl_options_mut().ssl_client_key = + Some(CertificateInput::File(key.as_ref().to_path_buf())); self } @@ -323,7 +343,8 @@ impl PgConnectOptions { /// .ssl_client_key_from_pem(KEY); /// ``` pub fn ssl_client_key_from_pem(mut self, key: impl AsRef<[u8]>) -> Self { - self.ssl_client_key = Some(CertificateInput::Inline(key.as_ref().to_vec())); + self.ssl_options_mut().ssl_client_key = + Some(CertificateInput::Inline(key.as_ref().to_vec())); self } @@ -339,7 +360,7 @@ impl PgConnectOptions { /// .ssl_root_cert_from_pem(vec![]); /// ``` pub fn ssl_root_cert_from_pem(mut self, pem_certificate: Vec) -> Self { - self.ssl_root_cert = Some(CertificateInput::Inline(pem_certificate)); + self.ssl_options_mut().ssl_root_cert = Some(CertificateInput::Inline(pem_certificate)); self } @@ -550,7 +571,7 @@ impl PgConnectOptions { /// assert!(matches!(options.get_ssl_mode(), PgSslMode::Prefer)); /// ``` pub fn get_ssl_mode(&self) -> PgSslMode { - self.ssl_mode + self.ssl_options.ssl_mode } /// Get the application name. diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs index e911305698..df8be6366d 100644 --- a/sqlx-postgres/src/options/parse.rs +++ b/sqlx-postgres/src/options/parse.rs @@ -136,7 +136,7 @@ impl PgConnectOptions { url.set_path(database); } - let ssl_mode = match self.ssl_mode { + let ssl_mode = match self.ssl_options.ssl_mode { PgSslMode::Allow => "allow", PgSslMode::Disable => "disable", PgSslMode::Prefer => "prefer", @@ -146,17 +146,17 @@ impl PgConnectOptions { }; url.query_pairs_mut().append_pair("sslmode", ssl_mode); - if let Some(ssl_root_cert) = &self.ssl_root_cert { + if let Some(ssl_root_cert) = &self.ssl_options.ssl_root_cert { url.query_pairs_mut() .append_pair("sslrootcert", &ssl_root_cert.to_string()); } - if let Some(ssl_client_cert) = &self.ssl_client_cert { + if let Some(ssl_client_cert) = &self.ssl_options.ssl_client_cert { url.query_pairs_mut() .append_pair("sslcert", &ssl_client_cert.to_string()); } - if let Some(ssl_client_key) = &self.ssl_client_key { + if let Some(ssl_client_key) = &self.ssl_options.ssl_client_key { url.query_pairs_mut() .append_pair("sslkey", &ssl_client_key.to_string()); }