Commit c0663ea7 authored by Dominik Charousset's avatar Dominik Charousset

Simplify client_config scaffolding

parent 5310b68e
......@@ -154,4 +154,10 @@ namespace caf::defaults::net {
/// previous connection has been closed.
constexpr auto max_connections = make_parameter("max-connections", size_t{64});
/// The default port for HTTP servers.
constexpr uint16_t http_default_port = 80;
/// The default port for HTTPS servers.
constexpr uint16_t https_default_port = 443;
} // namespace caf::defaults::net
......@@ -9,7 +9,6 @@
#include "caf/intrusive_ptr.hpp"
#include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/config_base.hpp"
#include "caf/net/dsl/fwd.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/connection.hpp"
#include "caf/net/ssl/context.hpp"
......@@ -22,81 +21,12 @@
namespace caf::net::dsl {
/// The server config type enum class.
enum class client_config_type { lazy, socket, conn, fail };
/// Meta programming utility for `as_base_ptr()`.
struct client_config_tag {};
/// Meta programming utility for `client_config<Base>::make()`.
/// Meta programming utility.
template <class T>
struct client_config_token {};
/// Base class for client configuration objects.
template <class Base>
class client_config : public Base, public client_config_tag {
public:
friend class lazy_client_config<Base>;
friend class socket_client_config<Base>;
friend class conn_client_config<Base>;
friend class fail_client_config<Base>;
using lazy = lazy_client_config<Base>;
using socket = socket_client_config<Base>;
using conn = conn_client_config<Base>;
using fail = fail_client_config<Base>;
/// Anchor type for meta programming.
using base_type = client_config;
client_config(client_config&&) = default;
client_config(const client_config&) = default;
/// Virtual destructor.
virtual ~client_config() = default;
/// Returns the server configuration type.
virtual client_config_type type() const noexcept = 0;
template <class Token, class... Ts>
static auto make(client_config_token<Token>, Ts&&... xs);
private:
/// Private constructor to enforce sealing.
template <class... Ts>
explicit client_config(multiplexer* mpx, Ts&&... xs)
: Base(mpx, std::forward<Ts>(xs)...) {
// nop
}
};
#define CAF_NET_DSL_ADD_CLIENT_TOKEN(type) \
struct client_config_bind_##type { \
template <class Base> \
using bind = typename client_config<Base>::type; \
}; \
static constexpr auto client_config_##type##_v \
= client_config_token<client_config_bind_##type> {}
/// Compile-time constant for `client_config::lazy`.
CAF_NET_DSL_ADD_CLIENT_TOKEN(lazy);
/// Compile-time constant for `client_config::socket`.
CAF_NET_DSL_ADD_CLIENT_TOKEN(socket);
/// Compile-time constant for `client_config::conn`.
CAF_NET_DSL_ADD_CLIENT_TOKEN(conn);
/// Compile-time constant for `client_config::fail`.
CAF_NET_DSL_ADD_CLIENT_TOKEN(fail);
/// Intrusive pointer type for server configurations.
template <class Base>
using client_config_ptr = intrusive_ptr<client_config<Base>>;
struct client_config_tag {};
/// Simple type for storing host and port information for reaching a server.
struct client_config_server_address {
struct server_address {
/// The host name or IP address of the host.
std::string host;
......@@ -104,220 +34,181 @@ struct client_config_server_address {
uint16_t port;
};
/// Configuration for a client that creates the socket on demand.
template <class Base>
class lazy_client_config final : public client_config<Base> {
class client_config {
public:
static constexpr auto type_token = client_config_type::lazy;
/// Configuration for a client that creates the socket on demand.
class lazy {
public:
/// Type for holding a client address.
using server_t = std::variant<server_address, uri>;
lazy(std::string host, uint16_t port) {
server = server_address{std::move(host), port};
}
using super = client_config<Base>;
explicit lazy(uri addr) {
server = addr;
}
template <class... Ts>
lazy_client_config(std::string host, uint16_t port, multiplexer* mpx,
Ts&&... xs)
: super(mpx, std::forward<Ts>(xs)...) {
server = client_config_server_address{std::move(host), port};
}
/// The address for reaching the server or an error.
server_t server;
/// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx;
/// The delay between connection attempts.
timespan retry_delay = std::chrono::seconds{1};
/// The timeout when trying to connect.
timespan connection_timeout = infinite;
/// The maximum amount of retries.
size_t max_retry_count = 0;
/// Returns a function that, when called with a @ref stream_socket, calls
/// `f` either with a new SSL connection from `ctx` or with the file the
/// file descriptor if no SSL context is defined.
template <class F>
auto with_ctx(F&& f) {
return [this, g = std::forward<F>(f)](stream_socket fd) mutable {
using res_t = decltype(g(fd));
if (ctx) {
auto conn = ctx->new_connection(fd);
if (conn)
return g(*conn);
else
return res_t{std::move(conn.error())};
} else
return g(fd);
};
}
};
template <class... Ts>
lazy_client_config(const uri& addr, multiplexer* mpx, Ts&&... xs)
: super(mpx, std::forward<Ts>(xs)...) {
server = addr;
}
static constexpr auto lazy_v = client_config_tag<lazy>{};
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
/// Configuration for a client that uses a user-provided socket.
class socket {
public:
explicit socket(stream_socket fd) : fd(fd) {
// nop
}
/// Type for holding a client address.
using server_t = std::variant<client_config_server_address, uri>;
socket() = delete;
/// The address for reaching the server or an error.
server_t server;
socket(const socket&) = delete;
/// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx;
socket& operator=(const socket&) = delete;
/// The delay between connection attempts.
timespan retry_delay = std::chrono::seconds{1};
socket(socket&& other) noexcept : fd(other.fd) {
other.fd.id = invalid_socket_id;
}
/// The timeout when trying to connect.
timespan connection_timeout = infinite;
socket& operator=(socket&& other) noexcept {
using std::swap;
swap(fd, other.fd);
return *this;
}
/// The maximum amount of retries.
size_t max_retry_count = 0;
};
~socket() {
if (fd != invalid_socket)
close(fd);
}
/// Configuration for a client that uses a user-provided socket.
template <class Base>
class socket_client_config final : public client_config<Base> {
public:
static constexpr auto type_token = client_config_type::socket;
using super = client_config<Base>;
template <class... Ts>
socket_client_config(stream_socket fd, multiplexer* mpx, Ts&&... xs)
: super(mpx, std::forward<Ts>(xs)...), fd(fd) {
// nop
}
~socket_client_config() override {
if (fd != invalid_socket)
close(fd);
}
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
/// The socket file descriptor to use.
stream_socket fd;
/// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx;
/// Returns the file descriptor and setting the `fd` member variable to the
/// invalid socket.
stream_socket take_fd() noexcept {
auto result = fd;
fd.id = invalid_socket_id;
return result;
}
};
/// The socket file descriptor to use.
stream_socket fd;
/// Configuration for a client that uses an already established SSL connection.
template <class Base>
class conn_client_config final : public client_config<Base> {
public:
static constexpr auto type_token = client_config_type::conn;
/// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx;
using super = client_config<Base>;
/// Returns the file descriptor and setting the `fd` member variable to the
/// invalid socket.
stream_socket take_fd() noexcept {
auto result = fd;
fd.id = invalid_socket_id;
return result;
}
};
template <class... Ts>
conn_client_config(ssl::connection state, multiplexer* mpx, Ts&&... xs)
: super(mpx, std::forward<Ts>(xs)...), state(std::move(state)) {
// nop
}
static constexpr auto socket_v = client_config_tag<socket>{};
~conn_client_config() override {
if (state) {
if (auto fd = state.fd(); fd != invalid_socket)
close(fd);
/// Configuration for a client that uses an already established SSL
/// connection.
class conn {
public:
explicit conn(ssl::connection st) : state(std::move(st)) {
// nop
}
}
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
conn() = delete;
/// SSL state for the connection.
ssl::connection state;
};
conn(const conn&) = delete;
/// Wraps an error that occurred earlier in the setup phase.
template <class Base>
class fail_client_config final : public client_config<Base> {
public:
static constexpr auto type_token = client_config_type::fail;
conn& operator=(const conn&) = delete;
using super = client_config<Base>;
conn(conn&&) noexcept = default;
template <class... Ts>
fail_client_config(error err, multiplexer* mpx, Ts&&... xs)
: super(mpx, std::forward<Ts>(xs)...), err(std::move(err)) {
// nop
}
conn& operator=(conn&&) noexcept = default;
fail_client_config(error err, const super& other)
: super(other), err(std::move(err)) {
// nop
}
~conn() {
if (state) {
if (auto fd = state.fd(); fd != invalid_socket)
close(fd);
}
}
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
/// SSL state for the connection.
ssl::connection state;
};
/// The forwarded error.
error err;
};
static constexpr auto conn_v = client_config_tag<conn>{};
/// Calls a function object with the actual subtype of a client configuration
/// and returns its result.
template <class F, class Base>
decltype(auto) visit(F&& f, client_config<Base>& cfg) {
auto type = cfg.type();
switch (cfg.type()) {
case client_config_type::lazy:
return f(static_cast<lazy_client_config<Base>&>(cfg));
case client_config_type::socket:
return f(static_cast<socket_client_config<Base>&>(cfg));
case client_config_type::conn:
return f(static_cast<conn_client_config<Base>&>(cfg));
default:
assert(type == client_config_type::fail);
return f(static_cast<fail_client_config<Base>&>(cfg));
}
}
static constexpr auto fail_v = client_config_tag<error>{};
/// Calls a function object with the actual subtype of a client configuration
/// and returns its result.
template <class F, class Base>
decltype(auto) visit(F&& f, const client_config<Base>& cfg) {
auto type = cfg.type();
switch (cfg.type()) {
case client_config_type::lazy:
return f(static_cast<const lazy_client_config<Base>&>(cfg));
case client_config_type::socket:
return f(static_cast<const socket_client_config<Base>&>(cfg));
case client_config_type::conn:
return f(static_cast<const conn_client_config<Base>&>(cfg));
default:
assert(type == client_config_type::fail);
return f(static_cast<const fail_client_config<Base>&>(cfg));
}
}
template <class Base>
class value : public Base {
public:
using super = Base;
/// Gets a pointer to a specific subtype of a client configuration.
template <class T, class Base>
T* get_if(client_config<Base>* config) {
if (T::type_token == config->type())
return static_cast<T*>(config);
return nullptr;
}
template <class Trait, class... Data>
value(net::multiplexer* mpx, Trait trait, Data&&... arg)
: super(mpx, std::move(trait)), data(std::forward<Data>(arg)...) {
// nop
}
/// Gets a pointer to a specific subtype of a client configuration.
template <class T, class Base>
const T* get_if(const client_config<Base>* config) {
if (T::type_token == config->type())
return static_cast<const T*>(config);
return nullptr;
}
template <class... Data>
explicit value(const value& other, Data&&... arg)
: super(other), data(std::forward<Data>(arg)...) {
// nop
}
std::variant<error, lazy, socket, conn> data;
template <class T, class Trait, class... Args>
static intrusive_ptr<value> make(client_config_tag<T>,
net::multiplexer* mpx, Trait trait,
Args&&... args) {
return make_counted<value>(mpx, std::move(trait), std::in_place_type<T>,
std::forward<Args>(args)...);
}
template <class F>
auto visit(F&& f) {
return std::visit([&](auto& arg) { return f(arg); }, data);
}
};
};
/// Creates a `fail_client_config` from another configuration object plus error.
template <class Base>
auto to_fail_config(client_config_ptr<Base> ptr, error err) {
using impl_t = fail_client_config<Base>;
return make_counted<impl_t>(std::move(err), *ptr);
}
using client_config_value = client_config::value<Base>;
/// Returns the pointer as a pointer to the `client_config` base type.
template <class T>
auto as_base_ptr(
intrusive_ptr<T> ptr,
std::enable_if_t<std::is_base_of_v<client_config_tag, T>>* = nullptr) {
return std::move(ptr).template upcast<typename T::base_type>();
}
template <class Base>
using client_config_ptr = intrusive_ptr<client_config_value<Base>>;
/// Creates a `fail_client_config` from another configuration object plus error.
template <class Base>
template <class Token, class... Ts>
auto client_config<Base>::make(client_config_token<Token>, Ts&&... xs) {
using type_t = typename Token::template bind<Base>;
return make_counted<type_t>(std::forward<Ts>(xs)...);
client_config_ptr<Base> to_fail_config(client_config_ptr<Base> ptr, error err) {
using val_t = typename client_config::template value<Base>;
return make_counted<val_t>(*ptr, std::move(err));
}
} // namespace caf::net::dsl
......@@ -20,22 +20,18 @@ namespace caf::net::dsl {
template <class ConfigBase, class Derived>
class client_factory_base {
public:
using config_type = client_config<ConfigBase>;
using config_type = client_config_value<ConfigBase>;
using trait_type = typename config_type::trait_type;
using config_pointer = intrusive_ptr<config_type>;
explicit client_factory_base(config_pointer cfg) : cfg_(std::move(cfg)) {
// nop
}
client_factory_base(const client_factory_base&) = default;
client_factory_base& operator=(const client_factory_base&) = default;
template <class T, class... Ts>
explicit client_factory_base(dsl::client_config_token<T> token, Ts&&... xs) {
explicit client_factory_base(dsl::client_config_tag<T> token, Ts&&... xs) {
cfg_ = config_type::make(token, std::forward<Ts>(xs)...);
}
......@@ -52,8 +48,8 @@ public:
/// @param value The new retry delay.
/// @returns a reference to this `client_factory`.
Derived& retry_delay(timespan value) {
if (auto* cfg = get_if<lazy_client_config<ConfigBase>>(cfg_.get()))
cfg->retry_delay = value;
if (auto* lazy = get_if<client_config::lazy>(&cfg_->data))
lazy->retry_delay = value;
return dref();
}
......@@ -62,8 +58,8 @@ public:
/// @param value The new connection timeout.
/// @returns a reference to this `client_factory`.
Derived& connection_timeout(timespan value) {
if (auto* cfg = get_if<lazy_client_config<ConfigBase>>(cfg_.get()))
cfg->connection_timeout = value;
if (auto* lazy = get_if<client_config::lazy>(&cfg_->data))
lazy->connection_timeout = value;
return dref();
}
......@@ -72,8 +68,8 @@ public:
/// @param value The new maximum retry count.
/// @returns a reference to this `client_factory`.
Derived& max_retry_count(size_t value) {
if (auto* cfg = get_if<lazy_client_config<ConfigBase>>(cfg_.get()))
cfg->max_retry_count = value;
if (auto* lazy = get_if<client_config::lazy>(&cfg_->data))
lazy->max_retry_count = value;
return dref();
}
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
namespace caf::net::dsl {
enum class client_config_type;
template <class Base>
class client_config;
template <class Base>
class lazy_client_config;
template <class Base>
class socket_client_config;
template <class Base>
class conn_client_config;
template <class Base>
class fail_client_config;
} // namespace caf::net::dsl
......@@ -27,8 +27,8 @@ public:
/// @returns a `connect_factory` object initialized with the given parameters.
auto connect(std::string host, uint16_t port) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_lazy_v, std::move(host), port, this->mpx(),
this->trait());
return dref.make(client_config::lazy_v, this->mpx(), this->trait(),
std::move(host), port);
}
/// Creates a `connect_factory` object for the given stream `fd`.
......@@ -36,7 +36,7 @@ public:
/// @returns a `connect_factory` object that will use the given socket.
auto connect(stream_socket fd) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_socket_v, fd, this->mpx(), this->trait());
return dref.make(client_config::socket_v, this->mpx(), this->trait(), fd);
}
/// Creates a `connect_factory` object for the given SSL `connection`.
......@@ -44,8 +44,8 @@ public:
/// @returns a `connect_factory` object that will use the given connection.
auto connect(ssl::connection conn) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_conn_v, std::move(conn), this->mpx(),
this->trait());
return dref.make(client_config::conn_v, this->mpx(), this->trait(),
std::move(conn));
}
};
......
......@@ -31,8 +31,8 @@ public:
/// @returns a `connect_factory` object initialized with the given parameters.
auto connect(const uri& endpoint) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_lazy_v, endpoint, this->mpx(),
this->trait());
return dref.make(client_config::lazy_v, this->mpx(), this->trait(),
endpoint);
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
......@@ -43,8 +43,8 @@ public:
if (endpoint)
return connect(*endpoint);
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_fail_v, std::move(endpoint.error()),
this->mpx(), this->trait());
return dref.make(client_config::fail_v, this->mpx(), this->trait(),
std::move(endpoint.error()));
}
};
......
......@@ -45,6 +45,12 @@ public:
std::move(push));
}
void abort(const error& err) override {
super::abort(err);
if (push_)
push_.abort(err);
}
error start(net::binary::lower_layer* down_ptr) override {
super::down_ = down_ptr;
return super::init(std::move(pull_), std::move(push_));
......@@ -75,39 +81,21 @@ public:
using super::super;
using start_res_t = expected<disposable>;
/// Starts a connection with the length-prefixing protocol.
template <class OnStart>
[[nodiscard]] expected<disposable> start(OnStart on_start) {
using input_res_t = typename Trait::input_resource;
using output_res_t = typename Trait::output_resource;
static_assert(std::is_invocable_v<OnStart, input_res_t, output_res_t>);
auto f = [this, &on_start](auto& cfg) {
return this->do_start(cfg, on_start);
};
return visit(f, this->config());
return super::config().visit([this, &on_start](auto& data) {
return do_start(super::config(), data, on_start);
});
}
private:
auto try_connect(const typename config_type::lazy& cfg,
const std::string& host, uint16_t port) {
auto result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
for (size_t i = 1; i <= cfg.max_retry_count; ++i) {
std::this_thread::sleep_for(cfg.retry_delay);
result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
}
return result;
}
template <class Conn, class OnStart>
start_res_t do_start_impl(config_type& cfg, Conn conn, OnStart& on_start) {
expected<disposable>
do_start_impl(config_type& cfg, Conn conn, OnStart& on_start) {
// s2a: socket-to-application (and a2s is the inverse).
using input_t = typename Trait::input_type;
using output_t = typename Trait::output_type;
......@@ -126,51 +114,45 @@ private:
bridge_ptr->self_ref(ptr->as_disposable());
cfg.mpx->start(ptr);
on_start(std::move(s2a_pull), std::move(a2s_push));
return start_res_t{disposable{std::move(ptr)}};
return expected<disposable>{disposable{std::move(ptr)}};
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, const std::string& host,
uint16_t port, OnStart& on_start) {
auto fd = try_connect(cfg, host, port);
if (fd) {
if (cfg.ctx) {
auto conn = cfg.ctx->new_connection(*fd);
if (conn)
return do_start_impl(cfg, std::move(*conn), on_start);
cfg.call_on_error(conn.error());
return start_res_t{std::move(conn.error())};
}
return do_start_impl(cfg, *fd, on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data,
OnStart& on_start) {
if (std::holds_alternative<uri>(data.server)) {
auto err = make_error(sec::invalid_argument,
"length-prefix factories do not accept URIs");
return do_start(cfg, err, on_start);
}
cfg.call_on_error(fd.error());
return start_res_t{std::move(fd.error())};
auto& addr = std::get<dsl::server_address>(data.server);
return detail::tcp_try_connect(std::move(addr.host), addr.port,
data.connection_timeout,
data.max_retry_count, data.retry_delay)
.and_then(data.with_ctx([this, &cfg, &on_start](auto& conn) {
return this->do_start_impl(cfg, std::move(conn), on_start);
}));
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, OnStart& on_start) {
if (auto* st = std::get_if<dsl::client_config_server_address>(&cfg.server))
return do_start(cfg, st->host, st->port, on_start);
auto err = make_error(sec::invalid_argument,
"length-prefix factories do not accept URIs");
cfg.call_on_error(err);
return start_res_t{std::move(err)};
expected<disposable> do_start(config_type& cfg,
dsl::client_config::socket& data,
OnStart& on_start) {
return do_start_impl(cfg, data.take_fd(), on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::socket& cfg, OnStart& on_start) {
return do_start_impl(cfg, cfg.take_fd(), on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::conn& data,
OnStart& on_start) {
return do_start_impl(cfg, std::move(data.state), on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::conn& cfg, OnStart& on_start) {
return do_start_impl(cfg, std::move(cfg.state), on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::fail& cfg, OnStart&) {
cfg.call_on_error(cfg.err);
return start_res_t{std::move(cfg.err)};
expected<disposable> do_start(config_type& cfg, error& err, OnStart&) {
cfg.call_on_error(err);
return expected<disposable>{std::move(err)};
}
};
......
......@@ -52,8 +52,8 @@ public:
/// @private
template <class T, class... Ts>
auto make(dsl::client_config_token<T> token, Ts&&... xs) {
return client_factory<Trait>{token, std::forward<Ts>(xs)...};
auto make(dsl::client_config_tag<T> tag, Ts&&... xs) {
return client_factory<Trait>{tag, std::forward<Ts>(xs)...};
}
private:
......
......@@ -9,7 +9,7 @@
namespace caf::net::ssl {
/// Configures the allowed DTLS versions on a @ref context.
/// Configures the allowed TLS versions on a @ref context.
enum class tls {
any,
v1_0,
......
......@@ -60,3 +60,11 @@ make_connected_tcp_stream_socket(std::string host, uint16_t port,
timespan timeout = infinite);
} // namespace caf::net
namespace caf::detail {
expected<net::tcp_stream_socket> CAF_NET_EXPORT //
tcp_try_connect(std::string host, uint16_t port, timespan connection_timeout,
size_t max_retry_count, timespan retry_delay);
} // namespace caf::detail
......@@ -61,12 +61,19 @@ private:
namespace caf::net::web_socket {
/// Configuration type for WebSocket clients with a handshake object. The
/// handshake object sets the default endpoint to '/' for convenience.
template <class Trait>
class client_factory_config : public dsl::config_with_trait<Trait> {
public:
using super = dsl::config_with_trait<Trait>;
using super::super;
client_factory_config(multiplexer* mpx, Trait trait)
: super(mpx, std::move(trait)) {
hs.endpoint("/");
}
client_factory_config(const client_factory_config&) = default;
handshake hs;
};
......@@ -82,9 +89,7 @@ public:
using super::super;
using start_res_t = expected<disposable>;
using config_type = dsl::client_config<client_factory_config<Trait>>;
using config_type = typename super::config_type;
/// Starts a connection with the length-prefixing protocol.
template <class OnStart>
......@@ -92,32 +97,15 @@ public:
using input_res_t = typename Trait::input_resource;
using output_res_t = typename Trait::output_resource;
static_assert(std::is_invocable_v<OnStart, input_res_t, output_res_t>);
auto f = [this, &on_start](auto& cfg) {
return this->do_start(cfg, on_start);
};
return visit(f, this->config());
return super::config().visit([this, &on_start](auto& data) {
return do_start(super::config(), data, on_start);
});
}
private:
auto try_connect(const typename config_type::lazy& cfg,
const std::string& host, uint16_t port) {
auto result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
for (size_t i = 1; i <= cfg.max_retry_count; ++i) {
std::this_thread::sleep_for(cfg.retry_delay);
result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
}
return result;
}
template <class Conn, class OnStart>
start_res_t do_start_impl(config_type& cfg, net::web_socket::handshake& hs,
Conn conn, OnStart& on_start) {
expected<disposable>
do_start_impl(config_type& cfg, Conn conn, OnStart& on_start) {
// s2a: socket-to-application (and a2s is the inverse).
using input_t = typename Trait::input_type;
using output_t = typename Trait::output_type;
......@@ -128,7 +116,7 @@ private:
auto bridge = bridge_t::make(cfg.mpx, std::move(a2s_pull),
std::move(s2a_push));
auto bridge_ptr = bridge.get();
auto impl = client::make(std::move(hs), std::move(bridge));
auto impl = client::make(std::move(cfg.hs), std::move(bridge));
auto fd = conn.fd();
auto transport = transport_t::make(std::move(conn), std::move(impl));
transport->active_policy().connect(fd);
......@@ -136,66 +124,112 @@ private:
bridge_ptr->self_ref(ptr->as_disposable());
cfg.mpx->start(ptr);
on_start(std::move(s2a_pull), std::move(a2s_push));
return start_res_t{disposable{std::move(ptr)}};
return expected<disposable>{disposable{std::move(ptr)}};
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, const std::string& host,
uint16_t port, std::string path, OnStart& on_start) {
net::web_socket::handshake hs;
hs.host(host);
hs.endpoint(path);
auto fd = try_connect(cfg, host, port);
if (!fd) {
cfg.call_on_error(fd.error());
return start_res_t{std::move(fd.error())};
}
if (cfg.ctx) {
auto conn = cfg.ctx->new_connection(*fd);
if (conn)
return do_start_impl(cfg, hs, std::move(*conn), on_start);
cfg.call_on_error(conn.error());
return start_res_t{std::move(conn.error())};
expected<void> sanity_check(config_type& cfg) {
if (cfg.hs.has_mandatory_fields()) {
return {};
} else {
auto err = make_error(sec::invalid_argument,
"WebSocket handshake lacks mandatory fields");
cfg.call_on_error(err);
return {std::move(err)};
}
return do_start_impl(cfg, hs, *fd, on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, OnStart& on_start) {
if (auto* st = std::get_if<dsl::client_config_server_address>(&cfg.server))
return do_start(cfg, st->host, st->port, "/", on_start);
const auto& addr = std::get<uri>(cfg.server);
if (addr.scheme() != "ws" && addr.scheme() != "wss") {
return make_error(sec::invalid_argument, "URI must use ws or wss scheme");
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data,
dsl::server_address& addr, OnStart& on_start) {
cfg.hs.host(addr.host);
return detail::tcp_try_connect(std::move(addr.host), addr.port,
data.connection_timeout,
data.max_retry_count, data.retry_delay)
.and_then(data.with_ctx([this, &cfg, &on_start](auto& conn) {
return this->do_start_impl(cfg, std::move(conn), on_start);
}));
}
template <class OnStart>
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data, const uri& addr,
OnStart& on_start) {
const auto& auth = addr.authority();
auto host = auth.host_str();
auto port = auth.port;
// Sanity checking.
if (host.empty()) {
auto err = make_error(sec::invalid_argument,
"URI must provide a valid hostname");
return do_start(cfg, err, on_start);
}
if (addr.scheme() == "ws" && cfg.ctx) {
return make_error(sec::logic_error, "found SSL config with scheme ws");
if (addr.scheme() == "ws") {
if (data.ctx) {
auto err = make_error(sec::logic_error,
"found SSL config with scheme ws");
return do_start(cfg, err, on_start);
}
if (port == 0)
port = defaults::net::http_default_port;
} else if (addr.scheme() == "wss") {
if (port == 0)
port = defaults::net::https_default_port;
if (!data.ctx) { // Auto-initialize SSL context for wss.
auto ctx = ssl::context::make_client(ssl::tls::v1_0);
if (!ctx)
return do_start(cfg, ctx.error(), on_start);
data.ctx = std::make_shared<ssl::context>(std::move(*ctx));
}
} else {
auto err = make_error(sec::invalid_argument,
"URI must use ws or wss scheme");
return do_start(cfg, err, on_start);
}
auto port = addr.authority().port;
return do_start(cfg, std::string{addr.host_str()}, port == 0 ? 80 : port,
addr.path_query_fragment(), on_start);
// Fill the handshake with fields from the URI.
cfg.hs.host(host);
cfg.hs.endpoint(addr.path_query_fragment());
// Try to connect.
return detail::tcp_try_connect(std::move(host), port,
data.connection_timeout,
data.max_retry_count, data.retry_delay)
.and_then(data.with_ctx([this, &cfg, &on_start](auto& conn) {
return this->do_start_impl(cfg, std::move(conn), on_start);
}));
}
template <class OnStart>
start_res_t do_start(typename config_type::socket& cfg, OnStart&) {
auto err = make_error(sec::logic_error, "not implemented yet");
cfg.call_on_error(err);
return start_res_t{std::move(err)};
// return do_start_impl(cfg, cfg.take_fd(), on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data,
OnStart& on_start) {
auto fn = [this, &cfg, &data, &on_start](auto& addr) {
return this->do_start(cfg, data, addr, on_start);
};
return std::visit(fn, data.server);
}
template <class OnStart>
start_res_t do_start(typename config_type::conn& cfg, OnStart&) {
auto err = make_error(sec::logic_error, "not implemented yet");
cfg.call_on_error(err);
return start_res_t{std::move(err)};
// return do_start_impl(cfg, std::move(cfg.state), on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::socket& data,
OnStart& on_start) {
return sanity_check(cfg).and_then([&] { //
return do_start_impl(cfg, data.take_fd(), on_start);
});
}
template <class OnStart>
start_res_t do_start(typename config_type::fail& cfg, OnStart&) {
cfg.call_on_error(cfg.err);
return start_res_t{std::move(cfg.err)};
expected<disposable> do_start(config_type& cfg,
dsl::client_config::conn& data,
OnStart& on_start) {
return sanity_check(cfg).and_then([&] { //
return do_start_impl(cfg, std::move(data.state), on_start);
});
}
template <class OnStart>
expected<disposable> do_start(config_type& cfg, error& err, OnStart&) {
cfg.call_on_error(err);
return expected<disposable>{std::move(err)};
}
};
......
......@@ -86,6 +86,11 @@ public:
fields_["_endpoint"] = std::move(value);
}
/// Checks whether the handshake has an `endpoint` defined.
bool has_endpoint() const noexcept {
return fields_.contains("_endpoint");
}
/// Sets a value for the mandatory `Host` field.
/// @param value The Internet host and port number of the resource being
/// requested, as obtained from the original URI given by the
......@@ -94,6 +99,11 @@ public:
fields_["_host"] = std::move(value);
}
/// Checks whether the handshake has an `host` defined.
bool has_host() const noexcept {
return fields_.contains("_host");
}
/// Sets a value for the optional `Origin` field.
/// @param value Indicates where the GET request originates from. Usually only
/// sent by browser clients.
......
......@@ -51,7 +51,7 @@ public:
/// @private
template <class T, class... Ts>
auto make(dsl::client_config_token<T> token, Ts&&... xs) {
auto make(dsl::client_config_tag<T> token, Ts&&... xs) {
return client_factory<Trait>{token, std::forward<Ts>(xs)...};
}
......
......@@ -104,6 +104,7 @@ void socket_manager::schedule(action what) {
}
void socket_manager::shutdown() {
CAF_LOG_TRACE("");
if (!shutting_down_) {
shutting_down_ = true;
dispose();
......@@ -122,6 +123,7 @@ error socket_manager::start() {
CAF_LOG_TRACE("");
if (auto err = nonblocking(fd_, true)) {
CAF_LOG_ERROR("failed to set nonblocking flag in socket:" << err);
handler_->abort(err);
cleanup();
return err;
} else if (err = handler_->start(this); err) {
......@@ -148,6 +150,7 @@ void socket_manager::handle_write_event() {
}
void socket_manager::handle_error(sec code) {
CAF_LOG_TRACE("");
if (!disposed_)
disposed_ = true;
if (handler_) {
......
......@@ -182,3 +182,25 @@ expected<tcp_stream_socket> make_connected_tcp_stream_socket(std::string host,
}
} // namespace caf::net
namespace caf::detail {
expected<net::tcp_stream_socket>
tcp_try_connect(std::string host, uint16_t port, timespan connection_timeout,
size_t max_retry_count, timespan retry_delay) {
uri::authority_type auth;
auth.host = std::move(host);
auth.port = port;
auto result = net::make_connected_tcp_stream_socket(auth, connection_timeout);
if (result)
return result;
for (size_t i = 1; i <= max_retry_count; ++i) {
std::this_thread::sleep_for(retry_delay);
result = net::make_connected_tcp_stream_socket(auth, connection_timeout);
if (result)
return result;
}
return result;
}
} // namespace caf::detail
......@@ -67,7 +67,7 @@ void handshake::randomize_key(unsigned seed) {
}
bool handshake::has_mandatory_fields() const noexcept {
return fields_.contains("_endpoint") && fields_.contains("_host");
return has_endpoint() && has_host();
}
// -- HTTP generation and validation -------------------------------------------
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment