Commit c0663ea7 authored by Dominik Charousset's avatar Dominik Charousset

Simplify client_config scaffolding

parent 5310b68e
......@@ -154,4 +154,10 @@ namespace caf::defaults::net {
/// previous connection has been closed.
constexpr auto max_connections = make_parameter("max-connections", size_t{64});
/// The default port for HTTP servers.
constexpr uint16_t http_default_port = 80;
/// The default port for HTTPS servers.
constexpr uint16_t https_default_port = 443;
} // namespace caf::defaults::net
This diff is collapsed.
......@@ -20,22 +20,18 @@ namespace caf::net::dsl {
template <class ConfigBase, class Derived>
class client_factory_base {
public:
using config_type = client_config<ConfigBase>;
using config_type = client_config_value<ConfigBase>;
using trait_type = typename config_type::trait_type;
using config_pointer = intrusive_ptr<config_type>;
explicit client_factory_base(config_pointer cfg) : cfg_(std::move(cfg)) {
// nop
}
client_factory_base(const client_factory_base&) = default;
client_factory_base& operator=(const client_factory_base&) = default;
template <class T, class... Ts>
explicit client_factory_base(dsl::client_config_token<T> token, Ts&&... xs) {
explicit client_factory_base(dsl::client_config_tag<T> token, Ts&&... xs) {
cfg_ = config_type::make(token, std::forward<Ts>(xs)...);
}
......@@ -52,8 +48,8 @@ public:
/// @param value The new retry delay.
/// @returns a reference to this `client_factory`.
Derived& retry_delay(timespan value) {
if (auto* cfg = get_if<lazy_client_config<ConfigBase>>(cfg_.get()))
cfg->retry_delay = value;
if (auto* lazy = get_if<client_config::lazy>(&cfg_->data))
lazy->retry_delay = value;
return dref();
}
......@@ -62,8 +58,8 @@ public:
/// @param value The new connection timeout.
/// @returns a reference to this `client_factory`.
Derived& connection_timeout(timespan value) {
if (auto* cfg = get_if<lazy_client_config<ConfigBase>>(cfg_.get()))
cfg->connection_timeout = value;
if (auto* lazy = get_if<client_config::lazy>(&cfg_->data))
lazy->connection_timeout = value;
return dref();
}
......@@ -72,8 +68,8 @@ public:
/// @param value The new maximum retry count.
/// @returns a reference to this `client_factory`.
Derived& max_retry_count(size_t value) {
if (auto* cfg = get_if<lazy_client_config<ConfigBase>>(cfg_.get()))
cfg->max_retry_count = value;
if (auto* lazy = get_if<client_config::lazy>(&cfg_->data))
lazy->max_retry_count = value;
return dref();
}
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
namespace caf::net::dsl {
enum class client_config_type;
template <class Base>
class client_config;
template <class Base>
class lazy_client_config;
template <class Base>
class socket_client_config;
template <class Base>
class conn_client_config;
template <class Base>
class fail_client_config;
} // namespace caf::net::dsl
......@@ -27,8 +27,8 @@ public:
/// @returns a `connect_factory` object initialized with the given parameters.
auto connect(std::string host, uint16_t port) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_lazy_v, std::move(host), port, this->mpx(),
this->trait());
return dref.make(client_config::lazy_v, this->mpx(), this->trait(),
std::move(host), port);
}
/// Creates a `connect_factory` object for the given stream `fd`.
......@@ -36,7 +36,7 @@ public:
/// @returns a `connect_factory` object that will use the given socket.
auto connect(stream_socket fd) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_socket_v, fd, this->mpx(), this->trait());
return dref.make(client_config::socket_v, this->mpx(), this->trait(), fd);
}
/// Creates a `connect_factory` object for the given SSL `connection`.
......@@ -44,8 +44,8 @@ public:
/// @returns a `connect_factory` object that will use the given connection.
auto connect(ssl::connection conn) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_conn_v, std::move(conn), this->mpx(),
this->trait());
return dref.make(client_config::conn_v, this->mpx(), this->trait(),
std::move(conn));
}
};
......
......@@ -31,8 +31,8 @@ public:
/// @returns a `connect_factory` object initialized with the given parameters.
auto connect(const uri& endpoint) {
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_lazy_v, endpoint, this->mpx(),
this->trait());
return dref.make(client_config::lazy_v, this->mpx(), this->trait(),
endpoint);
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
......@@ -43,8 +43,8 @@ public:
if (endpoint)
return connect(*endpoint);
auto& dref = static_cast<Subtype&>(*this);
return dref.make(client_config_fail_v, std::move(endpoint.error()),
this->mpx(), this->trait());
return dref.make(client_config::fail_v, this->mpx(), this->trait(),
std::move(endpoint.error()));
}
};
......
......@@ -45,6 +45,12 @@ public:
std::move(push));
}
void abort(const error& err) override {
super::abort(err);
if (push_)
push_.abort(err);
}
error start(net::binary::lower_layer* down_ptr) override {
super::down_ = down_ptr;
return super::init(std::move(pull_), std::move(push_));
......@@ -75,39 +81,21 @@ public:
using super::super;
using start_res_t = expected<disposable>;
/// Starts a connection with the length-prefixing protocol.
template <class OnStart>
[[nodiscard]] expected<disposable> start(OnStart on_start) {
using input_res_t = typename Trait::input_resource;
using output_res_t = typename Trait::output_resource;
static_assert(std::is_invocable_v<OnStart, input_res_t, output_res_t>);
auto f = [this, &on_start](auto& cfg) {
return this->do_start(cfg, on_start);
};
return visit(f, this->config());
return super::config().visit([this, &on_start](auto& data) {
return do_start(super::config(), data, on_start);
});
}
private:
auto try_connect(const typename config_type::lazy& cfg,
const std::string& host, uint16_t port) {
auto result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
for (size_t i = 1; i <= cfg.max_retry_count; ++i) {
std::this_thread::sleep_for(cfg.retry_delay);
result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
}
return result;
}
template <class Conn, class OnStart>
start_res_t do_start_impl(config_type& cfg, Conn conn, OnStart& on_start) {
expected<disposable>
do_start_impl(config_type& cfg, Conn conn, OnStart& on_start) {
// s2a: socket-to-application (and a2s is the inverse).
using input_t = typename Trait::input_type;
using output_t = typename Trait::output_type;
......@@ -126,51 +114,45 @@ private:
bridge_ptr->self_ref(ptr->as_disposable());
cfg.mpx->start(ptr);
on_start(std::move(s2a_pull), std::move(a2s_push));
return start_res_t{disposable{std::move(ptr)}};
return expected<disposable>{disposable{std::move(ptr)}};
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, const std::string& host,
uint16_t port, OnStart& on_start) {
auto fd = try_connect(cfg, host, port);
if (fd) {
if (cfg.ctx) {
auto conn = cfg.ctx->new_connection(*fd);
if (conn)
return do_start_impl(cfg, std::move(*conn), on_start);
cfg.call_on_error(conn.error());
return start_res_t{std::move(conn.error())};
}
return do_start_impl(cfg, *fd, on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data,
OnStart& on_start) {
if (std::holds_alternative<uri>(data.server)) {
auto err = make_error(sec::invalid_argument,
"length-prefix factories do not accept URIs");
return do_start(cfg, err, on_start);
}
cfg.call_on_error(fd.error());
return start_res_t{std::move(fd.error())};
auto& addr = std::get<dsl::server_address>(data.server);
return detail::tcp_try_connect(std::move(addr.host), addr.port,
data.connection_timeout,
data.max_retry_count, data.retry_delay)
.and_then(data.with_ctx([this, &cfg, &on_start](auto& conn) {
return this->do_start_impl(cfg, std::move(conn), on_start);
}));
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, OnStart& on_start) {
if (auto* st = std::get_if<dsl::client_config_server_address>(&cfg.server))
return do_start(cfg, st->host, st->port, on_start);
auto err = make_error(sec::invalid_argument,
"length-prefix factories do not accept URIs");
cfg.call_on_error(err);
return start_res_t{std::move(err)};
expected<disposable> do_start(config_type& cfg,
dsl::client_config::socket& data,
OnStart& on_start) {
return do_start_impl(cfg, data.take_fd(), on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::socket& cfg, OnStart& on_start) {
return do_start_impl(cfg, cfg.take_fd(), on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::conn& data,
OnStart& on_start) {
return do_start_impl(cfg, std::move(data.state), on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::conn& cfg, OnStart& on_start) {
return do_start_impl(cfg, std::move(cfg.state), on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::fail& cfg, OnStart&) {
cfg.call_on_error(cfg.err);
return start_res_t{std::move(cfg.err)};
expected<disposable> do_start(config_type& cfg, error& err, OnStart&) {
cfg.call_on_error(err);
return expected<disposable>{std::move(err)};
}
};
......
......@@ -52,8 +52,8 @@ public:
/// @private
template <class T, class... Ts>
auto make(dsl::client_config_token<T> token, Ts&&... xs) {
return client_factory<Trait>{token, std::forward<Ts>(xs)...};
auto make(dsl::client_config_tag<T> tag, Ts&&... xs) {
return client_factory<Trait>{tag, std::forward<Ts>(xs)...};
}
private:
......
......@@ -9,7 +9,7 @@
namespace caf::net::ssl {
/// Configures the allowed DTLS versions on a @ref context.
/// Configures the allowed TLS versions on a @ref context.
enum class tls {
any,
v1_0,
......
......@@ -60,3 +60,11 @@ make_connected_tcp_stream_socket(std::string host, uint16_t port,
timespan timeout = infinite);
} // namespace caf::net
namespace caf::detail {
expected<net::tcp_stream_socket> CAF_NET_EXPORT //
tcp_try_connect(std::string host, uint16_t port, timespan connection_timeout,
size_t max_retry_count, timespan retry_delay);
} // namespace caf::detail
......@@ -61,12 +61,19 @@ private:
namespace caf::net::web_socket {
/// Configuration type for WebSocket clients with a handshake object. The
/// handshake object sets the default endpoint to '/' for convenience.
template <class Trait>
class client_factory_config : public dsl::config_with_trait<Trait> {
public:
using super = dsl::config_with_trait<Trait>;
using super::super;
client_factory_config(multiplexer* mpx, Trait trait)
: super(mpx, std::move(trait)) {
hs.endpoint("/");
}
client_factory_config(const client_factory_config&) = default;
handshake hs;
};
......@@ -82,9 +89,7 @@ public:
using super::super;
using start_res_t = expected<disposable>;
using config_type = dsl::client_config<client_factory_config<Trait>>;
using config_type = typename super::config_type;
/// Starts a connection with the length-prefixing protocol.
template <class OnStart>
......@@ -92,32 +97,15 @@ public:
using input_res_t = typename Trait::input_resource;
using output_res_t = typename Trait::output_resource;
static_assert(std::is_invocable_v<OnStart, input_res_t, output_res_t>);
auto f = [this, &on_start](auto& cfg) {
return this->do_start(cfg, on_start);
};
return visit(f, this->config());
return super::config().visit([this, &on_start](auto& data) {
return do_start(super::config(), data, on_start);
});
}
private:
auto try_connect(const typename config_type::lazy& cfg,
const std::string& host, uint16_t port) {
auto result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
for (size_t i = 1; i <= cfg.max_retry_count; ++i) {
std::this_thread::sleep_for(cfg.retry_delay);
result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
}
return result;
}
template <class Conn, class OnStart>
start_res_t do_start_impl(config_type& cfg, net::web_socket::handshake& hs,
Conn conn, OnStart& on_start) {
expected<disposable>
do_start_impl(config_type& cfg, Conn conn, OnStart& on_start) {
// s2a: socket-to-application (and a2s is the inverse).
using input_t = typename Trait::input_type;
using output_t = typename Trait::output_type;
......@@ -128,7 +116,7 @@ private:
auto bridge = bridge_t::make(cfg.mpx, std::move(a2s_pull),
std::move(s2a_push));
auto bridge_ptr = bridge.get();
auto impl = client::make(std::move(hs), std::move(bridge));
auto impl = client::make(std::move(cfg.hs), std::move(bridge));
auto fd = conn.fd();
auto transport = transport_t::make(std::move(conn), std::move(impl));
transport->active_policy().connect(fd);
......@@ -136,66 +124,112 @@ private:
bridge_ptr->self_ref(ptr->as_disposable());
cfg.mpx->start(ptr);
on_start(std::move(s2a_pull), std::move(a2s_push));
return start_res_t{disposable{std::move(ptr)}};
return expected<disposable>{disposable{std::move(ptr)}};
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, const std::string& host,
uint16_t port, std::string path, OnStart& on_start) {
net::web_socket::handshake hs;
hs.host(host);
hs.endpoint(path);
auto fd = try_connect(cfg, host, port);
if (!fd) {
cfg.call_on_error(fd.error());
return start_res_t{std::move(fd.error())};
}
if (cfg.ctx) {
auto conn = cfg.ctx->new_connection(*fd);
if (conn)
return do_start_impl(cfg, hs, std::move(*conn), on_start);
cfg.call_on_error(conn.error());
return start_res_t{std::move(conn.error())};
expected<void> sanity_check(config_type& cfg) {
if (cfg.hs.has_mandatory_fields()) {
return {};
} else {
auto err = make_error(sec::invalid_argument,
"WebSocket handshake lacks mandatory fields");
cfg.call_on_error(err);
return {std::move(err)};
}
return do_start_impl(cfg, hs, *fd, on_start);
}
template <class OnStart>
start_res_t do_start(typename config_type::lazy& cfg, OnStart& on_start) {
if (auto* st = std::get_if<dsl::client_config_server_address>(&cfg.server))
return do_start(cfg, st->host, st->port, "/", on_start);
const auto& addr = std::get<uri>(cfg.server);
if (addr.scheme() != "ws" && addr.scheme() != "wss") {
return make_error(sec::invalid_argument, "URI must use ws or wss scheme");
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data,
dsl::server_address& addr, OnStart& on_start) {
cfg.hs.host(addr.host);
return detail::tcp_try_connect(std::move(addr.host), addr.port,
data.connection_timeout,
data.max_retry_count, data.retry_delay)
.and_then(data.with_ctx([this, &cfg, &on_start](auto& conn) {
return this->do_start_impl(cfg, std::move(conn), on_start);
}));
}
template <class OnStart>
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data, const uri& addr,
OnStart& on_start) {
const auto& auth = addr.authority();
auto host = auth.host_str();
auto port = auth.port;
// Sanity checking.
if (host.empty()) {
auto err = make_error(sec::invalid_argument,
"URI must provide a valid hostname");
return do_start(cfg, err, on_start);
}
if (addr.scheme() == "ws" && cfg.ctx) {
return make_error(sec::logic_error, "found SSL config with scheme ws");
if (addr.scheme() == "ws") {
if (data.ctx) {
auto err = make_error(sec::logic_error,
"found SSL config with scheme ws");
return do_start(cfg, err, on_start);
}
if (port == 0)
port = defaults::net::http_default_port;
} else if (addr.scheme() == "wss") {
if (port == 0)
port = defaults::net::https_default_port;
if (!data.ctx) { // Auto-initialize SSL context for wss.
auto ctx = ssl::context::make_client(ssl::tls::v1_0);
if (!ctx)
return do_start(cfg, ctx.error(), on_start);
data.ctx = std::make_shared<ssl::context>(std::move(*ctx));
}
} else {
auto err = make_error(sec::invalid_argument,
"URI must use ws or wss scheme");
return do_start(cfg, err, on_start);
}
auto port = addr.authority().port;
return do_start(cfg, std::string{addr.host_str()}, port == 0 ? 80 : port,
addr.path_query_fragment(), on_start);
// Fill the handshake with fields from the URI.
cfg.hs.host(host);
cfg.hs.endpoint(addr.path_query_fragment());
// Try to connect.
return detail::tcp_try_connect(std::move(host), port,
data.connection_timeout,
data.max_retry_count, data.retry_delay)
.and_then(data.with_ctx([this, &cfg, &on_start](auto& conn) {
return this->do_start_impl(cfg, std::move(conn), on_start);
}));
}
template <class OnStart>
start_res_t do_start(typename config_type::socket& cfg, OnStart&) {
auto err = make_error(sec::logic_error, "not implemented yet");
cfg.call_on_error(err);
return start_res_t{std::move(err)};
// return do_start_impl(cfg, cfg.take_fd(), on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::lazy& data,
OnStart& on_start) {
auto fn = [this, &cfg, &data, &on_start](auto& addr) {
return this->do_start(cfg, data, addr, on_start);
};
return std::visit(fn, data.server);
}
template <class OnStart>
start_res_t do_start(typename config_type::conn& cfg, OnStart&) {
auto err = make_error(sec::logic_error, "not implemented yet");
cfg.call_on_error(err);
return start_res_t{std::move(err)};
// return do_start_impl(cfg, std::move(cfg.state), on_start);
expected<disposable> do_start(config_type& cfg,
dsl::client_config::socket& data,
OnStart& on_start) {
return sanity_check(cfg).and_then([&] { //
return do_start_impl(cfg, data.take_fd(), on_start);
});
}
template <class OnStart>
start_res_t do_start(typename config_type::fail& cfg, OnStart&) {
cfg.call_on_error(cfg.err);
return start_res_t{std::move(cfg.err)};
expected<disposable> do_start(config_type& cfg,
dsl::client_config::conn& data,
OnStart& on_start) {
return sanity_check(cfg).and_then([&] { //
return do_start_impl(cfg, std::move(data.state), on_start);
});
}
template <class OnStart>
expected<disposable> do_start(config_type& cfg, error& err, OnStart&) {
cfg.call_on_error(err);
return expected<disposable>{std::move(err)};
}
};
......
......@@ -86,6 +86,11 @@ public:
fields_["_endpoint"] = std::move(value);
}
/// Checks whether the handshake has an `endpoint` defined.
bool has_endpoint() const noexcept {
return fields_.contains("_endpoint");
}
/// Sets a value for the mandatory `Host` field.
/// @param value The Internet host and port number of the resource being
/// requested, as obtained from the original URI given by the
......@@ -94,6 +99,11 @@ public:
fields_["_host"] = std::move(value);
}
/// Checks whether the handshake has an `host` defined.
bool has_host() const noexcept {
return fields_.contains("_host");
}
/// Sets a value for the optional `Origin` field.
/// @param value Indicates where the GET request originates from. Usually only
/// sent by browser clients.
......
......@@ -51,7 +51,7 @@ public:
/// @private
template <class T, class... Ts>
auto make(dsl::client_config_token<T> token, Ts&&... xs) {
auto make(dsl::client_config_tag<T> token, Ts&&... xs) {
return client_factory<Trait>{token, std::forward<Ts>(xs)...};
}
......
......@@ -104,6 +104,7 @@ void socket_manager::schedule(action what) {
}
void socket_manager::shutdown() {
CAF_LOG_TRACE("");
if (!shutting_down_) {
shutting_down_ = true;
dispose();
......@@ -122,6 +123,7 @@ error socket_manager::start() {
CAF_LOG_TRACE("");
if (auto err = nonblocking(fd_, true)) {
CAF_LOG_ERROR("failed to set nonblocking flag in socket:" << err);
handler_->abort(err);
cleanup();
return err;
} else if (err = handler_->start(this); err) {
......@@ -148,6 +150,7 @@ void socket_manager::handle_write_event() {
}
void socket_manager::handle_error(sec code) {
CAF_LOG_TRACE("");
if (!disposed_)
disposed_ = true;
if (handler_) {
......
......@@ -182,3 +182,25 @@ expected<tcp_stream_socket> make_connected_tcp_stream_socket(std::string host,
}
} // namespace caf::net
namespace caf::detail {
expected<net::tcp_stream_socket>
tcp_try_connect(std::string host, uint16_t port, timespan connection_timeout,
size_t max_retry_count, timespan retry_delay) {
uri::authority_type auth;
auth.host = std::move(host);
auth.port = port;
auto result = net::make_connected_tcp_stream_socket(auth, connection_timeout);
if (result)
return result;
for (size_t i = 1; i <= max_retry_count; ++i) {
std::this_thread::sleep_for(retry_delay);
result = net::make_connected_tcp_stream_socket(auth, connection_timeout);
if (result)
return result;
}
return result;
}
} // namespace caf::detail
......@@ -67,7 +67,7 @@ void handshake::randomize_key(unsigned seed) {
}
bool handshake::has_mandatory_fields() const noexcept {
return fields_.contains("_endpoint") && fields_.contains("_host");
return has_endpoint() && has_host();
}
// -- HTTP generation and validation -------------------------------------------
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment