diff options
Diffstat (limited to 'ext/http/lib.rs')
-rw-r--r-- | ext/http/lib.rs | 99 |
1 files changed, 69 insertions, 30 deletions
diff --git a/ext/http/lib.rs b/ext/http/lib.rs index 24dd77c92..e11d42da1 100644 --- a/ext/http/lib.rs +++ b/ext/http/lib.rs @@ -39,6 +39,7 @@ use hyper::service::Service; use hyper::Body; use hyper::Request; use hyper::Response; +use percent_encoding::percent_encode; use serde::Deserialize; use serde::Serialize; use std::borrow::Cow; @@ -49,7 +50,6 @@ use std::future::Future; use std::io; use std::mem::replace; use std::mem::take; -use std::net::SocketAddr; use std::pin::Pin; use std::rc::Rc; use std::sync::Arc; @@ -83,8 +83,27 @@ pub fn init() -> Extension { .build() } +pub enum HttpSocketAddr { + IpSocket(std::net::SocketAddr), + #[cfg(unix)] + UnixSocket(tokio::net::unix::SocketAddr), +} + +impl From<std::net::SocketAddr> for HttpSocketAddr { + fn from(addr: std::net::SocketAddr) -> Self { + Self::IpSocket(addr) + } +} + +#[cfg(unix)] +impl From<tokio::net::unix::SocketAddr> for HttpSocketAddr { + fn from(addr: tokio::net::unix::SocketAddr) -> Self { + Self::UnixSocket(addr) + } +} + struct HttpConnResource { - addr: SocketAddr, + addr: HttpSocketAddr, scheme: &'static str, acceptors_tx: mpsc::UnboundedSender<HttpAcceptor>, closed_fut: Shared<RemoteHandle<Result<(), Arc<hyper::Error>>>>, @@ -92,7 +111,7 @@ struct HttpConnResource { } impl HttpConnResource { - fn new<S>(io: S, scheme: &'static str, addr: SocketAddr) -> Self + fn new<S>(io: S, scheme: &'static str, addr: HttpSocketAddr) -> Self where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -172,8 +191,8 @@ impl HttpConnResource { self.scheme } - fn addr(&self) -> SocketAddr { - self.addr + fn addr(&self) -> &HttpSocketAddr { + &self.addr } } @@ -188,16 +207,17 @@ impl Resource for HttpConnResource { } /// Creates a new HttpConn resource which uses `io` as its transport. -pub fn http_create_conn_resource<S>( +pub fn http_create_conn_resource<S, A>( state: &mut OpState, io: S, - addr: SocketAddr, + addr: A, scheme: &'static str, ) -> Result<ResourceId, AnyError> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + A: Into<HttpSocketAddr>, { - let conn = HttpConnResource::new(io, scheme, addr); + let conn = HttpConnResource::new(io, scheme, addr.into()); let rid = state.resource_table.add(conn); Ok(rid) } @@ -375,30 +395,49 @@ async fn op_http_accept( fn req_url( req: &hyper::Request<hyper::Body>, scheme: &'static str, - addr: SocketAddr, + addr: &HttpSocketAddr, ) -> String { - let host: Cow<str> = if let Some(auth) = req.uri().authority() { - match addr.port() { - 443 if scheme == "https" => Cow::Borrowed(auth.host()), - 80 if scheme == "http" => Cow::Borrowed(auth.host()), - _ => Cow::Borrowed(auth.as_str()), // Includes port number. - } - } else if let Some(host) = req.uri().host() { - Cow::Borrowed(host) - } else if let Some(host) = req.headers().get("HOST") { - match host.to_str() { - Ok(host) => Cow::Borrowed(host), - Err(_) => Cow::Owned( - host - .as_bytes() - .iter() - .cloned() - .map(char::from) - .collect::<String>(), - ), + let host: Cow<str> = match addr { + HttpSocketAddr::IpSocket(addr) => { + if let Some(auth) = req.uri().authority() { + match addr.port() { + 443 if scheme == "https" => Cow::Borrowed(auth.host()), + 80 if scheme == "http" => Cow::Borrowed(auth.host()), + _ => Cow::Borrowed(auth.as_str()), // Includes port number. + } + } else if let Some(host) = req.uri().host() { + Cow::Borrowed(host) + } else if let Some(host) = req.headers().get("HOST") { + match host.to_str() { + Ok(host) => Cow::Borrowed(host), + Err(_) => Cow::Owned( + host + .as_bytes() + .iter() + .cloned() + .map(char::from) + .collect::<String>(), + ), + } + } else { + Cow::Owned(addr.to_string()) + } } - } else { - Cow::Owned(addr.to_string()) + // There is no standard way for unix domain socket URLs + // nginx and nodejs request use http://unix:[socket_path]:/ but it is not a valid URL + // httpie uses http+unix://[percent_encoding_of_path]/ which we follow + #[cfg(unix)] + HttpSocketAddr::UnixSocket(addr) => Cow::Owned( + percent_encode( + addr + .as_pathname() + .and_then(|x| x.to_str()) + .unwrap_or_default() + .as_bytes(), + percent_encoding::NON_ALPHANUMERIC, + ) + .to_string(), + ), }; let path = req.uri().path_and_query().map_or("/", |p| p.as_str()); [scheme, "://", &host, path].concat() |