Commit f1a3b267 authored by Dominik Charousset's avatar Dominik Charousset

Allow socket managers to transfer socket ownership

parent 085a1045
...@@ -28,6 +28,10 @@ public: ...@@ -28,6 +28,10 @@ public:
using factory_type = Factory; using factory_type = Factory;
using read_result = typename socket_manager::read_result;
using write_result = typename socket_manager::write_result;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
template <class... Ts> template <class... Ts>
...@@ -51,28 +55,28 @@ public: ...@@ -51,28 +55,28 @@ public:
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool handle_read_event(LowerLayerPtr parent) { read_result handle_read_event(LowerLayerPtr parent) {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
if (auto x = accept(parent->handle())) { if (auto x = accept(parent->handle())) {
socket_manager_ptr child = factory_.make(*x, owner_->mpx_ptr()); socket_manager_ptr child = factory_.make(*x, owner_->mpx_ptr());
if (!child) { if (!child) {
CAF_LOG_ERROR("factory failed to create a new child"); CAF_LOG_ERROR("factory failed to create a new child");
parent->abort_reason(sec::runtime_error); parent->abort_reason(sec::runtime_error);
return false; return read_result::stop;
} }
if (auto err = child->init(cfg_)) { if (auto err = child->init(cfg_)) {
CAF_LOG_ERROR("failed to initialize new child:" << err); CAF_LOG_ERROR("failed to initialize new child:" << err);
parent->abort_reason(std::move(err)); parent->abort_reason(std::move(err));
return false; return read_result::stop;
} }
if (limit_ == 0) { if (limit_ == 0) {
return true; return read_result::again;
} else { } else {
return ++accepted_ < limit_; return ++accepted_ < limit_ ? read_result::again : read_result::stop;
} }
} else { } else {
CAF_LOG_ERROR("accept failed:" << x.error()); CAF_LOG_ERROR("accept failed:" << x.error());
return false; return read_result::stop;
} }
} }
...@@ -82,9 +86,9 @@ public: ...@@ -82,9 +86,9 @@ public:
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool handle_write_event(LowerLayerPtr) { write_result handle_write_event(LowerLayerPtr) {
CAF_LOG_ERROR("connection_acceptor received write event"); CAF_LOG_ERROR("connection_acceptor received write event");
return false; return write_result::stop;
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
......
...@@ -23,6 +23,10 @@ public: ...@@ -23,6 +23,10 @@ public:
using application_type = typename transport_type::application_type; using application_type = typename transport_type::application_type;
using read_result = typename super::read_result;
using write_result = typename super::write_result;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
endpoint_manager_impl(const multiplexer_ptr& parent, actor_system& sys, endpoint_manager_impl(const multiplexer_ptr& parent, actor_system& sys,
...@@ -52,11 +56,11 @@ public: ...@@ -52,11 +56,11 @@ public:
return transport_.init(*this); return transport_.init(*this);
} }
bool handle_read_event() override { read_result handle_read_event() override {
return transport_.handle_read_event(*this); return transport_.handle_read_event(*this);
} }
bool handle_write_event() override { write_result handle_write_event() override {
if (!this->queue_.blocked()) { if (!this->queue_.blocked()) {
this->queue_.fetch_more(); this->queue_.fetch_more();
auto& q = std::get<0>(this->queue_.queue().queues()); auto& q = std::get<0>(this->queue_.queue().queues());
...@@ -83,10 +87,13 @@ public: ...@@ -83,10 +87,13 @@ public:
} }
if (!transport_.handle_write_event(*this)) { if (!transport_.handle_write_event(*this)) {
if (this->queue_.blocked()) if (this->queue_.blocked())
return false; return write_result::stop;
return !(this->queue_.empty() && this->queue_.try_block()); else if (!(this->queue_.empty() && this->queue_.try_block()))
return write_result::again;
else
return write_result::stop;
} }
return true; return write_result::again;
} }
void handle_error(sec code) override { void handle_error(sec code) override {
......
...@@ -117,7 +117,7 @@ public: ...@@ -117,7 +117,7 @@ public:
std::string len; std::string len;
header_fields_type fields; header_fields_type fields;
if (!content.empty()) { if (!content.empty()) {
auto len = std::to_string(content.size()); len = std::to_string(content.size());
fields.emplace("Content-Type", content_type); fields.emplace("Content-Type", content_type);
fields.emplace("Content-Length", len); fields.emplace("Content-Length", len);
} }
......
...@@ -160,7 +160,7 @@ public: ...@@ -160,7 +160,7 @@ public:
void abort(LowerLayerPtr, const error& reason) { void abort(LowerLayerPtr, const error& reason) {
CAF_LOG_TRACE(CAF_ARG(reason)); CAF_LOG_TRACE(CAF_ARG(reason));
if (out_) { if (out_) {
if (reason == sec::socket_disconnected || reason == sec::discarded) if (reason == sec::socket_disconnected || reason == sec::disposed)
out_->close(); out_->close();
else else
out_->abort(reason); out_->abort(reason);
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "caf/action.hpp" #include "caf/action.hpp"
#include "caf/detail/net_export.hpp" #include "caf/detail/net_export.hpp"
#include "caf/detail/unordered_flat_map.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/operation.hpp" #include "caf/net/operation.hpp"
#include "caf/net/pipe_socket.hpp" #include "caf/net/pipe_socket.hpp"
...@@ -24,15 +25,28 @@ struct pollfd; ...@@ -24,15 +25,28 @@ struct pollfd;
namespace caf::net { namespace caf::net {
class pollset_updater;
/// Multiplexes any number of ::socket_manager objects with a ::socket. /// Multiplexes any number of ::socket_manager objects with a ::socket.
class CAF_NET_EXPORT multiplexer { class CAF_NET_EXPORT multiplexer {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
struct poll_update {
short events = 0;
socket_manager_ptr mgr;
};
using poll_update_map = detail::unordered_flat_map<socket, poll_update>;
using pollfd_list = std::vector<pollfd>; using pollfd_list = std::vector<pollfd>;
using manager_list = std::vector<socket_manager_ptr>; using manager_list = std::vector<socket_manager_ptr>;
// -- friends ----------------------------------------------------------------
friend class pollset_updater; // Needs access to the `do_*` functions.
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
/// @param parent Points to the owning middleman instance. May be `nullptr` /// @param parent Points to the owning middleman instance. May be `nullptr`
...@@ -55,12 +69,18 @@ public: ...@@ -55,12 +69,18 @@ public:
/// Returns the index of `mgr` in the pollset or `-1`. /// Returns the index of `mgr` in the pollset or `-1`.
ptrdiff_t index_of(const socket_manager_ptr& mgr); ptrdiff_t index_of(const socket_manager_ptr& mgr);
/// Returns the index of `fd` in the pollset or `-1`.
ptrdiff_t index_of(socket fd);
/// Returns the owning @ref middleman instance. /// Returns the owning @ref middleman instance.
middleman& owner(); middleman& owner();
/// Returns the enclosing @ref actor_system. /// Returns the enclosing @ref actor_system.
actor_system& system(); actor_system& system();
/// Computes the current mask for the manager. Mostly useful for testing.
operation mask_of(const socket_manager_ptr& mgr);
// -- thread-safe signaling -------------------------------------------------- // -- thread-safe signaling --------------------------------------------------
/// Registers `mgr` for read events. /// Registers `mgr` for read events.
...@@ -109,6 +129,9 @@ public: ...@@ -109,6 +129,9 @@ public:
/// ready as a result. /// ready as a result.
bool poll_once(bool blocking); bool poll_once(bool blocking);
/// Applies all pending updates.
void apply_updates();
/// Sets the thread ID to `std::this_thread::id()`. /// Sets the thread ID to `std::this_thread::id()`.
void set_thread_id(); void set_thread_id();
...@@ -123,13 +146,7 @@ protected: ...@@ -123,13 +146,7 @@ protected:
// -- utility functions ------------------------------------------------------ // -- utility functions ------------------------------------------------------
/// Handles an I/O event on given manager. /// Handles an I/O event on given manager.
short handle(const socket_manager_ptr& mgr, short events, short revents); void handle(const socket_manager_ptr& mgr, short events, short revents);
/// Adds a new socket manager to the pollset.
void add(socket_manager_ptr mgr);
/// Deletes a known socket manager from the pollset.
void del(ptrdiff_t index);
// -- member variables ------------------------------------------------------- // -- member variables -------------------------------------------------------
...@@ -140,6 +157,10 @@ protected: ...@@ -140,6 +157,10 @@ protected:
/// order as their sockets appear in `pollset_`. /// order as their sockets appear in `pollset_`.
manager_list managers_; manager_list managers_;
/// Caches changes to the events mask of managed sockets until they can safely
/// take place.
poll_update_map updates_;
/// Stores the ID of the thread this multiplexer is running in. Set when /// Stores the ID of the thread this multiplexer is running in. Set when
/// calling `init()`. /// calling `init()`.
std::thread::id tid_; std::thread::id tid_;
...@@ -157,15 +178,39 @@ protected: ...@@ -157,15 +178,39 @@ protected:
bool shutting_down_ = false; bool shutting_down_ = false;
private: private:
/// Returns a change entry for the socket at given index. Lazily creates a new
/// entry before returning if necessary.
poll_update& update_for(ptrdiff_t index);
/// Returns a change entry for the socket of the manager.
poll_update& update_for(const socket_manager_ptr& mgr);
/// Writes `opcode` and pointer to `mgr` the the pipe for handling an event /// Writes `opcode` and pointer to `mgr` the the pipe for handling an event
/// later via the pollset updater. /// later via the pollset updater.
template <class T> template <class T>
void write_to_pipe(uint8_t opcode, T* ptr); void write_to_pipe(uint8_t opcode, T* ptr);
/// @copydoc write_to_pipe
template <class Enum, class T> template <class Enum, class T>
std::enable_if_t<std::is_enum_v<Enum>> write_to_pipe(Enum opcode, T* ptr) { std::enable_if_t<std::is_enum_v<Enum>> write_to_pipe(Enum opcode, T* ptr) {
write_to_pipe(static_cast<uint8_t>(opcode), ptr); write_to_pipe(static_cast<uint8_t>(opcode), ptr);
} }
// -- internal callback the pollset updater ----------------------------------
void do_shutdown();
void do_register_reading(const socket_manager_ptr& mgr);
void do_register_writing(const socket_manager_ptr& mgr);
void do_discard(const socket_manager_ptr& mgr);
void do_shutdown_reading(const socket_manager_ptr& mgr);
void do_shutdown_writing(const socket_manager_ptr& mgr);
void do_init(const socket_manager_ptr& mgr);
}; };
} // namespace caf::net } // namespace caf::net
...@@ -51,9 +51,9 @@ public: ...@@ -51,9 +51,9 @@ public:
error init(const settings& config) override; error init(const settings& config) override;
bool handle_read_event() override; read_result handle_read_event() override;
bool handle_write_event() override; write_result handle_write_event() override;
void handle_error(sec code) override; void handle_error(sec code) override;
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "caf/make_counted.hpp" #include "caf/make_counted.hpp"
#include "caf/net/actor_shell.hpp" #include "caf/net/actor_shell.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/operation.hpp"
#include "caf/net/socket.hpp" #include "caf/net/socket.hpp"
#include "caf/net/typed_actor_shell.hpp" #include "caf/net/typed_actor_shell.hpp"
#include "caf/ref_counted.hpp" #include "caf/ref_counted.hpp"
...@@ -28,13 +27,48 @@ class CAF_NET_EXPORT socket_manager : public ref_counted { ...@@ -28,13 +27,48 @@ class CAF_NET_EXPORT socket_manager : public ref_counted {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
/// A callback for unprocessed messages.
using fallback_handler = unique_callback_ptr<result<message>(message&)>; using fallback_handler = unique_callback_ptr<result<message>(message&)>;
/// Encodes how a manager wishes to proceed after a read operation.
enum class read_result {
/// Indicates that a manager wants to read again later.
again,
/// Indicates that a manager wants to stop reading until explicitly resumed.
stop,
/// Indicates that a manager wants to write to the socket instead of reading
/// from the socket.
want_write,
/// Indicates that a manager is done with the socket and hands ownership to
/// another manager.
handover,
};
/// Encodes how a manager wishes to proceed after a write operation.
enum class write_result {
/// Indicates that a manager wants to read again later.
again,
/// Indicates that a manager wants to stop reading until explicitly resumed.
stop,
/// Indicates that a manager wants to read from the socket instead of
/// writing to the socket.
want_read,
/// Indicates that a manager is done with the socket and hands ownership to
/// another manager.
handover,
};
/// Stores manager-related flags in a single block.
struct flags_t {
bool read_closed : 1;
bool write_closed : 1;
};
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
/// @pre `handle != invalid_socket` /// @pre `handle != invalid_socket`
/// @pre `parent != nullptr` /// @pre `mpx!= nullptr`
socket_manager(socket handle, multiplexer* parent); socket_manager(socket handle, multiplexer* mpx);
~socket_manager() override; ~socket_manager() override;
...@@ -59,69 +93,45 @@ public: ...@@ -59,69 +93,45 @@ public:
/// Returns the owning @ref multiplexer instance. /// Returns the owning @ref multiplexer instance.
multiplexer& mpx() noexcept { multiplexer& mpx() noexcept {
return *parent_; return *mpx_;
} }
/// Returns the owning @ref multiplexer instance. /// Returns the owning @ref multiplexer instance.
const multiplexer& mpx() const noexcept { const multiplexer& mpx() const noexcept {
return *parent_; return *mpx_;
} }
/// Returns a pointer to the owning @ref multiplexer instance. /// Returns a pointer to the owning @ref multiplexer instance.
multiplexer* mpx_ptr() noexcept { multiplexer* mpx_ptr() noexcept {
return parent_; return mpx_;
} }
/// Returns a pointer to the owning @ref multiplexer instance. /// Returns a pointer to the owning @ref multiplexer instance.
const multiplexer* mpx_ptr() const noexcept { const multiplexer* mpx_ptr() const noexcept {
return parent_; return mpx_;
} }
/// Returns registered operations (read, write, or both). /// Closes the read channel of the socket.
operation mask() const noexcept { void close_read() noexcept;
return mask_;
}
/// Convenience function for checking whether `mask()` contains the read bit. /// Closes the write channel of the socket.
bool is_reading() const noexcept { void close_write() noexcept;
return net::is_reading(mask_);
}
/// Convenience function for checking whether `mask()` contains the write bit. /// Returns whether the manager closed read operations on the socket.
bool is_writing() const noexcept { [[nodiscard]] bool read_closed() const noexcept {
return net::is_writing(mask_); return flags_.read_closed;
} }
/// Tries to add the read flag to the event mask. /// Returns whether the manager closed write operations on the socket.
/// @returns `true` if the flag was added, `false` if this call had no effect. [[nodiscard]] bool write_closed() const noexcept {
bool set_read_flag() noexcept; return flags_.write_closed;
}
/// Tries to add the write flag to the event mask.
/// @returns `true` if the flag was added, `false` if this call had no effect.
bool set_write_flag() noexcept;
/// Removes the read flag from the event mask if present.
bool unset_read_flag() noexcept;
/// Removes the write flag from the event mask if present.
bool unset_write_flag() noexcept;
/// Adds the `block_read` flag to the event mask.
void block_reads() noexcept;
/// Adds the `block_write` flag to the event mask.
void block_writes() noexcept;
/// Blocks reading and writing in the event mask.
void block_reads_and_writes() noexcept;
const error& abort_reason() const noexcept { const error& abort_reason() const noexcept {
return abort_reason_; return abort_reason_;
} }
void abort_reason(error reason) noexcept { void abort_reason(error reason) noexcept;
abort_reason_ = std::move(reason);
}
template <class... Ts> template <class... Ts>
const error& abort_reason_or(Ts&&... xs) { const error& abort_reason_or(Ts&&... xs) {
...@@ -153,23 +163,26 @@ public: ...@@ -153,23 +163,26 @@ public:
// -- event loop management -------------------------------------------------- // -- event loop management --------------------------------------------------
/// Registers the manager for read operations on the @ref multiplexer.
void register_reading(); void register_reading();
/// Registers the manager for write operations on the @ref multiplexer.
void register_writing(); void register_writing();
void shutdown_reading(); /// Performs a handover to another manager after `handle_read_event` or
/// `handle_read_event` returned `handover`.
void shutdown_writing(); socket_manager_ptr do_handover();
// -- pure virtual member functions ------------------------------------------ // -- pure virtual member functions ------------------------------------------
/// Initializes the manager and its all of its sub-components.
virtual error init(const settings& config) = 0; virtual error init(const settings& config) = 0;
/// Called whenever the socket received new data. /// Called whenever the socket received new data.
virtual bool handle_read_event() = 0; virtual read_result handle_read_event() = 0;
/// Called whenever the socket is allowed to send data. /// Called whenever the socket is allowed to send data.
virtual bool handle_write_event() = 0; virtual write_result handle_write_event() = 0;
/// Called when the remote side becomes unreachable due to an error. /// Called when the remote side becomes unreachable due to an error.
/// @param code The error code as reported by the operating system. /// @param code The error code as reported by the operating system.
...@@ -179,16 +192,25 @@ public: ...@@ -179,16 +192,25 @@ public:
/// function on active managers is a no-op. /// function on active managers is a no-op.
virtual void continue_reading() = 0; virtual void continue_reading() = 0;
/// Returns the new manager for the socket after `handle_read_event` or
/// `handle_read_event` returned `handover`.
/// @note When returning a non-null pointer, the new manager *must* also be
/// initialized.
virtual socket_manager_ptr make_next_manager(socket handle);
protected: protected:
// -- member variables ------------------------------------------------------- // -- protected member variables ---------------------------------------------
socket handle_; socket handle_;
operation mask_; multiplexer* mpx_;
multiplexer* parent_; private:
// -- private member variables -----------------------------------------------
error abort_reason_; error abort_reason_;
flags_t flags_;
}; };
template <class Protocol> template <class Protocol>
...@@ -196,10 +218,16 @@ class socket_manager_impl : public socket_manager { ...@@ -196,10 +218,16 @@ class socket_manager_impl : public socket_manager {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
using super = socket_manager;
using output_tag = tag::io_event_oriented; using output_tag = tag::io_event_oriented;
using socket_type = typename Protocol::socket_type; using socket_type = typename Protocol::socket_type;
using read_result = typename super::read_result;
using write_result = typename super::write_result;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
template <class... Ts> template <class... Ts>
...@@ -228,20 +256,20 @@ public: ...@@ -228,20 +256,20 @@ public:
// -- event callbacks -------------------------------------------------------- // -- event callbacks --------------------------------------------------------
bool handle_read_event() override { read_result handle_read_event() override {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
return protocol_.handle_read_event(this); return protocol_.handle_read_event(this);
} }
bool handle_write_event() override { write_result handle_write_event() override {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
return protocol_.handle_write_event(this); return protocol_.handle_write_event(this);
} }
void handle_error(sec code) override { void handle_error(sec code) override {
CAF_LOG_TRACE(CAF_ARG(code)); CAF_LOG_TRACE(CAF_ARG(code));
abort_reason_ = code; this->abort_reason(make_error(code));
return protocol_.abort(this, abort_reason_); return protocol_.abort(this, this->abort_reason());
} }
void continue_reading() override { void continue_reading() override {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "caf/logger.hpp" #include "caf/logger.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/receive_policy.hpp" #include "caf/net/receive_policy.hpp"
#include "caf/net/socket_manager.hpp"
#include "caf/net/stream_oriented_layer_ptr.hpp" #include "caf/net/stream_oriented_layer_ptr.hpp"
#include "caf/net/stream_socket.hpp" #include "caf/net/stream_socket.hpp"
#include "caf/sec.hpp" #include "caf/sec.hpp"
...@@ -23,9 +24,46 @@ ...@@ -23,9 +24,46 @@
namespace caf::net { namespace caf::net {
enum class stream_transport_error {
/// Indicates that the transport should try again later.
temporary,
/// Indicates that the transport must read data before trying again.
want_read,
/// Indicates that the transport must write data before trying again.
want_write,
/// Indicates that the transport cannot resume this operation.
permanent,
};
/// Configures a stream transport with default socket operations.
struct default_stream_transport_policy {
public:
/// Reads data from the socket into the buffer.
static ptrdiff_t read(stream_socket x, span<byte> buf) {
return net::read(x, buf);
}
/// Writes data from the buffer to the socket.
static ptrdiff_t write(stream_socket x, span<const byte> buf) {
return net::write(x, buf);
}
/// Returns the last socket error on this thread.
static stream_transport_error last_error(stream_socket, ptrdiff_t) {
return last_socket_error_is_temporary() ? stream_transport_error::temporary
: stream_transport_error::permanent;
}
/// Returns the number of bytes that are buffered internally and that
/// available for immediate read.
static constexpr size_t buffered() {
return 0;
}
};
/// Implements a stream_transport that manages a stream socket. /// Implements a stream_transport that manages a stream socket.
template <class UpperLayer> template <class Policy, class UpperLayer>
class stream_transport { class stream_transport_base {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
...@@ -35,12 +73,16 @@ public: ...@@ -35,12 +73,16 @@ public:
using socket_type = stream_socket; using socket_type = stream_socket;
using read_result = typename socket_manager::read_result;
using write_result = typename socket_manager::write_result;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
template <class... Ts> template <class... Ts>
explicit stream_transport(Ts&&... xs) explicit stream_transport_base(Policy policy, Ts&&... xs)
: upper_layer_(std::forward<Ts>(xs)...) { : upper_layer_(std::forward<Ts>(xs)...), policy_(std::move(policy)) {
// nop memset(&flags, 0, sizeof(flags_t));
} }
// -- interface for stream_oriented_layer_ptr -------------------------------- // -- interface for stream_oriented_layer_ptr --------------------------------
...@@ -151,139 +193,205 @@ public: ...@@ -151,139 +193,205 @@ public:
// -- event callbacks -------------------------------------------------------- // -- event callbacks --------------------------------------------------------
template <class ParentPtr> template <class ParentPtr>
bool handle_read_event(ParentPtr parent) { read_result handle_read_event(ParentPtr parent) {
CAF_LOG_TRACE(CAF_ARG2("socket", parent->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", parent->handle().id));
auto fail = [this, parent](auto reason) { // Pointer for passing "this layer" to the next one down the chain.
auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent);
// Convenience lambda for failing the application.
auto fail = [this, &parent, &this_layer_ptr](auto reason) {
CAF_LOG_DEBUG("read failed" << CAF_ARG(reason)); CAF_LOG_DEBUG("read failed" << CAF_ARG(reason));
parent->abort_reason(std::move(reason)); parent->abort_reason(std::move(reason));
auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent);
upper_layer_.abort(this_layer_ptr, parent->abort_reason()); upper_layer_.abort(this_layer_ptr, parent->abort_reason());
return false; return read_result::stop;
}; };
if (read_buf_.size() < max_read_size_) // Convenience lambda for invoking the next layer.
read_buf_.resize(max_read_size_); auto invoke_upper_layer = [this, &this_layer_ptr](byte* ptr, ptrdiff_t off,
auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent); ptrdiff_t delta) {
static constexpr bool has_after_reading auto bytes = make_span(ptr, off);
= detail::has_after_reading_v<UpperLayer, decltype(this_layer_ptr)>; return upper_layer_.consume(this_layer_ptr, bytes, bytes.subspan(delta));
for (size_t i = 0; max_read_size_ > 0 && i < max_consecutive_reads_; ++i) { };
// Calling configure_read(read_policy::stop()) halts receive events. // Resume a write operation if the transport waited for the socket to be
if (max_read_size_ == 0) { // readable from the last call to handle_write_event.
if constexpr (has_after_reading) if (flags.wanted_read_from_write_event) {
upper_layer_.after_reading(this_layer_ptr); flags.wanted_read_from_write_event = false;
return false; switch (handle_write_event(parent))
} else if (offset_ >= max_read_size_) { {
auto old_max = max_read_size_; case write_result::want_read:
// This may happen if the upper layer changes it receive policy to a CAF_ASSERT(flags.wanted_read_from_write_event);
// smaller max. size than what was already available. In this case, the return read_result::again;
// upper layer must consume bytes before we can receive new data. case write_result::handover:
auto bytes = make_span(read_buf_.data(), max_read_size_); return read_result::handover;
ptrdiff_t consumed = upper_layer_.consume(this_layer_ptr, bytes, {}); case write_result::again:
CAF_LOG_DEBUG(CAF_ARG2("socket", parent->handle().id) parent->register_writing();
<< CAF_ARG(consumed)); break;
if (consumed < 0) { default:
upper_layer_.abort(this_layer_ptr, break;
parent->abort_reason_or(caf::sec::runtime_error));
return false;
} else if (consumed == 0) {
// At the very least, the upper layer must accept more data next time.
if (old_max >= max_read_size_) {
upper_layer_.abort(this_layer_ptr, parent->abort_reason_or(
caf::sec::runtime_error,
"unable to make progress"));
return false;
}
} }
// Try again.
continue;
} }
// Before returning from the event handler, we always call after_reading for
// clients that request this callback.
auto after_reading_guard
= detail::make_scope_guard([this, &this_layer_ptr] {
if constexpr (detail::has_after_reading_v<UpperLayer,
decltype(this_layer_ptr)>)
upper_layer_.after_reading(this_layer_ptr);
});
// Loop until meeting one of our stop criteria.
for (size_t read_count = 0;;) {
// Stop condition 1: the application halted receive operations. Usually by
// calling `configure_read(read_policy::stop())`. Here, we return and ask
// the multiplexer to not call this event handler until
// `register_reading()` gets called.
if (max_read_size_ == 0)
return read_result::stop;
// Make sure our buffer has sufficient space.
if (read_buf_.size() < max_read_size_)
read_buf_.resize(max_read_size_);
// Fetch new data and stop on errors.
auto rd_buf = make_span(read_buf_.data() + offset_, auto rd_buf = make_span(read_buf_.data() + offset_,
max_read_size_ - static_cast<size_t>(offset_)); max_read_size_ - static_cast<size_t>(offset_));
auto read_res = read(parent->handle(), rd_buf); auto read_res = policy_.read(parent->handle(), rd_buf);
CAF_LOG_DEBUG(CAF_ARG2("socket", parent->handle().id) // CAF_LOG_DEBUG(CAF_ARG2("socket", parent->handle().id) //
<< CAF_ARG(max_read_size_) << CAF_ARG(offset_) << CAF_ARG(max_read_size_) << CAF_ARG(offset_)
<< CAF_ARG(read_res)); << CAF_ARG(read_res));
// Update state. // Stop condition 2: cannot get data from the socket.
if (read_res > 0) { if (read_res < 0) {
// Try again later on temporary errors such as EWOULDBLOCK and
// stop reading on the socket on hard errors.
switch (policy_.last_error(parent->handle(), read_res)) {
case stream_transport_error::temporary:
case stream_transport_error::want_read:
return read_result::again;
case stream_transport_error::want_write:
flags.wanted_write_from_read_event = true;
return read_result::want_write;
default:
return fail(sec::socket_operation_failed);
}
} else if (read_res == 0) {
// read() returns 0 only if the connection was closed.
return fail(sec::socket_disconnected);
}
++read_count;
// Ask the next layer to process some data.
offset_ += read_res; offset_ += read_res;
if (offset_ < min_read_size_) auto internal_buffer_size = policy_.buffered();
continue; // The offset_ may change as a result of invoking the upper layer. Hence,
auto bytes = make_span(read_buf_.data(), offset_); // need to run this in a loop to push data up for as long as we have
auto delta = bytes.subspan(delta_offset_); // buffered data available.
ptrdiff_t consumed = upper_layer_.consume(this_layer_ptr, bytes, delta); while (offset_ >= min_read_size_) {
// Here, we have yet another loop. This one makes sure that we do not
// leave this event handler if we can make progress from the data
// buffered inside the socket policy. For 'raw' policies (like the
// default policy), there is no buffer. However, any block-oriented
// transport like OpenSSL has to buffer data internally. We need to make
// sure to consume the buffer because the OS does not know about it and
// will not trigger a read event based on data available there.
do {
ptrdiff_t consumed = invoke_upper_layer(read_buf_.data(), offset_,
delta_offset_);
CAF_LOG_DEBUG(CAF_ARG2("socket", parent->handle().id) CAF_LOG_DEBUG(CAF_ARG2("socket", parent->handle().id)
<< CAF_ARG(consumed)); << CAF_ARG(consumed));
if (consumed > 0) { if (consumed < 0) {
upper_layer_.abort(this_layer_ptr,
parent->abort_reason_or(caf::sec::runtime_error,
"consumed < 0"));
return read_result::stop;
}
// Shift unconsumed bytes to the beginning of the buffer. // Shift unconsumed bytes to the beginning of the buffer.
if (consumed < offset_) if (consumed < offset_)
std::copy(read_buf_.begin() + consumed, read_buf_.begin() + offset_, std::copy(read_buf_.begin() + consumed, read_buf_.begin() + offset_,
read_buf_.begin()); read_buf_.begin());
offset_ -= consumed; offset_ -= consumed;
delta_offset_ = offset_; delta_offset_ = offset_;
} else if (consumed < 0) { // Stop if the application asked for it.
upper_layer_.abort(this_layer_ptr, if (max_read_size_ == 0)
parent->abort_reason_or(caf::sec::runtime_error, return read_result::stop;
"consumed < 0")); if (internal_buffer_size > 0 && offset_ < max_read_size_) {
return false; // Fetch already buffered data to 'refill' the buffer as we go.
auto n = std::min(internal_buffer_size,
max_read_size_ - static_cast<size_t>(offset_));
auto rdb = make_span(read_buf_.data() + offset_, n);
auto rd = policy_.read(parent->handle(), rdb);
if (rd < 0)
return fail(make_error(caf::sec::runtime_error,
"policy error: reading buffered data "
"may not result in an error"));
offset_ += rd;
internal_buffer_size = policy_.buffered();
}
} while (internal_buffer_size > 0);
} }
// Our thresholds may have changed if the upper layer called // Our thresholds may have changed if the upper layer called
// configure_read. Shrink/grow buffer as necessary. // configure_read. Shrink/grow buffer as necessary.
if (read_buf_.size() != max_read_size_) if (read_buf_.size() != max_read_size_ && offset_ <= max_read_size_)
if (offset_ < max_read_size_)
read_buf_.resize(max_read_size_); read_buf_.resize(max_read_size_);
} else if (read_res < 0) { // Try again (next for-loop iteration) unless we hit the read limit.
// Try again later on temporary errors such as EWOULDBLOCK and if (read_count >= max_consecutive_reads_)
// stop reading on the socket on hard errors. return read_result::again;
if (last_socket_error_is_temporary()) {
if constexpr (has_after_reading)
upper_layer_.after_reading(this_layer_ptr);
return true;
} else {
return fail(sec::socket_operation_failed);
}
} else {
// read() returns 0 iff the connection was closed.
return fail(sec::socket_disconnected);
}
} }
// Calling configure_read(read_policy::stop()) halts receive events.
if constexpr (has_after_reading)
upper_layer_.after_reading(this_layer_ptr);
return max_read_size_ > 0;
} }
template <class ParentPtr> template <class ParentPtr>
bool handle_write_event(ParentPtr parent) { write_result handle_write_event(ParentPtr parent) {
CAF_LOG_TRACE(CAF_ARG2("socket", parent->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", parent->handle().id));
auto fail = [this, parent](sec reason) { auto fail = [this, parent](sec reason) {
CAF_LOG_DEBUG("read failed" << CAF_ARG(reason)); CAF_LOG_DEBUG("read failed" << CAF_ARG(reason));
parent->abort_reason(reason); parent->abort_reason(reason);
auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent); auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent);
upper_layer_.abort(this_layer_ptr, reason); upper_layer_.abort(this_layer_ptr, reason);
return false; return write_result::stop;
}; };
// Resume a read operation if the transport waited for the socket to be
// writable from the last call to handle_read_event.
if (flags.wanted_write_from_read_event) {
flags.wanted_write_from_read_event = false;
switch (handle_read_event(parent)) {
case read_result::want_write:
CAF_ASSERT(flags.wanted_write_from_read_event);
return write_result::again;
case read_result::handover:
return write_result::handover;
case read_result::again:
parent->register_reading();
break;
default:
break;
}
// Fall though and see if we also have something to write.
}
// Allow the upper layer to add extra data to the write buffer. // Allow the upper layer to add extra data to the write buffer.
auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent); auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent);
if (!upper_layer_.prepare_send(this_layer_ptr)) { if (!upper_layer_.prepare_send(this_layer_ptr)) {
upper_layer_.abort(this_layer_ptr, upper_layer_.abort(this_layer_ptr,
parent->abort_reason_or(caf::sec::runtime_error, parent->abort_reason_or(caf::sec::runtime_error,
"prepare_send failed")); "prepare_send failed"));
return false; return write_result::stop;
} }
if (write_buf_.empty()) if (write_buf_.empty())
return !upper_layer_.done_sending(this_layer_ptr); return !upper_layer_.done_sending(this_layer_ptr) ? write_result::again
auto written = write(parent->handle(), write_buf_); : write_result::stop;
if (written > 0) { auto write_res = policy_.write(parent->handle(), write_buf_);
write_buf_.erase(write_buf_.begin(), write_buf_.begin() + written); if (write_res > 0) {
return !write_buf_.empty() || !upper_layer_.done_sending(this_layer_ptr); write_buf_.erase(write_buf_.begin(), write_buf_.begin() + write_res);
} else if (written < 0) { return !write_buf_.empty() || !upper_layer_.done_sending(this_layer_ptr)
? write_result::again
: write_result::stop;
} else if (write_res < 0) {
// Try again later on temporary errors such as EWOULDBLOCK and // Try again later on temporary errors such as EWOULDBLOCK and
// stop writing to the socket on hard errors. // stop writing to the socket on hard errors.
return last_socket_error_is_temporary() switch (policy_.last_error(parent->handle(), write_res)) {
? true case stream_transport_error::temporary:
: fail(sec::socket_operation_failed); case stream_transport_error::want_write:
return write_result::again;
case stream_transport_error::want_read:
flags.wanted_read_from_write_event = true;
return write_result::want_read;
default:
return fail(sec::socket_operation_failed);
}
} else { } else {
// write() returns 0 iff the connection was closed. // write() returns 0 if the connection was closed.
return fail(sec::socket_disconnected); return fail(sec::socket_disconnected);
} }
} }
...@@ -303,32 +411,60 @@ public: ...@@ -303,32 +411,60 @@ public:
} }
private: private:
// Caches the config parameter for limiting max. socket operations. ///
struct flags_t {
bool wanted_read_from_write_event : 1;
bool wanted_write_from_read_event : 1;
} flags;
/// Caches the config parameter for limiting max. socket operations.
uint32_t max_consecutive_reads_ = 0; uint32_t max_consecutive_reads_ = 0;
// Caches the write buffer size of the socket. /// Caches the write buffer size of the socket.
uint32_t max_write_buf_size_ = 0; uint32_t max_write_buf_size_ = 0;
// Stores what the user has configured as read threshold. /// Stores what the user has configured as read threshold.
uint32_t min_read_size_ = 0; uint32_t min_read_size_ = 0;
// Stores what the user has configured as max. number of bytes to receive. /// Stores what the user has configured as max. number of bytes to receive.
uint32_t max_read_size_ = 0; uint32_t max_read_size_ = 0;
// Stores the current offset in `read_buf_`. /// Stores the current offset in `read_buf_`.
ptrdiff_t offset_ = 0; ptrdiff_t offset_ = 0;
// Stores the offset in `read_buf_` since last calling `upper_layer_.consume`. /// Stores the offset in `read_buf_` since last calling `upper_layer_.consume`.
ptrdiff_t delta_offset_ = 0; ptrdiff_t delta_offset_ = 0;
// Caches incoming data. /// Caches incoming data.
byte_buffer read_buf_; byte_buffer read_buf_;
// Caches outgoing data. /// Caches outgoing data.
byte_buffer write_buf_; byte_buffer write_buf_;
// Processes incoming data and generates outgoing data. /// Processes incoming data and generates outgoing data.
UpperLayer upper_layer_; UpperLayer upper_layer_;
/// Configures how we read and write to the socket.
Policy policy_;
};
/// Implements a stream_transport that manages a stream socket.
template <class UpperLayer>
class stream_transport
: public stream_transport_base<default_stream_transport_policy, UpperLayer> {
public:
// -- member types -----------------------------------------------------------
using super
= stream_transport_base<default_stream_transport_policy, UpperLayer>;
// -- constructors, destructors, and assignment operators --------------------
template <class... Ts>
explicit stream_transport(Ts&&... xs)
: super(default_stream_transport_policy{}, std::forward<Ts>(xs)...) {
// nop
}
}; };
} // namespace caf::net } // namespace caf::net
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
#include "caf/net/multiplexer.hpp" #include "caf/net/multiplexer.hpp"
#include <algorithm>
#include "caf/action.hpp" #include "caf/action.hpp"
#include "caf/byte.hpp" #include "caf/byte.hpp"
#include "caf/config.hpp" #include "caf/config.hpp"
...@@ -21,6 +19,9 @@ ...@@ -21,6 +19,9 @@
#include "caf/span.hpp" #include "caf/span.hpp"
#include "caf/variant.hpp" #include "caf/variant.hpp"
#include <algorithm>
#include <optional>
#ifndef CAF_WINDOWS #ifndef CAF_WINDOWS
# include <poll.h> # include <poll.h>
#else #else
...@@ -51,9 +52,25 @@ const short error_mask = POLLRDHUP | POLLERR | POLLHUP | POLLNVAL; ...@@ -51,9 +52,25 @@ const short error_mask = POLLRDHUP | POLLERR | POLLHUP | POLLNVAL;
const short output_mask = POLLOUT; const short output_mask = POLLOUT;
short to_bitmask(operation mask) { // short to_bitmask(operation mask) {
return static_cast<short>((is_reading(mask) ? input_mask : 0) // return static_cast<short>((is_reading(mask) ? input_mask : 0)
| (is_writing(mask) ? output_mask : 0)); // | (is_writing(mask) ? output_mask : 0));
// }
operation to_operation(const socket_manager_ptr& mgr,
std::optional<short> mask) {
operation res = operation::none;
if (mgr->read_closed())
res = block_reads(res);
if (mgr->write_closed())
res = block_writes(res);
if (mask) {
if ((*mask & input_mask) != 0)
res = add_read_flag(res);
if ((*mask & output_mask) != 0)
res = add_write_flag(res);
}
return res;
} }
} // namespace } // namespace
...@@ -96,7 +113,8 @@ error multiplexer::init() { ...@@ -96,7 +113,8 @@ error multiplexer::init() {
settings dummy; settings dummy;
if (auto err = updater->init(dummy)) if (auto err = updater->init(dummy))
return err; return err;
add(std::move(updater)); register_reading(updater);
apply_updates();
write_handle_ = pipe_handles->second; write_handle_ = pipe_handles->second;
return none; return none;
} }
...@@ -114,6 +132,13 @@ ptrdiff_t multiplexer::index_of(const socket_manager_ptr& mgr) { ...@@ -114,6 +132,13 @@ ptrdiff_t multiplexer::index_of(const socket_manager_ptr& mgr) {
return i == last ? -1 : std::distance(first, i); return i == last ? -1 : std::distance(first, i);
} }
ptrdiff_t multiplexer::index_of(socket fd) {
auto first = pollset_.begin();
auto last = pollset_.end();
auto i = std::find_if(first, last, [fd](pollfd& x) { return x.fd == fd.id; });
return i == last ? -1 : std::distance(first, i);
}
middleman& multiplexer::owner() { middleman& multiplexer::owner() {
CAF_ASSERT(owner_ != nullptr); CAF_ASSERT(owner_ != nullptr);
return *owner_; return *owner_;
...@@ -123,93 +148,105 @@ actor_system& multiplexer::system() { ...@@ -123,93 +148,105 @@ actor_system& multiplexer::system() {
return owner().system(); return owner().system();
} }
// -- thread-safe signaling ---------------------------------------------------- operation multiplexer::mask_of(const socket_manager_ptr& mgr) {
auto fd = mgr->handle();
if (auto i = updates_.find(fd); i != updates_.end())
return to_operation(mgr, i->second.events);
else if (auto index = index_of(mgr);index!=-1)
return to_operation(mgr, pollset_[index].events);
else
return to_operation(mgr, std::nullopt);
}
// -- thread-safe signaling and their internal callbacks -----------------------
void multiplexer::register_reading(const socket_manager_ptr& mgr) { void multiplexer::register_reading(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_ || mgr->abort_reason()) { do_register_reading(mgr);
// nop
} else if (!is_idle(mgr->mask())) {
if (auto index = index_of(mgr); index != -1 && mgr->set_read_flag()) {
auto& fd = pollset_[index];
fd.events |= input_mask;
}
} else if (mgr->set_read_flag()) {
add(mgr);
}
} else { } else {
write_to_pipe(pollset_updater::code::register_reading, mgr.get()); write_to_pipe(pollset_updater::code::register_reading, mgr.get());
} }
} }
void multiplexer::do_register_reading(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
// When shutting down, no new reads are allowed.
if (shutting_down_)
mgr->close_read();
else if (!mgr->read_closed())
update_for(mgr).events |= input_mask;
}
void multiplexer::register_writing(const socket_manager_ptr& mgr) { void multiplexer::register_writing(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
CAF_ASSERT(mgr != nullptr);
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (mgr->abort_reason()) { do_register_writing(mgr);
// nop
} else if (!is_idle(mgr->mask())) {
if (auto index = index_of(mgr); index != -1 && mgr->set_write_flag()) {
auto& fd = pollset_[index];
fd.events |= output_mask;
}
} else if (mgr->set_write_flag()) {
add(mgr);
}
} else { } else {
write_to_pipe(pollset_updater::code::register_writing, mgr.get()); write_to_pipe(pollset_updater::code::register_writing, mgr.get());
} }
} }
void multiplexer::do_register_writing(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
// When shutting down, we do allow managers to write whatever is currently
// pending but we make sure that all read channels are closed.
if (shutting_down_)
mgr->close_read();
if (!mgr->write_closed())
update_for(mgr).events |= output_mask;
}
void multiplexer::discard(const socket_manager_ptr& mgr) { void multiplexer::discard(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_) { do_discard(mgr);
// nop
} else {
mgr->handle_error(sec::discarded);
if (auto mgr_index = index_of(mgr); mgr_index != -1)
del(mgr_index);
}
} else { } else {
write_to_pipe(pollset_updater::code::discard_manager, mgr.get()); write_to_pipe(pollset_updater::code::discard_manager, mgr.get());
} }
} }
void multiplexer::do_discard(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
mgr->handle_error(sec::disposed);
update_for(mgr).events = 0;
}
void multiplexer::shutdown_reading(const socket_manager_ptr& mgr) { void multiplexer::shutdown_reading(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_) { do_shutdown_reading(mgr);
// nop
} else if (auto index = index_of(mgr); index != -1) {
mgr->block_reads();
auto& entry = pollset_[index];
entry.events &= ~input_mask;
if (entry.events == 0)
del(index);
}
} else { } else {
write_to_pipe(pollset_updater::code::shutdown_reading, mgr.get()); write_to_pipe(pollset_updater::code::shutdown_reading, mgr.get());
} }
} }
void multiplexer::do_shutdown_reading(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (!shutting_down_ && !mgr->read_closed()) {
mgr->close_read();
update_for(mgr).events &= ~input_mask;
}
}
void multiplexer::shutdown_writing(const socket_manager_ptr& mgr) { void multiplexer::shutdown_writing(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_) { do_shutdown_writing(mgr);
// nop
} else if (auto index = index_of(mgr); index != -1) {
mgr->block_writes();
auto& entry = pollset_[index];
entry.events &= ~output_mask;
if (entry.events == 0)
del(index);
}
} else { } else {
write_to_pipe(pollset_updater::code::shutdown_writing, mgr.get()); write_to_pipe(pollset_updater::code::shutdown_writing, mgr.get());
} }
} }
void multiplexer::do_shutdown_writing(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (!shutting_down_ && !mgr->write_closed()) {
mgr->close_write();
update_for(mgr).events &= ~output_mask;
}
}
void multiplexer::schedule(const action& what) { void multiplexer::schedule(const action& what) {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
write_to_pipe(pollset_updater::code::run_action, what.ptr()); write_to_pipe(pollset_updater::code::run_action, what.ptr());
...@@ -218,21 +255,35 @@ void multiplexer::schedule(const action& what) { ...@@ -218,21 +255,35 @@ void multiplexer::schedule(const action& what) {
void multiplexer::init(const socket_manager_ptr& mgr) { void multiplexer::init(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_) { if (!shutting_down_) {
// nop
} else {
if (auto err = mgr->init(content(system().config()))) { if (auto err = mgr->init(content(system().config()))) {
CAF_LOG_ERROR("mgr->init failed: " << err); CAF_LOG_DEBUG("mgr->init failed: " << err);
// The socket manager should not register itself for any events if // The socket manager should not register itself for any events if
// initialization fails. So there's probably nothing we could do // initialization fails. Purge any state just in case.
// here other than discarding the manager. update_for(mgr).events = 0;
} }
// Else: no update since the manager is supposed to call continue_reading
// and continue_writing as necessary.
} }
} else { } else {
write_to_pipe(pollset_updater::code::init_manager, mgr.get()); write_to_pipe(pollset_updater::code::init_manager, mgr.get());
} }
} }
void multiplexer::do_init(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (!shutting_down_) {
if (auto err = mgr->init(content(system().config()))) {
CAF_LOG_DEBUG("mgr->init failed: " << err);
// The socket manager should not register itself for any events if
// initialization fails. Purge any state just in case.
update_for(mgr).events = 0;
}
// Else: no update since the manager is supposed to call continue_reading
// and continue_writing as necessary.
}
}
void multiplexer::close_pipe() { void multiplexer::close_pipe() {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
std::lock_guard<std::mutex> guard{write_lock_}; std::lock_guard<std::mutex> guard{write_lock_};
...@@ -263,22 +314,21 @@ bool multiplexer::poll_once(bool blocking) { ...@@ -263,22 +314,21 @@ bool multiplexer::poll_once(bool blocking) {
<< presult << "event(s)"); << presult << "event(s)");
// Scan pollset for events. // Scan pollset for events.
CAF_LOG_DEBUG("scan pollset for socket events"); CAF_LOG_DEBUG("scan pollset for socket events");
for (size_t i = 0; i < pollset_.size() && presult > 0;) { if (auto revents = pollset_[0].revents; revents != 0) {
auto revents = pollset_[i].revents; // Index 0 is always the pollset updater. This is the only handler that
if (revents != 0) { // is allowed to modify pollset_ and managers_. Since this may very well
auto events = pollset_[i].events; // mess with the for loop below, we process this handler first.
auto mgr = managers_[i]; auto mgr = managers_[0];
auto new_events = handle(mgr, events, revents); handle(mgr,pollset_[0].events,revents);
--presult; --presult;
if (new_events == 0) {
del(i);
continue;
} else if (new_events != events) {
pollset_[i].events = new_events;
} }
for (size_t i = 1; i < pollset_.size() && presult > 0; ++i) {
if (auto revents = pollset_[i].revents; revents != 0) {
handle(managers_[i], pollset_[i].events, revents);
--presult;
} }
++i;
} }
apply_updates();
return true; return true;
} else if (presult == 0) { } else if (presult == 0) {
// No activity. // No activity.
...@@ -310,6 +360,28 @@ bool multiplexer::poll_once(bool blocking) { ...@@ -310,6 +360,28 @@ bool multiplexer::poll_once(bool blocking) {
} }
} }
void multiplexer::apply_updates() {
CAF_LOG_DEBUG("apply" << updates_.size() << "updates");
if (!updates_.empty()) {
for (auto& [fd, update] : updates_) {
if (auto index = index_of(fd); index == -1) {
if (update.events != 0) {
pollfd new_entry{socket_cast<socket_id>(fd), update.events, 0};
pollset_.emplace_back(new_entry);
managers_.emplace_back(std::move(update.mgr));
}
} else if (update.events != 0) {
pollset_[index].events = update.events;
managers_[index].swap(update.mgr);
} else {
pollset_.erase(pollset_.begin() + index);
managers_.erase(managers_.begin() + index);
}
}
updates_.clear();
}
}
void multiplexer::set_thread_id() { void multiplexer::set_thread_id() {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
tid_ = std::this_thread::get_id(); tid_ = std::this_thread::get_id();
...@@ -323,51 +395,86 @@ void multiplexer::run() { ...@@ -323,51 +395,86 @@ void multiplexer::run() {
} }
void multiplexer::shutdown() { void multiplexer::shutdown() {
CAF_LOG_TRACE(""); // Note: there is no 'shortcut' when calling the function in the multiplexer's
if (std::this_thread::get_id() == tid_) { // thread, because do_shutdown calls apply_updates. This must only be called
CAF_LOG_DEBUG("initiate shutdown"); // from the pollset_updater.
shutting_down_ = true;
// First manager is the pollset_updater. Skip it and delete later.
for (size_t i = 1; i < managers_.size();) {
auto& mgr = managers_[i];
if (mgr->unset_read_flag()) {
auto& fd = pollset_[index_of(mgr)];
fd.events &= ~input_mask;
}
mgr->block_reads();
if (is_idle(mgr->mask()))
del(i);
else
++i;
}
} else {
CAF_LOG_DEBUG("push shutdown event to pipe"); CAF_LOG_DEBUG("push shutdown event to pipe");
write_to_pipe(pollset_updater::code::shutdown, write_to_pipe(pollset_updater::code::shutdown,
static_cast<socket_manager*>(nullptr)); static_cast<socket_manager*>(nullptr));
}
void multiplexer::do_shutdown() {
// Note: calling apply_updates here is only safe because we know that the
// pollset updater runs outside of the for-loop in run_once.
CAF_LOG_DEBUG("initiate shutdown");
shutting_down_ = true;
apply_updates();
// Skip the first manager (the pollset updater).
for (size_t i = 1; i < managers_.size(); ++i) {
auto& mgr = managers_[i];
mgr->close_read();
update_for(static_cast<ptrdiff_t>(i)).events &= ~input_mask;
} }
apply_updates();
} }
// -- utility functions -------------------------------------------------------- // -- utility functions --------------------------------------------------------
short multiplexer::handle(const socket_manager_ptr& mgr, void multiplexer::handle(const socket_manager_ptr& mgr,
[[maybe_unused]] short events, short revents) { [[maybe_unused]] short events, short revents) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id) CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)
<< CAF_ARG(events) << CAF_ARG(revents)); << CAF_ARG(events) << CAF_ARG(revents));
CAF_ASSERT(mgr != nullptr); CAF_ASSERT(mgr != nullptr);
bool checkerror = true; bool checkerror = true;
// Convenience function for performing a handover between managers.
auto do_handover = [this, &mgr] {
// Make sure to override the manager pointer in the update. Updates are
// associated to sockets, so the new manager is likely to modify this update
// again. Hence, it *must not* point to the old manager.
auto& update = update_for(mgr);
auto new_mgr = mgr->do_handover();
update.events = 0;
if (new_mgr != nullptr)
update.mgr = new_mgr;
};
//
// Note: we double-check whether the manager is actually reading because a // Note: we double-check whether the manager is actually reading because a
// previous action from the pipe may have called shutdown_reading. // previous action from the pipe may have called shutdown_reading.
if ((revents & input_mask) != 0 && mgr->is_reading()) { if ((events & revents & input_mask) != 0) {
checkerror = false; checkerror = false;
if (!mgr->handle_read_event()) switch (mgr->handle_read_event()) {
mgr->unset_read_flag(); default: // socket_manager::read_result::again
// Nothing to do, bitmask may remain unchanged.
break;
case socket_manager::read_result::stop:
update_for(mgr).events &= ~input_mask;
break;
case socket_manager::read_result::want_write:
update_for(mgr).events = output_mask;
break;
case socket_manager::read_result::handover: {
do_handover();
return;
}
}
} }
// Similar reasoning than before: double-check whether this event should still // Similar reasoning than before: double-check whether this event should still
// get dispatched. // get dispatched.
if ((revents & output_mask) != 0 && mgr->is_writing()) { if ((events & revents & output_mask) != 0) {
checkerror = false; checkerror = false;
if (!mgr->handle_write_event()) switch (mgr->handle_write_event()) {
mgr->unset_write_flag(); default: // socket_manager::write_result::again
break;
case socket_manager::write_result::stop:
update_for(mgr).events &= ~output_mask;
break;
case socket_manager::write_result::want_read:
update_for(mgr).events = input_mask;
break;
case socket_manager::write_result::handover:
do_handover();
return;
}
} }
if (checkerror && ((revents & error_mask) != 0)) { if (checkerror && ((revents & error_mask) != 0)) {
if (revents & POLLNVAL) if (revents & POLLNVAL)
...@@ -376,26 +483,29 @@ short multiplexer::handle(const socket_manager_ptr& mgr, ...@@ -376,26 +483,29 @@ short multiplexer::handle(const socket_manager_ptr& mgr,
mgr->handle_error(sec::socket_disconnected); mgr->handle_error(sec::socket_disconnected);
else else
mgr->handle_error(sec::socket_operation_failed); mgr->handle_error(sec::socket_operation_failed);
mgr->block_reads_and_writes(); update_for(mgr).events = 0;
return 0;
} else {
return to_bitmask(mgr->mask());
} }
} }
void multiplexer::add(socket_manager_ptr mgr) { multiplexer::poll_update& multiplexer::update_for(ptrdiff_t index) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle())); auto fd = socket{pollset_[index].fd};
CAF_ASSERT(index_of(mgr) == -1); if (auto i = updates_.find(fd); i != updates_.end()) {
pollfd new_entry{socket_cast<socket_id>(mgr->handle()), return i->second;
to_bitmask(mgr->mask()), 0}; } else {
pollset_.emplace_back(new_entry); updates_.container().emplace_back(fd, poll_update{pollset_[index].events,
managers_.emplace_back(std::move(mgr)); managers_[index]});
return updates_.container().back().second;
}
} }
void multiplexer::del(ptrdiff_t index) { multiplexer::poll_update&
CAF_ASSERT(index != -1); multiplexer::update_for(const socket_manager_ptr& mgr) {
pollset_.erase(pollset_.begin() + index); auto fd = mgr->handle();
managers_.erase(managers_.begin() + index); if (auto index = index_of(fd); index != -1) {
return update_for(index);
} else {
return updates_.emplace(fd, poll_update{0, mgr}).first->second;
}
} }
} // namespace caf::net } // namespace caf::net
...@@ -18,7 +18,7 @@ namespace caf::net { ...@@ -18,7 +18,7 @@ namespace caf::net {
pollset_updater::pollset_updater(pipe_socket read_handle, multiplexer* parent) pollset_updater::pollset_updater(pipe_socket read_handle, multiplexer* parent)
: super(read_handle, parent) { : super(read_handle, parent) {
mask_ = operation::read; // nop
} }
pollset_updater::~pollset_updater() { pollset_updater::~pollset_updater() {
...@@ -45,7 +45,7 @@ void run_action(intptr_t ptr) { ...@@ -45,7 +45,7 @@ void run_action(intptr_t ptr) {
} // namespace } // namespace
bool pollset_updater::handle_read_event() { pollset_updater::read_result pollset_updater::handle_read_event() {
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
for (;;) { for (;;) {
CAF_ASSERT((buf_.size() - buf_size_) > 0); CAF_ASSERT((buf_.size() - buf_size_) > 0);
...@@ -60,29 +60,29 @@ bool pollset_updater::handle_read_event() { ...@@ -60,29 +60,29 @@ bool pollset_updater::handle_read_event() {
memcpy(&ptr, buf_.data() + 1, sizeof(intptr_t)); memcpy(&ptr, buf_.data() + 1, sizeof(intptr_t));
switch (static_cast<code>(opcode)) { switch (static_cast<code>(opcode)) {
case code::register_reading: case code::register_reading:
parent_->register_reading(as_mgr(ptr)); mpx_->do_register_reading(as_mgr(ptr));
break; break;
case code::register_writing: case code::register_writing:
parent_->register_writing(as_mgr(ptr)); mpx_->do_register_writing(as_mgr(ptr));
break; break;
case code::init_manager: case code::init_manager:
parent_->init(as_mgr(ptr)); mpx_->do_init(as_mgr(ptr));
break; break;
case code::discard_manager: case code::discard_manager:
parent_->discard(as_mgr(ptr)); mpx_->do_discard(as_mgr(ptr));
break; break;
case code::shutdown_reading: case code::shutdown_reading:
parent_->shutdown_reading(as_mgr(ptr)); mpx_->do_shutdown_reading(as_mgr(ptr));
break; break;
case code::shutdown_writing: case code::shutdown_writing:
parent_->shutdown_writing(as_mgr(ptr)); mpx_->do_shutdown_writing(as_mgr(ptr));
break; break;
case code::run_action: case code::run_action:
run_action(ptr); run_action(ptr);
break; break;
case code::shutdown: case code::shutdown:
CAF_ASSERT(ptr == 0); CAF_ASSERT(ptr == 0);
parent_->shutdown(); mpx_->do_shutdown();
break; break;
default: default:
CAF_LOG_ERROR("opcode not recognized: " << CAF_ARG(opcode)); CAF_LOG_ERROR("opcode not recognized: " << CAF_ARG(opcode));
...@@ -91,15 +91,17 @@ bool pollset_updater::handle_read_event() { ...@@ -91,15 +91,17 @@ bool pollset_updater::handle_read_event() {
} }
} else if (num_bytes == 0) { } else if (num_bytes == 0) {
CAF_LOG_DEBUG("pipe closed, assume shutdown"); CAF_LOG_DEBUG("pipe closed, assume shutdown");
return false; return read_result::stop;
} else if (last_socket_error_is_temporary()) {
return read_result::again;
} else { } else {
return last_socket_error_is_temporary(); return read_result::stop;
} }
} }
} }
bool pollset_updater::handle_write_event() { pollset_updater::write_result pollset_updater::handle_write_event() {
return false; return write_result::stop;
} }
void pollset_updater::handle_error(sec) { void pollset_updater::handle_error(sec) {
......
...@@ -10,10 +10,11 @@ ...@@ -10,10 +10,11 @@
namespace caf::net { namespace caf::net {
socket_manager::socket_manager(socket handle, multiplexer* parent) socket_manager::socket_manager(socket handle, multiplexer* mpx)
: handle_(handle), mask_(operation::none), parent_(parent) { : handle_(handle), mpx_(mpx) {
CAF_ASSERT(handle_ != invalid_socket); CAF_ASSERT(handle_ != invalid_socket);
CAF_ASSERT(parent != nullptr); CAF_ASSERT(mpx_ != nullptr);
memset(&flags_, 0, sizeof(flags_t));
} }
socket_manager::~socket_manager() { socket_manager::~socket_manager() {
...@@ -21,62 +22,51 @@ socket_manager::~socket_manager() { ...@@ -21,62 +22,51 @@ socket_manager::~socket_manager() {
} }
actor_system& socket_manager::system() noexcept { actor_system& socket_manager::system() noexcept {
CAF_ASSERT(parent_ != nullptr); CAF_ASSERT(mpx_ != nullptr);
return parent_->system(); return mpx_->system();
} }
bool socket_manager::set_read_flag() noexcept { void socket_manager::close_read() noexcept {
auto old = mask_; // TODO: extend transport API for closing read operations.
mask_ = add_read_flag(mask_); flags_.read_closed = true;
return old != mask_;
} }
bool socket_manager::set_write_flag() noexcept { void socket_manager::close_write() noexcept {
auto old = mask_; // TODO: extend transport API for closing write operations.
mask_ = add_write_flag(mask_); flags_.write_closed = true;
return old != mask_;
} }
bool socket_manager::unset_read_flag() noexcept { void socket_manager::abort_reason(error reason) noexcept {
auto old = mask_; abort_reason_ = std::move(reason);
mask_ = remove_read_flag(mask_); flags_.read_closed = true;
return old != mask_; flags_.write_closed = true;
}
bool socket_manager::unset_write_flag() noexcept {
auto old = mask_;
mask_ = remove_write_flag(mask_);
return old != mask_;
}
void socket_manager::block_reads() noexcept {
mask_ = net::block_reads(mask_);
}
void socket_manager::block_writes() noexcept {
mask_ = net::block_writes(mask_);
}
void socket_manager::block_reads_and_writes() noexcept {
mask_ = operation::shutdown;
} }
void socket_manager::register_reading() { void socket_manager::register_reading() {
if (!net::is_reading(mask_) && !is_read_blocked(mask_)) if (!read_closed())
parent_->register_reading(this); mpx_->register_reading(this);
} }
void socket_manager::register_writing() { void socket_manager::register_writing() {
if (!net::is_writing(mask_) && !is_write_blocked(mask_)) if (!write_closed())
parent_->register_writing(this); mpx_->register_writing(this);
} }
void socket_manager::shutdown_reading() { socket_manager_ptr socket_manager::do_handover() {
parent_->shutdown_reading(this); flags_.read_closed = true;
flags_.write_closed = true;
auto hdl = handle_;
handle_ = invalid_socket;
if (auto ptr = make_next_manager(hdl)) {
return ptr;
} else {
close(hdl);
return nullptr;
}
} }
void socket_manager::shutdown_writing() { socket_manager_ptr socket_manager::make_next_manager(socket) {
parent_->shutdown_writing(this); return {};
} }
} // namespace caf::net } // namespace caf::net
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
#include "caf/net/multiplexer.hpp" #include "caf/net/multiplexer.hpp"
#include "caf/net/test/host_fixture.hpp" #include "net-test.hpp"
#include "caf/test/dsl.hpp"
#include <new> #include <new>
#include <tuple> #include <tuple>
...@@ -23,19 +22,21 @@ using namespace caf::net; ...@@ -23,19 +22,21 @@ using namespace caf::net;
namespace { namespace {
using shared_atomic_count = std::shared_ptr<std::atomic<size_t>>;
class dummy_manager : public socket_manager { class dummy_manager : public socket_manager {
public: public:
dummy_manager(size_t& manager_count, stream_socket handle, dummy_manager(stream_socket handle, multiplexer* parent, std::string name,
multiplexer* parent) shared_atomic_count count)
: socket_manager(handle, parent), count_(manager_count) { : socket_manager(handle, parent), name(std::move(name)), count_(count) {
CAF_MESSAGE("created new dummy manager"); MESSAGE("created new dummy manager");
++count_; ++*count_;
rd_buf_.resize(1024); rd_buf_.resize(1024);
} }
~dummy_manager() { ~dummy_manager() {
CAF_MESSAGE("destroyed dummy manager"); MESSAGE("destroyed dummy manager");
--count_; --*count_;
} }
error init(const settings&) override { error init(const settings&) override {
...@@ -46,7 +47,11 @@ public: ...@@ -46,7 +47,11 @@ public:
return socket_cast<stream_socket>(handle_); return socket_cast<stream_socket>(handle_);
} }
bool handle_read_event() override { read_result handle_read_event() override {
if (trigger_handover) {
MESSAGE(name << " triggered a handover");
return read_result::handover;
}
if (read_capacity() < 1024) if (read_capacity() < 1024)
rd_buf_.resize(rd_buf_.size() + 2048); rd_buf_.resize(rd_buf_.size() + 2048);
auto num_bytes = read(handle(), auto num_bytes = read(handle(),
...@@ -54,28 +59,47 @@ public: ...@@ -54,28 +59,47 @@ public:
if (num_bytes > 0) { if (num_bytes > 0) {
CAF_ASSERT(num_bytes > 0); CAF_ASSERT(num_bytes > 0);
rd_buf_pos_ += num_bytes; rd_buf_pos_ += num_bytes;
return true; return read_result::again;
} else if (num_bytes < 0 && last_socket_error_is_temporary()) {
return read_result::again;
} else {
return read_result::stop;
} }
return num_bytes < 0 && last_socket_error_is_temporary();
} }
bool handle_write_event() override { write_result handle_write_event() override {
if (trigger_handover) {
MESSAGE(name << " triggered a handover");
return write_result::handover;
}
if (wr_buf_.size() == 0) if (wr_buf_.size() == 0)
return false; return write_result::stop;
auto num_bytes = write(handle(), wr_buf_); auto num_bytes = write(handle(), wr_buf_);
if (num_bytes > 0) { if (num_bytes > 0) {
wr_buf_.erase(wr_buf_.begin(), wr_buf_.begin() + num_bytes); wr_buf_.erase(wr_buf_.begin(), wr_buf_.begin() + num_bytes);
return wr_buf_.size() > 0; return wr_buf_.size() > 0 ? write_result::again : write_result::stop;
} }
return num_bytes < 0 && last_socket_error_is_temporary(); return num_bytes < 0 && last_socket_error_is_temporary()
? write_result::again
: write_result::stop;
} }
void handle_error(sec code) override { void handle_error(sec code) override {
CAF_FAIL("handle_error called with code " << code); FAIL("handle_error called with code " << code);
} }
void continue_reading() override { void continue_reading() override {
CAF_FAIL("continue_reading called"); FAIL("continue_reading called");
}
socket_manager_ptr make_next_manager(socket handle) override {
if (next != nullptr)
FAIL("asked to do handover twice!");
next = make_counted<dummy_manager>(socket_cast<stream_socket>(handle), mpx_,
"Carl", count_);
if (auto err = next->init(settings{}))
FAIL("next->init failed: " << err);
return next;
} }
void send(string_view x) { void send(string_view x) {
...@@ -89,6 +113,12 @@ public: ...@@ -89,6 +113,12 @@ public:
return result; return result;
} }
bool trigger_handover = false;
intrusive_ptr<dummy_manager> next;
std::string name;
private: private:
byte* read_position_begin() { byte* read_position_begin() {
return rd_buf_.data() + rd_buf_pos_; return rd_buf_.data() + rd_buf_pos_;
...@@ -102,7 +132,7 @@ private: ...@@ -102,7 +132,7 @@ private:
return rd_buf_.size() - rd_buf_pos_; return rd_buf_.size() - rd_buf_pos_;
} }
size_t& count_; shared_atomic_count count_;
size_t rd_buf_pos_ = 0; size_t rd_buf_pos_ = 0;
...@@ -115,84 +145,150 @@ using dummy_manager_ptr = intrusive_ptr<dummy_manager>; ...@@ -115,84 +145,150 @@ using dummy_manager_ptr = intrusive_ptr<dummy_manager>;
struct fixture : host_fixture { struct fixture : host_fixture {
fixture() : mpx(nullptr) { fixture() : mpx(nullptr) {
manager_count = std::make_shared<std::atomic<size_t>>(0);
mpx.set_thread_id(); mpx.set_thread_id();
} }
~fixture() { ~fixture() {
CAF_REQUIRE_EQUAL(manager_count, 0u); mpx.shutdown();
exhaust();
REQUIRE_EQ(*manager_count, 0u);
} }
void exhaust() { void exhaust() {
mpx.apply_updates();
while (mpx.poll_once(false)) while (mpx.poll_once(false))
; // Repeat. ; // Repeat.
} }
size_t manager_count = 0; void apply_updates() {
mpx.apply_updates();
}
auto make_manager(stream_socket fd, std::string name) {
return make_counted<dummy_manager>(fd, &mpx, std::move(name),
manager_count);
}
void init() {
if (auto err = mpx.init())
FAIL("mpx.init failed: " << err);
exhaust();
}
shared_atomic_count manager_count;
multiplexer mpx; multiplexer mpx;
}; };
} // namespace } // namespace
CAF_TEST_FIXTURE_SCOPE(multiplexer_tests, fixture) BEGIN_FIXTURE_SCOPE(fixture)
CAF_TEST(default construction) { SCENARIO("the multiplexer has no socket managers after default construction") {
CAF_CHECK_EQUAL(mpx.num_socket_managers(), 0u); GIVEN("a default constructed multiplexer") {
WHEN("querying the number of socket managers") {
THEN("the result is 0") {
CHECK_EQ(mpx.num_socket_managers(), 0u);
}
}
}
} }
CAF_TEST(init) { SCENARIO("the multiplexer constructs the pollset updater while initializing") {
CAF_CHECK_EQUAL(mpx.num_socket_managers(), 0u); GIVEN("an initialized multiplexer") {
CAF_REQUIRE_EQUAL(mpx.init(), none); WHEN("querying the number of socket managers") {
CAF_CHECK_EQUAL(mpx.num_socket_managers(), 1u); THEN("the result is 1") {
mpx.shutdown(); CHECK_EQ(mpx.num_socket_managers(), 0u);
CHECK_EQ(mpx.init(), none);
exhaust(); exhaust();
// Calling run must have no effect now. CHECK_EQ(mpx.num_socket_managers(), 1u);
mpx.run(); }
}
}
} }
CAF_TEST(send and receive) { SCENARIO("socket managers can register for read and write operations") {
CAF_REQUIRE_EQUAL(mpx.init(), none); GIVEN("an initialized multiplexer") {
auto sockets = unbox(make_stream_socket_pair()); init();
{ // Lifetime scope of alice and bob. WHEN("socket managers register for read and write operations") {
auto alice = make_counted<dummy_manager>(manager_count, sockets.first, auto [alice_fd, bob_fd] = unbox(make_stream_socket_pair());
&mpx); auto alice = make_manager(alice_fd, "Alice");
auto bob = make_counted<dummy_manager>(manager_count, sockets.second, &mpx); auto bob = make_manager(bob_fd, "Bob");
alice->register_reading(); alice->register_reading();
bob->register_reading(); bob->register_reading();
CAF_CHECK_EQUAL(mpx.num_socket_managers(), 3u); apply_updates();
alice->send("hello bob"); CHECK_EQ(mpx.num_socket_managers(), 3u);
THEN("the multiplexer runs callbacks on socket activity") {
alice->send("Hello Bob!");
alice->register_writing(); alice->register_writing();
exhaust(); exhaust();
CAF_CHECK_EQUAL(bob->receive(), "hello bob"); CHECK_EQ(bob->receive(), "Hello Bob!");
}
}
} }
mpx.shutdown();
} }
CAF_TEST(shutdown) { SCENARIO("a multiplexer terminates its thread after shutting down") {
std::mutex m; GIVEN("a multiplexer running in its own thread and some socket managers") {
std::condition_variable cv; init();
bool thread_id_set = false; auto go_time = std::make_shared<barrier>(2);
auto run_mpx = [&] { auto mpx_thread = std::thread{[this, go_time] {
{
std::unique_lock<std::mutex> guard(m);
mpx.set_thread_id(); mpx.set_thread_id();
thread_id_set = true; go_time->arrive_and_wait();
cv.notify_one();
}
mpx.run(); mpx.run();
}; }};
CAF_REQUIRE_EQUAL(mpx.init(), none); go_time->arrive_and_wait();
auto sockets = unbox(make_stream_socket_pair()); auto [alice_fd, bob_fd] = unbox(make_stream_socket_pair());
auto alice = make_counted<dummy_manager>(manager_count, sockets.first, &mpx); auto alice = make_manager(alice_fd, "Alice");
auto bob = make_counted<dummy_manager>(manager_count, sockets.second, &mpx); auto bob = make_manager(bob_fd, "Bob");
alice->register_reading(); alice->register_reading();
bob->register_reading(); bob->register_reading();
CAF_REQUIRE_EQUAL(mpx.num_socket_managers(), 3u); WHEN("calling shutdown on the multiplexer") {
std::thread mpx_thread{run_mpx};
std::unique_lock<std::mutex> lk(m);
cv.wait(lk, [&] { return thread_id_set; });
mpx.shutdown(); mpx.shutdown();
THEN("the thread terminates and all socket managers get shut down") {
mpx_thread.join(); mpx_thread.join();
CHECK(alice->read_closed());
CHECK(bob->read_closed());
}
}
}
}
SCENARIO("a multiplexer allows managers to perform socket handovers") {
GIVEN("an initialized multiplexer") {
init();
WHEN("socket manager triggers a handover") {
auto [alice_fd, bob_fd] = unbox(make_stream_socket_pair());
auto alice = make_manager(alice_fd, "Alice");
auto bob = make_manager(bob_fd, "Bob");
alice->register_reading();
bob->register_reading();
apply_updates();
CHECK_EQ(mpx.num_socket_managers(), 3u);
THEN("the multiplexer swaps out the socket managers for the socket") {
alice->send("Hello Bob!");
alice->register_writing();
exhaust();
CHECK_EQ(bob->receive(), "Hello Bob!");
bob->trigger_handover = true;
alice->send("Hello Carl!");
alice->register_writing();
bob->register_reading();
exhaust();
CHECK_EQ(bob->receive(), "");
CHECK_EQ(bob->handle(), invalid_socket);
if (CHECK_NE(bob->next, nullptr)) {
auto carl = bob->next;
CHECK_EQ(carl->handle(), socket{bob_fd});
carl->register_reading();
exhaust();
CHECK_EQ(carl->name, "Carl");
CHECK_EQ(carl->receive(), "Hello Carl!");
}
}
}
}
} }
CAF_TEST_FIXTURE_SCOPE_END() END_FIXTURE_SCOPE()
...@@ -148,3 +148,31 @@ private: ...@@ -148,3 +148,31 @@ private:
caf::error abort_reason_; caf::error abort_reason_;
}; };
// Drop-in replacement for std::barrier (based on the TS API as of 2020).
class barrier {
public:
explicit barrier(ptrdiff_t num_threads)
: num_threads_(num_threads), count_(0) {
// nop
}
void arrive_and_wait() {
std::unique_lock<std::mutex> guard{mx_};
auto new_count = ++count_;
if (new_count == num_threads_) {
cv_.notify_all();
} else if (new_count > num_threads_) {
count_ = 1;
cv_.wait(guard, [this] { return count_.load() == num_threads_; });
} else {
cv_.wait(guard, [this] { return count_.load() == num_threads_; });
}
}
private:
ptrdiff_t num_threads_;
std::mutex mx_;
std::atomic<ptrdiff_t> count_;
std::condition_variable cv_;
};
...@@ -163,6 +163,7 @@ struct fixture : host_fixture, test_coordinator_fixture<> { ...@@ -163,6 +163,7 @@ struct fixture : host_fixture, test_coordinator_fixture<> {
if (!predicate()) if (!predicate())
return; return;
for (size_t i = 0; i < 1000; ++i) { for (size_t i = 0; i < 1000; ++i) {
mpx.apply_updates();
mpx.poll_once(false); mpx.poll_once(false);
byte tmp[1024]; byte tmp[1024];
auto bytes = read(self_socket_guard.socket(), make_span(tmp, 1024)); auto bytes = read(self_socket_guard.socket(), make_span(tmp, 1024));
......
...@@ -153,22 +153,24 @@ SCENARIO("calling suspend_reading removes message apps temporarily") { ...@@ -153,22 +153,24 @@ SCENARIO("calling suspend_reading removes message apps temporarily") {
} }
}}; }};
net::multiplexer mpx{nullptr}; net::multiplexer mpx{nullptr};
mpx.set_thread_id();
if (auto err = mpx.init()) if (auto err = mpx.init())
FAIL("mpx.init failed: " << err); FAIL("mpx.init failed: " << err);
mpx.set_thread_id(); mpx.apply_updates();
REQUIRE_EQ(mpx.num_socket_managers(), 1u); REQUIRE_EQ(mpx.num_socket_managers(), 1u);
if (auto err = net::nonblocking(fd2, true)) if (auto err = net::nonblocking(fd2, true))
CAF_FAIL("nonblocking returned an error: " << err); CAF_FAIL("nonblocking returned an error: " << err);
auto mgr = net::make_socket_manager<app<true>, net::length_prefix_framing, auto mgr = net::make_socket_manager<app<true>, net::length_prefix_framing,
net::stream_transport>(fd2, &mpx); net::stream_transport>(fd2, &mpx);
CHECK_EQ(mgr->init(settings{}), none); CHECK_EQ(mgr->init(settings{}), none);
mpx.apply_updates();
REQUIRE_EQ(mpx.num_socket_managers(), 2u); REQUIRE_EQ(mpx.num_socket_managers(), 2u);
CHECK_EQ(mgr->mask(), net::operation::read); CHECK_EQ(mpx.mask_of(mgr), net::operation::read);
auto& state = mgr->top_layer(); auto& state = mgr->top_layer();
WHEN("the app calls suspend_reading") { WHEN("the app calls suspend_reading") {
while (mpx.num_socket_managers() > 1u) while (mpx.num_socket_managers() > 1u)
mpx.poll_once(true); mpx.poll_once(true);
CHECK_EQ(mgr->mask(), net::operation::none); CHECK_EQ(mpx.mask_of(mgr), net::operation::none);
if (CHECK_EQ(state.inputs.size(), 3u)) { if (CHECK_EQ(state.inputs.size(), 3u)) {
CHECK_EQ(state.inputs[0], "first"); CHECK_EQ(state.inputs[0], "first");
CHECK_EQ(state.inputs[1], "second"); CHECK_EQ(state.inputs[1], "second");
...@@ -176,7 +178,8 @@ SCENARIO("calling suspend_reading removes message apps temporarily") { ...@@ -176,7 +178,8 @@ SCENARIO("calling suspend_reading removes message apps temporarily") {
} }
THEN("users can resume it via continue_reading ") { THEN("users can resume it via continue_reading ") {
mgr->continue_reading(); mgr->continue_reading();
CHECK_EQ(mgr->mask(), net::operation::read); mpx.apply_updates();
CHECK_EQ(mpx.mask_of(mgr), net::operation::read);
while (mpx.num_socket_managers() > 1u) while (mpx.num_socket_managers() > 1u)
mpx.poll_once(true); mpx.poll_once(true);
if (CHECK_EQ(state.inputs.size(), 5u)) { if (CHECK_EQ(state.inputs.size(), 5u)) {
......
...@@ -89,8 +89,7 @@ public: ...@@ -89,8 +89,7 @@ public:
template <class LowerLayerPtr> template <class LowerLayerPtr>
void abort(LowerLayerPtr, const error& reason) { void abort(LowerLayerPtr, const error& reason) {
if (reason == caf::sec::socket_disconnected if (reason == caf::sec::socket_disconnected || reason == caf::sec::disposed)
|| reason == caf::sec::discarded)
adapter_->close(); adapter_->close();
else else
adapter_->abort(reason); adapter_->abort(reason);
......
...@@ -166,6 +166,7 @@ struct fixture : host_fixture, test_coordinator_fixture<> { ...@@ -166,6 +166,7 @@ struct fixture : host_fixture, test_coordinator_fixture<> {
if (!predicate()) if (!predicate())
return; return;
for (size_t i = 0; i < 1000; ++i) { for (size_t i = 0; i < 1000; ++i) {
mpx.apply_updates();
mpx.poll_once(false); mpx.poll_once(false);
byte tmp[1024]; byte tmp[1024];
auto bytes = read(self_socket_guard.socket(), make_span(tmp, 1024)); auto bytes = read(self_socket_guard.socket(), make_span(tmp, 1024));
......
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
#include "caf/net/stream_transport.hpp" #include "caf/net/stream_transport.hpp"
#include "caf/net/test/host_fixture.hpp" #include "net-test.hpp"
#include "caf/test/dsl.hpp"
#include "caf/binary_deserializer.hpp" #include "caf/binary_deserializer.hpp"
#include "caf/binary_serializer.hpp" #include "caf/binary_serializer.hpp"
...@@ -26,6 +25,7 @@ using namespace caf; ...@@ -26,6 +25,7 @@ using namespace caf;
using namespace caf::net; using namespace caf::net;
namespace { namespace {
constexpr string_view hello_manager = "hello manager!"; constexpr string_view hello_manager = "hello manager!";
struct fixture : test_coordinator_fixture<>, host_fixture { struct fixture : test_coordinator_fixture<>, host_fixture {
...@@ -36,15 +36,16 @@ struct fixture : test_coordinator_fixture<>, host_fixture { ...@@ -36,15 +36,16 @@ struct fixture : test_coordinator_fixture<>, host_fixture {
recv_buf(1024), recv_buf(1024),
shared_recv_buf{std::make_shared<byte_buffer>()}, shared_recv_buf{std::make_shared<byte_buffer>()},
shared_send_buf{std::make_shared<byte_buffer>()} { shared_send_buf{std::make_shared<byte_buffer>()} {
if (auto err = mpx.init())
CAF_FAIL("mpx.init failed: " << err);
mpx.set_thread_id(); mpx.set_thread_id();
CAF_CHECK_EQUAL(mpx.num_socket_managers(), 1u); mpx.apply_updates();
if (auto err = mpx.init())
FAIL("mpx.init failed: " << err);
REQUIRE_EQ(mpx.num_socket_managers(), 1u);
auto sockets = unbox(make_stream_socket_pair()); auto sockets = unbox(make_stream_socket_pair());
send_socket_guard.reset(sockets.first); send_socket_guard.reset(sockets.first);
recv_socket_guard.reset(sockets.second); recv_socket_guard.reset(sockets.second);
if (auto err = nonblocking(recv_socket_guard.socket(), true)) if (auto err = nonblocking(recv_socket_guard.socket(), true))
CAF_FAIL("nonblocking returned an error: " << err); FAIL("nonblocking returned an error: " << err);
} }
bool handle_io_event() override { bool handle_io_event() override {
...@@ -82,7 +83,7 @@ public: ...@@ -82,7 +83,7 @@ public:
template <class ParentPtr> template <class ParentPtr>
bool prepare_send(ParentPtr parent) { bool prepare_send(ParentPtr parent) {
CAF_MESSAGE("prepare_send called"); MESSAGE("prepare_send called");
auto& buf = parent->output_buffer(); auto& buf = parent->output_buffer();
auto data = as_bytes(make_span(hello_manager)); auto data = as_bytes(make_span(hello_manager));
buf.insert(buf.end(), data.begin(), data.end()); buf.insert(buf.end(), data.begin(), data.end());
...@@ -91,44 +92,30 @@ public: ...@@ -91,44 +92,30 @@ public:
template <class ParentPtr> template <class ParentPtr>
bool done_sending(ParentPtr) { bool done_sending(ParentPtr) {
CAF_MESSAGE("done_sending called"); MESSAGE("done_sending called");
return true; return true;
} }
template <class ParentPtr> template <class ParentPtr>
void continue_reading(ParentPtr) { void continue_reading(ParentPtr) {
CAF_FAIL("continue_reading called"); FAIL("continue_reading called");
} }
template <class ParentPtr> template <class ParentPtr>
size_t consume(ParentPtr, span<const byte> data, span<const byte>) { size_t consume(ParentPtr, span<const byte> data, span<const byte>) {
recv_buf_->clear(); recv_buf_->clear();
recv_buf_->insert(recv_buf_->begin(), data.begin(), data.end()); recv_buf_->insert(recv_buf_->begin(), data.begin(), data.end());
CAF_MESSAGE("Received " << recv_buf_->size() MESSAGE("Received " << recv_buf_->size() << " bytes in dummy_application");
<< " bytes in dummy_application");
return recv_buf_->size(); return recv_buf_->size();
} }
template <class ParentPtr>
void resolve(ParentPtr parent, string_view path, const actor& listener) {
actor_id aid = 42;
auto hid = string_view("0011223344556677889900112233445566778899");
auto nid = unbox(make_node_id(42, hid));
actor_config cfg;
endpoint_manager_ptr ptr{&parent->manager()};
auto p = make_actor<actor_proxy_impl, strong_actor_ptr>(
aid, nid, &parent->system(), cfg, std::move(ptr));
anon_send(listener, resolve_atom_v, std::string{path.begin(), path.end()},
p);
}
static void handle_error(sec code) { static void handle_error(sec code) {
CAF_FAIL("handle_error called with " << CAF_ARG(code)); FAIL("handle_error called with " << CAF_ARG(code));
} }
template <class ParentPtr> template <class ParentPtr>
static void abort(ParentPtr, const error& reason) { static void abort(ParentPtr, const error& reason) {
CAF_FAIL("abort called with " << CAF_ARG(reason)); FAIL("abort called with " << CAF_ARG(reason));
} }
private: private:
...@@ -138,20 +125,20 @@ private: ...@@ -138,20 +125,20 @@ private:
} // namespace } // namespace
CAF_TEST_FIXTURE_SCOPE(endpoint_manager_tests, fixture) BEGIN_FIXTURE_SCOPE(fixture)
CAF_TEST(receive) { CAF_TEST(receive) {
auto mgr = make_socket_manager<dummy_application, stream_transport>( auto mgr = make_socket_manager<dummy_application, stream_transport>(
recv_socket_guard.release(), &mpx, shared_recv_buf, shared_send_buf); recv_socket_guard.release(), &mpx, shared_recv_buf, shared_send_buf);
CAF_CHECK_EQUAL(mgr->init(config), none); CHECK_EQ(mgr->init(config), none);
CAF_CHECK_EQUAL(mpx.num_socket_managers(), 2u); mpx.apply_updates();
CAF_CHECK_EQUAL(static_cast<size_t>( CHECK_EQ(mpx.num_socket_managers(), 2u);
write(send_socket_guard.socket(), CHECK_EQ(static_cast<size_t>(write(send_socket_guard.socket(),
as_bytes(make_span(hello_manager)))), as_bytes(make_span(hello_manager)))),
hello_manager.size()); hello_manager.size());
CAF_MESSAGE("wrote " << hello_manager.size() << " bytes."); MESSAGE("wrote " << hello_manager.size() << " bytes.");
run(); run();
CAF_CHECK_EQUAL(string_view(reinterpret_cast<char*>(shared_recv_buf->data()), CHECK_EQ(string_view(reinterpret_cast<char*>(shared_recv_buf->data()),
shared_recv_buf->size()), shared_recv_buf->size()),
hello_manager); hello_manager);
} }
...@@ -159,18 +146,20 @@ CAF_TEST(receive) { ...@@ -159,18 +146,20 @@ CAF_TEST(receive) {
CAF_TEST(send) { CAF_TEST(send) {
auto mgr = make_socket_manager<dummy_application, stream_transport>( auto mgr = make_socket_manager<dummy_application, stream_transport>(
recv_socket_guard.release(), &mpx, shared_recv_buf, shared_send_buf); recv_socket_guard.release(), &mpx, shared_recv_buf, shared_send_buf);
CAF_CHECK_EQUAL(mgr->init(config), none); CHECK_EQ(mgr->init(config), none);
CAF_CHECK_EQUAL(mpx.num_socket_managers(), 2u); mpx.apply_updates();
CHECK_EQ(mpx.num_socket_managers(), 2u);
mgr->register_writing(); mgr->register_writing();
mpx.apply_updates();
while (handle_io_event()) while (handle_io_event())
; ;
recv_buf.resize(hello_manager.size()); recv_buf.resize(hello_manager.size());
auto res = read(send_socket_guard.socket(), make_span(recv_buf)); auto res = read(send_socket_guard.socket(), make_span(recv_buf));
CAF_MESSAGE("received " << res << " bytes"); MESSAGE("received " << res << " bytes");
recv_buf.resize(res); recv_buf.resize(res);
CAF_CHECK_EQUAL(string_view(reinterpret_cast<char*>(recv_buf.data()), CHECK_EQ(string_view(reinterpret_cast<char*>(recv_buf.data()),
recv_buf.size()), recv_buf.size()),
hello_manager); hello_manager);
} }
CAF_TEST_FIXTURE_SCOPE_END() END_FIXTURE_SCOPE()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment