Commit db0db507 authored by Dominik Charousset's avatar Dominik Charousset

Enbale HTTP routes to switch to WebSocket

parent 215e4ed3
// Simple WebSocket server that sends local files from a working directory to
// the clients.
// Simple HTTP/WebSocket server that sends predefined text snippets
// (philosophers quotes) to the client. Clients may either ask for a single
// quote via HTTP GET request or for all quotes of a selected philosopher by
// connecting via WebSocket.
#include "caf/actor_system.hpp"
#include "caf/actor_system_config.hpp"
......@@ -7,9 +9,10 @@
#include "caf/cow_string.hpp"
#include "caf/cow_tuple.hpp"
#include "caf/event_based_actor.hpp"
#include "caf/net/http/with.hpp"
#include "caf/net/middleman.hpp"
#include "caf/net/ssl/context.hpp"
#include "caf/net/web_socket/with.hpp"
#include "caf/net/web_socket/switch_protocol.hpp"
#include "caf/scheduled_actor/flow.hpp"
#include "caf/span.hpp"
......@@ -81,12 +84,12 @@ struct config : caf::actor_system_config {
// -- helper functions ---------------------------------------------------------
// Returns a list of philosopher quotes by path.
caf::span<const std::string_view> quotes_by_path(std::string_view path) {
if (path == "/epictetus")
caf::span<const std::string_view> quotes_by_name(std::string_view path) {
if (path == "epictetus")
return caf::make_span(epictetus);
else if (path == "/seneca")
else if (path == "seneca")
return caf::make_span(seneca);
else if (path == "/plato")
else if (path == "plato")
return caf::make_span(plato);
else
return {};
......@@ -109,6 +112,13 @@ private:
std::minstd_rand engine_;
};
std::string not_found_str(std::string_view name) {
auto result = "Name '"s;
result += name;
result += "' not found. Try 'epictetus', 'seneca' or 'plato'.";
return result;
}
// -- main ---------------------------------------------------------------------
int caf_main(caf::actor_system& sys, const config& cfg) {
......@@ -129,7 +139,7 @@ int caf_main(caf::actor_system& sys, const config& cfg) {
// Open up a TCP port for incoming connections and start the server.
using trait = ws::default_trait;
auto server
= ws::with(sys)
= http::with(sys)
// Optionally enable TLS.
.context(ssl::context::enable(key_file && cert_file)
.and_then(ssl::emplace_server(ssl::tls::v1_2))
......@@ -139,34 +149,61 @@ int caf_main(caf::actor_system& sys, const config& cfg) {
.accept(port)
// Limit how many clients may be connected at any given time.
.max_connections(max_connections)
// Forward the path from the WebSocket request to the worker.
.on_request([](ws::acceptor<caf::cow_string>& acc) {
// The hdr parameter is a dictionary with fields from the WebSocket
// handshake such as the path. This is only field we care about
// here. By passing the (copy-on-write) string to accept() here, we
// make it available to the worker through the acceptor_resource.
acc.accept(caf::cow_string{acc.header().path()});
// On "/quote/<arg>", we pick one random quote for the client.
.route("/quote/<arg>", http::method::get,
[](http::responder& res, std::string name) {
auto quotes = quotes_by_name(name);
if (quotes.empty()) {
res.respond(http::status::not_found, "text/plain",
not_found_str(name));
} else {
pick_random f;
res.respond(http::status::ok, "text/plain", f(quotes));
}
})
// On "/ws/quotes/<arg>", we switch the protocol to WebSocket.
.route("/ws/quotes/<arg>", http::method::get,
ws::switch_protocol()
// Check that the client asks for a known philosopher.
.on_request(
[](ws::acceptor<caf::cow_string>& acc, std::string name) {
auto quotes = quotes_by_name(name);
if (quotes.empty()) {
auto err = make_error(caf::sec::invalid_argument,
not_found_str(name));
acc.reject(std::move(err));
} else {
// Forward the name to the WebSocket worker.
acc.accept(caf::cow_string{std::move(name)});
}
})
// When started, run our worker actor to handle incoming connections.
.start([&sys](trait::acceptor_resource<caf::cow_string> events) {
// Spawn a worker for the WebSocket clients.
.on_start(
[&sys](trait::acceptor_resource<caf::cow_string> events) {
// Spawn a worker that reads from `events`.
using event_t = trait::accept_event<caf::cow_string>;
sys.spawn([events](caf::event_based_actor* self) {
// For each buffer pair, we create a new flow ...
// Each WS connection has a pull/push buffer pair.
self->make_observable()
.from_resource(events) //
.for_each([self, f = pick_random{}](const event_t& ev) mutable {
// ... that pushes one random quote to the client.
auto [pull, push, path] = ev.data();
auto quotes = quotes_by_path(path);
auto quote = quotes.empty()
? "Try /epictetus, /seneca or /plato."
: f(quotes);
self->make_observable().just(ws::frame{quote}).subscribe(push);
.for_each([self](const event_t& ev) mutable {
// Forward the quotes to the client.
auto [pull, push, name] = ev.data();
auto quotes = quotes_by_name(name);
assert(!quotes.empty()); // Checked in on_request.
self->make_observable()
.from_container(quotes)
.map([](std::string_view quote) {
return ws::frame{quote};
})
.subscribe(push);
// We ignore whatever the client may send to us.
pull.observe_on(self).subscribe(std::ignore);
});
});
});
}))
// Run with the configured routes.
.start();
// Report any error to the user.
if (!server) {
std::cerr << "*** unable to run at port " << port << ": "
......
......@@ -18,6 +18,12 @@
namespace caf {
/// Tag type for selecting case-insensitive algorithms.
struct ignore_case_t {};
/// Tag for selecting case-insensitive algorithms.
constexpr ignore_case_t ignore_case = ignore_case_t{};
// provide boost::split compatible interface
constexpr std::string_view is_any_of(std::string_view arg) noexcept {
......
......@@ -65,14 +65,11 @@ std::string_view trim(std::string_view str) {
}
bool icase_equal(std::string_view x, std::string_view y) {
if (x.size() != y.size()) {
return false;
} else {
for (size_t index = 0; index < x.size(); ++index)
if (tolower(x[index]) != tolower(y[index]))
return false;
return true;
}
auto cmp = [](const char lhs, const char rhs) {
auto to_uchar = [](char c) { return static_cast<unsigned char>(c); };
return tolower(to_uchar(lhs)) == tolower(to_uchar(rhs));
};
return std::equal(x.begin(), x.end(), y.begin(), y.end(), cmp);
}
std::pair<std::string_view, std::string_view> split_by(std::string_view str,
......
......@@ -30,10 +30,6 @@ public:
/// Type for the producer adapter. We produce the input of the application.
using producer_type = async::producer_adapter<input_type>;
flow_bridge_base(async::execution_context_ptr loop) : loop_(std::move(loop)) {
// nop
}
virtual bool write(const output_type& item) = 0;
bool running() const noexcept {
......@@ -41,14 +37,14 @@ public:
}
/// Initializes consumer and producer of the bridge.
error init(async::consumer_resource<output_type> pull,
error init(net::multiplexer* mpx, async::consumer_resource<output_type> pull,
async::producer_resource<input_type> push) {
// Initialize our consumer.
auto do_wakeup = make_action([this] {
if (running())
prepare_send();
});
in_ = consumer_type::make(pull.try_open(), loop_, std::move(do_wakeup));
in_ = consumer_type::make(pull.try_open(), mpx, std::move(do_wakeup));
if (!in_) {
auto err = make_error(sec::runtime_error,
"flow bridge failed to open the input resource");
......@@ -62,7 +58,7 @@ public:
down_->shutdown();
}
});
out_ = producer_type::make(push.try_open(), loop_, std::move(do_resume),
out_ = producer_type::make(push.try_open(), mpx, std::move(do_resume),
std::move(do_cancel));
if (!out_) {
auto err = make_error(sec::runtime_error,
......@@ -139,9 +135,6 @@ protected:
/// Converts between raw bytes and native C++ objects.
Trait trait_;
/// Our event loop.
async::execution_context_ptr loop_;
/// Type-erased handle to the @ref socket_manager. This reference is important
/// to keep the bridge alive while the manager is not registered for writing
/// or reading.
......
......@@ -9,6 +9,7 @@
#include "caf/net/fwd.hpp"
#include "caf/net/http/method.hpp"
#include "caf/net/http/status.hpp"
#include "caf/string_algorithms.hpp"
#include "caf/uri.hpp"
#include <string_view>
......@@ -91,6 +92,25 @@ public:
return {};
}
/// Checks whether the field `key` exists and equals `val` when using
/// case-insensitive compare.
bool field_equals(ignore_case_t, std::string_view key,
std::string_view val) const noexcept {
if (auto i = fields_.find(key); i != fields_.end())
return icase_equal(val, i->second);
else
return false;
}
/// Checks whether the field `key` exists and equals `val` when using
/// case-insensitive compare.
bool field_equals(std::string_view key, std::string_view val) const noexcept {
if (auto i = fields_.find(key); i != fields_.end())
return val == i->second;
else
return false;
}
/// Returns the value of the field with the specified key as the requested
/// type T, or std::nullopt if the field is not found or cannot be converted.
template <class T>
......
......@@ -107,9 +107,10 @@ public:
/// asynchronously.
request to_request() &&;
private:
/// Returns a pointer to the HTTP layer.
lower_layer* down();
private:
const request_header* hdr_;
const_byte_span body_;
http::router* router_;
......
......@@ -72,6 +72,15 @@ bool match_path(std::string_view lhs, std::string_view rhs, F&& predicate) {
}
return tail2.empty();
}
template <class, class = void>
struct http_route_has_init : std::false_type {};
template <class T>
struct http_route_has_init<T, std::void_t<decltype(std::declval<T>().init())>>
: std::true_type {};
template <class T>
constexpr bool http_route_has_init_v = http_route_has_init<T>::value;
/// Base type for HTTP routes that parse one or more arguments from the requests
/// and then forward them to a user-provided function object.
......@@ -145,6 +154,12 @@ public:
// nop
}
void init() override {
if constexpr (detail::http_route_has_init_v<F>) {
f_.init();
}
}
private:
void do_apply(net::http::responder& res, Ts&&... args) override {
f_(res, std::move(args)...);
......@@ -184,6 +199,12 @@ public:
// nop
}
void init() override {
if constexpr (detail::http_route_has_init<F>::value) {
f_.init();
}
}
private:
void do_apply(net::http::responder& res) override {
f_(res);
......@@ -211,7 +232,7 @@ private:
F f_;
};
/// Default policy class for
/// Creates a route from a function object.
template <class F, class... Args>
net::http::route_ptr
make_http_route_impl(std::string& path, std::optional<net::http::method> method,
......
......@@ -170,6 +170,8 @@ public:
[[nodiscard]] expected<disposable> start(OnStart on_start) {
using consumer_resource = async::consumer_resource<request>;
static_assert(std::is_invocable_v<OnStart, consumer_resource>);
for (auto& ptr : super::config().routes)
ptr->init();
auto& cfg = super::config();
return cfg.visit([this, &cfg, &on_start](auto& data) {
return this->do_start(cfg, data, on_start)
......
......@@ -34,15 +34,13 @@ public:
// We produce the input type of the application.
using push_t = async::producer_resource<typename Trait::input_type>;
lp_client_flow_bridge(async::execution_context_ptr loop, pull_t pull,
push_t push)
: super(std::move(loop)), pull_(std::move(pull)), push_(std::move(push)) {
lp_client_flow_bridge(pull_t pull, push_t push)
: pull_(std::move(pull)), push_(std::move(push)) {
// nop
}
static std::unique_ptr<lp_client_flow_bridge> make(net::multiplexer* mpx,
pull_t pull, push_t push) {
return std::make_unique<lp_client_flow_bridge>(mpx, std::move(pull),
static std::unique_ptr<lp_client_flow_bridge> make(pull_t pull, push_t push) {
return std::make_unique<lp_client_flow_bridge>(std::move(pull),
std::move(push));
}
......@@ -54,7 +52,7 @@ public:
error start(net::lp::lower_layer* down_ptr) override {
super::down_ = down_ptr;
return super::init(std::move(pull_), std::move(push_));
return super::init(&down_ptr->mpx(), std::move(pull_), std::move(push_));
}
private:
......@@ -101,8 +99,7 @@ private:
using transport_t = typename Conn::transport_type;
auto [s2a_pull, s2a_push] = async::make_spsc_buffer_resource<input_t>();
auto [a2s_pull, a2s_push] = async::make_spsc_buffer_resource<output_t>();
auto bridge = bridge_t::make(cfg.mpx, std::move(a2s_pull),
std::move(s2a_push));
auto bridge = bridge_t::make(std::move(a2s_pull), std::move(s2a_push));
auto bridge_ptr = bridge.get();
auto impl = framing::make(std::move(bridge));
auto fd = conn.fd();
......
......@@ -46,14 +46,13 @@ public:
// one thread running in the multiplexer (which makes this safe).
using shared_producer_type = std::shared_ptr<producer_type>;
lp_server_flow_bridge(async::execution_context_ptr loop,
shared_producer_type producer)
: super(std::move(loop)), producer_(std::move(producer)) {
lp_server_flow_bridge(shared_producer_type producer)
: producer_(std::move(producer)) {
// nop
}
static auto make(net::multiplexer* mpx, shared_producer_type producer) {
return std::make_unique<lp_server_flow_bridge>(mpx, std::move(producer));
static auto make(shared_producer_type producer) {
return std::make_unique<lp_server_flow_bridge>(std::move(producer));
}
error start(net::lp::lower_layer* down_ptr) override {
......@@ -66,7 +65,8 @@ public:
return make_error(sec::runtime_error,
"Length-prefixed connection dropped: client canceled");
}
return super::init(std::move(lp_pull), std::move(lp_push));
return super::init(&down_ptr->mpx(), std::move(lp_pull),
std::move(lp_push));
}
private:
......@@ -95,7 +95,7 @@ public:
net::socket_manager_ptr make(net::multiplexer* mpx,
connection_handle conn) override {
using bridge_t = lp_server_flow_bridge<Trait>;
auto bridge = bridge_t::make(mpx, producer_);
auto bridge = bridge_t::make(producer_);
auto bridge_ptr = bridge.get();
auto impl = net::lp::framing::make(std::move(bridge));
auto fd = conn.fd();
......
......@@ -33,7 +33,7 @@ public:
using handshake_ptr = std::unique_ptr<handshake>;
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer::client>;
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>;
// -- constructors, destructors, and assignment operators --------------------
......
......@@ -26,9 +26,9 @@ namespace caf::detail {
/// Specializes the WebSocket flow bridge for the server side.
template <class Trait, class... Ts>
class ws_client_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::upper_layer::client> {
: public ws_flow_bridge<Trait, net::web_socket::upper_layer> {
public:
using super = ws_flow_bridge<Trait, net::web_socket::upper_layer::client>;
using super = ws_flow_bridge<Trait, net::web_socket::upper_layer>;
// We consume the output type of the application.
using pull_t = async::consumer_resource<typename Trait::input_type>;
......@@ -36,21 +36,19 @@ public:
// We produce the input type of the application.
using push_t = async::producer_resource<typename Trait::output_type>;
ws_client_flow_bridge(async::execution_context_ptr loop, pull_t pull,
push_t push)
: super(std::move(loop)), pull_(std::move(pull)), push_(std::move(push)) {
ws_client_flow_bridge(pull_t pull, push_t push)
: pull_(std::move(pull)), push_(std::move(push)) {
// nop
}
static std::unique_ptr<ws_client_flow_bridge> make(net::multiplexer* mpx,
pull_t pull, push_t push) {
return std::make_unique<ws_client_flow_bridge>(mpx, std::move(pull),
static std::unique_ptr<ws_client_flow_bridge> make(pull_t pull, push_t push) {
return std::make_unique<ws_client_flow_bridge>(std::move(pull),
std::move(push));
}
error start(net::web_socket::lower_layer* down_ptr) override {
super::down_ = down_ptr;
return super::init(std::move(pull_), std::move(push_));
return super::init(&down_ptr->mpx(), std::move(pull_), std::move(push_));
}
private:
......@@ -117,8 +115,7 @@ private:
using bridge_t = detail::ws_client_flow_bridge<Trait>;
auto [s2a_pull, s2a_push] = async::make_spsc_buffer_resource<input_t>();
auto [a2s_pull, a2s_push] = async::make_spsc_buffer_resource<output_t>();
auto bridge = bridge_t::make(cfg.mpx, std::move(a2s_pull),
std::move(s2a_push));
auto bridge = bridge_t::make(std::move(a2s_pull), std::move(s2a_push));
auto bridge_ptr = bridge.get();
auto impl = client::make(std::move(cfg.hs), std::move(bridge));
auto fd = conn.fd();
......
......@@ -35,10 +35,6 @@ public:
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>;
using server_ptr = std::unique_ptr<web_socket::upper_layer::server>;
using client_ptr = std::unique_ptr<web_socket::upper_layer::client>;
// -- constants --------------------------------------------------------------
/// Restricts the size of received frames (including header).
......@@ -49,13 +45,18 @@ public:
// -- constructors, destructors, and assignment operators --------------------
static std::unique_ptr<framing> make(client_ptr up) {
/// Creates a new framing protocol for client mode.
static std::unique_ptr<framing> make_client(upper_layer_ptr up) {
return std::unique_ptr<framing>{new framing(std::move(up))};
}
static std::unique_ptr<framing> make(server_ptr up,
http::request_header hdr) {
return std::unique_ptr<framing>{new framing(std::move(up), std::move(hdr))};
/// Creates a new framing protocol for server mode.
static std::unique_ptr<framing> make_server(upper_layer_ptr up) {
// > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
auto res = std::unique_ptr<framing>{new framing(std::move(up))};
res->mask_outgoing_frames = false;
return res;
}
// -- properties -------------------------------------------------------------
......@@ -127,17 +128,10 @@ public:
private:
// -- implementation details -------------------------------------------------
explicit framing(client_ptr up) : up_(std::move(up)) {
explicit framing(upper_layer_ptr up) : up_(std::move(up)) {
// nop
}
explicit framing(server_ptr up, http::request_header&& hdr)
: up_(std::move(up)), hdr_(std::move(hdr)) {
// > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
mask_outgoing_frames = false;
}
bool handle(uint8_t opcode, byte_span payload);
void ship_pong(byte_span payload);
......@@ -167,9 +161,6 @@ private:
/// Next layer in the processing chain.
upper_layer_ptr up_;
/// Stored when running as a server and passed to `up_` in start.
std::optional<http::request_header> hdr_;
};
} // namespace caf::net::web_socket
......@@ -7,6 +7,7 @@
#include "caf/byte_buffer.hpp"
#include "caf/detail/net_export.hpp"
#include "caf/dictionary.hpp"
#include "caf/net/fwd.hpp"
#include <cstddef>
#include <string>
......@@ -135,6 +136,10 @@ public:
/// @pre `has_valid_key()`
void write_http_1_response(byte_buffer& buf) const;
/// Writes the HTTP response message to `down`.
/// @pre `has_valid_key()`
void write_response(http::lower_layer* down) const;
/// Checks whether the `http_response` contains a HTTP 1.1 response to the
/// generated HTTP GET request. A valid response contains:
/// - HTTP status code 101 (Switching Protocols).
......
......@@ -40,49 +40,55 @@ public:
using producer_type = async::blocking_producer<accept_event>;
using acceptor_impl_t = net::web_socket::acceptor_impl<Trait, Ts...>;
using ws_res_type = typename acceptor_impl_t::ws_res_type;
// Note: this is shared with the connection factory. In general, this is
// *unsafe*. However, we exploit the fact that there is currently only
// one thread running in the multiplexer (which makes this safe).
using shared_producer_type = std::shared_ptr<producer_type>;
ws_server_flow_bridge(async::execution_context_ptr loop,
on_request_cb_type on_request,
ws_server_flow_bridge(on_request_cb_type on_request,
shared_producer_type producer)
: super(std::move(loop)),
on_request_(std::move(on_request)),
producer_(std::move(producer)) {
: on_request_(std::move(on_request)), producer_(std::move(producer)) {
// nop
}
static auto make(net::multiplexer* mpx, on_request_cb_type on_request,
static auto make(on_request_cb_type on_request,
shared_producer_type producer) {
return std::make_unique<ws_server_flow_bridge>(mpx, std::move(on_request),
return std::make_unique<ws_server_flow_bridge>(std::move(on_request),
std::move(producer));
}
error start(net::web_socket::lower_layer* down_ptr,
const net::http::request_header& hdr) override {
error start(net::web_socket::lower_layer* down_ptr) override {
CAF_ASSERT(down_ptr != nullptr);
super::down_ = down_ptr;
if (!producer_->push(app_event)) {
return make_error(sec::runtime_error,
"WebSocket connection dropped: client canceled");
}
auto& [pull, push] = ws_resources;
return super::init(&down_ptr->mpx(), std::move(pull), std::move(push));
}
error accept(const net::http::request_header& hdr) override {
net::web_socket::acceptor_impl<Trait, Ts...> acc{hdr};
(*on_request_)(acc);
if (!acc.accepted()) {
if (acc.accepted()) {
app_event = std::move(acc.app_event);
return {};
}
return std::move(acc) //
.reject_reason()
.or_else(sec::runtime_error,
"WebSocket request rejected without reason");
}
if (!producer_->push(acc.app_event)) {
return make_error(sec::runtime_error,
"WebSocket connection dropped: client canceled");
}
auto& [pull, push] = acc.ws_resources;
return super::init(std::move(pull), std::move(push));
.or_else(sec::runtime_error, "WebSocket request rejected without reason");
}
private:
on_request_cb_type on_request_;
shared_producer_type producer_;
accept_event app_event;
ws_res_type ws_resources;
};
/// Specializes @ref connection_factory for the WebSocket protocol.
......@@ -118,7 +124,7 @@ public:
return nullptr;
}
using bridge_t = ws_server_flow_bridge<Trait, Ts...>;
auto app = bridge_t::make(mpx, on_request_, producer_);
auto app = bridge_t::make(on_request_, producer_);
auto app_ptr = app.get();
auto ws = net::web_socket::server::make(std::move(app));
auto fd = conn.fd();
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/detail/type_list.hpp"
#include "caf/net/web_socket/acceptor.hpp"
#include "caf/net/web_socket/default_trait.hpp"
#include "caf/net/web_socket/server_factory.hpp"
#include <memory>
namespace caf::detail {
/// Specializes the WebSocket flow bridge for the switch-protocol use case.
template <class Trait, class... Ts>
class ws_switch_protocol_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::upper_layer> {
public:
using super = ws_flow_bridge<Trait, net::web_socket::upper_layer>;
using accept_event = typename Trait::template accept_event<Ts...>;
using producer_type = async::blocking_producer<accept_event>;
using pull_type = async::consumer_resource<typename Trait::output_type>;
using push_type = async::producer_resource<typename Trait::input_type>;
// Note: in general, this is *unsafe*. However, we exploit the fact that there
// is currently only one thread running in the multiplexer (which makes
// this safe).
using shared_producer_type = std::shared_ptr<producer_type>;
ws_switch_protocol_flow_bridge(shared_producer_type producer, pull_type pull,
push_type push)
: producer_(std::move(producer)),
pull_(std::move(pull)),
push_(std::move(push)) {
// nop
}
static auto make(shared_producer_type producer, pull_type pull,
push_type push) {
using impl_t = ws_switch_protocol_flow_bridge;
return std::make_unique<impl_t>(std::move(producer), std::move(pull),
std::move(push));
}
error start(net::web_socket::lower_layer* down_ptr) override {
CAF_ASSERT(down_ptr != nullptr);
super::down_ = down_ptr;
return super::init(&down_ptr->mpx(), std::move(pull_), std::move(push_));
}
private:
shared_producer_type producer_;
pull_type pull_;
push_type push_;
};
template <class OnRequest, class OnStart>
struct ws_switch_protocol_state {
ws_switch_protocol_state(OnRequest on_request_fn, OnStart on_start_fn)
: on_request(std::move(on_request_fn)), on_start(std::move(on_start_fn)) {
// nop
}
OnRequest on_request;
std::optional<OnStart> on_start;
};
template <class Trait, class State, class Out, class... Ts>
class ws_switch_protocol;
template <class Trait, class State, class... Out, class... Ts>
class ws_switch_protocol<Trait, State, type_list<Out...>, Ts...> {
public:
using accept_event = typename Trait::template accept_event<Out...>;
using producer_type = async::blocking_producer<accept_event>;
using shared_producer_type = std::shared_ptr<producer_type>;
explicit ws_switch_protocol(std::shared_ptr<State> state)
: state_(std::move(state)) {
// nop
}
ws_switch_protocol(ws_switch_protocol&&) = default;
ws_switch_protocol(const ws_switch_protocol&) = default;
ws_switch_protocol& operator=(ws_switch_protocol&&) = default;
ws_switch_protocol& operator=(const ws_switch_protocol&) = default;
void operator()(net::http::responder& res, Ts... args) {
namespace http = net::http;
auto& hdr = res.header();
// Sanity checking.
if (!hdr.field_equals(ignore_case, "Connection", "upgrade")
|| !hdr.field_equals(ignore_case, "Upgrade", "websocket")) {
res.respond(net::http::status::bad_request, "text/plain",
"Expected a WebSocket client handshake request.");
return;
}
auto sec_key = hdr.field("Sec-WebSocket-Key");
if (sec_key.empty()) {
res.respond(net::http::status::bad_request, "text/plain",
"Mandatory field Sec-WebSocket-Key missing or invalid.");
return;
}
// Call user-defined on_request callback.
net::web_socket::acceptor_impl<Trait, Out...> acc{hdr};
(state_->on_request)(acc, args...);
if (!acc.accepted()) {
if (auto& err = acc.reject_reason()) {
auto descr = to_string(err);
res.respond(http::status::bad_request, "text/plain", descr);
} else {
res.respond(http::status::bad_request, "text/plain", "Bad request.");
}
return;
}
if (!producer_->push(acc.app_event)) {
res.respond(http::status::internal_server_error, "text/plain",
"Upstream channel closed.");
return;
}
// Finalize the WebSocket handshake.
net::web_socket::handshake hs;
hs.assign_key(sec_key);
hs.write_response(res.down());
// Switch to the WebSocket framing protocol.
auto& [pull, push] = acc.ws_resources;
using net::web_socket::framing;
using bridge_t = ws_switch_protocol_flow_bridge<Trait, Out...>;
auto bridge = bridge_t::make(producer_, std::move(pull), std::move(push));
res.down()->switch_protocol(framing::make_server(std::move(bridge)));
}
void init() {
if (auto& on_start = state_->on_start; on_start) {
auto [pull, push] = async::make_spsc_buffer_resource<accept_event>();
using producer_t = async::blocking_producer<accept_event>;
producer_ = std::make_shared<producer_t>(producer_t{push.try_open()});
(*on_start)(std::move(pull));
on_start = std::nullopt;
}
}
private:
std::shared_ptr<State> state_;
shared_producer_type producer_;
};
} // namespace caf::detail
namespace caf::net::web_socket {
/// Binds a `switch_protocol` invocation to a trait class and a function object
/// for on_request.
template <class Trait, class OnRequest>
struct switch_protocol_bind_2 {
public:
switch_protocol_bind_2(OnRequest on_request)
: on_request_(std::move(on_request)) {
// nop
}
template <class OnStart>
auto on_start(OnStart on_start) && {
using on_request_trait = detail::get_callable_trait_t<OnRequest>;
using on_request_args = typename on_request_trait::arg_types;
return make(on_start, on_request_args{});
}
private:
template <class OnStart, class... Out, class... Ts>
auto make(OnStart& on_start,
detail::type_list<net::web_socket::acceptor<Out...>&, Ts...>) {
using namespace detail;
using state_t = ws_switch_protocol_state<OnRequest, OnStart>;
using impl_t = ws_switch_protocol<Trait, state_t, type_list<Out...>, Ts...>;
auto state = std::make_shared<state_t>(std::move(on_request_),
std::move(on_start));
static_assert(http_route_has_init_v<impl_t>);
return impl_t{std::move(state)};
}
OnRequest on_request_;
};
/// Binds a `switch_protocol` invocation to a trait class.
template <class Trait>
struct switch_protocol_bind_1 {
template <class OnRequest>
auto on_request(OnRequest on_request) {
return switch_protocol_bind_2<Trait, OnRequest>(std::move(on_request));
}
};
template <class Trait = default_trait>
auto switch_protocol() {
return switch_protocol_bind_1<Trait>{};
}
} // namespace caf::net::web_socket
......@@ -25,18 +25,17 @@ public:
virtual ptrdiff_t consume_binary(byte_span buf) = 0;
virtual ptrdiff_t consume_text(std::string_view buf) = 0;
virtual error start(lower_layer* down) = 0;
};
class upper_layer::server : public upper_layer {
public:
virtual ~server();
virtual error start(lower_layer* down, const http::request_header& hdr) = 0;
};
class upper_layer::client : public upper_layer {
public:
virtual ~client();
virtual error start(lower_layer* down) = 0;
/// Asks the layer to accept a new client.
/// @warning the server calls this function *before* calling `start`.
virtual error accept(const http::request_header& hdr) = 0;
};
} // namespace caf::net::web_socket
......@@ -50,7 +50,7 @@ void client::abort(const error& reason) {
up_->abort(reason);
}
ptrdiff_t client::consume(byte_span buffer, byte_span delta) {
ptrdiff_t client::consume(byte_span buffer, byte_span) {
CAF_LOG_TRACE(CAF_ARG2("buffer", buffer.size()));
// Check whether we have received the HTTP header or else wait for more
// data. Abort when exceeding the maximum size.
......@@ -89,7 +89,7 @@ bool client::handle_header(std::string_view http) {
auto http_ok = hs_->is_valid_http_1_response(http);
hs_.reset();
if (http_ok) {
down_->switch_protocol(framing::make(std::move(up_)));
down_->switch_protocol(framing::make_client(std::move(up_)));
return true;
}
CAF_LOG_DEBUG("received an invalid WebSocket handshake");
......
......@@ -15,22 +15,7 @@ error framing::start(octet_stream::lower_layer* down) {
std::random_device rd;
rng_.seed(rd());
down_ = down;
if (!hdr_) {
using dptr_t = web_socket::upper_layer::client;
return static_cast<dptr_t*>(up_.get())->start(this);
}
using dptr_t = web_socket::upper_layer::server;
auto err = static_cast<dptr_t*>(up_.get())->start(this, *hdr_);
hdr_ = std::nullopt;
if (err) {
auto descr = to_string(err);
CAF_LOG_DEBUG("upper layer rejected a WebSocket connection:" << descr);
down_->begin_output();
http::v1::write_response(http::status::bad_request, "text/plain", descr,
down_->output_buffer());
down_->end_output();
}
return err;
return up_->start(this);
}
void framing::abort(const error& reason) {
......
......@@ -7,6 +7,8 @@
#include "caf/config.hpp"
#include "caf/detail/base64.hpp"
#include "caf/hash/sha1.hpp"
#include "caf/net/http/lower_layer.hpp"
#include "caf/net/http/status.hpp"
#include "caf/string_algorithms.hpp"
#include <algorithm>
......@@ -118,6 +120,15 @@ void handshake::write_http_1_response(byte_buffer& buf) const {
<< response_key() << "\r\n\r\n";
}
void handshake::write_response(http::lower_layer* down) const {
down->begin_header(http::status::switching_protocols);
down->add_header_field("Upgrade", "websocket");
down->add_header_field("Connection", "Upgrade");
down->add_header_field("Sec-WebSocket-Accept", response_key());
down->end_header();
down->send_payload({});
}
namespace {
template <class F>
......
......@@ -80,14 +80,20 @@ bool server::handle_header(std::string_view http) {
CAF_LOG_DEBUG("received invalid WebSocket handshake");
return false;
}
// Kindly ask the upper layer to accept a new WebSocket connection.
if (auto err = up_->accept(hdr)) {
write_response(http::status::bad_request, to_string(err));
return false;
}
// Finalize the WebSocket handshake.
handshake hs;
hs.assign_key(sec_key);
down_->begin_output();
hs.write_http_1_response(down_->output_buffer());
down_->end_output();
// All done. Switch to the framing protocol.
CAF_LOG_DEBUG("completed WebSocket handshake");
down_->switch_protocol(framing::make(std::move(up_), std::move(hdr)));
down_->switch_protocol(framing::make_server(std::move(up_)));
return true;
}
......
......@@ -14,8 +14,4 @@ upper_layer::server::~server() {
// nop
}
upper_layer::client::~client() {
// nop
}
} // namespace caf::net::web_socket
......@@ -14,7 +14,7 @@ namespace {
using svec = std::vector<std::string>;
class app_t : public net::web_socket::upper_layer::client {
class app_t : public net::web_socket::upper_layer {
public:
static auto make() {
return std::make_unique<app_t>();
......
......@@ -32,9 +32,12 @@ public:
return std::make_unique<app_t>();
}
error start(net::web_socket::lower_layer* down,
const net::http::request_header& hdr) override {
error start(net::web_socket::lower_layer* down) override {
down->request_messages();
return none;
}
error accept(const net::http::request_header& hdr) override {
// Store the request information in cfg to evaluate them later.
auto& ws = cfg["web-socket"].as_dictionary();
put(ws, "method", to_rfc_string(hdr.method()));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment