Commit 69587e6e authored by Dominik Charousset's avatar Dominik Charousset

Iterate on the new caf-net DSL

parent 5e3b24eb
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <iostream> #include <iostream>
#include <utility> #include <utility>
using namespace std::literals;
// -- convenience type aliases ------------------------------------------------- // -- convenience type aliases -------------------------------------------------
// The trait for translating between bytes on the wire and flow items. The // The trait for translating between bytes on the wire and flow items. The
...@@ -44,69 +46,88 @@ struct config : caf::actor_system_config { ...@@ -44,69 +46,88 @@ struct config : caf::actor_system_config {
.add<uint16_t>("port,p", "port of the server") .add<uint16_t>("port,p", "port of the server")
.add<std::string>("host,H", "host of the server") .add<std::string>("host,H", "host of the server")
.add<std::string>("name,n", "set name"); .add<std::string>("name,n", "set name");
opt_group{custom_options_, "tls"} //
.add<bool>("enable", "enables encryption via TLS")
.add<std::string>("ca-file", "CA file for trusted servers");
} }
}; };
// -- main --------------------------------------------------------------------- // -- main ---------------------------------------------------------------------
int caf_main(caf::actor_system& sys, const config& cfg) { int caf_main(caf::actor_system& sys, const config& cfg) {
namespace ssl = caf::net::ssl;
// Read the configuration. // Read the configuration.
bool had_error = false; auto use_ssl = caf::get_or(cfg, "tls.enable", false);
auto port = caf::get_or(cfg, "port", default_port); auto port = caf::get_or(cfg, "port", default_port);
auto host = caf::get_or(cfg, "host", default_host); auto host = caf::get_or(cfg, "host", default_host);
auto name = caf::get_or(cfg, "name", ""); auto name = caf::get_or(cfg, "name", "");
auto ca_file = caf::get_as<std::string>(cfg, "tls.ca-file");
if (name.empty()) { if (name.empty()) {
std::cerr << "*** mandatory parameter 'name' missing or empty\n"; std::cerr << "*** mandatory parameter 'name' missing or empty\n";
return EXIT_FAILURE; return EXIT_FAILURE;
} }
// Connect to the server. // Connect to the server.
caf::net::lp::with(sys) auto conn
.connect(host, port) = caf::net::lp::with(sys)
.do_on_error([&](const caf::error& what) { // Optionally enable TLS.
std::cerr << "*** unable to connect to " << host << ":" << port << ": " .context(ssl::context::enable(use_ssl)
<< to_string(what) << '\n'; .and_then(ssl::emplace_client(ssl::tls::v1_2))
had_error = true; .and_then(ssl::load_verify_file_if(ca_file)))
}) // Connect to "$host:$port".
.start([&sys, name](auto pull, auto push) { .connect(host, port)
// Spin up a worker that prints received inputs. // If we don't succeed at first, try up to 10 times with 1s delay.
sys.spawn([pull](caf::event_based_actor* self) { .retry_delay(1s)
pull .max_retry_count(9)
.observe_on(self) // // After connecting, spin up a worker that prints received inputs.
.do_finally([self] { .start([&sys, name](auto pull, auto push) {
std::cout << "*** lost connection to server -> quit\n" sys.spawn([pull](caf::event_based_actor* self) {
<< "*** use CTRL+D or CTRL+C to terminate\n"; pull
self->quit(); .observe_on(self) //
}) .do_on_error([](const caf::error& err) {
.for_each([](const bin_frame& frame) { std::cout << "*** connection error: " << to_string(err) << '\n';
// Interpret the bytes as ASCII characters. })
auto bytes = frame.bytes(); .do_finally([self] {
auto str = std::string_view{ std::cout << "*** lost connection to server -> quit\n"
reinterpret_cast<const char*>(bytes.data()), bytes.size()}; << "*** use CTRL+D or CTRL+C to terminate\n";
if (std::all_of(str.begin(), str.end(), ::isprint)) { self->quit();
std::cout << str << '\n'; })
} else { .for_each([](const bin_frame& frame) {
std::cout << "<non-ascii-data of size " << bytes.size() << ">\n"; // Interpret the bytes as ASCII characters.
auto bytes = frame.bytes();
auto str = std::string_view{
reinterpret_cast<const char*>(bytes.data()), bytes.size()};
if (std::all_of(str.begin(), str.end(), ::isprint)) {
std::cout << str << '\n';
} else {
std::cout << "<non-ascii-data of size " << bytes.size()
<< ">\n";
}
});
});
// Spin up a second worker that reads from std::cin and sends each
// line to the server. Put that to its own thread since it's doing
// I/O.
sys.spawn<caf::detached>([push, name] {
auto lines = caf::async::make_blocking_producer(push);
if (!lines)
throw std::logic_error("failed to create blocking producer");
auto line = std::string{};
auto prefix = name + ": ";
while (std::getline(std::cin, line)) {
line.insert(line.begin(), prefix.begin(), prefix.end());
lines->push(bin_frame{caf::as_bytes(caf::make_span(line))});
line.clear();
} }
}); });
}); });
// Spin up a second worker that reads from std::cin and sends each line to if (!conn) {
// the server. Put that to its own thread since it's doing I/O. std::cerr << "*** unable to connect to " << host << ":" << port << ": "
sys.spawn<caf::detached>([push, name] { << to_string(conn.error()) << '\n';
auto lines = caf::async::make_blocking_producer(push); return EXIT_FAILURE;
if (!lines) }
throw std::logic_error("failed to create blocking producer");
auto line = std::string{};
auto prefix = name + ": ";
while (std::getline(std::cin, line)) {
line.insert(line.begin(), prefix.begin(), prefix.end());
lines->push(bin_frame{caf::as_bytes(caf::make_span(line))});
line.clear();
}
});
});
// Note: the actor system will keep the application running for as long as the // Note: the actor system will keep the application running for as long as the
// workers are still alive. // workers are still alive.
return had_error ? EXIT_FAILURE : EXIT_SUCCESS; return EXIT_SUCCESS;
} }
CAF_MAIN(caf::net::middleman) CAF_MAIN(caf::net::middleman)
...@@ -32,12 +32,18 @@ using message_t = std::pair<caf::uuid, bin_frame>; ...@@ -32,12 +32,18 @@ using message_t = std::pair<caf::uuid, bin_frame>;
static constexpr uint16_t default_port = 7788; static constexpr uint16_t default_port = 7788;
static constexpr size_t default_max_connections = 128;
// -- configuration setup ------------------------------------------------------ // -- configuration setup ------------------------------------------------------
struct config : caf::actor_system_config { struct config : caf::actor_system_config {
config() { config() {
opt_group{custom_options_, "global"} // opt_group{custom_options_, "global"} //
.add<uint16_t>("port,p", "port to listen for incoming connections"); .add<uint16_t>("port,p", "port to listen for incoming connections")
.add<size_t>("max-connections,m", "limit for concurrent clients");
opt_group{custom_options_, "tls"} //
.add<std::string>("key-file,k", "path to the private key file")
.add<std::string>("cert-file,c", "path to the certificate file");
} }
}; };
...@@ -78,16 +84,19 @@ void worker_impl(caf::event_based_actor* self, ...@@ -78,16 +84,19 @@ void worker_impl(caf::event_based_actor* self,
}) })
.subscribe(push); .subscribe(push);
// Feed messages from the `pull` end into the central merge point. // Feed messages from the `pull` end into the central merge point.
auto inputs = pull.observe_on(self) auto inputs
.on_error_complete() // Cary on if a connection breaks. = pull.observe_on(self)
.do_on_complete([conn] { .do_on_error([](const caf::error& err) {
std::cout << "*** lost connection " << to_string(conn) std::cout << "*** connection error: " << to_string(err) << '\n';
<< '\n'; })
}) .on_error_complete() // Cary on if a connection breaks.
.map([conn](const bin_frame& frame) { .do_on_complete([conn] {
return message_t{conn, frame}; std::cout << "*** lost connection " << to_string(conn) << '\n';
}) })
.as_observable(); .map([conn](const bin_frame& frame) {
return message_t{conn, frame};
})
.as_observable();
pub.push(inputs); pub.push(inputs);
}); });
} }
...@@ -95,22 +104,43 @@ void worker_impl(caf::event_based_actor* self, ...@@ -95,22 +104,43 @@ void worker_impl(caf::event_based_actor* self,
// -- main --------------------------------------------------------------------- // -- main ---------------------------------------------------------------------
int caf_main(caf::actor_system& sys, const config& cfg) { int caf_main(caf::actor_system& sys, const config& cfg) {
namespace ssl = caf::net::ssl;
// Read the configuration.
auto port = caf::get_or(cfg, "port", default_port);
auto pem = ssl::format::pem;
auto key_file = caf::get_as<std::string>(cfg, "tls.key-file");
auto cert_file = caf::get_as<std::string>(cfg, "tls.cert-file");
auto max_connections = caf::get_or(cfg, "max-connections",
default_max_connections);
if (!key_file != !cert_file) {
std::cerr << "*** inconsistent TLS config: declare neither file or both\n";
return EXIT_FAILURE;
}
// Open up a TCP port for incoming connections and start the server. // Open up a TCP port for incoming connections and start the server.
auto had_error = false; auto had_error = false;
auto port = caf::get_or(cfg, "port", default_port); auto server
caf::net::lp::with(sys) = caf::net::lp::with(sys)
.accept(port) // Optionally enable TLS.
.do_on_error([&](const caf::error& what) { .context(ssl::context::enable(key_file && cert_file)
std::cerr << "*** unable to open port " << port << ": " << to_string(what) .and_then(ssl::emplace_server(ssl::tls::v1_2))
<< '\n'; .and_then(ssl::use_private_key_file(key_file, pem))
had_error = true; .and_then(ssl::use_certificate_file(cert_file, pem)))
}) // Bind to the user-defined port.
.start([&sys](trait::acceptor_resource accept_events) { .accept(port)
sys.spawn(worker_impl, std::move(accept_events)); // Limit how many clients may be connected at any given time.
}); .max_connections(max_connections)
// When started, run our worker actor to handle incoming connections.
.start([&sys](trait::acceptor_resource accept_events) {
sys.spawn(worker_impl, std::move(accept_events));
});
if (!server) {
std::cerr << "*** unable to run at port " << port << ": "
<< to_string(server.error()) << '\n';
return EXIT_FAILURE;
}
// Note: the actor system will keep the application running for as long as the // Note: the actor system will keep the application running for as long as the
// workers are still alive. // workers are still alive.
return had_error ? EXIT_FAILURE : EXIT_SUCCESS; return EXIT_SUCCESS;
} }
CAF_MAIN(caf::net::middleman) CAF_MAIN(caf::net::middleman)
...@@ -70,27 +70,23 @@ int caf_main(actor_system& sys, const config& cfg) { ...@@ -70,27 +70,23 @@ int caf_main(actor_system& sys, const config& cfg) {
Ui::ChatWindow helper; Ui::ChatWindow helper;
helper.setupUi(&mw); helper.setupUi(&mw);
// Connect to the server. // Connect to the server.
auto had_error = false;
auto conn auto conn
= caf::net::lp::with(sys) = caf::net::lp::with(sys)
.connect(host, port) .connect(host, port)
.do_on_error([&](const caf::error& what) {
std::cerr << "*** unable to connect to " << host << ":" << port
<< ": " << to_string(what) << '\n';
had_error = true;
})
.start([&](auto pull, auto push) { .start([&](auto pull, auto push) {
std::cout << "*** connected to " << host << ":" << port << '\n'; std::cout << "*** connected to " << host << ":" << port << '\n';
helper.chatwidget->init(sys, name, std::move(pull), std::move(push)); helper.chatwidget->init(sys, name, std::move(pull), std::move(push));
}); });
if (had_error) { if (!conn) {
std::cerr << "*** unable to connect to " << host << ":" << port << ": "
<< to_string(conn.error()) << '\n';
mw.close(); mw.close();
return app.exec(); return app.exec();
} }
// Setup and run. // Setup and run.
mw.show(); mw.show();
auto result = app.exec(); auto result = app.exec();
conn.dispose(); conn->dispose();
return result; return result;
} }
......
...@@ -1049,6 +1049,13 @@ struct unboxed_oracle<std::optional<T>> { ...@@ -1049,6 +1049,13 @@ struct unboxed_oracle<std::optional<T>> {
template <class T> template <class T>
using unboxed_t = typename unboxed_oracle<T>::type; using unboxed_t = typename unboxed_oracle<T>::type;
/// Evaluates to true if `T` is a std::string or is convertible to a `const
/// char*`.
template <class T>
constexpr bool is_string_or_cstring_v
= std::is_convertible_v<T, const char*>
|| std::is_same_v<std::string, std::decay_t<T>>;
} // namespace caf::detail } // namespace caf::detail
#undef CAF_HAS_MEMBER_TRAIT #undef CAF_HAS_MEMBER_TRAIT
......
...@@ -176,12 +176,20 @@ public: ...@@ -176,12 +176,20 @@ public:
template <class C> template <class C>
intrusive_ptr<C> downcast() const noexcept { intrusive_ptr<C> downcast() const noexcept {
return (ptr_) ? dynamic_cast<C*>(get()) : nullptr; static_assert(std::is_base_of_v<T, C>);
return intrusive_ptr<C>{ptr_ ? dynamic_cast<C*>(get()) : nullptr};
} }
template <class C> template <class C>
intrusive_ptr<C> upcast() const noexcept { intrusive_ptr<C> upcast() const& noexcept {
return (ptr_) ? static_cast<C*>(get()) : nullptr; static_assert(std::is_base_of_v<C, T>);
return intrusive_ptr<C>{ptr_ ? ptr_ : nullptr};
}
template <class C>
intrusive_ptr<C> upcast() && noexcept {
static_assert(std::is_base_of_v<C, T>);
return intrusive_ptr<C>{ptr_ ? release() : nullptr, false};
} }
private: private:
......
...@@ -39,15 +39,16 @@ caf_add_component( ...@@ -39,15 +39,16 @@ caf_add_component(
src/net/binary/lower_layer.cpp src/net/binary/lower_layer.cpp
src/net/binary/upper_layer.cpp src/net/binary/upper_layer.cpp
src/net/datagram_socket.cpp src/net/datagram_socket.cpp
src/net/dsl/config_base.cpp
src/net/generic_lower_layer.cpp src/net/generic_lower_layer.cpp
src/net/generic_upper_layer.cpp src/net/generic_upper_layer.cpp
src/net/http/header.cpp src/net/http/header.cpp
src/net/http/lower_layer.cpp src/net/http/lower_layer.cpp
src/net/http/serve.cpp
src/net/http/method.cpp src/net/http/method.cpp
src/net/http/request.cpp src/net/http/request.cpp
src/net/http/response.cpp src/net/http/response.cpp
src/net/http/serve.cpp src/net/http/serve.cpp
src/net/http/serve.cpp
src/net/http/server.cpp src/net/http/server.cpp
src/net/http/status.cpp src/net/http/status.cpp
src/net/http/upper_layer.cpp src/net/http/upper_layer.cpp
...@@ -67,6 +68,7 @@ caf_add_component( ...@@ -67,6 +68,7 @@ caf_add_component(
src/net/ssl/connection.cpp src/net/ssl/connection.cpp
src/net/ssl/context.cpp src/net/ssl/context.cpp
src/net/ssl/dtls.cpp src/net/ssl/dtls.cpp
src/net/ssl/errc.cpp
src/net/ssl/format.cpp src/net/ssl/format.cpp
src/net/ssl/password.cpp src/net/ssl/password.cpp
src/net/ssl/startup.cpp src/net/ssl/startup.cpp
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/expected.hpp"
#include <optional>
namespace caf::net::dsl::arg {
/// Represents a null-terminated string or `null`.
class cstring {
public:
cstring() : data_(nullptr) {
// nop
}
cstring(const char* str) : data_(str) {
// nop
}
cstring(std::string str) : data_(std::move(str)) {
// nop
}
cstring(std::optional<const char*> str) : cstring() {
if (str)
data_ = *str;
}
cstring(std::optional<std::string> str) : cstring() {
if (str)
data_ = std::move(*str);
}
cstring(caf::expected<const char*> str) : cstring() {
if (str)
data_ = *str;
}
cstring(caf::expected<std::string> str) : cstring() {
if (str)
data_ = std::move(*str);
}
cstring(cstring&&) = default;
cstring(const cstring&) = default;
cstring& operator=(cstring&&) = default;
cstring& operator=(const cstring&) = default;
/// @returns a pointer to the null-terminated string.
const char* get() const noexcept {
return std::visit(
[](auto& arg) -> const char* {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, const char*>) {
return arg;
} else {
return arg.c_str();
}
},
data_);
}
bool has_value() const noexcept {
return !operator!();
}
explicit operator bool() const noexcept {
return has_value();
}
bool operator!() const noexcept {
return data_.index() == 0 && std::get<0>(data_) == nullptr;
}
private:
std::variant<const char*, std::string> data_;
};
/// Represents a value of type T or `null`.
template <class T>
class val {
public:
val() = default;
val(T value) : data_(std::move(value)) {
// nop
}
val(std::optional<T> value) : data_(std::move(value)) {
// nop
}
val(caf::expected<T> value) {
if (value)
data_ = std::move(*value);
}
val(val&&) = default;
val(const val&) = default;
val& operator=(val&&) = default;
val& operator=(const val&) = default;
const T& get() const noexcept {
return *data_;
}
explicit operator bool() const noexcept {
return data_.has_value();
}
bool operator!() const noexcept {
return !data_;
}
private:
std::optional<T> data_;
};
} // namespace caf::net::dsl::arg
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/fwd.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/fwd.hpp"
namespace caf::net::dsl {
/// Base type for our DSL classes to configure a factory object..
template <class Trait>
class base {
public:
using trait_type = Trait;
virtual ~base() {
// nop
}
/// @returns the pointer to the @ref multiplexer.
virtual multiplexer* mpx() const noexcept = 0;
/// @returns the trait object.
virtual const Trait& trait() const noexcept = 0;
/// @returns the optional SSL context, whereas an object with
/// default-constructed error is treated as "no SSL".
expected<ssl::context>& get_context() {
return get_context_impl();
}
/// @private
template <class ConfigType>
auto with_context(intrusive_ptr<ConfigType> ptr) {
using ConfigBaseType = typename ConfigType::super;
auto as_base_ptr = [](auto& derived_ptr) {
return std::move(derived_ptr).template upcast<ConfigBaseType>();
};
// Move the context into the config if present.
auto& ctx = get_context();
if (ctx) {
ptr->ctx = std::make_shared<ssl::context>(std::move(*ctx));
return as_base_ptr(ptr);
}
// Default-constructed error just means "no SSL".
if (!ctx.error())
return as_base_ptr(ptr);
// We actually have an error: replace `ptr` with a fail config. Need to cast
// to the base type for to_fail_config to pick up the right overload.
auto fptr = to_fail_config(as_base_ptr(ptr), std::move(ctx.error()));
return as_base_ptr(fptr);
}
private:
virtual expected<ssl::context>& get_context_impl() noexcept = 0;
};
} // namespace caf::net::dsl
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
#include "caf/callback.hpp" #include "caf/callback.hpp"
#include "caf/defaults.hpp" #include "caf/defaults.hpp"
#include "caf/detail/plain_ref_counted.hpp"
#include "caf/intrusive_ptr.hpp" #include "caf/intrusive_ptr.hpp"
#include "caf/net/dsl/has_trait.hpp" #include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/config_base.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/ssl/connection.hpp" #include "caf/net/ssl/connection.hpp"
#include "caf/net/ssl/context.hpp" #include "caf/net/ssl/context.hpp"
...@@ -24,10 +24,12 @@ namespace caf::net::dsl { ...@@ -24,10 +24,12 @@ namespace caf::net::dsl {
/// The server config type enum class. /// The server config type enum class.
enum class client_config_type { lazy, socket, conn, fail }; enum class client_config_type { lazy, socket, conn, fail };
/// Base class for server configuration objects. /// Base class for client configuration objects.
template <class Trait> template <class Trait>
class client_config : public detail::plain_ref_counted { class client_config : public config_base {
public: public:
using trait_type = Trait;
class lazy; class lazy;
class socket; class socket;
class conn; class conn;
...@@ -38,42 +40,19 @@ public: ...@@ -38,42 +40,19 @@ public:
friend class conn; friend class conn;
friend class fail; friend class fail;
client_config(const client_config&) = delete;
client_config& operator=(const client_config&) = delete;
/// Virtual destructor. /// Virtual destructor.
virtual ~client_config() = default; virtual ~client_config() = default;
/// Returns the server configuration type. /// Returns the server configuration type.
virtual client_config_type type() const noexcept = 0; virtual client_config_type type() const noexcept = 0;
/// The pointer to the @ref multiplexer for running the server.
multiplexer* mpx;
/// The user-defined trait for configuration serialization. /// The user-defined trait for configuration serialization.
Trait trait; Trait trait;
/// User-defined callback for errors.
shared_callback_ptr<void(const error&)> on_error;
/// Calls `on_error` if non-null.
void call_on_error(const error& what) {
if (on_error)
(*on_error)(what);
}
friend void intrusive_ptr_add_ref(const client_config* ptr) noexcept {
ptr->ref();
}
friend void intrusive_ptr_release(const client_config* ptr) noexcept {
ptr->deref();
}
private: private:
/// Private constructor to enforce sealing. /// Private constructor to enforce sealing.
client_config(multiplexer* mpx, const Trait& trait) : mpx(mpx), trait(trait) { client_config(multiplexer* mpx, const Trait& trait)
: config_base(mpx), trait(trait) {
// nop // nop
} }
}; };
...@@ -294,4 +273,11 @@ const T* get_if(const client_config<Trait>* config) { ...@@ -294,4 +273,11 @@ const T* get_if(const client_config<Trait>* config) {
return nullptr; return nullptr;
} }
/// Creates a `fail_client_config` from another configuration object plus error.
template <class Trait>
auto to_fail_config(client_config_ptr<Trait> ptr, error err) {
using impl_t = fail_client_config<Trait>;
return make_counted<impl_t>(ptr->mpx, ptr->trait, std::move(err));
}
} // namespace caf::net::dsl } // namespace caf::net::dsl
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#pragma once #pragma once
#include "caf/make_counted.hpp" #include "caf/make_counted.hpp"
#include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/client_config.hpp" #include "caf/net/dsl/client_config.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/ssl/acceptor.hpp" #include "caf/net/ssl/acceptor.hpp"
#include "caf/net/tcp_accept_socket.hpp" #include "caf/net/tcp_accept_socket.hpp"
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
/// @param value The new retry delay. /// @param value The new retry delay.
/// @returns a reference to this `client_factory`. /// @returns a reference to this `client_factory`.
Derived& retry_delay(timespan value) { Derived& retry_delay(timespan value) {
if (auto* cfg = std::get_if<lazy_client_config<Trait>>(&cfg_.get())) if (auto* cfg = get_if<lazy_client_config<Trait>>(cfg_.get()))
cfg->retry_delay = value; cfg->retry_delay = value;
return dref(); return dref();
} }
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
/// @param value The new connection timeout. /// @param value The new connection timeout.
/// @returns a reference to this `client_factory`. /// @returns a reference to this `client_factory`.
Derived& connection_timeout(timespan value) { Derived& connection_timeout(timespan value) {
if (auto* cfg = std::get_if<lazy_client_config<Trait>>(&cfg_.get())) if (auto* cfg = get_if<lazy_client_config<Trait>>(cfg_.get()))
cfg->connection_timeout = value; cfg->connection_timeout = value;
return dref(); return dref();
} }
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
/// @param value The new maximum retry count. /// @param value The new maximum retry count.
/// @returns a reference to this `client_factory`. /// @returns a reference to this `client_factory`.
Derived& max_retry_count(size_t value) { Derived& max_retry_count(size_t value) {
if (auto* cfg = std::get_if<lazy_client_config<Trait>>(&cfg_.get())) if (auto* cfg = get_if<lazy_client_config<Trait>>(cfg_.get()))
cfg->max_retry_count = value; cfg->max_retry_count = value;
return dref(); return dref();
} }
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/callback.hpp"
#include "caf/defaults.hpp"
#include "caf/detail/net_export.hpp"
#include "caf/intrusive_ptr.hpp"
#include "caf/net/dsl/base.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/ssl/connection.hpp"
#include "caf/net/ssl/context.hpp"
#include "caf/net/stream_socket.hpp"
#include "caf/ref_counted.hpp"
#include "caf/uri.hpp"
#include <cassert>
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// Base class for configuration objects.
class CAF_NET_EXPORT config_base : public ref_counted {
public:
explicit config_base(multiplexer* mpx) : mpx(mpx) {
// nop
}
config_base(const config_base&) = delete;
config_base& operator=(const config_base&) = delete;
virtual ~config_base();
/// The pointer to the parent @ref multiplexer.
multiplexer* mpx;
/// User-defined callback for errors.
shared_callback_ptr<void(const error&)> on_error;
/// Calls `on_error` if non-null.
void call_on_error(const error& what) {
if (on_error)
(*on_error)(what);
}
};
} // namespace caf::net::dsl
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#pragma once #pragma once
#include "caf/make_counted.hpp" #include "caf/make_counted.hpp"
#include "caf/net/dsl/has_trait.hpp" #include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/server_config.hpp" #include "caf/net/dsl/server_config.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/ssl/acceptor.hpp" #include "caf/net/ssl/acceptor.hpp"
...@@ -17,85 +17,76 @@ ...@@ -17,85 +17,76 @@
namespace caf::net::dsl { namespace caf::net::dsl {
/// DSL entry point for creating a server. /// DSL entry point for creating a server.
template <class ServerFactory> template <class Base, class Subtype>
class has_accept : public has_trait<typename ServerFactory::trait_type> { class has_accept : public Base {
public: public:
using trait_type = typename ServerFactory::trait_type; using trait_type = typename Base::trait_type;
using super = has_trait<trait_type>;
using super::super;
/// Creates an `accept_factory` object for the given TCP `port` and
/// `bind_address`.
///
/// @param port Port number to bind to.
/// @param bind_address IP address to bind to. Default is an empty string.
/// @returns an `accept_factory` object initialized with the given parameters.
ServerFactory accept(uint16_t port, std::string bind_address = "") {
auto cfg = make_lazy_config(port, std::move(bind_address));
return ServerFactory{std::move(cfg)};
}
/// Creates an `accept_factory` object for the given TCP `port` and /// Creates an `accept_factory` object for the given TCP `port` and
/// `bind_address`. /// `bind_address`.
/// ///
/// @param ctx The SSL context for encryption.
/// @param port Port number to bind to. /// @param port Port number to bind to.
/// @param bind_address IP address to bind to. Default is an empty string. /// @param bind_address IP address to bind to. Default is an empty string.
/// @returns an `accept_factory` object initialized with the given parameters. /// @returns an `accept_factory` object initialized with the given parameters.
ServerFactory accept(ssl::context ctx, uint16_t port, auto accept(uint16_t port, std::string bind_address = "") {
std::string bind_address = "") { auto& dref = static_cast<Subtype&>(*this);
auto cfg = make_lazy_config(port, std::move(bind_address)); auto cfg = make_lazy_config(port, std::move(bind_address));
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx)); return dref.lift(dref.with_context(std::move(cfg)));
return ServerFactory{std::move(cfg)};
} }
/// Creates an `accept_factory` object for the given accept socket. /// Creates an `accept_factory` object for the given accept socket.
/// ///
/// @param fd File descriptor for the accept socket. /// @param fd File descriptor for the accept socket.
/// @returns an `accept_factory` object that will start a Prometheus server on /// @returns an `accept_factory` object that will start a server on `fd`.
/// the given socket. auto accept(tcp_accept_socket fd) {
ServerFactory accept(tcp_accept_socket fd) { auto& dref = static_cast<Subtype&>(*this);
auto cfg = make_socket_config(fd); return dref.lift(dref.with_context(make_socket_config(fd)));
return ServerFactory{std::move(cfg)};
}
/// Creates an `accept_factory` object for the given acceptor.
///
/// @param ctx The SSL context for encryption.
/// @param fd File descriptor for the accept socket.
/// @returns an `accept_factory` object that will start a Prometheus server on
/// the given acceptor.
ServerFactory accept(ssl::context ctx, tcp_accept_socket fd) {
auto cfg = make_socket_config(fd);
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx));
return ServerFactory{std::move(cfg)};
} }
/// Creates an `accept_factory` object for the given acceptor. /// Creates an `accept_factory` object for the given acceptor.
/// ///
/// @param acc The SSL acceptor for incoming connections. /// @param acc The SSL acceptor for incoming connections.
/// @returns an `accept_factory` object that will start a Prometheus server on /// @returns an `accept_factory` object that will start a server on `acc`.
/// the given acceptor. auto accept(ssl::acceptor acc) {
ServerFactory accept(ssl::acceptor acc) { auto& dref = static_cast<Subtype&>(*this);
return accept(std::move(acc.ctx()), acc.fd()); // The SSL acceptor has its own context, we cannot have two.
auto& ctx = dref().context();
if (ctx.has_value()) {
auto err = make_error(
sec::logic_error,
"passed an ssl::acceptor to a factory with a valid SSL context");
return dref.lift(make_fail_config(std::move(err)));
}
// Forward an already existing error.
if (ctx.error()) {
return dref.lift(make_fail_config(std::move(ctx.error())));
}
// Default-constructed error means: "no SSL". Use he one from the acceptor.
ctx = std::move(acc.ctx());
return accept(acc.fd());
} }
private: private:
template <class... Ts> template <class... Ts>
server_config_ptr<trait_type> make_lazy_config(Ts&&... xs) { auto make_lazy_config(Ts&&... xs) {
using impl_t = typename server_config<trait_type>::lazy; using impl_t = typename server_config<trait_type>::lazy;
return make_counted<impl_t>(this->mpx(), this->trait(), return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...); std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
server_config_ptr<trait_type> make_socket_config(Ts&&... xs) { auto make_socket_config(Ts&&... xs) {
using impl_t = typename server_config<trait_type>::socket; using impl_t = typename server_config<trait_type>::socket;
return make_counted<impl_t>(this->mpx(), this->trait(), return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...); std::forward<Ts>(xs)...);
} }
template <class... Ts>
auto make_fail_config(Ts&&... xs) {
using impl_t = fail_server_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...);
}
}; };
} // namespace caf::net::dsl } // namespace caf::net::dsl
...@@ -5,137 +5,77 @@ ...@@ -5,137 +5,77 @@
#pragma once #pragma once
#include "caf/make_counted.hpp" #include "caf/make_counted.hpp"
#include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/client_config.hpp" #include "caf/net/dsl/client_config.hpp"
#include "caf/net/dsl/has_trait.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/tcp_stream_socket.hpp" #include "caf/net/tcp_stream_socket.hpp"
#include "caf/uri.hpp"
#include <cstdint> #include <cstdint>
#include <string> #include <string>
namespace caf::net::dsl { namespace caf::net::dsl {
/// DSL entry point for creating a server. /// DSL entry point for creating a client.
template <class ClientFactory> template <class Base, class Subtype>
class has_connect : public has_trait<typename ClientFactory::trait_type> { class has_connect : public Base {
public: public:
using trait_type = typename ClientFactory::trait_type; using trait_type = typename Base::trait_type;
using super = has_trait<trait_type>;
using super::super;
/// Creates a `connect_factory` object for the given TCP `host` and `port`. /// Creates a `connect_factory` object for the given TCP `host` and `port`.
/// ///
/// @param host The hostname or IP address to connect to. /// @param host The hostname or IP address to connect to.
/// @param port The port number to connect to. /// @param port The port number to connect to.
/// @returns a `connect_factory` object initialized with the given parameters. /// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(std::string host, uint16_t port) { auto connect(std::string host, uint16_t port) {
auto& dref = static_cast<Subtype&>(*this);
auto cfg = make_lazy_config(std::move(host), port); auto cfg = make_lazy_config(std::move(host), port);
return ClientFactory{std::move(cfg)}; return dref.lift(dref.with_context(std::move(cfg)));
}
/// Creates a `connect_factory` object for the given SSL `context`, TCP
/// `host`, and `port`.
///
/// @param ctx The SSL context for encryption.
/// @param host The hostname or IP address to connect to.
/// @param port The port number to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(ssl::context ctx, std::string host, uint16_t port) {
auto cfg = make_lazy_config(std::move(host), port);
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx));
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(const uri& endpoint) {
auto cfg = make_lazy_config(endpoint);
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given SSL `context` and TCP
/// `endpoint`.
///
/// @param ctx The SSL context for encryption.
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(ssl::context ctx, const uri& endpoint) {
auto cfg = make_lazy_config(endpoint);
cfg->ctx = std::make_shared<ssl::context>(std::move(ctx));
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(expected<uri> endpoint) {
if (endpoint)
return connect(*endpoint);
auto cfg = make_fail_config(endpoint.error());
return ClientFactory{std::move(cfg)};
}
/// Creates a `connect_factory` object for the given SSL `context` and TCP
/// `endpoint`.
///
/// @param ctx The SSL context for encryption.
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
ClientFactory connect(ssl::context ctx, expected<uri> endpoint) {
if (endpoint)
return connect(std::move(ctx), *endpoint);
auto cfg = make_fail_config(endpoint.error());
return ClientFactory{std::move(cfg)};
} }
/// Creates a `connect_factory` object for the given stream `fd`. /// Creates a `connect_factory` object for the given stream `fd`.
/// ///
/// @param fd The stream socket to use for the connection. /// @param fd The stream socket to use for the connection.
/// @returns a `connect_factory` object that will use the given socket. /// @returns a `connect_factory` object that will use the given socket.
ClientFactory connect(stream_socket fd) { auto connect(stream_socket fd) {
auto& dref = static_cast<Subtype&>(*this);
auto cfg = make_socket_config(fd); auto cfg = make_socket_config(fd);
return ClientFactory{std::move(cfg)}; return dref.lift(dref.with_context(std::move(cfg)));
} }
/// Creates a `connect_factory` object for the given SSL `connection`. /// Creates a `connect_factory` object for the given SSL `connection`.
/// ///
/// @param conn The SSL connection to use. /// @param conn The SSL connection to use.
/// @returns a `connect_factory` object that will use the given connection. /// @returns a `connect_factory` object that will use the given connection.
ClientFactory connect(ssl::connection conn) { auto connect(ssl::connection conn) {
auto& dref = static_cast<Subtype&>(*this);
auto cfg = make_conn_config(std::move(conn)); auto cfg = make_conn_config(std::move(conn));
return ClientFactory{std::move(cfg)}; return dref.lift(std::move(cfg));
} }
private: protected:
template <class... Ts> template <class... Ts>
client_config_ptr<trait_type> make_lazy_config(Ts&&... xs) { auto make_lazy_config(Ts&&... xs) {
using impl_t = lazy_client_config<trait_type>; using impl_t = lazy_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(), return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...); std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
client_config_ptr<trait_type> make_socket_config(Ts&&... xs) { auto make_socket_config(Ts&&... xs) {
using impl_t = socket_client_config<trait_type>; using impl_t = socket_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(), return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...); std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
client_config_ptr<trait_type> make_conn_config(Ts&&... xs) { auto make_conn_config(Ts&&... xs) {
using impl_t = conn_client_config<trait_type>; using impl_t = conn_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(), return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...); std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
client_config_ptr<trait_type> make_fail_config(Ts&&... xs) { auto make_fail_config(Ts&&... xs) {
using impl_t = fail_client_config<trait_type>; using impl_t = fail_client_config<trait_type>;
return make_counted<impl_t>(this->mpx(), this->trait(), return make_counted<impl_t>(this->mpx(), this->trait(),
std::forward<Ts>(xs)...); std::forward<Ts>(xs)...);
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/expected.hpp"
#include "caf/net/ssl/context.hpp"
namespace caf::net::dsl {
/// DSL entry point for creating a server.
template <class Base, class Subtype>
class has_context : public Base {
public:
using trait_type = typename Base::trait_type;
/// Sets the optional SSL context.
///
/// @param ctx The SSL context for encryption.
/// @returns a reference to `*this`.
Subtype& context(expected<ssl::context> ctx) {
auto& dref = static_cast<Subtype&>(*this);
dref.get_context() = std::move(ctx);
return dref;
}
};
} // namespace caf::net::dsl
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/make_counted.hpp"
#include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/client_config.hpp"
#include "caf/net/dsl/has_connect.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/tcp_stream_socket.hpp"
#include "caf/uri.hpp"
#include <cstdint>
#include <string>
namespace caf::net::dsl {
/// DSL entry point for creating a client from an URI.
template <class Base, class Subtype>
class has_uri_connect : public Base {
public:
using trait_type = typename Base::trait_type;
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
auto connect(const uri& endpoint) {
auto& dref = static_cast<Subtype&>(*this);
auto cfg = this->make_lazy_config(endpoint);
return dref.lift(dref.with_context(std::move(cfg)));
}
/// Creates a `connect_factory` object for the given TCP `endpoint`.
///
/// @param endpoint The endpoint of the TCP server to connect to.
/// @returns a `connect_factory` object initialized with the given parameters.
auto connect(expected<uri> endpoint) {
if (endpoint)
return connect(*endpoint);
auto& dref = static_cast<Subtype&>(*this);
auto cfg = this->make_fail_config(endpoint.error());
return dref.lift(dref.with_context(std::move(cfg)));
}
};
} // namespace caf::net::dsl
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
#include "caf/callback.hpp" #include "caf/callback.hpp"
#include "caf/defaults.hpp" #include "caf/defaults.hpp"
#include "caf/detail/plain_ref_counted.hpp"
#include "caf/intrusive_ptr.hpp" #include "caf/intrusive_ptr.hpp"
#include "caf/net/dsl/has_trait.hpp" #include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/config_base.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/ssl/context.hpp" #include "caf/net/ssl/context.hpp"
#include "caf/net/tcp_accept_socket.hpp" #include "caf/net/tcp_accept_socket.hpp"
...@@ -20,60 +20,37 @@ ...@@ -20,60 +20,37 @@
namespace caf::net::dsl { namespace caf::net::dsl {
/// The server config type enum class. /// The server config type enum class.
enum class server_config_type { lazy, socket }; enum class server_config_type { lazy, socket, fail };
/// Base class for server configuration objects. /// Base class for server configuration objects.
template <class Trait> template <class Trait>
class server_config : public detail::plain_ref_counted { class server_config : public config_base {
public: public:
using trait_type = Trait;
class lazy; class lazy;
class socket; class socket;
class fail;
friend class lazy; friend class lazy;
friend class socket; friend class socket;
friend class fail;
server_config(const server_config&) = delete;
server_config& operator=(const server_config&) = delete;
/// Virtual destructor.
virtual ~server_config() = default;
/// Returns the server configuration type. /// Returns the server configuration type.
virtual server_config_type type() const noexcept = 0; virtual server_config_type type() const noexcept = 0;
/// The pointer to the @ref multiplexer for running the server.
multiplexer* mpx;
/// The user-defined trait for configuration serialization. /// The user-defined trait for configuration serialization.
Trait trait; Trait trait;
/// SSL context for secure servers. /// SSL context for secure servers.
std::shared_ptr<ssl::context> ctx; std::shared_ptr<ssl::context> ctx;
/// User-defined callback for errors.
shared_callback_ptr<void(const error&)> on_error;
/// Configures the maximum number of concurrent connections.
size_t max_connections = defaults::net::max_connections.fallback; size_t max_connections = defaults::net::max_connections.fallback;
/// Calls `on_error` if non-null.
void call_on_error(const error& what) {
if (on_error)
(*on_error)(what);
}
friend void intrusive_ptr_add_ref(const server_config* ptr) noexcept {
ptr->ref();
}
friend void intrusive_ptr_release(const server_config* ptr) noexcept {
ptr->deref();
}
private: private:
/// Private constructor to enforce sealing. /// Private constructor to enforce sealing.
server_config(multiplexer* mpx, const Trait& trait) : mpx(mpx), trait(trait) { server_config(multiplexer* mpx, const Trait& trait)
: config_base(mpx), trait(trait) {
// nop // nop
} }
}; };
...@@ -146,6 +123,28 @@ public: ...@@ -146,6 +123,28 @@ public:
} }
}; };
/// Wraps an error that occurred earlier in the setup phase.
template <class Trait>
class server_config<Trait>::fail final : public server_config<Trait> {
public:
static constexpr auto type_token = server_config_type::fail;
using super = server_config;
fail(multiplexer* mpx, const Trait& trait, error err)
: super(mpx, trait), err(std::move(err)) {
// nop
}
/// Returns the server configuration type.
server_config_type type() const noexcept override {
return type_token;
}
/// The forwarded error.
error err;
};
/// Convenience alias for the `lazy` sub-type of @ref server_config. /// Convenience alias for the `lazy` sub-type of @ref server_config.
template <class Trait> template <class Trait>
using lazy_server_config = typename server_config<Trait>::lazy; using lazy_server_config = typename server_config<Trait>::lazy;
...@@ -154,25 +153,39 @@ using lazy_server_config = typename server_config<Trait>::lazy; ...@@ -154,25 +153,39 @@ using lazy_server_config = typename server_config<Trait>::lazy;
template <class Trait> template <class Trait>
using socket_server_config = typename server_config<Trait>::socket; using socket_server_config = typename server_config<Trait>::socket;
/// Convenience alias for the `fail` sub-type of @ref server_config.
template <class Trait>
using fail_server_config = typename server_config<Trait>::fail;
/// Calls a function object with the actual subtype of a server configuration /// Calls a function object with the actual subtype of a server configuration
/// and returns its result. /// and returns its result.
template <class F, class Trait> template <class F, class Trait>
decltype(auto) visit(F&& f, server_config<Trait>& cfg) { decltype(auto) visit(F&& f, server_config<Trait>& cfg) {
auto type = cfg.type(); auto type = cfg.type();
if (cfg.type() == server_config_type::lazy) switch (cfg.type()) {
return f(static_cast<lazy_server_config<Trait>&>(cfg)); case server_config_type::lazy:
assert(type == server_config_type::socket); return f(static_cast<lazy_server_config<Trait>&>(cfg));
return f(static_cast<socket_server_config<Trait>&>(cfg)); case server_config_type::socket:
return f(static_cast<socket_server_config<Trait>&>(cfg));
default:
assert(type == server_config_type::fail);
return f(static_cast<fail_server_config<Trait>&>(cfg));
}
} }
/// Calls a function object with the actual subtype of a server configuration. /// Calls a function object with the actual subtype of a server configuration.
template <class F, class Trait> template <class F, class Trait>
decltype(auto) visit(F&& f, const server_config<Trait>& cfg) { decltype(auto) visit(F&& f, const server_config<Trait>& cfg) {
auto type = cfg.type(); auto type = cfg.type();
if (cfg.type() == server_config_type::lazy) switch (cfg.type()) {
return f(static_cast<const lazy_server_config<Trait>&>(cfg)); case server_config_type::lazy:
assert(type == server_config_type::socket); return f(static_cast<const lazy_server_config<Trait>&>(cfg));
return f(static_cast<const socket_server_config<Trait>&>(cfg)); case server_config_type::socket:
return f(static_cast<const socket_server_config<Trait>&>(cfg));
default:
assert(type == server_config_type::fail);
return f(static_cast<const fail_server_config<Trait>&>(cfg));
}
} }
/// Gets a pointer to a specific subtype of a server configuration. /// Gets a pointer to a specific subtype of a server configuration.
...@@ -191,4 +204,11 @@ const T* get_if(const server_config<Trait>* cfg) { ...@@ -191,4 +204,11 @@ const T* get_if(const server_config<Trait>* cfg) {
return nullptr; return nullptr;
} }
/// Creates a `fail_server_config` from another configuration object plus error.
template <class Trait>
auto to_fail_config(server_config_ptr<Trait> ptr, error err) {
using impl_t = fail_server_config<Trait>;
return make_counted<impl_t>(ptr->mpx, ptr->trait, std::move(err));
}
} // namespace caf::net::dsl } // namespace caf::net::dsl
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#pragma once #pragma once
#include "caf/make_counted.hpp" #include "caf/make_counted.hpp"
#include "caf/net/dsl/has_trait.hpp" #include "caf/net/dsl/base.hpp"
#include "caf/net/dsl/server_config.hpp" #include "caf/net/dsl/server_config.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/ssl/acceptor.hpp" #include "caf/net/ssl/acceptor.hpp"
......
...@@ -34,9 +34,11 @@ public: ...@@ -34,9 +34,11 @@ public:
using super::super; using super::super;
using start_res_t = expected<disposable>;
/// Starts a connection with the length-prefixing protocol. /// Starts a connection with the length-prefixing protocol.
template <class OnStart> template <class OnStart>
disposable start(OnStart on_start) { [[nodiscard]] expected<disposable> start(OnStart on_start) {
using input_res_t = typename Trait::input_resource; using input_res_t = typename Trait::input_resource;
using output_res_t = typename Trait::output_resource; using output_res_t = typename Trait::output_resource;
static_assert(std::is_invocable_v<OnStart, input_res_t, output_res_t>); static_assert(std::is_invocable_v<OnStart, input_res_t, output_res_t>);
...@@ -47,24 +49,24 @@ public: ...@@ -47,24 +49,24 @@ public:
} }
private: private:
expected<stream_socket> try_connect(const dsl::lazy_client_config<Trait>& cfg, auto try_connect(const dsl::lazy_client_config<Trait>& cfg,
const std::string& host, uint16_t port) { const std::string& host, uint16_t port) {
auto result = make_connected_tcp_stream_socket(host, port, auto result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout); cfg.connection_timeout);
if (result) if (result)
return {*result}; return result;
for (size_t i = 1; i <= cfg.max_retry_count; ++i) { for (size_t i = 1; i <= cfg.max_retry_count; ++i) {
std::this_thread::sleep_for(cfg.retry_delay); std::this_thread::sleep_for(cfg.retry_delay);
result = make_connected_tcp_stream_socket(host, port, result = make_connected_tcp_stream_socket(host, port,
cfg.connection_timeout); cfg.connection_timeout);
if (result) if (result)
return {*result}; return result;
} }
return {std::move(result.error())}; return result;
} }
template <class Conn, class OnStart> template <class Conn, class OnStart>
disposable start_res_t
do_start_impl(dsl::client_config<Trait>& cfg, Conn conn, OnStart& on_start) { do_start_impl(dsl::client_config<Trait>& cfg, Conn conn, OnStart& on_start) {
// s2a: socket-to-application (and a2s is the inverse). // s2a: socket-to-application (and a2s is the inverse).
using input_t = typename Trait::input_type; using input_t = typename Trait::input_type;
...@@ -78,18 +80,20 @@ private: ...@@ -78,18 +80,20 @@ private:
std::move(fc)); std::move(fc));
auto bridge_ptr = bridge.get(); auto bridge_ptr = bridge.get();
auto impl = framing::make(std::move(bridge)); auto impl = framing::make(std::move(bridge));
auto fd = conn.fd();
auto transport = transport_t::make(std::move(conn), std::move(impl)); auto transport = transport_t::make(std::move(conn), std::move(impl));
transport->active_policy().connect(fd);
auto ptr = socket_manager::make(cfg.mpx, std::move(transport)); auto ptr = socket_manager::make(cfg.mpx, std::move(transport));
bridge_ptr->self_ref(ptr->as_disposable()); bridge_ptr->self_ref(ptr->as_disposable());
cfg.mpx->start(ptr); cfg.mpx->start(ptr);
on_start(std::move(s2a_pull), std::move(a2s_push)); on_start(std::move(s2a_pull), std::move(a2s_push));
return disposable{std::move(ptr)}; return start_res_t{disposable{std::move(ptr)}};
} }
template <class OnStart> template <class OnStart>
disposable do_start(dsl::lazy_client_config<Trait>& cfg, start_res_t do_start(dsl::lazy_client_config<Trait>& cfg,
const std::string& host, uint16_t port, const std::string& host, uint16_t port,
OnStart& on_start) { OnStart& on_start) {
auto fd = try_connect(cfg, host, port); auto fd = try_connect(cfg, host, port);
if (fd) { if (fd) {
if (cfg.ctx) { if (cfg.ctx) {
...@@ -97,48 +101,39 @@ private: ...@@ -97,48 +101,39 @@ private:
if (conn) if (conn)
return do_start_impl(cfg, std::move(*conn), on_start); return do_start_impl(cfg, std::move(*conn), on_start);
cfg.call_on_error(conn.error()); cfg.call_on_error(conn.error());
return {}; return start_res_t{std::move(conn.error())};
} }
return do_start_impl(cfg, *fd, on_start); return do_start_impl(cfg, *fd, on_start);
} }
cfg.call_on_error(fd.error()); cfg.call_on_error(fd.error());
return {}; return start_res_t{std::move(fd.error())};
} }
template <class OnStart> template <class OnStart>
disposable do_start(dsl::lazy_client_config<Trait>& cfg, OnStart& on_start) { start_res_t do_start(dsl::lazy_client_config<Trait>& cfg, OnStart& on_start) {
if (auto* st = std::get_if<dsl::client_config_server_address>(&cfg.server)) if (auto* st = std::get_if<dsl::client_config_server_address>(&cfg.server))
return do_start(cfg, st->host, st->port, on_start); return do_start(cfg, st->host, st->port, on_start);
auto fail = [&cfg](auto code, std::string description) { auto err = make_error(sec::invalid_argument,
auto err = make_error(code, std::move(description)); "length-prefix factories do not accept URIs");
cfg.call_on_error(err); cfg.call_on_error(err);
return disposable{}; return start_res_t{std::move(err)};
};
auto& server_uri = std::get<uri>(cfg.server);
if (server_uri.scheme() != "tcp")
return fail(sec::invalid_argument, "connect expects tcp://<host>:<port>");
auto& auth = server_uri.authority();
if (auth.empty() || auth.port == 0)
return fail(sec::invalid_argument,
"connect expects tcp://<host>:<port> with non-zero port");
return do_start(cfg, auth.host_str(), auth.port, on_start);
} }
template <class OnStart> template <class OnStart>
disposable start_res_t
do_start(dsl::socket_client_config<Trait>& cfg, OnStart& on_start) { do_start(dsl::socket_client_config<Trait>& cfg, OnStart& on_start) {
return do_start_impl(cfg, cfg.take_fd(), on_start); return do_start_impl(cfg, cfg.take_fd(), on_start);
} }
template <class OnStart> template <class OnStart>
disposable do_start(dsl::conn_client_config<Trait>& cfg, OnStart& on_start) { start_res_t do_start(dsl::conn_client_config<Trait>& cfg, OnStart& on_start) {
return do_start_impl(cfg, std::move(cfg.state), on_start); return do_start_impl(cfg, std::move(cfg.state), on_start);
} }
template <class OnStart> template <class OnStart>
disposable do_start(dsl::fail_client_config<Trait>& cfg, OnStart&) { start_res_t do_start(dsl::fail_client_config<Trait>& cfg, OnStart&) {
cfg.call_on_error(cfg.err); cfg.call_on_error(cfg.err);
return {}; return start_res_t{std::move(cfg.err)};
} }
}; };
......
...@@ -70,10 +70,12 @@ public: ...@@ -70,10 +70,12 @@ public:
using super::super; using super::super;
using start_res_t = expected<disposable>;
/// Starts a server that accepts incoming connections with the /// Starts a server that accepts incoming connections with the
/// length-prefixing protocol. /// length-prefixing protocol.
template <class OnStart> template <class OnStart>
disposable start(OnStart on_start) { start_res_t start(OnStart on_start) {
using acceptor_resource = typename Trait::acceptor_resource; using acceptor_resource = typename Trait::acceptor_resource;
static_assert(std::is_invocable_v<OnStart, acceptor_resource>); static_assert(std::is_invocable_v<OnStart, acceptor_resource>);
auto f = [this, &on_start](auto& cfg) { auto f = [this, &on_start](auto& cfg) {
...@@ -84,8 +86,8 @@ public: ...@@ -84,8 +86,8 @@ public:
private: private:
template <class Factory, class AcceptHandler, class Acceptor, class OnStart> template <class Factory, class AcceptHandler, class Acceptor, class OnStart>
disposable do_start_impl(dsl::server_config<Trait>& cfg, Acceptor acc, start_res_t do_start_impl(dsl::server_config<Trait>& cfg, Acceptor acc,
OnStart& on_start) { OnStart& on_start) {
using accept_event = typename Trait::accept_event; using accept_event = typename Trait::accept_event;
using connector_t = detail::flow_connector<Trait>; using connector_t = detail::flow_connector<Trait>;
auto [pull, push] = async::make_spsc_buffer_resource<accept_event>(); auto [pull, push] = async::make_spsc_buffer_resource<accept_event>();
...@@ -98,12 +100,12 @@ private: ...@@ -98,12 +100,12 @@ private:
impl_ptr->self_ref(ptr->as_disposable()); impl_ptr->self_ref(ptr->as_disposable());
cfg.mpx->start(ptr); cfg.mpx->start(ptr);
on_start(std::move(pull)); on_start(std::move(pull));
return disposable{std::move(ptr)}; return start_res_t{disposable{std::move(ptr)}};
} }
template <class OnStart> template <class OnStart>
disposable do_start(dsl::server_config<Trait>& cfg, tcp_accept_socket fd, start_res_t do_start(dsl::server_config<Trait>& cfg, tcp_accept_socket fd,
OnStart& on_start) { OnStart& on_start) {
if (!cfg.ctx) { if (!cfg.ctx) {
using factory_t = detail::lp_connection_factory<Trait, stream_transport>; using factory_t = detail::lp_connection_factory<Trait, stream_transport>;
using impl_t = detail::accept_handler<tcp_accept_socket, stream_socket>; using impl_t = detail::accept_handler<tcp_accept_socket, stream_socket>;
...@@ -116,29 +118,35 @@ private: ...@@ -116,29 +118,35 @@ private:
} }
template <class OnStart> template <class OnStart>
disposable start_res_t
do_start(typename dsl::server_config<Trait>::socket& cfg, OnStart& on_start) { do_start(typename dsl::server_config<Trait>::socket& cfg, OnStart& on_start) {
if (cfg.fd == invalid_socket) { if (cfg.fd == invalid_socket) {
auto err = make_error( auto err = make_error(
sec::runtime_error, sec::runtime_error,
"server factory cannot create a server on an invalid socket"); "server factory cannot create a server on an invalid socket");
cfg.call_on_error(err); cfg.call_on_error(err);
return {}; return start_res_t{std::move(err)};
} }
return do_start(cfg, cfg.take_fd(), on_start); return do_start(cfg, cfg.take_fd(), on_start);
} }
template <class OnStart> template <class OnStart>
disposable start_res_t
do_start(typename dsl::server_config<Trait>::lazy& cfg, OnStart& on_start) { do_start(typename dsl::server_config<Trait>::lazy& cfg, OnStart& on_start) {
auto fd = make_tcp_accept_socket(cfg.port, cfg.bind_address, auto fd = make_tcp_accept_socket(cfg.port, cfg.bind_address,
cfg.reuse_addr); cfg.reuse_addr);
if (!fd) { if (!fd) {
cfg.call_on_error(fd.error()); cfg.call_on_error(fd.error());
return {}; return start_res_t{std::move(fd.error())};
} }
return do_start(cfg, *fd, on_start); return do_start(cfg, *fd, on_start);
} }
template <class OnStart>
start_res_t do_start(dsl::fail_server_config<Trait>& cfg, OnStart&) {
cfg.call_on_error(cfg.err);
return start_res_t{std::move(cfg.err)};
}
}; };
} // namespace caf::net::lp } // namespace caf::net::lp
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "caf/fwd.hpp" #include "caf/fwd.hpp"
#include "caf/net/dsl/has_accept.hpp" #include "caf/net/dsl/has_accept.hpp"
#include "caf/net/dsl/has_connect.hpp" #include "caf/net/dsl/has_connect.hpp"
#include "caf/net/dsl/has_context.hpp"
#include "caf/net/lp/client_factory.hpp" #include "caf/net/lp/client_factory.hpp"
#include "caf/net/lp/server_factory.hpp" #include "caf/net/lp/server_factory.hpp"
#include "caf/net/multiplexer.hpp" #include "caf/net/multiplexer.hpp"
...@@ -20,12 +21,12 @@ namespace caf::net::lp { ...@@ -20,12 +21,12 @@ namespace caf::net::lp {
/// Entry point for the `with(...)` DSL. /// Entry point for the `with(...)` DSL.
template <class Trait> template <class Trait>
class with_t : public dsl::has_accept<server_factory<Trait>>, class with_t : public extend<dsl::base<Trait>, with_t<Trait>>::template //
public dsl::has_connect<client_factory<Trait>> { with<dsl::has_accept, dsl::has_connect, dsl::has_context> {
public: public:
template <class... Ts> template <class... Ts>
explicit with_t(multiplexer* mpx, Ts&&... xs) explicit with_t(multiplexer* mpx, Ts&&... xs)
: mpx_(mpx), trait_(std::forward<Ts>(xs)...) { : mpx_(mpx), trait_(std::forward<Ts>(xs)...), ctx_(error{}) {
// nop // nop
} }
...@@ -41,12 +42,29 @@ public: ...@@ -41,12 +42,29 @@ public:
return trait_; return trait_;
} }
/// @private
server_factory<Trait> lift(dsl::server_config_ptr<Trait> cfg) {
return server_factory<Trait>{std::move(cfg)};
}
/// @private
client_factory<Trait> lift(dsl::client_config_ptr<Trait> cfg) {
return client_factory<Trait>{std::move(cfg)};
}
private: private:
expected<ssl::context>& get_context_impl() noexcept override {
return ctx_;
}
/// Pointer to multiplexer that runs the protocol stack. /// Pointer to multiplexer that runs the protocol stack.
multiplexer* mpx_; multiplexer* mpx_;
/// User-defined trait for configuring serialization. /// User-defined trait for configuring serialization.
Trait trait_; Trait trait_;
/// The optional SSL context.
expected<ssl::context> ctx_;
}; };
template <class Trait = binary::default_trait> template <class Trait = binary::default_trait>
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#pragma once #pragma once
#include "caf/detail/net_export.hpp" #include "caf/detail/net_export.hpp"
#include "caf/expected.hpp"
#include "caf/net/dsl/arg.hpp"
#include "caf/net/socket_guard.hpp"
#include "caf/net/ssl/dtls.hpp" #include "caf/net/ssl/dtls.hpp"
#include "caf/net/ssl/format.hpp" #include "caf/net/ssl/format.hpp"
#include "caf/net/ssl/fwd.hpp" #include "caf/net/ssl/fwd.hpp"
...@@ -15,6 +18,7 @@ ...@@ -15,6 +18,7 @@
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <type_traits>
namespace caf::net::ssl { namespace caf::net::ssl {
...@@ -45,19 +49,30 @@ public: ...@@ -45,19 +49,30 @@ public:
// -- factories -------------------------------------------------------------- // -- factories --------------------------------------------------------------
/// Starting point for chaining `expected<T>::and_then()` invocations, whereas
/// the next function in the chain should create the SSL context depending on
/// the value of `flag`.
static expected<void> enable(bool flag);
/// Returns a generic SSL context with TLS.
static expected<context> make(tls min_version, tls max_version = tls::any); static expected<context> make(tls min_version, tls max_version = tls::any);
/// Returns a SSL context with TLS for a server role.
static expected<context> make_server(tls min_version, static expected<context> make_server(tls min_version,
tls max_version = tls::any); tls max_version = tls::any);
/// Returns a SSL context with TLS for a client role.
static expected<context> make_client(tls min_version, static expected<context> make_client(tls min_version,
tls max_version = tls::any); tls max_version = tls::any);
/// Returns a generic SSL context with DTLS.
static expected<context> make(dtls min_version, dtls max_version = dtls::any); static expected<context> make(dtls min_version, dtls max_version = dtls::any);
/// Returns a SSL context with DTLS for a server role.
static expected<context> make_server(dtls min_version, static expected<context> make_server(dtls min_version,
dtls max_version = dtls::any); dtls max_version = dtls::any);
/// Returns a SSL context with TLS for a client role.
static expected<context> make_client(dtls min_version, static expected<context> make_client(dtls min_version,
dtls max_version = dtls::any); dtls max_version = dtls::any);
...@@ -73,26 +88,26 @@ public: ...@@ -73,26 +88,26 @@ public:
/// Overrides the verification mode for this context. /// Overrides the verification mode for this context.
/// @note calls @c SSL_CTX_set_verify /// @note calls @c SSL_CTX_set_verify
void set_verify_mode(verify_t flags); void verify_mode(verify_t flags);
/// Overrides the callback to obtain the password for encrypted PEM files. /// Overrides the callback to obtain the password for encrypted PEM files.
/// @note calls @c SSL_CTX_set_default_passwd_cb /// @note calls @c SSL_CTX_set_default_passwd_cb
template <typename PasswordCallback> template <typename PasswordCallback>
void set_password_callback(PasswordCallback callback) { void password_callback(PasswordCallback callback) {
set_password_callback_impl(password::make_callback(std::move(callback))); password_callback_impl(password::make_callback(std::move(callback)));
} }
/// Overrides the callback to obtain the password for encrypted PEM files with /// Overrides the callback to obtain the password for encrypted PEM files with
/// a function that always returns @p password. /// a function that always returns @p password.
/// @note calls @c SSL_CTX_set_default_passwd_cb /// @note calls @c SSL_CTX_set_default_passwd_cb
void set_password(std::string password) { void password(std::string password) {
auto cb = [pw = std::move(password)](char* buf, int len, auto cb = [pw = std::move(password)](char* buf, int len,
password::purpose) { password::purpose) {
strncpy(buf, pw.c_str(), static_cast<size_t>(len)); strncpy(buf, pw.c_str(), static_cast<size_t>(len));
buf[len - 1] = '\0'; buf[len - 1] = '\0';
return static_cast<int>(pw.size()); return static_cast<int>(pw.size());
}; };
set_password_callback(std::move(cb)); password_callback(std::move(cb));
} }
// -- native handles --------------------------------------------------------- // -- native handles ---------------------------------------------------------
...@@ -107,13 +122,37 @@ public: ...@@ -107,13 +122,37 @@ public:
// -- error handling --------------------------------------------------------- // -- error handling ---------------------------------------------------------
/// Retrieves a human-readable error description for a preceding call to /// Retrieves a human-readable error description for a preceding call to
/// another member functions and removes that error from the error queue. Call /// another member functions and removes that error from the thread-local
/// repeatedly until @ref has_last_error returns `false` to retrieve all /// error queue. Call repeatedly until @ref has_error returns `false` to
/// errors from the queue. /// retrieve all errors from the queue.
static std::string next_error_string();
/// Retrieves a human-readable error description for a preceding call to
/// another member functions, appends the generated string to `buf` and
/// removes that error from the thread-local error queue. Call repeatedly
/// until @ref has_error returns `false` to retrieve all errors from the
/// queue.
static void append_next_error_string(std::string& buf);
/// Convenience function for calling `next_error_string` repeatedly until
/// @ref has_error returns `false`.
static std::string last_error_string(); static std::string last_error_string();
/// Queries whether the error stack has at least one entry. /// Queries whether the thread-local error stack has at least one entry.
static bool has_last_error() noexcept; static bool has_error() noexcept;
/// Retrieves all errors from the thread-local error queue and assembles them
/// into a single error string.
/// @returns all error strings from the thread-local error queue or
static error last_error();
/// Returns @ref last_error or `default_error` if the former is
/// default-constructed.
static error last_error_or(error default_error);
/// Returns @ref last_error or an error that represents an unexpected failure
/// if the former is default-constructed.
static error last_error_or_unexpected(std::string_view description);
// -- connections ------------------------------------------------------------ // -- connections ------------------------------------------------------------
...@@ -132,7 +171,7 @@ public: ...@@ -132,7 +171,7 @@ public:
/// certificates. /// certificates.
/// @returns `true` on success, `false` otherwise and `last_error` can be used /// @returns `true` on success, `false` otherwise and `last_error` can be used
/// to retrieve a human-readable error representation. /// to retrieve a human-readable error representation.
[[nodiscard]] bool set_default_verify_paths(); [[nodiscard]] bool enable_default_verify_paths();
/// Configures the context to load CA certificate from a directory. /// Configures the context to load CA certificate from a directory.
/// @param path Null-terminated string with a path to a directory. Files in /// @param path Null-terminated string with a path to a directory. Files in
...@@ -162,9 +201,15 @@ public: ...@@ -162,9 +201,15 @@ public:
} }
/// Loads the first certificate found in given file. /// Loads the first certificate found in given file.
/// @param path Null-terminated string with a path to a single PEM file. /// @param path Null-terminated string with a path to a single file.
[[nodiscard]] bool use_certificate_file(const char* path, format file_format); [[nodiscard]] bool use_certificate_file(const char* path, format file_format);
/// @copydoc use_certificate_file
[[nodiscard]] bool use_certificate_file(const std::string& path,
format file_format) {
return use_certificate_file(path.c_str(), file_format);
}
/// Loads a certificate chain from a PEM-formatted file. /// Loads a certificate chain from a PEM-formatted file.
/// @note calls @c SSL_CTX_use_certificate_chain_file /// @note calls @c SSL_CTX_use_certificate_chain_file
[[nodiscard]] bool use_certificate_chain_file(const char* path); [[nodiscard]] bool use_certificate_chain_file(const char* path);
...@@ -188,10 +233,252 @@ private: ...@@ -188,10 +233,252 @@ private:
// nop // nop
} }
void set_password_callback_impl(password::callback_ptr callback); void password_callback_impl(password::callback_ptr callback);
impl* pimpl_; impl* pimpl_;
user_data* data_ = nullptr; user_data* data_ = nullptr;
}; };
} // namespace caf::net::ssl } // namespace caf::net::ssl
namespace caf::detail {
// Convenience function for turning the Boolean results into a
// expected<context>.
inline expected<net::ssl::context>
ssl_ctx_chain(net::ssl::context& ctx, std::string_view descr, bool fn_res) {
using net::ssl::context;
if (fn_res)
return expected<context>{std::move(ctx)};
else
return expected<context>{context::last_error_or_unexpected(descr)};
}
// Convenience function for calling a member function on the context with some
// arguments.
template <class... Ts, class... Args>
expected<net::ssl::context>
ssl_ctx_chain(net::ssl::context& ctx, std::string_view arg_check_error,
std::string_view fn_error, bool (net::ssl::context::*fn)(Ts...),
Args&... args) {
using net::ssl::context;
if ((!args && ...)) {
auto err = make_error(sec::invalid_argument, std::string{arg_check_error});
return expected<context>{std::move(err)};
} else if ((ctx.*fn)(args.get()...)) {
return expected<context>{std::move(ctx)};
} else {
return expected<context>{context::last_error_or_unexpected(fn_error)};
}
}
// Convenience function for calling a member function on the context with some
// arguments. Unlike ssl_ctx_chain, this function does not result in an error if
// the arguments are invalid but simply returns the context as-is.
template <class... Ts, class... Args>
expected<net::ssl::context>
ssl_ctx_chain_if(net::ssl::context& ctx, std::string_view fn_error,
bool (net::ssl::context::*fn)(Ts...), Args&... args) {
using net::ssl::context;
if ((!args && ...) || (ctx.*fn)(args.get()...))
return expected<context>{std::move(ctx)};
else
return expected<context>{context::last_error_or_unexpected(fn_error)};
}
} // namespace caf::detail
namespace caf::net::ssl {
// -- utility functions for turning expected<void> into an expected<context> ---
inline auto emplace_context(tls min_version, tls max_version = tls::any) {
return [=] { return context::make(min_version, max_version); };
}
inline auto emplace_server(tls min_version, tls max_version = tls::any) {
return [=] { return context::make_server(min_version, max_version); };
}
inline auto emplace_client(tls min_version, tls max_version = tls::any) {
return [=] { return context::make_client(min_version, max_version); };
}
inline auto emplace_context(dtls min_version, dtls max_version = dtls::any) {
return [=] { return context::make(min_version, max_version); };
}
inline auto emplace_server(dtls min_version, dtls max_version = dtls::any) {
return [=] { return context::make_server(min_version, max_version); };
}
inline auto emplace_client(dtls min_version, dtls max_version = dtls::any) {
return [=] { return context::make_client(min_version, max_version); };
}
// -- utility functions for chaining .and_then(...) on an expected<context> ----
/// Creates a new SSL connection on `fd`. The connection does not take ownership
/// of the socket, i.e., does not close the socket when the SSL session end or
/// on error.
/// @param fd the stream socket for adding encryption.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto new_connection(stream_socket fd) {
// Note: this is a template to force the compiler to evaluate the body at a
// later time, because ssl::connection is incomplete here.
return [fd](auto ctx) { return ctx.new_connection(fd); };
}
/// Creates a new SSL connection on `fd`. The connection takes ownership of
/// the socket, i.e., closes the socket when the SSL session ends.
/// @param fd the stream socket for adding encryption.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto new_connection(stream_socket fd, close_on_shutdown_t) {
// Wrap into a guard to make sure the socket gets closed if this function
// doesn't get called.
return [guard = make_socket_guard(fd)](auto ctx) mutable {
return ctx.new_connection(guard.release(), close_on_shutdown);
};
}
/// Configure a context to use the default locations for loading CA
/// certificates.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto enable_default_verify_paths() {
return [](context ctx) {
return detail::ssl_ctx_chain(ctx, "enable_default_verify_paths failed",
ctx.enable_default_verify_paths());
};
}
/// Configures the context to load CA certificate from a directory.
/// @param path Null-terminated string with a path to a directory. Files in
/// the directory must use the CA subject name hash value as file
/// name with a suffix to disambiguate multiple certificates,
/// e.g., `9d66eef0.0` and `9d66eef0.1`.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto add_verify_path(dsl::arg::cstring path) {
return [arg = std::move(path)](context ctx) {
bool (context::*fn)(const char*) = &context::add_verify_path;
return detail::ssl_ctx_chain(ctx, "add_verify_path: path cannot be null",
"add_verify_path failed", fn, arg);
};
}
/// Configures the context to load CA certificate from a directory if all
/// arguments are non-null. Otherwise, does nothing.
/// @param path Null-terminated string with a path to a directory. Files in
/// the directory must use the CA subject name hash value as file
/// name with a suffix to disambiguate multiple certificates,
/// e.g., `9d66eef0.0` and `9d66eef0.1`.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto add_verify_path_if(dsl::arg::cstring path) {
return [arg = std::move(path)](context ctx) {
bool (context::*fn)(const char*) = &context::add_verify_path;
return detail::ssl_ctx_chain_if(ctx, "add_verify_path failed", fn, arg);
};
}
/// Loads a CA certificate file.
/// @param path String with a path to a single PEM file.
/// @returns `true` on success, `false` otherwise and `last_error` can be used
/// to retrieve a human-readable error representation.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto load_verify_file(dsl::arg::cstring path) {
return [arg = std::move(path)](context ctx) {
bool (context::*fn)(const char*) = &context::load_verify_file;
return detail::ssl_ctx_chain(ctx, "load_verify_file: path cannot be null",
"load_verify_file failed", fn, arg);
};
}
/// Loads a CA certificate file if all arguments are non-null. Otherwise, does
/// nothing.
/// @param path String with a path to a single PEM file.
/// @returns `true` on success, `false` otherwise and `last_error` can be used
/// to retrieve a human-readable error representation.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto load_verify_file_if(dsl::arg::cstring path) {
return [arg = std::move(path)](context ctx) {
bool (context::*fn)(const char*) = &context::load_verify_file;
return detail::ssl_ctx_chain_if(ctx, "load_verify_file failed", fn, arg);
};
}
/// Loads the first certificate found in given file.
/// @param path Null-terminated string with a path to a single file.
/// @param file_format Denotes the format of the certificate file.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto use_certificate_file(dsl::arg::cstring path,
dsl::arg::val<format> file_format) {
return [arg1 = std::move(path), arg2 = file_format](context ctx) mutable {
bool (context::*fn)(const char*, format) = &context::use_certificate_file;
return detail::ssl_ctx_chain(
ctx, "use_certificate_file: path and file_format cannot be null",
"use_certificate_file failed", fn, arg1, arg2);
};
}
/// Loads the first certificate found in given file if all arguments are
/// non-null. Otherwise, does nothing.
/// @param path Null-terminated string with a path to a single file.
/// @param file_format Denotes the format of the certificate file.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto use_certificate_file_if(dsl::arg::cstring path,
dsl::arg::val<format> file_format) {
return [arg1 = std::move(path), arg2 = file_format](context ctx) mutable {
bool (context::*fn)(const char*, format) = &context::use_certificate_file;
return detail::ssl_ctx_chain_if(ctx, "use_certificate_file failed", fn,
arg1, arg2);
};
}
/// Loads a certificate chain from a PEM-formatted file.
/// @note calls @c SSL_CTX_use_certificate_chain_file
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto use_certificate_chain_file(dsl::arg::cstring path) {
return [arg = std::move(path)](context ctx) mutable {
bool (context::*fn)(const char*) = &context::use_certificate_chain_file;
return detail::ssl_ctx_chain(
ctx, "use_certificate_chain_file: path cannot be null",
"use_certificate_chain_file failed", fn, arg);
};
}
/// Loads a certificate chain from a PEM-formatted file if all arguments are
/// non-null. Otherwise, does nothing.
/// @note calls @c SSL_CTX_use_certificate_chain_file
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto use_certificate_chain_file_if(dsl::arg::cstring path) {
return [arg = std::move(path)](context ctx) mutable {
bool (context::*fn)(const char*) = &context::use_certificate_chain_file;
return detail::ssl_ctx_chain_if(ctx, "use_certificate_chain_file failed",
fn, arg);
};
}
/// Loads the first private key found in given file.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto use_private_key_file(dsl::arg::cstring path,
dsl::arg::val<format> file_format) {
return [arg1 = std::move(path), arg2 = file_format](context ctx) mutable {
bool (context::*fn)(const char*, format) = &context::use_private_key_file;
return detail::ssl_ctx_chain(
ctx, "use_private_key_file: path and file_format cannot be null",
"use_private_key_file failed", fn, arg1, arg2);
};
}
/// Loads the first private key found in given file if all arguments are
/// non-null. Otherwise, does nothing.
/// @returns a function object for chaining `expected<T>::and_then()`.
inline auto use_private_key_file_if(dsl::arg::cstring path,
dsl::arg::val<format> file_format) {
return [arg1 = std::move(path), arg2 = file_format](context ctx) mutable {
bool (context::*fn)(const char*, format) = &context::use_private_key_file;
return detail::ssl_ctx_chain_if(ctx, "use_private_key_file failed", fn,
arg1, arg2);
};
}
} // namespace caf::net::ssl
...@@ -61,4 +61,10 @@ bool inspect(Inspector& f, errc& x) { ...@@ -61,4 +61,10 @@ bool inspect(Inspector& f, errc& x) {
} // namespace caf::net::ssl } // namespace caf::net::ssl
namespace caf::detail {
CAF_NET_EXPORT net::ssl::errc ssl_errc_from_native(int);
} // namespace caf::detail
CAF_ERROR_CODE_ENUM(caf::net::ssl::errc) CAF_ERROR_CODE_ENUM(caf::net::ssl::errc)
...@@ -2,25 +2,12 @@ ...@@ -2,25 +2,12 @@
// the main distribution directory for license terms and copyright or visit // the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE. // https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once #include "caf/net/dsl/config_base.hpp"
#include "caf/net/fwd.hpp"
namespace caf::net::dsl { namespace caf::net::dsl {
/// Base type for DSL classes. config_base::~config_base() {
template <class Trait> // nop
class has_trait { }
public:
virtual ~has_trait() {
// nop
}
/// @returns the pointer to the @ref multiplexer.
virtual multiplexer* mpx() const noexcept = 0;
/// @returns the trait object.
virtual const Trait& trait() const noexcept = 0;
};
} // namespace caf::net::dsl } // namespace caf::net::dsl
...@@ -65,40 +65,8 @@ std::string connection::last_error_string(ptrdiff_t ret) const { ...@@ -65,40 +65,8 @@ std::string connection::last_error_string(ptrdiff_t ret) const {
} }
errc connection::last_error(ptrdiff_t ret) const { errc connection::last_error(ptrdiff_t ret) const {
switch (SSL_get_error(native(pimpl_), static_cast<int>(ret))) { auto code = SSL_get_error(native(pimpl_), static_cast<int>(ret));
case SSL_ERROR_NONE: return detail::ssl_errc_from_native(code);
return errc::none;
case SSL_ERROR_ZERO_RETURN:
return errc::closed;
case SSL_ERROR_WANT_READ:
return errc::want_read;
case SSL_ERROR_WANT_WRITE:
return errc::want_write;
case SSL_ERROR_WANT_CONNECT:
return errc::want_connect;
case SSL_ERROR_WANT_ACCEPT:
return errc::want_accept;
case SSL_ERROR_WANT_X509_LOOKUP:
return errc::want_x509_lookup;
#ifdef SSL_ERROR_WANT_ASYNC
case SSL_ERROR_WANT_ASYNC:
return errc::want_async;
#endif
#ifdef SSL_ERROR_WANT_ASYNC_JOB
case SSL_ERROR_WANT_ASYNC_JOB:
return errc::want_async_job;
#endif
#ifdef SSL_ERROR_WANT_CLIENT_HELLO_CB
case SSL_ERROR_WANT_CLIENT_HELLO_CB:
return errc::want_client_hello;
#endif
case SSL_ERROR_SYSCALL:
return errc::syscall_failed;
case SSL_ERROR_SSL:
return errc::fatal;
default:
return errc::unspecified;
}
} }
// -- connecting and teardown -------------------------------------------------- // -- connecting and teardown --------------------------------------------------
......
...@@ -105,6 +105,16 @@ make_ctx(const SSL_METHOD* method, Enum min_val, Enum max_val) { ...@@ -105,6 +105,16 @@ make_ctx(const SSL_METHOD* method, Enum min_val, Enum max_val) {
} // namespace } // namespace
expected<void> context::enable(bool flag) {
// By returning a default-constructed error, we suppress any subsequent
// function calls in an `and_then` chain. The caf-net DSL then treats a
// default-constructed error as "no SSL".
if (flag)
return expected<void>{};
else
return expected<void>{caf::error{}};
}
expected<context> context::make(tls vmin, tls vmax) { expected<context> context::make(tls vmin, tls vmax) {
auto [raw, errstr] = make_ctx(CAF_TLS_METHOD(_), vmin, vmax); auto [raw, errstr] = make_ctx(CAF_TLS_METHOD(_), vmin, vmax);
context ctx{reinterpret_cast<impl*>(raw)}; context ctx{reinterpret_cast<impl*>(raw)};
...@@ -161,7 +171,7 @@ expected<context> context::make_client(dtls vmin, dtls vmax) { ...@@ -161,7 +171,7 @@ expected<context> context::make_client(dtls vmin, dtls vmax) {
// -- properties --------------------------------------------------------------- // -- properties ---------------------------------------------------------------
void context::set_verify_mode(verify_t flags) { void context::verify_mode(verify_t flags) {
auto ptr = native(pimpl_); auto ptr = native(pimpl_);
SSL_CTX_set_verify(ptr, to_integer(flags), SSL_CTX_get_verify_callback(ptr)); SSL_CTX_set_verify(ptr, to_integer(flags), SSL_CTX_get_verify_callback(ptr));
} }
...@@ -178,7 +188,7 @@ int c_password_callback(char* buf, int size, int rwflag, void* ptr) { ...@@ -178,7 +188,7 @@ int c_password_callback(char* buf, int size, int rwflag, void* ptr) {
} // namespace } // namespace
void context::set_password_callback_impl(password::callback_ptr callback) { void context::password_callback_impl(password::callback_ptr callback) {
if (data_ == nullptr) if (data_ == nullptr)
data_ = new user_data; data_ = new user_data;
auto ptr = native(pimpl_); auto ptr = native(pimpl_);
...@@ -199,27 +209,68 @@ void* context::native_handle() const noexcept { ...@@ -199,27 +209,68 @@ void* context::native_handle() const noexcept {
// -- error handling ----------------------------------------------------------- // -- error handling -----------------------------------------------------------
std::string context::last_error_string() { std::string context::next_error_string() {
std::string result;
append_next_error_string(result);
return result;
}
void context::append_next_error_string(std::string& buf) {
auto save_cstr = [](const char* cstr) { return cstr ? cstr : "NULL"; }; auto save_cstr = [](const char* cstr) { return cstr ? cstr : "NULL"; };
if (auto code = ERR_get_error(); code != 0) { if (auto code = ERR_get_error(); code != 0) {
std::string result; buf = "error:";
result.reserve(256); buf += std::to_string(code);
result = "error:"; buf += ':';
result += std::to_string(code); buf += save_cstr(ERR_lib_error_string(code));
result += ':'; buf += "::";
result += save_cstr(ERR_lib_error_string(code)); buf += save_cstr(ERR_reason_error_string(code));
result += "::";
result += save_cstr(ERR_reason_error_string(code));
return result;
} else { } else {
return "no-error"; buf += "no-error";
}
}
std::string context::last_error_string() {
if (!has_error())
return {};
auto result = next_error_string();
while (has_error()) {
result += '\n';
append_next_error_string(result);
} }
return result;
} }
bool context::has_last_error() noexcept { bool context::has_error() noexcept {
return ERR_peek_error() != 0; return ERR_peek_error() != 0;
} }
error context::last_error() {
if (ERR_peek_error() == 0)
return error{};
auto description = next_error_string();
while (has_error()) {
description += '\n';
append_next_error_string(description);
}
// TODO: Mapping the codes to an error enum would be much nicer than using
// the generic 'runtime_error'.
return make_error(sec::runtime_error, std::move(description));
}
error context::last_error_or(error default_error) {
if (ERR_peek_error() == 0)
return default_error;
else
return last_error();
}
error context::last_error_or_unexpected(std::string_view description) {
if (ERR_peek_error() == 0)
return make_error(sec::runtime_error, std::string{description});
else
return last_error();
}
// -- connections -------------------------------------------------------------- // -- connections --------------------------------------------------------------
expected<connection> context::new_connection(stream_socket fd) { expected<connection> context::new_connection(stream_socket fd) {
...@@ -227,6 +278,7 @@ expected<connection> context::new_connection(stream_socket fd) { ...@@ -227,6 +278,7 @@ expected<connection> context::new_connection(stream_socket fd) {
auto conn = connection::from_native(ptr); auto conn = connection::from_native(ptr);
if (auto bio_ptr = BIO_new_socket(fd.id, BIO_NOCLOSE)) { if (auto bio_ptr = BIO_new_socket(fd.id, BIO_NOCLOSE)) {
SSL_set_bio(ptr, bio_ptr, bio_ptr); SSL_set_bio(ptr, bio_ptr, bio_ptr);
return {std::move(conn)}; return {std::move(conn)};
} else { } else {
return {make_error(sec::logic_error, "BIO_new_socket failed")}; return {make_error(sec::logic_error, "BIO_new_socket failed")};
...@@ -252,7 +304,7 @@ expected<connection> context::new_connection(stream_socket fd, ...@@ -252,7 +304,7 @@ expected<connection> context::new_connection(stream_socket fd,
// -- certificates and keys ---------------------------------------------------- // -- certificates and keys ----------------------------------------------------
bool context::set_default_verify_paths() { bool context::enable_default_verify_paths() {
ERR_clear_error(); ERR_clear_error();
return SSL_CTX_set_default_verify_paths(native(pimpl_)) == 1; return SSL_CTX_set_default_verify_paths(native(pimpl_)) == 1;
} }
......
#include "caf/net/ssl/errc.hpp"
#include "caf/config.hpp"
CAF_PUSH_WARNINGS
#include <openssl/err.h>
#include <openssl/ssl.h>
CAF_POP_WARNINGS
namespace caf::detail {
net::ssl::errc ssl_errc_from_native(int code) {
using net::ssl::errc;
switch (code) {
case SSL_ERROR_NONE:
return errc::none;
case SSL_ERROR_ZERO_RETURN:
return errc::closed;
case SSL_ERROR_WANT_READ:
return errc::want_read;
case SSL_ERROR_WANT_WRITE:
return errc::want_write;
case SSL_ERROR_WANT_CONNECT:
return errc::want_connect;
case SSL_ERROR_WANT_ACCEPT:
return errc::want_accept;
case SSL_ERROR_WANT_X509_LOOKUP:
return errc::want_x509_lookup;
#ifdef SSL_ERROR_WANT_ASYNC
case SSL_ERROR_WANT_ASYNC:
return errc::want_async;
#endif
#ifdef SSL_ERROR_WANT_ASYNC_JOB
case SSL_ERROR_WANT_ASYNC_JOB:
return errc::want_async_job;
#endif
#ifdef SSL_ERROR_WANT_CLIENT_HELLO_CB
case SSL_ERROR_WANT_CLIENT_HELLO_CB:
return errc::want_client_hello;
#endif
case SSL_ERROR_SYSCALL:
return errc::syscall_failed;
case SSL_ERROR_SSL:
return errc::fatal;
default:
return errc::unspecified;
}
}
} // namespace caf::detail
...@@ -247,28 +247,32 @@ SCENARIO("lp::with(...).connect(...) translates between flows and socket I/O") { ...@@ -247,28 +247,32 @@ SCENARIO("lp::with(...).connect(...) translates between flows and socket I/O") {
caf::actor_system sys{cfg}; caf::actor_system sys{cfg};
auto buf = std::make_shared<std::vector<std::string>>(); auto buf = std::make_shared<std::vector<std::string>>();
caf::actor hdl; caf::actor hdl;
net::lp::with(sys).connect(fd2).start([&](auto pull, auto push) { auto conn
hdl = sys.spawn([buf, pull, push](event_based_actor* self) { = net::lp::with(sys) //
pull.observe_on(self) .connect(fd2)
.do_on_error([](const error& what) { // .start([&](auto pull, auto push) {
MESSAGE("flow aborted: " << what); hdl = sys.spawn([buf, pull, push](event_based_actor* self) {
}) pull.observe_on(self)
.do_on_complete([] { MESSAGE("flow completed"); }) .do_on_error([](const error& what) { //
.do_on_next([buf](const net::binary::frame& x) { MESSAGE("flow aborted: " << what);
std::string str; })
for (auto val : x.bytes()) .do_on_complete([] { MESSAGE("flow completed"); })
str.push_back(static_cast<char>(val)); .do_on_next([buf](const net::binary::frame& x) {
buf->push_back(std::move(str)); std::string str;
}) for (auto val : x.bytes())
.map([](const net::binary::frame& x) { str.push_back(static_cast<char>(val));
std::string response = "ok "; buf->push_back(std::move(str));
for (auto val : x.bytes()) })
response.push_back(static_cast<char>(val)); .map([](const net::binary::frame& x) {
return net::binary::frame{as_bytes(make_span(response))}; std::string response = "ok ";
}) for (auto val : x.bytes())
.subscribe(push); response.push_back(static_cast<char>(val));
}); return net::binary::frame{as_bytes(make_span(response))};
}); })
.subscribe(push);
});
});
REQUIRE(conn);
scoped_actor self{sys}; scoped_actor self{sys};
self->wait_for(hdl); self->wait_for(hdl);
if (CHECK_EQ(buf->size(), 5u)) { if (CHECK_EQ(buf->size(), 5u)) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment