Commit 51fdf3d4 authored by Dominik Charousset's avatar Dominik Charousset

Implement publisher adapter

parent df9e8117
...@@ -61,7 +61,8 @@ caf_incubator_add_component( ...@@ -61,7 +61,8 @@ caf_incubator_add_component(
multiplexer multiplexer
net.actor_shell net.actor_shell
net.length_prefix_framing net.length_prefix_framing
net.subscriber_adapter net.observer_adapter
net.publisher_adapter
net.typed_actor_shell net.typed_actor_shell
net.web_socket.client net.web_socket.client
net.web_socket.handshake net.web_socket.handshake
......
...@@ -64,6 +64,12 @@ stack *up*. Outgoing data always travels the protocol stack *down*. ...@@ -64,6 +64,12 @@ stack *up*. Outgoing data always travels the protocol stack *down*.
/// event loop, `false` otherwise. /// event loop, `false` otherwise.
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool done_sending(LowerLayerPtr down); bool done_sending(LowerLayerPtr down);
/// When provided, the underlying transport calls this member function
/// before leaving `handle_read_event`. The primary use case for this
/// callback is flushing buffers.
template <class LowerLayerPtr>
[[optional]] void after_reading(LowerLayerPtr down);
} }
interface base [role: lower layer] { interface base [role: lower layer] {
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
namespace caf::detail {
template <class T, class LowerLayerPtr>
class has_after_reading {
private:
template <class A, class B>
static auto sfinae(A& up, B& ptr)
-> decltype(up.after_reading(ptr), std::true_type{});
template <class A>
static std::false_type sfinae(A&, ...);
using sfinae_result
= decltype(sfinae(std::declval<T&>(), std::declval<LowerLayerPtr&>()));
public:
static constexpr bool value = sfinae_result::value;
};
template <class T, class LowerLayerPtr>
constexpr bool has_after_reading_v = has_after_reading<T, LowerLayerPtr>::value;
} // namespace caf::detail
...@@ -13,7 +13,7 @@ namespace caf::net { ...@@ -13,7 +13,7 @@ namespace caf::net {
// -- templates ---------------------------------------------------------------- // -- templates ----------------------------------------------------------------
template <class Application> template <class UpperLayer>
class stream_transport; class stream_transport;
template <class Factory> template <class Factory>
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <type_traits>
#include "caf/byte.hpp" #include "caf/byte.hpp"
#include "caf/byte_span.hpp" #include "caf/byte_span.hpp"
#include "caf/detail/has_after_reading.hpp"
#include "caf/detail/network_order.hpp" #include "caf/detail/network_order.hpp"
#include "caf/error.hpp" #include "caf/error.hpp"
#include "caf/net/message_oriented_layer_ptr.hpp" #include "caf/net/message_oriented_layer_ptr.hpp"
...@@ -35,9 +37,9 @@ public: ...@@ -35,9 +37,9 @@ public:
using length_prefix_type = uint32_t; using length_prefix_type = uint32_t;
static constexpr size_t max_message_length = INT32_MAX - sizeof(uint32_t); static constexpr size_t hdr_size = sizeof(uint32_t);
static constexpr uint32_t default_receive_size = 4 * 1024; // 4kb. static constexpr size_t max_message_length = INT32_MAX - sizeof(uint32_t);
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
...@@ -51,8 +53,7 @@ public: ...@@ -51,8 +53,7 @@ public:
template <class LowerLayerPtr> template <class LowerLayerPtr>
error init(socket_manager* owner, LowerLayerPtr down, const settings& cfg) { error init(socket_manager* owner, LowerLayerPtr down, const settings& cfg) {
down->configure_read( down->configure_read(receive_policy::exactly(hdr_size));
receive_policy::between(sizeof(uint32_t), default_receive_size));
return upper_layer_.init(owner, this_layer_ptr(down), cfg); return upper_layer_.init(owner, this_layer_ptr(down), cfg);
} }
...@@ -78,6 +79,11 @@ public: ...@@ -78,6 +79,11 @@ public:
return down->handle(); return down->handle();
} }
template <class LowerLayerPtr>
static void suspend_reading(LowerLayerPtr down) {
return down->suspend_reading();
}
template <class LowerLayerPtr> template <class LowerLayerPtr>
void begin_message(LowerLayerPtr down) { void begin_message(LowerLayerPtr down) {
down->begin_output(); down->begin_output();
...@@ -123,6 +129,14 @@ public: ...@@ -123,6 +129,14 @@ public:
// -- interface for the lower layer ------------------------------------------ // -- interface for the lower layer ------------------------------------------
template <class LowerLayerPtr>
std::enable_if_t<detail::has_after_reading_v<
UpperLayer,
message_oriented_layer_ptr<length_prefix_framing, LowerLayerPtr>>>
after_reading(LowerLayerPtr down) {
return upper_layer_.after_reading(this_layer_ptr(down));
}
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool prepare_send(LowerLayerPtr down) { bool prepare_send(LowerLayerPtr down) {
return upper_layer_.prepare_send(this_layer_ptr(down)); return upper_layer_.prepare_send(this_layer_ptr(down));
...@@ -140,40 +154,41 @@ public: ...@@ -140,40 +154,41 @@ public:
template <class LowerLayerPtr> template <class LowerLayerPtr>
ptrdiff_t consume(LowerLayerPtr down, byte_span input, byte_span) { ptrdiff_t consume(LowerLayerPtr down, byte_span input, byte_span) {
auto buffer = input;
auto consumed = ptrdiff_t{0};
auto this_layer = this_layer_ptr(down); auto this_layer = this_layer_ptr(down);
for (;;) { if (input.size() < sizeof(uint32_t)) {
if (input.size() < sizeof(uint32_t)) { auto err = make_error(sec::runtime_error,
return consumed; "received too few bytes from underlying transport");
down->abort_reason(std::move(err));
return -1;
} else if (input.size() == sizeof(uint32_t)) {
auto u32_size = uint32_t{0};
memcpy(&u32_size, input.data(), sizeof(uint32_t));
auto msg_size = static_cast<size_t>(detail::from_network_order(u32_size));
if (msg_size == 0) {
// Ignore empty messages.
return static_cast<ptrdiff_t>(input.size());
} else if (msg_size > max_message_length) {
auto err = make_error(sec::runtime_error,
"maximum message size exceeded");
down->abort_reason(std::move(err));
return -1;
} else { } else {
auto [msg_size, sub_buffer] = split(input); down->configure_read(receive_policy::exactly(hdr_size + msg_size));
if (msg_size == 0) { return 0;
consumed += static_cast<ptrdiff_t>(sizeof(uint32_t)); }
input = sub_buffer; } else {
} else if (msg_size > max_message_length) { auto [msg_size, msg] = split(input);
auto err = make_error(sec::runtime_error, if (msg_size == msg.size() && msg_size + hdr_size == input.size()) {
"maximum message size exceeded"); if (upper_layer_.consume(this_layer, msg) >= 0) {
down->abort_reason(std::move(err)); down->configure_read(receive_policy::exactly(hdr_size));
return -1; return static_cast<ptrdiff_t>(input.size());
} else if (msg_size > sub_buffer.size()) {
if (msg_size + sizeof(uint32_t) > receive_buf_upper_bound_) {
auto min_read_size = static_cast<uint32_t>(sizeof(uint32_t));
receive_buf_upper_bound_
= static_cast<uint32_t>(msg_size + sizeof(uint32_t));
down->configure_read(
receive_policy::between(min_read_size, receive_buf_upper_bound_));
}
return consumed;
} else { } else {
auto msg = sub_buffer.subspan(0, msg_size); return -1;
if (auto res = upper_layer_.consume(this_layer, msg); res >= 0) {
consumed += static_cast<ptrdiff_t>(msg.size()) + sizeof(uint32_t);
input = sub_buffer.subspan(msg_size);
} else {
return -1;
}
} }
} else {
auto err = make_error(sec::runtime_error, "received malformed message");
down->abort_reason(std::move(err));
return -1;
} }
} }
} }
...@@ -200,7 +215,6 @@ private: ...@@ -200,7 +215,6 @@ private:
UpperLayer upper_layer_; UpperLayer upper_layer_;
size_t message_offset_ = 0; size_t message_offset_ = 0;
uint32_t receive_buf_upper_bound_ = default_receive_size;
}; };
} // namespace caf::net } // namespace caf::net
...@@ -22,6 +22,10 @@ public: ...@@ -22,6 +22,10 @@ public:
// nop // nop
} }
void suspend_reading() {
return lptr_->suspend_reading(llptr_);
}
bool can_send_more() const noexcept { bool can_send_more() const noexcept {
return lptr_->can_send_more(llptr_); return lptr_->can_send_more(llptr_);
} }
......
...@@ -70,6 +70,10 @@ public: ...@@ -70,6 +70,10 @@ public:
/// @thread-safe /// @thread-safe
void register_writing(const socket_manager_ptr& mgr); void register_writing(const socket_manager_ptr& mgr);
/// Schedules a call to `mgr->handle_error(sec::discarded)`.
/// @thread-safe
void discard(const socket_manager_ptr& mgr);
/// Registers `mgr` for initialization in the multiplexer's thread. /// Registers `mgr` for initialization in the multiplexer's thread.
/// @thread-safe /// @thread-safe
void init(const socket_manager_ptr& mgr); void init(const socket_manager_ptr& mgr);
......
...@@ -4,24 +4,32 @@ ...@@ -4,24 +4,32 @@
#pragma once #pragma once
#include "caf/flow/poll_subscriber.hpp" #include "caf/async/observer_buffer.hpp"
#include "caf/net/multiplexer.hpp" #include "caf/net/multiplexer.hpp"
#include "caf/net/socket_manager.hpp" #include "caf/net/socket_manager.hpp"
namespace caf::net { namespace caf::net {
/// Base class for buffered consumption of published items. /// Connects a socket manager to an asynchronous publisher using a buffer.
/// Whenever the buffer becomes non-empty, the adapter registers the socket
/// manager for writing. The usual pattern for using the adapter then is to call
/// `poll` on the adapter in `prepare_send`.
template <class T> template <class T>
class subscriber_adapter : public flow::poll_subscriber<T> { class observer_adapter : public async::observer_buffer<T> {
public: public:
using super = flow::poll_subscriber<T>; using super = async::observer_buffer<T>;
explicit subscriber_adapter(socket_manager* owner) : mgr_(owner) { explicit observer_adapter(socket_manager* owner) : mgr_(owner) {
// nop // nop
} }
private: private:
void wakeup(std::unique_lock<std::mutex>&) { void deinit(std::unique_lock<std::mutex>& guard) final {
wakeup(guard);
mgr_ = nullptr;
}
void wakeup(std::unique_lock<std::mutex>&) final {
mgr_->mpx().register_writing(mgr_); mgr_->mpx().register_writing(mgr_);
} }
...@@ -29,6 +37,6 @@ private: ...@@ -29,6 +37,6 @@ private:
}; };
template <class T> template <class T>
using subscriber_adapter_ptr = intrusive_ptr<subscriber_adapter<T>>; using observer_adapter_ptr = intrusive_ptr<observer_adapter<T>>;
} // namespace caf::net } // namespace caf::net
...@@ -30,6 +30,8 @@ public: ...@@ -30,6 +30,8 @@ public:
static constexpr uint8_t init_manager_code = 0x02; static constexpr uint8_t init_manager_code = 0x02;
static constexpr uint8_t discard_manager_code = 0x03;
static constexpr uint8_t shutdown_code = 0x04; static constexpr uint8_t shutdown_code = 0x04;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include <memory>
#include <new>
#include "caf/async/publisher.hpp"
#include "caf/flow/observer.hpp"
#include "caf/flow/subscription.hpp"
#include "caf/net/multiplexer.hpp"
#include "caf/net/socket_manager.hpp"
namespace caf::net {
template <class T>
class publisher_adapter final : public async::publisher<T>::impl,
public flow::subscription::impl {
public:
publisher_adapter(socket_manager* owner, uint32_t max_in_flight,
uint32_t batch_size)
: batch_size_(batch_size), max_in_flight_(max_in_flight), mgr_(owner) {
CAF_ASSERT(max_in_flight > batch_size);
buf_ = reinterpret_cast<T*>(malloc(sizeof(T) * max_in_flight * 2));
}
~publisher_adapter() {
auto first = buf_ + rd_pos_;
auto last = buf_ + wr_pos_;
std::destroy(first, last);
free(buf_);
}
void subscribe(flow::observer<T> sink) override {
if (std::unique_lock guard{mtx_}; !sink_) {
sink_ = std::move(sink);
auto ptr = intrusive_ptr<flow::subscription::impl>{this};
sink_.on_attach(flow::subscription{std::move(ptr)});
} else {
sink.on_error(
make_error(sec::downstream_already_exists,
"caf::net::publisher_adapter only allows one observer"));
}
}
void request(size_t n) override {
CAF_ASSERT(n > 0);
// Reactive Streams specification 1.0.3:
// > Subscription.request MUST place an upper bound on possible synchronous
// > recursion between Publisher and Subscriber.
std::unique_lock guard{mtx_};
if (!sink_)
return;
credit_ += static_cast<uint32_t>(n);
if (!in_request_body_) {
in_request_body_ = true;
auto n = std::min(size(), credit_);
// When full, we take whatever we can out of the buffer even if the client
// requests less than a batch. Otherwise, we try to wait until we have
// sufficient credit for a full batch.
if (n == 0) {
in_request_body_ = false;
return;
} else if (full()) {
wakeup();
} else if (n < batch_size_) {
in_request_body_ = false;
return;
}
auto m = std::min(n, batch_size_);
deliver(m);
n -= m;
while (sink_ && n >= batch_size_) {
deliver(batch_size_);
n -= batch_size_;
}
shift_elements();
in_request_body_ = false;
}
}
void cancel() override {
std::unique_lock guard{mtx_};
discard();
}
void on_complete() {
std::unique_lock guard{mtx_};
if (sink_) {
sink_.on_complete();
sink_ = nullptr;
}
}
void on_error(const error& what) {
std::unique_lock guard{mtx_};
if (sink_) {
sink_.on_error(what);
sink_ = nullptr;
}
}
/// Enqueues a new element to the buffer.
/// @returns The remaining buffer capacity. If this function return 0, the
/// manager MUST suspend reading until the observer consumes at least
/// one element.
size_t push(T value) {
std::unique_lock guard{mtx_};
if (!mgr_)
return 0;
new (buf_ + wr_pos_) T(std::move(value));
++wr_pos_;
if (auto n = std::min(size(), credit_); n >= batch_size_) {
do {
deliver(n);
n -= batch_size_;
} while (n >= batch_size_);
shift_elements();
}
if (auto result = capacity(); result == 0 && credit_ > 0) {
// Can only reach here if batch_size_ > credit_.
deliver(credit_);
shift_elements();
return capacity();
} else {
return result;
}
}
/// Pushes any buffered items to the observer as long as there is any
/// available credit.
void flush() {
std::unique_lock guard{mtx_};
while (sink_) {
if (auto n = std::min({size(), credit_, batch_size_}); n > 0)
deliver(n);
else
break;
}
shift_elements();
}
private:
void discard() {
if (mgr_) {
sink_ = nullptr;
mgr_->mpx().discard(mgr_);
mgr_ = nullptr;
credit_ = 0;
}
}
/// @pre `mtx_` is locked
[[nodiscard]] uint32_t size() const noexcept {
return wr_pos_ - rd_pos_;
}
/// @pre `mtx_` is locked
[[nodiscard]] uint32_t capacity() const noexcept {
return max_in_flight_ - size();
}
/// @pre `mtx_` is locked
[[nodiscard]] bool full() const noexcept {
return capacity() == 0;
}
/// @pre `mtx_` is locked
[[nodiscard]] bool empty() const noexcept {
return wr_pos_ == rd_pos_;
}
/// @pre `mtx_` is locked
void wakeup() {
CAF_ASSERT(mgr_ != nullptr);
mgr_->mpx().register_reading(mgr_);
}
void deliver(uint32_t n) {
auto first = buf_ + rd_pos_;
auto last = first + n;
sink_.on_next(span<const T>{first, n});
std::destroy(first, last);
CAF_ASSERT(rd_pos_ + n <= wr_pos_);
rd_pos_ += n;
CAF_ASSERT(credit_ >= n);
credit_ -= n;
}
void shift_elements() {
if (rd_pos_ >= max_in_flight_) {
if (empty()) {
rd_pos_ = 0;
wr_pos_ = 0;
} else {
// No need to check for overlap: the first half of the buffer is empty.
auto first = buf_ + rd_pos_;
auto last = buf_ + wr_pos_;
std::uninitialized_move(first, last, buf_);
std::destroy(first, last);
wr_pos_ -= rd_pos_;
rd_pos_ = 0;
}
}
}
std::recursive_mutex mtx_;
/// Allocated to max_in_flight_ * 2, but at most holds max_in_flight_ elements
/// at any point in time. We dynamically shift elements into the first half of
/// the buffer whenever rd_pos_ crosses the midpoint.
T* buf_;
uint32_t rd_pos_ = 0;
uint32_t wr_pos_ = 0;
uint32_t credit_ = 0;
uint32_t batch_size_;
uint32_t max_in_flight_;
bool in_request_body_ = false;
flow::observer<T> sink_;
intrusive_ptr<socket_manager> mgr_;
};
template <class T>
using publisher_adapter_ptr = intrusive_ptr<publisher_adapter<T>>;
} // namespace caf::net
...@@ -23,6 +23,10 @@ public: ...@@ -23,6 +23,10 @@ public:
// nop // nop
} }
void suspend_reading() {
return lptr_->suspend_reading(llptr_);
}
bool can_send_more() const noexcept { bool can_send_more() const noexcept {
return lptr_->can_send_more(llptr_); return lptr_->can_send_more(llptr_);
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "caf/byte_buffer.hpp" #include "caf/byte_buffer.hpp"
#include "caf/defaults.hpp" #include "caf/defaults.hpp"
#include "caf/detail/has_after_reading.hpp"
#include "caf/fwd.hpp" #include "caf/fwd.hpp"
#include "caf/logger.hpp" #include "caf/logger.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
...@@ -48,6 +49,11 @@ public: ...@@ -48,6 +49,11 @@ public:
// -- interface for stream_oriented_layer_ptr -------------------------------- // -- interface for stream_oriented_layer_ptr --------------------------------
template <class ParentPtr>
void suspend_reading(ParentPtr) {
suspend_reading_ = true;
}
template <class ParentPtr> template <class ParentPtr>
bool can_send_more(ParentPtr) const noexcept { bool can_send_more(ParentPtr) const noexcept {
return write_buf_.size() < max_write_buf_size_; return write_buf_.size() < max_write_buf_size_;
...@@ -160,9 +166,13 @@ public: ...@@ -160,9 +166,13 @@ public:
if (read_buf_.size() < max_read_size_) if (read_buf_.size() < max_read_size_)
read_buf_.resize(max_read_size_); read_buf_.resize(max_read_size_);
auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent); auto this_layer_ptr = make_stream_oriented_layer_ptr(this, parent);
static constexpr bool has_after_reading
= detail::has_after_reading_v<UpperLayer, decltype(this_layer_ptr)>;
for (size_t i = 0; max_read_size_ > 0 && i < max_consecutive_reads_; ++i) { for (size_t i = 0; max_read_size_ > 0 && i < max_consecutive_reads_; ++i) {
// Calling configure_read(read_policy::stop()) halts receive events. // Calling configure_read(read_policy::stop()) halts receive events.
if (max_read_size_ == 0) { if (max_read_size_ == 0) {
if constexpr (has_after_reading)
upper_layer_.after_reading(this_layer_ptr);
return false; return false;
} else if (offset_ >= max_read_size_) { } else if (offset_ >= max_read_size_) {
auto old_max = max_read_size_; auto old_max = max_read_size_;
...@@ -223,19 +233,31 @@ public: ...@@ -223,19 +233,31 @@ public:
if (read_buf_.size() != max_read_size_) if (read_buf_.size() != max_read_size_)
if (offset_ < max_read_size_) if (offset_ < max_read_size_)
read_buf_.resize(max_read_size_); read_buf_.resize(max_read_size_);
// Upper layer may have called suspend_reading().
if (suspend_reading_) {
suspend_reading_ = false;
if constexpr (has_after_reading)
upper_layer_.after_reading(this_layer_ptr);
return false;
}
} else if (read_res < 0) { } else if (read_res < 0) {
// Try again later on temporary errors such as EWOULDBLOCK and // Try again later on temporary errors such as EWOULDBLOCK and
// stop reading on the socket on hard errors. // stop reading on the socket on hard errors.
return last_socket_error_is_temporary() if (last_socket_error_is_temporary()) {
? true if constexpr (has_after_reading)
: fail(sec::socket_operation_failed); upper_layer_.after_reading(this_layer_ptr);
return true;
} else {
return fail(sec::socket_operation_failed);
}
} else { } else {
// read() returns 0 iff the connection was closed. // read() returns 0 iff the connection was closed.
return fail(sec::socket_disconnected); return fail(sec::socket_disconnected);
} }
} }
// Calling configure_read(read_policy::stop()) halts receive events. // Calling configure_read(read_policy::stop()) halts receive events.
if constexpr (has_after_reading)
upper_layer_.after_reading(this_layer_ptr);
return max_read_size_ > 0; return max_read_size_ > 0;
} }
...@@ -301,6 +323,9 @@ private: ...@@ -301,6 +323,9 @@ private:
// Stores the offset in `read_buf_` since last calling `upper_layer_.consume`. // Stores the offset in `read_buf_` since last calling `upper_layer_.consume`.
ptrdiff_t delta_offset_ = 0; ptrdiff_t delta_offset_ = 0;
// Stores whether the user called `suspend_reading()`.
bool suspend_reading_ = false;
// Caches incoming data. // Caches incoming data.
byte_buffer read_buf_; byte_buffer read_buf_;
......
...@@ -118,7 +118,7 @@ void multiplexer::register_reading(const socket_manager_ptr& mgr) { ...@@ -118,7 +118,7 @@ void multiplexer::register_reading(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_) { if (shutting_down_) {
// discard // nop
} else if (mgr->mask() != operation::none) { } else if (mgr->mask() != operation::none) {
if (auto index = index_of(mgr); if (auto index = index_of(mgr);
index != -1 && mgr->mask_add(operation::read)) { index != -1 && mgr->mask_add(operation::read)) {
...@@ -137,7 +137,7 @@ void multiplexer::register_writing(const socket_manager_ptr& mgr) { ...@@ -137,7 +137,7 @@ void multiplexer::register_writing(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_) { if (shutting_down_) {
// discard // nop
} else if (mgr->mask() != operation::none) { } else if (mgr->mask() != operation::none) {
if (auto index = index_of(mgr); if (auto index = index_of(mgr);
index != -1 && mgr->mask_add(operation::write)) { index != -1 && mgr->mask_add(operation::write)) {
...@@ -152,11 +152,24 @@ void multiplexer::register_writing(const socket_manager_ptr& mgr) { ...@@ -152,11 +152,24 @@ void multiplexer::register_writing(const socket_manager_ptr& mgr) {
} }
} }
void multiplexer::discard(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) {
if (shutting_down_) {
// nop
} else {
mgr->handle_error(sec::discarded);
}
} else {
write_to_pipe(pollset_updater::discard_manager_code, mgr);
}
}
void multiplexer::init(const socket_manager_ptr& mgr) { void multiplexer::init(const socket_manager_ptr& mgr) {
CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id)); CAF_LOG_TRACE(CAF_ARG2("socket", mgr->handle().id));
if (std::this_thread::get_id() == tid_) { if (std::this_thread::get_id() == tid_) {
if (shutting_down_) { if (shutting_down_) {
// discard // nop
} else { } else {
if (auto err = mgr->init(content(system().config()))) { if (auto err = mgr->init(content(system().config()))) {
CAF_LOG_ERROR("mgr->init failed: " << err); CAF_LOG_ERROR("mgr->init failed: " << err);
...@@ -293,14 +306,12 @@ short multiplexer::handle(const socket_manager_ptr& mgr, short events, ...@@ -293,14 +306,12 @@ short multiplexer::handle(const socket_manager_ptr& mgr, short events,
checkerror = false; checkerror = false;
if (!mgr->handle_read_event()) { if (!mgr->handle_read_event()) {
mgr->mask_del(operation::read); mgr->mask_del(operation::read);
events &= ~input_mask;
} }
} }
if ((revents & output_mask) != 0) { if ((revents & output_mask) != 0) {
checkerror = false; checkerror = false;
if (!mgr->handle_write_event()) { if (!mgr->handle_write_event()) {
mgr->mask_del(operation::write); mgr->mask_del(operation::write);
events &= ~output_mask;
} }
} }
if (checkerror && ((revents & error_mask) != 0)) { if (checkerror && ((revents & error_mask) != 0)) {
...@@ -312,6 +323,20 @@ short multiplexer::handle(const socket_manager_ptr& mgr, short events, ...@@ -312,6 +323,20 @@ short multiplexer::handle(const socket_manager_ptr& mgr, short events,
mgr->handle_error(sec::socket_operation_failed); mgr->handle_error(sec::socket_operation_failed);
mgr->mask_del(operation::read_write); mgr->mask_del(operation::read_write);
events = 0; events = 0;
} else {
switch (mgr->mask()){
case operation::read:
events = input_mask;
break;
case operation::write:
events = output_mask;
break;
case operation::read_write:
events = input_mask | output_mask;
break;
default:
events = 0;
}
} }
return events; return events;
} }
......
...@@ -53,6 +53,9 @@ bool pollset_updater::handle_read_event() { ...@@ -53,6 +53,9 @@ bool pollset_updater::handle_read_event() {
case init_manager_code: case init_manager_code:
parent_->init(mgr); parent_->init(mgr);
break; break;
case discard_manager_code:
parent_->discard(mgr);
break;
case shutdown_code: case shutdown_code:
parent_->shutdown(); parent_->shutdown();
break; break;
......
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
#include "caf/byte_buffer.hpp" #include "caf/byte_buffer.hpp"
#include "caf/byte_span.hpp" #include "caf/byte_span.hpp"
#include "caf/detail/network_order.hpp" #include "caf/detail/network_order.hpp"
#include "caf/net/multiplexer.hpp"
#include "caf/net/socket_guard.hpp"
#include "caf/net/socket_manager.hpp"
#include "caf/net/stream_socket.hpp"
#include "caf/net/stream_transport.hpp"
#include "caf/span.hpp" #include "caf/span.hpp"
#include "caf/tag/message_oriented.hpp" #include "caf/tag/message_oriented.hpp"
...@@ -27,6 +32,7 @@ namespace { ...@@ -27,6 +32,7 @@ namespace {
using string_list = std::vector<std::string>; using string_list = std::vector<std::string>;
template <bool EnableSuspend>
struct app { struct app {
using input_tag = tag::message_oriented; using input_tag = tag::message_oriented;
...@@ -56,6 +62,9 @@ struct app { ...@@ -56,6 +62,9 @@ struct app {
if (CHECK(std::all_of(buf.begin(), buf.end(), printable))) { if (CHECK(std::all_of(buf.begin(), buf.end(), printable))) {
auto str_buf = reinterpret_cast<char*>(buf.data()); auto str_buf = reinterpret_cast<char*>(buf.data());
inputs.emplace_back(std::string{str_buf, buf.size()}); inputs.emplace_back(std::string{str_buf, buf.size()});
if constexpr (EnableSuspend)
if (inputs.back() == "pause")
down->suspend_reading();
std::string response = "ok "; std::string response = "ok ";
response += std::to_string(inputs.size()); response += std::to_string(inputs.size());
auto response_bytes = as_bytes(make_span(response)); auto response_bytes = as_bytes(make_span(response));
...@@ -86,7 +95,7 @@ auto decode(byte_buffer& buf) { ...@@ -86,7 +95,7 @@ auto decode(byte_buffer& buf) {
string_list result; string_list result;
auto input = make_span(buf); auto input = make_span(buf);
while (!input.empty()) { while (!input.empty()) {
auto [msg_size, msg] = net::length_prefix_framing<app>::split(input); auto [msg_size, msg] = net::length_prefix_framing<app<false>>::split(input);
if (msg_size > msg.size()) { if (msg_size > msg.size()) {
CAF_FAIL("cannot decode buffer: invalid message size"); CAF_FAIL("cannot decode buffer: invalid message size");
} else if (!std::all_of(msg.begin(), msg.begin() + msg_size, printable)) { } else if (!std::all_of(msg.begin(), msg.begin() + msg_size, printable)) {
...@@ -103,15 +112,15 @@ auto decode(byte_buffer& buf) { ...@@ -103,15 +112,15 @@ auto decode(byte_buffer& buf) {
} // namespace } // namespace
SCENARIO("length-prefix framing reads data with 32-bit size headers") { SCENARIO("length-prefix framing reads data with 32-bit size headers") {
GIVEN("a length_prefix_framing with an app that consumed strings") { GIVEN("a length_prefix_framing with an app that consumes strings") {
mock_stream_transport<net::length_prefix_framing<app>> uut;
CHECK_EQ(uut.init(), error{});
WHEN("pushing data into the unit-under-test") { WHEN("pushing data into the unit-under-test") {
encode(uut.input, "hello"); mock_stream_transport<net::length_prefix_framing<app<false>>> uut;
encode(uut.input, "world"); CHECK_EQ(uut.init(), error{});
auto input_size = static_cast<ptrdiff_t>(uut.input.size());
CHECK_EQ(uut.handle_input(), input_size);
THEN("the app receives all strings as individual messages") { THEN("the app receives all strings as individual messages") {
encode(uut.input, "hello");
encode(uut.input, "world");
auto input_size = static_cast<ptrdiff_t>(uut.input.size());
CHECK_EQ(uut.handle_input(), input_size);
auto& state = uut.upper_layer.upper_layer(); auto& state = uut.upper_layer.upper_layer();
if (CHECK_EQ(state.inputs.size(), 2u)) { if (CHECK_EQ(state.inputs.size(), 2u)) {
CHECK_EQ(state.inputs[0], "hello"); CHECK_EQ(state.inputs[0], "hello");
...@@ -124,3 +133,62 @@ SCENARIO("length-prefix framing reads data with 32-bit size headers") { ...@@ -124,3 +133,62 @@ SCENARIO("length-prefix framing reads data with 32-bit size headers") {
} }
} }
} }
SCENARIO("calling suspend_reading removes message apps temporarily") {
using namespace std::literals;
GIVEN("a length_prefix_framing with an app that consumes strings") {
auto [fd1, fd2] = unbox(net::make_stream_socket_pair());
auto writer = std::thread{[fd1{fd1}] {
auto guard = make_socket_guard(fd1);
std::vector<std::string_view> inputs{"first", "second", "pause", "third",
"fourth"};
byte_buffer wr_buf;
byte_buffer rd_buf;
rd_buf.resize(512);
for (auto input : inputs) {
wr_buf.clear();
encode(wr_buf, input);
write(fd1, wr_buf);
read(fd1, rd_buf);
}
}};
net::multiplexer mpx{nullptr};
if (auto err = mpx.init())
FAIL("mpx.init failed: " << err);
mpx.set_thread_id();
REQUIRE_EQ(mpx.num_socket_managers(), 1u);
if (auto err = net::nonblocking(fd2, true))
CAF_FAIL("nonblocking returned an error: " << err);
auto mgr = net::make_socket_manager<app<true>, net::length_prefix_framing,
net::stream_transport>(fd2, &mpx);
CHECK_EQ(mgr->init(settings{}), none);
REQUIRE_EQ(mpx.num_socket_managers(), 2u);
CHECK_EQ(mgr->mask(), net::operation::read);
auto& state = mgr->top_layer();
WHEN("the app calls suspend_reading") {
while (mpx.num_socket_managers() > 1u)
mpx.poll_once(true);
CHECK_EQ(mgr->mask(), net::operation::none);
if (CHECK_EQ(state.inputs.size(), 3u)) {
CHECK_EQ(state.inputs[0], "first");
CHECK_EQ(state.inputs[1], "second");
CHECK_EQ(state.inputs[2], "pause");
}
THEN("users can resume it via register_reading ") {
mpx.register_reading(mgr);
CHECK_EQ(mgr->mask(), net::operation::read);
//mgr->register_reading();
while (mpx.num_socket_managers() > 1u)
mpx.poll_once(true);
if (CHECK_EQ(state.inputs.size(), 5u)) {
CHECK_EQ(state.inputs[0], "first");
CHECK_EQ(state.inputs[1], "second");
CHECK_EQ(state.inputs[2], "pause");
CHECK_EQ(state.inputs[3], "third");
CHECK_EQ(state.inputs[4], "fourth");
}
}
}
writer.join();
}
}
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
// the main distribution directory for license terms and copyright or visit // the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE. // https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#define CAF_SUITE net.subscriber_adapter #define CAF_SUITE net.observer_adapter
#include "caf/net/subscriber_adapter.hpp" #include "caf/net/observer_adapter.hpp"
#include "net-test.hpp" #include "net-test.hpp"
#include "caf/flow/async/publisher.hpp" #include "caf/async/publisher.hpp"
#include "caf/net/middleman.hpp" #include "caf/net/middleman.hpp"
#include "caf/net/socket_guard.hpp" #include "caf/net/socket_guard.hpp"
#include "caf/net/stream_socket.hpp" #include "caf/net/stream_socket.hpp"
...@@ -21,8 +21,6 @@ namespace { ...@@ -21,8 +21,6 @@ namespace {
class reader { class reader {
public: public:
reader(net::stream_socket fd, size_t n) : sg_(fd) { reader(net::stream_socket fd, size_t n) : sg_(fd) {
if (auto err = nonblocking(fd, true))
FAIL("unable to set nonblocking flag: " << err);
buf_.resize(n); buf_.resize(n);
} }
...@@ -60,14 +58,14 @@ class app_t { ...@@ -60,14 +58,14 @@ class app_t {
public: public:
using input_tag = tag::stream_oriented; using input_tag = tag::stream_oriented;
explicit app_t(flow::async::publisher_ptr<int32_t> input) : input_(input) { explicit app_t(async::publisher<int32_t> input) : input_(std::move(input)) {
// nop // nop
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
error init(net::socket_manager* owner, LowerLayerPtr, const settings&) { error init(net::socket_manager* owner, LowerLayerPtr, const settings&) {
adapter_ = make_counted<net::subscriber_adapter<int32_t>>(owner); adapter_ = make_counted<net::observer_adapter<int32_t>>(owner);
input_->async_subscribe(adapter_); input_.subscribe(adapter_->as_observer());
input_ = nullptr; input_ = nullptr;
return none; return none;
} }
...@@ -126,8 +124,8 @@ private: ...@@ -126,8 +124,8 @@ private:
bool done_ = false; bool done_ = false;
std::vector<int32_t> written_values_; std::vector<int32_t> written_values_;
std::vector<byte> written_bytes_; std::vector<byte> written_bytes_;
net::subscriber_adapter_ptr<int32_t> adapter_; net::observer_adapter_ptr<int32_t> adapter_;
flow::async::publisher_ptr<int32_t> input_; async::publisher<int32_t> input_;
}; };
struct fixture : test_coordinator_fixture<>, host_fixture { struct fixture : test_coordinator_fixture<>, host_fixture {
...@@ -151,8 +149,8 @@ BEGIN_FIXTURE_SCOPE(fixture) ...@@ -151,8 +149,8 @@ BEGIN_FIXTURE_SCOPE(fixture)
SCENARIO("subscriber adapters wake up idle socket managers") { SCENARIO("subscriber adapters wake up idle socket managers") {
GIVEN("a publisher<T>") { GIVEN("a publisher<T>") {
static constexpr size_t num_items = 4211; static constexpr size_t num_items = 4211;
auto src = flow::async::publisher_from(sys, [](auto* self) { auto src = async::publisher_from<event_based_actor>(sys, [](auto* self) {
return self->make_publisher()->repeat(42)->take(num_items); return self->make_observable().repeat(42).take(num_items);
}); });
WHEN("sending items of the stream over a socket") { WHEN("sending items of the stream over a socket") {
auto [fd1, fd2] = unbox(net::make_stream_socket_pair()); auto [fd1, fd2] = unbox(net::make_stream_socket_pair());
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#define CAF_SUITE net.publisher_adapter
#include "caf/net/publisher_adapter.hpp"
#include "net-test.hpp"
#include "caf/async/publisher.hpp"
#include "caf/detail/network_order.hpp"
#include "caf/net/length_prefix_framing.hpp"
#include "caf/net/middleman.hpp"
#include "caf/net/socket_guard.hpp"
#include "caf/net/stream_socket.hpp"
#include "caf/scheduled_actor/flow.hpp"
#include "caf/tag/message_oriented.hpp"
using namespace caf;
namespace {
class writer {
public:
explicit writer(net::stream_socket fd) : sg_(fd) {
// nop
}
auto fd() {
return sg_.socket();
}
byte_buffer encode(string_view msg) {
using detail::to_network_order;
auto prefix = to_network_order(static_cast<uint32_t>(msg.size()));
auto prefix_bytes = as_bytes(make_span(&prefix, 1));
byte_buffer buf;
buf.insert(buf.end(), prefix_bytes.begin(), prefix_bytes.end());
auto bytes = as_bytes(make_span(msg));
buf.insert(buf.end(), bytes.begin(), bytes.end());
return buf;
}
void write(string_view msg) {
auto buf = encode(msg);
if (net::write(fd(), buf) < 0)
FAIL("failed to write: " << net::last_socket_error_as_string());
}
private:
net::socket_guard<net::stream_socket> sg_;
};
class app {
public:
using input_tag = tag::message_oriented;
template <class LowerLayerPtr>
error init(net::socket_manager* owner, LowerLayerPtr, const settings&) {
adapter = make_counted<net::publisher_adapter<int32_t>>(owner, 3, 2);
return none;
}
template <class LowerLayerPtr>
bool prepare_send(LowerLayerPtr) {
return true;
}
template <class LowerLayerPtr>
bool done_sending(LowerLayerPtr) {
return true;
}
template <class LowerLayerPtr>
void abort(LowerLayerPtr, const error& reason) {
adapter->flush();
if (reason == caf::sec::socket_disconnected)
adapter->on_complete();
else
adapter->on_error(reason);
}
template <class LowerLayerPtr>
void after_reading(LowerLayerPtr) {
adapter->flush();
}
template <class LowerLayerPtr>
ptrdiff_t consume(LowerLayerPtr down, byte_span buf) {
auto val = int32_t{0};
auto str = string_view{reinterpret_cast<char*>(buf.data()), buf.size()};
if (auto err = detail::parse(str, val))
FAIL("unable to parse input: " << err);
++received_messages;
if (auto n = adapter->push(val); n == 0)
down->suspend_reading();
return static_cast<ptrdiff_t>(buf.size());
}
size_t received_messages = 0;
net::publisher_adapter_ptr<int32_t> adapter;
};
struct mock_observer : flow::observer<int32_t>::impl {
void dispose() {
if (sub) {
sub.cancel();
sub = nullptr;
}
done = true;
}
bool disposed() const noexcept {
return done;
}
void on_complete() {
sub = nullptr;
done = true;
}
void on_error(const error& what) {
FAIL("observer received an error: " << what);
}
void on_attach(flow::subscription new_sub) {
REQUIRE(!sub);
sub = std::move(new_sub);
}
void on_next(span<const int32_t> items) {
buf.insert(buf.end(), items.begin(), items.end());
}
bool done = false;
flow::subscription sub;
std::vector<int32_t> buf;
};
struct fixture {
};
} // namespace
CAF_TEST_FIXTURE_SCOPE(publisher_adapter_tests, fixture)
SCENARIO("publisher adapters suspend reads if the buffer becomes full") {
auto ls = [](auto... xs) { return std::vector<int32_t>{xs...}; };
GIVEN("a writer and a message-based application") {
auto [fd1, fd2] = unbox(net::make_stream_socket_pair());
auto writer_thread = std::thread{[fd1{fd1}] {
writer out{fd1};
for (int i = 0; i < 12; ++i)
out.write(std::to_string(i));
}};
net::multiplexer mpx{nullptr};
if (auto err = mpx.init())
FAIL("mpx.init failed: " << err);
mpx.set_thread_id();
REQUIRE_EQ(mpx.num_socket_managers(), 1u);
if (auto err = net::nonblocking(fd2, true))
CAF_FAIL("nonblocking returned an error: " << err);
auto mgr = net::make_socket_manager<app, net::length_prefix_framing,
net::stream_transport>(fd2, &mpx);
auto& st = mgr->top_layer();
CHECK_EQ(mgr->init(settings{}), none);
REQUIRE_EQ(mpx.num_socket_managers(), 2u);
CHECK_EQ(mgr->mask(), net::operation::read);
WHEN("the publisher adapter runs out of capacity") {
while (mpx.num_socket_managers() > 1u)
mpx.poll_once(true);
CHECK_EQ(mgr->mask(), net::operation::none);
CHECK_EQ(st.received_messages, 3u);
THEN("reading from the adapter registers the manager for reading again") {
auto obs = make_counted<mock_observer>();
st.adapter->subscribe(flow::observer<int32_t>{obs});
REQUIRE(obs->sub.valid());
obs->sub.request(1);
while (st.received_messages != 4u)
mpx.poll_once(true);
CHECK_EQ(obs->buf, ls(0));
obs->sub.request(20);
while (st.received_messages != 12u)
mpx.poll_once(true);
CHECK_EQ(obs->buf, ls(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11));
}
}
writer_thread.join();
}
}
CAF_TEST_FIXTURE_SCOPE_END()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment