diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index efbc43989b..e77f0798ed 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::env::var; +use std::ffi::OsStr; use std::fmt::{self, Display, Write}; use std::path::{Path, PathBuf}; @@ -20,6 +21,7 @@ pub struct PgConnectOptions { pub(crate) socket: Option, pub(crate) username: String, pub(crate) password: Option, + pub(crate) passfile_paths: Vec, pub(crate) database: Option, pub(crate) ssl_mode: PgSslMode, pub(crate) ssl_root_cert: Option, @@ -74,6 +76,7 @@ impl PgConnectOptions { socket: None, username, password: var("PGPASSWORD").ok(), + passfile_paths: vec![], database, ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from), ssl_client_cert: var("PGSSLCERT").ok().map(CertificateInput::from), @@ -100,6 +103,7 @@ impl PgConnectOptions { self.port, &self.username, self.database.as_deref(), + &self.passfile_paths, ); } @@ -184,6 +188,23 @@ impl PgConnectOptions { self } + /// Sets the paths to try when looking for the pgpass file. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .passfile_paths(&["/non/default/pgpass"]); + /// ``` + pub fn passfile_paths

(mut self, paths: impl IntoIterator) -> Self + where + P: Into + AsRef, + { + self.passfile_paths = paths.into_iter().map(Into::into).collect(); + self + } + /// Sets the database name. Defaults to be the same as the user name. /// /// # Example diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs index e911305698..54e82420a0 100644 --- a/sqlx-postgres/src/options/parse.rs +++ b/sqlx-postgres/src/options/parse.rs @@ -3,6 +3,7 @@ use crate::{PgConnectOptions, PgSslMode}; use sqlx_core::percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; use std::net::IpAddr; +use std::path::PathBuf; use std::str::FromStr; impl PgConnectOptions { @@ -87,6 +88,8 @@ impl PgConnectOptions { "password" => options = options.password(&value), + "passfile" => options.passfile_paths.insert(0, PathBuf::from(&*value)), + "application_name" => options = options.application_name(&value), "options" => { @@ -242,6 +245,20 @@ fn it_parses_password_correctly_from_parameter() { assert_eq!(Some("some_pass"), opts.password.as_deref()); } +#[test] +fn it_parses_passfile_correctly_from_parameter() { + let url = "postgres:///?passfile=/non%20default/pgpass&passfile=.pgpass"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!( + vec![ + PathBuf::from(".pgpass"), + PathBuf::from("/non default/pgpass"), + ], + opts.passfile_paths + ); +} + #[test] fn it_parses_application_name_correctly_from_parameter() { let url = "postgres:///?application_name=some_name"; diff --git a/sqlx-postgres/src/options/pgpass.rs b/sqlx-postgres/src/options/pgpass.rs index bf16559548..6afc6a7228 100644 --- a/sqlx-postgres/src/options/pgpass.rs +++ b/sqlx-postgres/src/options/pgpass.rs @@ -2,46 +2,59 @@ use std::borrow::Cow; use std::env::var_os; use std::fs::File; use std::io::{BufRead, BufReader}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; -/// try to load a password from the various pgpass file locations +/// Try to load a password from the various pgpass file locations. +/// +/// Loading is attempted in the following order: +/// 1. Path given via the `PGPASSFILE` environment variable. +/// 2. Paths given via custom_paths. +/// 3. Default path (`~/.pgpass` on Linux and `%APPDATA%/postgres/pgpass.conf` +/// on Windows) pub fn load_password( host: &str, port: u16, username: &str, database: Option<&str>, + custom_paths: &[impl AsRef], ) -> Option { - let custom_file = var_os("PGPASSFILE"); - if let Some(file) = custom_file { - if let Some(password) = - load_password_from_file(PathBuf::from(file), host, port, username, database) - { - return Some(password); - } - } + let env_path = var_os("PGPASSFILE").map(PathBuf::from); + let default_path = default_path(); + + let path_iter = env_path + .as_deref() + .into_iter() + .chain(custom_paths.iter().map(AsRef::as_ref)) + .chain(default_path.as_deref()); + + path_iter + .filter_map(|path| load_password_from_file(path, host, port, username, database)) + .next() +} + +#[cfg(not(target_os = "windows"))] +fn default_path() -> Option { + home::home_dir().map(|path| path.join(".pgpass")) +} + +#[cfg(target_os = "windows")] +fn default_path() -> Option { + use etcetera::BaseStrategy; - #[cfg(not(target_os = "windows"))] - let default_file = home::home_dir().map(|path| path.join(".pgpass")); - #[cfg(target_os = "windows")] - let default_file = { - use etcetera::BaseStrategy; - - etcetera::base_strategy::Windows::new() - .ok() - .map(|basedirs| basedirs.data_dir().join("postgres").join("pgpass.conf")) - }; - load_password_from_file(default_file?, host, port, username, database) + etcetera::base_strategy::Windows::new() + .ok() + .map(|basedirs| basedirs.data_dir().join("postgres").join("pgpass.conf")) } /// try to extract a password from a pgpass file fn load_password_from_file( - path: PathBuf, + path: &Path, host: &str, port: u16, username: &str, database: Option<&str>, ) -> Option { - let file = File::open(&path) + let file = File::open(path) .map_err(|e| { match e.kind() { std::io::ErrorKind::NotFound => {