Commit 300e68f7 authored by Dominik Charousset's avatar Dominik Charousset

Re-implement WebSocket with protocol switching

parent b5eea2ee
...@@ -18,16 +18,15 @@ ...@@ -18,16 +18,15 @@
namespace caf::detail { namespace caf::detail {
/// Convenience alias for referring to the base type of @ref flow_bridge. /// Convenience alias for referring to the base type of @ref flow_bridge.
template <class Trait, class Role> template <class Trait, class Base>
using ws_flow_bridge_base_t using ws_flow_bridge_base_t
= detail::flow_bridge_base<typename Role::upper_layer, = detail::flow_bridge_base<Base, net::web_socket::lower_layer, Trait>;
net::web_socket::lower_layer, Trait>;
/// Translates between a message-oriented transport and data flows. /// Translates between a message-oriented transport and data flows.
template <class Trait, class Role> template <class Trait, class Base>
class ws_flow_bridge : public ws_flow_bridge_base_t<Trait, Role> { class ws_flow_bridge : public ws_flow_bridge_base_t<Trait, Base> {
public: public:
using super = ws_flow_bridge_base_t<Trait, Role>; using super = ws_flow_bridge_base_t<Trait, Base>;
using input_type = typename Trait::input_type; using input_type = typename Trait::input_type;
......
...@@ -38,6 +38,11 @@ public: ...@@ -38,6 +38,11 @@ public:
/// Prepares written data for transfer, e.g., by flushing buffers or /// Prepares written data for transfer, e.g., by flushing buffers or
/// registering sockets for write events. /// registering sockets for write events.
virtual bool end_output() = 0; virtual bool end_output() = 0;
/// Asks the stream to swap the current upper layer with `next` after
/// returning from `consume()`.
/// @note may only be called from the upper layer in `consume`.
virtual void switch_protocol(std::unique_ptr<upper_layer> next) = 0;
}; };
} // namespace caf::net::octet_stream } // namespace caf::net::octet_stream
...@@ -81,6 +81,8 @@ public: ...@@ -81,6 +81,8 @@ public:
void shutdown() override; void shutdown() override;
void switch_protocol(upper_layer_ptr) override;
// -- properties ------------------------------------------------------------- // -- properties -------------------------------------------------------------
auto& read_buffer() noexcept { auto& read_buffer() noexcept {
...@@ -188,6 +190,10 @@ protected: ...@@ -188,6 +190,10 @@ protected:
/// Fallback policy. /// Fallback policy.
policy default_policy_; policy default_policy_;
/// Setting this to non-null informs the transport to replace `up_` with
/// `next_`.
upper_layer_ptr next_;
// TODO: add [[no_unique_address]] to default_policy_ when switching to C++20. // TODO: add [[no_unique_address]] to default_policy_ when switching to C++20.
}; };
......
...@@ -31,23 +31,16 @@ class CAF_NET_EXPORT client : public octet_stream::upper_layer { ...@@ -31,23 +31,16 @@ class CAF_NET_EXPORT client : public octet_stream::upper_layer {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
class CAF_NET_EXPORT upper_layer : public web_socket::upper_layer {
public:
virtual ~upper_layer();
/// Initializes the upper layer.
/// @param down A pointer to the lower layer that remains valid for the
/// lifetime of the upper layer.
virtual error start(lower_layer* down) = 0;
};
using handshake_ptr = std::unique_ptr<handshake>; using handshake_ptr = std::unique_ptr<handshake>;
using upper_layer_ptr = std::unique_ptr<upper_layer>; using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer::client>;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
client(handshake_ptr hs, upper_layer_ptr up); client(handshake_ptr hs, upper_layer_ptr up)
: hs_(std::move(hs)), up_(std::move(up)) {
// nop
}
// -- factories -------------------------------------------------------------- // -- factories --------------------------------------------------------------
...@@ -57,32 +50,6 @@ public: ...@@ -57,32 +50,6 @@ public:
return make(std::make_unique<handshake>(std::move(hs)), std::move(up)); return make(std::make_unique<handshake>(std::move(hs)), std::move(up));
} }
// -- properties -------------------------------------------------------------
client::upper_layer& up() noexcept {
// This cast is safe, because we know that we have initialized the framing
// layer with a pointer to an web_socket::client::upper_layer object that
// the framing then upcasts to web_socket::upper_layer.
return static_cast<client::upper_layer&>(framing_.up());
}
const client::upper_layer& up() const noexcept {
// See comment in the other up() overload.
return static_cast<const client::upper_layer&>(framing_.up());
}
octet_stream::lower_layer& down() noexcept {
return framing_.down();
}
const octet_stream::lower_layer& down() const noexcept {
return framing_.down();
}
bool handshake_completed() const noexcept {
return hs_ == nullptr;
}
// -- implementation of octet_stream::upper_layer ---------------------------- // -- implementation of octet_stream::upper_layer ----------------------------
error start(octet_stream::lower_layer* down) override; error start(octet_stream::lower_layer* down) override;
...@@ -102,13 +69,14 @@ private: ...@@ -102,13 +69,14 @@ private:
// -- member variables ------------------------------------------------------- // -- member variables -------------------------------------------------------
/// Points to the transport layer below.
octet_stream::lower_layer* down_;
/// Stores the WebSocket handshake data until the handshake completed. /// Stores the WebSocket handshake data until the handshake completed.
handshake_ptr hs_; handshake_ptr hs_;
/// Stores the upper layer. /// Next layer in the processing chain.
framing framing_; upper_layer_ptr up_;
settings cfg_;
}; };
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -26,9 +26,9 @@ namespace caf::detail { ...@@ -26,9 +26,9 @@ namespace caf::detail {
/// Specializes the WebSocket flow bridge for the server side. /// Specializes the WebSocket flow bridge for the server side.
template <class Trait, class... Ts> template <class Trait, class... Ts>
class ws_client_flow_bridge class ws_client_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::client> { : public ws_flow_bridge<Trait, net::web_socket::upper_layer::client> {
public: public:
using super = ws_flow_bridge<Trait, net::web_socket::client>; using super = ws_flow_bridge<Trait, net::web_socket::upper_layer::client>;
// We consume the output type of the application. // We consume the output type of the application.
using pull_t = async::consumer_resource<typename Trait::input_type>; using pull_t = async::consumer_resource<typename Trait::input_type>;
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "caf/detail/rfc6455.hpp" #include "caf/detail/rfc6455.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/octet_stream/lower_layer.hpp" #include "caf/net/octet_stream/lower_layer.hpp"
#include "caf/net/octet_stream/upper_layer.hpp"
#include "caf/net/receive_policy.hpp" #include "caf/net/receive_policy.hpp"
#include "caf/net/web_socket/lower_layer.hpp" #include "caf/net/web_socket/lower_layer.hpp"
#include "caf/net/web_socket/status.hpp" #include "caf/net/web_socket/status.hpp"
...@@ -25,7 +26,8 @@ ...@@ -25,7 +26,8 @@
namespace caf::net::web_socket { namespace caf::net::web_socket {
/// Implements the WebSocket framing protocol as defined in RFC-6455. /// Implements the WebSocket framing protocol as defined in RFC-6455.
class CAF_NET_EXPORT framing : public web_socket::lower_layer { class CAF_NET_EXPORT framing : public octet_stream::upper_layer,
public web_socket::lower_layer {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
...@@ -33,6 +35,10 @@ public: ...@@ -33,6 +35,10 @@ public:
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>; using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>;
using server_ptr = std::unique_ptr<web_socket::upper_layer::server>;
using client_ptr = std::unique_ptr<web_socket::upper_layer::client>;
// -- constants -------------------------------------------------------------- // -- constants --------------------------------------------------------------
/// Restricts the size of received frames (including header). /// Restricts the size of received frames (including header).
...@@ -43,13 +49,14 @@ public: ...@@ -43,13 +49,14 @@ public:
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
explicit framing(upper_layer_ptr up) : up_(std::move(up)) { static std::unique_ptr<framing> make(client_ptr up) {
// nop return std::unique_ptr<framing>{new framing(std::move(up))};
} }
// -- initialization --------------------------------------------------------- static std::unique_ptr<framing> make(server_ptr up,
http::request_header hdr) {
void start(octet_stream::lower_layer* down); return std::unique_ptr<framing>{new framing(std::move(up), std::move(hdr))};
}
// -- properties ------------------------------------------------------------- // -- properties -------------------------------------------------------------
...@@ -75,6 +82,18 @@ public: ...@@ -75,6 +82,18 @@ public:
/// the standard. /// the standard.
bool mask_outgoing_frames = true; bool mask_outgoing_frames = true;
// -- octet_stream::upper_layer implementation -------------------------------
error start(octet_stream::lower_layer* down) override;
void abort(const error& reason) override;
ptrdiff_t consume(byte_span input, byte_span) override;
void prepare_send() override;
bool done_sending() override;
// -- web_socket::lower_layer implementation --------------------------------- // -- web_socket::lower_layer implementation ---------------------------------
using web_socket::lower_layer::shutdown; using web_socket::lower_layer::shutdown;
...@@ -105,13 +124,20 @@ public: ...@@ -105,13 +124,20 @@ public:
bool end_text_message() override; bool end_text_message() override;
// -- interface for the lower layer ------------------------------------------
ptrdiff_t consume(byte_span input, byte_span);
private: private:
// -- implementation details ------------------------------------------------- // -- implementation details -------------------------------------------------
explicit framing(client_ptr up) : up_(std::move(up)) {
// nop
}
explicit framing(server_ptr up, http::request_header&& hdr)
: up_(std::move(up)), hdr_(std::move(hdr)) {
// > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
mask_outgoing_frames = false;
}
bool handle(uint8_t opcode, byte_span payload); bool handle(uint8_t opcode, byte_span payload);
void ship_pong(byte_span payload); void ship_pong(byte_span payload);
...@@ -141,6 +167,9 @@ private: ...@@ -141,6 +167,9 @@ private:
/// Next layer in the processing chain. /// Next layer in the processing chain.
upper_layer_ptr up_; upper_layer_ptr up_;
/// Stored when running as a server and passed to `up_` in start.
std::optional<http::request_header> hdr_;
}; };
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -39,51 +39,18 @@ class CAF_NET_EXPORT server : public octet_stream::upper_layer { ...@@ -39,51 +39,18 @@ class CAF_NET_EXPORT server : public octet_stream::upper_layer {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
class CAF_NET_EXPORT upper_layer : public web_socket::upper_layer { using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer::server>;
public:
virtual ~upper_layer();
/// Initializes the upper layer.
/// @param down A pointer to the lower layer that remains valid for the
/// lifetime of the upper layer.
/// @param hdr The HTTP request header from the client handshake.
virtual error start(lower_layer* down, const http::request_header& hdr) = 0;
};
using upper_layer_ptr = std::unique_ptr<upper_layer>;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
explicit server(upper_layer_ptr up); explicit server(upper_layer_ptr up) : up_(std::move(up)) {
// nop
// -- factories --------------------------------------------------------------
static std::unique_ptr<server> make(upper_layer_ptr up);
// -- properties -------------------------------------------------------------
server::upper_layer& up() noexcept {
// This cast is safe, because we know that we have initialized the framing
// layer with a pointer to an web_socket::server::upper_layer object that
// the framing then upcasts to web_socket::upper_layer.
return static_cast<server::upper_layer&>(framing_.up());
} }
const server::upper_layer& up() const noexcept { // -- factories --------------------------------------------------------------
// See comment in the other up() overload.
return static_cast<const server::upper_layer&>(framing_.up());
}
octet_stream::lower_layer& down() noexcept {
return framing_.down();
}
const octet_stream::lower_layer& down() const noexcept {
return framing_.down();
}
bool handshake_complete() const noexcept { static std::unique_ptr<server> make(upper_layer_ptr up) {
return handshake_complete_; return std::make_unique<server>(std::move(up));
} }
// -- octet_stream::upper_layer implementation ------------------------------- // -- octet_stream::upper_layer implementation -------------------------------
...@@ -105,11 +72,11 @@ private: ...@@ -105,11 +72,11 @@ private:
bool handle_header(std::string_view http); bool handle_header(std::string_view http);
/// Stores whether the WebSocket handshake completed successfully. /// Points to the transport layer below.
bool handshake_complete_ = false; octet_stream::lower_layer* down_;
/// Stores the upper layer. /// We store this only to pass it to the framing layer after the handshake.
framing framing_; upper_layer_ptr up_;
}; };
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -29,9 +29,9 @@ namespace caf::detail { ...@@ -29,9 +29,9 @@ namespace caf::detail {
/// Specializes the WebSocket flow bridge for the server side. /// Specializes the WebSocket flow bridge for the server side.
template <class Trait, class... Ts> template <class Trait, class... Ts>
class ws_server_flow_bridge class ws_server_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::server> { : public ws_flow_bridge<Trait, net::web_socket::upper_layer::server> {
public: public:
using super = ws_flow_bridge<Trait, net::web_socket::server>; using super = ws_flow_bridge<Trait, net::web_socket::upper_layer::server>;
using ws_acceptor_t = net::web_socket::acceptor<Ts...>; using ws_acceptor_t = net::web_socket::acceptor<Ts...>;
......
...@@ -16,6 +16,10 @@ namespace caf::net::web_socket { ...@@ -16,6 +16,10 @@ namespace caf::net::web_socket {
/// a server or a client. /// a server or a client.
class CAF_NET_EXPORT upper_layer : public generic_upper_layer { class CAF_NET_EXPORT upper_layer : public generic_upper_layer {
public: public:
class server;
class client;
virtual ~upper_layer(); virtual ~upper_layer();
virtual ptrdiff_t consume_binary(byte_span buf) = 0; virtual ptrdiff_t consume_binary(byte_span buf) = 0;
...@@ -23,4 +27,16 @@ public: ...@@ -23,4 +27,16 @@ public:
virtual ptrdiff_t consume_text(std::string_view buf) = 0; virtual ptrdiff_t consume_text(std::string_view buf) = 0;
}; };
class upper_layer::server : public upper_layer {
public:
virtual ~server();
virtual error start(lower_layer* down, const http::request_header& hdr) = 0;
};
class upper_layer::client : public upper_layer {
public:
virtual ~client();
virtual error start(lower_layer* down) = 0;
};
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -94,6 +94,10 @@ void transport::shutdown() { ...@@ -94,6 +94,10 @@ void transport::shutdown() {
} }
} }
void transport::switch_protocol(upper_layer_ptr next) {
next_ = std::move(next);
}
// -- implementation of transport ---------------------------------------------- // -- implementation of transport ----------------------------------------------
error transport::start(socket_manager* owner) { error transport::start(socket_manager* owner) {
...@@ -186,6 +190,19 @@ void transport::handle_buffered_data() { ...@@ -186,6 +190,19 @@ void transport::handle_buffered_data() {
CAF_LOG_TRACE(CAF_ARG(buffered_)); CAF_LOG_TRACE(CAF_ARG(buffered_));
// Loop until we have drained the buffer as much as we can. // Loop until we have drained the buffer as much as we can.
CAF_ASSERT(min_read_size_ <= max_read_size_); CAF_ASSERT(min_read_size_ <= max_read_size_);
auto switch_to_next_protocol = [this] {
assert(next_);
// Switch to the new protocol and initialize it.
configure_read(receive_policy::stop());
up_.reset(next_.release());
if (auto err = up_->start(this)) {
up_.reset();
parent_->deregister();
parent_->shutdown();
return false;
}
return true;
};
while (parent_->is_reading() && max_read_size_ > 0 while (parent_->is_reading() && max_read_size_ > 0
&& buffered_ >= min_read_size_) { && buffered_ >= min_read_size_) {
auto n = std::min(buffered_, size_t{max_read_size_}); auto n = std::min(buffered_, size_t{max_read_size_});
...@@ -205,16 +222,23 @@ void transport::handle_buffered_data() { ...@@ -205,16 +222,23 @@ void transport::handle_buffered_data() {
parent_->deregister(); parent_->deregister();
return; return;
} else if (consumed == 0) { } else if (consumed == 0) {
// See whether the next iteration would change what we pass to the if (next_) {
// application (max_read_size_ may have changed). Otherwise, we'll try // When switching protocol, the new layer has never seen the data, so we
// again later. // might just re-invoke the same data again.
delta_offset_ = static_cast<ptrdiff_t>(n); if (!switch_to_next_protocol())
if (n == std::min(buffered_, size_t{max_read_size_})) { return;
return;
} else { } else {
// "Fall through". // See whether the next iteration would change what we pass to the
// application (max_read_size_ may have changed). Otherwise, we'll try
// again later.
delta_offset_ = static_cast<ptrdiff_t>(n);
if (n == std::min(buffered_, size_t{max_read_size_}))
return;
// else: "Fall through".
} }
} else { } else {
if (next_ && !switch_to_next_protocol())
return;
// Shove the unread bytes to the beginning of the buffer and continue // Shove the unread bytes to the beginning of the buffer and continue
// to the next loop iteration. // to the next loop iteration.
auto del = static_cast<size_t>(consumed); auto del = static_cast<size_t>(consumed);
......
...@@ -22,19 +22,6 @@ ...@@ -22,19 +22,6 @@
namespace caf::net::web_socket { namespace caf::net::web_socket {
// -- member types -------------------------------------------------------------
client::upper_layer::~upper_layer() {
// nop
}
// -- constructors, destructors, and assignment operators ----------------------
client::client(handshake_ptr hs, upper_layer_ptr up_ptr)
: hs_(std::move(hs)), framing_(std::move(up_ptr)) {
// nop
}
// -- factories ---------------------------------------------------------------- // -- factories ----------------------------------------------------------------
std::unique_ptr<client> client::make(handshake_ptr hs, upper_layer_ptr up_ptr) { std::unique_ptr<client> client::make(handshake_ptr hs, upper_layer_ptr up_ptr) {
...@@ -43,67 +30,56 @@ std::unique_ptr<client> client::make(handshake_ptr hs, upper_layer_ptr up_ptr) { ...@@ -43,67 +30,56 @@ std::unique_ptr<client> client::make(handshake_ptr hs, upper_layer_ptr up_ptr) {
// -- implementation of octet_stream::upper_layer ------------------------------ // -- implementation of octet_stream::upper_layer ------------------------------
error client::start(octet_stream::lower_layer* down_ptr) { error client::start(octet_stream::lower_layer* down) {
CAF_ASSERT(hs_ != nullptr); CAF_ASSERT(hs_ != nullptr);
framing_.start(down_ptr);
if (!hs_->has_mandatory_fields()) if (!hs_->has_mandatory_fields())
return make_error(sec::runtime_error, return make_error(sec::runtime_error,
"handshake data lacks mandatory fields"); "WebSocket client received an incomplete handshake");
if (!hs_->has_valid_key()) if (!hs_->has_valid_key())
hs_->randomize_key(); hs_->randomize_key();
down().begin_output(); down_ = down;
hs_->write_http_1_request(down().output_buffer()); down_->begin_output();
down().end_output(); hs_->write_http_1_request(down_->output_buffer());
down().configure_read(receive_policy::up_to(handshake::max_http_size)); down_->end_output();
down_->configure_read(receive_policy::up_to(handshake::max_http_size));
return none; return none;
} }
void client::abort(const error& reason) { void client::abort(const error& reason) {
up().abort(reason); assert(up_);
up_->abort(reason);
} }
ptrdiff_t client::consume(byte_span buffer, byte_span delta) { ptrdiff_t client::consume(byte_span buffer, byte_span delta) {
CAF_LOG_TRACE(CAF_ARG2("buffer", buffer.size())); CAF_LOG_TRACE(CAF_ARG2("buffer", buffer.size()));
if (handshake_completed()) { // Check whether we have received the HTTP header or else wait for more
// Short circuit to the framing layer after the handshake completed. // data. Abort when exceeding the maximum size.
return framing_.consume(buffer, delta); auto [hdr, remainder] = http::v1::split_header(buffer);
} else { if (hdr.empty()) {
// Check whether we have received the HTTP header or else wait for more if (buffer.size() >= handshake::max_http_size) {
// data. Abort when exceeding the maximum size. CAF_LOG_ERROR("server response exceeded the maximum header size");
auto [hdr, remainder] = http::v1::split_header(buffer); up_->abort(make_error(sec::protocol_error, "server response exceeded "
if (hdr.empty()) { "the maximum header size"));
if (buffer.size() >= handshake::max_http_size) {
CAF_LOG_ERROR("server response exceeded the maximum header size");
up().abort(make_error(sec::protocol_error, "server response exceeded "
"the maximum header size"));
return -1;
} else {
return 0;
}
} else if (!handle_header(hdr)) {
// Note: handle_header() already calls upper_layer().abort().
return -1; return -1;
} else if (remainder.empty()) {
CAF_ASSERT(hdr.size() == buffer.size());
return hdr.size();
} else {
CAF_LOG_DEBUG(CAF_ARG2("remainder.size", remainder.size()));
if (auto res = framing_.consume(remainder, remainder); res >= 0) {
return hdr.size() + res;
} else {
return res;
}
} }
// Wait for more data.
return 0;
}
if (!handle_header(hdr)) {
// Note: handle_header() already calls upper_layer().abort().
return -1;
} }
// We only care about the header here. The framing layer is responsible for
// any remaining data.
return static_cast<ptrdiff_t>(hdr.size());
} }
void client::prepare_send() { void client::prepare_send() {
if (handshake_completed()) // nop
up().prepare_send();
} }
bool client::done_sending() { bool client::done_sending() {
return handshake_completed() ? up().done_sending() : true; return true;
} }
// -- HTTP response processing ------------------------------------------------- // -- HTTP response processing -------------------------------------------------
...@@ -113,19 +89,13 @@ bool client::handle_header(std::string_view http) { ...@@ -113,19 +89,13 @@ bool client::handle_header(std::string_view http) {
auto http_ok = hs_->is_valid_http_1_response(http); auto http_ok = hs_->is_valid_http_1_response(http);
hs_.reset(); hs_.reset();
if (http_ok) { if (http_ok) {
if (auto err = up().start(&framing_)) { down_->switch_protocol(framing::make(std::move(up_)));
CAF_LOG_DEBUG("failed to initialize WebSocket framing layer"); return true;
return false;
} else {
CAF_LOG_DEBUG("completed WebSocket handshake");
return true;
}
} else {
CAF_LOG_DEBUG("received an invalid WebSocket handshake");
up().abort(make_error(sec::protocol_error,
"received an invalid WebSocket handshake"));
return false;
} }
CAF_LOG_DEBUG("received an invalid WebSocket handshake");
up_->abort(
make_error(sec::protocol_error, "received an invalid WebSocket handshake"));
return false;
} }
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -5,92 +5,38 @@ ...@@ -5,92 +5,38 @@
#include "caf/net/web_socket/framing.hpp" #include "caf/net/web_socket/framing.hpp"
#include "caf/logger.hpp" #include "caf/logger.hpp"
#include "caf/net/http/v1.hpp"
namespace caf::net::web_socket { namespace caf::net::web_socket {
// -- initialization --------------------------------------------------------- // -- octet_stream::upper_layer implementation ---------------------------------
void framing::start(octet_stream::lower_layer* down) { error framing::start(octet_stream::lower_layer* down) {
std::random_device rd; std::random_device rd;
rng_.seed(rd()); rng_.seed(rd());
down_ = down; down_ = down;
} if (!hdr_) {
using dptr_t = web_socket::upper_layer::client;
// -- web_socket::lower_layer implementation ----------------------------------- return static_cast<dptr_t*>(up_.get())->start(this);
multiplexer& framing::mpx() noexcept {
return down_->mpx();
}
bool framing::can_send_more() const noexcept {
return down_->can_send_more();
}
void framing::suspend_reading() {
down_->configure_read(receive_policy::stop());
}
bool framing::is_reading() const noexcept {
return down_->is_reading();
}
void framing::write_later() {
down_->write_later();
}
void framing::shutdown(status code, std::string_view msg) {
auto code_val = static_cast<uint16_t>(code);
uint32_t mask_key = 0;
byte_buffer payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<std::byte>((code_val & 0xFF00) >> 8));
payload.push_back(static_cast<std::byte>(code_val & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<std::byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
} }
down_->begin_output(); using dptr_t = web_socket::upper_layer::server;
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key, auto err = static_cast<dptr_t*>(up_.get())->start(this, *hdr_);
payload, down_->output_buffer()); hdr_ = std::nullopt;
down_->end_output(); if (err) {
down_->shutdown(); auto descr = to_string(err);
} CAF_LOG_DEBUG("upper layer rejected a WebSocket connection:" << descr);
down_->begin_output();
void framing::request_messages() { http::v1::write_response(http::status::bad_request, "text/plain", descr,
if (!down_->is_reading()) down_->output_buffer());
down_->configure_read(receive_policy::up_to(2048)); down_->end_output();
} }
return err;
void framing::begin_binary_message() {
// nop
}
byte_buffer& framing::binary_message_buffer() {
return binary_buf_;
}
bool framing::end_binary_message() {
ship_frame(binary_buf_);
return true;
}
void framing::begin_text_message() {
// nop
}
text_buffer& framing::text_message_buffer() {
return text_buf_;
} }
bool framing::end_text_message() { void framing::abort(const error& reason) {
ship_frame(text_buf_); up_->abort(reason);
return true;
} }
// -- interface for the lower layer --------------------------------------------
ptrdiff_t framing::consume(byte_span buffer, byte_span) { ptrdiff_t framing::consume(byte_span buffer, byte_span) {
// Make sure we're overriding any 'exactly' setting. // Make sure we're overriding any 'exactly' setting.
down_->configure_read(receive_policy::up_to(2048)); down_->configure_read(receive_policy::up_to(2048));
...@@ -181,6 +127,87 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) { ...@@ -181,6 +127,87 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) {
return static_cast<ptrdiff_t>(frame_size); return static_cast<ptrdiff_t>(frame_size);
} }
void framing::prepare_send() {
up_->prepare_send();
}
bool framing::done_sending() {
return up_->done_sending();
}
// -- web_socket::lower_layer implementation -----------------------------------
multiplexer& framing::mpx() noexcept {
return down_->mpx();
}
bool framing::can_send_more() const noexcept {
return down_->can_send_more();
}
void framing::suspend_reading() {
down_->configure_read(receive_policy::stop());
}
bool framing::is_reading() const noexcept {
return down_->is_reading();
}
void framing::write_later() {
down_->write_later();
}
void framing::shutdown(status code, std::string_view msg) {
auto code_val = static_cast<uint16_t>(code);
uint32_t mask_key = 0;
byte_buffer payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<std::byte>((code_val & 0xFF00) >> 8));
payload.push_back(static_cast<std::byte>(code_val & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<std::byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down_->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down_->output_buffer());
down_->end_output();
down_->shutdown();
}
void framing::request_messages() {
if (!down_->is_reading())
down_->configure_read(receive_policy::up_to(2048));
}
void framing::begin_binary_message() {
// nop
}
byte_buffer& framing::binary_message_buffer() {
return binary_buf_;
}
bool framing::end_binary_message() {
ship_frame(binary_buf_);
return true;
}
void framing::begin_text_message() {
// nop
}
text_buffer& framing::text_message_buffer() {
return text_buf_;
}
bool framing::end_text_message() {
ship_frame(text_buf_);
return true;
}
// -- implementation details --------------------------------------------------- // -- implementation details ---------------------------------------------------
bool framing::handle(uint8_t opcode, byte_span payload) { bool framing::handle(uint8_t opcode, byte_span payload) {
......
...@@ -6,84 +6,56 @@ ...@@ -6,84 +6,56 @@
namespace caf::net::web_socket { namespace caf::net::web_socket {
// -- member types -------------------------------------------------------------
server::upper_layer::~upper_layer() {
// nop
}
// -- constructors, destructors, and assignment operators ----------------------
server::server(upper_layer_ptr up) : framing_(std::move(up)) {
// > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
framing_.mask_outgoing_frames = false;
}
// -- factories ----------------------------------------------------------------
std::unique_ptr<server> server::make(upper_layer_ptr up) {
return std::make_unique<server>(std::move(up));
}
// -- octet_stream::upper_layer implementation --------------------------------- // -- octet_stream::upper_layer implementation ---------------------------------
error server::start(octet_stream::lower_layer* down_ptr) { error server::start(octet_stream::lower_layer* down) {
framing_.start(down_ptr); down_ = down;
down().configure_read(receive_policy::up_to(handshake::max_http_size)); down_->configure_read(receive_policy::up_to(handshake::max_http_size));
return none; return none;
} }
void server::abort(const error& reason) { void server::abort(const error& err) {
if (handshake_complete_) up_->abort(err);
up().abort(reason);
} }
ptrdiff_t server::consume(byte_span input, byte_span delta) { ptrdiff_t server::consume(byte_span input, byte_span) {
using namespace std::literals; using namespace std::literals;
CAF_LOG_TRACE(CAF_ARG2("bytes.size", input.size())); CAF_LOG_TRACE(CAF_ARG2("bytes.size", input.size()));
if (handshake_complete_) { // Check whether we received an HTTP header or else wait for more data.
// Short circuit to the framing layer after the handshake completed. // Abort when exceeding the maximum size.
return framing_.consume(input, delta); auto [hdr, remainder] = http::v1::split_header(input);
} else { if (hdr.empty()) {
// Check whether we received an HTTP header or else wait for more data. if (input.size() >= handshake::max_http_size) {
// Abort when exceeding the maximum size. down_->begin_output();
auto [hdr, remainder] = http::v1::split_header(input); http::v1::write_response(http::status::request_header_fields_too_large,
if (hdr.empty()) { "text/plain"sv, "Header exceeds maximum size."sv,
if (input.size() >= handshake::max_http_size) { down_->output_buffer());
down().begin_output(); down_->end_output();
http::v1::write_response(http::status::request_header_fields_too_large,
"text/plain"sv,
"Header exceeds maximum size."sv,
down().output_buffer());
down().end_output();
return -1;
} else {
return 0;
}
} else if (!handle_header(hdr)) {
return -1; return -1;
} else { } else {
return hdr.size(); return 0;
} }
} else if (!handle_header(hdr)) {
return -1;
} else {
return hdr.size();
} }
} }
void server::prepare_send() { void server::prepare_send() {
if (handshake_complete_) // nop
up().prepare_send();
} }
bool server::done_sending() { bool server::done_sending() {
return handshake_complete_ ? up().done_sending() : true; return true;
} }
// -- HTTP request processing ------------------------------------------------ // -- HTTP request processing ------------------------------------------------
void server::write_response(http::status code, std::string_view msg) { void server::write_response(http::status code, std::string_view msg) {
down().begin_output(); down_->begin_output();
http::v1::write_response(code, "text/plain", msg, down().output_buffer()); http::v1::write_response(code, "text/plain", msg, down_->output_buffer());
down().end_output(); down_->end_output();
} }
bool server::handle_header(std::string_view http) { bool server::handle_header(std::string_view http) {
...@@ -108,23 +80,14 @@ bool server::handle_header(std::string_view http) { ...@@ -108,23 +80,14 @@ bool server::handle_header(std::string_view http) {
CAF_LOG_DEBUG("received invalid WebSocket handshake"); CAF_LOG_DEBUG("received invalid WebSocket handshake");
return false; return false;
} }
// Try to initialize the upper layer.
down().configure_read(receive_policy::stop());
if (auto err = up().start(&framing_, hdr)) {
auto descr = to_string(err);
CAF_LOG_DEBUG("upper layer rejected a WebSocket connection:" << descr);
write_response(http::status::bad_request, descr);
return false;
}
// Finalize the WebSocket handshake. // Finalize the WebSocket handshake.
handshake hs; handshake hs;
hs.assign_key(sec_key); hs.assign_key(sec_key);
down().begin_output(); down_->begin_output();
hs.write_http_1_response(down().output_buffer()); hs.write_http_1_response(down_->output_buffer());
down().end_output(); down_->end_output();
CAF_LOG_DEBUG("completed WebSocket handshake"); CAF_LOG_DEBUG("completed WebSocket handshake");
handshake_complete_ = true; down_->switch_protocol(framing::make(std::move(up_), std::move(hdr)));
return true; return true;
} }
......
...@@ -10,4 +10,12 @@ upper_layer::~upper_layer() { ...@@ -10,4 +10,12 @@ upper_layer::~upper_layer() {
// nop // nop
} }
upper_layer::server::~server() {
// nop
}
upper_layer::client::~client() {
// nop
}
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -35,6 +35,10 @@ void mock_stream_transport::shutdown() { ...@@ -35,6 +35,10 @@ void mock_stream_transport::shutdown() {
// nop // nop
} }
void mock_stream_transport::switch_protocol(upper_layer_ptr new_up) {
next.swap(new_up);
}
void mock_stream_transport::configure_read(net::receive_policy policy) { void mock_stream_transport::configure_read(net::receive_policy policy) {
min_read_size = policy.min_size; min_read_size = policy.min_size;
max_read_size = policy.max_size; max_read_size = policy.max_size;
...@@ -54,6 +58,17 @@ bool mock_stream_transport::end_output() { ...@@ -54,6 +58,17 @@ bool mock_stream_transport::end_output() {
ptrdiff_t mock_stream_transport::handle_input() { ptrdiff_t mock_stream_transport::handle_input() {
ptrdiff_t result = 0; ptrdiff_t result = 0;
auto switch_to_next_protocol = [this] {
assert(next);
// Switch to the new protocol and initialize it.
configure_read(net::receive_policy::stop());
up.reset(next.release());
if (auto err = up->start(this)) {
up.reset();
return false;
}
return true;
};
// Loop until we have drained the buffer as much as we can. // Loop until we have drained the buffer as much as we can.
while (max_read_size > 0 && input.size() >= min_read_size) { while (max_read_size > 0 && input.size() >= min_read_size) {
auto n = std::min(input.size(), size_t{max_read_size}); auto n = std::min(input.size(), size_t{max_read_size});
...@@ -71,16 +86,24 @@ ptrdiff_t mock_stream_transport::handle_input() { ...@@ -71,16 +86,24 @@ ptrdiff_t mock_stream_transport::handle_input() {
up->abort(make_error(sec::logic_error, "consumed > buffer.size")); up->abort(make_error(sec::logic_error, "consumed > buffer.size"));
return result; return result;
} else if (consumed == 0) { } else if (consumed == 0) {
// See whether the next iteration would change what we pass to the if (next) {
// application (max_read_size_ may have changed). Otherwise, we'll try // When switching protocol, the new layer has never seen the data, so we
// again later. // might just re-invoke the same data again.
delta_offset = static_cast<ptrdiff_t>(n); if (!switch_to_next_protocol())
if (n == std::min(input.size(), size_t{max_read_size})) { return -1;
return result;
} else { } else {
// "Fall through". // See whether the next iteration would change what we pass to the
// application (max_read_size_ may have changed). Otherwise, we'll try
// again later.
delta_offset = static_cast<ptrdiff_t>(n);
if (n == std::min(input.size(), size_t{max_read_size})) {
return result;
}
// else: "Fall through".
} }
} else { } else {
if (next && !switch_to_next_protocol())
return -1;
// Shove the unread bytes to the beginning of the buffer and continue // Shove the unread bytes to the beginning of the buffer and continue
// to the next loop iteration. // to the next loop iteration.
result += consumed; result += consumed;
......
...@@ -48,6 +48,8 @@ public: ...@@ -48,6 +48,8 @@ public:
bool end_output() override; bool end_output() override;
void switch_protocol(upper_layer_ptr) override;
// -- initialization --------------------------------------------------------- // -- initialization ---------------------------------------------------------
caf::error start(caf::net::multiplexer* ptr) { caf::error start(caf::net::multiplexer* ptr) {
...@@ -81,6 +83,8 @@ public: ...@@ -81,6 +83,8 @@ public:
upper_layer_ptr up; upper_layer_ptr up;
upper_layer_ptr next;
caf::byte_buffer output; caf::byte_buffer output;
caf::byte_buffer input; caf::byte_buffer input;
......
...@@ -14,7 +14,7 @@ namespace { ...@@ -14,7 +14,7 @@ namespace {
using svec = std::vector<std::string>; using svec = std::vector<std::string>;
class app_t : public net::web_socket::client::upper_layer { class app_t : public net::web_socket::upper_layer::client {
public: public:
static auto make() { static auto make() {
return std::make_unique<app_t>(); return std::make_unique<app_t>();
...@@ -95,7 +95,6 @@ SCENARIO("the client performs the WebSocket handshake on startup") { ...@@ -95,7 +95,6 @@ SCENARIO("the client performs the WebSocket handshake on startup") {
WHEN("starting a WebSocket client") { WHEN("starting a WebSocket client") {
auto app = app_t::make(); auto app = app_t::make();
auto ws = net::web_socket::client::make(make_handshake(), std::move(app)); auto ws = net::web_socket::client::make(make_handshake(), std::move(app));
auto& ws_state = *ws;
auto uut = mock_stream_transport::make(std::move(ws)); auto uut = mock_stream_transport::make(std::move(ws));
THEN("the client sends its HTTP request when initializing it") { THEN("the client sends its HTTP request when initializing it") {
CHECK_EQ(uut->start(nullptr), error{}); CHECK_EQ(uut->start(nullptr), error{});
...@@ -105,7 +104,6 @@ SCENARIO("the client performs the WebSocket handshake on startup") { ...@@ -105,7 +104,6 @@ SCENARIO("the client performs the WebSocket handshake on startup") {
uut->push(http_response); uut->push(http_response);
CHECK_EQ(uut->handle_input(), CHECK_EQ(uut->handle_input(),
static_cast<ptrdiff_t>(http_response.size())); static_cast<ptrdiff_t>(http_response.size()));
CHECK(ws_state.handshake_completed());
} }
} }
} }
......
...@@ -20,7 +20,7 @@ namespace { ...@@ -20,7 +20,7 @@ namespace {
using svec = std::vector<std::string>; using svec = std::vector<std::string>;
class app_t : public net::web_socket::server::upper_layer { class app_t : public net::web_socket::upper_layer::server {
public: public:
std::string text_input; std::string text_input;
...@@ -80,7 +80,6 @@ struct fixture { ...@@ -80,7 +80,6 @@ struct fixture {
auto app_ptr = app_t::make(); auto app_ptr = app_t::make();
app = app_ptr.get(); app = app_ptr.get();
auto ws_ptr = net::web_socket::server::make(std::move(app_ptr)); auto ws_ptr = net::web_socket::server::make(std::move(app_ptr));
ws = ws_ptr.get();
transport = mock_stream_transport::make(std::move(ws_ptr)); transport = mock_stream_transport::make(std::move(ws_ptr));
if (auto err = transport->start(nullptr)) if (auto err = transport->start(nullptr))
CAF_FAIL("failed to initialize mock transport: " << err); CAF_FAIL("failed to initialize mock transport: " << err);
...@@ -118,8 +117,6 @@ struct fixture { ...@@ -118,8 +117,6 @@ struct fixture {
std::unique_ptr<mock_stream_transport> transport; std::unique_ptr<mock_stream_transport> transport;
net::web_socket::server* ws;
app_t* app; app_t* app;
std::minstd_rand rng; std::minstd_rand rng;
...@@ -154,7 +151,6 @@ CAF_TEST(applications receive handshake data via config) { ...@@ -154,7 +151,6 @@ CAF_TEST(applications receive handshake data via config) {
} }
CHECK_EQ(transport->input.size(), 0u); CHECK_EQ(transport->input.size(), 0u);
CHECK_EQ(transport->unconsumed(), 0u); CHECK_EQ(transport->unconsumed(), 0u);
CHECK(ws->handshake_complete());
CHECK_SETTING("web-socket.method", "GET"); CHECK_SETTING("web-socket.method", "GET");
CHECK_SETTING("web-socket.path", "/chat"); CHECK_SETTING("web-socket.path", "/chat");
CHECK_SETTING("web-socket.http-version", "HTTP/1.1"); CHECK_SETTING("web-socket.http-version", "HTTP/1.1");
...@@ -175,7 +171,6 @@ CAF_TEST(the server responds with an HTTP response on success) { ...@@ -175,7 +171,6 @@ CAF_TEST(the server responds with an HTTP response on success) {
transport->push(opening_handshake); transport->push(opening_handshake);
CHECK_EQ(transport->handle_input(), CHECK_EQ(transport->handle_input(),
static_cast<ptrdiff_t>(opening_handshake.size())); static_cast<ptrdiff_t>(opening_handshake.size()));
CHECK(ws->handshake_complete());
CHECK_EQ(transport->output_as_str(), CHECK_EQ(transport->output_as_str(),
"HTTP/1.1 101 Switching Protocols\r\n" "HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n" "Upgrade: websocket\r\n"
...@@ -194,14 +189,11 @@ CAF_TEST(handshakes may arrive in chunks) { ...@@ -194,14 +189,11 @@ CAF_TEST(handshakes may arrive in chunks) {
bufs.emplace_back(i, opening_handshake.end()); bufs.emplace_back(i, opening_handshake.end());
transport->push(bufs[0]); transport->push(bufs[0]);
CHECK_EQ(transport->handle_input(), 0u); CHECK_EQ(transport->handle_input(), 0u);
CHECK(!ws->handshake_complete());
transport->push(bufs[1]); transport->push(bufs[1]);
CHECK_EQ(transport->handle_input(), 0u); CHECK_EQ(transport->handle_input(), 0u);
CHECK(!ws->handshake_complete());
transport->push(bufs[2]); transport->push(bufs[2]);
CHECK_EQ(static_cast<size_t>(transport->handle_input()), CHECK_EQ(static_cast<size_t>(transport->handle_input()),
opening_handshake.size()); opening_handshake.size());
CHECK(ws->handshake_complete());
} }
CAF_TEST(data may follow the handshake immediately) { CAF_TEST(data may follow the handshake immediately) {
...@@ -211,7 +203,6 @@ CAF_TEST(data may follow the handshake immediately) { ...@@ -211,7 +203,6 @@ CAF_TEST(data may follow the handshake immediately) {
rfc6455_append("Bye WebSocket!\n"sv, buf); rfc6455_append("Bye WebSocket!\n"sv, buf);
transport->push(buf); transport->push(buf);
CHECK_EQ(transport->handle_input(), static_cast<ptrdiff_t>(buf.size())); CHECK_EQ(transport->handle_input(), static_cast<ptrdiff_t>(buf.size()));
CHECK(ws->handshake_complete());
CHECK_EQ(app->text_input, "Hello WebSocket!\nBye WebSocket!\n"); CHECK_EQ(app->text_input, "Hello WebSocket!\nBye WebSocket!\n");
} }
...@@ -219,7 +210,6 @@ CAF_TEST(data may arrive later) { ...@@ -219,7 +210,6 @@ CAF_TEST(data may arrive later) {
transport->push(opening_handshake); transport->push(opening_handshake);
CHECK_EQ(transport->handle_input(), CHECK_EQ(transport->handle_input(),
static_cast<ptrdiff_t>(opening_handshake.size())); static_cast<ptrdiff_t>(opening_handshake.size()));
CHECK(ws->handshake_complete());
push("Hello WebSocket!\nBye WebSocket!\n"sv); push("Hello WebSocket!\nBye WebSocket!\n"sv);
transport->handle_input(); transport->handle_input();
CHECK_EQ(app->text_input, "Hello WebSocket!\nBye WebSocket!\n"); CHECK_EQ(app->text_input, "Hello WebSocket!\nBye WebSocket!\n");
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment