Skip to content

Commit

Permalink
refactor(serve): Refactor the validation module (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x676e67 committed May 3, 2024
1 parent 6ce2f30 commit 20e82e1
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 94 deletions.
22 changes: 3 additions & 19 deletions src/proxy/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,6 @@ use crate::BootArgs;
use std::net::{IpAddr, SocketAddr};
use tokio::sync::OnceCell;

/// Auth Error
#[derive(thiserror::Error, Debug)]
pub enum AuthError {
#[error("Missing credentials")]
MissingCredentials,
#[error("Invalid credentials")]
InvalidCredentials,
#[error("Unauthorized")]
Unauthorized,
}

/// Ip address whitelist
static IP_WHITELIST: OnceCell<Option<Vec<IpAddr>>> = OnceCell::const_new();

Expand All @@ -28,14 +17,9 @@ pub fn init_ip_whitelist(args: &BootArgs) {
}

/// Valid Ip address whitelist
pub fn authenticate_ip(socket: SocketAddr) -> Result<(), AuthError> {
pub fn authenticate_ip(socket: SocketAddr) -> bool {
match IP_WHITELIST.get() {
Some(Some(ip)) => {
if ip.contains(&socket.ip()) {
return Ok(());
}
Err(AuthError::Unauthorized)
}
Some(None) | None => Ok(()),
Some(Some(ip)) => ip.contains(&socket.ip()),
Some(None) | None => true,
}
}
15 changes: 13 additions & 2 deletions src/proxy/http/auth.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
use crate::proxy::auth::{self, AuthError};
use crate::proxy::auth;
use base64::Engine;
use http::{header, HeaderMap};
use std::net::SocketAddr;

/// Auth Error
#[derive(thiserror::Error, Debug)]
pub enum AuthError {
#[error("Missing credentials")]
MissingCredentials,
#[error("Invalid credentials")]
InvalidCredentials,
#[error("Unauthorized")]
Unauthorized,
}

#[derive(Clone)]
pub enum Authenticator {
None,
Expand All @@ -12,7 +23,7 @@ pub enum Authenticator {
impl Authenticator {
pub fn authenticate(&self, headers: &HeaderMap, socket: SocketAddr) -> Result<(), AuthError> {
// If no authentication is required, return immediately
if auth::authenticate_ip(socket).is_ok() {
if auth::authenticate_ip(socket) {
return Ok(());
}
match self {
Expand Down
2 changes: 1 addition & 1 deletion src/proxy/http/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::proxy::auth::AuthError;
use super::auth::AuthError;

/// Proxy Error
#[derive(thiserror::Error, Debug)]
Expand Down
2 changes: 1 addition & 1 deletion src/proxy/socks5/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub enum Error {
InvalidFragmentId(u8),

#[error("Invalid authentication method: {0:?}")]
InvalidAuthMethod(crate::proxy::socks5::proto::AuthMethod),
InvalidAuthMethod(crate::proxy::socks5::proto::Method),

#[error("SOCKS version is 4 when 5 is expected")]
WrongVersion,
Expand Down
6 changes: 3 additions & 3 deletions src/proxy/socks5/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ where
{
let server = Server::bind_with_concurrency(listen_addr, auth, concurrent).await?;

while let Ok((conn, socket)) = server.accept().await {
while let Ok((conn, _)) = server.accept().await {
tokio::spawn(async move {
if let Err(err) = handle(conn, socket).await {
if let Err(err) = handle(conn).await {
tracing::error!("{err}");
}
});
}
Ok(())
}

async fn handle<S>(conn: IncomingConnection<S>, _: SocketAddr) -> Result<()>
async fn handle<S>(conn: IncomingConnection<S>) -> Result<()>
where
S: Send + Sync + 'static,
{
Expand Down
52 changes: 26 additions & 26 deletions src/proxy/socks5/proto/handshake/method.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/// A proxy authentication method.
#[repr(u8)]
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum AuthMethod {
pub enum Method {
/// No authentication required.
NoAuth = 0x00,
/// GSS API.
Expand All @@ -16,47 +16,47 @@ pub enum AuthMethod {
NoAcceptableMethods = 0xff,
}

impl From<u8> for AuthMethod {
impl From<u8> for Method {
fn from(value: u8) -> Self {
match value {
0x00 => AuthMethod::NoAuth,
0x01 => AuthMethod::GssApi,
0x02 => AuthMethod::Password,
0x03..=0x7f => AuthMethod::IanaReserved(value),
0x80..=0xfe => AuthMethod::Private(value),
0xff => AuthMethod::NoAcceptableMethods,
0x00 => Method::NoAuth,
0x01 => Method::GssApi,
0x02 => Method::Password,
0x03..=0x7f => Method::IanaReserved(value),
0x80..=0xfe => Method::Private(value),
0xff => Method::NoAcceptableMethods,
}
}
}

impl From<AuthMethod> for u8 {
fn from(value: AuthMethod) -> Self {
From::<&AuthMethod>::from(&value)
impl From<Method> for u8 {
fn from(value: Method) -> Self {
From::<&Method>::from(&value)
}
}

impl From<&AuthMethod> for u8 {
fn from(value: &AuthMethod) -> Self {
impl From<&Method> for u8 {
fn from(value: &Method) -> Self {
match value {
AuthMethod::NoAuth => 0x00,
AuthMethod::GssApi => 0x01,
AuthMethod::Password => 0x02,
AuthMethod::IanaReserved(value) => *value,
AuthMethod::Private(value) => *value,
AuthMethod::NoAcceptableMethods => 0xff,
Method::NoAuth => 0x00,
Method::GssApi => 0x01,
Method::Password => 0x02,
Method::IanaReserved(value) => *value,
Method::Private(value) => *value,
Method::NoAcceptableMethods => 0xff,
}
}
}

impl std::fmt::Display for AuthMethod {
impl std::fmt::Display for Method {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
AuthMethod::NoAuth => write!(f, "NoAuth"),
AuthMethod::GssApi => write!(f, "GssApi"),
AuthMethod::Password => write!(f, "UserPass"),
AuthMethod::IanaReserved(value) => write!(f, "IanaReserved({0:#x})", value),
AuthMethod::Private(value) => write!(f, "Private({0:#x})", value),
AuthMethod::NoAcceptableMethods => write!(f, "NoAcceptableMethods"),
Method::NoAuth => write!(f, "NoAuth"),
Method::GssApi => write!(f, "GssApi"),
Method::Password => write!(f, "UserPass"),
Method::IanaReserved(value) => write!(f, "IanaReserved({0:#x})", value),
Method::Private(value) => write!(f, "Private({0:#x})", value),
Method::NoAcceptableMethods => write!(f, "NoAcceptableMethods"),
}
}
}
2 changes: 1 addition & 1 deletion src/proxy/socks5/proto/handshake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pub mod password;
mod request;
mod response;

pub use self::{method::AuthMethod, request::Request, response::Response};
pub use self::{method::Method, request::Request, response::Response};
18 changes: 2 additions & 16 deletions src/proxy/socks5/proto/handshake/password/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ impl std::fmt::Display for UsernamePassword {
percent_encode(self.username.as_bytes(), NON_ALPHANUMERIC)
),
(false, false) => {
let username =
percent_encode(self.username.as_bytes(), NON_ALPHANUMERIC).to_string();
let password =
percent_encode(self.password.as_bytes(), NON_ALPHANUMERIC).to_string();
let username = percent_encode(self.username.as_bytes(), NON_ALPHANUMERIC);
let password = percent_encode(self.password.as_bytes(), NON_ALPHANUMERIC);
write!(f, "{}:{}", username, password)
}
}
Expand Down Expand Up @@ -63,15 +61,3 @@ impl UsernamePassword {
self.password.as_bytes().to_vec()
}
}

#[test]
fn test_user_pass() {
let user_pass = UsernamePassword::new("username", "pass@word");
assert_eq!(user_pass.to_string(), "username:pass%40word");
let user_pass = UsernamePassword::new("username", "");
assert_eq!(user_pass.to_string(), "username");
let user_pass = UsernamePassword::new("", "password");
assert_eq!(user_pass.to_string(), ":password");
let user_pass = UsernamePassword::new("", "");
assert_eq!(user_pass.to_string(), "");
}
10 changes: 5 additions & 5 deletions src/proxy/socks5/proto/handshake/request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::proxy::socks5::proto::{AsyncStreamOperation, AuthMethod, StreamOperation, Version};
use crate::proxy::socks5::proto::{AsyncStreamOperation, Method, StreamOperation, Version};
use tokio::io::{AsyncRead, AsyncReadExt};

/// SOCKS5 handshake request
Expand All @@ -12,11 +12,11 @@ use tokio::io::{AsyncRead, AsyncReadExt};
/// ```
#[derive(Clone, Debug)]
pub struct Request {
methods: Vec<AuthMethod>,
methods: Vec<Method>,
}

impl Request {
pub fn evaluate_method(&self, server_method: AuthMethod) -> bool {
pub fn evaluate_method(&self, server_method: Method) -> bool {
self.methods.iter().any(|&m| m == server_method)
}
}
Expand All @@ -39,7 +39,7 @@ impl StreamOperation for Request {
let mut methods = vec![0; mlen as usize];
r.read_exact(&mut methods)?;

let methods = methods.into_iter().map(AuthMethod::from).collect();
let methods = methods.into_iter().map(Method::from).collect();

Ok(Self { methods })
}
Expand Down Expand Up @@ -73,7 +73,7 @@ impl AsyncStreamOperation for Request {
let mut methods = vec![0; mlen as usize];
r.read_exact(&mut methods).await?;

let methods = methods.into_iter().map(AuthMethod::from).collect();
let methods = methods.into_iter().map(Method::from).collect();

Ok(Self { methods })
}
Expand Down
10 changes: 5 additions & 5 deletions src/proxy/socks5/proto/handshake/response.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::proxy::socks5::proto::{AsyncStreamOperation, AuthMethod, StreamOperation, Version};
use crate::proxy::socks5::proto::{AsyncStreamOperation, Method, StreamOperation, Version};
use tokio::io::{AsyncRead, AsyncReadExt};

/// SOCKS5 handshake response
Expand All @@ -12,11 +12,11 @@ use tokio::io::{AsyncRead, AsyncReadExt};
/// ```
#[derive(Clone, Debug)]
pub struct Response {
pub method: AuthMethod,
pub method: Method,
}

impl Response {
pub fn new(method: AuthMethod) -> Self {
pub fn new(method: Method) -> Self {
Self { method }
}
}
Expand All @@ -34,7 +34,7 @@ impl StreamOperation for Response {

let mut method = [0; 1];
r.read_exact(&mut method)?;
let method = AuthMethod::from(method[0]);
let method = Method::from(method[0]);

Ok(Self { method })
}
Expand All @@ -61,7 +61,7 @@ impl AsyncStreamOperation for Response {
return Err(std::io::Error::new(std::io::ErrorKind::Unsupported, err));
}

let method = AuthMethod::from(r.read_u8().await?);
let method = Method::from(r.read_u8().await?);

Ok(Self { method })
}
Expand Down
2 changes: 1 addition & 1 deletion src/proxy/socks5/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod udp;
pub use self::{
address::Address,
command::Command,
handshake::{password::UsernamePassword, AuthMethod},
handshake::{password::UsernamePassword, Method},
reply::Reply,
request::Request,
response::Response,
Expand Down
23 changes: 13 additions & 10 deletions src/proxy/socks5/server/auth.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use crate::proxy::{
auth,
socks5::proto::{handshake::password, AsyncStreamOperation, AuthMethod, UsernamePassword},
socks5::proto::{handshake::password, AsyncStreamOperation, Method, UsernamePassword},
};
use as_any::AsAny;
use async_trait::async_trait;
use std::sync::Arc;
use std::{
io::{Error, ErrorKind},
sync::Arc,
};
use tokio::net::TcpStream;

pub type AuthAdaptor<O> = Arc<dyn Auth<Output = O> + Send + Sync>;

#[async_trait]
pub trait Auth {
type Output: AsAny;
fn auth_method(&self) -> AuthMethod;
fn method(&self) -> Method;
async fn execute(&self, stream: &mut TcpStream) -> Self::Output;
}

Expand All @@ -24,8 +27,8 @@ pub struct NoAuth;
impl Auth for NoAuth {
type Output = ();

fn auth_method(&self) -> AuthMethod {
AuthMethod::NoAuth
fn method(&self) -> Method {
Method::NoAuth
}

async fn execute(&self, _: &mut TcpStream) -> Self::Output {}
Expand All @@ -45,23 +48,23 @@ impl Password {
impl Auth for Password {
type Output = std::io::Result<bool>;

fn auth_method(&self) -> AuthMethod {
AuthMethod::Password
fn method(&self) -> Method {
Method::Password
}

async fn execute(&self, stream: &mut TcpStream) -> Self::Output {
use password::{Request, Response, Status::*};
let req = Request::retrieve_from_async_stream(stream).await?;
let socket = stream.peer_addr()?;

let is_equal = (req.user_pass == self.0) || auth::authenticate_ip(socket).is_ok();
let is_equal = (req.user_pass == self.0) || auth::authenticate_ip(socket);
let resp = Response::new(if is_equal { Succeeded } else { Failed });
resp.write_to_async_stream(stream).await?;
if is_equal {
Ok(true)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
Err(Error::new(
ErrorKind::Other,
"username or password is incorrect",
))
}
Expand Down
8 changes: 4 additions & 4 deletions src/proxy/socks5/server/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use self::{associate::UdpAssociate, bind::Bind, connect::Connect};
use crate::proxy::{
socks5::{
proto::{self, handshake, Address, AsyncStreamOperation, AuthMethod, Command},
proto::{self, handshake, Address, AsyncStreamOperation, Command, Method},
server::AuthAdaptor,
},
Error,
Expand Down Expand Up @@ -109,15 +109,15 @@ impl<O: 'static> IncomingConnection<O> {
let output = self.auth.execute(&mut self.stream).await;
Ok((AuthenticatedStream::new(self.stream), output))
} else {
let response = handshake::Response::new(AuthMethod::NoAcceptableMethods);
let response = handshake::Response::new(Method::NoAcceptableMethods);
response.write_to_async_stream(&mut self.stream).await?;
let err = "No available handshake method provided by client";
Err(std::io::Error::new(std::io::ErrorKind::Unsupported, err))
}
}

fn evaluate_request(&self, req: &handshake::Request) -> Option<AuthMethod> {
let method = self.auth.auth_method();
fn evaluate_request(&self, req: &handshake::Request) -> Option<Method> {
let method = self.auth.method();
if req.evaluate_method(method) {
Some(method)
} else {
Expand Down

0 comments on commit 20e82e1

Please sign in to comment.