Commit 51f9f415 authored by chenhuaqing's avatar chenhuaqing

code refinement for tcp

parent 34d1d07b
......@@ -358,6 +358,8 @@ fn main() {
local_dns_config.local_dns_addr = local_config.local_dns_addr.take();
local_dns_config.remote_dns_addr = local_config.remote_dns_addr.take();
local_dns_config.mode = Mode::UdpOnly;
config.local.push(local_dns_config);
}
}
......
......@@ -12,7 +12,8 @@ use futures::{
};
use log::{error, trace};
use shadowsocks::{
config::{Mode, TokenHolder},
config::Mode,
context::SharedAuthContext,
net::{AcceptOpts, ConnectOpts, TcpStream as OutboundTcpStream},
plugin::{Plugin, PluginMode},
ServerConfig,
......@@ -116,6 +117,13 @@ pub async fn run(mut config: Config) -> io::Result<()> {
assert!(!config.local.is_empty(), "no valid local server configuration");
if let Some(svr_cfg) = config.server.first() {
context
.context_ref()
.auth_context()
.update_token(svr_cfg.token().expect("token"));
}
let context = Arc::new(context);
let vfut = FuturesUnordered::new();
......@@ -162,8 +170,9 @@ pub async fn run(mut config: Config) -> io::Result<()> {
);
}
let token_holder = server.token_holder().expect("token holder").clone();
vfut.push(config_notify_update_token(watched_config.clone(), token_holder).boxed());
vfut.push(
config_notify_update_token(watched_config.clone(), context.context_ref().auth_context()).boxed(),
);
}
}
......@@ -380,7 +389,7 @@ async fn authenticate_server(
)
.await?;
auth_stream
.write_buf(&mut svr_cfg.token_holder().expect("token").get_token().as_bytes())
.write_buf(&mut svr_context.context_ref().auth_context().get_token().as_bytes())
.await?;
let mut retry_count = svr_cfg.retry_limit();
......@@ -480,7 +489,7 @@ async fn create_auth_stream_monitor(
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
pub async fn config_notify_update_token(watched_config: String, token_holder: TokenHolder) -> io::Result<()> {
pub async fn config_notify_update_token(watched_config: String, auth_context: &SharedAuthContext) -> io::Result<()> {
use log::debug;
use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Result as NotifyResult, Watcher};
use tokio::sync::watch;
......@@ -518,7 +527,7 @@ pub async fn config_notify_update_token(watched_config: String, token_holder: To
match Config::load_from_file(&watched_config, ConfigType::Local) {
Ok(cfg) => {
if let Some(new_svr_cfg) = cfg.server.first() {
token_holder.set_token(&new_svr_cfg.token_holder().expect("new token holder").get_token());
auth_context.update_token(&new_svr_cfg.token().expect("token"));
}
}
Err(err) => {
......
......@@ -161,13 +161,9 @@ impl Socks5TcpHandler {
};
if svr_cfg.method().is_none() {
debug!(
"stream send token {} to remote server",
svr_cfg.token_holder().expect("token").get_token()
);
remote
.write_buf(&mut svr_cfg.token_holder().expect("token").get_token().as_bytes())
.await?;
let token = self.context.context_ref().auth_context().get_token();
trace!("stream send token {} to remote server", token);
remote.write_buf(&mut token.as_bytes()).await?;
}
establish_tcp_tunnel(svr_cfg, &mut stream, &mut remote, peer_addr, &target_addr).await
......
......@@ -7,7 +7,6 @@ use std::{
fmt::{self, Display},
net::SocketAddr,
str::FromStr,
sync::{Arc, RwLock},
time::Duration,
};
......@@ -145,33 +144,6 @@ impl ServerWeight {
}
}
#[derive(Debug, Clone)]
pub struct TokenHolder {
token: Arc<RwLock<String>>,
}
impl TokenHolder {
pub fn new(token: String) -> TokenHolder {
TokenHolder {
token: Arc::new(RwLock::new(token)),
}
}
pub fn set_token(&self, t: &str) {
if let Ok(mut token) = self.token.write() {
*token = t.to_string();
}
}
pub fn get_token(&self) -> String {
if let Ok(token) = self.token.read() {
return (*token).clone();
} else {
return String::default();
}
}
}
/// Configuration for a server
#[derive(Clone, Debug)]
pub struct ServerConfig {
......@@ -187,7 +159,7 @@ pub struct ServerConfig {
timeout: Option<Duration>,
/// Authencation token
token: Option<TokenHolder>,
token: Option<String>,
/// Auth retry limit
retry_limit: u32,
......@@ -293,12 +265,12 @@ impl ServerConfig {
where
S: Into<String>,
{
self.token = Some(TokenHolder::new(token.into()));
self.token = Some(token.into());
}
/// Get token holder
pub fn token_holder(&self) -> Option<&TokenHolder> {
self.token.as_ref()
pub fn token(&self) -> Option<&str> {
self.token.as_ref().map(AsRef::as_ref)
}
/// Set auth retry limit
......
//! Shadowsocks service context
use std::{io, net::SocketAddr, sync::{Arc, RwLock}};
use std::{
io,
net::SocketAddr,
sync::{Arc, RwLock},
};
use bloomfilter::Bloom;
use spin::Mutex as SpinMutex;
......@@ -90,7 +94,43 @@ impl PingPongBloom {
}
}
pub type TokenHolder = RwLock<String>;
pub type SharedToken = Arc<RwLock<String>>;
pub type SharedAuthContext = Arc<AuthContext>;
pub struct AuthContext {
token: SharedToken,
token_len: usize,
}
impl AuthContext {
pub fn new(token: &str) -> AuthContext {
AuthContext {
token: Arc::new(RwLock::new(token.to_string())),
token_len: token.len(),
}
}
pub fn update_token(&self, token: &str) {
if let Ok(mut t) = self.token.write() {
*t = token.to_string();
} else {
log::error!("cannot update token");
}
}
pub fn get_token(&self) -> String {
if let Ok(t) = self.token.read() {
return t.clone();
} else {
log::error!("cannot read token");
return String::default();
}
}
pub fn get_token_len(&self) -> usize {
self.token_len
}
}
/// Service context
pub struct Context {
......@@ -101,7 +141,8 @@ pub struct Context {
// trust-dns resolver, which supports REAL asynchronous resolving, and also customizable
dns_resolver: Arc<DnsResolver>,
token_holder: Option<Arc<TokenHolder>>,
/// authentication context
auth_context: Arc<AuthContext>,
}
/// `Context` for sharing between services
......@@ -114,7 +155,7 @@ impl Context {
Context {
nonce_ppbloom,
dns_resolver: Arc::new(DnsResolver::system_resolver()),
token_holder: None,
auth_context: Arc::new(AuthContext::new(&String::default())),
}
}
......@@ -155,8 +196,8 @@ impl Context {
self.dns_resolver.resolve(addr, port).await
}
pub fn save_token(&mut self, token: &str) {
self.token_holder = Some(Arc::new(RwLock::new(token.to_string())));
/// Get the authentication context
pub fn auth_context(&self) -> &Arc<AuthContext> {
&self.auth_context
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment