Commit b71f5595 authored by chenhuaqing's avatar chenhuaqing

optimized for context

parent 72201d5f
...@@ -8,7 +8,7 @@ use std::{net::IpAddr, time::Duration}; ...@@ -8,7 +8,7 @@ use std::{net::IpAddr, time::Duration};
use lru_time_cache::LruCache; use lru_time_cache::LruCache;
use shadowsocks::{ use shadowsocks::{
config::ServerType, config::ServerType,
context::{Context, SharedContext}, context::{AuthContext, Context, SharedAuthContext, SharedContext},
dns_resolver::DnsResolver, dns_resolver::DnsResolver,
net::{AcceptOpts, ConnectOpts}, net::{AcceptOpts, ConnectOpts},
relay::Address, relay::Address,
...@@ -119,6 +119,17 @@ impl ServiceContext { ...@@ -119,6 +119,17 @@ impl ServiceContext {
self.context.dns_resolver() self.context.dns_resolver()
} }
/// Set authentication context
pub fn set_auth_context(&mut self, auth_context: SharedAuthContext) {
let context = Arc::get_mut(&mut self.context).expect("cannot set DNS resolver on a shared context");
context.set_auth_context(auth_context)
}
/// Get reference of authentication context
pub fn auth_context(&self) -> &AuthContext {
self.context.auth_context()
}
/// Check if target should be bypassed /// Check if target should be bypassed
pub async fn check_target_bypassed(&self, addr: &Address) -> bool { pub async fn check_target_bypassed(&self, addr: &Address) -> bool {
match self.acl { match self.acl {
......
...@@ -13,12 +13,11 @@ use futures::{ ...@@ -13,12 +13,11 @@ use futures::{
use log::{error, trace}; use log::{error, trace};
use shadowsocks::{ use shadowsocks::{
config::Mode, config::Mode,
context::SharedAuthContext, context::AuthContext,
net::{AcceptOpts, ConnectOpts, TcpStream as OutboundTcpStream}, net::{AcceptOpts, ConnectOpts, TcpStream as OutboundTcpStream},
plugin::{Plugin, PluginMode}, plugin::{Plugin, PluginMode},
ServerConfig, ServerConfig,
}; };
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[cfg(feature = "local-flow-stat")] #[cfg(feature = "local-flow-stat")]
use crate::net::FlowStat; use crate::net::FlowStat;
...@@ -117,6 +116,12 @@ pub async fn run(mut config: Config) -> io::Result<()> { ...@@ -117,6 +116,12 @@ pub async fn run(mut config: Config) -> io::Result<()> {
assert!(!config.local.is_empty(), "no valid local server configuration"); assert!(!config.local.is_empty(), "no valid local server configuration");
if let Some(server) = config.server.first() {
if let Some(token) = server.token() {
context.set_auth_context(AuthContext::new_shared(token));
}
}
let context = Arc::new(context); let context = Arc::new(context);
let vfut = FuturesUnordered::new(); let vfut = FuturesUnordered::new();
...@@ -150,9 +155,7 @@ pub async fn run(mut config: Config) -> io::Result<()> { ...@@ -150,9 +155,7 @@ pub async fn run(mut config: Config) -> io::Result<()> {
if let Some(watched_config) = config.watched_config { if let Some(watched_config) = config.watched_config {
for server in &config.server { for server in &config.server {
if let Some(token) = server.token() { if server.token().is_some() {
context.context_ref().auth_context().update_token(token);
// Start monitor for token based server // Start monitor for token based server
if let Some(auth_stream) = authenticate_server(&server, context.clone()).await? { if let Some(auth_stream) = authenticate_server(&server, context.clone()).await? {
vfut.push( vfut.push(
...@@ -165,10 +168,8 @@ pub async fn run(mut config: Config) -> io::Result<()> { ...@@ -165,10 +168,8 @@ pub async fn run(mut config: Config) -> io::Result<()> {
.boxed(), .boxed(),
); );
} }
vfut.push( vfut.push(config_notify_update_token(watched_config.clone(), context.auth_context()).boxed());
config_notify_update_token(watched_config.clone(), context.context_ref().auth_context()).boxed(),
);
} }
} }
} }
...@@ -373,7 +374,10 @@ async fn authenticate_server( ...@@ -373,7 +374,10 @@ async fn authenticate_server(
svr_context: Arc<ServiceContext>, svr_context: Arc<ServiceContext>,
) -> io::Result<Option<OutboundTcpStream>> { ) -> io::Result<Option<OutboundTcpStream>> {
use log::debug; use log::debug;
use tokio::time; use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time,
};
// connect to remote server // connect to remote server
if svr_cfg.token().is_some() { if svr_cfg.token().is_some() {
...@@ -386,7 +390,7 @@ async fn authenticate_server( ...@@ -386,7 +390,7 @@ async fn authenticate_server(
) )
.await?; .await?;
auth_stream auth_stream
.write_buf(&mut svr_context.context_ref().auth_context().get_token().as_bytes()) .write_buf(&mut svr_context.auth_context().get_token().as_bytes())
.await?; .await?;
let mut retry_count = svr_cfg.retry_limit(); let mut retry_count = svr_cfg.retry_limit();
...@@ -435,7 +439,10 @@ async fn create_auth_stream_monitor( ...@@ -435,7 +439,10 @@ async fn create_auth_stream_monitor(
beats_interval: u64, beats_interval: u64,
) -> io::Result<()> { ) -> io::Result<()> {
use log::debug; use log::debug;
use tokio::time; use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time,
};
let mut retry_count = retry_limit; let mut retry_count = retry_limit;
let auth_timeout = Duration::from_secs(auth_timeout); let auth_timeout = Duration::from_secs(auth_timeout);
...@@ -483,10 +490,13 @@ async fn create_auth_stream_monitor( ...@@ -483,10 +490,13 @@ async fn create_auth_stream_monitor(
time::sleep(beats_interval).await; time::sleep(beats_interval).await;
retry_count = retry_limit; retry_count = retry_limit;
} }
return Err(io::Error::from(io::ErrorKind::BrokenPipe)); return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"auth_stream_monitor cannot communicate with remote server",
));
} }
pub async fn config_notify_update_token(watched_config: String, auth_context: &SharedAuthContext) -> io::Result<()> { pub async fn config_notify_update_token(watched_config: String, auth_context: &AuthContext) -> io::Result<()> {
use log::debug; use log::debug;
use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Result as NotifyResult, Watcher}; use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Result as NotifyResult, Watcher};
use tokio::sync::watch; use tokio::sync::watch;
......
...@@ -161,7 +161,7 @@ impl Socks5TcpHandler { ...@@ -161,7 +161,7 @@ impl Socks5TcpHandler {
}; };
if svr_cfg.token().is_some() { if svr_cfg.token().is_some() {
let token = self.context.context_ref().auth_context().get_token(); let token = self.context.auth_context().get_token();
trace!("stream send token {} to remote server", token); trace!("stream send token {} to remote server", token);
remote.write_buf(&mut token.as_bytes()).await?; remote.write_buf(&mut token.as_bytes()).await?;
} }
......
...@@ -110,6 +110,10 @@ impl AuthContext { ...@@ -110,6 +110,10 @@ impl AuthContext {
} }
} }
pub fn new_shared(token: &str) -> SharedAuthContext {
Arc::new(Self::new(token))
}
pub fn update_token(&self, token: &str) { pub fn update_token(&self, token: &str) {
if let Ok(mut t) = self.token.write() { if let Ok(mut t) = self.token.write() {
*t = token.to_string(); *t = token.to_string();
...@@ -196,8 +200,13 @@ impl Context { ...@@ -196,8 +200,13 @@ impl Context {
self.dns_resolver.resolve(addr, port).await self.dns_resolver.resolve(addr, port).await
} }
/// Set the authentication context
pub fn set_auth_context(&mut self, auth_context: SharedAuthContext) {
self.auth_context = auth_context;
}
/// Get the authentication context /// Get the authentication context
pub fn auth_context(&self) -> &Arc<AuthContext> { pub fn auth_context(&self) -> &SharedAuthContext {
&self.auth_context &self.auth_context
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment