Commit 72201d5f authored by chenhuaqing's avatar chenhuaqing

fixed configuration without token

parent 51f9f415
...@@ -1220,8 +1220,8 @@ impl Config { ...@@ -1220,8 +1220,8 @@ impl Config {
nsvr.set_timeout(timeout); nsvr.set_timeout(timeout);
} }
if method.is_none() { if let Some(token) = config.token {
nsvr.set_token(config.token.expect("token")); nsvr.set_token(token);
nsvr.set_retry_limit(config.retry_limit.expect("retry_limit")); nsvr.set_retry_limit(config.retry_limit.expect("retry_limit"));
nsvr.set_auth_timeout(config.auth_timeout.expect("auth_timeout")); nsvr.set_auth_timeout(config.auth_timeout.expect("auth_timeout"));
nsvr.set_beats_interval(config.beats_interval.expect("beats_interval")); nsvr.set_beats_interval(config.beats_interval.expect("beats_interval"));
......
...@@ -503,6 +503,7 @@ impl DnsClient { ...@@ -503,6 +503,7 @@ impl DnsClient {
message = result; message = result;
message.set_id(request.id()); message.set_id(request.id());
} else { } else {
error!("acl_lookup error {:?}", r);
message.set_response_code(ResponseCode::ServFail); message.set_response_code(ResponseCode::ServFail);
} }
} }
......
...@@ -79,10 +79,12 @@ impl DnsClient { ...@@ -79,10 +79,12 @@ impl DnsClient {
connect_opts: &ConnectOpts, connect_opts: &ConnectOpts,
flow_stat: Arc<FlowStat>, flow_stat: Arc<FlowStat>,
) -> io::Result<DnsClient> { ) -> io::Result<DnsClient> {
let stream = ProxyClientStream::connect_with_opts_map(context, svr_cfg, ns, connect_opts, |s| { let auth_context = context.auth_context().clone();
let mut stream = ProxyClientStream::connect_with_opts_map(context, svr_cfg, ns, connect_opts, |s| {
MonProxyStream::from_stream(s, flow_stat) MonProxyStream::from_stream(s, flow_stat)
}) })
.await?; .await?;
stream.write_buf(&mut auth_context.get_token().as_bytes()).await?;
Ok(DnsClient::TcpRemote { stream }) Ok(DnsClient::TcpRemote { stream })
} }
...@@ -135,6 +137,7 @@ impl DnsClient { ...@@ -135,6 +137,7 @@ impl DnsClient {
DnsClient::UnixStream { ref mut stream } => stream_query(stream, msg).await, DnsClient::UnixStream { ref mut stream } => stream_query(stream, msg).await,
DnsClient::TcpRemote { ref mut stream } => stream_query(stream, msg).await, DnsClient::TcpRemote { ref mut stream } => stream_query(stream, msg).await,
DnsClient::UdpRemote { ref mut socket, ref ns } => { DnsClient::UdpRemote { ref mut socket, ref ns } => {
trace!("send dns query to remote server {}", ns);
let bytes = msg.to_vec()?; let bytes = msg.to_vec()?;
socket.send(ns, &bytes).await?; socket.send(ns, &bytes).await?;
......
...@@ -117,13 +117,6 @@ pub async fn run(mut config: Config) -> io::Result<()> { ...@@ -117,13 +117,6 @@ pub async fn run(mut config: Config) -> io::Result<()> {
assert!(!config.local.is_empty(), "no valid local server configuration"); assert!(!config.local.is_empty(), "no valid local server configuration");
if let Some(svr_cfg) = config.server.first() {
context
.context_ref()
.auth_context()
.update_token(svr_cfg.token().expect("token"));
}
let context = Arc::new(context); let context = Arc::new(context);
let vfut = FuturesUnordered::new(); let vfut = FuturesUnordered::new();
...@@ -157,22 +150,26 @@ pub async fn run(mut config: Config) -> io::Result<()> { ...@@ -157,22 +150,26 @@ pub async fn run(mut config: Config) -> io::Result<()> {
if let Some(watched_config) = config.watched_config { if let Some(watched_config) = config.watched_config {
for server in &config.server { for server in &config.server {
// Start monitor for token based server if let Some(token) = server.token() {
if let Some(auth_stream) = authenticate_server(&server, context.clone()).await? { context.context_ref().auth_context().update_token(token);
// Start monitor for token based server
if let Some(auth_stream) = authenticate_server(&server, context.clone()).await? {
vfut.push(
create_auth_stream_monitor(
auth_stream,
server.retry_limit(),
server.auth_timeout(),
server.beats_interval(),
)
.boxed(),
);
}
vfut.push( vfut.push(
create_auth_stream_monitor( config_notify_update_token(watched_config.clone(), context.context_ref().auth_context()).boxed(),
auth_stream,
server.retry_limit(),
server.auth_timeout(),
server.beats_interval(),
)
.boxed(),
); );
} }
vfut.push(
config_notify_update_token(watched_config.clone(), context.context_ref().auth_context()).boxed(),
);
} }
} }
...@@ -379,7 +376,7 @@ async fn authenticate_server( ...@@ -379,7 +376,7 @@ async fn authenticate_server(
use tokio::time; use tokio::time;
// connect to remote server // connect to remote server
if svr_cfg.method().is_none() { if svr_cfg.token().is_some() {
let context = svr_context.clone().context(); let context = svr_context.clone().context();
debug!("connect to {}", svr_cfg.addr()); debug!("connect to {}", svr_cfg.addr());
let mut auth_stream = OutboundTcpStream::connect_server_with_opts( let mut auth_stream = OutboundTcpStream::connect_server_with_opts(
...@@ -446,7 +443,7 @@ async fn create_auth_stream_monitor( ...@@ -446,7 +443,7 @@ async fn create_auth_stream_monitor(
debug!("prepare send packet to remote server"); debug!("prepare send packet to remote server");
while retry_count > 0 { while retry_count > 0 {
let mut req_buf = BytesMut::with_capacity(1); let mut req_buf = BytesMut::with_capacity(1);
req_buf.put_i32(1); req_buf.put_u8(1);
match time::timeout(auth_timeout, stream.write_buf(&mut req_buf.freeze())).await { match time::timeout(auth_timeout, stream.write_buf(&mut req_buf.freeze())).await {
Ok(res) => { Ok(res) => {
if let Err(err) = res { if let Err(err) = res {
......
...@@ -160,7 +160,7 @@ impl Socks5TcpHandler { ...@@ -160,7 +160,7 @@ impl Socks5TcpHandler {
} }
}; };
if svr_cfg.method().is_none() { if svr_cfg.token().is_some() {
let token = self.context.context_ref().auth_context().get_token(); let token = self.context.context_ref().auth_context().get_token();
trace!("stream send token {} to remote server", token); trace!("stream send token {} to remote server", token);
remote.write_buf(&mut token.as_bytes()).await?; remote.write_buf(&mut token.as_bytes()).await?;
......
...@@ -228,6 +228,12 @@ where ...@@ -228,6 +228,12 @@ where
self.enc.poll_write_encrypted(cx, &mut self.stream, buf) self.enc.poll_write_encrypted(cx, &mut self.stream, buf)
} }
/// Attempt to write data to `stream`
#[inline]
pub fn poll_write(&mut self, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
/// Polls `flush` on the underlying stream /// Polls `flush` on the underlying stream
#[inline] #[inline]
pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
......
...@@ -184,9 +184,9 @@ where ...@@ -184,9 +184,9 @@ where
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> { fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
let mut this = self.project(); let mut this = self.project();
if !*this.authed { if this.context.auth_context().get_token_len() > 0 && !*this.authed {
*this.authed = true; *this.authed = true;
return this.stream.poll_write_encrypted(cx, buf); return this.stream.poll_write(cx, buf);
} }
loop { loop {
......
...@@ -42,7 +42,9 @@ pub fn encrypt_payload( ...@@ -42,7 +42,9 @@ pub fn encrypt_payload(
) { ) {
match method.category() { match method.category() {
CipherCategory::None => { CipherCategory::None => {
dst.reserve(addr.serialized_len() + payload.len()); let auth_ctx = context.auth_context();
dst.reserve(auth_ctx.get_token_len() + addr.serialized_len() + payload.len());
dst.put_slice(auth_ctx.get_token().as_bytes());
addr.write_to_buf(dst); addr.write_to_buf(dst);
dst.put_slice(payload); dst.put_slice(payload);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment