Commit c6fbe7ad authored by Dominik Charousset's avatar Dominik Charousset

Improve WebSocket API

- Implement proper WebSocket close handshake.
- Add convenience API for flow-based WebSocket servers.
parent 61fa2989
...@@ -17,6 +17,7 @@ caf_incubator_add_component( ...@@ -17,6 +17,7 @@ caf_incubator_add_component(
net.basp.ec net.basp.ec
net.basp.message_type net.basp.message_type
net.operation net.operation
net.web_socket.status
HEADERS HEADERS
${CAF_NET_HEADERS} ${CAF_NET_HEADERS}
SOURCES SOURCES
......
...@@ -31,7 +31,8 @@ public: ...@@ -31,7 +31,8 @@ public:
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
template <class... Ts> template <class... Ts>
explicit connection_acceptor(Ts&&... xs) : factory_(std::forward<Ts>(xs)...) { explicit connection_acceptor(size_t limit, Ts&&... xs)
: factory_(std::forward<Ts>(xs)...), limit_(limit) {
// nop // nop
} }
...@@ -54,13 +55,21 @@ public: ...@@ -54,13 +55,21 @@ public:
CAF_LOG_TRACE(""); CAF_LOG_TRACE("");
if (auto x = accept(parent->handle())) { if (auto x = accept(parent->handle())) {
socket_manager_ptr child = factory_.make(*x, owner_->mpx_ptr()); socket_manager_ptr child = factory_.make(*x, owner_->mpx_ptr());
CAF_ASSERT(child != nullptr); if (!child) {
CAF_LOG_ERROR("factory failed to create a new child");
parent->abort_reason(sec::runtime_error);
return false;
}
if (auto err = child->init(cfg_)) { if (auto err = child->init(cfg_)) {
CAF_LOG_ERROR("failed to initialize new child:" << err); CAF_LOG_ERROR("failed to initialize new child:" << err);
parent->abort_reason(std::move(err)); parent->abort_reason(std::move(err));
return false; return false;
} }
if (limit_ == 0) {
return true; return true;
} else {
return ++accepted_ < limit_;
}
} else { } else {
CAF_LOG_ERROR("accept failed:" << x.error()); CAF_LOG_ERROR("accept failed:" << x.error());
return false; return false;
...@@ -89,6 +98,10 @@ private: ...@@ -89,6 +98,10 @@ private:
socket_manager* owner_; socket_manager* owner_;
size_t limit_;
size_t accepted_ = 0;
settings cfg_; settings cfg_;
}; };
......
...@@ -121,6 +121,18 @@ public: ...@@ -121,6 +121,18 @@ public:
} }
} }
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr) {
// nop; this framing layer has no close handshake
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr, const error&) {
// nop; this framing layer has no close handshake
return true;
}
template <class LowerLayerPtr> template <class LowerLayerPtr>
static void abort_reason(LowerLayerPtr down, error reason) { static void abort_reason(LowerLayerPtr down, error reason) {
return down->abort_reason(std::move(reason)); return down->abort_reason(std::move(reason));
...@@ -237,8 +249,8 @@ private: ...@@ -237,8 +249,8 @@ private:
/// @param out Outputs from the socket. /// @param out Outputs from the socket.
/// @param trait Converts between the native and the wire format. /// @param trait Converts between the native and the wire format.
/// @relates length_prefix_framing /// @relates length_prefix_framing
template <template <class> class Transport = stream_transport, class T, template <template <class> class Transport = stream_transport, class Socket,
class Socket, class Trait> class T, class Trait>
error run_with_length_prefix_framing(multiplexer& mpx, Socket fd, error run_with_length_prefix_framing(multiplexer& mpx, Socket fd,
const settings& cfg, const settings& cfg,
async::consumer_resource<T> in, async::consumer_resource<T> in,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "caf/sec.hpp" #include "caf/sec.hpp"
#include "caf/settings.hpp" #include "caf/settings.hpp"
#include "caf/tag/message_oriented.hpp" #include "caf/tag/message_oriented.hpp"
#include "caf/tag/mixed_message_oriented.hpp"
#include "caf/tag/no_auto_reading.hpp" #include "caf/tag/no_auto_reading.hpp"
#include <utility> #include <utility>
...@@ -29,56 +30,78 @@ namespace caf::net { ...@@ -29,56 +30,78 @@ namespace caf::net {
/// bool convert(const_byte_span bytes, T& value); /// bool convert(const_byte_span bytes, T& value);
/// }; /// };
/// ~~~ /// ~~~
template <class T, class Trait> template <class T, class Trait, class Tag = tag::message_oriented>
class message_flow_bridge : public caf::tag::no_auto_reading { class message_flow_bridge : public tag::no_auto_reading {
public: public:
using input_tag = caf::tag::message_oriented; using input_tag = Tag;
using buffer_type = caf::async::spsc_buffer<T>; using buffer_type = async::spsc_buffer<T>;
using consumer_resource_t = async::consumer_resource<T>;
using producer_resource_t = async::producer_resource<T>;
explicit message_flow_bridge(Trait trait) : trait_(std::move(trait)) { explicit message_flow_bridge(Trait trait) : trait_(std::move(trait)) {
// nop // nop
} }
void connect_flows(caf::net::socket_manager* mgr, void connect_flows(net::socket_manager* mgr, consumer_resource_t in,
async::consumer_resource<T> in, producer_resource_t out) {
async::producer_resource<T> out) {
in_ = consumer_adapter<buffer_type>::try_open(mgr, in); in_ = consumer_adapter<buffer_type>::try_open(mgr, in);
out_ = producer_adapter<buffer_type>::try_open(mgr, out); out_ = producer_adapter<buffer_type>::try_open(mgr, out);
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
caf::error error
init(caf::net::socket_manager* mgr, LowerLayerPtr&&, const caf::settings&) { init(net::socket_manager* mgr, LowerLayerPtr down, const settings& cfg) {
mgr_ = mgr; mgr_ = mgr;
if constexpr (caf::detail::has_init_v<Trait>) {
if (auto err = init_res(trait_.init(cfg)))
return err;
}
if (!in_ && !out_) if (!in_ && !out_)
return make_error(sec::cannot_open_resource, return make_error(sec::cannot_open_resource,
"flow bridge cannot run without at least one resource"); "flow bridge cannot run without at least one resource");
else if (!out_)
return caf::none; down->suspend_reading();
return none;
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool write(LowerLayerPtr down, const T& item) { bool write(LowerLayerPtr down, const T& item) {
if constexpr (std::is_same_v<Tag, tag::message_oriented>) {
down->begin_message(); down->begin_message();
auto& buf = down->message_buffer(); auto& buf = down->message_buffer();
return trait_.convert(item, buf) && down->end_message(); return trait_.convert(item, buf) && down->end_message();
} else {
static_assert(std::is_same_v<Tag, tag::mixed_message_oriented>);
if (trait_.converts_to_binary(item)) {
down->begin_binary_message();
auto& buf = down->binary_message_buffer();
return trait_.convert(item, buf) && down->end_binary_message();
} else {
down->begin_text_message();
auto& buf = down->text_message_buffer();
return trait_.convert(item, buf) && down->end_text_message();
}
}
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
struct send_helper { struct write_helper {
using bridge_type = message_flow_bridge; using bridge_type = message_flow_bridge;
bridge_type* bridge; bridge_type* bridge;
LowerLayerPtr down; LowerLayerPtr down;
bool aborted = false; bool aborted = false;
size_t consumed = 0; size_t consumed = 0;
error err;
send_helper(bridge_type* bridge, LowerLayerPtr down) write_helper(bridge_type* bridge, LowerLayerPtr down)
: bridge(bridge), down(down) { : bridge(bridge), down(down) {
// nop // nop
} }
void on_next(caf::span<const T> items) { void on_next(span<const T> items) {
CAF_ASSERT(items.size() == 1); CAF_ASSERT(items.size() == 1);
for (const auto& item : items) { for (const auto& item : items) {
if (!bridge->write(down, item)) { if (!bridge->write(down, item)) {
...@@ -92,17 +115,22 @@ public: ...@@ -92,17 +115,22 @@ public:
// nop // nop
} }
void on_error(const caf::error&) { void on_error(const error& x) {
// nop err = x;
} }
}; };
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool prepare_send(LowerLayerPtr down) { bool prepare_send(LowerLayerPtr down) {
send_helper<LowerLayerPtr> helper{this, down}; write_helper<LowerLayerPtr> helper{this, down};
while (down->can_send_more() && in_) { while (down->can_send_more() && in_) {
auto [again, consumed] = in_->pull(caf::async::delay_errors, 1, helper); auto [again, consumed] = in_->pull(async::delay_errors, 1, helper);
if (!again) { if (!again) {
if (helper.err) {
down->send_close_message(helper.err);
} else {
down->send_close_message();
}
in_ = nullptr; in_ = nullptr;
} else if (helper.aborted) { } else if (helper.aborted) {
down->abort_reason(make_error(sec::conversion_failed)); down->abort_reason(make_error(sec::conversion_failed));
...@@ -122,11 +150,10 @@ public: ...@@ -122,11 +150,10 @@ public:
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
void abort(LowerLayerPtr, const caf::error& reason) { void abort(LowerLayerPtr, const error& reason) {
CAF_LOG_TRACE(CAF_ARG(reason)); CAF_LOG_TRACE(CAF_ARG(reason));
if (out_) { if (out_) {
if (reason == caf::sec::socket_disconnected if (reason == sec::socket_disconnected || reason == sec::discarded)
|| reason == caf::sec::discarded)
out_->close(); out_->close();
else else
out_->abort(reason); out_->abort(reason);
...@@ -138,8 +165,29 @@ public: ...@@ -138,8 +165,29 @@ public:
} }
} }
template <class LowerLayerPtr> template <class U = Tag, class LowerLayerPtr>
ptrdiff_t consume(LowerLayerPtr down, caf::byte_span buf) { ptrdiff_t consume(LowerLayerPtr down, byte_span buf) {
if (!out_) {
down->abort_reason(make_error(sec::connection_closed));
return -1;
}
T val;
if (!trait_.convert(buf, val)) {
down->abort_reason(make_error(sec::conversion_failed));
return -1;
}
if (out_->push(std::move(val)) == 0)
down->suspend_reading();
return static_cast<ptrdiff_t>(buf.size());
}
template <class U = Tag, class LowerLayerPtr>
ptrdiff_t consume_binary(LowerLayerPtr down, byte_span buf) {
return consume(down, buf);
}
template <class U = Tag, class LowerLayerPtr>
ptrdiff_t consume_text(LowerLayerPtr down, string_view buf) {
if (!out_) { if (!out_) {
down->abort_reason(make_error(sec::connection_closed)); down->abort_reason(make_error(sec::connection_closed));
return -1; return -1;
...@@ -155,8 +203,35 @@ public: ...@@ -155,8 +203,35 @@ public:
} }
private: private:
error init_res(error err) {
return err;
}
error init_res(consumer_resource_t in, producer_resource_t out) {
connect_flows(mgr_, std::move(in), std::move(out));
return caf::none;
}
error init_res(std::tuple<consumer_resource_t, producer_resource_t> in_out) {
auto& [in, out] = in_out;
return init_res(std::move(in), std::move(out));
}
error init_res(std::pair<consumer_resource_t, producer_resource_t> in_out) {
auto& [in, out] = in_out;
return init_res(std::move(in), std::move(out));
}
template <class R>
error init_res(expected<R> res) {
if (res)
return init_res(*res);
else
return std::move(res.error());
}
/// Points to the manager that runs this protocol stack. /// Points to the manager that runs this protocol stack.
caf::net::socket_manager* mgr_ = nullptr; net::socket_manager* mgr_ = nullptr;
/// Incoming messages, serialized to the socket. /// Incoming messages, serialized to the socket.
consumer_adapter_ptr<buffer_type> in_; consumer_adapter_ptr<buffer_type> in_;
......
...@@ -46,6 +46,11 @@ public: ...@@ -46,6 +46,11 @@ public:
return lptr_->end_message(llptr_); return lptr_->end_message(llptr_);
} }
template <class... Ts>
bool send_close_message(Ts&&... xs) {
return lptr_->send_close_message(llptr_, std::forward<Ts>(xs)...);
}
void abort_reason(error reason) { void abort_reason(error reason) {
return lptr_->abort_reason(llptr_, std::move(reason)); return lptr_->abort_reason(llptr_, std::move(reason));
} }
......
...@@ -44,20 +44,24 @@ public: ...@@ -44,20 +44,24 @@ public:
// -- socket manager functions ----------------------------------------------- // -- socket manager functions -----------------------------------------------
/// /// Creates a new acceptor that accepts incoming connections from @p sock and
/// creates socket managers using @p factory.
/// @param sock An accept socket in listening mode. For a TCP socket, this /// @param sock An accept socket in listening mode. For a TCP socket, this
/// socket must already listen to an address plus port. /// socket must already listen to an address plus port.
/// @param factory An application stack factory. /// @param factory A function object for creating socket managers that take
/// ownership of incoming connections.
/// @param limit The maximum number of connections that this acceptor should
/// establish or 0 for 'no limit'.
template <class Socket, class Factory> template <class Socket, class Factory>
auto make_acceptor(Socket sock, Factory factory) { auto make_acceptor(Socket sock, Factory factory, size_t limit = 0) {
using connected_socket_type = typename Socket::connected_socket_type; using connected_socket_type = typename Socket::connected_socket_type;
if constexpr (detail::is_callable_with<Factory, connected_socket_type, if constexpr (detail::is_callable_with<Factory, connected_socket_type,
multiplexer*>::value) { multiplexer*>::value) {
connection_acceptor_factory_adapter<Factory> adapter{std::move(factory)}; connection_acceptor_factory_adapter<Factory> adapter{std::move(factory)};
return make_acceptor(std::move(sock), std::move(adapter)); return make_acceptor(std::move(sock), std::move(adapter), limit);
} else { } else {
using impl = connection_acceptor<Socket, Factory>; using impl = connection_acceptor<Socket, Factory>;
auto ptr = make_socket_manager<impl>(std::move(sock), &mpx_, auto ptr = make_socket_manager<impl>(std::move(sock), &mpx_, limit,
std::move(factory)); std::move(factory));
mpx_.init(ptr); mpx_.init(ptr);
return ptr; return ptr;
......
...@@ -42,8 +42,8 @@ public: ...@@ -42,8 +42,8 @@ public:
return lptr_->binary_message_buffer(llptr_); return lptr_->binary_message_buffer(llptr_);
} }
void end_binary_message() { bool end_binary_message() {
lptr_->end_binary_message(llptr_); return lptr_->end_binary_message(llptr_);
} }
void begin_text_message() { void begin_text_message() {
...@@ -54,8 +54,13 @@ public: ...@@ -54,8 +54,13 @@ public:
return lptr_->text_message_buffer(llptr_); return lptr_->text_message_buffer(llptr_);
} }
void end_text_message() { bool end_text_message() {
lptr_->end_text_message(llptr_); return lptr_->end_text_message(llptr_);
}
template <class... Ts>
bool send_close_message(Ts&&... xs) {
return lptr_->send_close_message(llptr_, std::forward<Ts>(xs)...);
} }
void abort_reason(error reason) { void abort_reason(error reason) {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "caf/byte_span.hpp" #include "caf/byte_span.hpp"
#include "caf/detail/rfc6455.hpp" #include "caf/detail/rfc6455.hpp"
#include "caf/net/mixed_message_oriented_layer_ptr.hpp" #include "caf/net/mixed_message_oriented_layer_ptr.hpp"
#include "caf/net/web_socket/status.hpp"
#include "caf/sec.hpp" #include "caf/sec.hpp"
#include "caf/span.hpp" #include "caf/span.hpp"
#include "caf/string_view.hpp" #include "caf/string_view.hpp"
...@@ -20,7 +21,7 @@ ...@@ -20,7 +21,7 @@
namespace caf::net::web_socket { namespace caf::net::web_socket {
/// Implements the WebProtocol framing protocol as defined in RFC-6455. /// Implements the WebSocket framing protocol as defined in RFC-6455.
template <class UpperLayer> template <class UpperLayer>
class framing { class framing {
public: public:
...@@ -87,11 +88,8 @@ public: ...@@ -87,11 +88,8 @@ public:
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
static void suspend_reading(LowerLayerPtr) { static void suspend_reading(LowerLayerPtr down) {
CAF_RAISE_ERROR("suspending / resuming a WebSocket not implemented yet"); down->configure_read(receive_policy::stop());
// TODO: uncommenting this isn't enough since consume() also must make sure
// to not override the configure_read.
// down->configure_read(receive_policy::stop());
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
...@@ -105,8 +103,9 @@ public: ...@@ -105,8 +103,9 @@ public:
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
void end_binary_message(LowerLayerPtr down) { bool end_binary_message(LowerLayerPtr down) {
ship_frame(down, binary_buf_); ship_frame(down, binary_buf_);
return true;
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
...@@ -120,8 +119,28 @@ public: ...@@ -120,8 +119,28 @@ public:
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
void end_text_message(LowerLayerPtr down) { bool end_text_message(LowerLayerPtr down) {
ship_frame(down, text_buf_); ship_frame(down, text_buf_);
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr down) {
ship_close(down);
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr down, status code, string_view desc) {
ship_close(down, static_cast<uint16_t>(code), desc);
return true;
}
template <class LowerLayerPtr>
bool send_close_message(LowerLayerPtr down, const error& reason) {
ship_close(down, static_cast<uint16_t>(status::unexpected_condition),
to_string(reason));
return true;
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
...@@ -138,12 +157,12 @@ public: ...@@ -138,12 +157,12 @@ public:
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool prepare_send(LowerLayerPtr down) { bool prepare_send(LowerLayerPtr down) {
return upper_layer_.prepare_send(down); return upper_layer_.prepare_send(this_layer_ptr(down));
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool done_sending(LowerLayerPtr down) { bool done_sending(LowerLayerPtr down) {
return upper_layer_.done_sending(down); return upper_layer_.done_sending(this_layer_ptr(down));
} }
template <class LowerLayerPtr> template <class LowerLayerPtr>
...@@ -185,6 +204,9 @@ public: ...@@ -185,6 +204,9 @@ public:
// Wait for more data if necessary. // Wait for more data if necessary.
size_t frame_size = hdr_bytes + hdr.payload_len; size_t frame_size = hdr_bytes + hdr.payload_len;
if (buffer.size() < frame_size) { if (buffer.size() < frame_size) {
// Ask for more data unless the upper layer called suspend_reading.
if (!down->stopped())
down->configure_read(receive_policy::up_to(2048));
down->configure_read(receive_policy::exactly(frame_size)); down->configure_read(receive_policy::exactly(frame_size));
return consumed; return consumed;
} }
...@@ -196,6 +218,7 @@ public: ...@@ -196,6 +218,7 @@ public:
} }
if (hdr.fin) { if (hdr.fin) {
if (opcode_ == nil_code) { if (opcode_ == nil_code) {
// Call upper layer.
if (!handle(down, hdr.opcode, payload)) if (!handle(down, hdr.opcode, payload))
return -1; return -1;
} else if (hdr.opcode != detail::rfc6455::continuation_frame) { } else if (hdr.opcode != detail::rfc6455::continuation_frame) {
...@@ -243,6 +266,8 @@ public: ...@@ -243,6 +266,8 @@ public:
// Advance to next frame in the input. // Advance to next frame in the input.
buffer = buffer.subspan(frame_size); buffer = buffer.subspan(frame_size);
if (buffer.empty()) { if (buffer.empty()) {
// Ask for more data unless the upper layer called suspend_reading.
if (!down->stopped())
down->configure_read(receive_policy::up_to(2048)); down->configure_read(receive_policy::up_to(2048));
return consumed + static_cast<ptrdiff_t>(frame_size); return consumed + static_cast<ptrdiff_t>(frame_size);
} }
...@@ -291,6 +316,42 @@ private: ...@@ -291,6 +316,42 @@ private:
down->end_output(); down->end_output();
} }
template <class LowerLayerPtr>
void ship_close(LowerLayerPtr down, uint16_t code, string_view msg) {
uint32_t mask_key = 0;
std::vector<byte> payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<byte>((code & 0xFF00) >> 8));
payload.push_back(static_cast<byte>(code & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down->output_buffer());
down->end_output();
}
template <class LowerLayerPtr>
void ship_close(LowerLayerPtr down) {
uint32_t mask_key = 0;
byte payload[] = {
byte{0x03}, byte{0xE8}, // Error code 1000: normal close.
byte{'E'}, byte{'O'}, byte{'F'}, // "EOF" string as goodbye message.
};
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down->output_buffer());
down->end_output();
}
template <class LowerLayerPtr, class T> template <class LowerLayerPtr, class T>
void ship_frame(LowerLayerPtr down, std::vector<T>& buf) { void ship_frame(LowerLayerPtr down, std::vector<T>& buf) {
uint32_t mask_key = 0; uint32_t mask_key = 0;
......
...@@ -125,6 +125,10 @@ public: ...@@ -125,6 +125,10 @@ public:
/// @pre `has_valid_key()` /// @pre `has_valid_key()`
void write_http_1_response(byte_buffer& buf) const; void write_http_1_response(byte_buffer& buf) const;
/// Writes an HTTP 1.1 'Bad Request' error to `buf` with `descr` providing
/// additional information to the client.
static void write_http_1_bad_request(byte_buffer& buf, string_view descr);
/// Writes a HTTP 1.1 431 (Request Header Fields Too Large) response. /// Writes a HTTP 1.1 431 (Request Header Fields Too Large) response.
static void write_http_1_header_too_large(byte_buffer& buf); static void write_http_1_header_too_large(byte_buffer& buf);
......
...@@ -9,10 +9,15 @@ ...@@ -9,10 +9,15 @@
#include "caf/byte_span.hpp" #include "caf/byte_span.hpp"
#include "caf/error.hpp" #include "caf/error.hpp"
#include "caf/logger.hpp" #include "caf/logger.hpp"
#include "caf/net/connection_acceptor.hpp"
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/message_flow_bridge.hpp"
#include "caf/net/multiplexer.hpp"
#include "caf/net/receive_policy.hpp" #include "caf/net/receive_policy.hpp"
#include "caf/net/socket_manager.hpp"
#include "caf/net/web_socket/framing.hpp" #include "caf/net/web_socket/framing.hpp"
#include "caf/net/web_socket/handshake.hpp" #include "caf/net/web_socket/handshake.hpp"
#include "caf/net/web_socket/status.hpp"
#include "caf/pec.hpp" #include "caf/pec.hpp"
#include "caf/settings.hpp" #include "caf/settings.hpp"
#include "caf/tag/mixed_message_oriented.hpp" #include "caf/tag/mixed_message_oriented.hpp"
...@@ -133,20 +138,50 @@ private: ...@@ -133,20 +138,50 @@ private:
template <class LowerLayerPtr> template <class LowerLayerPtr>
bool handle_header(LowerLayerPtr down, string_view http) { bool handle_header(LowerLayerPtr down, string_view http) {
using namespace std::literals;
// Parse the first line, i.e., "METHOD REQUEST-URI VERSION". // Parse the first line, i.e., "METHOD REQUEST-URI VERSION".
auto [first_line, remainder] = split(http, "\r\n"); auto [first_line, remainder] = split(http, "\r\n");
auto [method, request_uri, version] = split2(first_line, " "); auto [method, request_uri_str, version] = split2(first_line, " ");
auto& hdr = cfg_["web-socket"].as_dictionary(); auto& hdr = cfg_["web-socket"].as_dictionary();
if (method != "GET") { if (method != "GET") {
down->begin_output();
handshake::write_http_1_bad_request(down->output_buffer(),
"Expected WebSocket handshake.");
down->end_output();
auto err = make_error(pec::invalid_argument, auto err = make_error(pec::invalid_argument,
"invalid operation: expected GET, got " "invalid operation: expected GET, got "
+ to_string(method)); + to_string(method));
down->abort_reason(std::move(err)); down->abort_reason(std::move(err));
return false; return false;
} }
// The path must be absolute.
if (request_uri_str.empty() || request_uri_str.front() != '/') {
auto descr = "Malformed Request-URI path: expected absolute path."s;
down->begin_output();
handshake::write_http_1_bad_request(down->output_buffer(), descr);
down->end_output();
down->abort_reason(make_error(pec::invalid_argument, std::move(descr)));
return false;
}
// The path must form a valid URI when prefixing a scheme. We don't actually
// care about the scheme, so just use "foo" here for the validation step.
uri request_uri;
if (auto res = make_uri("foo://localhost" + to_string(request_uri_str))) {
request_uri = std::move(*res);
} else {
auto descr = "Malformed Request-URI path: " + to_string(res.error());
descr += '.';
down->begin_output();
handshake::write_http_1_bad_request(down->output_buffer(), descr);
down->end_output();
down->abort_reason(make_error(pec::invalid_argument, std::move(descr)));
return false;
}
// Store the request information in the settings for the upper layer. // Store the request information in the settings for the upper layer.
put(hdr, "method", method); put(hdr, "method", method);
put(hdr, "request-uri", request_uri); put(hdr, "path", request_uri.path());
put(hdr, "query", request_uri.query());
put(hdr, "fragment", request_uri.fragment());
put(hdr, "http-version", version); put(hdr, "http-version", version);
// Store the remaining header fields. // Store the remaining header fields.
auto& fields = hdr["fields"].as_dictionary(); auto& fields = hdr["fields"].as_dictionary();
...@@ -165,21 +200,27 @@ private: ...@@ -165,21 +200,27 @@ private:
skey_field && hs.assign_key(*skey_field)) { skey_field && hs.assign_key(*skey_field)) {
CAF_LOG_DEBUG("received Sec-WebSocket-Key" << *skey_field); CAF_LOG_DEBUG("received Sec-WebSocket-Key" << *skey_field);
} else { } else {
auto descr = "Mandatory field Sec-WebSocket-Key missing or invalid."s;
down->begin_output();
handshake::write_http_1_bad_request(down->output_buffer(), descr);
down->end_output();
CAF_LOG_DEBUG("received invalid WebSocket handshake"); CAF_LOG_DEBUG("received invalid WebSocket handshake");
down->abort_reason( down->abort_reason(make_error(pec::missing_field, std::move(descr)));
make_error(pec::missing_field,
"mandatory field Sec-WebSocket-Key missing or invalid"));
return false; return false;
} }
// Send server handshake.
down->begin_output();
hs.write_http_1_response(down->output_buffer());
down->end_output();
// Try initializing the upper layer. // Try initializing the upper layer.
if (auto err = upper_layer_.init(owner_, down, cfg_)) { if (auto err = upper_layer_.init(owner_, down, cfg_)) {
auto descr = to_string(err);
down->begin_output();
handshake::write_http_1_bad_request(down->output_buffer(), descr);
down->end_output();
down->abort_reason(std::move(err)); down->abort_reason(std::move(err));
return false; return false;
} }
// Send server handshake.
down->begin_output();
hs.write_http_1_response(down->output_buffer());
down->end_output();
// Done. // Done.
CAF_LOG_DEBUG("completed WebSocket handshake"); CAF_LOG_DEBUG("completed WebSocket handshake");
handshake_complete_ = true; handshake_complete_ = true;
...@@ -234,8 +275,183 @@ private: ...@@ -234,8 +275,183 @@ private:
socket_manager* owner_ = nullptr; socket_manager* owner_ = nullptr;
/// Holds a copy of the settings in order to delay initialization of the upper /// Holds a copy of the settings in order to delay initialization of the upper
/// layer until the handshake completed. /// layer until the handshake completed. We also fill this dictionary with the
/// contents of the HTTP GET header.
settings cfg_; settings cfg_;
}; };
/// Creates a WebSocket server on the connected socket `fd`.
/// @param mpx The multiplexer that takes ownership of the socket.
/// @param fd A connected stream socket.
/// @param in Inputs for writing to the socket.
/// @param out Outputs from the socket.
/// @param trait Converts between the native and the wire format.
template <template <class> class Transport = stream_transport, class Socket,
class T, class Trait>
socket_manager_ptr make_server(multiplexer& mpx, Socket fd,
async::consumer_resource<T> in,
async::producer_resource<T> out, Trait trait) {
using app_t = message_flow_bridge<T, Trait, tag::mixed_message_oriented>;
using stack_t = Transport<server<app_t>>;
auto mgr = make_socket_manager<stack_t>(fd, &mpx, std::move(trait));
mgr->top_layer().connect_flows(mgr.get(), std::move(in), std::move(out));
return mgr;
}
} // namespace caf::net::web_socket
namespace caf::detail {
template <class T, class Trait>
using on_request_result = expected<
std::tuple<async::consumer_resource<T>, // For the connection to read from.
async::producer_resource<T>, // For the connection to write to.
Trait>>; // For translating between native and wire format.
template <class T>
struct is_on_request_result : std::false_type {};
template <class T, class Trait>
struct is_on_request_result<on_request_result<T, Trait>> : std::true_type {};
template <class T>
struct on_request_trait;
template <class T, class ServerTrait>
struct on_request_trait<on_request_result<T, ServerTrait>> {
using value_type = T;
using trait_type = ServerTrait;
};
template <class OnRequest>
class ws_accept_trait {
public:
using on_request_r
= decltype(std::declval<OnRequest&>()(std::declval<const settings&>()));
static_assert(is_on_request_result<on_request_r>::value,
"OnRequest must return an on_request_result");
using on_request_t = on_request_trait<on_request_r>;
using value_type = typename on_request_t::value_type;
using decorated_trait = typename on_request_t::trait_type;
using consumer_resource_t = async::consumer_resource<value_type>;
using producer_resource_t = async::producer_resource<value_type>;
using in_out_tuple = std::tuple<consumer_resource_t, producer_resource_t>;
using init_res_t = expected<in_out_tuple>;
ws_accept_trait() = delete;
explicit ws_accept_trait(OnRequest on_request) : state_(on_request) {
// nop
}
ws_accept_trait(ws_accept_trait&&) = default;
ws_accept_trait& operator=(ws_accept_trait&&) = default;
init_res_t init(const settings& cfg) {
auto f = std::move(std::get<OnRequest>(state_));
if (auto res = f(cfg)) {
auto& [in, out, trait] = *res;
if (auto err = trait.init(cfg)) {
state_ = none;
return std::move(res.error());
} else {
state_ = std::move(trait);
return std::make_tuple(std::move(in), std::move(out));
}
} else {
state_ = none;
return std::move(res.error());
}
}
bool converts_to_binary(const value_type& x) {
return decorated().converts_to_binary(x);
}
bool convert(const value_type& x, byte_buffer& bytes) {
return decorated().convert(x, bytes);
}
bool convert(const value_type& x, std::vector<char>& text) {
return decorated().convert(x, text);
}
bool convert(const_byte_span bytes, int32_t&x) {
return decorated().convert(bytes, x);
}
bool convert(string_view text, int32_t& x) {
return decorated().convert(text, x);
}
private:
decorated_trait& decorated() {
return std::get<decorated_trait>(state_);
}
std::variant<none_t, OnRequest, decorated_trait> state_;
};
template <template <class> class Transport, class OnRequest>
class ws_acceptor_factory {
public:
explicit ws_acceptor_factory(OnRequest on_request)
: on_request_(std::move(on_request)) {
// nop
}
error init(net::socket_manager*, const settings&) {
return none;
}
template <class Socket>
net::socket_manager_ptr make(Socket fd, net::multiplexer* mpx) {
using trait_t = ws_accept_trait<OnRequest>;
using value_type = typename trait_t::value_type;
using app_t = net::message_flow_bridge<value_type, trait_t,
tag::mixed_message_oriented>;
using stack_t = Transport<net::web_socket::server<app_t>>;
return net::make_socket_manager<stack_t>(fd, mpx, trait_t{on_request_});
}
void abort(const error&) {
// nop
}
private:
OnRequest on_request_;
};
} // namespace caf::detail
namespace caf::net::web_socket {
/// Creates a WebSocket server on the connected socket `fd`.
/// @param mpx The multiplexer that takes ownership of the socket.
/// @param fd An accept socket in listening mode. For a TCP socket, this socket
/// must already listen to an address plus port.
/// @param on_request Function object for turning requests into a tuple
/// consisting of a consumer resource, a producer resource,
/// and a trait. These arguments get forwarded to
/// @ref make_server internally.
template <template <class> class Transport = stream_transport, class Socket,
class OnRequest>
void accept(multiplexer& mpx, Socket fd, OnRequest on_request,
size_t limit = 0) {
using factory = detail::ws_acceptor_factory<Transport, OnRequest>;
using impl = connection_acceptor<Socket, factory>;
auto ptr = make_socket_manager<impl>(std::move(fd), &mpx, limit,
factory{std::move(on_request)});
mpx.init(ptr);
}
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/default_enum_inspect.hpp"
#include "caf/detail/net_export.hpp"
#include <cstdint>
namespace caf::net::web_socket {
/// Status codes as defined by RFC 6455, Section 7.4.
enum class status : uint16_t {
/// Indicates a normal closure, meaning that the purpose for which the
/// connection was established has been fulfilled.
normal_close = 1000,
/// Indicates that an endpoint is "going away", such as a server going down or
/// a browser having navigated away from a page.
going_away = 1001,
/// Indicates that an endpoint is terminating the connection due to a protocol
/// error.
protocol_error = 1002,
/// Indicates that an endpoint is terminating the connection because it has
/// received a type of data it cannot accept (e.g., an endpoint that
/// understands only text data MAY send this if it receives a binary message).
invalid_data = 1003,
/// A reserved value and MUST NOT be set as a status code in a Close control
/// frame by an endpoint. It is designated for use in applications expecting
/// a status code to indicate that no status code was actually present.
no_status = 1005,
/// A reserved value and MUST NOT be set as a status code in a Close control
/// frame by an endpoint. It is designated for use in applications expecting
/// a status code to indicate that the connection was closed abnormally, e.g.,
/// without sending or receiving a Close control frame.
abnormal_exit = 1006,
/// Indicates that an endpoint is terminating the connection because it has
/// received data within a message that was not consistent with the type of
/// the message (e.g., non-UTF-8 [RFC3629] data within a text message).
inconsistent_data = 1007,
/// Indicates that an endpoint is terminating the connection because it has
/// received a message that violates its policy. This is a generic status
/// code that can be returned when there is no other more suitable status code
/// (e.g., 1003 or 1009) or if there is a need to hide specific details about
/// the policy.
policy_violation = 1008,
/// Indicates that an endpoint is terminating the connection because it has
/// received a message that is too big for it to process.
message_too_big = 1009,
/// Indicates that an endpoint (client) is terminating the connection because
/// it has expected the server to negotiate one or more extension, but the
/// server didn't return them in the response message of the WebSocket
/// handshake. The list of extensions that are needed SHOULD appear in the
/// /reason/ part of the Close frame. Note that this status code is not used
/// by the server, because it can fail the WebSocket handshake instead.
missing_extensions = 1010,
/// Indicates that a server is terminating the connection because it
/// encountered an unexpected condition that prevented it from fulfilling the
/// request.
unexpected_condition = 1011,
/// A reserved value and MUST NOT be set as a status code in a Close control
/// frame by an endpoint. It is designated for use in applications expecting
/// a status code to indicate that the connection was closed due to a failure
/// to perform a TLS handshake (e.g., the server certificate can't be
/// verified).
tls_handshake_failure = 1015,
};
/// @relates status
CAF_NET_EXPORT std::string to_string(status);
/// @relates status
CAF_NET_EXPORT bool from_string(string_view, status&);
/// @relates status
CAF_NET_EXPORT bool from_integer(std::underlying_type_t<status>, status&);
/// @relates status
template <class Inspector>
bool inspect(Inspector& f, status& x) {
return default_enum_inspect(f, x);
}
} // namespace caf::net::web_socket
...@@ -10,6 +10,11 @@ ...@@ -10,6 +10,11 @@
#include <random> #include <random>
#include <tuple> #include <tuple>
#include <iostream>
#include "caf/config.hpp" #include "caf/config.hpp"
#include "caf/detail/base64.hpp" #include "caf/detail/base64.hpp"
#include "caf/hash/sha1.hpp" #include "caf/hash/sha1.hpp"
...@@ -117,6 +122,15 @@ void handshake::write_http_1_response(byte_buffer& buf) const { ...@@ -117,6 +122,15 @@ void handshake::write_http_1_response(byte_buffer& buf) const {
<< response_key() << "\r\n\r\n"; << response_key() << "\r\n\r\n";
} }
void handshake::write_http_1_bad_request(byte_buffer& buf, string_view descr) {
std::cout<<"BAD REQUEST: "<<descr<<'\n';
writer out{&buf};
out << "HTTP/1.1 400 Bad Request\r\n"
"Content-Type: text/plain\r\n"
"\r\n"
<< descr << "\r\n";
}
void handshake::write_http_1_header_too_large(byte_buffer& buf) { void handshake::write_http_1_header_too_large(byte_buffer& buf) {
writer out{&buf}; writer out{&buf};
out << "HTTP/1.1 431 Request Header Fields Too Large\r\n" out << "HTTP/1.1 431 Request Header Fields Too Large\r\n"
......
...@@ -113,7 +113,7 @@ struct fixture : host_fixture { ...@@ -113,7 +113,7 @@ struct fixture : host_fixture {
}; };
constexpr auto opening_handshake constexpr auto opening_handshake
= "GET /chat HTTP/1.1\r\n" = "GET /chat?room=lounge HTTP/1.1\r\n"
"Host: server.example.com\r\n" "Host: server.example.com\r\n"
"Upgrade: websocket\r\n" "Upgrade: websocket\r\n"
"Connection: Upgrade\r\n" "Connection: Upgrade\r\n"
...@@ -143,7 +143,7 @@ CAF_TEST(applications receive handshake data via config) { ...@@ -143,7 +143,7 @@ CAF_TEST(applications receive handshake data via config) {
CAF_CHECK_EQUAL(transport.unconsumed(), 0u); CAF_CHECK_EQUAL(transport.unconsumed(), 0u);
CAF_CHECK(ws->handshake_complete()); CAF_CHECK(ws->handshake_complete());
CHECK_SETTING("web-socket.method", "GET"); CHECK_SETTING("web-socket.method", "GET");
CHECK_SETTING("web-socket.request-uri", "/chat"); CHECK_SETTING("web-socket.path", "/chat");
CHECK_SETTING("web-socket.http-version", "HTTP/1.1"); CHECK_SETTING("web-socket.http-version", "HTTP/1.1");
CHECK_SETTING("web-socket.fields.Host", "server.example.com"); CHECK_SETTING("web-socket.fields.Host", "server.example.com");
CHECK_SETTING("web-socket.fields.Upgrade", "websocket"); CHECK_SETTING("web-socket.fields.Upgrade", "websocket");
...@@ -153,6 +153,10 @@ CAF_TEST(applications receive handshake data via config) { ...@@ -153,6 +153,10 @@ CAF_TEST(applications receive handshake data via config) {
CHECK_SETTING("web-socket.fields.Sec-WebSocket-Version", "13"); CHECK_SETTING("web-socket.fields.Sec-WebSocket-Version", "13");
CHECK_SETTING("web-socket.fields.Sec-WebSocket-Key", CHECK_SETTING("web-socket.fields.Sec-WebSocket-Key",
"dGhlIHNhbXBsZSBub25jZQ=="); "dGhlIHNhbXBsZSBub25jZQ==");
using str_map = std::map<std::string, std::string>;
if (auto query = get_as<str_map>(app->cfg, "web-socket.query");
CAF_CHECK(query))
CAF_CHECK_EQUAL(*query, str_map({{"room"s, "lounge"s}}));
} }
CAF_TEST(the server responds with an HTTP response on success) { CAF_TEST(the server responds with an HTTP response on success) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment