Commit b71f5595 authored by chenhuaqing's avatar chenhuaqing

optimized for context

parent 72201d5f
......@@ -8,7 +8,7 @@ use std::{net::IpAddr, time::Duration};
use lru_time_cache::LruCache;
use shadowsocks::{
config::ServerType,
context::{Context, SharedContext},
context::{AuthContext, Context, SharedAuthContext, SharedContext},
dns_resolver::DnsResolver,
net::{AcceptOpts, ConnectOpts},
relay::Address,
......@@ -119,6 +119,17 @@ impl ServiceContext {
self.context.dns_resolver()
}
/// Set authentication context
pub fn set_auth_context(&mut self, auth_context: SharedAuthContext) {
let context = Arc::get_mut(&mut self.context).expect("cannot set DNS resolver on a shared context");
context.set_auth_context(auth_context)
}
/// Get reference of authentication context
pub fn auth_context(&self) -> &AuthContext {
self.context.auth_context()
}
/// Check if target should be bypassed
pub async fn check_target_bypassed(&self, addr: &Address) -> bool {
match self.acl {
......
......@@ -13,12 +13,11 @@ use futures::{
use log::{error, trace};
use shadowsocks::{
config::Mode,
context::SharedAuthContext,
context::AuthContext,
net::{AcceptOpts, ConnectOpts, TcpStream as OutboundTcpStream},
plugin::{Plugin, PluginMode},
ServerConfig,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[cfg(feature = "local-flow-stat")]
use crate::net::FlowStat;
......@@ -117,6 +116,12 @@ pub async fn run(mut config: Config) -> io::Result<()> {
assert!(!config.local.is_empty(), "no valid local server configuration");
if let Some(server) = config.server.first() {
if let Some(token) = server.token() {
context.set_auth_context(AuthContext::new_shared(token));
}
}
let context = Arc::new(context);
let vfut = FuturesUnordered::new();
......@@ -150,9 +155,7 @@ pub async fn run(mut config: Config) -> io::Result<()> {
if let Some(watched_config) = config.watched_config {
for server in &config.server {
if let Some(token) = server.token() {
context.context_ref().auth_context().update_token(token);
if server.token().is_some() {
// Start monitor for token based server
if let Some(auth_stream) = authenticate_server(&server, context.clone()).await? {
vfut.push(
......@@ -166,9 +169,7 @@ pub async fn run(mut config: Config) -> io::Result<()> {
);
}
vfut.push(
config_notify_update_token(watched_config.clone(), context.context_ref().auth_context()).boxed(),
);
vfut.push(config_notify_update_token(watched_config.clone(), context.auth_context()).boxed());
}
}
}
......@@ -373,7 +374,10 @@ async fn authenticate_server(
svr_context: Arc<ServiceContext>,
) -> io::Result<Option<OutboundTcpStream>> {
use log::debug;
use tokio::time;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time,
};
// connect to remote server
if svr_cfg.token().is_some() {
......@@ -386,7 +390,7 @@ async fn authenticate_server(
)
.await?;
auth_stream
.write_buf(&mut svr_context.context_ref().auth_context().get_token().as_bytes())
.write_buf(&mut svr_context.auth_context().get_token().as_bytes())
.await?;
let mut retry_count = svr_cfg.retry_limit();
......@@ -435,7 +439,10 @@ async fn create_auth_stream_monitor(
beats_interval: u64,
) -> io::Result<()> {
use log::debug;
use tokio::time;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time,
};
let mut retry_count = retry_limit;
let auth_timeout = Duration::from_secs(auth_timeout);
......@@ -483,10 +490,13 @@ async fn create_auth_stream_monitor(
time::sleep(beats_interval).await;
retry_count = retry_limit;
}
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"auth_stream_monitor cannot communicate with remote server",
));
}
pub async fn config_notify_update_token(watched_config: String, auth_context: &SharedAuthContext) -> io::Result<()> {
pub async fn config_notify_update_token(watched_config: String, auth_context: &AuthContext) -> io::Result<()> {
use log::debug;
use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Result as NotifyResult, Watcher};
use tokio::sync::watch;
......
......@@ -161,7 +161,7 @@ impl Socks5TcpHandler {
};
if svr_cfg.token().is_some() {
let token = self.context.context_ref().auth_context().get_token();
let token = self.context.auth_context().get_token();
trace!("stream send token {} to remote server", token);
remote.write_buf(&mut token.as_bytes()).await?;
}
......
......@@ -110,6 +110,10 @@ impl AuthContext {
}
}
pub fn new_shared(token: &str) -> SharedAuthContext {
Arc::new(Self::new(token))
}
pub fn update_token(&self, token: &str) {
if let Ok(mut t) = self.token.write() {
*t = token.to_string();
......@@ -196,8 +200,13 @@ impl Context {
self.dns_resolver.resolve(addr, port).await
}
/// Set the authentication context
pub fn set_auth_context(&mut self, auth_context: SharedAuthContext) {
self.auth_context = auth_context;
}
/// Get the authentication context
pub fn auth_context(&self) -> &Arc<AuthContext> {
pub fn auth_context(&self) -> &SharedAuthContext {
&self.auth_context
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment