Skip to content

Commit

Permalink
Add support for unix domain sockets
Browse files Browse the repository at this point in the history
Preferrably, we get a Connection impl for UnixStream
into hyper_util to avoid the new UnixStreamWrapper.

Closes #39
  • Loading branch information
flash-freezing-lava committed Jul 18, 2024
1 parent c660535 commit 8ebac73
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ wasm-bindgen-test = "0.3"
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] }

# to test unix domain sockets
[target.'cfg(unix)'.dev-dependencies]
tempfile = "3.3.0"

[[example]]
name = "blocking"
path = "examples/blocking.rs"
Expand Down
102 changes: 102 additions & 0 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use hyper_util::client::legacy::connect::{Connected, Connection};
use hyper_util::rt::TokioIo;
#[cfg(feature = "default-tls")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
#[cfg(all(feature = "__tls", unix))]
use tokio::net::UnixStream;
use tower_service::Service;

use pin_project_lite::pin_project;
Expand All @@ -18,6 +20,8 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
#[cfg(unix)]
use std::path::Path;

#[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn;
Expand Down Expand Up @@ -203,6 +207,16 @@ impl Connector {
self.verbose.0 = enabled;
}

#[cfg(unix)]
async fn connect_unix_socket<P: AsRef<Path>>(&self, socket: P) -> Result<Conn, BoxError> {
let unix_stream = unix_socket_conn::connect(socket).await?;
Ok(Conn {
inner: self.verbose.wrap(unix_stream),
is_proxy: false, // defaults to false to have the same behavior as curl's --unix-socket
tls_info: false,
})
}

#[cfg(feature = "socks")]
async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
let dns = match proxy {
Expand All @@ -215,6 +229,10 @@ impl Connector {
ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
unreachable!("connect_socks is only called for socks proxies");
}
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => {
unreachable!("connect_socks is only called for socks proxies");
}
};

match &self.inner {
Expand Down Expand Up @@ -368,6 +386,8 @@ impl Connector {
ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
#[cfg(unix)]
ProxyScheme::UnixSocket { socket } => return self.connect_unix_socket(socket).await,
};

#[cfg(feature = "__tls")]
Expand Down Expand Up @@ -611,6 +631,13 @@ impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpSt
}
}

#[cfg(all(feature = "__tls", unix))]
impl TlsInfoFactory for UnixStream {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
None
}
}

pub(crate) trait AsyncConn:
Read + Write + Connection + Send + Sync + Unpin + 'static
{
Expand Down Expand Up @@ -1089,6 +1116,81 @@ mod socks {
}
}

#[cfg(unix)]
mod unix_socket_conn {
use std::path::Path;
use hyper_util::client::legacy::connect::{Connected, Connection};
use hyper_util::rt::TokioIo;
use tokio::net::UnixStream;
use hyper::rt::{Read, Write};
use pin_project_lite::pin_project;
use crate::error::BoxError;

#[cfg(feature = "__tls")]
use super::TlsInfoFactory;

pub async fn connect<P: AsRef<Path>>(socket: P) -> Result<UnixStreamWrapper, BoxError> {
let target_stream = UnixStream::connect(&socket).await?;
Ok(UnixStreamWrapper { inner: TokioIo::new(target_stream) })
}

pin_project! {
/// This wrapper is necessary because Connection is not implemented for UnixStream in hyper_utils.
pub struct UnixStreamWrapper {
#[pin]
inner: TokioIo<UnixStream>,
}
}

impl Connection for UnixStreamWrapper {
fn connected(&self) -> Connected {
Connected::new()
}
}

impl Write for UnixStreamWrapper {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let proj = self.project();
proj.inner.poll_write(cx, buf)
}

fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
let proj = self.project();
proj.inner.poll_flush(cx)
}

fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let proj = self.project();
proj.inner.poll_shutdown(cx)
}
}

impl Read for UnixStreamWrapper {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: hyper::rt::ReadBufCursor<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let proj = self.project();
proj.inner.poll_read(cx, buf)
}
}

#[cfg(feature = "__tls")]
impl TlsInfoFactory for UnixStreamWrapper {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
None
}
}
}

mod verbose {
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_util::client::legacy::connect::{Connected, Connection};
Expand Down
2 changes: 1 addition & 1 deletion src/into_url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl IntoUrlSealed for Url {
return Ok(self);
}

if self.has_host() {
if self.scheme() == "unix" || self.has_host() {
Ok(self)
} else {
Err(crate::error::url_bad_scheme(self))
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@
//! export https_proxy=socks5://127.0.0.1:1086
//! ```
//!
//! You can aso configure a proxy to send requests through unix domain sockets (see [Proxy](Proxy) for details).
//!
//! ## TLS
//!
//! A `Client` will use transport layer security (TLS) by default to connect to
Expand Down
66 changes: 66 additions & 0 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use system_configuration::{
sys::schema_definitions::kSCPropNetProxiesHTTPSPort,
sys::schema_definitions::kSCPropNetProxiesHTTPSProxy,
};
#[cfg(unix)]
use std::path::PathBuf;
#[cfg(target_os = "windows")]
use winreg::enums::HKEY_CURRENT_USER;
#[cfg(target_os = "windows")]
Expand Down Expand Up @@ -66,6 +68,16 @@ use winreg::RegKey;
/// # Ok(())
/// # }
/// ```
///
/// On unix, it is also possible to send request to a unix socket via url or [Proxy::unix]:
/// ```rust
/// # fn run() -> Result<(), Box<dyn std::error::Error>> {
/// let proxy = reqwest::Proxy::all("unix:///run/snapd.socket")?;
/// // equivalent to:
/// let proxy = reqwest::Proxy::unix("/run/snapd.socket");
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct Proxy {
intercept: Intercept,
Expand Down Expand Up @@ -115,6 +127,10 @@ pub enum ProxyScheme {
auth: Option<(String, String)>,
remote_dns: bool,
},
#[cfg(unix)]
UnixSocket {
socket: PathBuf,
},
}

impl ProxyScheme {
Expand All @@ -123,6 +139,8 @@ impl ProxyScheme {
ProxyScheme::Http { auth, .. } | ProxyScheme::Https { auth, .. } => auth.as_ref(),
#[cfg(feature = "socks")]
_ => None,
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => None,
}
}
}
Expand Down Expand Up @@ -250,6 +268,26 @@ impl Proxy {
)))
}

/// Proxy **all** traffic to the passed unix domain socket.
///
/// # Example
///
/// ```
/// # extern crate reqwest;
/// # fn run() -> Result<(), Box<dyn std::error::Error>> {
/// let client = reqwest::Client::builder()
/// .proxy(reqwest::Proxy::unix("/run/snapd.socket"))
/// .build()?;
/// # Ok(())
/// # }
/// # fn main() {}
/// ```
pub fn unix<Path: Into<PathBuf>>(socket_path: Path) -> Proxy {
Proxy::new(Intercept::All(
ProxyScheme::unix_socket(socket_path),
))
}

/// Provide a custom function to determine what traffic to proxy to where.
///
/// # Example
Expand Down Expand Up @@ -611,6 +649,14 @@ impl ProxyScheme {
})
}

/// Proxy traffic via the specified URL over HTTPS
#[cfg(unix)]
fn unix_socket<Path: Into<PathBuf>>(path: Path) -> Self {
ProxyScheme::UnixSocket {
socket: path.into(),
}
}

/// Use a username and password when connecting to the proxy server
fn with_basic_auth<T: Into<String>, U: Into<String>>(
mut self,
Expand All @@ -635,6 +681,8 @@ impl ProxyScheme {
ProxyScheme::Socks5 { ref mut auth, .. } => {
*auth = Some((username.into(), password.into()));
}
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => (),
}
}

Expand All @@ -650,6 +698,10 @@ impl ProxyScheme {
ProxyScheme::Socks5 { .. } => {
panic!("Socks is not supported for this method")
}
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => {
panic!("Unix sockets are not supported for this method")
}
}
}

Expand All @@ -667,6 +719,8 @@ impl ProxyScheme {
}
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => {}
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => {}
}

self
Expand Down Expand Up @@ -701,6 +755,8 @@ impl ProxyScheme {
"socks5" => Self::socks5(to_addr()?)?,
#[cfg(feature = "socks")]
"socks5h" => Self::socks5h(to_addr()?)?,
#[cfg(unix)]
"unix" => Self::unix_socket(url.path()),
_ => return Err(crate::error::builder("unknown proxy scheme")),
};

Expand All @@ -720,6 +776,8 @@ impl ProxyScheme {
ProxyScheme::Https { .. } => "https",
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => "socks5",
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => "unix",
}
}

Expand All @@ -730,6 +788,8 @@ impl ProxyScheme {
ProxyScheme::Https { host, .. } => host.as_str(),
#[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => panic!("socks5"),
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => panic!("unix"),
}
}
}
Expand All @@ -748,6 +808,10 @@ impl fmt::Debug for ProxyScheme {
let h = if *remote_dns { "h" } else { "" };
write!(f, "socks5{h}://{addr}")
}
#[cfg(unix)]
ProxyScheme::UnixSocket { socket } => {
write!(f, "unix://{}", socket.display())
}
}
}
}
Expand Down Expand Up @@ -1127,6 +1191,8 @@ mod tests {
let (scheme, host) = match p.intercept(&url(s)).unwrap() {
ProxyScheme::Http { host, .. } => ("http", host),
ProxyScheme::Https { host, .. } => ("https", host),
#[cfg(unix)]
ProxyScheme::UnixSocket { .. } => panic!("intercepted as unix"),
#[cfg(feature = "socks")]
_ => panic!("intercepted as socks"),
};
Expand Down
Loading

0 comments on commit 8ebac73

Please sign in to comment.