Commit 2a163af7 authored by Dominik Charousset's avatar Dominik Charousset

Factor out with(...) DSL scaffolding

parent 5c897919
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/callback.hpp"
#include "caf/defaults.hpp"
#include "caf/detail/plain_ref_counted.hpp"
#include "caf/intrusive_ptr.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/connection.hpp"
#include "caf/net/ssl/context.hpp"
#include "caf/net/stream_socket.hpp"
#include "caf/uri.hpp"
#include <cassert>
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// The server config type enum class.
enum class client_config_type { lazy, socket, conn, fail };
/// Base class for server configuration objects.
template <class Trait>
class client_config : public detail::plain_ref_counted {
public:
class lazy;
class socket;
class conn;
class fail;
friend class lazy;
friend class socket;
friend class conn;
friend class fail;
client_config(const client_config&) = delete;
client_config& operator=(const client_config&) = delete;
/// Virtual destructor.
virtual ~client_config() = default;
/// Returns the server configuration type.
virtual client_config_type type() const noexcept = 0;
/// The pointer to the @ref multiplexer for running the server.
multiplexer* mpx;
/// The user-defined trait for configuration serialization.
Trait trait;
/// User-defined callback for errors.
shared_callback_ptr<void(const error&)> on_error;
/// Calls `on_error` if non-null.
void call_on_error(const error& what) {
if (on_error)
(*on_error)(what);
}
friend void intrusive_ptr_add_ref(const client_config* ptr) noexcept {
ptr->ref();
}
friend void intrusive_ptr_release(const client_config* ptr) noexcept {
ptr->deref();
}
private:
/// Private constructor to enforce sealing.
client_config(multiplexer* mpx, const Trait& trait) : mpx(mpx), trait(trait) {
// nop
}
};
/// Intrusive pointer type for server configurations.
template <class Trait>
using client_config_ptr = intrusive_ptr<client_config<Trait>>;
/// Simple type for storing host and port information for reaching a server.
struct client_config_server_address {
/// The host name or IP address of the host.
std::string host;
/// The port to connect to.
uint16_t port;
};
/// Configuration for a client that creates the socket on demand.
template <class Trait>
class client_config<Trait>::lazy final : public client_config<Trait> {
public:
static constexpr auto type_token = client_config_type::lazy;
using super = client_config;
lazy(multiplexer* mpx, const Trait& trait, std::string host, uint16_t port)
: super(mpx, trait) {
server = client_config_server_address{std::move(host), port};
}
lazy(multiplexer* mpx, const Trait& trait, const uri& addr)
: super(mpx, trait) {
server = addr;
}
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
/// Type for holding a client address.
using server_t = std::variant<client_config_server_address, uri>;
/// The address for reaching the server or an error.
server_t server;
/// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx;
/// The delay between connection attempts.
timespan retry_delay = std::chrono::seconds{1};
/// The timeout when trying to connect.
timespan connection_timeout = infinite;
/// The maximum amount of retries.
size_t max_retry_count = 0;
};
/// Configuration for a client that uses a user-provided socket.
template <class Trait>
class client_config<Trait>::socket final : public client_config<Trait> {
public:
static constexpr auto type_token = client_config_type::socket;
using super = client_config;
socket(multiplexer* mpx, const Trait& trait) : super(mpx, trait) {
// nop
}
socket(multiplexer* mpx, const Trait& trait, stream_socket fd)
: super(mpx, trait), fd(fd) {
// nop
}
~socket() override {
if (fd != invalid_socket)
close(fd);
}
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
/// The socket file descriptor to use.
stream_socket fd;
/// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx;
/// Returns the file descriptor and setting the `fd` member variable to the
/// invalid socket.
stream_socket take_fd() noexcept {
auto result = fd;
fd.id = invalid_socket_id;
return result;
}
};
/// Configuration for a client that uses an already established SSL connection.
template <class Trait>
class client_config<Trait>::conn final : public client_config<Trait> {
public:
static constexpr auto type_token = client_config_type::conn;
using super = client_config;
conn(multiplexer* mpx, const Trait& trait, ssl::connection state)
: super(mpx, trait), state(std::move(state)) {
// nop
}
~conn() override {
if (state) {
if (auto fd = state.fd(); fd != invalid_socket)
close(fd);
}
}
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
/// SSL state for the connection.
ssl::connection state;
};
/// Wraps an error that occurred earlier in the setup phase.
template <class Trait>
class client_config<Trait>::fail final : public client_config<Trait> {
public:
static constexpr auto type_token = client_config_type::fail;
using super = client_config;
fail(multiplexer* mpx, const Trait& trait, error err)
: super(mpx, trait), err(std::move(err)) {
// nop
}
/// Returns the server configuration type.
client_config_type type() const noexcept override {
return type_token;
}
/// The forwarded error.
error err;
};
/// Convenience alias for the `lazy` sub-type of @ref client_config.
template <class Trait>
using lazy_client_config = typename client_config<Trait>::lazy;
/// Convenience alias for the `socket` sub-type of @ref client_config.
template <class Trait>
using socket_client_config = typename client_config<Trait>::socket;
/// Convenience alias for the `conn` sub-type of @ref client_config.
template <class Trait>
using conn_client_config = typename client_config<Trait>::conn;
/// Convenience alias for the `fail` sub-type of @ref client_config.
template <class Trait>
using fail_client_config = typename client_config<Trait>::fail;
/// Calls a function object with the actual subtype of a client configuration
/// and returns its result.
template <class F, class Trait>
decltype(auto) visit(F&& f, client_config<Trait>& cfg) {
auto type = cfg.type();
switch (cfg.type()) {
case client_config_type::lazy:
return f(static_cast<lazy_client_config<Trait>&>(cfg));
case client_config_type::socket:
return f(static_cast<socket_client_config<Trait>&>(cfg));
case client_config_type::conn:
return f(static_cast<conn_client_config<Trait>&>(cfg));
default:
assert(type == client_config_type::fail);
return f(static_cast<fail_client_config<Trait>&>(cfg));
}
}
/// Calls a function object with the actual subtype of a client configuration
/// and returns its result.
template <class F, class Trait>
decltype(auto) visit(F&& f, const client_config<Trait>& cfg) {
auto type = cfg.type();
switch (cfg.type()) {
case client_config_type::lazy:
return f(static_cast<const lazy_client_config<Trait>&>(cfg));
case client_config_type::socket:
return f(static_cast<const socket_client_config<Trait>&>(cfg));
case client_config_type::conn:
return f(static_cast<const conn_client_config<Trait>&>(cfg));
default:
assert(type == client_config_type::fail);
return f(static_cast<const fail_client_config<Trait>&>(cfg));
}
}
/// Gets a pointer to a specific subtype of a client configuration.
template <class T, class Trait>
T* get_if(client_config<Trait>* config) {
if (T::type_token == config->type())
return static_cast<T*>(config);
return nullptr;
}
/// Gets a pointer to a specific subtype of a client configuration.
template <class T, class Trait>
const T* get_if(const client_config<Trait>* config) {
if (T::type_token == config->type())
return static_cast<const T*>(config);
return nullptr;
}
} // namespace caf::net::dsl
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/make_counted.hpp"
#include "caf/net/dsl/client_config.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/acceptor.hpp"
#include "caf/net/tcp_accept_socket.hpp"
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// Base type for client factories for use with `can_connect`.
template <class Trait, class Derived>
class client_factory_base {
public:
using trait_type = Trait;
explicit client_factory_base(client_config_ptr<Trait> cfg)
: cfg_(std::move(cfg)) {
// nop
}
client_factory_base(const client_factory_base&) = default;
client_factory_base& operator=(const client_factory_base&) = default;
/// Sets the callback for errors.
template <class F>
Derived& do_on_error(F callback) {
static_assert(std::is_invocable_v<F, const error&>);
cfg_->on_error = make_shared_type_erased_callback(std::move(callback));
return dref();
}
/// Sets the retry delay for connection attempts.
///
/// @param value The new retry delay.
/// @returns a reference to this `client_factory`.
Derived& retry_delay(timespan value) {
if (auto* cfg = std::get_if<lazy_client_config<Trait>>(&cfg_.get()))
cfg->retry_delay = value;
return dref();
}
/// Sets the connection timeout for connection attempts.
///
/// @param value The new connection timeout.
/// @returns a reference to this `client_factory`.
Derived& connection_timeout(timespan value) {
if (auto* cfg = std::get_if<lazy_client_config<Trait>>(&cfg_.get()))
cfg->connection_timeout = value;
return dref();
}
/// Sets the maximum number of connection retry attempts.
///
/// @param value The new maximum retry count.
/// @returns a reference to this `client_factory`.
Derived& max_retry_count(size_t value) {
if (auto* cfg = std::get_if<lazy_client_config<Trait>>(&cfg_.get()))
cfg->max_retry_count = value;
return dref();
}
client_config<Trait>& config() {
return *cfg_;
}
private:
Derived& dref() {
return static_cast<Derived&>(*this);
}
client_config_ptr<Trait> cfg_;
};
} // namespace caf::net::dsl
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/make_counted.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/dsl/server_config.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/acceptor.hpp"
#include "caf/net/tcp_accept_socket.hpp"
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// DSL entry point for creating a server.
template <class ServerFactory>
class has_accept : public has_trait<typename ServerFactory::trait_type> {
public:
using trait_type = typename ServerFactory::trait_type;
using super = has_trait<trait_type>;
using super::super;
/// Creates an `accept_factory` object for the given TCP `port` and
/// `bind_address`.
///
/// @param port Port number to bind to.
/// @param bind_address IP address to bind to. Default is an empty string.
/// @returns an `accept_factory` object initialized with the given parameters.
ServerFactory accept(uint16_t port, std::string bind_address = "") {
auto cfg = make_lazy_config(port, std::move(bind_address));
return ServerFactory{std::move(cfg)};
}
/// Creates an `accept_factory` object for the given TCP `port` and
/// `bind_address`.
///
/// @param ctx The SSL context for encryption.
/// @param port Port number to bind to.
/// @param bind_address IP address to bind to. Default is an empty string.
/// @returns an `accept_factory` object initialized with the given parameters.
ServerFactory accept(ssl::context ctx, uint16_t port,
std::string bind_address = "") {
auto cfg = make_lazy_config(port, std::move(bind_address));
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx));
return ServerFactory{std::move(cfg)};
}
/// Creates an `accept_factory` object for the given accept socket.
///
/// @param fd File descriptor for the accept socket.
/// @returns an `accept_factory` object that will start a Prometheus server on
/// the given socket.
ServerFactory accept(tcp_accept_socket fd) {
auto cfg = make_socket_config(fd);
return ServerFactory{std::move(cfg)};
}
/// Creates an `accept_factory` object for the given acceptor.
///
/// @param ctx The SSL context for encryption.
/// @param fd File descriptor for the accept socket.
/// @returns an `accept_factory` object that will start a Prometheus server on
/// the given acceptor.
ServerFactory accept(ssl::context ctx, tcp_accept_socket fd) {
auto cfg = make_socket_config(fd);
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx));
return ServerFactory{std::move(cfg)};
}
/// Creates an `accept_factory` object for the given acceptor.
///
/// @param acc The SSL acceptor for incoming connections.
/// @returns an `accept_factory` object that will start a Prometheus server on
/// the given acceptor.
ServerFactory accept(ssl::acceptor acc) {
return accept(std::move(acc.ctx()), acc.fd());
}
private:
template <class... Ts>
server_config_ptr<trait_type> make_lazy_config(Ts&&... xs) {
using impl_t = typename server_config<trait_type>::lazy;
return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...);
}
template <class... Ts>
server_config_ptr<trait_type> make_socket_config(Ts&&... xs) {
using impl_t = typename server_config<trait_type>::socket;
return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...);
}
};
} // namespace caf::net::dsl
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/make_counted.hpp"
#include "caf/net/dsl/client_config.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/tcp_stream_socket.hpp"
#include "caf/uri.hpp"
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// DSL entry point for creating a server.
template <class ClientFactory>
class has_connect : public has_trait<typename ClientFactory::trait_type> {
public:
using trait_type = typename ClientFactory::trait_type;
using super = has_trait<trait_type>;
using super::super;
/// Creates a `connect_factory` object for the given TCP `host` and `port`.
///
/// @param host The hostname or IP address to connect to.
/// @param port The port number to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(std::string host, uint16_t port) {
auto cfg = make_lazy_config(std::move(host), port);
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given SSL `context`, TCP
/// `host`, and `port`.
///
/// @param ctx The SSL context for encryption.
/// @param host The hostname or IP address to connect to.
/// @param port The port number to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(ssl::context ctx, std::string host, uint16_t port) {
auto cfg = make_lazy_config(std::move(host), port);
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx));
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(const uri& endpoint) {
auto cfg = make_lazy_config(endpoint);
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given SSL `context` and TCP
/// `endpoint`.
///
/// @param ctx The SSL context for encryption.
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(ssl::context ctx, const uri& endpoint) {
auto cfg = make_lazy_config(endpoint);
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx));
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(expected<uri> endpoint) {
if (endpoint)
return connect(*endpoint);
auto cfg = make_fail_config(endpoint.error());
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given SSL `context` and TCP
/// `endpoint`.
///
/// @param ctx The SSL context for encryption.
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(ssl::context ctx, expected<uri> endpoint) {
if (endpoint)
return connect(std::move(ctx), *endpoint);
auto cfg = make_fail_config(endpoint.error());
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given stream `fd`.
///
/// @param fd The stream socket to use for the connection.
/// @returns a `connect_factory` object that will use the given socket.
ClientFactory connect(stream_socket fd) {
auto cfg = make_socket_config(fd);
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given SSL `connection`.
///
/// @param conn The SSL connection to use.
/// @returns a `connect_factory` object that will use the given connection.
ClientFactory connect(ssl::connection conn) {
auto cfg = make_conn_config(std::move(conn));
return ClientFactory{std::move(cfg)};
}
private:
template <class... Ts>
client_config_ptr<trait_type> make_lazy_config(Ts&&... xs) {
using impl_t = lazy_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...);
}
template <class... Ts>
client_config_ptr<trait_type> make_socket_config(Ts&&... xs) {
using impl_t = socket_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...);
}
template <class... Ts>
client_config_ptr<trait_type> make_conn_config(Ts&&... xs) {
using impl_t = conn_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...);
}
template <class... Ts>
client_config_ptr<trait_type> make_fail_config(Ts&&... xs) {
using impl_t = fail_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...);
}
};
} // namespace caf::net::dsl
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/net/fwd.hpp"
namespace caf::net::dsl {
/// Base type for DSL classes.
template <class Trait>
class has_trait {
public:
virtual ~has_trait() {
// nop
}
/// @returns the pointer to the @ref multiplexer.
virtual multiplexer* mpx() const noexcept = 0;
/// @returns the trait object.
virtual const Trait& trait() const noexcept = 0;
};
} // namespace caf::net::dsl
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/callback.hpp"
#include "caf/defaults.hpp"
#include "caf/detail/plain_ref_counted.hpp"
#include "caf/intrusive_ptr.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/context.hpp"
#include "caf/net/tcp_accept_socket.hpp"
#include <cassert>
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// The server config type enum class.
enum class server_config_type { lazy, socket };
/// Base class for server configuration objects.
template <class Trait>
class server_config : public detail::plain_ref_counted {
public:
class lazy;
class socket;
friend class lazy;
friend class socket;
server_config(const server_config&) = delete;
server_config& operator=(const server_config&) = delete;
/// Virtual destructor.
virtual ~server_config() = default;
/// Returns the server configuration type.
virtual server_config_type type() const noexcept = 0;
/// The pointer to the @ref multiplexer for running the server.
multiplexer* mpx;
/// The user-defined trait for configuration serialization.
Trait trait;
/// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx;
/// User-defined callback for errors.
shared_callback_ptr<void(const error&)> on_error;
/// Configures the maximum number of concurrent connections.
size_t max_connections = defaults::net::max_connections.fallback;
/// Calls `on_error` if non-null.
void call_on_error(const error& what) {
if (on_error)
(*on_error)(what);
}
friend void intrusive_ptr_add_ref(const server_config* ptr) noexcept {
ptr->ref();
}
friend void intrusive_ptr_release(const server_config* ptr) noexcept {
ptr->deref();
}
private:
/// Private constructor to enforce sealing.
server_config(multiplexer* mpx, const Trait& trait) : mpx(mpx), trait(trait) {
// nop
}
};
/// Intrusive pointer type for server configurations.
template <class Trait>
using server_config_ptr = intrusive_ptr<server_config<Trait>>;
/// Configuration for a server that creates the socket on demand.
template <class Trait>
class server_config<Trait>::lazy final : public server_config<Trait> {
public:
static constexpr auto type_token = server_config_type::lazy;
using super = server_config;
lazy(multiplexer* mpx, const Trait& trait, uint16_t port,
std::string bind_address)
: super(mpx, trait), port(port), bind_address(std::move(bind_address)) {
// nop
}
/// Returns the server configuration type.
server_config_type type() const noexcept override {
return type_token;
}
/// The port number to bind to.
uint16_t port = 0;
/// The address to bind to.
std::string bind_address;
/// Whether to set `SO_REUSEADDR` on the socket.
bool reuse_addr = true;
};
/// Configuration for a server that uses a user-provided socket.
template <class Trait>
class server_config<Trait>::socket final : public server_config<Trait> {
public:
static constexpr auto type_token = server_config_type::socket;
using super = server_config;
socket(multiplexer* mpx, const Trait& trait, tcp_accept_socket fd)
: super(mpx, trait), fd(fd) {
// nop
}
~socket() override {
if (fd != invalid_socket)
close(fd);
}
/// Returns the server configuration type.
server_config_type type() const noexcept override {
return type_token;
}
/// The socket file descriptor to use.
tcp_accept_socket fd;
/// Returns the file descriptor and setting the `fd` member variable to the
/// invalid socket.
tcp_accept_socket take_fd() noexcept {
auto result = fd;
fd.id = invalid_socket_id;
return result;
}
};
/// Convenience alias for the `lazy` sub-type of @ref server_config.
template <class Trait>
using lazy_server_config = typename server_config<Trait>::lazy;
/// Convenience alias for the `socket` sub-type of @ref server_config.
template <class Trait>
using socket_server_config = typename server_config<Trait>::socket;
/// Calls a function object with the actual subtype of a server configuration
/// and returns its result.
template <class F, class Trait>
decltype(auto) visit(F&& f, server_config<Trait>& cfg) {
auto type = cfg.type();
if (cfg.type() == server_config_type::lazy)
return f(static_cast<lazy_server_config<Trait>&>(cfg));
assert(type == server_config_type::socket);
return f(static_cast<socket_server_config<Trait>&>(cfg));
}
/// Calls a function object with the actual subtype of a server configuration.
template <class F, class Trait>
decltype(auto) visit(F&& f, const server_config<Trait>& cfg) {
auto type = cfg.type();
if (cfg.type() == server_config_type::lazy)
return f(static_cast<const lazy_server_config<Trait>&>(cfg));
assert(type == server_config_type::socket);
return f(static_cast<const socket_server_config<Trait>&>(cfg));
}
/// Gets a pointer to a specific subtype of a server configuration.
template <class T, class Trait>
T* get_if(server_config<Trait>* cfg) {
if (T::type_token == cfg->type())
return static_cast<T*>(cfg);
return nullptr;
}
/// Gets a pointer to a specific subtype of a server configuration.
template <class T, class Trait>
const T* get_if(const server_config<Trait>* cfg) {
if (T::type_token == cfg->type())
return static_cast<const T*>(cfg);
return nullptr;
}
} // namespace caf::net::dsl
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/make_counted.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/dsl/server_config.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/acceptor.hpp"
#include "caf/net/tcp_accept_socket.hpp"
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// Base type for server factories for use with `can_accept`.
template <class Trait, class Derived>
class server_factory_base {
public:
using trait_type = Trait;
explicit server_factory_base(server_config_ptr<Trait> cfg)
: cfg_(std::move(cfg)) {
// nop
}
server_factory_base(const server_factory_base&) = default;
server_factory_base& operator=(const server_factory_base&) = default;
/// Sets the callback for errors.
template <class F>
Derived& do_on_error(F callback) {
static_assert(std::is_invocable_v<F, const error&>);
cfg_->on_error = make_shared_type_erased_callback(std::move(callback));
return dref();
}
/// Configures how many concurrent connections the server accepts.
Derived& max_connections(size_t value) {
cfg_->max_connections = value;
return dref();
}
/// Configures whether the server creates its socket with `SO_REUSEADDR`.
Derived& reuse_addr(bool value) {
if (auto* cfg = get_if<lazy_server_config<Trait>>(cfg_.get()))
cfg->reuse_addr = value;
return dref();
}
server_config<Trait>& config() {
return *cfg_;
}
private:
Derived& dref() {
return static_cast<Derived&>(*this);
}
server_config_ptr<Trait> cfg_;
};
} // namespace caf::net::dsl
......@@ -8,7 +8,9 @@
#include "caf/detail/binary_flow_bridge.hpp"
#include "caf/detail/flow_connector.hpp"
#include "caf/disposable.hpp"
#include "caf/net/dsl/client_factory_base.hpp"
#include "caf/net/lp/framing.hpp"
#include "caf/net/ssl/connection.hpp"
#include "caf/net/tcp_stream_socket.hpp"
#include "caf/timespan.hpp"
......@@ -25,17 +27,12 @@ class with_t;
/// Factory for the `with(...).connect(...).start(...)` DSL.
template <class Trait>
class connect_factory {
class client_factory
: public dsl::client_factory_base<Trait, client_factory<Trait>> {
public:
friend class with_t<Trait>;
using super = dsl::client_factory_base<Trait, client_factory<Trait>>;
connect_factory(const connect_factory&) noexcept = delete;
connect_factory& operator=(const connect_factory&) noexcept = delete;
connect_factory(connect_factory&&) noexcept = default;
connect_factory& operator=(connect_factory&&) noexcept = default;
using super::super;
/// Starts a connection with the length-prefixing protocol.
template <class OnStart>
......@@ -43,114 +40,32 @@ public:
using input_res_t = typename Trait::input_resource;
using output_res_t = typename Trait::output_resource;
static_assert(std::is_invocable_v<OnStart, input_res_t, output_res_t>);
switch (state_.index()) {
case 1: { // config
auto fd = try_connect(std::get<1>(state_));
if (fd) {
if (ctx_) {
auto conn = ctx_->new_connection(*fd);
if (conn)
return do_start(std::move(*conn), on_start);
if (do_on_error_)
do_on_error_(conn.error());
return {};
}
return do_start(*fd, on_start);
}
if (do_on_error_)
do_on_error_(fd.error());
return {};
}
case 2: { // stream_socket
// Pass ownership of the stream socket.
auto fd = std::get<2>(state_);
state_ = none;
return do_start(fd, on_start);
}
case 3: { // ssl::connection
// Pass ownership of the SSL connection.
auto conn = std::move(std::get<3>(state_));
state_ = none;
return do_start(std::move(conn), on_start);
}
case 4: // error
if (do_on_error_)
do_on_error_(std::get<4>(state_));
return {};
default:
return {};
}
}
/// Sets the retry delay for connection attempts.
///
/// @param value The new retry delay.
/// @returns a reference to this `connect_factory`.
connect_factory& retry_delay(timespan value) {
if (auto* cfg = std::get_if<config>(&state_))
cfg->retry_delay = value;
return *this;
}
/// Sets the connection timeout for connection attempts.
///
/// @param value The new connection timeout.
/// @returns a reference to this `connect_factory`.
connect_factory& connection_timeout(timespan value) {
if (auto* cfg = std::get_if<config>(&state_))
cfg->connection_timeout = value;
return *this;
}
/// Sets the maximum number of connection retry attempts.
///
/// @param value The new maximum retry count.
/// @returns a reference to this `connect_factory`.
connect_factory& max_retry_count(size_t value) {
if (auto* cfg = std::get_if<config>(&state_))
cfg->max_retry_count = value;
return *this;
}
/// Sets the callback for errors.
/// @returns a reference to this `connect_factory`.
template <class F>
connect_factory& do_on_error(F callback) {
do_on_error_ = std::move(callback);
return *this;
auto f = [this, &on_start](auto& cfg) {
return this->do_start(cfg, on_start);
};
return visit(f, this->config());
}
private:
struct config {
config(std::string address, uint16_t port)
: address(std::move(address)), port(port) {
// nop
}
std::string address;
uint16_t port;
timespan retry_delay = std::chrono::seconds{1};
timespan connection_timeout = infinite;
size_t max_retry_count = 0;
};
expected<tcp_stream_socket> try_connect(const config& cfg) {
auto result = make_connected_tcp_stream_socket(cfg.address, cfg.port,
expected<stream_socket> try_connect(const dsl::lazy_client_config<Trait>& cfg,
const std::string& host, uint16_t port) {
auto result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
return {*result};
for (size_t i = 1; i <= cfg.max_retry_count; ++i) {
std::this_thread::sleep_for(cfg.retry_delay);
result = make_connected_tcp_stream_socket(cfg.address, cfg.port,
result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout);
if (result)
return result;
return {*result};
}
return result;
return {std::move(result.error())};
}
template <class Conn, class OnStart>
disposable do_start(Conn conn, OnStart& on_start) {
disposable
do_start_impl(dsl::client_config<Trait>& cfg, Conn conn, OnStart& on_start) {
// s2a: socket-to-application (and a2s is the inverse).
using input_t = typename Trait::input_type;
using output_t = typename Trait::output_type;
......@@ -159,74 +74,72 @@ private:
auto [a2s_pull, a2s_push] = async::make_spsc_buffer_resource<output_t>();
auto fc = detail::flow_connector<Trait>::make_trivial(std::move(a2s_pull),
std::move(s2a_push));
auto bridge = detail::binary_flow_bridge<Trait>::make(mpx_, std::move(fc));
auto bridge = detail::binary_flow_bridge<Trait>::make(cfg.mpx,
std::move(fc));
auto bridge_ptr = bridge.get();
auto impl = framing::make(std::move(bridge));
auto transport = transport_t::make(std::move(conn), std::move(impl));
auto ptr = socket_manager::make(mpx_, std::move(transport));
auto ptr = socket_manager::make(cfg.mpx, std::move(transport));
bridge_ptr->self_ref(ptr->as_disposable());
mpx_->start(ptr);
cfg.mpx->start(ptr);
on_start(std::move(s2a_pull), std::move(a2s_push));
return disposable{std::move(ptr)};
}
explicit connect_factory(multiplexer* mpx) : mpx_(mpx) {
// nop
}
connect_factory(multiplexer* mpx, error err)
: mpx_(mpx), state_(std::move(err)) {
// nop
}
/// Initializes the connect factory to connect to the given TCP `host` and
/// `port`.
///
/// @param host The hostname or IP address to connect to.
/// @param port The port number to connect to.
void init(std::string host, uint16_t port) {
state_ = config{std::move(host), port};
template <class OnStart>
disposable do_start(dsl::lazy_client_config<Trait>& cfg,
const std::string& host, uint16_t port,
OnStart& on_start) {
auto fd = try_connect(cfg, host, port);
if (fd) {
if (cfg.ctx) {
auto conn = cfg.ctx->new_connection(*fd);
if (conn)
return do_start_impl(cfg, std::move(*conn), on_start);
cfg.call_on_error(conn.error());
return {};
}
return do_start_impl(cfg, *fd, on_start);
}
cfg.call_on_error(fd.error());
return {};
}
/// Initializes the connect factory to connect to the given TCP `socket`.
///
/// @param fd The TCP socket to connect.
void init(stream_socket fd) {
state_ = fd;
template <class OnStart>
disposable do_start(dsl::lazy_client_config<Trait>& cfg, OnStart& on_start) {
if (auto* st = std::get_if<dsl::client_config_server_address>(&cfg.server))
return do_start(cfg, st->host, st->port, on_start);
auto fail = [&cfg](auto code, std::string description) {
auto err = make_error(code, std::move(description));
cfg.call_on_error(err);
return disposable{};
};
auto& server_uri = std::get<uri>(cfg.server);
if (server_uri.scheme() != "tcp")
return fail(sec::invalid_argument, "connect expects tcp://<host>:<port>");
auto& auth = server_uri.authority();
if (auth.empty() || auth.port == 0)
return fail(sec::invalid_argument,
"connect expects tcp://<host>:<port> with non-zero port");
return do_start(cfg, auth.host_str(), auth.port, on_start);
}
/// Initializes the connect factory to connect to the given TCP `socket`.
///
/// @param conn The SSL connection object.
void init(ssl::connection conn) {
state_ = std::move(conn);
template <class OnStart>
disposable
do_start(dsl::socket_client_config<Trait>& cfg, OnStart& on_start) {
return do_start_impl(cfg, cfg.take_fd(), on_start);
}
/// Initializes the connect factory with an error.
///
/// @param err The error to be later forwarded to the `do_on_error_` handler.
void init(error err) {
state_ = std::move(err);
template <class OnStart>
disposable do_start(dsl::conn_client_config<Trait>& cfg, OnStart& on_start) {
return do_start_impl(cfg, std::move(cfg.state), on_start);
}
void set_ssl(ssl::context ctx) {
ctx_ = std::make_shared<ssl::context>(std::move(ctx));
template <class OnStart>
disposable do_start(dsl::fail_client_config<Trait>& cfg, OnStart&) {
cfg.call_on_error(cfg.err);
return {};
}
/// Pointer to multiplexer that runs the protocol stack.
multiplexer* mpx_;
/// Callback for errors.
std::function<void(const error&)> do_on_error_;
/// Configures the maximum number of concurrent connections.
size_t max_connections_ = defaults::net::max_connections.fallback;
/// User-defined state for getting things up and running.
std::variant<none_t, config, stream_socket, ssl::connection, error> state_;
/// Pointer to the (optional) SSL context.
std::shared_ptr<ssl::context> ctx_;
};
} // namespace caf::net::lp
......@@ -11,11 +11,10 @@
#include "caf/detail/flow_connector.hpp"
#include "caf/detail/shared_ssl_acceptor.hpp"
#include "caf/fwd.hpp"
#include "caf/net/dsl/server_factory_base.hpp"
#include "caf/net/http/server.hpp"
#include "caf/net/lp/framing.hpp"
#include "caf/net/multiplexer.hpp"
#include "caf/net/prometheus/accept_factory.hpp"
#include "caf/net/prometheus/server.hpp"
#include "caf/net/ssl/transport.hpp"
#include "caf/net/stream_transport.hpp"
#include "caf/net/tcp_accept_socket.hpp"
......@@ -62,40 +61,14 @@ private:
namespace caf::net::lp {
template <class>
class with_t;
/// Factory for the `with(...).accept(...).start(...)` DSL.
/// Factory type for the `with(...).accept(...).start(...)` DSL.
template <class Trait>
class accept_factory {
class server_factory
: public dsl::server_factory_base<Trait, server_factory<Trait>> {
public:
friend class with_t<Trait>;
accept_factory(accept_factory&&) = default;
accept_factory(const accept_factory&) = delete;
accept_factory& operator=(accept_factory&&) noexcept = default;
using super = dsl::server_factory_base<Trait, server_factory<Trait>>;
accept_factory& operator=(const accept_factory&) noexcept = delete;
~accept_factory() {
if (auto* fd = std::get_if<tcp_accept_socket>(&state_))
close(*fd);
}
/// Configures how many concurrent connections we are allowing.
accept_factory& max_connections(size_t value) {
max_connections_ = value;
return *this;
}
/// Sets the callback for errors.
template <class F>
accept_factory& do_on_error(F callback) {
do_on_error_ = std::move(callback);
return *this;
}
using super::super;
/// Starts a server that accepts incoming connections with the
/// length-prefixing protocol.
......@@ -103,94 +76,69 @@ public:
disposable start(OnStart on_start) {
using acceptor_resource = typename Trait::acceptor_resource;
static_assert(std::is_invocable_v<OnStart, acceptor_resource>);
switch (state_.index()) {
case 1: {
auto& cfg = std::get<1>(state_);
auto fd = make_tcp_accept_socket(cfg.port, cfg.address, cfg.reuse_addr);
if (fd)
return do_start(*fd, on_start);
if (do_on_error_)
do_on_error_(fd.error());
return {};
}
case 2: {
// Pass ownership of the socket to the accept handler.
auto fd = std::get<2>(state_);
state_ = none;
return do_start(fd, on_start);
}
default:
return {};
}
auto f = [this, &on_start](auto& cfg) {
return this->do_start(cfg, on_start);
};
return visit(f, this->config());
}
private:
struct config {
uint16_t port;
std::string address;
bool reuse_addr;
};
explicit accept_factory(multiplexer* mpx) : mpx_(mpx) {
// nop
}
template <class Factory, class AcceptHandler, class Acceptor, class OnStart>
disposable do_start_impl(Acceptor&& acc, OnStart& on_start) {
disposable do_start_impl(dsl::server_config<Trait>& cfg, Acceptor acc,
OnStart& on_start) {
using accept_event = typename Trait::accept_event;
using connector_t = detail::flow_connector<Trait>;
auto [pull, push] = async::make_spsc_buffer_resource<accept_event>();
auto serv = connector_t::make_basic_server(push.try_open());
auto factory = std::make_unique<Factory>(std::move(serv));
auto impl = AcceptHandler::make(std::move(acc), std::move(factory),
max_connections_);
cfg.max_connections);
auto impl_ptr = impl.get();
auto ptr = net::socket_manager::make(mpx_, std::move(impl));
auto ptr = net::socket_manager::make(cfg.mpx, std::move(impl));
impl_ptr->self_ref(ptr->as_disposable());
mpx_->start(ptr);
cfg.mpx->start(ptr);
on_start(std::move(pull));
return disposable{std::move(ptr)};
}
template <class OnStart>
disposable do_start(tcp_accept_socket fd, OnStart& on_start) {
if (!ctx_) {
disposable do_start(dsl::server_config<Trait>& cfg, tcp_accept_socket fd,
OnStart& on_start) {
if (!cfg.ctx) {
using factory_t = detail::lp_connection_factory<Trait, stream_transport>;
using impl_t = detail::accept_handler<tcp_accept_socket, stream_socket>;
return do_start_impl<factory_t, impl_t>(fd, on_start);
return do_start_impl<factory_t, impl_t>(cfg, fd, on_start);
}
using factory_t = detail::lp_connection_factory<Trait, ssl::transport>;
using acc_t = detail::shared_ssl_acceptor;
using impl_t = detail::accept_handler<acc_t, ssl::connection>;
return do_start_impl<factory_t, impl_t>(acc_t{fd, ctx_}, on_start);
}
void set_ssl(ssl::context ctx) {
ctx_ = std::make_shared<ssl::context>(std::move(ctx));
return do_start_impl<factory_t, impl_t>(cfg, acc_t{fd, cfg.ctx}, on_start);
}
void init(uint16_t port, std::string address, bool reuse_addr) {
state_ = config{port, std::move(address), reuse_addr};
template <class OnStart>
disposable
do_start(typename dsl::server_config<Trait>::socket& cfg, OnStart& on_start) {
if (cfg.fd == invalid_socket) {
auto err = make_error(
sec::runtime_error,
"server factory cannot create a server on an invalid socket");
cfg.call_on_error(err);
return {};
}
return do_start(cfg, cfg.take_fd(), on_start);
}
void init(tcp_accept_socket fd) {
state_ = fd;
template <class OnStart>
disposable
do_start(typename dsl::server_config<Trait>::lazy& cfg, OnStart& on_start) {
auto fd = make_tcp_accept_socket(cfg.port, cfg.bind_address,
cfg.reuse_addr);
if (!fd) {
cfg.call_on_error(fd.error());
return {};
}
return do_start(cfg, *fd, on_start);
}
/// Pointer to the hosting actor system.
multiplexer* mpx_;
/// Callback for errors.
std::function<void(const error&)> do_on_error_;
/// Configures the maximum number of concurrent connections.
size_t max_connections_ = defaults::net::max_connections.fallback;
/// User-defined state for getting things up and running.
std::variant<none_t, config, tcp_accept_socket> state_;
/// Pointer to the (optional) SSL context.
std::shared_ptr<ssl::context> ctx_;
};
} // namespace caf::net::lp
......@@ -5,8 +5,10 @@
#pragma once
#include "caf/fwd.hpp"
#include "caf/net/lp/accept_factory.hpp"
#include "caf/net/lp/connect_factory.hpp"
#include "caf/net/dsl/has_accept.hpp"
#include "caf/net/dsl/has_connect.hpp"
#include "caf/net/lp/client_factory.hpp"
#include "caf/net/lp/server_factory.hpp"
#include "caf/net/multiplexer.hpp"
#include "caf/net/ssl/acceptor.hpp"
#include "caf/net/ssl/context.hpp"
......@@ -18,9 +20,12 @@ namespace caf::net::lp {
/// Entry point for the `with(...)` DSL.
template <class Trait>
class with_t {
class with_t : public dsl::has_accept<server_factory<Trait>>,
public dsl::has_connect<client_factory<Trait>> {
public:
explicit with_t(multiplexer* mpx) : mpx_(mpx) {
template <class... Ts>
explicit with_t(multiplexer* mpx, Ts&&... xs)
: mpx_(mpx), trait_(std::forward<Ts>(xs)...) {
// nop
}
......@@ -28,168 +33,20 @@ public:
with_t& operator=(const with_t&) noexcept = default;
/// Creates an `accept_factory` object for the given TCP `port` and
/// `bind_address`.
///
/// @param port Port number to bind to.
/// @param bind_address IP address to bind to. Default is an empty string.
/// @param reuse_addr Whether or not to set `SO_REUSEADDR`.
/// @returns an `accept_factory` object initialized with the given parameters.
accept_factory<Trait> accept(uint16_t port, std::string bind_address = "",
bool reuse_addr = true) {
accept_factory<Trait> factory{mpx_};
factory.init(port, std::move(bind_address), std::move(reuse_addr));
return factory;
multiplexer* mpx() const noexcept override {
return mpx_;
}
/// Creates an `accept_factory` object for the given accept socket.
///
/// @param fd File descriptor for the accept socket.
/// @returns an `accept_factory` object that will start a Prometheus server on
/// the given socket.
accept_factory<Trait> accept(tcp_accept_socket fd) {
accept_factory<Trait> factory{mpx_};
factory.init(fd);
return factory;
}
/// Creates an `accept_factory` object for the given acceptor.
///
/// @param acc The SSL acceptor for incoming connections.
/// @returns an `accept_factory` object that will start a Prometheus server on
/// the given acceptor.
accept_factory<Trait> accept(ssl::acceptor acc) {
accept_factory<Trait> factory{mpx_};
factory.set_ssl(std::move(std::move(acc.ctx())));
factory.init(acc.fd());
return factory;
}
/// Creates an `accept_factory` object for the given TCP `port` and
/// `bind_address`.
///
/// @param ctx The SSL context for encryption.
/// @param port Port number to bind to.
/// @param bind_address IP address to bind to. Default is an empty string.
/// @param reuse_addr Whether or not to set `SO_REUSEADDR`.
/// @returns an `accept_factory` object initialized with the given parameters.
accept_factory<Trait> accept(ssl::context ctx, uint16_t port,
std::string bind_address = "",
bool reuse_addr = true) {
accept_factory<Trait> factory{mpx_};
factory.set_ssl(std::move(std::move(ctx)));
factory.init(port, std::move(bind_address), std::move(reuse_addr));
return factory;
}
/// Creates a `connect_factory` object for the given TCP `host` and `port`.
///
/// @param host The hostname or IP address to connect to.
/// @param port The port number to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
connect_factory<Trait> connect(std::string host, uint16_t port) {
connect_factory<Trait> factory{mpx_};
factory.init(std::move(host), port);
return factory;
}
/// Creates a `connect_factory` object for the given SSL `context`, TCP
/// `host`, and `port`.
///
/// @param ctx The SSL context for encryption.
/// @param host The hostname or IP address to connect to.
/// @param port The port number to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
connect_factory<Trait> connect(ssl::context ctx, std::string host,
uint16_t port) {
connect_factory<Trait> factory{mpx_};
factory.set_ssl(std::move(ctx));
factory.init(std::move(host), port);
return factory;
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
connect_factory<Trait> connect(const uri& endpoint) {
return connect_impl(nullptr, endpoint);
}
/// Creates a `connect_factory` object for the given SSL `context` and TCP
/// `endpoint`.
///
/// @param ctx The SSL context for encryption.
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
connect_factory<Trait> connect(ssl::context ctx, const uri& endpoint) {
return connect_impl(&ctx, endpoint);
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
connect_factory<Trait> connect(expected<uri> endpoint) {
if (endpoint)
return connect_impl(nullptr, std::move(*endpoint));
return connect_factory<Trait>{std::move(endpoint.error())};
}
/// Creates a `connect_factory` object for the given SSL `context` and TCP
/// `endpoint`.
///
/// @param ctx The SSL context for encryption.
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
connect_factory<Trait> connect(ssl::context ctx, expected<uri> endpoint) {
if (endpoint)
return connect_impl(&ctx, std::move(*endpoint));
return connect_factory<Trait>{std::move(endpoint.error())};
}
/// Creates a `connect_factory` object for the given stream `fd`.
///
/// @param fd The stream socket to use for the connection.
/// @returns a `connect_factory` object that will use the given socket.
connect_factory<Trait> connect(stream_socket fd) {
connect_factory<Trait> factory{mpx_};
factory.init(fd);
return factory;
}
/// Creates a `connect_factory` object for the given SSL `connection`.
///
/// @param conn The SSL connection to use.
/// @returns a `connect_factory` object that will use the given connection.
connect_factory<Trait> connect(ssl::connection conn) {
connect_factory<Trait> factory{mpx_};
factory.init(std::move(conn));
return factory;
const Trait& trait() const noexcept override {
return trait_;
}
private:
connect_factory<Trait> connect_impl(ssl::context* ctx, const uri& endpoint) {
if (endpoint.scheme() != "tcp" || endpoint.authority().empty()) {
auto err = make_error(sec::invalid_argument,
"lp::connect expects tcp://<host>:<port> URIs");
return connect_factory<Trait>{mpx_, std::move(err)};
}
if (endpoint.authority().port == 0) {
auto err = make_error(sec::invalid_argument,
"lp::connect expects URIs with a non-zero port");
return connect_factory<Trait>{mpx_, std::move(err)};
}
connect_factory<Trait> factory{mpx_};
if (ctx != nullptr) {
factory.set_ssl(std::move(*ctx));
}
factory.init(endpoint.authority().host_str(), endpoint.authority().port);
return factory;
}
/// Pointer to multiplexer that runs the protocol stack.
multiplexer* mpx_;
/// User-defined trait for configuring serialization.
Trait trait_;
};
template <class Trait = binary::default_trait>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment