Commit 300e68f7 authored by Dominik Charousset's avatar Dominik Charousset

Re-implement WebSocket with protocol switching

parent b5eea2ee
......@@ -18,16 +18,15 @@
namespace caf::detail {
/// Convenience alias for referring to the base type of @ref flow_bridge.
template <class Trait, class Role>
template <class Trait, class Base>
using ws_flow_bridge_base_t
= detail::flow_bridge_base<typename Role::upper_layer,
net::web_socket::lower_layer, Trait>;
= detail::flow_bridge_base<Base, net::web_socket::lower_layer, Trait>;
/// Translates between a message-oriented transport and data flows.
template <class Trait, class Role>
class ws_flow_bridge : public ws_flow_bridge_base_t<Trait, Role> {
template <class Trait, class Base>
class ws_flow_bridge : public ws_flow_bridge_base_t<Trait, Base> {
public:
using super = ws_flow_bridge_base_t<Trait, Role>;
using super = ws_flow_bridge_base_t<Trait, Base>;
using input_type = typename Trait::input_type;
......
......@@ -38,6 +38,11 @@ public:
/// Prepares written data for transfer, e.g., by flushing buffers or
/// registering sockets for write events.
virtual bool end_output() = 0;
/// Asks the stream to swap the current upper layer with `next` after
/// returning from `consume()`.
/// @note may only be called from the upper layer in `consume`.
virtual void switch_protocol(std::unique_ptr<upper_layer> next) = 0;
};
} // namespace caf::net::octet_stream
......@@ -81,6 +81,8 @@ public:
void shutdown() override;
void switch_protocol(upper_layer_ptr) override;
// -- properties -------------------------------------------------------------
auto& read_buffer() noexcept {
......@@ -188,6 +190,10 @@ protected:
/// Fallback policy.
policy default_policy_;
/// Setting this to non-null informs the transport to replace `up_` with
/// `next_`.
upper_layer_ptr next_;
// TODO: add [[no_unique_address]] to default_policy_ when switching to C++20.
};
......
......@@ -31,23 +31,16 @@ class CAF_NET_EXPORT client : public octet_stream::upper_layer {
public:
// -- member types -----------------------------------------------------------
class CAF_NET_EXPORT upper_layer : public web_socket::upper_layer {
public:
virtual ~upper_layer();
/// Initializes the upper layer.
/// @param down A pointer to the lower layer that remains valid for the
/// lifetime of the upper layer.
virtual error start(lower_layer* down) = 0;
};
using handshake_ptr = std::unique_ptr<handshake>;
using upper_layer_ptr = std::unique_ptr<upper_layer>;
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer::client>;
// -- constructors, destructors, and assignment operators --------------------
client(handshake_ptr hs, upper_layer_ptr up);
client(handshake_ptr hs, upper_layer_ptr up)
: hs_(std::move(hs)), up_(std::move(up)) {
// nop
}
// -- factories --------------------------------------------------------------
......@@ -57,32 +50,6 @@ public:
return make(std::make_unique<handshake>(std::move(hs)), std::move(up));
}
// -- properties -------------------------------------------------------------
client::upper_layer& up() noexcept {
// This cast is safe, because we know that we have initialized the framing
// layer with a pointer to an web_socket::client::upper_layer object that
// the framing then upcasts to web_socket::upper_layer.
return static_cast<client::upper_layer&>(framing_.up());
}
const client::upper_layer& up() const noexcept {
// See comment in the other up() overload.
return static_cast<const client::upper_layer&>(framing_.up());
}
octet_stream::lower_layer& down() noexcept {
return framing_.down();
}
const octet_stream::lower_layer& down() const noexcept {
return framing_.down();
}
bool handshake_completed() const noexcept {
return hs_ == nullptr;
}
// -- implementation of octet_stream::upper_layer ----------------------------
error start(octet_stream::lower_layer* down) override;
......@@ -102,13 +69,14 @@ private:
// -- member variables -------------------------------------------------------
/// Points to the transport layer below.
octet_stream::lower_layer* down_;
/// Stores the WebSocket handshake data until the handshake completed.
handshake_ptr hs_;
/// Stores the upper layer.
framing framing_;
settings cfg_;
/// Next layer in the processing chain.
upper_layer_ptr up_;
};
} // namespace caf::net::web_socket
......@@ -26,9 +26,9 @@ namespace caf::detail {
/// Specializes the WebSocket flow bridge for the server side.
template <class Trait, class... Ts>
class ws_client_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::client> {
: public ws_flow_bridge<Trait, net::web_socket::upper_layer::client> {
public:
using super = ws_flow_bridge<Trait, net::web_socket::client>;
using super = ws_flow_bridge<Trait, net::web_socket::upper_layer::client>;
// We consume the output type of the application.
using pull_t = async::consumer_resource<typename Trait::input_type>;
......
......@@ -9,6 +9,7 @@
#include "caf/detail/rfc6455.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/octet_stream/lower_layer.hpp"
#include "caf/net/octet_stream/upper_layer.hpp"
#include "caf/net/receive_policy.hpp"
#include "caf/net/web_socket/lower_layer.hpp"
#include "caf/net/web_socket/status.hpp"
......@@ -25,7 +26,8 @@
namespace caf::net::web_socket {
/// Implements the WebSocket framing protocol as defined in RFC-6455.
class CAF_NET_EXPORT framing : public web_socket::lower_layer {
class CAF_NET_EXPORT framing : public octet_stream::upper_layer,
public web_socket::lower_layer {
public:
// -- member types -----------------------------------------------------------
......@@ -33,6 +35,10 @@ public:
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>;
using server_ptr = std::unique_ptr<web_socket::upper_layer::server>;
using client_ptr = std::unique_ptr<web_socket::upper_layer::client>;
// -- constants --------------------------------------------------------------
/// Restricts the size of received frames (including header).
......@@ -43,13 +49,14 @@ public:
// -- constructors, destructors, and assignment operators --------------------
explicit framing(upper_layer_ptr up) : up_(std::move(up)) {
// nop
static std::unique_ptr<framing> make(client_ptr up) {
return std::unique_ptr<framing>{new framing(std::move(up))};
}
// -- initialization ---------------------------------------------------------
void start(octet_stream::lower_layer* down);
static std::unique_ptr<framing> make(server_ptr up,
http::request_header hdr) {
return std::unique_ptr<framing>{new framing(std::move(up), std::move(hdr))};
}
// -- properties -------------------------------------------------------------
......@@ -75,6 +82,18 @@ public:
/// the standard.
bool mask_outgoing_frames = true;
// -- octet_stream::upper_layer implementation -------------------------------
error start(octet_stream::lower_layer* down) override;
void abort(const error& reason) override;
ptrdiff_t consume(byte_span input, byte_span) override;
void prepare_send() override;
bool done_sending() override;
// -- web_socket::lower_layer implementation ---------------------------------
using web_socket::lower_layer::shutdown;
......@@ -105,13 +124,20 @@ public:
bool end_text_message() override;
// -- interface for the lower layer ------------------------------------------
ptrdiff_t consume(byte_span input, byte_span);
private:
// -- implementation details -------------------------------------------------
explicit framing(client_ptr up) : up_(std::move(up)) {
// nop
}
explicit framing(server_ptr up, http::request_header&& hdr)
: up_(std::move(up)), hdr_(std::move(hdr)) {
// > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
mask_outgoing_frames = false;
}
bool handle(uint8_t opcode, byte_span payload);
void ship_pong(byte_span payload);
......@@ -141,6 +167,9 @@ private:
/// Next layer in the processing chain.
upper_layer_ptr up_;
/// Stored when running as a server and passed to `up_` in start.
std::optional<http::request_header> hdr_;
};
} // namespace caf::net::web_socket
......@@ -39,51 +39,18 @@ class CAF_NET_EXPORT server : public octet_stream::upper_layer {
public:
// -- member types -----------------------------------------------------------
class CAF_NET_EXPORT upper_layer : public web_socket::upper_layer {
public:
virtual ~upper_layer();
/// Initializes the upper layer.
/// @param down A pointer to the lower layer that remains valid for the
/// lifetime of the upper layer.
/// @param hdr The HTTP request header from the client handshake.
virtual error start(lower_layer* down, const http::request_header& hdr) = 0;
};
using upper_layer_ptr = std::unique_ptr<upper_layer>;
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer::server>;
// -- constructors, destructors, and assignment operators --------------------
explicit server(upper_layer_ptr up);
// -- factories --------------------------------------------------------------
static std::unique_ptr<server> make(upper_layer_ptr up);
// -- properties -------------------------------------------------------------
server::upper_layer& up() noexcept {
// This cast is safe, because we know that we have initialized the framing
// layer with a pointer to an web_socket::server::upper_layer object that
// the framing then upcasts to web_socket::upper_layer.
return static_cast<server::upper_layer&>(framing_.up());
explicit server(upper_layer_ptr up) : up_(std::move(up)) {
// nop
}
const server::upper_layer& up() const noexcept {
// See comment in the other up() overload.
return static_cast<const server::upper_layer&>(framing_.up());
}
octet_stream::lower_layer& down() noexcept {
return framing_.down();
}
const octet_stream::lower_layer& down() const noexcept {
return framing_.down();
}
// -- factories --------------------------------------------------------------
bool handshake_complete() const noexcept {
return handshake_complete_;
static std::unique_ptr<server> make(upper_layer_ptr up) {
return std::make_unique<server>(std::move(up));
}
// -- octet_stream::upper_layer implementation -------------------------------
......@@ -105,11 +72,11 @@ private:
bool handle_header(std::string_view http);
/// Stores whether the WebSocket handshake completed successfully.
bool handshake_complete_ = false;
/// Points to the transport layer below.
octet_stream::lower_layer* down_;
/// Stores the upper layer.
framing framing_;
/// We store this only to pass it to the framing layer after the handshake.
upper_layer_ptr up_;
};
} // namespace caf::net::web_socket
......@@ -29,9 +29,9 @@ namespace caf::detail {
/// Specializes the WebSocket flow bridge for the server side.
template <class Trait, class... Ts>
class ws_server_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::server> {
: public ws_flow_bridge<Trait, net::web_socket::upper_layer::server> {
public:
using super = ws_flow_bridge<Trait, net::web_socket::server>;
using super = ws_flow_bridge<Trait, net::web_socket::upper_layer::server>;
using ws_acceptor_t = net::web_socket::acceptor<Ts...>;
......
......@@ -16,6 +16,10 @@ namespace caf::net::web_socket {
/// a server or a client.
class CAF_NET_EXPORT upper_layer : public generic_upper_layer {
public:
class server;
class client;
virtual ~upper_layer();
virtual ptrdiff_t consume_binary(byte_span buf) = 0;
......@@ -23,4 +27,16 @@ public:
virtual ptrdiff_t consume_text(std::string_view buf) = 0;
};
class upper_layer::server : public upper_layer {
public:
virtual ~server();
virtual error start(lower_layer* down, const http::request_header& hdr) = 0;
};
class upper_layer::client : public upper_layer {
public:
virtual ~client();
virtual error start(lower_layer* down) = 0;
};
} // namespace caf::net::web_socket
......@@ -94,6 +94,10 @@ void transport::shutdown() {
}
}
void transport::switch_protocol(upper_layer_ptr next) {
next_ = std::move(next);
}
// -- implementation of transport ----------------------------------------------
error transport::start(socket_manager* owner) {
......@@ -186,6 +190,19 @@ void transport::handle_buffered_data() {
CAF_LOG_TRACE(CAF_ARG(buffered_));
// Loop until we have drained the buffer as much as we can.
CAF_ASSERT(min_read_size_ <= max_read_size_);
auto switch_to_next_protocol = [this] {
assert(next_);
// Switch to the new protocol and initialize it.
configure_read(receive_policy::stop());
up_.reset(next_.release());
if (auto err = up_->start(this)) {
up_.reset();
parent_->deregister();
parent_->shutdown();
return false;
}
return true;
};
while (parent_->is_reading() && max_read_size_ > 0
&& buffered_ >= min_read_size_) {
auto n = std::min(buffered_, size_t{max_read_size_});
......@@ -205,16 +222,23 @@ void transport::handle_buffered_data() {
parent_->deregister();
return;
} else if (consumed == 0) {
// See whether the next iteration would change what we pass to the
// application (max_read_size_ may have changed). Otherwise, we'll try
// again later.
delta_offset_ = static_cast<ptrdiff_t>(n);
if (n == std::min(buffered_, size_t{max_read_size_})) {
return;
if (next_) {
// When switching protocol, the new layer has never seen the data, so we
// might just re-invoke the same data again.
if (!switch_to_next_protocol())
return;
} else {
// "Fall through".
// See whether the next iteration would change what we pass to the
// application (max_read_size_ may have changed). Otherwise, we'll try
// again later.
delta_offset_ = static_cast<ptrdiff_t>(n);
if (n == std::min(buffered_, size_t{max_read_size_}))
return;
// else: "Fall through".
}
} else {
if (next_ && !switch_to_next_protocol())
return;
// Shove the unread bytes to the beginning of the buffer and continue
// to the next loop iteration.
auto del = static_cast<size_t>(consumed);
......
......@@ -22,19 +22,6 @@
namespace caf::net::web_socket {
// -- member types -------------------------------------------------------------
client::upper_layer::~upper_layer() {
// nop
}
// -- constructors, destructors, and assignment operators ----------------------
client::client(handshake_ptr hs, upper_layer_ptr up_ptr)
: hs_(std::move(hs)), framing_(std::move(up_ptr)) {
// nop
}
// -- factories ----------------------------------------------------------------
std::unique_ptr<client> client::make(handshake_ptr hs, upper_layer_ptr up_ptr) {
......@@ -43,67 +30,56 @@ std::unique_ptr<client> client::make(handshake_ptr hs, upper_layer_ptr up_ptr) {
// -- implementation of octet_stream::upper_layer ------------------------------
error client::start(octet_stream::lower_layer* down_ptr) {
error client::start(octet_stream::lower_layer* down) {
CAF_ASSERT(hs_ != nullptr);
framing_.start(down_ptr);
if (!hs_->has_mandatory_fields())
return make_error(sec::runtime_error,
"handshake data lacks mandatory fields");
"WebSocket client received an incomplete handshake");
if (!hs_->has_valid_key())
hs_->randomize_key();
down().begin_output();
hs_->write_http_1_request(down().output_buffer());
down().end_output();
down().configure_read(receive_policy::up_to(handshake::max_http_size));
down_ = down;
down_->begin_output();
hs_->write_http_1_request(down_->output_buffer());
down_->end_output();
down_->configure_read(receive_policy::up_to(handshake::max_http_size));
return none;
}
void client::abort(const error& reason) {
up().abort(reason);
assert(up_);
up_->abort(reason);
}
ptrdiff_t client::consume(byte_span buffer, byte_span delta) {
CAF_LOG_TRACE(CAF_ARG2("buffer", buffer.size()));
if (handshake_completed()) {
// Short circuit to the framing layer after the handshake completed.
return framing_.consume(buffer, delta);
} else {
// Check whether we have received the HTTP header or else wait for more
// data. Abort when exceeding the maximum size.
auto [hdr, remainder] = http::v1::split_header(buffer);
if (hdr.empty()) {
if (buffer.size() >= handshake::max_http_size) {
CAF_LOG_ERROR("server response exceeded the maximum header size");
up().abort(make_error(sec::protocol_error, "server response exceeded "
"the maximum header size"));
return -1;
} else {
return 0;
}
} else if (!handle_header(hdr)) {
// Note: handle_header() already calls upper_layer().abort().
// Check whether we have received the HTTP header or else wait for more
// data. Abort when exceeding the maximum size.
auto [hdr, remainder] = http::v1::split_header(buffer);
if (hdr.empty()) {
if (buffer.size() >= handshake::max_http_size) {
CAF_LOG_ERROR("server response exceeded the maximum header size");
up_->abort(make_error(sec::protocol_error, "server response exceeded "
"the maximum header size"));
return -1;
} else if (remainder.empty()) {
CAF_ASSERT(hdr.size() == buffer.size());
return hdr.size();
} else {
CAF_LOG_DEBUG(CAF_ARG2("remainder.size", remainder.size()));
if (auto res = framing_.consume(remainder, remainder); res >= 0) {
return hdr.size() + res;
} else {
return res;
}
}
// Wait for more data.
return 0;
}
if (!handle_header(hdr)) {
// Note: handle_header() already calls upper_layer().abort().
return -1;
}
// We only care about the header here. The framing layer is responsible for
// any remaining data.
return static_cast<ptrdiff_t>(hdr.size());
}
void client::prepare_send() {
if (handshake_completed())
up().prepare_send();
// nop
}
bool client::done_sending() {
return handshake_completed() ? up().done_sending() : true;
return true;
}
// -- HTTP response processing -------------------------------------------------
......@@ -113,19 +89,13 @@ bool client::handle_header(std::string_view http) {
auto http_ok = hs_->is_valid_http_1_response(http);
hs_.reset();
if (http_ok) {
if (auto err = up().start(&framing_)) {
CAF_LOG_DEBUG("failed to initialize WebSocket framing layer");
return false;
} else {
CAF_LOG_DEBUG("completed WebSocket handshake");
return true;
}
} else {
CAF_LOG_DEBUG("received an invalid WebSocket handshake");
up().abort(make_error(sec::protocol_error,
"received an invalid WebSocket handshake"));
return false;
down_->switch_protocol(framing::make(std::move(up_)));
return true;
}
CAF_LOG_DEBUG("received an invalid WebSocket handshake");
up_->abort(
make_error(sec::protocol_error, "received an invalid WebSocket handshake"));
return false;
}
} // namespace caf::net::web_socket
......@@ -5,92 +5,38 @@
#include "caf/net/web_socket/framing.hpp"
#include "caf/logger.hpp"
#include "caf/net/http/v1.hpp"
namespace caf::net::web_socket {
// -- initialization ---------------------------------------------------------
// -- octet_stream::upper_layer implementation ---------------------------------
void framing::start(octet_stream::lower_layer* down) {
error framing::start(octet_stream::lower_layer* down) {
std::random_device rd;
rng_.seed(rd());
down_ = down;
}
// -- web_socket::lower_layer implementation -----------------------------------
multiplexer& framing::mpx() noexcept {
return down_->mpx();
}
bool framing::can_send_more() const noexcept {
return down_->can_send_more();
}
void framing::suspend_reading() {
down_->configure_read(receive_policy::stop());
}
bool framing::is_reading() const noexcept {
return down_->is_reading();
}
void framing::write_later() {
down_->write_later();
}
void framing::shutdown(status code, std::string_view msg) {
auto code_val = static_cast<uint16_t>(code);
uint32_t mask_key = 0;
byte_buffer payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<std::byte>((code_val & 0xFF00) >> 8));
payload.push_back(static_cast<std::byte>(code_val & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<std::byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
if (!hdr_) {
using dptr_t = web_socket::upper_layer::client;
return static_cast<dptr_t*>(up_.get())->start(this);
}
down_->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down_->output_buffer());
down_->end_output();
down_->shutdown();
}
void framing::request_messages() {
if (!down_->is_reading())
down_->configure_read(receive_policy::up_to(2048));
}
void framing::begin_binary_message() {
// nop
}
byte_buffer& framing::binary_message_buffer() {
return binary_buf_;
}
bool framing::end_binary_message() {
ship_frame(binary_buf_);
return true;
}
void framing::begin_text_message() {
// nop
}
text_buffer& framing::text_message_buffer() {
return text_buf_;
using dptr_t = web_socket::upper_layer::server;
auto err = static_cast<dptr_t*>(up_.get())->start(this, *hdr_);
hdr_ = std::nullopt;
if (err) {
auto descr = to_string(err);
CAF_LOG_DEBUG("upper layer rejected a WebSocket connection:" << descr);
down_->begin_output();
http::v1::write_response(http::status::bad_request, "text/plain", descr,
down_->output_buffer());
down_->end_output();
}
return err;
}
bool framing::end_text_message() {
ship_frame(text_buf_);
return true;
void framing::abort(const error& reason) {
up_->abort(reason);
}
// -- interface for the lower layer --------------------------------------------
ptrdiff_t framing::consume(byte_span buffer, byte_span) {
// Make sure we're overriding any 'exactly' setting.
down_->configure_read(receive_policy::up_to(2048));
......@@ -181,6 +127,87 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) {
return static_cast<ptrdiff_t>(frame_size);
}
void framing::prepare_send() {
up_->prepare_send();
}
bool framing::done_sending() {
return up_->done_sending();
}
// -- web_socket::lower_layer implementation -----------------------------------
multiplexer& framing::mpx() noexcept {
return down_->mpx();
}
bool framing::can_send_more() const noexcept {
return down_->can_send_more();
}
void framing::suspend_reading() {
down_->configure_read(receive_policy::stop());
}
bool framing::is_reading() const noexcept {
return down_->is_reading();
}
void framing::write_later() {
down_->write_later();
}
void framing::shutdown(status code, std::string_view msg) {
auto code_val = static_cast<uint16_t>(code);
uint32_t mask_key = 0;
byte_buffer payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<std::byte>((code_val & 0xFF00) >> 8));
payload.push_back(static_cast<std::byte>(code_val & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<std::byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down_->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down_->output_buffer());
down_->end_output();
down_->shutdown();
}
void framing::request_messages() {
if (!down_->is_reading())
down_->configure_read(receive_policy::up_to(2048));
}
void framing::begin_binary_message() {
// nop
}
byte_buffer& framing::binary_message_buffer() {
return binary_buf_;
}
bool framing::end_binary_message() {
ship_frame(binary_buf_);
return true;
}
void framing::begin_text_message() {
// nop
}
text_buffer& framing::text_message_buffer() {
return text_buf_;
}
bool framing::end_text_message() {
ship_frame(text_buf_);
return true;
}
// -- implementation details ---------------------------------------------------
bool framing::handle(uint8_t opcode, byte_span payload) {
......
......@@ -6,84 +6,56 @@
namespace caf::net::web_socket {
// -- member types -------------------------------------------------------------
server::upper_layer::~upper_layer() {
// nop
}
// -- constructors, destructors, and assignment operators ----------------------
server::server(upper_layer_ptr up) : framing_(std::move(up)) {
// > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
framing_.mask_outgoing_frames = false;
}
// -- factories ----------------------------------------------------------------
std::unique_ptr<server> server::make(upper_layer_ptr up) {
return std::make_unique<server>(std::move(up));
}
// -- octet_stream::upper_layer implementation ---------------------------------
error server::start(octet_stream::lower_layer* down_ptr) {
framing_.start(down_ptr);
down().configure_read(receive_policy::up_to(handshake::max_http_size));
error server::start(octet_stream::lower_layer* down) {
down_ = down;
down_->configure_read(receive_policy::up_to(handshake::max_http_size));
return none;
}
void server::abort(const error& reason) {
if (handshake_complete_)
up().abort(reason);
void server::abort(const error& err) {
up_->abort(err);
}
ptrdiff_t server::consume(byte_span input, byte_span delta) {
ptrdiff_t server::consume(byte_span input, byte_span) {
using namespace std::literals;
CAF_LOG_TRACE(CAF_ARG2("bytes.size", input.size()));
if (handshake_complete_) {
// Short circuit to the framing layer after the handshake completed.
return framing_.consume(input, delta);
} else {
// Check whether we received an HTTP header or else wait for more data.
// Abort when exceeding the maximum size.
auto [hdr, remainder] = http::v1::split_header(input);
if (hdr.empty()) {
if (input.size() >= handshake::max_http_size) {
down().begin_output();
http::v1::write_response(http::status::request_header_fields_too_large,
"text/plain"sv,
"Header exceeds maximum size."sv,
down().output_buffer());
down().end_output();
return -1;
} else {
return 0;
}
} else if (!handle_header(hdr)) {
// Check whether we received an HTTP header or else wait for more data.
// Abort when exceeding the maximum size.
auto [hdr, remainder] = http::v1::split_header(input);
if (hdr.empty()) {
if (input.size() >= handshake::max_http_size) {
down_->begin_output();
http::v1::write_response(http::status::request_header_fields_too_large,
"text/plain"sv, "Header exceeds maximum size."sv,
down_->output_buffer());
down_->end_output();
return -1;
} else {
return hdr.size();
return 0;
}
} else if (!handle_header(hdr)) {
return -1;
} else {
return hdr.size();
}
}
void server::prepare_send() {
if (handshake_complete_)
up().prepare_send();
// nop
}
bool server::done_sending() {
return handshake_complete_ ? up().done_sending() : true;
return true;
}
// -- HTTP request processing ------------------------------------------------
void server::write_response(http::status code, std::string_view msg) {
down().begin_output();
http::v1::write_response(code, "text/plain", msg, down().output_buffer());
down().end_output();
down_->begin_output();
http::v1::write_response(code, "text/plain", msg, down_->output_buffer());
down_->end_output();
}
bool server::handle_header(std::string_view http) {
......@@ -108,23 +80,14 @@ bool server::handle_header(std::string_view http) {
CAF_LOG_DEBUG("received invalid WebSocket handshake");
return false;
}
// Try to initialize the upper layer.
down().configure_read(receive_policy::stop());
if (auto err = up().start(&framing_, hdr)) {
auto descr = to_string(err);
CAF_LOG_DEBUG("upper layer rejected a WebSocket connection:" << descr);
write_response(http::status::bad_request, descr);
return false;
}
// Finalize the WebSocket handshake.
handshake hs;
hs.assign_key(sec_key);
down().begin_output();
hs.write_http_1_response(down().output_buffer());
down().end_output();
down_->begin_output();
hs.write_http_1_response(down_->output_buffer());
down_->end_output();
CAF_LOG_DEBUG("completed WebSocket handshake");
handshake_complete_ = true;
down_->switch_protocol(framing::make(std::move(up_), std::move(hdr)));
return true;
}
......
......@@ -10,4 +10,12 @@ upper_layer::~upper_layer() {
// nop
}
upper_layer::server::~server() {
// nop
}
upper_layer::client::~client() {
// nop
}
} // namespace caf::net::web_socket
......@@ -35,6 +35,10 @@ void mock_stream_transport::shutdown() {
// nop
}
void mock_stream_transport::switch_protocol(upper_layer_ptr new_up) {
next.swap(new_up);
}
void mock_stream_transport::configure_read(net::receive_policy policy) {
min_read_size = policy.min_size;
max_read_size = policy.max_size;
......@@ -54,6 +58,17 @@ bool mock_stream_transport::end_output() {
ptrdiff_t mock_stream_transport::handle_input() {
ptrdiff_t result = 0;
auto switch_to_next_protocol = [this] {
assert(next);
// Switch to the new protocol and initialize it.
configure_read(net::receive_policy::stop());
up.reset(next.release());
if (auto err = up->start(this)) {
up.reset();
return false;
}
return true;
};
// Loop until we have drained the buffer as much as we can.
while (max_read_size > 0 && input.size() >= min_read_size) {
auto n = std::min(input.size(), size_t{max_read_size});
......@@ -71,16 +86,24 @@ ptrdiff_t mock_stream_transport::handle_input() {
up->abort(make_error(sec::logic_error, "consumed > buffer.size"));
return result;
} else if (consumed == 0) {
// See whether the next iteration would change what we pass to the
// application (max_read_size_ may have changed). Otherwise, we'll try
// again later.
delta_offset = static_cast<ptrdiff_t>(n);
if (n == std::min(input.size(), size_t{max_read_size})) {
return result;
if (next) {
// When switching protocol, the new layer has never seen the data, so we
// might just re-invoke the same data again.
if (!switch_to_next_protocol())
return -1;
} else {
// "Fall through".
// See whether the next iteration would change what we pass to the
// application (max_read_size_ may have changed). Otherwise, we'll try
// again later.
delta_offset = static_cast<ptrdiff_t>(n);
if (n == std::min(input.size(), size_t{max_read_size})) {
return result;
}
// else: "Fall through".
}
} else {
if (next && !switch_to_next_protocol())
return -1;
// Shove the unread bytes to the beginning of the buffer and continue
// to the next loop iteration.
result += consumed;
......
......@@ -48,6 +48,8 @@ public:
bool end_output() override;
void switch_protocol(upper_layer_ptr) override;
// -- initialization ---------------------------------------------------------
caf::error start(caf::net::multiplexer* ptr) {
......@@ -81,6 +83,8 @@ public:
upper_layer_ptr up;
upper_layer_ptr next;
caf::byte_buffer output;
caf::byte_buffer input;
......
......@@ -14,7 +14,7 @@ namespace {
using svec = std::vector<std::string>;
class app_t : public net::web_socket::client::upper_layer {
class app_t : public net::web_socket::upper_layer::client {
public:
static auto make() {
return std::make_unique<app_t>();
......@@ -95,7 +95,6 @@ SCENARIO("the client performs the WebSocket handshake on startup") {
WHEN("starting a WebSocket client") {
auto app = app_t::make();
auto ws = net::web_socket::client::make(make_handshake(), std::move(app));
auto& ws_state = *ws;
auto uut = mock_stream_transport::make(std::move(ws));
THEN("the client sends its HTTP request when initializing it") {
CHECK_EQ(uut->start(nullptr), error{});
......@@ -105,7 +104,6 @@ SCENARIO("the client performs the WebSocket handshake on startup") {
uut->push(http_response);
CHECK_EQ(uut->handle_input(),
static_cast<ptrdiff_t>(http_response.size()));
CHECK(ws_state.handshake_completed());
}
}
}
......
......@@ -20,7 +20,7 @@ namespace {
using svec = std::vector<std::string>;
class app_t : public net::web_socket::server::upper_layer {
class app_t : public net::web_socket::upper_layer::server {
public:
std::string text_input;
......@@ -80,7 +80,6 @@ struct fixture {
auto app_ptr = app_t::make();
app = app_ptr.get();
auto ws_ptr = net::web_socket::server::make(std::move(app_ptr));
ws = ws_ptr.get();
transport = mock_stream_transport::make(std::move(ws_ptr));
if (auto err = transport->start(nullptr))
CAF_FAIL("failed to initialize mock transport: " << err);
......@@ -118,8 +117,6 @@ struct fixture {
std::unique_ptr<mock_stream_transport> transport;
net::web_socket::server* ws;
app_t* app;
std::minstd_rand rng;
......@@ -154,7 +151,6 @@ CAF_TEST(applications receive handshake data via config) {
}
CHECK_EQ(transport->input.size(), 0u);
CHECK_EQ(transport->unconsumed(), 0u);
CHECK(ws->handshake_complete());
CHECK_SETTING("web-socket.method", "GET");
CHECK_SETTING("web-socket.path", "/chat");
CHECK_SETTING("web-socket.http-version", "HTTP/1.1");
......@@ -175,7 +171,6 @@ CAF_TEST(the server responds with an HTTP response on success) {
transport->push(opening_handshake);
CHECK_EQ(transport->handle_input(),
static_cast<ptrdiff_t>(opening_handshake.size()));
CHECK(ws->handshake_complete());
CHECK_EQ(transport->output_as_str(),
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
......@@ -194,14 +189,11 @@ CAF_TEST(handshakes may arrive in chunks) {
bufs.emplace_back(i, opening_handshake.end());
transport->push(bufs[0]);
CHECK_EQ(transport->handle_input(), 0u);
CHECK(!ws->handshake_complete());
transport->push(bufs[1]);
CHECK_EQ(transport->handle_input(), 0u);
CHECK(!ws->handshake_complete());
transport->push(bufs[2]);
CHECK_EQ(static_cast<size_t>(transport->handle_input()),
opening_handshake.size());
CHECK(ws->handshake_complete());
}
CAF_TEST(data may follow the handshake immediately) {
......@@ -211,7 +203,6 @@ CAF_TEST(data may follow the handshake immediately) {
rfc6455_append("Bye WebSocket!\n"sv, buf);
transport->push(buf);
CHECK_EQ(transport->handle_input(), static_cast<ptrdiff_t>(buf.size()));
CHECK(ws->handshake_complete());
CHECK_EQ(app->text_input, "Hello WebSocket!\nBye WebSocket!\n");
}
......@@ -219,7 +210,6 @@ CAF_TEST(data may arrive later) {
transport->push(opening_handshake);
CHECK_EQ(transport->handle_input(),
static_cast<ptrdiff_t>(opening_handshake.size()));
CHECK(ws->handshake_complete());
push("Hello WebSocket!\nBye WebSocket!\n"sv);
transport->handle_input();
CHECK_EQ(app->text_input, "Hello WebSocket!\nBye WebSocket!\n");
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment