Commit c6fbe7ad authored by Dominik Charousset's avatar Dominik Charousset

Improve WebSocket API

- Implement proper WebSocket close handshake.
- Add convenience API for flow-based WebSocket servers.
parent 61fa2989
......@@ -17,6 +17,7 @@ caf_incubator_add_component(
net.basp.ec
net.basp.message_type
net.operation
net.web_socket.status
HEADERS
${CAF_NET_HEADERS}
SOURCES
......
......@@ -31,7 +31,8 @@ public:
// -- constructors, destructors, and assignment operators --------------------
template <class... Ts>
explicit connection_acceptor(Ts&&... xs) : factory_(std::forward<Ts>(xs)...) {
explicit connection_acceptor(size_t limit, Ts&&... xs)
: factory_(std::forward<Ts>(xs)...), limit_(limit) {
// nop
}
......@@ -54,13 +55,21 @@ public:
CAF_LOG_TRACE("");
if (auto x = accept(parent->handle())) {
socket_manager_ptr child = factory_.make(*x, owner_->mpx_ptr());
CAF_ASSERT(child != nullptr);
if (!child) {
CAF_LOG_ERROR("factory failed to create a new child");
parent->abort_reason(sec::runtime_error);
return false;
}
if (auto err = child->init(cfg_)) {
CAF_LOG_ERROR("failed to initialize new child:" << err);
parent->abort_reason(std::move(err));
return false;
}
if (limit_ == 0) {
return true;
} else {
return ++accepted_ < limit_;
}
} else {
CAF_LOG_ERROR("accept failed:" << x.error());
return false;
......@@ -89,6 +98,10 @@ private:
socket_manager* owner_;
size_t limit_;
size_t accepted_ = 0;
settings cfg_;
};
......
......@@ -121,6 +121,18 @@ public:
}
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr) {
// nop; this framing layer has no close handshake
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr, const error&) {
// nop; this framing layer has no close handshake
return true;
}
template <class LowerLayerPtr>
static void abort_reason(LowerLayerPtr down, error reason) {
return down->abort_reason(std::move(reason));
......@@ -237,8 +249,8 @@ private:
/// @param out Outputs from the socket.
/// @param trait Converts between the native and the wire format.
/// @relates length_prefix_framing
template <template <class> class Transport = stream_transport, class T,
class Socket, class Trait>
template <template <class> class Transport = stream_transport, class Socket,
class T, class Trait>
error run_with_length_prefix_framing(multiplexer& mpx, Socket fd,
const settings& cfg,
async::consumer_resource<T> in,
......
......@@ -13,6 +13,7 @@
#include "caf/sec.hpp"
#include "caf/settings.hpp"
#include "caf/tag/message_oriented.hpp"
#include "caf/tag/mixed_message_oriented.hpp"
#include "caf/tag/no_auto_reading.hpp"
#include <utility>
......@@ -29,56 +30,78 @@ namespace caf::net {
/// bool convert(const_byte_span bytes, T& value);
/// };
/// ~~~
template <class T, class Trait>
class message_flow_bridge : public caf::tag::no_auto_reading {
template <class T, class Trait, class Tag = tag::message_oriented>
class message_flow_bridge : public tag::no_auto_reading {
public:
using input_tag = caf::tag::message_oriented;
using input_tag = Tag;
using buffer_type = caf::async::spsc_buffer<T>;
using buffer_type = async::spsc_buffer<T>;
using consumer_resource_t = async::consumer_resource<T>;
using producer_resource_t = async::producer_resource<T>;
explicit message_flow_bridge(Trait trait) : trait_(std::move(trait)) {
// nop
}
void connect_flows(caf::net::socket_manager* mgr,
async::consumer_resource<T> in,
async::producer_resource<T> out) {
void connect_flows(net::socket_manager* mgr, consumer_resource_t in,
producer_resource_t out) {
in_ = consumer_adapter<buffer_type>::try_open(mgr, in);
out_ = producer_adapter<buffer_type>::try_open(mgr, out);
}
template <class LowerLayerPtr>
caf::error
init(caf::net::socket_manager* mgr, LowerLayerPtr&&, const caf::settings&) {
error
init(net::socket_manager* mgr, LowerLayerPtr down, const settings& cfg) {
mgr_ = mgr;
if constexpr (caf::detail::has_init_v<Trait>) {
if (auto err = init_res(trait_.init(cfg)))
return err;
}
if (!in_ && !out_)
return make_error(sec::cannot_open_resource,
"flow bridge cannot run without at least one resource");
else
return caf::none;
if (!out_)
down->suspend_reading();
return none;
}
template <class LowerLayerPtr>
bool write(LowerLayerPtr down, const T& item) {
if constexpr (std::is_same_v<Tag, tag::message_oriented>) {
down->begin_message();
auto& buf = down->message_buffer();
return trait_.convert(item, buf) && down->end_message();
} else {
static_assert(std::is_same_v<Tag, tag::mixed_message_oriented>);
if (trait_.converts_to_binary(item)) {
down->begin_binary_message();
auto& buf = down->binary_message_buffer();
return trait_.convert(item, buf) && down->end_binary_message();
} else {
down->begin_text_message();
auto& buf = down->text_message_buffer();
return trait_.convert(item, buf) && down->end_text_message();
}
}
}
template <class LowerLayerPtr>
struct send_helper {
struct write_helper {
using bridge_type = message_flow_bridge;
bridge_type* bridge;
LowerLayerPtr down;
bool aborted = false;
size_t consumed = 0;
error err;
send_helper(bridge_type* bridge, LowerLayerPtr down)
write_helper(bridge_type* bridge, LowerLayerPtr down)
: bridge(bridge), down(down) {
// nop
}
void on_next(caf::span<const T> items) {
void on_next(span<const T> items) {
CAF_ASSERT(items.size() == 1);
for (const auto& item : items) {
if (!bridge->write(down, item)) {
......@@ -92,17 +115,22 @@ public:
// nop
}
void on_error(const caf::error&) {
// nop
void on_error(const error& x) {
err = x;
}
};
template <class LowerLayerPtr>
bool prepare_send(LowerLayerPtr down) {
send_helper<LowerLayerPtr> helper{this, down};
write_helper<LowerLayerPtr> helper{this, down};
while (down->can_send_more() && in_) {
auto [again, consumed] = in_->pull(caf::async::delay_errors, 1, helper);
auto [again, consumed] = in_->pull(async::delay_errors, 1, helper);
if (!again) {
if (helper.err) {
down->send_close_message(helper.err);
} else {
down->send_close_message();
}
in_ = nullptr;
} else if (helper.aborted) {
down->abort_reason(make_error(sec::conversion_failed));
......@@ -122,11 +150,10 @@ public:
}
template <class LowerLayerPtr>
void abort(LowerLayerPtr, const caf::error& reason) {
void abort(LowerLayerPtr, const error& reason) {
CAF_LOG_TRACE(CAF_ARG(reason));
if (out_) {
if (reason == caf::sec::socket_disconnected
|| reason == caf::sec::discarded)
if (reason == sec::socket_disconnected || reason == sec::discarded)
out_->close();
else
out_->abort(reason);
......@@ -138,8 +165,29 @@ public:
}
}
template <class LowerLayerPtr>
ptrdiff_t consume(LowerLayerPtr down, caf::byte_span buf) {
template <class U = Tag, class LowerLayerPtr>
ptrdiff_t consume(LowerLayerPtr down, byte_span buf) {
if (!out_) {
down->abort_reason(make_error(sec::connection_closed));
return -1;
}
T val;
if (!trait_.convert(buf, val)) {
down->abort_reason(make_error(sec::conversion_failed));
return -1;
}
if (out_->push(std::move(val)) == 0)
down->suspend_reading();
return static_cast<ptrdiff_t>(buf.size());
}
template <class U = Tag, class LowerLayerPtr>
ptrdiff_t consume_binary(LowerLayerPtr down, byte_span buf) {
return consume(down, buf);
}
template <class U = Tag, class LowerLayerPtr>
ptrdiff_t consume_text(LowerLayerPtr down, string_view buf) {
if (!out_) {
down->abort_reason(make_error(sec::connection_closed));
return -1;
......@@ -155,8 +203,35 @@ public:
}
private:
error init_res(error err) {
return err;
}
error init_res(consumer_resource_t in, producer_resource_t out) {
connect_flows(mgr_, std::move(in), std::move(out));
return caf::none;
}
error init_res(std::tuple<consumer_resource_t, producer_resource_t> in_out) {
auto& [in, out] = in_out;
return init_res(std::move(in), std::move(out));
}
error init_res(std::pair<consumer_resource_t, producer_resource_t> in_out) {
auto& [in, out] = in_out;
return init_res(std::move(in), std::move(out));
}
template <class R>
error init_res(expected<R> res) {
if (res)
return init_res(*res);
else
return std::move(res.error());
}
/// Points to the manager that runs this protocol stack.
caf::net::socket_manager* mgr_ = nullptr;
net::socket_manager* mgr_ = nullptr;
/// Incoming messages, serialized to the socket.
consumer_adapter_ptr<buffer_type> in_;
......
......@@ -46,6 +46,11 @@ public:
return lptr_->end_message(llptr_);
}
template <class... Ts>
bool send_close_message(Ts&&... xs) {
return lptr_->send_close_message(llptr_, std::forward<Ts>(xs)...);
}
void abort_reason(error reason) {
return lptr_->abort_reason(llptr_, std::move(reason));
}
......
......@@ -44,20 +44,24 @@ public:
// -- socket manager functions -----------------------------------------------
///
/// Creates a new acceptor that accepts incoming connections from @p sock and
/// creates socket managers using @p factory.
/// @param sock An accept socket in listening mode. For a TCP socket, this
/// socket must already listen to an address plus port.
/// @param factory An application stack factory.
/// @param factory A function object for creating socket managers that take
/// ownership of incoming connections.
/// @param limit The maximum number of connections that this acceptor should
/// establish or 0 for 'no limit'.
template <class Socket, class Factory>
auto make_acceptor(Socket sock, Factory factory) {
auto make_acceptor(Socket sock, Factory factory, size_t limit = 0) {
using connected_socket_type = typename Socket::connected_socket_type;
if constexpr (detail::is_callable_with<Factory, connected_socket_type,
multiplexer*>::value) {
connection_acceptor_factory_adapter<Factory> adapter{std::move(factory)};
return make_acceptor(std::move(sock), std::move(adapter));
return make_acceptor(std::move(sock), std::move(adapter), limit);
} else {
using impl = connection_acceptor<Socket, Factory>;
auto ptr = make_socket_manager<impl>(std::move(sock), &mpx_,
auto ptr = make_socket_manager<impl>(std::move(sock), &mpx_, limit,
std::move(factory));
mpx_.init(ptr);
return ptr;
......
......@@ -42,8 +42,8 @@ public:
return lptr_->binary_message_buffer(llptr_);
}
void end_binary_message() {
lptr_->end_binary_message(llptr_);
bool end_binary_message() {
return lptr_->end_binary_message(llptr_);
}
void begin_text_message() {
......@@ -54,8 +54,13 @@ public:
return lptr_->text_message_buffer(llptr_);
}
void end_text_message() {
lptr_->end_text_message(llptr_);
bool end_text_message() {
return lptr_->end_text_message(llptr_);
}
template <class... Ts>
bool send_close_message(Ts&&... xs) {
return lptr_->send_close_message(llptr_, std::forward<Ts>(xs)...);
}
void abort_reason(error reason) {
......
......@@ -8,6 +8,7 @@
#include "caf/byte_span.hpp"
#include "caf/detail/rfc6455.hpp"
#include "caf/net/mixed_message_oriented_layer_ptr.hpp"
#include "caf/net/web_socket/status.hpp"
#include "caf/sec.hpp"
#include "caf/span.hpp"
#include "caf/string_view.hpp"
......@@ -20,7 +21,7 @@
namespace caf::net::web_socket {
/// Implements the WebProtocol framing protocol as defined in RFC-6455.
/// Implements the WebSocket framing protocol as defined in RFC-6455.
template <class UpperLayer>
class framing {
public:
......@@ -87,11 +88,8 @@ public:
}
template <class LowerLayerPtr>
static void suspend_reading(LowerLayerPtr) {
CAF_RAISE_ERROR("suspending / resuming a WebSocket not implemented yet");
// TODO: uncommenting this isn't enough since consume() also must make sure
// to not override the configure_read.
// down->configure_read(receive_policy::stop());
static void suspend_reading(LowerLayerPtr down) {
down->configure_read(receive_policy::stop());
}
template <class LowerLayerPtr>
......@@ -105,8 +103,9 @@ public:
}
template <class LowerLayerPtr>
void end_binary_message(LowerLayerPtr down) {
bool end_binary_message(LowerLayerPtr down) {
ship_frame(down, binary_buf_);
return true;
}
template <class LowerLayerPtr>
......@@ -120,8 +119,28 @@ public:
}
template <class LowerLayerPtr>
void end_text_message(LowerLayerPtr down) {
bool end_text_message(LowerLayerPtr down) {
ship_frame(down, text_buf_);
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr down) {
ship_close(down);
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr down, status code, string_view desc) {
ship_close(down, static_cast<uint16_t>(code), desc);
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr down, const error& reason) {
ship_close(down, static_cast<uint16_t>(status::unexpected_condition),
to_string(reason));
return true;
}
template <class LowerLayerPtr>
......@@ -138,12 +157,12 @@ public:
template <class LowerLayerPtr>
bool prepare_send(LowerLayerPtr down) {
return upper_layer_.prepare_send(down);
return upper_layer_.prepare_send(this_layer_ptr(down));
}
template <class LowerLayerPtr>
bool done_sending(LowerLayerPtr down) {
return upper_layer_.done_sending(down);
return upper_layer_.done_sending(this_layer_ptr(down));
}
template <class LowerLayerPtr>
......@@ -185,6 +204,9 @@ public:
// Wait for more data if necessary.
size_t frame_size = hdr_bytes + hdr.payload_len;
if (buffer.size() < frame_size) {
// Ask for more data unless the upper layer called suspend_reading.
if (!down->stopped())
down->configure_read(receive_policy::up_to(2048));
down->configure_read(receive_policy::exactly(frame_size));
return consumed;
}
......@@ -196,6 +218,7 @@ public:
}
if (hdr.fin) {
if (opcode_ == nil_code) {
// Call upper layer.
if (!handle(down, hdr.opcode, payload))
return -1;
} else if (hdr.opcode != detail::rfc6455::continuation_frame) {
......@@ -243,6 +266,8 @@ public:
// Advance to next frame in the input.
buffer = buffer.subspan(frame_size);
if (buffer.empty()) {
// Ask for more data unless the upper layer called suspend_reading.
if (!down->stopped())
down->configure_read(receive_policy::up_to(2048));
return consumed + static_cast<ptrdiff_t>(frame_size);
}
......@@ -291,6 +316,42 @@ private:
down->end_output();
}
template <class LowerLayerPtr>
void ship_close(LowerLayerPtr down, uint16_t code, string_view msg) {
uint32_t mask_key = 0;
std::vector<byte> payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<byte>((code & 0xFF00) >> 8));
payload.push_back(static_cast<byte>(code & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down->output_buffer());
down->end_output();
}
template <class LowerLayerPtr>
void ship_close(LowerLayerPtr down) {
uint32_t mask_key = 0;
byte payload[] = {
byte{0x03}, byte{0xE8}, // Error code 1000: normal close.
byte{'E'}, byte{'O'}, byte{'F'}, // "EOF" string as goodbye message.
};
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down->output_buffer());
down->end_output();
}
template <class LowerLayerPtr, class T>
void ship_frame(LowerLayerPtr down, std::vector<T>& buf) {
uint32_t mask_key = 0;
......
......@@ -125,6 +125,10 @@ public:
/// @pre `has_valid_key()`
void write_http_1_response(byte_buffer& buf) const;
/// Writes an HTTP 1.1 'Bad Request' error to `buf` with `descr` providing
/// additional information to the client.
static void write_http_1_bad_request(byte_buffer& buf, string_view descr);
/// Writes a HTTP 1.1 431 (Request Header Fields Too Large) response.
static void write_http_1_header_too_large(byte_buffer& buf);
......
This diff is collapsed.
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/default_enum_inspect.hpp"
#include "caf/detail/net_export.hpp"
#include <cstdint>
namespace caf::net::web_socket {
/// Status codes as defined by RFC 6455, Section 7.4.
enum class status : uint16_t {
/// Indicates a normal closure, meaning that the purpose for which the
/// connection was established has been fulfilled.
normal_close = 1000,
/// Indicates that an endpoint is "going away", such as a server going down or
/// a browser having navigated away from a page.
going_away = 1001,
/// Indicates that an endpoint is terminating the connection due to a protocol
/// error.
protocol_error = 1002,
/// Indicates that an endpoint is terminating the connection because it has
/// received a type of data it cannot accept (e.g., an endpoint that
/// understands only text data MAY send this if it receives a binary message).
invalid_data = 1003,
/// A reserved value and MUST NOT be set as a status code in a Close control
/// frame by an endpoint. It is designated for use in applications expecting
/// a status code to indicate that no status code was actually present.
no_status = 1005,
/// A reserved value and MUST NOT be set as a status code in a Close control
/// frame by an endpoint. It is designated for use in applications expecting
/// a status code to indicate that the connection was closed abnormally, e.g.,
/// without sending or receiving a Close control frame.
abnormal_exit = 1006,
/// Indicates that an endpoint is terminating the connection because it has
/// received data within a message that was not consistent with the type of
/// the message (e.g., non-UTF-8 [RFC3629] data within a text message).
inconsistent_data = 1007,
/// Indicates that an endpoint is terminating the connection because it has
/// received a message that violates its policy. This is a generic status
/// code that can be returned when there is no other more suitable status code
/// (e.g., 1003 or 1009) or if there is a need to hide specific details about
/// the policy.
policy_violation = 1008,
/// Indicates that an endpoint is terminating the connection because it has
/// received a message that is too big for it to process.
message_too_big = 1009,
/// Indicates that an endpoint (client) is terminating the connection because
/// it has expected the server to negotiate one or more extension, but the
/// server didn't return them in the response message of the WebSocket
/// handshake. The list of extensions that are needed SHOULD appear in the
/// /reason/ part of the Close frame. Note that this status code is not used
/// by the server, because it can fail the WebSocket handshake instead.
missing_extensions = 1010,
/// Indicates that a server is terminating the connection because it
/// encountered an unexpected condition that prevented it from fulfilling the
/// request.
unexpected_condition = 1011,
/// A reserved value and MUST NOT be set as a status code in a Close control
/// frame by an endpoint. It is designated for use in applications expecting
/// a status code to indicate that the connection was closed due to a failure
/// to perform a TLS handshake (e.g., the server certificate can't be
/// verified).
tls_handshake_failure = 1015,
};
/// @relates status
CAF_NET_EXPORT std::string to_string(status);
/// @relates status
CAF_NET_EXPORT bool from_string(string_view, status&);
/// @relates status
CAF_NET_EXPORT bool from_integer(std::underlying_type_t<status>, status&);
/// @relates status
template <class Inspector>
bool inspect(Inspector& f, status& x) {
return default_enum_inspect(f, x);
}
} // namespace caf::net::web_socket
......@@ -10,6 +10,11 @@
#include <random>
#include <tuple>
#include <iostream>
#include "caf/config.hpp"
#include "caf/detail/base64.hpp"
#include "caf/hash/sha1.hpp"
......@@ -117,6 +122,15 @@ void handshake::write_http_1_response(byte_buffer& buf) const {
<< response_key() << "\r\n\r\n";
}
void handshake::write_http_1_bad_request(byte_buffer& buf, string_view descr) {
std::cout<<"BAD REQUEST: "<<descr<<'\n';
writer out{&buf};
out << "HTTP/1.1 400 Bad Request\r\n"
"Content-Type: text/plain\r\n"
"\r\n"
<< descr << "\r\n";
}
void handshake::write_http_1_header_too_large(byte_buffer& buf) {
writer out{&buf};
out << "HTTP/1.1 431 Request Header Fields Too Large\r\n"
......
......@@ -113,7 +113,7 @@ struct fixture : host_fixture {
};
constexpr auto opening_handshake
= "GET /chat HTTP/1.1\r\n"
= "GET /chat?room=lounge HTTP/1.1\r\n"
"Host: server.example.com\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
......@@ -143,7 +143,7 @@ CAF_TEST(applications receive handshake data via config) {
CAF_CHECK_EQUAL(transport.unconsumed(), 0u);
CAF_CHECK(ws->handshake_complete());
CHECK_SETTING("web-socket.method", "GET");
CHECK_SETTING("web-socket.request-uri", "/chat");
CHECK_SETTING("web-socket.path", "/chat");
CHECK_SETTING("web-socket.http-version", "HTTP/1.1");
CHECK_SETTING("web-socket.fields.Host", "server.example.com");
CHECK_SETTING("web-socket.fields.Upgrade", "websocket");
......@@ -153,6 +153,10 @@ CAF_TEST(applications receive handshake data via config) {
CHECK_SETTING("web-socket.fields.Sec-WebSocket-Version", "13");
CHECK_SETTING("web-socket.fields.Sec-WebSocket-Key",
"dGhlIHNhbXBsZSBub25jZQ==");
using str_map = std::map<std::string, std::string>;
if (auto query = get_as<str_map>(app->cfg, "web-socket.query");
CAF_CHECK(query))
CAF_CHECK_EQUAL(*query, str_map({{"room"s, "lounge"s}}));
}
CAF_TEST(the server responds with an HTTP response on success) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment