Commit db0db507 authored by Dominik Charousset's avatar Dominik Charousset

Enbale HTTP routes to switch to WebSocket

parent 215e4ed3
// Simple WebSocket server that sends local files from a working directory to // Simple HTTP/WebSocket server that sends predefined text snippets
// the clients. // (philosophers quotes) to the client. Clients may either ask for a single
// quote via HTTP GET request or for all quotes of a selected philosopher by
// connecting via WebSocket.
#include "caf/actor_system.hpp" #include "caf/actor_system.hpp"
#include "caf/actor_system_config.hpp" #include "caf/actor_system_config.hpp"
...@@ -7,9 +9,10 @@ ...@@ -7,9 +9,10 @@
#include "caf/cow_string.hpp" #include "caf/cow_string.hpp"
#include "caf/cow_tuple.hpp" #include "caf/cow_tuple.hpp"
#include "caf/event_based_actor.hpp" #include "caf/event_based_actor.hpp"
#include "caf/net/http/with.hpp"
#include "caf/net/middleman.hpp" #include "caf/net/middleman.hpp"
#include "caf/net/ssl/context.hpp" #include "caf/net/ssl/context.hpp"
#include "caf/net/web_socket/with.hpp" #include "caf/net/web_socket/switch_protocol.hpp"
#include "caf/scheduled_actor/flow.hpp" #include "caf/scheduled_actor/flow.hpp"
#include "caf/span.hpp" #include "caf/span.hpp"
...@@ -81,12 +84,12 @@ struct config : caf::actor_system_config { ...@@ -81,12 +84,12 @@ struct config : caf::actor_system_config {
// -- helper functions --------------------------------------------------------- // -- helper functions ---------------------------------------------------------
// Returns a list of philosopher quotes by path. // Returns a list of philosopher quotes by path.
caf::span<const std::string_view> quotes_by_path(std::string_view path) { caf::span<const std::string_view> quotes_by_name(std::string_view path) {
if (path == "/epictetus") if (path == "epictetus")
return caf::make_span(epictetus); return caf::make_span(epictetus);
else if (path == "/seneca") else if (path == "seneca")
return caf::make_span(seneca); return caf::make_span(seneca);
else if (path == "/plato") else if (path == "plato")
return caf::make_span(plato); return caf::make_span(plato);
else else
return {}; return {};
...@@ -109,6 +112,13 @@ private: ...@@ -109,6 +112,13 @@ private:
std::minstd_rand engine_; std::minstd_rand engine_;
}; };
std::string not_found_str(std::string_view name) {
auto result = "Name '"s;
result += name;
result += "' not found. Try 'epictetus', 'seneca' or 'plato'.";
return result;
}
// -- main --------------------------------------------------------------------- // -- main ---------------------------------------------------------------------
int caf_main(caf::actor_system& sys, const config& cfg) { int caf_main(caf::actor_system& sys, const config& cfg) {
...@@ -129,7 +139,7 @@ int caf_main(caf::actor_system& sys, const config& cfg) { ...@@ -129,7 +139,7 @@ int caf_main(caf::actor_system& sys, const config& cfg) {
// Open up a TCP port for incoming connections and start the server. // Open up a TCP port for incoming connections and start the server.
using trait = ws::default_trait; using trait = ws::default_trait;
auto server auto server
= ws::with(sys) = http::with(sys)
// Optionally enable TLS. // Optionally enable TLS.
.context(ssl::context::enable(key_file && cert_file) .context(ssl::context::enable(key_file && cert_file)
.and_then(ssl::emplace_server(ssl::tls::v1_2)) .and_then(ssl::emplace_server(ssl::tls::v1_2))
...@@ -139,34 +149,61 @@ int caf_main(caf::actor_system& sys, const config& cfg) { ...@@ -139,34 +149,61 @@ int caf_main(caf::actor_system& sys, const config& cfg) {
.accept(port) .accept(port)
// Limit how many clients may be connected at any given time. // Limit how many clients may be connected at any given time.
.max_connections(max_connections) .max_connections(max_connections)
// Forward the path from the WebSocket request to the worker. // On "/quote/<arg>", we pick one random quote for the client.
.on_request([](ws::acceptor<caf::cow_string>& acc) { .route("/quote/<arg>", http::method::get,
// The hdr parameter is a dictionary with fields from the WebSocket [](http::responder& res, std::string name) {
// handshake such as the path. This is only field we care about auto quotes = quotes_by_name(name);
// here. By passing the (copy-on-write) string to accept() here, we if (quotes.empty()) {
// make it available to the worker through the acceptor_resource. res.respond(http::status::not_found, "text/plain",
acc.accept(caf::cow_string{acc.header().path()}); not_found_str(name));
} else {
pick_random f;
res.respond(http::status::ok, "text/plain", f(quotes));
}
})
// On "/ws/quotes/<arg>", we switch the protocol to WebSocket.
.route("/ws/quotes/<arg>", http::method::get,
ws::switch_protocol()
// Check that the client asks for a known philosopher.
.on_request(
[](ws::acceptor<caf::cow_string>& acc, std::string name) {
auto quotes = quotes_by_name(name);
if (quotes.empty()) {
auto err = make_error(caf::sec::invalid_argument,
not_found_str(name));
acc.reject(std::move(err));
} else {
// Forward the name to the WebSocket worker.
acc.accept(caf::cow_string{std::move(name)});
}
}) })
// When started, run our worker actor to handle incoming connections. // Spawn a worker for the WebSocket clients.
.start([&sys](trait::acceptor_resource<caf::cow_string> events) { .on_start(
[&sys](trait::acceptor_resource<caf::cow_string> events) {
// Spawn a worker that reads from `events`.
using event_t = trait::accept_event<caf::cow_string>; using event_t = trait::accept_event<caf::cow_string>;
sys.spawn([events](caf::event_based_actor* self) { sys.spawn([events](caf::event_based_actor* self) {
// For each buffer pair, we create a new flow ... // Each WS connection has a pull/push buffer pair.
self->make_observable() self->make_observable()
.from_resource(events) // .from_resource(events) //
.for_each([self, f = pick_random{}](const event_t& ev) mutable { .for_each([self](const event_t& ev) mutable {
// ... that pushes one random quote to the client. // Forward the quotes to the client.
auto [pull, push, path] = ev.data(); auto [pull, push, name] = ev.data();
auto quotes = quotes_by_path(path); auto quotes = quotes_by_name(name);
auto quote = quotes.empty() assert(!quotes.empty()); // Checked in on_request.
? "Try /epictetus, /seneca or /plato." self->make_observable()
: f(quotes); .from_container(quotes)
self->make_observable().just(ws::frame{quote}).subscribe(push); .map([](std::string_view quote) {
return ws::frame{quote};
})
.subscribe(push);
// We ignore whatever the client may send to us. // We ignore whatever the client may send to us.
pull.observe_on(self).subscribe(std::ignore); pull.observe_on(self).subscribe(std::ignore);
}); });
}); });
}); }))
// Run with the configured routes.
.start();
// Report any error to the user. // Report any error to the user.
if (!server) { if (!server) {
std::cerr << "*** unable to run at port " << port << ": " std::cerr << "*** unable to run at port " << port << ": "
......
...@@ -18,6 +18,12 @@ ...@@ -18,6 +18,12 @@
namespace caf { namespace caf {
/// Tag type for selecting case-insensitive algorithms.
struct ignore_case_t {};
/// Tag for selecting case-insensitive algorithms.
constexpr ignore_case_t ignore_case = ignore_case_t{};
// provide boost::split compatible interface // provide boost::split compatible interface
constexpr std::string_view is_any_of(std::string_view arg) noexcept { constexpr std::string_view is_any_of(std::string_view arg) noexcept {
......
...@@ -65,14 +65,11 @@ std::string_view trim(std::string_view str) { ...@@ -65,14 +65,11 @@ std::string_view trim(std::string_view str) {
} }
bool icase_equal(std::string_view x, std::string_view y) { bool icase_equal(std::string_view x, std::string_view y) {
if (x.size() != y.size()) { auto cmp = [](const char lhs, const char rhs) {
return false; auto to_uchar = [](char c) { return static_cast<unsigned char>(c); };
} else { return tolower(to_uchar(lhs)) == tolower(to_uchar(rhs));
for (size_t index = 0; index < x.size(); ++index) };
if (tolower(x[index]) != tolower(y[index])) return std::equal(x.begin(), x.end(), y.begin(), y.end(), cmp);
return false;
return true;
}
} }
std::pair<std::string_view, std::string_view> split_by(std::string_view str, std::pair<std::string_view, std::string_view> split_by(std::string_view str,
......
...@@ -30,10 +30,6 @@ public: ...@@ -30,10 +30,6 @@ public:
/// Type for the producer adapter. We produce the input of the application. /// Type for the producer adapter. We produce the input of the application.
using producer_type = async::producer_adapter<input_type>; using producer_type = async::producer_adapter<input_type>;
flow_bridge_base(async::execution_context_ptr loop) : loop_(std::move(loop)) {
// nop
}
virtual bool write(const output_type& item) = 0; virtual bool write(const output_type& item) = 0;
bool running() const noexcept { bool running() const noexcept {
...@@ -41,14 +37,14 @@ public: ...@@ -41,14 +37,14 @@ public:
} }
/// Initializes consumer and producer of the bridge. /// Initializes consumer and producer of the bridge.
error init(async::consumer_resource<output_type> pull, error init(net::multiplexer* mpx, async::consumer_resource<output_type> pull,
async::producer_resource<input_type> push) { async::producer_resource<input_type> push) {
// Initialize our consumer. // Initialize our consumer.
auto do_wakeup = make_action([this] { auto do_wakeup = make_action([this] {
if (running()) if (running())
prepare_send(); prepare_send();
}); });
in_ = consumer_type::make(pull.try_open(), loop_, std::move(do_wakeup)); in_ = consumer_type::make(pull.try_open(), mpx, std::move(do_wakeup));
if (!in_) { if (!in_) {
auto err = make_error(sec::runtime_error, auto err = make_error(sec::runtime_error,
"flow bridge failed to open the input resource"); "flow bridge failed to open the input resource");
...@@ -62,7 +58,7 @@ public: ...@@ -62,7 +58,7 @@ public:
down_->shutdown(); down_->shutdown();
} }
}); });
out_ = producer_type::make(push.try_open(), loop_, std::move(do_resume), out_ = producer_type::make(push.try_open(), mpx, std::move(do_resume),
std::move(do_cancel)); std::move(do_cancel));
if (!out_) { if (!out_) {
auto err = make_error(sec::runtime_error, auto err = make_error(sec::runtime_error,
...@@ -139,9 +135,6 @@ protected: ...@@ -139,9 +135,6 @@ protected:
/// Converts between raw bytes and native C++ objects. /// Converts between raw bytes and native C++ objects.
Trait trait_; Trait trait_;
/// Our event loop.
async::execution_context_ptr loop_;
/// Type-erased handle to the @ref socket_manager. This reference is important /// Type-erased handle to the @ref socket_manager. This reference is important
/// to keep the bridge alive while the manager is not registered for writing /// to keep the bridge alive while the manager is not registered for writing
/// or reading. /// or reading.
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "caf/net/fwd.hpp" #include "caf/net/fwd.hpp"
#include "caf/net/http/method.hpp" #include "caf/net/http/method.hpp"
#include "caf/net/http/status.hpp" #include "caf/net/http/status.hpp"
#include "caf/string_algorithms.hpp"
#include "caf/uri.hpp" #include "caf/uri.hpp"
#include <string_view> #include <string_view>
...@@ -91,6 +92,25 @@ public: ...@@ -91,6 +92,25 @@ public:
return {}; return {};
} }
/// Checks whether the field `key` exists and equals `val` when using
/// case-insensitive compare.
bool field_equals(ignore_case_t, std::string_view key,
std::string_view val) const noexcept {
if (auto i = fields_.find(key); i != fields_.end())
return icase_equal(val, i->second);
else
return false;
}
/// Checks whether the field `key` exists and equals `val` when using
/// case-insensitive compare.
bool field_equals(std::string_view key, std::string_view val) const noexcept {
if (auto i = fields_.find(key); i != fields_.end())
return val == i->second;
else
return false;
}
/// Returns the value of the field with the specified key as the requested /// Returns the value of the field with the specified key as the requested
/// type T, or std::nullopt if the field is not found or cannot be converted. /// type T, or std::nullopt if the field is not found or cannot be converted.
template <class T> template <class T>
......
...@@ -107,9 +107,10 @@ public: ...@@ -107,9 +107,10 @@ public:
/// asynchronously. /// asynchronously.
request to_request() &&; request to_request() &&;
private: /// Returns a pointer to the HTTP layer.
lower_layer* down(); lower_layer* down();
private:
const request_header* hdr_; const request_header* hdr_;
const_byte_span body_; const_byte_span body_;
http::router* router_; http::router* router_;
......
...@@ -72,6 +72,15 @@ bool match_path(std::string_view lhs, std::string_view rhs, F&& predicate) { ...@@ -72,6 +72,15 @@ bool match_path(std::string_view lhs, std::string_view rhs, F&& predicate) {
} }
return tail2.empty(); return tail2.empty();
} }
template <class, class = void>
struct http_route_has_init : std::false_type {};
template <class T>
struct http_route_has_init<T, std::void_t<decltype(std::declval<T>().init())>>
: std::true_type {};
template <class T>
constexpr bool http_route_has_init_v = http_route_has_init<T>::value;
/// Base type for HTTP routes that parse one or more arguments from the requests /// Base type for HTTP routes that parse one or more arguments from the requests
/// and then forward them to a user-provided function object. /// and then forward them to a user-provided function object.
...@@ -145,6 +154,12 @@ public: ...@@ -145,6 +154,12 @@ public:
// nop // nop
} }
void init() override {
if constexpr (detail::http_route_has_init_v<F>) {
f_.init();
}
}
private: private:
void do_apply(net::http::responder& res, Ts&&... args) override { void do_apply(net::http::responder& res, Ts&&... args) override {
f_(res, std::move(args)...); f_(res, std::move(args)...);
...@@ -184,6 +199,12 @@ public: ...@@ -184,6 +199,12 @@ public:
// nop // nop
} }
void init() override {
if constexpr (detail::http_route_has_init<F>::value) {
f_.init();
}
}
private: private:
void do_apply(net::http::responder& res) override { void do_apply(net::http::responder& res) override {
f_(res); f_(res);
...@@ -211,7 +232,7 @@ private: ...@@ -211,7 +232,7 @@ private:
F f_; F f_;
}; };
/// Default policy class for /// Creates a route from a function object.
template <class F, class... Args> template <class F, class... Args>
net::http::route_ptr net::http::route_ptr
make_http_route_impl(std::string& path, std::optional<net::http::method> method, make_http_route_impl(std::string& path, std::optional<net::http::method> method,
......
...@@ -170,6 +170,8 @@ public: ...@@ -170,6 +170,8 @@ public:
[[nodiscard]] expected<disposable> start(OnStart on_start) { [[nodiscard]] expected<disposable> start(OnStart on_start) {
using consumer_resource = async::consumer_resource<request>; using consumer_resource = async::consumer_resource<request>;
static_assert(std::is_invocable_v<OnStart, consumer_resource>); static_assert(std::is_invocable_v<OnStart, consumer_resource>);
for (auto& ptr : super::config().routes)
ptr->init();
auto& cfg = super::config(); auto& cfg = super::config();
return cfg.visit([this, &cfg, &on_start](auto& data) { return cfg.visit([this, &cfg, &on_start](auto& data) {
return this->do_start(cfg, data, on_start) return this->do_start(cfg, data, on_start)
......
...@@ -34,15 +34,13 @@ public: ...@@ -34,15 +34,13 @@ public:
// We produce the input type of the application. // We produce the input type of the application.
using push_t = async::producer_resource<typename Trait::input_type>; using push_t = async::producer_resource<typename Trait::input_type>;
lp_client_flow_bridge(async::execution_context_ptr loop, pull_t pull, lp_client_flow_bridge(pull_t pull, push_t push)
push_t push) : pull_(std::move(pull)), push_(std::move(push)) {
: super(std::move(loop)), pull_(std::move(pull)), push_(std::move(push)) {
// nop // nop
} }
static std::unique_ptr<lp_client_flow_bridge> make(net::multiplexer* mpx, static std::unique_ptr<lp_client_flow_bridge> make(pull_t pull, push_t push) {
pull_t pull, push_t push) { return std::make_unique<lp_client_flow_bridge>(std::move(pull),
return std::make_unique<lp_client_flow_bridge>(mpx, std::move(pull),
std::move(push)); std::move(push));
} }
...@@ -54,7 +52,7 @@ public: ...@@ -54,7 +52,7 @@ public:
error start(net::lp::lower_layer* down_ptr) override { error start(net::lp::lower_layer* down_ptr) override {
super::down_ = down_ptr; super::down_ = down_ptr;
return super::init(std::move(pull_), std::move(push_)); return super::init(&down_ptr->mpx(), std::move(pull_), std::move(push_));
} }
private: private:
...@@ -101,8 +99,7 @@ private: ...@@ -101,8 +99,7 @@ private:
using transport_t = typename Conn::transport_type; using transport_t = typename Conn::transport_type;
auto [s2a_pull, s2a_push] = async::make_spsc_buffer_resource<input_t>(); auto [s2a_pull, s2a_push] = async::make_spsc_buffer_resource<input_t>();
auto [a2s_pull, a2s_push] = async::make_spsc_buffer_resource<output_t>(); auto [a2s_pull, a2s_push] = async::make_spsc_buffer_resource<output_t>();
auto bridge = bridge_t::make(cfg.mpx, std::move(a2s_pull), auto bridge = bridge_t::make(std::move(a2s_pull), std::move(s2a_push));
std::move(s2a_push));
auto bridge_ptr = bridge.get(); auto bridge_ptr = bridge.get();
auto impl = framing::make(std::move(bridge)); auto impl = framing::make(std::move(bridge));
auto fd = conn.fd(); auto fd = conn.fd();
......
...@@ -46,14 +46,13 @@ public: ...@@ -46,14 +46,13 @@ public:
// one thread running in the multiplexer (which makes this safe). // one thread running in the multiplexer (which makes this safe).
using shared_producer_type = std::shared_ptr<producer_type>; using shared_producer_type = std::shared_ptr<producer_type>;
lp_server_flow_bridge(async::execution_context_ptr loop, lp_server_flow_bridge(shared_producer_type producer)
shared_producer_type producer) : producer_(std::move(producer)) {
: super(std::move(loop)), producer_(std::move(producer)) {
// nop // nop
} }
static auto make(net::multiplexer* mpx, shared_producer_type producer) { static auto make(shared_producer_type producer) {
return std::make_unique<lp_server_flow_bridge>(mpx, std::move(producer)); return std::make_unique<lp_server_flow_bridge>(std::move(producer));
} }
error start(net::lp::lower_layer* down_ptr) override { error start(net::lp::lower_layer* down_ptr) override {
...@@ -66,7 +65,8 @@ public: ...@@ -66,7 +65,8 @@ public:
return make_error(sec::runtime_error, return make_error(sec::runtime_error,
"Length-prefixed connection dropped: client canceled"); "Length-prefixed connection dropped: client canceled");
} }
return super::init(std::move(lp_pull), std::move(lp_push)); return super::init(&down_ptr->mpx(), std::move(lp_pull),
std::move(lp_push));
} }
private: private:
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
net::socket_manager_ptr make(net::multiplexer* mpx, net::socket_manager_ptr make(net::multiplexer* mpx,
connection_handle conn) override { connection_handle conn) override {
using bridge_t = lp_server_flow_bridge<Trait>; using bridge_t = lp_server_flow_bridge<Trait>;
auto bridge = bridge_t::make(mpx, producer_); auto bridge = bridge_t::make(producer_);
auto bridge_ptr = bridge.get(); auto bridge_ptr = bridge.get();
auto impl = net::lp::framing::make(std::move(bridge)); auto impl = net::lp::framing::make(std::move(bridge));
auto fd = conn.fd(); auto fd = conn.fd();
......
...@@ -33,7 +33,7 @@ public: ...@@ -33,7 +33,7 @@ public:
using handshake_ptr = std::unique_ptr<handshake>; using handshake_ptr = std::unique_ptr<handshake>;
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer::client>; using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>;
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
......
...@@ -26,9 +26,9 @@ namespace caf::detail { ...@@ -26,9 +26,9 @@ namespace caf::detail {
/// Specializes the WebSocket flow bridge for the server side. /// Specializes the WebSocket flow bridge for the server side.
template <class Trait, class... Ts> template <class Trait, class... Ts>
class ws_client_flow_bridge class ws_client_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::upper_layer::client> { : public ws_flow_bridge<Trait, net::web_socket::upper_layer> {
public: public:
using super = ws_flow_bridge<Trait, net::web_socket::upper_layer::client>; using super = ws_flow_bridge<Trait, net::web_socket::upper_layer>;
// We consume the output type of the application. // We consume the output type of the application.
using pull_t = async::consumer_resource<typename Trait::input_type>; using pull_t = async::consumer_resource<typename Trait::input_type>;
...@@ -36,21 +36,19 @@ public: ...@@ -36,21 +36,19 @@ public:
// We produce the input type of the application. // We produce the input type of the application.
using push_t = async::producer_resource<typename Trait::output_type>; using push_t = async::producer_resource<typename Trait::output_type>;
ws_client_flow_bridge(async::execution_context_ptr loop, pull_t pull, ws_client_flow_bridge(pull_t pull, push_t push)
push_t push) : pull_(std::move(pull)), push_(std::move(push)) {
: super(std::move(loop)), pull_(std::move(pull)), push_(std::move(push)) {
// nop // nop
} }
static std::unique_ptr<ws_client_flow_bridge> make(net::multiplexer* mpx, static std::unique_ptr<ws_client_flow_bridge> make(pull_t pull, push_t push) {
pull_t pull, push_t push) { return std::make_unique<ws_client_flow_bridge>(std::move(pull),
return std::make_unique<ws_client_flow_bridge>(mpx, std::move(pull),
std::move(push)); std::move(push));
} }
error start(net::web_socket::lower_layer* down_ptr) override { error start(net::web_socket::lower_layer* down_ptr) override {
super::down_ = down_ptr; super::down_ = down_ptr;
return super::init(std::move(pull_), std::move(push_)); return super::init(&down_ptr->mpx(), std::move(pull_), std::move(push_));
} }
private: private:
...@@ -117,8 +115,7 @@ private: ...@@ -117,8 +115,7 @@ private:
using bridge_t = detail::ws_client_flow_bridge<Trait>; using bridge_t = detail::ws_client_flow_bridge<Trait>;
auto [s2a_pull, s2a_push] = async::make_spsc_buffer_resource<input_t>(); auto [s2a_pull, s2a_push] = async::make_spsc_buffer_resource<input_t>();
auto [a2s_pull, a2s_push] = async::make_spsc_buffer_resource<output_t>(); auto [a2s_pull, a2s_push] = async::make_spsc_buffer_resource<output_t>();
auto bridge = bridge_t::make(cfg.mpx, std::move(a2s_pull), auto bridge = bridge_t::make(std::move(a2s_pull), std::move(s2a_push));
std::move(s2a_push));
auto bridge_ptr = bridge.get(); auto bridge_ptr = bridge.get();
auto impl = client::make(std::move(cfg.hs), std::move(bridge)); auto impl = client::make(std::move(cfg.hs), std::move(bridge));
auto fd = conn.fd(); auto fd = conn.fd();
......
...@@ -35,10 +35,6 @@ public: ...@@ -35,10 +35,6 @@ public:
using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>; using upper_layer_ptr = std::unique_ptr<web_socket::upper_layer>;
using server_ptr = std::unique_ptr<web_socket::upper_layer::server>;
using client_ptr = std::unique_ptr<web_socket::upper_layer::client>;
// -- constants -------------------------------------------------------------- // -- constants --------------------------------------------------------------
/// Restricts the size of received frames (including header). /// Restricts the size of received frames (including header).
...@@ -49,13 +45,18 @@ public: ...@@ -49,13 +45,18 @@ public:
// -- constructors, destructors, and assignment operators -------------------- // -- constructors, destructors, and assignment operators --------------------
static std::unique_ptr<framing> make(client_ptr up) { /// Creates a new framing protocol for client mode.
static std::unique_ptr<framing> make_client(upper_layer_ptr up) {
return std::unique_ptr<framing>{new framing(std::move(up))}; return std::unique_ptr<framing>{new framing(std::move(up))};
} }
static std::unique_ptr<framing> make(server_ptr up, /// Creates a new framing protocol for server mode.
http::request_header hdr) { static std::unique_ptr<framing> make_server(upper_layer_ptr up) {
return std::unique_ptr<framing>{new framing(std::move(up), std::move(hdr))}; // > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
auto res = std::unique_ptr<framing>{new framing(std::move(up))};
res->mask_outgoing_frames = false;
return res;
} }
// -- properties ------------------------------------------------------------- // -- properties -------------------------------------------------------------
...@@ -127,17 +128,10 @@ public: ...@@ -127,17 +128,10 @@ public:
private: private:
// -- implementation details ------------------------------------------------- // -- implementation details -------------------------------------------------
explicit framing(client_ptr up) : up_(std::move(up)) { explicit framing(upper_layer_ptr up) : up_(std::move(up)) {
// nop // nop
} }
explicit framing(server_ptr up, http::request_header&& hdr)
: up_(std::move(up)), hdr_(std::move(hdr)) {
// > A server MUST NOT mask any frames that it sends to the client.
// See RFC 6455, Section 5.1.
mask_outgoing_frames = false;
}
bool handle(uint8_t opcode, byte_span payload); bool handle(uint8_t opcode, byte_span payload);
void ship_pong(byte_span payload); void ship_pong(byte_span payload);
...@@ -167,9 +161,6 @@ private: ...@@ -167,9 +161,6 @@ private:
/// Next layer in the processing chain. /// Next layer in the processing chain.
upper_layer_ptr up_; upper_layer_ptr up_;
/// Stored when running as a server and passed to `up_` in start.
std::optional<http::request_header> hdr_;
}; };
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "caf/byte_buffer.hpp" #include "caf/byte_buffer.hpp"
#include "caf/detail/net_export.hpp" #include "caf/detail/net_export.hpp"
#include "caf/dictionary.hpp" #include "caf/dictionary.hpp"
#include "caf/net/fwd.hpp"
#include <cstddef> #include <cstddef>
#include <string> #include <string>
...@@ -135,6 +136,10 @@ public: ...@@ -135,6 +136,10 @@ public:
/// @pre `has_valid_key()` /// @pre `has_valid_key()`
void write_http_1_response(byte_buffer& buf) const; void write_http_1_response(byte_buffer& buf) const;
/// Writes the HTTP response message to `down`.
/// @pre `has_valid_key()`
void write_response(http::lower_layer* down) const;
/// Checks whether the `http_response` contains a HTTP 1.1 response to the /// Checks whether the `http_response` contains a HTTP 1.1 response to the
/// generated HTTP GET request. A valid response contains: /// generated HTTP GET request. A valid response contains:
/// - HTTP status code 101 (Switching Protocols). /// - HTTP status code 101 (Switching Protocols).
......
...@@ -40,49 +40,55 @@ public: ...@@ -40,49 +40,55 @@ public:
using producer_type = async::blocking_producer<accept_event>; using producer_type = async::blocking_producer<accept_event>;
using acceptor_impl_t = net::web_socket::acceptor_impl<Trait, Ts...>;
using ws_res_type = typename acceptor_impl_t::ws_res_type;
// Note: this is shared with the connection factory. In general, this is // Note: this is shared with the connection factory. In general, this is
// *unsafe*. However, we exploit the fact that there is currently only // *unsafe*. However, we exploit the fact that there is currently only
// one thread running in the multiplexer (which makes this safe). // one thread running in the multiplexer (which makes this safe).
using shared_producer_type = std::shared_ptr<producer_type>; using shared_producer_type = std::shared_ptr<producer_type>;
ws_server_flow_bridge(async::execution_context_ptr loop, ws_server_flow_bridge(on_request_cb_type on_request,
on_request_cb_type on_request,
shared_producer_type producer) shared_producer_type producer)
: super(std::move(loop)), : on_request_(std::move(on_request)), producer_(std::move(producer)) {
on_request_(std::move(on_request)),
producer_(std::move(producer)) {
// nop // nop
} }
static auto make(net::multiplexer* mpx, on_request_cb_type on_request, static auto make(on_request_cb_type on_request,
shared_producer_type producer) { shared_producer_type producer) {
return std::make_unique<ws_server_flow_bridge>(mpx, std::move(on_request), return std::make_unique<ws_server_flow_bridge>(std::move(on_request),
std::move(producer)); std::move(producer));
} }
error start(net::web_socket::lower_layer* down_ptr, error start(net::web_socket::lower_layer* down_ptr) override {
const net::http::request_header& hdr) override {
CAF_ASSERT(down_ptr != nullptr); CAF_ASSERT(down_ptr != nullptr);
super::down_ = down_ptr; super::down_ = down_ptr;
if (!producer_->push(app_event)) {
return make_error(sec::runtime_error,
"WebSocket connection dropped: client canceled");
}
auto& [pull, push] = ws_resources;
return super::init(&down_ptr->mpx(), std::move(pull), std::move(push));
}
error accept(const net::http::request_header& hdr) override {
net::web_socket::acceptor_impl<Trait, Ts...> acc{hdr}; net::web_socket::acceptor_impl<Trait, Ts...> acc{hdr};
(*on_request_)(acc); (*on_request_)(acc);
if (!acc.accepted()) { if (acc.accepted()) {
app_event = std::move(acc.app_event);
return {};
}
return std::move(acc) // return std::move(acc) //
.reject_reason() .reject_reason()
.or_else(sec::runtime_error, .or_else(sec::runtime_error, "WebSocket request rejected without reason");
"WebSocket request rejected without reason");
}
if (!producer_->push(acc.app_event)) {
return make_error(sec::runtime_error,
"WebSocket connection dropped: client canceled");
}
auto& [pull, push] = acc.ws_resources;
return super::init(std::move(pull), std::move(push));
} }
private: private:
on_request_cb_type on_request_; on_request_cb_type on_request_;
shared_producer_type producer_; shared_producer_type producer_;
accept_event app_event;
ws_res_type ws_resources;
}; };
/// Specializes @ref connection_factory for the WebSocket protocol. /// Specializes @ref connection_factory for the WebSocket protocol.
...@@ -118,7 +124,7 @@ public: ...@@ -118,7 +124,7 @@ public:
return nullptr; return nullptr;
} }
using bridge_t = ws_server_flow_bridge<Trait, Ts...>; using bridge_t = ws_server_flow_bridge<Trait, Ts...>;
auto app = bridge_t::make(mpx, on_request_, producer_); auto app = bridge_t::make(on_request_, producer_);
auto app_ptr = app.get(); auto app_ptr = app.get();
auto ws = net::web_socket::server::make(std::move(app)); auto ws = net::web_socket::server::make(std::move(app));
auto fd = conn.fd(); auto fd = conn.fd();
......
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/detail/type_list.hpp"
#include "caf/net/web_socket/acceptor.hpp"
#include "caf/net/web_socket/default_trait.hpp"
#include "caf/net/web_socket/server_factory.hpp"
#include <memory>
namespace caf::detail {
/// Specializes the WebSocket flow bridge for the switch-protocol use case.
template <class Trait, class... Ts>
class ws_switch_protocol_flow_bridge
: public ws_flow_bridge<Trait, net::web_socket::upper_layer> {
public:
using super = ws_flow_bridge<Trait, net::web_socket::upper_layer>;
using accept_event = typename Trait::template accept_event<Ts...>;
using producer_type = async::blocking_producer<accept_event>;
using pull_type = async::consumer_resource<typename Trait::output_type>;
using push_type = async::producer_resource<typename Trait::input_type>;
// Note: in general, this is *unsafe*. However, we exploit the fact that there
// is currently only one thread running in the multiplexer (which makes
// this safe).
using shared_producer_type = std::shared_ptr<producer_type>;
ws_switch_protocol_flow_bridge(shared_producer_type producer, pull_type pull,
push_type push)
: producer_(std::move(producer)),
pull_(std::move(pull)),
push_(std::move(push)) {
// nop
}
static auto make(shared_producer_type producer, pull_type pull,
push_type push) {
using impl_t = ws_switch_protocol_flow_bridge;
return std::make_unique<impl_t>(std::move(producer), std::move(pull),
std::move(push));
}
error start(net::web_socket::lower_layer* down_ptr) override {
CAF_ASSERT(down_ptr != nullptr);
super::down_ = down_ptr;
return super::init(&down_ptr->mpx(), std::move(pull_), std::move(push_));
}
private:
shared_producer_type producer_;
pull_type pull_;
push_type push_;
};
template <class OnRequest, class OnStart>
struct ws_switch_protocol_state {
ws_switch_protocol_state(OnRequest on_request_fn, OnStart on_start_fn)
: on_request(std::move(on_request_fn)), on_start(std::move(on_start_fn)) {
// nop
}
OnRequest on_request;
std::optional<OnStart> on_start;
};
template <class Trait, class State, class Out, class... Ts>
class ws_switch_protocol;
template <class Trait, class State, class... Out, class... Ts>
class ws_switch_protocol<Trait, State, type_list<Out...>, Ts...> {
public:
using accept_event = typename Trait::template accept_event<Out...>;
using producer_type = async::blocking_producer<accept_event>;
using shared_producer_type = std::shared_ptr<producer_type>;
explicit ws_switch_protocol(std::shared_ptr<State> state)
: state_(std::move(state)) {
// nop
}
ws_switch_protocol(ws_switch_protocol&&) = default;
ws_switch_protocol(const ws_switch_protocol&) = default;
ws_switch_protocol& operator=(ws_switch_protocol&&) = default;
ws_switch_protocol& operator=(const ws_switch_protocol&) = default;
void operator()(net::http::responder& res, Ts... args) {
namespace http = net::http;
auto& hdr = res.header();
// Sanity checking.
if (!hdr.field_equals(ignore_case, "Connection", "upgrade")
|| !hdr.field_equals(ignore_case, "Upgrade", "websocket")) {
res.respond(net::http::status::bad_request, "text/plain",
"Expected a WebSocket client handshake request.");
return;
}
auto sec_key = hdr.field("Sec-WebSocket-Key");
if (sec_key.empty()) {
res.respond(net::http::status::bad_request, "text/plain",
"Mandatory field Sec-WebSocket-Key missing or invalid.");
return;
}
// Call user-defined on_request callback.
net::web_socket::acceptor_impl<Trait, Out...> acc{hdr};
(state_->on_request)(acc, args...);
if (!acc.accepted()) {
if (auto& err = acc.reject_reason()) {
auto descr = to_string(err);
res.respond(http::status::bad_request, "text/plain", descr);
} else {
res.respond(http::status::bad_request, "text/plain", "Bad request.");
}
return;
}
if (!producer_->push(acc.app_event)) {
res.respond(http::status::internal_server_error, "text/plain",
"Upstream channel closed.");
return;
}
// Finalize the WebSocket handshake.
net::web_socket::handshake hs;
hs.assign_key(sec_key);
hs.write_response(res.down());
// Switch to the WebSocket framing protocol.
auto& [pull, push] = acc.ws_resources;
using net::web_socket::framing;
using bridge_t = ws_switch_protocol_flow_bridge<Trait, Out...>;
auto bridge = bridge_t::make(producer_, std::move(pull), std::move(push));
res.down()->switch_protocol(framing::make_server(std::move(bridge)));
}
void init() {
if (auto& on_start = state_->on_start; on_start) {
auto [pull, push] = async::make_spsc_buffer_resource<accept_event>();
using producer_t = async::blocking_producer<accept_event>;
producer_ = std::make_shared<producer_t>(producer_t{push.try_open()});
(*on_start)(std::move(pull));
on_start = std::nullopt;
}
}
private:
std::shared_ptr<State> state_;
shared_producer_type producer_;
};
} // namespace caf::detail
namespace caf::net::web_socket {
/// Binds a `switch_protocol` invocation to a trait class and a function object
/// for on_request.
template <class Trait, class OnRequest>
struct switch_protocol_bind_2 {
public:
switch_protocol_bind_2(OnRequest on_request)
: on_request_(std::move(on_request)) {
// nop
}
template <class OnStart>
auto on_start(OnStart on_start) && {
using on_request_trait = detail::get_callable_trait_t<OnRequest>;
using on_request_args = typename on_request_trait::arg_types;
return make(on_start, on_request_args{});
}
private:
template <class OnStart, class... Out, class... Ts>
auto make(OnStart& on_start,
detail::type_list<net::web_socket::acceptor<Out...>&, Ts...>) {
using namespace detail;
using state_t = ws_switch_protocol_state<OnRequest, OnStart>;
using impl_t = ws_switch_protocol<Trait, state_t, type_list<Out...>, Ts...>;
auto state = std::make_shared<state_t>(std::move(on_request_),
std::move(on_start));
static_assert(http_route_has_init_v<impl_t>);
return impl_t{std::move(state)};
}
OnRequest on_request_;
};
/// Binds a `switch_protocol` invocation to a trait class.
template <class Trait>
struct switch_protocol_bind_1 {
template <class OnRequest>
auto on_request(OnRequest on_request) {
return switch_protocol_bind_2<Trait, OnRequest>(std::move(on_request));
}
};
template <class Trait = default_trait>
auto switch_protocol() {
return switch_protocol_bind_1<Trait>{};
}
} // namespace caf::net::web_socket
...@@ -25,18 +25,17 @@ public: ...@@ -25,18 +25,17 @@ public:
virtual ptrdiff_t consume_binary(byte_span buf) = 0; virtual ptrdiff_t consume_binary(byte_span buf) = 0;
virtual ptrdiff_t consume_text(std::string_view buf) = 0; virtual ptrdiff_t consume_text(std::string_view buf) = 0;
virtual error start(lower_layer* down) = 0;
}; };
class upper_layer::server : public upper_layer { class upper_layer::server : public upper_layer {
public: public:
virtual ~server(); virtual ~server();
virtual error start(lower_layer* down, const http::request_header& hdr) = 0;
};
class upper_layer::client : public upper_layer { /// Asks the layer to accept a new client.
public: /// @warning the server calls this function *before* calling `start`.
virtual ~client(); virtual error accept(const http::request_header& hdr) = 0;
virtual error start(lower_layer* down) = 0;
}; };
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -50,7 +50,7 @@ void client::abort(const error& reason) { ...@@ -50,7 +50,7 @@ void client::abort(const error& reason) {
up_->abort(reason); up_->abort(reason);
} }
ptrdiff_t client::consume(byte_span buffer, byte_span delta) { ptrdiff_t client::consume(byte_span buffer, byte_span) {
CAF_LOG_TRACE(CAF_ARG2("buffer", buffer.size())); CAF_LOG_TRACE(CAF_ARG2("buffer", buffer.size()));
// Check whether we have received the HTTP header or else wait for more // Check whether we have received the HTTP header or else wait for more
// data. Abort when exceeding the maximum size. // data. Abort when exceeding the maximum size.
...@@ -89,7 +89,7 @@ bool client::handle_header(std::string_view http) { ...@@ -89,7 +89,7 @@ bool client::handle_header(std::string_view http) {
auto http_ok = hs_->is_valid_http_1_response(http); auto http_ok = hs_->is_valid_http_1_response(http);
hs_.reset(); hs_.reset();
if (http_ok) { if (http_ok) {
down_->switch_protocol(framing::make(std::move(up_))); down_->switch_protocol(framing::make_client(std::move(up_)));
return true; return true;
} }
CAF_LOG_DEBUG("received an invalid WebSocket handshake"); CAF_LOG_DEBUG("received an invalid WebSocket handshake");
......
...@@ -15,22 +15,7 @@ error framing::start(octet_stream::lower_layer* down) { ...@@ -15,22 +15,7 @@ error framing::start(octet_stream::lower_layer* down) {
std::random_device rd; std::random_device rd;
rng_.seed(rd()); rng_.seed(rd());
down_ = down; down_ = down;
if (!hdr_) { return up_->start(this);
using dptr_t = web_socket::upper_layer::client;
return static_cast<dptr_t*>(up_.get())->start(this);
}
using dptr_t = web_socket::upper_layer::server;
auto err = static_cast<dptr_t*>(up_.get())->start(this, *hdr_);
hdr_ = std::nullopt;
if (err) {
auto descr = to_string(err);
CAF_LOG_DEBUG("upper layer rejected a WebSocket connection:" << descr);
down_->begin_output();
http::v1::write_response(http::status::bad_request, "text/plain", descr,
down_->output_buffer());
down_->end_output();
}
return err;
} }
void framing::abort(const error& reason) { void framing::abort(const error& reason) {
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include "caf/config.hpp" #include "caf/config.hpp"
#include "caf/detail/base64.hpp" #include "caf/detail/base64.hpp"
#include "caf/hash/sha1.hpp" #include "caf/hash/sha1.hpp"
#include "caf/net/http/lower_layer.hpp"
#include "caf/net/http/status.hpp"
#include "caf/string_algorithms.hpp" #include "caf/string_algorithms.hpp"
#include <algorithm> #include <algorithm>
...@@ -118,6 +120,15 @@ void handshake::write_http_1_response(byte_buffer& buf) const { ...@@ -118,6 +120,15 @@ void handshake::write_http_1_response(byte_buffer& buf) const {
<< response_key() << "\r\n\r\n"; << response_key() << "\r\n\r\n";
} }
void handshake::write_response(http::lower_layer* down) const {
down->begin_header(http::status::switching_protocols);
down->add_header_field("Upgrade", "websocket");
down->add_header_field("Connection", "Upgrade");
down->add_header_field("Sec-WebSocket-Accept", response_key());
down->end_header();
down->send_payload({});
}
namespace { namespace {
template <class F> template <class F>
......
...@@ -80,14 +80,20 @@ bool server::handle_header(std::string_view http) { ...@@ -80,14 +80,20 @@ bool server::handle_header(std::string_view http) {
CAF_LOG_DEBUG("received invalid WebSocket handshake"); CAF_LOG_DEBUG("received invalid WebSocket handshake");
return false; return false;
} }
// Kindly ask the upper layer to accept a new WebSocket connection.
if (auto err = up_->accept(hdr)) {
write_response(http::status::bad_request, to_string(err));
return false;
}
// Finalize the WebSocket handshake. // Finalize the WebSocket handshake.
handshake hs; handshake hs;
hs.assign_key(sec_key); hs.assign_key(sec_key);
down_->begin_output(); down_->begin_output();
hs.write_http_1_response(down_->output_buffer()); hs.write_http_1_response(down_->output_buffer());
down_->end_output(); down_->end_output();
// All done. Switch to the framing protocol.
CAF_LOG_DEBUG("completed WebSocket handshake"); CAF_LOG_DEBUG("completed WebSocket handshake");
down_->switch_protocol(framing::make(std::move(up_), std::move(hdr))); down_->switch_protocol(framing::make_server(std::move(up_)));
return true; return true;
} }
......
...@@ -14,8 +14,4 @@ upper_layer::server::~server() { ...@@ -14,8 +14,4 @@ upper_layer::server::~server() {
// nop // nop
} }
upper_layer::client::~client() {
// nop
}
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
...@@ -14,7 +14,7 @@ namespace { ...@@ -14,7 +14,7 @@ namespace {
using svec = std::vector<std::string>; using svec = std::vector<std::string>;
class app_t : public net::web_socket::upper_layer::client { class app_t : public net::web_socket::upper_layer {
public: public:
static auto make() { static auto make() {
return std::make_unique<app_t>(); return std::make_unique<app_t>();
......
...@@ -32,9 +32,12 @@ public: ...@@ -32,9 +32,12 @@ public:
return std::make_unique<app_t>(); return std::make_unique<app_t>();
} }
error start(net::web_socket::lower_layer* down, error start(net::web_socket::lower_layer* down) override {
const net::http::request_header& hdr) override {
down->request_messages(); down->request_messages();
return none;
}
error accept(const net::http::request_header& hdr) override {
// Store the request information in cfg to evaluate them later. // Store the request information in cfg to evaluate them later.
auto& ws = cfg["web-socket"].as_dictionary(); auto& ws = cfg["web-socket"].as_dictionary();
put(ws, "method", to_rfc_string(hdr.method())); put(ws, "method", to_rfc_string(hdr.method()));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment