Commit 539737e9 authored by Dominik Charousset's avatar Dominik Charousset

Implement WebSocket handshake response

parent d3a43ddb
......@@ -38,6 +38,10 @@ struct receive_policy {
return {size, size};
}
static constexpr receive_policy up_to(uint32_t max_size) {
return {1, max_size};
}
static constexpr receive_policy stop() {
return {0, 0};
}
......
......@@ -170,9 +170,9 @@ public:
};
access<Parent> this_layer{&parent, this};
for (size_t i = 0; max_read_size_ > 0 && i < max_consecutive_reads_; ++i) {
CAF_ASSERT(min_read_size_ > read_size_);
CAF_ASSERT(max_read_size_ > read_size_);
auto buf = read_buf_.data() + read_size_;
size_t len = min_read_size_ - read_size_;
size_t len = max_read_size_ - read_size_;
CAF_LOG_DEBUG(CAF_ARG2("missing", len));
auto num_bytes = read(parent.template handle<socket_type>(),
make_span(buf, len));
......
......@@ -20,8 +20,11 @@
#include <algorithm>
#include "caf/detail/encode_base64.hpp"
#include "caf/detail/move_if_not_ptr.hpp"
#include "caf/error.hpp"
#include "caf/hash/sha1.hpp"
#include "caf/net/receive_policy.hpp"
#include "caf/pec.hpp"
#include "caf/settings.hpp"
#include "caf/tag/stream_oriented.hpp"
......@@ -42,6 +45,24 @@ public:
using output_tag = tag::stream_oriented;
// -- constants --------------------------------------------------------------
static constexpr std::array<byte, 4> end_of_header{{
byte{'\r'},
byte{'\n'},
byte{'\r'},
byte{'\n'},
}};
// A handshake should usually fit into 200-300 Bytes. 2KB is more than enough.
static constexpr uint32_t max_header_size = 2048;
static constexpr string_view header_too_large
= "HTTP/1.1 431 Request Header Fields Too Large\r\n"
"Content-Type: text/plain\r\n"
"\r\n"
"Header exceeds 2048 Bytes.\r\n";
// -- constructors, destructors, and assignment operators --------------------
template <class... Ts>
......@@ -53,15 +74,6 @@ public:
// nop
}
// -- constants --------------------------------------------------------------
static constexpr std::array<byte, 4> end_of_header{{
byte{'\r'},
byte{'\n'},
byte{'\r'},
byte{'\n'},
}};
// -- properties -------------------------------------------------------------
auto& upper_layer() noexcept {
......@@ -74,9 +86,10 @@ public:
// -- initialization ---------------------------------------------------------
template <class Parent>
error init(Parent&, const settings& config) {
template <class LowerLayer>
error init(LowerLayer& down, const settings& config) {
cfg_ = config;
down.configure_read(net::receive_policy::up_to(max_header_size));
return none;
}
......@@ -84,24 +97,18 @@ public:
template <class LowerLayer>
bool prepare_send(LowerLayer& down) {
if (handshake_complete_)
return upper_layer_.prepare_send(down);
// TODO: implement me.
return false;
return handshake_complete_ ? upper_layer_.prepare_send(down) : true;
}
template <class LowerLayer>
bool done_sending(LowerLayer& down) {
if (handshake_complete_)
return upper_layer_.done_sending(down);
// TODO: implement me.
return false;
return handshake_complete_ ? upper_layer_.done_sending(down) : true;
}
template <class LowerLayer>
void abort(LowerLayer& down, const error& reason) {
if (handshake_complete_)
return upper_layer_.abort(down, reason);
upper_layer_.abort(down, reason);
}
template <class LowerLayer>
......@@ -111,8 +118,16 @@ public:
// TODO: we could avoid repeated scans by using the delta parameter.
auto i = std::search(buffer.begin(), buffer.end(),
end_of_header.begin(), end_of_header.end());
if (i == buffer.end())
if (i == buffer.end()) {
if (buffer.size() == max_header_size) {
write(down, header_too_large);
auto err = make_error(pec::too_many_characters,
"exceeded maximum header size");
down.abort_reason(std::move(err));
return -1;
}
return 0;
}
auto offset = static_cast<size_t>(std::distance(buffer.begin(), i));
offset += end_of_header.size();
// Take all but the last two bytes (to avoid an empty line) as input for
......@@ -121,10 +136,29 @@ public:
offset - 2};
if (!handle_header(down, header))
return -1;
return offset + upper_layer_.consume(down, buffer.subspan(offset), {});
ptrdiff_t sub_result = 0;
if (offset < buffer.size()) {
sub_result = upper_layer_.consume(down, buffer.subspan(offset), {});
if (sub_result < 0)
return sub_result;
}
return static_cast<ptrdiff_t>(offset) + sub_result;
}
bool handshake_complete() const noexcept {
return handshake_complete_;
}
private:
template <class LowerLayer>
static void write(LowerLayer& down, string_view output) {
auto out = as_bytes(make_span(output));
down.begin_output();
auto& buf = down.output_buffer();
buf.insert(buf.end(), out.begin(), out.end());
down.end_output();
}
template <class LowerLayer>
bool handle_header(LowerLayer& down, string_view input) {
// Parse the first line, i.e., "METHOD REQUEST-URI VERSION".
......@@ -156,7 +190,8 @@ private:
// Check whether the mandatory fields exist.
std::string sec_key;
if (auto skey_field = get_if<std::string>(&fields, "Sec-WebSocket-Key")) {
sec_key = detail::move_if_not_ptr(skey_field);
auto field_hash = hash::sha1::compute(*skey_field);
sec_key = detail::encode_base64(field_hash);
} else {
auto err = make_error(pec::missing_field,
"Mandatory field Sec-WebSocket-Key not found");
......@@ -169,7 +204,21 @@ private:
return false;
}
// Send server handshake.
down.begin_output();
auto& buf = down.output_buffer();
auto append = [&buf](string_view output) {
auto out = as_bytes(make_span(output));
buf.insert(buf.end(), out.begin(), out.end());
};
append("HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: ");
append(sec_key);
append("\r\n\r\n");
down.end_output();
// Done.
handshake_complete_ = true;
return true;
}
......@@ -187,8 +236,9 @@ private:
static string_view trim(string_view str) {
str.remove_prefix(std::min(str.find_first_not_of(' '), str.size()));
str.remove_suffix(str.size()
- std::min(str.find_last_not_of(' '), str.size()));
auto trim_pos = str.find_last_not_of(' ');
if (trim_pos != str.npos)
str.remove_suffix(str.size() - (trim_pos + 1));
return str;
}
......
#pragma once
#include "caf/error.hpp"
#include "caf/net/receive_policy.hpp"
#include "caf/net/test/host_fixture.hpp"
#include "caf/settings.hpp"
#include "caf/span.hpp"
#include "caf/string_view.hpp"
#include "caf/tag/stream_oriented.hpp"
#include "caf/test/dsl.hpp"
template <class UpperLayer>
class mock_stream_transport {
public:
// -- member types -----------------------------------------------------------
using output_tag = caf::tag::stream_oriented;
// -- interface for the upper layer ------------------------------------------
class access {
public:
explicit access(mock_stream_transport* transport) : transport_(transport) {
// nop
}
void begin_output() {
// nop
}
auto& output_buffer() {
return transport_->output;
}
constexpr void end_output() {
// nop
}
bool can_send_more() const noexcept {
return true;
}
void abort_reason(caf::error reason) {
transport_->abort_reason = std::move(reason);
}
void configure_read(caf::net::receive_policy policy) {
transport_->min_read_size = policy.min_size;
transport_->max_read_size = policy.max_size;
}
private:
mock_stream_transport* transport_;
};
friend class access;
// -- initialization ---------------------------------------------------------
caf::error init(const caf::settings& config) {
access this_layer{this};
return upper_layer.init(this_layer, config);
}
caf::error init() {
caf::settings config;
return init(config);
}
// -- buffer management ------------------------------------------------------
void push(caf::span<const caf::byte> bytes) {
input.insert(input.begin(), bytes.begin(), bytes.end());
}
void push(caf::string_view str) {
push(caf::as_bytes(caf::make_span(str)));
}
size_t unconsumed() const noexcept {
return read_buf_.size();
}
caf::string_view output_as_str() const noexcept {
return {reinterpret_cast<const char*>(output.data()), output.size()};
}
// -- event callbacks --------------------------------------------------------
ptrdiff_t handle_input() {
ptrdiff_t result = 0;
access this_layer{this};
while (max_read_size > 0) {
CAF_ASSERT(max_read_size > static_cast<size_t>(read_size_));
size_t len = max_read_size - static_cast<size_t>(read_size_);
CAF_LOG_DEBUG(CAF_ARG2("available capacity:", len));
auto num_bytes = std::min(input.size(), len);
if (num_bytes == 0)
return result;
auto delta_offset = static_cast<ptrdiff_t>(read_buf_.size());
read_buf_.insert(read_buf_.end(), input.begin(),
input.begin() + num_bytes);
input.erase(input.begin(), input.begin() + num_bytes);
read_size_ += static_cast<ptrdiff_t>(num_bytes);
if (static_cast<size_t>(read_size_) < min_read_size)
return result;
auto delta = make_span(read_buf_.data() + delta_offset,
read_size_ - delta_offset);
auto consumed = upper_layer.consume(this_layer, caf::make_span(read_buf_),
delta);
if (consumed > 0) {
result += static_cast<ptrdiff_t>(consumed);
read_buf_.erase(read_buf_.begin(), read_buf_.begin() + consumed);
read_size_ -= consumed;
} else if (consumed < 0) {
if (!abort_reason)
abort_reason = caf::sec::runtime_error;
upper_layer.abort(this_layer, abort_reason);
return -1;
}
}
return result;
}
// -- member variables -------------------------------------------------------
caf::error abort_reason;
UpperLayer upper_layer;
std::vector<caf::byte> output;
std::vector<caf::byte> input;
uint32_t min_read_size = 0;
uint32_t max_read_size = 0;
private:
std::vector<caf::byte> read_buf_;
ptrdiff_t read_size_ = 0;
};
......@@ -20,8 +20,7 @@
#include "caf/net/web_socket.hpp"
#include "caf/net/test/host_fixture.hpp"
#include "caf/test/dsl.hpp"
#include "net-test.hpp"
#include "caf/net/multiplexer.hpp"
#include "caf/net/socket_manager.hpp"
......@@ -29,16 +28,22 @@
#include "caf/net/stream_transport.hpp"
using namespace caf;
using namespace std::literals::string_literals;
namespace {
using byte_span = span<const byte>;
struct app {
using svec = std::vector<std::string>;
struct app_t {
std::vector<std::string> lines;
settings cfg;
template <class LowerLayer>
error init(LowerLayer&, const settings&) {
error init(LowerLayer&, const settings& init_cfg) {
cfg = init_cfg;
return none;
}
......@@ -63,11 +68,12 @@ struct app {
auto e = buffer.end();
if (auto i = std::find(buffer.begin(), e, nl); i != e) {
std::string str;
auto string_size = static_cast<size_t>(std::distance(buffer.begin(), e));
auto string_size = static_cast<size_t>(std::distance(buffer.begin(), i));
str.reserve(string_size);
auto num_bytes = string_size + 1; // Also consume the newline character.
std::transform(buffer.begin(), i, std::back_inserter(str),
[](byte x) { return static_cast<char>(x); });
lines.emplace_back(std::move(str));
return num_bytes + consume(down, buffer.subspan(num_bytes), {});
}
return 0;
......@@ -77,37 +83,113 @@ struct app {
struct fixture : host_fixture {
fixture() {
using namespace caf::net;
mpx = std::make_shared<multiplexer>();
mpx->set_thread_id();
std::tie(sock.self, sock.mgr) = unbox(make_stream_socket_pair());
auto ptr = make_socket_manager<app, web_socket, stream_transport>(sock.mgr,
mpx);
settings cfg;
if (auto err = ptr->init(cfg))
CAF_FAIL("initializing the socket manager failed: " << err);
mgr = ptr;
ws = std::addressof(transport.upper_layer);
app = std::addressof(ws->upper_layer());
if (auto err = transport.init())
CAF_FAIL("failed to initialize mock transport: " << err);
}
~fixture() {
close(sock.self);
}
net::multiplexer_ptr mpx;
mock_stream_transport<net::web_socket<app_t>> transport;
net::socket_manager_ptr mgr;
net::web_socket<app_t>* ws;
struct {
net::stream_socket self;
net::stream_socket mgr;
} sock;
app_t* app;
};
constexpr string_view opening_handshake
= "GET /chat HTTP/1.1\r\n"
"Host: server.example.com\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Origin: http://example.com\r\n"
"Sec-WebSocket-Protocol: chat, superchat\r\n"
"Sec-WebSocket-Version: 13\r\n"
"\r\n";
} // namespace
#define CHECK_SETTING(key, expected_value) \
if (CAF_CHECK(holds_alternative<std::string>(app->cfg, key))) \
CAF_CHECK_EQUAL(get<std::string>(app->cfg, key), expected_value);
CAF_TEST_FIXTURE_SCOPE(web_socket_tests, fixture)
CAF_TEST(todo) {
CAF_TEST(applications receive handshake data via config) {
transport.push(opening_handshake);
{
auto consumed = transport.handle_input();
if (consumed < 0)
CAF_FAIL("error handling input: " << transport.abort_reason);
CAF_CHECK_EQUAL(consumed, static_cast<ptrdiff_t>(opening_handshake.size()));
}
CAF_CHECK_EQUAL(transport.input.size(), 0u);
CAF_CHECK_EQUAL(transport.unconsumed(), 0u);
CAF_CHECK(ws->handshake_complete());
CHECK_SETTING("web-socket.method", "GET");
CHECK_SETTING("web-socket.request-uri", "/chat");
CHECK_SETTING("web-socket.http-version", "HTTP/1.1");
CHECK_SETTING("web-socket.fields.Host", "server.example.com");
CHECK_SETTING("web-socket.fields.Upgrade", "websocket");
CHECK_SETTING("web-socket.fields.Connection", "Upgrade");
CHECK_SETTING("web-socket.fields.Origin", "http://example.com");
CHECK_SETTING("web-socket.fields.Sec-WebSocket-Protocol", "chat, superchat");
CHECK_SETTING("web-socket.fields.Sec-WebSocket-Version", "13");
CHECK_SETTING("web-socket.fields.Sec-WebSocket-Key",
"dGhlIHNhbXBsZSBub25jZQ==");
}
CAF_TEST(the server responds with an HTTP response on success) {
transport.push(opening_handshake);
CAF_CHECK_EQUAL(transport.handle_input(),
static_cast<ptrdiff_t>(opening_handshake.size()));
CAF_CHECK(ws->handshake_complete());
CAF_CHECK(transport.output_as_str(),
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n");
}
CAF_TEST(handshakes may arrive in chunks) {
svec bufs;
size_t chunk_size = opening_handshake.size() / 3;
auto i = opening_handshake.begin();
bufs.emplace_back(i, i + chunk_size);
i += chunk_size;
bufs.emplace_back(i, i + chunk_size);
i += chunk_size;
bufs.emplace_back(i, opening_handshake.end());
transport.push(bufs[0]);
CAF_CHECK_EQUAL(transport.handle_input(), 0u);
CAF_CHECK(!ws->handshake_complete());
transport.push(bufs[1]);
CAF_CHECK_EQUAL(transport.handle_input(), 0u);
CAF_CHECK(!ws->handshake_complete());
transport.push(bufs[2]);
CAF_CHECK_EQUAL(transport.handle_input(), opening_handshake.size());
CAF_CHECK(ws->handshake_complete());
}
CAF_TEST(data may follow the handshake immediately) {
std::string buf{opening_handshake.begin(), opening_handshake.end()};
buf += "Hello WebSocket!\n";
buf += "Bye WebSocket!\n";
transport.push(buf);
CAF_CHECK_EQUAL(transport.handle_input(), static_cast<ptrdiff_t>(buf.size()));
CAF_CHECK(ws->handshake_complete());
CAF_CHECK_EQUAL(app->lines, svec({"Hello WebSocket!", "Bye WebSocket!"}));
}
CAF_TEST(data may arrive later) {
transport.push(opening_handshake);
CAF_CHECK_EQUAL(transport.handle_input(),
static_cast<ptrdiff_t>(opening_handshake.size()));
CAF_CHECK(ws->handshake_complete());
auto buf = "Hello WebSocket!\nBye WebSocket!\n"s;
transport.push(buf);
CAF_CHECK_EQUAL(transport.handle_input(), static_cast<ptrdiff_t>(buf.size()));
CAF_CHECK_EQUAL(app->lines, svec({"Hello WebSocket!", "Bye WebSocket!"}));
}
CAF_TEST_FIXTURE_SCOPE_END()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment