Commit 8e7f7b58 authored by Dominik Charousset's avatar Dominik Charousset

Implement new OpenSSL transport

parent f1a3b267
...@@ -19,6 +19,7 @@ caf_incubator_add_component( ...@@ -19,6 +19,7 @@ caf_incubator_add_component(
net.http.method net.http.method
net.http.status net.http.status
net.operation net.operation
net.stream_transport_error
net.web_socket.status net.web_socket.status
HEADERS HEADERS
${CAF_NET_HEADERS} ${CAF_NET_HEADERS}
...@@ -79,3 +80,9 @@ caf_incubator_add_component( ...@@ -79,3 +80,9 @@ caf_incubator_add_component(
stream_transport stream_transport
tcp_sockets tcp_sockets
udp_datagram_socket) udp_datagram_socket)
if(TARGET OpenSSL::SSL AND TARGET OpenSSL::Crypto)
caf_incubator_add_test_suites(caf-net-test net.openssl_transport)
target_sources(caf-net-test PRIVATE test/net/openssl_transport_constants.cpp)
target_link_libraries(caf-net-test PRIVATE OpenSSL::SSL OpenSSL::Crypto)
endif()
...@@ -40,33 +40,32 @@ public: ...@@ -40,33 +40,32 @@ public:
// nop // nop
} }
// -- member functions ------------------------------------------------------- // -- interface functions ----------------------------------------------------
template <class LowerLayerPtr> template <class LowerLayerPtr>
error error init(socket_manager* owner, LowerLayerPtr down, const settings& cfg) {
init(socket_manager* owner, LowerLayerPtr parent, const settings& config) {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
owner_ = owner; owner_ = owner;
cfg_ = config; cfg_ = cfg;
if (auto err = factory_.init(owner, config)) if (auto err = factory_.init(owner, cfg))
return err; return err;
parent->register_reading(); down->register_reading();
return none; return none;
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
read_result handle_read_event(LowerLayerPtr parent) { read_result handle_read_event(LowerLayerPtr down) {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
if (auto x = accept(parent->handle())) { if (auto x = accept(down->handle())) {
socket_manager_ptr child = factory_.make(*x, owner_->mpx_ptr()); socket_manager_ptr child = factory_.make(*x, owner_->mpx_ptr());
if (!child) { if (!child) {
CAF_LOG_ERROR("factory failed to create a new child"); CAF_LOG_ERROR("factory failed to create a new child");
parent->abort_reason(sec::runtime_error); down->abort_reason(sec::runtime_error);
return read_result::stop; return read_result::stop;
} }
if (auto err = child->init(cfg_)) { if (auto err = child->init(cfg_)) {
CAF_LOG_ERROR("failed to initialize new child:" << err); CAF_LOG_ERROR("failed to initialize new child:" << err);
parent->abort_reason(std::move(err)); down->abort_reason(std::move(err));
return read_result::stop; return read_result::stop;
} }
if (limit_ == 0) { if (limit_ == 0) {
...@@ -81,8 +80,13 @@ public: ...@@ -81,8 +80,13 @@ public:
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
static void continue_reading(LowerLayerPtr) { static read_result handle_buffered_data(LowerLayerPtr) {
// nop return read_result::again;
}
template <class LowerLayerPtr>
static read_result handle_continue_reading(LowerLayerPtr) {
return read_result::again;
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
...@@ -91,6 +95,12 @@ public: ...@@ -91,6 +95,12 @@ public:
return write_result::stop; return write_result::stop;
} }
template <class LowerLayerPtr>
static write_result handle_continue_writing(LowerLayerPtr) {
CAF_LOG_ERROR("connection_acceptor received continue writing event");
return write_result::stop;
}
template <class LowerLayerPtr> template <class LowerLayerPtr>
void abort(LowerLayerPtr, const error& reason) { void abort(LowerLayerPtr, const error& reason) {
CAF_LOG_ERROR("connection_acceptor aborts due to an error: " << reason); CAF_LOG_ERROR("connection_acceptor aborts due to an error: " << reason);
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/net/socket_manager.hpp"
#include "caf/net/stream_transport_error.hpp"
namespace caf::net {
template <class OnSuccess, class OnError>
struct default_handshake_worker_factory {
OnSuccess make;
OnError abort;
};
/// An connect worker calls an asynchronous `connect` callback until it succeeds.
/// On success, the worker calls a factory object to transfer ownership of
/// socket and communication policy to the create the socket manager that takes
/// care of the established connection.
template <bool IsServer, class Socket, class Policy, class Factory>
class handshake_worker : public socket_manager {
public:
// -- member types -----------------------------------------------------------
using super = socket_manager;
using read_result = typename super::read_result;
using write_result = typename super::write_result;
handshake_worker(Socket handle, multiplexer* parent, Policy policy,
Factory factory)
: super(handle, parent),
policy_(std::move(policy)),
factory_(std::move(factory)) {
// nop
}
// -- interface functions ----------------------------------------------------
error init(const settings& config) override {
cfg_ = config;
register_writing();
return caf::none;
}
read_result handle_read_event() override {
auto fd = socket_cast<Socket>(this->handle());
if (auto res = advance_handshake(fd); res > 0) {
return read_result::handover;
} else if (res == 0) {
factory_.abort(make_error(sec::connection_closed));
return read_result::stop;
} else {
auto err = policy_.last_error(fd, res);
switch (err) {
case stream_transport_error::want_read:
case stream_transport_error::temporary:
return read_result::again;
case stream_transport_error::want_write:
return read_result::want_write;
default:
auto err = make_error(sec::cannot_connect_to_node,
policy_.fetch_error_str());
factory_.abort(std::move(err));
return read_result::stop;
}
}
}
read_result handle_buffered_data() override {
return read_result::again;
}
read_result handle_continue_reading() override {
return read_result::again;
}
write_result handle_write_event() override {
auto fd = socket_cast<Socket>(this->handle());
if (auto res = advance_handshake(fd); res > 0) {
return write_result::handover;
} else if (res == 0) {
factory_.abort(make_error(sec::connection_closed));
return write_result::stop;
} else {
switch (policy_.last_error(fd, res)) {
case stream_transport_error::want_write:
case stream_transport_error::temporary:
return write_result::again;
case stream_transport_error::want_read:
return write_result::want_read;
default:
auto err = make_error(sec::cannot_connect_to_node,
policy_.fetch_error_str());
factory_.abort(std::move(err));
return write_result::stop;
}
}
}
write_result handle_continue_writing() override {
return write_result::again;
}
void handle_error(sec code) override {
factory_.abort(make_error(code));
}
socket_manager_ptr make_next_manager(socket hdl) override {
auto ptr = factory_.make(socket_cast<Socket>(hdl), this->mpx_ptr(),
std::move(policy_));
if (ptr) {
if (auto err = ptr->init(cfg_)) {
factory_.abort(err);
return nullptr;
} else {
return ptr;
}
} else {
factory_.abort(make_error(sec::runtime_error, "factory_.make failed"));
return nullptr;
}
}
private:
ptrdiff_t advance_handshake(Socket fd) {
if constexpr (IsServer)
return policy_.accept(fd);
else
return policy_.connect(fd);
}
settings cfg_;
Policy policy_;
Factory factory_;
};
} // namespace caf::net
...@@ -47,6 +47,12 @@ public: ...@@ -47,6 +47,12 @@ public:
friend class pollset_updater; // Needs access to the `do_*` functions. friend class pollset_updater; // Needs access to the `do_*` functions.
// -- static utility functions -----------------------------------------------
/// Blocks the PIPE signal on the current thread when running on a POSIX
/// windows. Has no effect when running on Windows.
static void block_sigpipe();
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
/// @param parent Points to the owning middleman instance. May be `nullptr` /// @param parent Points to the owning middleman instance. May be `nullptr`
...@@ -91,6 +97,14 @@ public: ...@@ -91,6 +97,14 @@ public:
/// @thread-safe /// @thread-safe
void register_writing(const socket_manager_ptr& mgr); void register_writing(const socket_manager_ptr& mgr);
/// Triggers a continue reading event for `mgr`.
/// @thread-safe
void continue_reading(const socket_manager_ptr& mgr);
/// Triggers a continue writing event for `mgr`.
/// @thread-safe
void continue_writing(const socket_manager_ptr& mgr);
/// Schedules a call to `mgr->handle_error(sec::discarded)`. /// Schedules a call to `mgr->handle_error(sec::discarded)`.
/// @thread-safe /// @thread-safe
void discard(const socket_manager_ptr& mgr); void discard(const socket_manager_ptr& mgr);
...@@ -118,10 +132,9 @@ public: ...@@ -118,10 +132,9 @@ public:
/// @thread-safe /// @thread-safe
void init(const socket_manager_ptr& mgr); void init(const socket_manager_ptr& mgr);
/// Closes the pipe for signaling updates to the multiplexer. After closing /// Signals the multiplexer to initiate shutdown.
/// the pipe, calls to `update` no longer have any effect.
/// @thread-safe /// @thread-safe
void close_pipe(); void shutdown();
// -- control flow ----------------------------------------------------------- // -- control flow -----------------------------------------------------------
...@@ -138,16 +151,42 @@ public: ...@@ -138,16 +151,42 @@ public:
/// Polls until no socket event handler remains. /// Polls until no socket event handler remains.
void run(); void run();
/// Signals the multiplexer to initiate shutdown.
/// @thread-safe
void shutdown();
protected: protected:
// -- utility functions ------------------------------------------------------ // -- utility functions ------------------------------------------------------
/// Handles an I/O event on given manager. /// Handles an I/O event on given manager.
void handle(const socket_manager_ptr& mgr, short events, short revents); void handle(const socket_manager_ptr& mgr, short events, short revents);
/// Transfers socket ownership from one manager to another.
void do_handover(const socket_manager_ptr& mgr);
/// Returns a change entry for the socket at given index. Lazily creates a new
/// entry before returning if necessary.
poll_update& update_for(ptrdiff_t index);
/// Returns a change entry for the socket of the manager.
poll_update& update_for(const socket_manager_ptr& mgr);
/// Writes `opcode` and pointer to `mgr` the the pipe for handling an event
/// later via the pollset updater.
template <class T>
void write_to_pipe(uint8_t opcode, T* ptr);
/// @copydoc write_to_pipe
template <class Enum, class T>
std::enable_if_t<std::is_enum_v<Enum>> write_to_pipe(Enum opcode, T* ptr) {
write_to_pipe(static_cast<uint8_t>(opcode), ptr);
}
/// Queries the currently active event bitmask for `mgr`.
short active_mask_of(const socket_manager_ptr& mgr);
/// Queries whether `mgr` is currently registered for reading.
bool is_reading(const socket_manager_ptr& mgr);
/// Queries whether `mgr` is currently registered for writing.
bool is_writing(const socket_manager_ptr& mgr);
// -- member variables ------------------------------------------------------- // -- member variables -------------------------------------------------------
/// Bookkeeping data for managed sockets. /// Bookkeeping data for managed sockets.
...@@ -178,25 +217,7 @@ protected: ...@@ -178,25 +217,7 @@ protected:
bool shutting_down_ = false; bool shutting_down_ = false;
private: private:
/// Returns a change entry for the socket at given index. Lazily creates a new // -- internal callbacks the pollset updater ---------------------------------
/// entry before returning if necessary.
poll_update& update_for(ptrdiff_t index);
/// Returns a change entry for the socket of the manager.
poll_update& update_for(const socket_manager_ptr& mgr);
/// Writes `opcode` and pointer to `mgr` the the pipe for handling an event
/// later via the pollset updater.
template <class T>
void write_to_pipe(uint8_t opcode, T* ptr);
/// @copydoc write_to_pipe
template <class Enum, class T>
std::enable_if_t<std::is_enum_v<Enum>> write_to_pipe(Enum opcode, T* ptr) {
write_to_pipe(static_cast<uint8_t>(opcode), ptr);
}
// -- internal callback the pollset updater ----------------------------------
void do_shutdown(); void do_shutdown();
...@@ -204,6 +225,10 @@ private: ...@@ -204,6 +225,10 @@ private:
void do_register_writing(const socket_manager_ptr& mgr); void do_register_writing(const socket_manager_ptr& mgr);
void do_continue_reading(const socket_manager_ptr& mgr);
void do_continue_writing(const socket_manager_ptr& mgr);
void do_discard(const socket_manager_ptr& mgr); void do_discard(const socket_manager_ptr& mgr);
void do_shutdown_reading(const socket_manager_ptr& mgr); void do_shutdown_reading(const socket_manager_ptr& mgr);
......
This diff is collapsed.
...@@ -22,11 +22,11 @@ public: ...@@ -22,11 +22,11 @@ public:
using msg_buf = std::array<byte, sizeof(intptr_t) + 1>; using msg_buf = std::array<byte, sizeof(intptr_t) + 1>;
// -- constants --------------------------------------------------------------
enum class code : uint8_t { enum class code : uint8_t {
register_reading, register_reading,
continue_reading,
register_writing, register_writing,
continue_writing,
init_manager, init_manager,
discard_manager, discard_manager,
shutdown_reading, shutdown_reading,
...@@ -34,6 +34,7 @@ public: ...@@ -34,6 +34,7 @@ public:
run_action, run_action,
shutdown, shutdown,
}; };
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
pollset_updater(pipe_socket read_handle, multiplexer* parent); pollset_updater(pipe_socket read_handle, multiplexer* parent);
...@@ -53,11 +54,15 @@ public: ...@@ -53,11 +54,15 @@ public:
read_result handle_read_event() override; read_result handle_read_event() override;
read_result handle_buffered_data() override;
read_result handle_continue_reading() override;
write_result handle_write_event() override; write_result handle_write_event() override;
void handle_error(sec code) override; write_result handle_continue_writing() override;
void continue_reading() override; void handle_error(sec code) override;
private: private:
msg_buf buf_; msg_buf buf_;
......
...@@ -39,8 +39,7 @@ public: ...@@ -39,8 +39,7 @@ public:
void on_consumer_demand(size_t new_demand) override { void on_consumer_demand(size_t new_demand) override {
auto prev = demand_.fetch_add(new_demand); auto prev = demand_.fetch_add(new_demand);
if (prev == 0) if (prev == 0)
mgr_->mpx().schedule_fn( mgr_->continue_reading();
[adapter = strong_this()] { adapter->continue_reading(); });
} }
void ref_producer() const noexcept override { void ref_producer() const noexcept override {
...@@ -129,11 +128,6 @@ private: ...@@ -129,11 +128,6 @@ private:
// nop // nop
} }
void continue_reading() {
if (mgr_)
mgr_->continue_reading();
}
void on_cancel() { void on_cancel() {
if (buf_) { if (buf_) {
mgr_->mpx().shutdown_reading(mgr_); mgr_->mpx().shutdown_reading(mgr_);
......
...@@ -164,11 +164,31 @@ public: ...@@ -164,11 +164,31 @@ public:
// -- event loop management -------------------------------------------------- // -- event loop management --------------------------------------------------
/// Registers the manager for read operations on the @ref multiplexer. /// Registers the manager for read operations on the @ref multiplexer.
/// @thread-safe
void register_reading(); void register_reading();
/// Registers the manager for write operations on the @ref multiplexer. /// Registers the manager for write operations on the @ref multiplexer.
/// @thread-safe
void register_writing(); void register_writing();
/// Schedules a call to `handle_continue_reading` on the @ref multiplexer.
/// This mechanism allows users to signal changes in the environment to the
/// manager that allow it to make progress, e.g., new demand in asynchronous
/// buffer that allow the manager to push available data downstream. The event
/// is a no-op if the manager is already registered for reading.
/// @thread-safe
void continue_reading();
/// Schedules a call to `handle_continue_reading` on the @ref multiplexer.
/// This mechanism allows users to signal changes in the environment to the
/// manager that allow it to make progress, e.g., new data for writing in an
/// asynchronous buffer. The event is a no-op if the manager is already
/// registered for writing.
/// @thread-safe
void continue_writing();
// -- callbacks for the multiplexer ------------------------------------------
/// Performs a handover to another manager after `handle_read_event` or /// Performs a handover to another manager after `handle_read_event` or
/// `handle_read_event` returned `handover`. /// `handle_read_event` returned `handover`.
socket_manager_ptr do_handover(); socket_manager_ptr do_handover();
...@@ -181,20 +201,31 @@ public: ...@@ -181,20 +201,31 @@ public:
/// Called whenever the socket received new data. /// Called whenever the socket received new data.
virtual read_result handle_read_event() = 0; virtual read_result handle_read_event() = 0;
/// Called after handovers to allow the manager to process any data that is
/// already buffered at the transport policy and thus would not trigger a read
/// event on the socket.
virtual read_result handle_buffered_data() = 0;
/// Restarts a socket manager that suspended reads. Calling this member
/// function on active managers is a no-op. This function also should read any
/// data buffered outside of the socket.
virtual read_result handle_continue_reading() = 0;
/// Called whenever the socket is allowed to send data. /// Called whenever the socket is allowed to send data.
virtual write_result handle_write_event() = 0; virtual write_result handle_write_event() = 0;
/// Restarts a socket manager that suspended writes. Calling this member
/// function on active managers is a no-op.
virtual write_result handle_continue_writing() = 0;
/// Called when the remote side becomes unreachable due to an error. /// Called when the remote side becomes unreachable due to an error.
/// @param code The error code as reported by the operating system. /// @param code The error code as reported by the operating system.
virtual void handle_error(sec code) = 0; virtual void handle_error(sec code) = 0;
/// Restarts a socket manager that suspended reads. Calling this member
/// function on active managers is a no-op.
virtual void continue_reading() = 0;
/// Returns the new manager for the socket after `handle_read_event` or /// Returns the new manager for the socket after `handle_read_event` or
/// `handle_read_event` returned `handover`. /// `handle_read_event` returned `handover`.
/// @note When returning a non-null pointer, the new manager *must* also be /// @note Called from `do_handover`.
/// @note When returning a non-null pointer, the new manager *must* be
/// initialized. /// initialized.
virtual socket_manager_ptr make_next_manager(socket handle); virtual socket_manager_ptr make_next_manager(socket handle);
...@@ -236,17 +267,6 @@ public: ...@@ -236,17 +267,6 @@ public:
// nop // nop
} }
// -- initialization ---------------------------------------------------------
error init(const settings& config) override {
CAF_LOG_TRACE("");
if (auto err = nonblocking(handle(), true)) {
CAF_LOG_ERROR("failed to set nonblocking flag in socket:" << err);
return err;
}
return protocol_.init(static_cast<socket_manager*>(this), this, config);
}
// -- properties ------------------------------------------------------------- // -- properties -------------------------------------------------------------
/// Returns the managed socket. /// Returns the managed socket.
...@@ -254,42 +274,56 @@ public: ...@@ -254,42 +274,56 @@ public:
return socket_cast<socket_type>(this->handle_); return socket_cast<socket_type>(this->handle_);
} }
// -- event callbacks -------------------------------------------------------- auto& protocol() noexcept {
return protocol_;
}
read_result handle_read_event() override { const auto& protocol() const noexcept {
CAF_LOG_TRACE(""); return protocol_;
return protocol_.handle_read_event(this);
} }
write_result handle_write_event() override { auto& top_layer() noexcept {
return climb(protocol_);
}
const auto& top_layer() const noexcept {
return climb(protocol_);
}
// -- interface functions ----------------------------------------------------
error init(const settings& config) override {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
return protocol_.handle_write_event(this); if (auto err = nonblocking(handle(), true)) {
CAF_LOG_ERROR("failed to set nonblocking flag in socket:" << err);
return err;
}
return protocol_.init(static_cast<socket_manager*>(this), this, config);
} }
void handle_error(sec code) override { read_result handle_read_event() override {
CAF_LOG_TRACE(CAF_ARG(code)); return protocol_.handle_read_event(this);
this->abort_reason(make_error(code));
return protocol_.abort(this, this->abort_reason());
} }
void continue_reading() override { read_result handle_buffered_data() override {
return protocol_.continue_reading(this); return protocol_.handle_buffered_data(this);
} }
auto& protocol() noexcept { read_result handle_continue_reading() override {
return protocol_; return protocol_.handle_continue_reading(this);
} }
const auto& protocol() const noexcept { write_result handle_write_event() override {
return protocol_; return protocol_.handle_write_event(this);
} }
auto& top_layer() noexcept { write_result handle_continue_writing() override {
return climb(protocol_); return protocol_.handle_continue_writing(this);
} }
const auto& top_layer() const noexcept { void handle_error(sec code) override {
return climb(protocol_); this->abort_reason(make_error(code));
return protocol_.abort(this, this->abort_reason());
} }
private: private:
......
This diff is collapsed.
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/default_enum_inspect.hpp"
#include "caf/detail/net_export.hpp"
#include "caf/fwd.hpp"
#include "caf/is_error_code_enum.hpp"
#include <cstdint>
#include <string>
#include <type_traits>
namespace caf::net {
enum class stream_transport_error {
/// Indicates that the transport should try again later.
temporary,
/// Indicates that the transport must read data before trying again.
want_read,
/// Indicates that the transport must write data before trying again.
want_write,
/// Indicates that the transport cannot resume this operation.
permanent,
};
/// @relates stream_transport_error
CAF_NET_EXPORT std::string to_string(stream_transport_error);
/// @relates stream_transport_error
CAF_NET_EXPORT bool from_string(string_view, stream_transport_error&);
/// @relates stream_transport_error
CAF_NET_EXPORT bool from_integer(std::underlying_type_t<stream_transport_error>,
stream_transport_error&);
/// @relates stream_transport_error
template <class Inspector>
bool inspect(Inspector& f, stream_transport_error& x) {
return default_enum_inspect(f, x);
}
} // namespace caf::net
This diff is collapsed.
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
namespace caf::net { namespace caf::net {
// -- constructors, destructors, and assignment operators ----------------------
pollset_updater::pollset_updater(pipe_socket read_handle, multiplexer* parent) pollset_updater::pollset_updater(pipe_socket read_handle, multiplexer* parent)
: super(read_handle, parent) { : super(read_handle, parent) {
// nop // nop
...@@ -25,28 +27,22 @@ pollset_updater::~pollset_updater() { ...@@ -25,28 +27,22 @@ pollset_updater::~pollset_updater() {
// nop // nop
} }
// -- interface functions ------------------------------------------------------
error pollset_updater::init(const settings&) { error pollset_updater::init(const settings&) {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
return nonblocking(handle(), true); return nonblocking(handle(), true);
} }
namespace { pollset_updater::read_result pollset_updater::handle_read_event() {
CAF_LOG_TRACE("");
auto as_mgr(intptr_t ptr) { auto as_mgr = [](intptr_t ptr) {
CAF_LOG_TRACE(CAF_ARG(ptr));
return intrusive_ptr{reinterpret_cast<socket_manager*>(ptr), false}; return intrusive_ptr{reinterpret_cast<socket_manager*>(ptr), false};
} };
auto run_action = [](intptr_t ptr) {
void run_action(intptr_t ptr) {
CAF_LOG_TRACE(CAF_ARG(ptr));
auto f = action{intrusive_ptr{reinterpret_cast<action::impl*>(ptr), false}}; auto f = action{intrusive_ptr{reinterpret_cast<action::impl*>(ptr), false}};
f.run(); f.run();
} };
} // namespace
pollset_updater::read_result pollset_updater::handle_read_event() {
CAF_LOG_TRACE("");
for (;;) { for (;;) {
CAF_ASSERT((buf_.size() - buf_size_) > 0); CAF_ASSERT((buf_.size() - buf_size_) > 0);
auto num_bytes = read(handle(), make_span(buf_.data() + buf_size_, auto num_bytes = read(handle(), make_span(buf_.data() + buf_size_,
...@@ -62,9 +58,15 @@ pollset_updater::read_result pollset_updater::handle_read_event() { ...@@ -62,9 +58,15 @@ pollset_updater::read_result pollset_updater::handle_read_event() {
case code::register_reading: case code::register_reading:
mpx_->do_register_reading(as_mgr(ptr)); mpx_->do_register_reading(as_mgr(ptr));
break; break;
case code::continue_reading:
mpx_->do_continue_reading(as_mgr(ptr));
break;
case code::register_writing: case code::register_writing:
mpx_->do_register_writing(as_mgr(ptr)); mpx_->do_register_writing(as_mgr(ptr));
break; break;
case code::continue_writing:
mpx_->do_continue_writing(as_mgr(ptr));
break;
case code::init_manager: case code::init_manager:
mpx_->do_init(as_mgr(ptr)); mpx_->do_init(as_mgr(ptr));
break; break;
...@@ -100,16 +102,24 @@ pollset_updater::read_result pollset_updater::handle_read_event() { ...@@ -100,16 +102,24 @@ pollset_updater::read_result pollset_updater::handle_read_event() {
} }
} }
pollset_updater::read_result pollset_updater::handle_buffered_data() {
return read_result::again;
}
pollset_updater::read_result pollset_updater::handle_continue_reading() {
return read_result::again;
}
pollset_updater::write_result pollset_updater::handle_write_event() { pollset_updater::write_result pollset_updater::handle_write_event() {
return write_result::stop; return write_result::stop;
} }
void pollset_updater::handle_error(sec) { pollset_updater::write_result pollset_updater::handle_continue_writing() {
// nop return write_result::stop;
} }
void pollset_updater::continue_reading() { void pollset_updater::handle_error(sec) {
register_reading(); // nop
} }
} // namespace caf::net } // namespace caf::net
...@@ -43,15 +43,21 @@ void socket_manager::abort_reason(error reason) noexcept { ...@@ -43,15 +43,21 @@ void socket_manager::abort_reason(error reason) noexcept {
} }
void socket_manager::register_reading() { void socket_manager::register_reading() {
if (!read_closed())
mpx_->register_reading(this); mpx_->register_reading(this);
} }
void socket_manager::continue_reading() {
mpx_->continue_reading(this);
}
void socket_manager::register_writing() { void socket_manager::register_writing() {
if (!write_closed())
mpx_->register_writing(this); mpx_->register_writing(this);
} }
void socket_manager::continue_writing() {
mpx_->continue_writing(this);
}
socket_manager_ptr socket_manager::do_handover() { socket_manager_ptr socket_manager::do_handover() {
flags_.read_closed = true; flags_.read_closed = true;
flags_.write_closed = true; flags_.write_closed = true;
......
...@@ -26,6 +26,8 @@ using shared_atomic_count = std::shared_ptr<std::atomic<size_t>>; ...@@ -26,6 +26,8 @@ using shared_atomic_count = std::shared_ptr<std::atomic<size_t>>;
class dummy_manager : public socket_manager { class dummy_manager : public socket_manager {
public: public:
// -- constructors, destructors, and assignment operators --------------------
dummy_manager(stream_socket handle, multiplexer* parent, std::string name, dummy_manager(stream_socket handle, multiplexer* parent, std::string name,
shared_atomic_count count) shared_atomic_count count)
: socket_manager(handle, parent), name(std::move(name)), count_(count) { : socket_manager(handle, parent), name(std::move(name)), count_(count) {
...@@ -39,14 +41,31 @@ public: ...@@ -39,14 +41,31 @@ public:
--*count_; --*count_;
} }
error init(const settings&) override { // -- properties -------------------------------------------------------------
return none;
}
stream_socket handle() const noexcept { stream_socket handle() const noexcept {
return socket_cast<stream_socket>(handle_); return socket_cast<stream_socket>(handle_);
} }
// -- testing DSL ------------------------------------------------------------
void send(string_view x) {
auto x_bytes = as_bytes(make_span(x));
wr_buf_.insert(wr_buf_.end(), x_bytes.begin(), x_bytes.end());
}
std::string receive() {
std::string result(reinterpret_cast<char*>(rd_buf_.data()), rd_buf_pos_);
rd_buf_pos_ = 0;
return result;
}
// -- interface functions ----------------------------------------------------
error init(const settings&) override {
return none;
}
read_result handle_read_event() override { read_result handle_read_event() override {
if (trigger_handover) { if (trigger_handover) {
MESSAGE(name << " triggered a handover"); MESSAGE(name << " triggered a handover");
...@@ -67,6 +86,14 @@ public: ...@@ -67,6 +86,14 @@ public:
} }
} }
read_result handle_buffered_data() override {
return read_result::again;
}
read_result handle_continue_reading() override {
return read_result::again;
}
write_result handle_write_event() override { write_result handle_write_event() override {
if (trigger_handover) { if (trigger_handover) {
MESSAGE(name << " triggered a handover"); MESSAGE(name << " triggered a handover");
...@@ -84,12 +111,12 @@ public: ...@@ -84,12 +111,12 @@ public:
: write_result::stop; : write_result::stop;
} }
void handle_error(sec code) override { write_result handle_continue_writing() override {
FAIL("handle_error called with code " << code); return write_result::again;
} }
void continue_reading() override { void handle_error(sec code) override {
FAIL("continue_reading called"); FAIL("handle_error called with code " << code);
} }
socket_manager_ptr make_next_manager(socket handle) override { socket_manager_ptr make_next_manager(socket handle) override {
...@@ -102,16 +129,7 @@ public: ...@@ -102,16 +129,7 @@ public:
return next; return next;
} }
void send(string_view x) { // --
auto x_bytes = as_bytes(make_span(x));
wr_buf_.insert(wr_buf_.end(), x_bytes.begin(), x_bytes.end());
}
std::string receive() {
std::string result(reinterpret_cast<char*>(rd_buf_.data()), rd_buf_pos_);
rd_buf_pos_ = 0;
return result;
}
bool trigger_handover = false; bool trigger_handover = false;
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#define CAF_SUITE net.openssl_transport
#include "caf/net/openssl_transport.hpp"
#include "net-test.hpp"
#include "caf/binary_deserializer.hpp"
#include "caf/binary_serializer.hpp"
#include "caf/byte.hpp"
#include "caf/byte_buffer.hpp"
#include "caf/detail/scope_guard.hpp"
#include "caf/make_actor.hpp"
#include "caf/net/actor_proxy_impl.hpp"
#include "caf/net/multiplexer.hpp"
#include "caf/net/socket_guard.hpp"
#include "caf/net/socket_manager.hpp"
#include "caf/net/stream_socket.hpp"
#include "caf/span.hpp"
#include <filesystem>
#include <random>
// Note: these constants are defined in openssl_transport_constants.cpp.
extern std::string_view ca_pem;
extern std::string_view cert_1_pem;
extern std::string_view cert_2_pem;
extern std::string_view key_1_enc_pem;
extern std::string_view key_1_pem;
extern std::string_view key_2_pem;
using namespace caf;
using namespace caf::net;
namespace {
using byte_buffer_ptr = std::shared_ptr<byte_buffer>;
struct fixture : host_fixture {
using byte_buffer_ptr = std::shared_ptr<byte_buffer>;
fixture(){
multiplexer::block_sigpipe();
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, nullptr);
// Make a directory name with 8 random (hex) character suffix.
std::string dir_name = "caf-net-test-";
std::random_device rd;
std::minstd_rand rng{rd()};
std::uniform_int_distribution<int> dist(0, 255);
for (int i = 0; i < 4; ++i)
detail::append_hex(dir_name, static_cast<uint8_t>(dist(rng)));
// Create the directory under /tmp (or its equivalent on non-POSIX).
namespace fs = std::filesystem;
tmp_dir = fs::temp_directory_path() / dir_name;
if (!fs::create_directory(tmp_dir)) {
std::cerr << "*** failed to create " << tmp_dir.string() << "\n";
abort();
}
// Create the .pem files on disk.
write_file("ca.pem", ca_pem);
write_file("cert.1.pem", cert_1_pem);
write_file("cert.2.pem", cert_1_pem);
write_file("key.1.enc.pem", key_1_enc_pem);
write_file("key.1.pem", key_1_pem);
write_file("key.2.pem", key_1_pem);
}
~fixture() {
// Clean up our files under /tmp.
if (!tmp_dir.empty())
std::filesystem::remove_all(tmp_dir);
}
std::string abs_path(std::string_view fname) {
auto path = tmp_dir / fname;
return path.string();
}
void write_file(std::string_view fname, std::string_view content) {
std::ofstream out{abs_path(fname)};
out << content;
}
std::filesystem::path tmp_dir;
};
class dummy_app {
public:
using input_tag = tag::stream_oriented;
explicit dummy_app(std::shared_ptr<bool> done, byte_buffer_ptr recv_buf)
: done_(std::move(done)), recv_buf_(std::move(recv_buf)) {
// nop
}
~dummy_app() {
*done_ = true;
}
template <class ParentPtr>
error init(socket_manager*, ParentPtr parent, const settings&) {
MESSAGE("initialize dummy app");
parent->configure_read(receive_policy::exactly(4));
parent->begin_output();
auto& buf = parent->output_buffer();
caf::binary_serializer out{nullptr, buf};
static_cast<void>(out.apply(10));
parent->end_output();
return none;
}
template <class ParentPtr>
bool prepare_send(ParentPtr) {
return true;
}
template <class ParentPtr>
bool done_sending(ParentPtr) {
return true;
}
template <class ParentPtr>
void continue_reading(ParentPtr) {
// nop
}
template <class ParentPtr>
size_t consume(ParentPtr down, span<const byte> data, span<const byte>) {
MESSAGE("dummy app received " << data.size() << " bytes");
// Store the received bytes.
recv_buf_->insert(recv_buf_->begin(), data.begin(), data.end());
// Respond with the same data and return.
down->begin_output();
auto& out = down->output_buffer();
out.insert(out.end(), data.begin(), data.end());
down->end_output();
return static_cast<ptrdiff_t>(data.size());
}
template <class ParentPtr>
void abort(ParentPtr, const error& reason) {
MESSAGE("dummy_app::abort called: " << reason);
*done_ = true;
}
private:
std::shared_ptr<bool> done_;
byte_buffer_ptr recv_buf_;
};
// Simulates a remote SSL server.
void dummy_tls_server(stream_socket fd, std::string cert_file,
std::string key_file) {
namespace ssl = caf::net::openssl;
multiplexer::block_sigpipe();
// Make sure we close our socket.
auto guard = detail::make_scope_guard([fd] { close(fd); });
// Get and configure our SSL context.
auto ctx = ssl::make_ctx(TLS_server_method());
if (auto err = ssl::certificate_pem_file(ctx, cert_file)) {
std::cerr << "*** certificate_pem_file failed: " << ssl::fetch_error_str();
return;
}
if (auto err = ssl::private_key_pem_file(ctx, key_file)) {
std::cerr << "*** private_key_pem_file failed: " << ssl::fetch_error_str();
return;
}
// Perform SSL handshake.
auto f = net::openssl::policy::make(std::move(ctx), fd);
if (f.accept(fd) <= 0) {
std::cerr << "*** accept failed: " << ssl::fetch_error_str();
return;
}
// Do some ping-pong messaging.
for (int i = 0; i < 4; ++i) {
byte_buffer buf;
buf.resize(4);
f.read(fd, buf);
f.write(fd, buf);
}
// Graceful shutdown.
f.notify_close();
}
// Simulates a remote SSL client.
void dummy_tls_client(stream_socket fd) {
multiplexer::block_sigpipe();
// Make sure we close our socket.
auto guard = detail::make_scope_guard([fd] { close(fd); });
// Perform SSL handshake.
auto f = net::openssl::policy::make(TLS_client_method(), fd);
if (f.connect(fd) <= 0) {
ERR_print_errors_fp(stderr);
return;
}
// Do some ping-pong messaging.
for (int i = 0; i < 4; ++i) {
byte_buffer buf;
buf.resize(4);
f.read(fd, buf);
f.write(fd, buf);
}
// Graceful shutdown.
f.notify_close();
}
} // namespace
BEGIN_FIXTURE_SCOPE(fixture)
SCENARIO("openssl::async_connect performs the client handshake") {
GIVEN("a connection to a TLS server") {
auto [serv_fd, client_fd] = unbox(make_stream_socket_pair());
if (auto err = net::nonblocking(client_fd, true))
FAIL("net::nonblocking failed: " << err);
std::thread server{dummy_tls_server, serv_fd, abs_path("cert.1.pem"),
abs_path("key.1.pem")};
WHEN("connecting as a client to an OpenSSL server") {
THEN("openssl::async_connect transparently calls SSL_connect") {
using stack_t = openssl_transport<dummy_app>;
net::multiplexer mpx{nullptr};
mpx.set_thread_id();
auto done = std::make_shared<bool>(false);
auto buf = std::make_shared<byte_buffer>();
auto make_manager = [done, buf](stream_socket fd, multiplexer* ptr,
net::openssl::policy policy) {
return make_socket_manager<stack_t>(fd, ptr, std::move(policy), done,
buf);
};
auto on_connect_error = [](const error& reason) {
FAIL("connect failed: " << reason);
};
net::openssl::async_connect(client_fd, &mpx,
net::openssl::policy::make(SSLv23_method(),
client_fd),
make_manager, on_connect_error);
mpx.apply_updates();
while (!*done)
mpx.poll_once(true);
if (CHECK_EQ(buf->size(), 16u)) { // 4x 32-bit integers
caf::binary_deserializer src{nullptr, *buf};
for (int i = 0; i < 4; ++i) {
int32_t value = 0;
static_cast<void>(src.apply(value));
CHECK_EQ(value, 10);
}
}
}
}
server.join();
}
}
SCENARIO("openssl::async_accept performs the server handshake") {
GIVEN("a socket that is connected to a client") {
auto [serv_fd, client_fd] = unbox(make_stream_socket_pair());
if (auto err = net::nonblocking(serv_fd, true))
FAIL("net::nonblocking failed: " << err);
std::thread client{dummy_tls_client, client_fd};
WHEN("acting as the OpenSSL server") {
THEN("openssl::async_accept transparently calls SSL_accept") {
using stack_t = openssl_transport<dummy_app>;
net::multiplexer mpx{nullptr};
mpx.set_thread_id();
auto done = std::make_shared<bool>(false);
auto buf = std::make_shared<byte_buffer>();
auto make_manager = [done, buf](stream_socket fd, multiplexer* ptr,
net::openssl::policy policy) {
return make_socket_manager<stack_t>(fd, ptr, std::move(policy), done,
buf);
};
auto on_accept_error = [](const error& reason) {
FAIL("accept failed: " << reason);
};
auto ssl = net::openssl::policy::make(TLS_server_method(), serv_fd);
if (auto err = ssl.certificate_pem_file(abs_path("cert.1.pem")))
FAIL("certificate_pem_file failed: " << err);
if (auto err = ssl.private_key_pem_file(abs_path("key.1.pem")))
FAIL("privat_key_pem_file failed: " << err);
net::openssl::async_accept(serv_fd, &mpx, std::move(ssl), make_manager,
on_accept_error);
mpx.apply_updates();
while (!*done)
mpx.poll_once(true);
if (CHECK_EQ(buf->size(), 16u)) { // 4x 32-bit integers
caf::binary_deserializer src{nullptr, *buf};
for (int i = 0; i < 4; ++i) {
int32_t value = 0;
static_cast<void>(src.apply(value));
CHECK_EQ(value, 10);
}
}
}
}
client.join();
}
}
END_FIXTURE_SCOPE()
This diff is collapsed.
...@@ -125,6 +125,7 @@ struct fixture : test_coordinator_fixture<>, host_fixture { ...@@ -125,6 +125,7 @@ struct fixture : test_coordinator_fixture<>, host_fixture {
} }
bool handle_io_event() override { bool handle_io_event() override {
mm.mpx().apply_updates();
return mm.mpx().poll_once(false); return mm.mpx().poll_once(false);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment