Commit 51f9f415 authored by chenhuaqing's avatar chenhuaqing

code refinement for tcp

parent 34d1d07b
...@@ -358,6 +358,8 @@ fn main() { ...@@ -358,6 +358,8 @@ fn main() {
local_dns_config.local_dns_addr = local_config.local_dns_addr.take(); local_dns_config.local_dns_addr = local_config.local_dns_addr.take();
local_dns_config.remote_dns_addr = local_config.remote_dns_addr.take(); local_dns_config.remote_dns_addr = local_config.remote_dns_addr.take();
local_dns_config.mode = Mode::UdpOnly;
config.local.push(local_dns_config); config.local.push(local_dns_config);
} }
} }
......
...@@ -12,7 +12,8 @@ use futures::{ ...@@ -12,7 +12,8 @@ use futures::{
}; };
use log::{error, trace}; use log::{error, trace};
use shadowsocks::{ use shadowsocks::{
config::{Mode, TokenHolder}, config::Mode,
context::SharedAuthContext,
net::{AcceptOpts, ConnectOpts, TcpStream as OutboundTcpStream}, net::{AcceptOpts, ConnectOpts, TcpStream as OutboundTcpStream},
plugin::{Plugin, PluginMode}, plugin::{Plugin, PluginMode},
ServerConfig, ServerConfig,
...@@ -116,6 +117,13 @@ pub async fn run(mut config: Config) -> io::Result<()> { ...@@ -116,6 +117,13 @@ pub async fn run(mut config: Config) -> io::Result<()> {
assert!(!config.local.is_empty(), "no valid local server configuration"); assert!(!config.local.is_empty(), "no valid local server configuration");
if let Some(svr_cfg) = config.server.first() {
context
.context_ref()
.auth_context()
.update_token(svr_cfg.token().expect("token"));
}
let context = Arc::new(context); let context = Arc::new(context);
let vfut = FuturesUnordered::new(); let vfut = FuturesUnordered::new();
...@@ -162,8 +170,9 @@ pub async fn run(mut config: Config) -> io::Result<()> { ...@@ -162,8 +170,9 @@ pub async fn run(mut config: Config) -> io::Result<()> {
); );
} }
let token_holder = server.token_holder().expect("token holder").clone(); vfut.push(
vfut.push(config_notify_update_token(watched_config.clone(), token_holder).boxed()); config_notify_update_token(watched_config.clone(), context.context_ref().auth_context()).boxed(),
);
} }
} }
...@@ -380,7 +389,7 @@ async fn authenticate_server( ...@@ -380,7 +389,7 @@ async fn authenticate_server(
) )
.await?; .await?;
auth_stream auth_stream
.write_buf(&mut svr_cfg.token_holder().expect("token").get_token().as_bytes()) .write_buf(&mut svr_context.context_ref().auth_context().get_token().as_bytes())
.await?; .await?;
let mut retry_count = svr_cfg.retry_limit(); let mut retry_count = svr_cfg.retry_limit();
...@@ -480,7 +489,7 @@ async fn create_auth_stream_monitor( ...@@ -480,7 +489,7 @@ async fn create_auth_stream_monitor(
return Err(io::Error::from(io::ErrorKind::BrokenPipe)); return Err(io::Error::from(io::ErrorKind::BrokenPipe));
} }
pub async fn config_notify_update_token(watched_config: String, token_holder: TokenHolder) -> io::Result<()> { pub async fn config_notify_update_token(watched_config: String, auth_context: &SharedAuthContext) -> io::Result<()> {
use log::debug; use log::debug;
use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Result as NotifyResult, Watcher}; use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Result as NotifyResult, Watcher};
use tokio::sync::watch; use tokio::sync::watch;
...@@ -518,7 +527,7 @@ pub async fn config_notify_update_token(watched_config: String, token_holder: To ...@@ -518,7 +527,7 @@ pub async fn config_notify_update_token(watched_config: String, token_holder: To
match Config::load_from_file(&watched_config, ConfigType::Local) { match Config::load_from_file(&watched_config, ConfigType::Local) {
Ok(cfg) => { Ok(cfg) => {
if let Some(new_svr_cfg) = cfg.server.first() { if let Some(new_svr_cfg) = cfg.server.first() {
token_holder.set_token(&new_svr_cfg.token_holder().expect("new token holder").get_token()); auth_context.update_token(&new_svr_cfg.token().expect("token"));
} }
} }
Err(err) => { Err(err) => {
......
...@@ -161,13 +161,9 @@ impl Socks5TcpHandler { ...@@ -161,13 +161,9 @@ impl Socks5TcpHandler {
}; };
if svr_cfg.method().is_none() { if svr_cfg.method().is_none() {
debug!( let token = self.context.context_ref().auth_context().get_token();
"stream send token {} to remote server", trace!("stream send token {} to remote server", token);
svr_cfg.token_holder().expect("token").get_token() remote.write_buf(&mut token.as_bytes()).await?;
);
remote
.write_buf(&mut svr_cfg.token_holder().expect("token").get_token().as_bytes())
.await?;
} }
establish_tcp_tunnel(svr_cfg, &mut stream, &mut remote, peer_addr, &target_addr).await establish_tcp_tunnel(svr_cfg, &mut stream, &mut remote, peer_addr, &target_addr).await
......
...@@ -7,7 +7,6 @@ use std::{ ...@@ -7,7 +7,6 @@ use std::{
fmt::{self, Display}, fmt::{self, Display},
net::SocketAddr, net::SocketAddr,
str::FromStr, str::FromStr,
sync::{Arc, RwLock},
time::Duration, time::Duration,
}; };
...@@ -145,33 +144,6 @@ impl ServerWeight { ...@@ -145,33 +144,6 @@ impl ServerWeight {
} }
} }
#[derive(Debug, Clone)]
pub struct TokenHolder {
token: Arc<RwLock<String>>,
}
impl TokenHolder {
pub fn new(token: String) -> TokenHolder {
TokenHolder {
token: Arc::new(RwLock::new(token)),
}
}
pub fn set_token(&self, t: &str) {
if let Ok(mut token) = self.token.write() {
*token = t.to_string();
}
}
pub fn get_token(&self) -> String {
if let Ok(token) = self.token.read() {
return (*token).clone();
} else {
return String::default();
}
}
}
/// Configuration for a server /// Configuration for a server
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ServerConfig { pub struct ServerConfig {
...@@ -187,7 +159,7 @@ pub struct ServerConfig { ...@@ -187,7 +159,7 @@ pub struct ServerConfig {
timeout: Option<Duration>, timeout: Option<Duration>,
/// Authencation token /// Authencation token
token: Option<TokenHolder>, token: Option<String>,
/// Auth retry limit /// Auth retry limit
retry_limit: u32, retry_limit: u32,
...@@ -293,12 +265,12 @@ impl ServerConfig { ...@@ -293,12 +265,12 @@ impl ServerConfig {
where where
S: Into<String>, S: Into<String>,
{ {
self.token = Some(TokenHolder::new(token.into())); self.token = Some(token.into());
} }
/// Get token holder /// Get token holder
pub fn token_holder(&self) -> Option<&TokenHolder> { pub fn token(&self) -> Option<&str> {
self.token.as_ref() self.token.as_ref().map(AsRef::as_ref)
} }
/// Set auth retry limit /// Set auth retry limit
......
//! Shadowsocks service context //! Shadowsocks service context
use std::{io, net::SocketAddr, sync::{Arc, RwLock}}; use std::{
io,
net::SocketAddr,
sync::{Arc, RwLock},
};
use bloomfilter::Bloom; use bloomfilter::Bloom;
use spin::Mutex as SpinMutex; use spin::Mutex as SpinMutex;
...@@ -90,7 +94,43 @@ impl PingPongBloom { ...@@ -90,7 +94,43 @@ impl PingPongBloom {
} }
} }
pub type TokenHolder = RwLock<String>; pub type SharedToken = Arc<RwLock<String>>;
pub type SharedAuthContext = Arc<AuthContext>;
pub struct AuthContext {
token: SharedToken,
token_len: usize,
}
impl AuthContext {
pub fn new(token: &str) -> AuthContext {
AuthContext {
token: Arc::new(RwLock::new(token.to_string())),
token_len: token.len(),
}
}
pub fn update_token(&self, token: &str) {
if let Ok(mut t) = self.token.write() {
*t = token.to_string();
} else {
log::error!("cannot update token");
}
}
pub fn get_token(&self) -> String {
if let Ok(t) = self.token.read() {
return t.clone();
} else {
log::error!("cannot read token");
return String::default();
}
}
pub fn get_token_len(&self) -> usize {
self.token_len
}
}
/// Service context /// Service context
pub struct Context { pub struct Context {
...@@ -101,7 +141,8 @@ pub struct Context { ...@@ -101,7 +141,8 @@ pub struct Context {
// trust-dns resolver, which supports REAL asynchronous resolving, and also customizable // trust-dns resolver, which supports REAL asynchronous resolving, and also customizable
dns_resolver: Arc<DnsResolver>, dns_resolver: Arc<DnsResolver>,
token_holder: Option<Arc<TokenHolder>>, /// authentication context
auth_context: Arc<AuthContext>,
} }
/// `Context` for sharing between services /// `Context` for sharing between services
...@@ -114,7 +155,7 @@ impl Context { ...@@ -114,7 +155,7 @@ impl Context {
Context { Context {
nonce_ppbloom, nonce_ppbloom,
dns_resolver: Arc::new(DnsResolver::system_resolver()), dns_resolver: Arc::new(DnsResolver::system_resolver()),
token_holder: None, auth_context: Arc::new(AuthContext::new(&String::default())),
} }
} }
...@@ -155,8 +196,8 @@ impl Context { ...@@ -155,8 +196,8 @@ impl Context {
self.dns_resolver.resolve(addr, port).await self.dns_resolver.resolve(addr, port).await
} }
pub fn save_token(&mut self, token: &str) { /// Get the authentication context
self.token_holder = Some(Arc::new(RwLock::new(token.to_string()))); pub fn auth_context(&self) -> &Arc<AuthContext> {
&self.auth_context
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment