Commit 215e4ed3 authored by Dominik Charousset's avatar Dominik Charousset

Make a pass over the router/route API

parent 300e68f7
...@@ -61,10 +61,10 @@ int caf_main(caf::actor_system& sys, const config& cfg) { ...@@ -61,10 +61,10 @@ int caf_main(caf::actor_system& sys, const config& cfg) {
// Limit how many clients may be connected at any given time. // Limit how many clients may be connected at any given time.
.max_connections(max_connections) .max_connections(max_connections)
// Accept only requests for path "/". // Accept only requests for path "/".
.on_request([](ws::acceptor<>& acc, const http::request_header& hdr) { .on_request([](ws::acceptor<>& acc) {
// The hdr parameter is a dictionary with fields from the WebSocket // The hdr parameter is a dictionary with fields from the WebSocket
// handshake such as the path. // handshake such as the path.
auto path = hdr.path(); auto path = acc.header().path();
std::cout << "*** new client request for path " << path << '\n'; std::cout << "*** new client request for path " << path << '\n';
// Accept the WebSocket connection only if the path is "/". // Accept the WebSocket connection only if the path is "/".
if (path == "/") { if (path == "/") {
......
...@@ -140,13 +140,12 @@ int caf_main(caf::actor_system& sys, const config& cfg) { ...@@ -140,13 +140,12 @@ int caf_main(caf::actor_system& sys, const config& cfg) {
// Limit how many clients may be connected at any given time. // Limit how many clients may be connected at any given time.
.max_connections(max_connections) .max_connections(max_connections)
// Forward the path from the WebSocket request to the worker. // Forward the path from the WebSocket request to the worker.
.on_request([](ws::acceptor<caf::cow_string>& acc, .on_request([](ws::acceptor<caf::cow_string>& acc) {
const http::request_header& hdr) {
// The hdr parameter is a dictionary with fields from the WebSocket // The hdr parameter is a dictionary with fields from the WebSocket
// handshake such as the path. This is only field we care about // handshake such as the path. This is only field we care about
// here. By passing the (copy-on-write) string to accept() here, we // here. By passing the (copy-on-write) string to accept() here, we
// make it available to the worker through the acceptor_resource. // make it available to the worker through the acceptor_resource.
acc.accept(caf::cow_string{hdr.path()}); acc.accept(caf::cow_string{acc.header().path()});
}) })
// When started, run our worker actor to handle incoming connections. // When started, run our worker actor to handle incoming connections.
.start([&sys](trait::acceptor_resource<caf::cow_string> events) { .start([&sys](trait::acceptor_resource<caf::cow_string> events) {
......
...@@ -185,7 +185,7 @@ int caf_main(caf::actor_system& sys, const config& cfg) { ...@@ -185,7 +185,7 @@ int caf_main(caf::actor_system& sys, const config& cfg) {
// Limit how many clients may be connected at any given time. // Limit how many clients may be connected at any given time.
.max_connections(max_connections) .max_connections(max_connections)
// Add handler for incoming connections. // Add handler for incoming connections.
.on_request([](ws::acceptor<>& acc, const http::request_header&) { .on_request([](ws::acceptor<>& acc) {
// Ignore all header fields and accept the connection. // Ignore all header fields and accept the connection.
acc.accept(); acc.accept();
}) })
......
...@@ -43,6 +43,7 @@ caf_add_component( ...@@ -43,6 +43,7 @@ caf_add_component(
src/net/http/request_header.cpp src/net/http/request_header.cpp
src/net/http/responder.cpp src/net/http/responder.cpp
src/net/http/response.cpp src/net/http/response.cpp
src/net/http/route.cpp
src/net/http/router.cpp src/net/http/router.cpp
src/net/http/server.cpp src/net/http/server.cpp
src/net/http/server_factory.cpp src/net/http/server_factory.cpp
......
...@@ -129,6 +129,7 @@ class lower_layer; ...@@ -129,6 +129,7 @@ class lower_layer;
class request; class request;
class request_header; class request_header;
class responder; class responder;
class route;
class router; class router;
class server; class server;
class upper_layer; class upper_layer;
...@@ -136,6 +137,8 @@ class upper_layer; ...@@ -136,6 +137,8 @@ class upper_layer;
enum class method : uint8_t; enum class method : uint8_t;
enum class status : uint16_t; enum class status : uint16_t;
using route_ptr = intrusive_ptr<route>;
} // namespace caf::net::http } // namespace caf::net::http
namespace caf::net::ssl { namespace caf::net::ssl {
......
...@@ -53,6 +53,12 @@ public: ...@@ -53,6 +53,12 @@ public:
/// @copydoc send_response /// @copydoc send_response
bool send_response(status code, std::string_view content_type, bool send_response(status code, std::string_view content_type,
std::string_view content); std::string_view content);
/// Asks the stream to swap the HTTP layer with `next` after returning from
/// `consume`.
/// @note may only be called from the upper layer in `consume`.
virtual void switch_protocol(std::unique_ptr<octet_stream::upper_layer> next)
= 0;
}; };
} // namespace caf::net::http } // namespace caf::net::http
// This file is part of CAF, the C++ Actor Framework. See the file LICENSE in
// the main distribution directory for license terms and copyright or visit
// https://github.com/actor-framework/actor-framework/blob/master/LICENSE.
#pragma once
#include "caf/byte_span.hpp"
#include "caf/intrusive_ptr.hpp"
#include "caf/net/fwd.hpp"
#include "caf/net/http/arg_parser.hpp"
#include "caf/net/http/request.hpp"
#include "caf/net/http/request_header.hpp"
#include "caf/net/http/responder.hpp"
#include "caf/ref_counted.hpp"
#include <string_view>
#include <tuple>
namespace caf::net::http {
/// Represents a single route for HTTP requests at a server.
class route : public ref_counted {
public:
virtual ~route();
/// Tries to match an HTTP request and processes the request on a match. The
/// route may send errors to the client or call `shutdown` on the `parent` for
/// severe errors.
/// @param hdr The HTTP request header from the client.
/// @param body The payload from the client.
/// @param parent Pointer to the object that uses this route.
/// @return `true` if the route matches the request, `false` otherwise.
virtual bool
exec(const request_header& hdr, const_byte_span body, router* parent)
= 0;
/// Called by the HTTP server when starting up. May be used to spin up workers
/// that the path dispatches to. The default implementation does nothing.
virtual void init();
};
} // namespace caf::net::http
namespace caf::detail {
/// Counts how many `<arg>` entries are in `path`.
size_t args_in_path(std::string_view path);
/// Splits `str` in the first component of a path and its remainder.
std::pair<std::string_view, std::string_view>
next_path_component(std::string_view str);
/// Matches two paths by splitting both inputs at '/' and then checking that
/// `predicate` holds for each resulting pair.
template <class F>
bool match_path(std::string_view lhs, std::string_view rhs, F&& predicate) {
std::string_view head1;
std::string_view tail1;
std::string_view head2;
std::string_view tail2;
std::tie(head1, tail1) = next_path_component(lhs);
std::tie(head2, tail2) = next_path_component(rhs);
if (!predicate(head1, head2))
return false;
while (!tail1.empty()) {
if (tail2.empty())
return false;
std::tie(head1, tail1) = next_path_component(tail1);
std::tie(head2, tail2) = next_path_component(tail2);
if (!predicate(head1, head2))
return false;
}
return tail2.empty();
}
/// Base type for HTTP routes that parse one or more arguments from the requests
/// and then forward them to a user-provided function object.
template <class... Ts>
class http_route_base : public net::http::route {
public:
explicit http_route_base(std::string&& path,
std::optional<net::http::method> method)
: path_(std::move(path)), method_(method) {
// nop
}
bool exec(const net::http::request_header& hdr, const_byte_span body,
net::http::router* parent) override {
if (method_ && *method_ != hdr.method())
return false;
// Try to match the path to the expected path and extract args.
std::string_view args[sizeof...(Ts)];
auto ok = match_path(path_, hdr.path(),
[pos = args](std::string_view lhs,
std::string_view rhs) mutable {
if (lhs == "<arg>") {
*pos++ = rhs;
return true;
} else {
return lhs == rhs;
}
});
if (!ok)
return false;
// Try to parse the arguments.
using iseq = std::make_index_sequence<sizeof...(Ts)>;
return exec_dis(hdr, body, parent, iseq{}, args);
}
template <size_t... Is>
bool exec_dis(const net::http::request_header& hdr, const_byte_span body,
net::http::router* parent, std::index_sequence<Is...>,
std::string_view* arr) {
return exec_impl(hdr, body, parent,
std::get<Is>(parsers_).parse(arr[Is])...);
}
template <class... Is>
bool exec_impl(const net::http::request_header& hdr, const_byte_span body,
net::http::router* parent, std::optional<Ts>&&... args) {
if ((args.has_value() && ...)) {
net::http::responder rp{&hdr, body, parent};
do_apply(rp, std::move(*args)...);
return true;
}
return false;
}
private:
virtual void do_apply(net::http::responder&, Ts&&...) = 0;
std::string path_;
std::optional<net::http::method> method_;
std::tuple<net::http::arg_parser_t<Ts>...> parsers_;
};
template <class F, class... Ts>
class http_route_impl : public http_route_base<Ts...> {
public:
using super = http_route_base<Ts...>;
http_route_impl(std::string&& path, std::optional<net::http::method> method,
F&& f)
: super(std::move(path), method), f_(std::move(f)) {
// nop
}
private:
void do_apply(net::http::responder& res, Ts&&... args) override {
f_(res, std::move(args)...);
}
F f_;
};
/// A simple implementation for `http::route` that does not parse any arguments
/// from the requests and simply calls the user-provided function object.
class http_simple_route_base : public net::http::route {
public:
http_simple_route_base(std::string&& path,
std::optional<net::http::method> method)
: path_(std::move(path)), method_(method) {
// nop
}
bool exec(const net::http::request_header& hdr, const_byte_span body,
net::http::router* parent) override;
private:
virtual void do_apply(net::http::responder&) = 0;
std::string path_;
std::optional<net::http::method> method_;
};
template <class F>
class http_simple_route_impl : public http_simple_route_base {
public:
using super = http_simple_route_base;
http_simple_route_impl(std::string&& path,
std::optional<net::http::method> method, F&& f)
: super(std::move(path), method), f_(std::move(f)) {
// nop
}
private:
void do_apply(net::http::responder& res) override {
f_(res);
}
F f_;
};
/// Represents an HTTP route that matches any path.
template <class F>
class http_catch_all_route_impl : public net::http::route {
public:
explicit http_catch_all_route_impl(F&& f) : f_(std::move(f)) {
// nop
}
bool exec(const net::http::request_header& hdr, const_byte_span body,
net::http::router* parent) override {
net::http::responder rp{&hdr, body, parent};
f_(rp);
return true;
}
private:
F f_;
};
/// Default policy class for
template <class F, class... Args>
net::http::route_ptr
make_http_route_impl(std::string& path, std::optional<net::http::method> method,
F& f, detail::type_list<net::http::responder&, Args...>) {
if constexpr (sizeof...(Args) == 0) {
using impl_t = http_simple_route_impl<F>;
return make_counted<impl_t>(std::move(path), method, std::move(f));
} else {
using impl_t = http_route_impl<F, Args...>;
return make_counted<impl_t>(std::move(path), method, std::move(f));
}
}
} // namespace caf::detail
namespace caf::net::http {
/// Creates a @ref route object from a function object.
/// @param path Description of the path, optionally with `<arg>` placeholders.
/// @param method The HTTP method for the path or `std::nullopt` for "any".
/// @param f The callback for the path.
/// @returns a @ref path on success, an error when failing to parse the path or
/// to match it to the signature of `f`.
template <class F>
expected<route_ptr>
make_route(std::string path, std::optional<http::method> method, F f) {
// F must have signature void (responder&, ...).
using f_trait = detail::get_callable_trait_t<F>;
using f_args = typename f_trait::arg_types;
static_assert(f_trait::num_args > 0, "F must take at least one argument");
using arg_0 = detail::tl_at_t<f_args, 0>;
static_assert(std::is_same_v<arg_0, responder&>,
"F must take 'responder&' as first argument");
// The path must be absolute.
if (path.empty() || path.front() != '/') {
return make_error(sec::invalid_argument,
"expected an absolute path, got: " + path);
}
// The path must have as many <arg> entries as F takes extra arguments.
auto num_args = detail::args_in_path(path);
if (num_args != f_trait::num_args - 1) {
auto msg = path;
msg += " defines ";
detail::print(msg, num_args);
msg += " arguments, but F accepts ";
detail::print(msg, f_trait::num_args - 1);
return make_error(sec::invalid_argument, std::move(msg));
}
// Create the object.
return detail::make_http_route_impl(path, method, f, f_args{});
}
/// Convenience function for calling `make_route(path, std::nullopt, f)`.
template <class F>
expected<route_ptr> make_route(std::string path, F f) {
return make_route(std::move(path), std::nullopt, std::move(f));
}
/// Creates a @ref route that matches all paths.
template <class F>
expected<route_ptr> make_route(F f) {
// F must have signature void (responder&).
using f_trait = detail::get_callable_trait_t<F>;
static_assert(std::is_same_v<typename f_trait::f_sig, void(responder&)>);
using impl_t = detail::http_catch_all_route_impl<F>;
return make_counted<impl_t>(std::move(f));
}
} // namespace caf::net::http
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "caf/net/http/arg_parser.hpp" #include "caf/net/http/arg_parser.hpp"
#include "caf/net/http/lower_layer.hpp" #include "caf/net/http/lower_layer.hpp"
#include "caf/net/http/responder.hpp" #include "caf/net/http/responder.hpp"
#include "caf/net/http/route.hpp"
#include "caf/net/http/upper_layer.hpp" #include "caf/net/http/upper_layer.hpp"
#include "caf/ref_counted.hpp" #include "caf/ref_counted.hpp"
...@@ -27,50 +28,6 @@ namespace caf::net::http { ...@@ -27,50 +28,6 @@ namespace caf::net::http {
/// user-defined handlers. /// user-defined handlers.
class CAF_NET_EXPORT router : public upper_layer { class CAF_NET_EXPORT router : public upper_layer {
public: public:
// -- member types -----------------------------------------------------------
class route : public ref_counted {
public:
virtual ~route();
/// Returns `true` if the route accepted the request, `false` otherwise.
virtual bool
exec(const request_header& hdr, const_byte_span body, router* parent)
= 0;
/// Counts how many `<arg>` entries are in `path`.
static size_t args_in_path(std::string_view path);
/// Splits `str` in the first component of a path and its remainder.
static std::pair<std::string_view, std::string_view>
next_component(std::string_view str);
/// Matches a two paths by splitting both inputs at '/' and then checking
/// that `predicate` holds for each resulting pair.
template <class F>
bool match_path(std::string_view lhs, std::string_view rhs, F&& predicate) {
std::string_view head1;
std::string_view tail1;
std::string_view head2;
std::string_view tail2;
std::tie(head1, tail1) = next_component(lhs);
std::tie(head2, tail2) = next_component(rhs);
if (!predicate(head1, head2))
return false;
while (!tail1.empty()) {
if (tail2.empty())
return false;
std::tie(head1, tail1) = next_component(tail1);
std::tie(head2, tail2) = next_component(tail2);
if (!predicate(head1, head2))
return false;
}
return tail2.empty();
}
};
using route_ptr = intrusive_ptr<route>;
// -- constructors and destructors ------------------------------------------- // -- constructors and destructors -------------------------------------------
router() = default; router() = default;
...@@ -85,34 +42,6 @@ public: ...@@ -85,34 +42,6 @@ public:
static std::unique_ptr<router> make(std::vector<route_ptr> routes); static std::unique_ptr<router> make(std::vector<route_ptr> routes);
/// Tries to create a new HTTP route.
/// @param path The path on this server for the new route.
/// @param f The function object for handling requests on the new route.
/// @return the @ref route object on success, an @ref error otherwise.
template <class F>
static expected<route_ptr> make_route(std::string path, F f) {
return make_route_dis(path, std::nullopt, f);
}
/// Tries to create a new HTTP route.
/// @param path The path on this server for the new route.
/// @param method The allowed HTTP method on the new route.
/// @param f The function object for handling requests on the new route.
/// @return the @ref route object on success, an @ref error otherwise.
template <class F>
static expected<route_ptr>
make_route(std::string path, http::method method, F f) {
return make_route_dis(path, method, f);
}
/// Create a new HTTP default "catch all" route.
/// @param f The function object for handling the requests.
/// @return the @ref route object.
template <class F>
static route_ptr make_route(F f) {
return make_counted<default_route_impl<F>>(std::move(f));
}
// -- properties ------------------------------------------------------------- // -- properties -------------------------------------------------------------
lower_layer* down() { lower_layer* down() {
...@@ -141,152 +70,6 @@ public: ...@@ -141,152 +70,6 @@ public:
void abort(const error& reason) override; void abort(const error& reason) override;
private: private:
template <class F, class... Ts>
class route_impl : public route {
public:
explicit route_impl(std::string&& path, std::optional<http::method> method,
F&& f)
: path_(std::move(path)), method_(method), f_(std::move(f)) {
// nop
}
bool exec(const request_header& hdr, const_byte_span body,
router* parent) override {
if (method_ && *method_ != hdr.method())
return false;
// Try to match the path to the expected path and extract args.
std::string_view args[sizeof...(Ts)];
auto ok = match_path(path_, hdr.path(),
[pos = args](std::string_view lhs,
std::string_view rhs) mutable {
if (lhs == "<arg>") {
*pos++ = rhs;
return true;
} else {
return lhs == rhs;
}
});
if (!ok)
return false;
// Try to parse the arguments.
using iseq = std::make_index_sequence<sizeof...(Ts)>;
return exec_dis(hdr, body, parent, iseq{}, args);
}
template <size_t... Is>
bool exec_dis(const request_header& hdr, const_byte_span body,
router* parent, std::index_sequence<Is...>,
std::string_view* arr) {
return exec_impl(hdr, body, parent,
std::get<Is>(parsers_).parse(arr[Is])...);
}
template <class... Is>
bool exec_impl(const request_header& hdr, const_byte_span body,
router* parent, std::optional<Ts>&&... args) {
if ((args.has_value() && ...)) {
responder rp{&hdr, body, parent};
f_(rp, std::move(*args)...);
return true;
}
return false;
}
private:
std::string path_;
std::optional<http::method> method_;
F f_;
std::tuple<arg_parser_t<Ts>...> parsers_;
};
template <class F>
class trivial_route_impl : public route {
public:
explicit trivial_route_impl(std::string&& path,
std::optional<http::method> method, F&& f)
: path_(std::move(path)), method_(method), f_(std::move(f)) {
// nop
}
bool exec(const request_header& hdr, const_byte_span body,
router* parent) override {
if (method_ && *method_ != hdr.method())
return false;
if (hdr.path() == path_) {
responder rp{&hdr, body, parent};
f_(rp);
return true;
}
return false;
}
private:
std::string path_;
std::optional<http::method> method_;
F f_;
};
template <class F>
class default_route_impl : public route {
public:
explicit default_route_impl(F&& f) : f_(std::move(f)) {
// nop
}
bool exec(const request_header& hdr, const_byte_span body,
router* parent) override {
responder rp{&hdr, body, parent};
f_(rp);
return true;
}
private:
F f_;
};
// Dispatches to make_route_impl after sanity checking.
template <class F>
static expected<route_ptr>
make_route_dis(std::string& path, std::optional<http::method> method, F& f) {
// F must have signature void (responder&, ...).
using f_trait = detail::get_callable_trait_t<F>;
using f_args = typename f_trait::arg_types;
static_assert(f_trait::num_args > 0, "F must take at least one argument");
using arg_0 = detail::tl_at_t<f_args, 0>;
static_assert(std::is_same_v<arg_0, responder&>,
"F must take 'responder&' as first argument");
// The path must be absolute.
if (path.empty() || path.front() != '/') {
return make_error(sec::invalid_argument,
"expected an absolute path, got: " + path);
}
// The path must has as many <arg> entries as F takes extra arguments.
auto num_args = route::args_in_path(path);
if (num_args != f_trait::num_args - 1) {
auto msg = path;
msg += " defines ";
detail::print(msg, num_args);
msg += " arguments, but F accepts ";
detail::print(msg, f_trait::num_args - 1);
return make_error(sec::invalid_argument, std::move(msg));
}
// Dispatch to the actual factory.
return make_route_impl(path, method, f, f_args{});
}
template <class F, class... Args>
static expected<route_ptr>
make_route_impl(std::string& path, std::optional<http::method> method, F& f,
detail::type_list<responder&, Args...>) {
if constexpr (sizeof...(Args) == 0) {
return make_counted<trivial_route_impl<F>>(std::move(path), method,
std::move(f));
} else {
return make_counted<route_impl<F, Args...>>(std::move(path), method,
std::move(f));
}
}
lower_layer* down_ = nullptr; lower_layer* down_ = nullptr;
std::vector<route_ptr> routes_; std::vector<route_ptr> routes_;
size_t request_id_ = 0; size_t request_id_ = 0;
......
...@@ -103,6 +103,8 @@ public: ...@@ -103,6 +103,8 @@ public:
bool send_end_of_chunks() override; bool send_end_of_chunks() override;
void switch_protocol(std::unique_ptr<octet_stream::upper_layer>) override;
// -- octet_stream::upper_layer implementation ------------------------------- // -- octet_stream::upper_layer implementation -------------------------------
error start(octet_stream::lower_layer* down) override; error start(octet_stream::lower_layer* down) override;
......
...@@ -69,7 +69,7 @@ class http_conn_factory ...@@ -69,7 +69,7 @@ class http_conn_factory
public: public:
using connection_handle = typename Transport::connection_handle; using connection_handle = typename Transport::connection_handle;
http_conn_factory(std::vector<net::http::router::route_ptr> routes, http_conn_factory(std::vector<net::http::route_ptr> routes,
size_t max_consecutive_reads) size_t max_consecutive_reads)
: routes_(std::move(routes)), : routes_(std::move(routes)),
max_consecutive_reads_(max_consecutive_reads) { max_consecutive_reads_(max_consecutive_reads) {
...@@ -90,7 +90,7 @@ public: ...@@ -90,7 +90,7 @@ public:
} }
private: private:
std::vector<net::http::router::route_ptr> routes_; std::vector<net::http::route_ptr> routes_;
size_t max_consecutive_reads_; size_t max_consecutive_reads_;
}; };
...@@ -114,7 +114,7 @@ public: ...@@ -114,7 +114,7 @@ public:
server_factory_config(const server_factory_config&) = default; server_factory_config(const server_factory_config&) = default;
std::vector<router::route_ptr> routes; std::vector<route_ptr> routes;
}; };
/// Factory type for the `with(...).accept(...).start(...)` DSL. /// Factory type for the `with(...).accept(...).start(...)` DSL.
...@@ -136,7 +136,7 @@ public: ...@@ -136,7 +136,7 @@ public:
auto& cfg = super::config(); auto& cfg = super::config();
if (cfg.failed()) if (cfg.failed())
return *this; return *this;
auto new_route = router::make_route(std::move(path), std::move(f)); auto new_route = make_route(std::move(path), std::move(f));
if (!new_route) { if (!new_route) {
cfg.fail(std::move(new_route.error())); cfg.fail(std::move(new_route.error()));
} else { } else {
...@@ -155,7 +155,7 @@ public: ...@@ -155,7 +155,7 @@ public:
auto& cfg = super::config(); auto& cfg = super::config();
if (cfg.failed()) if (cfg.failed())
return *this; return *this;
auto new_route = router::make_route(std::move(path), method, std::move(f)); auto new_route = make_route(std::move(path), method, std::move(f));
if (!new_route) { if (!new_route) {
cfg.fail(std::move(new_route.error())); cfg.fail(std::move(new_route.error()));
} else { } else {
...@@ -214,7 +214,7 @@ private: ...@@ -214,7 +214,7 @@ private:
auto [pull, push] = async::make_spsc_buffer_resource<request>(); auto [pull, push] = async::make_spsc_buffer_resource<request>();
auto producer = detail::http_request_producer::make(cfg.mpx, auto producer = detail::http_request_producer::make(cfg.mpx,
push.try_open()); push.try_open());
routes.push_back(router::make_route([producer](responder& res) { routes.push_back(make_route([producer](responder& res) {
if (!producer->push(std::move(res).to_request())) { if (!producer->push(std::move(res).to_request())) {
auto err = make_error(sec::runtime_error, "flow disconnected"); auto err = make_error(sec::runtime_error, "flow disconnected");
res.router()->shutdown(err); res.router()->shutdown(err);
......
...@@ -43,6 +43,9 @@ public: ...@@ -43,6 +43,9 @@ public:
/// returning from `consume()`. /// returning from `consume()`.
/// @note may only be called from the upper layer in `consume`. /// @note may only be called from the upper layer in `consume`.
virtual void switch_protocol(std::unique_ptr<upper_layer> next) = 0; virtual void switch_protocol(std::unique_ptr<upper_layer> next) = 0;
/// Queries whether `switch_protocol` has been called.
virtual bool switching_protocol() const noexcept = 0;
}; };
} // namespace caf::net::octet_stream } // namespace caf::net::octet_stream
...@@ -83,6 +83,8 @@ public: ...@@ -83,6 +83,8 @@ public:
void switch_protocol(upper_layer_ptr) override; void switch_protocol(upper_layer_ptr) override;
bool switching_protocol() const noexcept override;
// -- properties ------------------------------------------------------------- // -- properties -------------------------------------------------------------
auto& read_buffer() noexcept { auto& read_buffer() noexcept {
......
...@@ -22,6 +22,10 @@ public: ...@@ -22,6 +22,10 @@ public:
template <class Trait> template <class Trait>
using server_factory_type = server_factory<Trait, Ts...>; using server_factory_type = server_factory<Trait, Ts...>;
explicit acceptor(const http::request_header& hdr) : hdr_(hdr) {
// nop
}
virtual ~acceptor() = default; virtual ~acceptor() = default;
virtual void accept(Ts... xs) = 0; virtual void accept(Ts... xs) = 0;
...@@ -44,7 +48,12 @@ public: ...@@ -44,7 +48,12 @@ public:
return reject_reason_; return reject_reason_;
} }
const http::request_header& header() const noexcept {
return hdr_;
}
protected: protected:
const http::request_header& hdr_;
bool accepted_ = false; bool accepted_ = false;
error reject_reason_; error reject_reason_;
}; };
...@@ -54,6 +63,8 @@ class acceptor_impl : public acceptor<Ts...> { ...@@ -54,6 +63,8 @@ class acceptor_impl : public acceptor<Ts...> {
public: public:
using super = acceptor<Ts...>; using super = acceptor<Ts...>;
using super::super;
using input_type = typename Trait::input_type; using input_type = typename Trait::input_type;
using output_type = typename Trait::output_type; using output_type = typename Trait::output_type;
......
...@@ -31,19 +31,15 @@ public: ...@@ -31,19 +31,15 @@ public:
auto on_request(OnRequest on_request) { auto on_request(OnRequest on_request) {
// Type checking. // Type checking.
using fn_trait = detail::get_callable_trait_t<OnRequest>; using fn_trait = detail::get_callable_trait_t<OnRequest>;
static_assert(fn_trait::num_args == 2, static_assert(fn_trait::num_args == 1,
"on_request must take exactly two arguments"); "on_request must take exactly one argument");
using arg_types = typename fn_trait::arg_types; using arg_types = typename fn_trait::arg_types;
using arg1_t = detail::tl_at_t<arg_types, 0>; using arg1_t = detail::tl_at_t<arg_types, 0>;
using arg2_t = detail::tl_at_t<arg_types, 1>;
using acceptor_t = std::decay_t<arg1_t>; using acceptor_t = std::decay_t<arg1_t>;
static_assert(is_acceptor_v<acceptor_t>, static_assert(is_acceptor_v<acceptor_t>,
"on_request must take an acceptor as 1st argument"); "on_request must take an acceptor as its argument");
static_assert(std::is_same_v<arg1_t, acceptor_t&>, static_assert(std::is_same_v<arg1_t, acceptor_t&>,
"on_request must take the acceptor as mutable reference"); "on_request must take the acceptor as mutable reference");
static_assert(
std::is_same_v<arg2_t, const http::request_header&>,
"on_request must take 'const http::request_header&' as 2nd argument");
// Wrap the callback and return the factory object. // Wrap the callback and return the factory object.
using factory_t = typename acceptor_t::template server_factory_type<Trait>; using factory_t = typename acceptor_t::template server_factory_type<Trait>;
auto callback = make_shared_type_erased_callback(std::move(on_request)); auto callback = make_shared_type_erased_callback(std::move(on_request));
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include "caf/detail/ws_flow_bridge.hpp" #include "caf/detail/ws_flow_bridge.hpp"
#include "caf/fwd.hpp" #include "caf/fwd.hpp"
#include "caf/net/dsl/server_factory_base.hpp" #include "caf/net/dsl/server_factory_base.hpp"
#include "caf/net/http/request_header.hpp"
#include "caf/net/http/server.hpp" #include "caf/net/http/server.hpp"
#include "caf/net/multiplexer.hpp" #include "caf/net/multiplexer.hpp"
#include "caf/net/octet_stream/transport.hpp" #include "caf/net/octet_stream/transport.hpp"
...@@ -35,9 +34,7 @@ public: ...@@ -35,9 +34,7 @@ public:
using ws_acceptor_t = net::web_socket::acceptor<Ts...>; using ws_acceptor_t = net::web_socket::acceptor<Ts...>;
using on_request_cb_type using on_request_cb_type = shared_callback_ptr<void(ws_acceptor_t&)>;
= shared_callback_ptr<void(ws_acceptor_t&, //
const net::http::request_header&)>;
using accept_event = typename Trait::template accept_event<Ts...>; using accept_event = typename Trait::template accept_event<Ts...>;
...@@ -67,8 +64,8 @@ public: ...@@ -67,8 +64,8 @@ public:
const net::http::request_header& hdr) override { const net::http::request_header& hdr) override {
CAF_ASSERT(down_ptr != nullptr); CAF_ASSERT(down_ptr != nullptr);
super::down_ = down_ptr; super::down_ = down_ptr;
net::web_socket::acceptor_impl<Trait, Ts...> acc; net::web_socket::acceptor_impl<Trait, Ts...> acc{hdr};
(*on_request_)(acc, hdr); (*on_request_)(acc);
if (!acc.accepted()) { if (!acc.accepted()) {
return std::move(acc) // return std::move(acc) //
.reject_reason() .reject_reason()
...@@ -95,9 +92,7 @@ class ws_connection_factory ...@@ -95,9 +92,7 @@ class ws_connection_factory
public: public:
using ws_acceptor_t = net::web_socket::acceptor<Ts...>; using ws_acceptor_t = net::web_socket::acceptor<Ts...>;
using on_request_cb_type using on_request_cb_type = shared_callback_ptr<void(ws_acceptor_t&)>;
= shared_callback_ptr<void(ws_acceptor_t&, //
const net::http::request_header&)>;
using accept_event = typename Trait::template accept_event<Ts...>; using accept_event = typename Trait::template accept_event<Ts...>;
...@@ -156,8 +151,7 @@ public: ...@@ -156,8 +151,7 @@ public:
using config_type = typename super::config_type; using config_type = typename super::config_type;
using on_request_cb_type using on_request_cb_type = shared_callback_ptr<void(acceptor<Ts...>&)>;
= shared_callback_ptr<void(acceptor<Ts...>&, const http::request_header&)>;
server_factory(intrusive_ptr<config_type> cfg, on_request_cb_type on_request) server_factory(intrusive_ptr<config_type> cfg, on_request_cb_type on_request)
: super(std::move(cfg)), on_request_(std::move(on_request)) { : super(std::move(cfg)), on_request_(std::move(on_request)) {
......
#include "caf/net/http/route.hpp"
#include "caf/async/future.hpp"
#include "caf/disposable.hpp"
#include "caf/net/http/request.hpp"
#include "caf/net/http/responder.hpp"
#include "caf/net/multiplexer.hpp"
namespace caf::net::http {
route::~route() {
// nop
}
void route::init() {
// nop
}
} // namespace caf::net::http
namespace caf::detail {
size_t args_in_path(std::string_view str) {
size_t count = 0;
size_t start = 0;
size_t end = 0;
while (end != std::string_view::npos) {
end = str.find('/', start);
auto component
= str.substr(start, end == std::string_view::npos ? end : end - start);
if (component == "<arg>")
++count;
start = end + 1;
}
return count;
}
std::pair<std::string_view, std::string_view>
next_path_component(const std::string_view str) {
if (str.empty() || str.front() != '/') {
return {std::string_view{}, std::string_view{}};
}
size_t start = 1;
size_t end = str.find('/', start);
auto component
= str.substr(start, end == std::string_view::npos ? end : end - start);
auto remaining = end == std::string_view::npos ? std::string_view{}
: str.substr(end);
return {component, remaining};
}
bool http_simple_route_base::exec(const net::http::request_header& hdr,
const_byte_span body,
net::http::router* parent) {
if (method_ && *method_ != hdr.method())
return false;
if (hdr.path() == path_) {
net::http::responder rp{&hdr, body, parent};
do_apply(rp);
return true;
}
return false;
}
} // namespace caf::detail
...@@ -8,41 +8,6 @@ ...@@ -8,41 +8,6 @@
namespace caf::net::http { namespace caf::net::http {
// -- member types -------------------------------------------------------------
router::route::~route() {
// nop
}
std::pair<std::string_view, std::string_view>
router::route::next_component(const std::string_view str) {
if (str.empty() || str.front() != '/') {
return {std::string_view{}, std::string_view{}};
}
size_t start = 1;
size_t end = str.find('/', start);
auto component
= str.substr(start, end == std::string_view::npos ? end : end - start);
auto remaining = end == std::string_view::npos ? std::string_view{}
: str.substr(end);
return {component, remaining};
}
size_t router::route::args_in_path(std::string_view str) {
size_t count = 0;
size_t start = 0;
size_t end = 0;
while (end != std::string_view::npos) {
end = str.find('/', start);
auto component
= str.substr(start, end == std::string_view::npos ? end : end - start);
if (component == "<arg>")
++count;
start = end + 1;
}
return count;
}
// -- constructors and destructors --------------------------------------------- // -- constructors and destructors ---------------------------------------------
router::~router() { router::~router() {
......
...@@ -84,6 +84,10 @@ bool server::send_end_of_chunks() { ...@@ -84,6 +84,10 @@ bool server::send_end_of_chunks() {
return down_->end_output(); return down_->end_output();
} }
void server::switch_protocol(std::unique_ptr<octet_stream::upper_layer> next) {
down_->switch_protocol(std::move(next));
}
// -- octet_stream::upper_layer implementation --------------------------------- // -- octet_stream::upper_layer implementation ---------------------------------
error server::start(octet_stream::lower_layer* down) { error server::start(octet_stream::lower_layer* down) {
......
...@@ -98,6 +98,10 @@ void transport::switch_protocol(upper_layer_ptr next) { ...@@ -98,6 +98,10 @@ void transport::switch_protocol(upper_layer_ptr next) {
next_ = std::move(next); next_ = std::move(next);
} }
bool transport::switching_protocol() const noexcept {
return next_ != nullptr;
}
// -- implementation of transport ---------------------------------------------- // -- implementation of transport ----------------------------------------------
error transport::start(socket_manager* owner) { error transport::start(socket_manager* owner) {
......
...@@ -39,6 +39,10 @@ void mock_stream_transport::switch_protocol(upper_layer_ptr new_up) { ...@@ -39,6 +39,10 @@ void mock_stream_transport::switch_protocol(upper_layer_ptr new_up) {
next.swap(new_up); next.swap(new_up);
} }
bool mock_stream_transport::switching_protocol() const noexcept {
return next != nullptr;
}
void mock_stream_transport::configure_read(net::receive_policy policy) { void mock_stream_transport::configure_read(net::receive_policy policy) {
min_read_size = policy.min_size; min_read_size = policy.min_size;
max_read_size = policy.max_size; max_read_size = policy.max_size;
......
...@@ -50,6 +50,8 @@ public: ...@@ -50,6 +50,8 @@ public:
void switch_protocol(upper_layer_ptr) override; void switch_protocol(upper_layer_ptr) override;
bool switching_protocol() const noexcept override;
// -- initialization --------------------------------------------------------- // -- initialization ---------------------------------------------------------
caf::error start(caf::net::multiplexer* ptr) { caf::error start(caf::net::multiplexer* ptr) {
......
...@@ -15,6 +15,7 @@ using namespace std::literals; ...@@ -15,6 +15,7 @@ using namespace std::literals;
namespace http = caf::net::http; namespace http = caf::net::http;
using http::make_route;
using http::responder; using http::responder;
using http::router; using http::router;
...@@ -76,6 +77,10 @@ public: ...@@ -76,6 +77,10 @@ public:
return true; return true;
} }
void switch_protocol(std::unique_ptr<net::octet_stream::upper_layer>) {
// nop
}
private: private:
net::multiplexer* mpx_; net::multiplexer* mpx_;
}; };
...@@ -125,9 +130,9 @@ SCENARIO("routes must have one <arg> entry per argument") { ...@@ -125,9 +130,9 @@ SCENARIO("routes must have one <arg> entry per argument") {
GIVEN("a make_route call that has fewer arguments than the callback") { GIVEN("a make_route call that has fewer arguments than the callback") {
WHEN("evaluating the factory call") { WHEN("evaluating the factory call") {
THEN("the factory produces an error") { THEN("the factory produces an error") {
auto res1 = router::make_route("/", [](responder&, int) {}); auto res1 = make_route("/", [](responder&, int) {});
CHECK_EQ(res1, sec::invalid_argument); CHECK_EQ(res1, sec::invalid_argument);
auto res2 = router::make_route("/<arg>", [](responder&, int, int) {}); auto res2 = make_route("/<arg>", [](responder&, int, int) {});
CHECK_EQ(res2, sec::invalid_argument); CHECK_EQ(res2, sec::invalid_argument);
} }
} }
...@@ -135,9 +140,9 @@ SCENARIO("routes must have one <arg> entry per argument") { ...@@ -135,9 +140,9 @@ SCENARIO("routes must have one <arg> entry per argument") {
GIVEN("a make_route call that has more arguments than the callback") { GIVEN("a make_route call that has more arguments than the callback") {
WHEN("evaluating the factory call") { WHEN("evaluating the factory call") {
THEN("the factory produces an error") { THEN("the factory produces an error") {
auto res1 = router::make_route("/<arg>/<arg>", [](responder&) {}); auto res1 = make_route("/<arg>/<arg>", [](responder&) {});
CHECK_EQ(res1, sec::invalid_argument); CHECK_EQ(res1, sec::invalid_argument);
auto res2 = router::make_route("/<arg>/<arg>", [](responder&, int) {}); auto res2 = make_route("/<arg>/<arg>", [](responder&, int) {});
CHECK_EQ(res2, sec::invalid_argument); CHECK_EQ(res2, sec::invalid_argument);
} }
} }
...@@ -145,14 +150,14 @@ SCENARIO("routes must have one <arg> entry per argument") { ...@@ -145,14 +150,14 @@ SCENARIO("routes must have one <arg> entry per argument") {
GIVEN("a make_route call with the matching number of arguments") { GIVEN("a make_route call with the matching number of arguments") {
WHEN("evaluating the factory call") { WHEN("evaluating the factory call") {
THEN("the factory produces a valid callback") { THEN("the factory produces a valid callback") {
if (auto res = router::make_route("/", [](responder&) {}); CHECK(res)) { if (auto res = make_route("/", [](responder&) {}); CHECK(res)) {
set_get_request("/"); set_get_request("/");
CHECK((*res)->exec(hdr, {}, &rt)); CHECK((*res)->exec(hdr, {}, &rt));
set_get_request("/foo/bar"); set_get_request("/foo/bar");
CHECK(!(*res)->exec(hdr, {}, &rt)); CHECK(!(*res)->exec(hdr, {}, &rt));
} }
if (auto res = router::make_route("/foo/bar", http::method::get, if (auto res = make_route("/foo/bar", http::method::get,
[](responder&) {}); [](responder&) {});
CHECK(res)) { CHECK(res)) {
set_get_request("/"); set_get_request("/");
CHECK(!(*res)->exec(hdr, {}, &rt)); CHECK(!(*res)->exec(hdr, {}, &rt));
...@@ -165,7 +170,7 @@ SCENARIO("routes must have one <arg> entry per argument") { ...@@ -165,7 +170,7 @@ SCENARIO("routes must have one <arg> entry per argument") {
set_get_request("/foo/bar"); set_get_request("/foo/bar");
CHECK((*res)->exec(hdr, {}, &rt)); CHECK((*res)->exec(hdr, {}, &rt));
} }
if (auto res = router::make_route( if (auto res = make_route(
"/<arg>", [this](responder&, int x) { args = make_args(x); }); "/<arg>", [this](responder&, int x) { args = make_args(x); });
CHECK(res)) { CHECK(res)) {
set_get_request("/"); set_get_request("/");
...@@ -176,10 +181,9 @@ SCENARIO("routes must have one <arg> entry per argument") { ...@@ -176,10 +181,9 @@ SCENARIO("routes must have one <arg> entry per argument") {
if (CHECK((*res)->exec(hdr, {}, &rt))) if (CHECK((*res)->exec(hdr, {}, &rt)))
CHECK_EQ(args, make_args(42)); CHECK_EQ(args, make_args(42));
} }
if (auto res = router::make_route("/foo/<arg>/bar", if (auto res
[this](responder&, int x) { = make_route("/foo/<arg>/bar",
args = make_args(x); [this](responder&, int x) { args = make_args(x); });
});
CHECK(res)) { CHECK(res)) {
set_get_request("/"); set_get_request("/");
CHECK(!(*res)->exec(hdr, {}, &rt)); CHECK(!(*res)->exec(hdr, {}, &rt));
...@@ -189,10 +193,10 @@ SCENARIO("routes must have one <arg> entry per argument") { ...@@ -189,10 +193,10 @@ SCENARIO("routes must have one <arg> entry per argument") {
if (CHECK((*res)->exec(hdr, {}, &rt))) if (CHECK((*res)->exec(hdr, {}, &rt)))
CHECK_EQ(args, make_args(123)); CHECK_EQ(args, make_args(123));
} }
if (auto res = router::make_route("/foo/<arg>/bar", if (auto res = make_route("/foo/<arg>/bar",
[this](responder&, std::string x) { [this](responder&, std::string x) {
args = make_args(x); args = make_args(x);
}); });
CHECK(res)) { CHECK(res)) {
set_get_request("/"); set_get_request("/");
CHECK(!(*res)->exec(hdr, {}, &rt)); CHECK(!(*res)->exec(hdr, {}, &rt));
...@@ -202,11 +206,10 @@ SCENARIO("routes must have one <arg> entry per argument") { ...@@ -202,11 +206,10 @@ SCENARIO("routes must have one <arg> entry per argument") {
if (CHECK((*res)->exec(hdr, {}, &rt))) if (CHECK((*res)->exec(hdr, {}, &rt)))
CHECK_EQ(args, make_args("my-arg"s)); CHECK_EQ(args, make_args("my-arg"s));
} }
if (auto res if (auto res = make_route("/<arg>/<arg>/<arg>",
= router::make_route("/<arg>/<arg>/<arg>", [this](responder&, int x, bool y, int z) {
[this](responder&, int x, bool y, int z) { args = make_args(x, y, z);
args = make_args(x, y, z); });
});
CHECK(res)) { CHECK(res)) {
set_get_request("/"); set_get_request("/");
CHECK(!(*res)->exec(hdr, {}, &rt)); CHECK(!(*res)->exec(hdr, {}, &rt));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment