Commit ef0d2b4c authored by Samir Halilcevic's avatar Samir Halilcevic

Refactor framing to isolate closing handshake

parent 5d0b63bc
...@@ -139,10 +139,12 @@ private: ...@@ -139,10 +139,12 @@ private:
template <class T> template <class T>
void ship_frame(std::vector<T>& buf); void ship_frame(std::vector<T>& buf);
void abort_and_shutdown(caf::error reason) { // Sends closing message, can be error status, or closing handshake
abort(reason); void ship_closing_message(status code, std::string_view desc);
shutdown(reason);
} // Signal abort to the upper layer and shutdown to the lower layer,
// with closing message
void abort_and_close_connection(sec reason, std::string_view msg);
// -- member variables ------------------------------------------------------- // -- member variables -------------------------------------------------------
......
...@@ -30,8 +30,8 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) { ...@@ -30,8 +30,8 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) {
auto hdr_bytes = detail::rfc6455::decode_header(buffer, hdr); auto hdr_bytes = detail::rfc6455::decode_header(buffer, hdr);
if (hdr_bytes < 0) { if (hdr_bytes < 0) {
CAF_LOG_DEBUG("decoded malformed data: hdr_bytes < 0"); CAF_LOG_DEBUG("decoded malformed data: hdr_bytes < 0");
abort_and_shutdown(make_error( abort_and_close_connection(sec::protocol_error,
sec::protocol_error, "negative header size on WebSocket connection")); "negative header size on WebSocket connection");
return -1; return -1;
} }
if (hdr_bytes == 0) { if (hdr_bytes == 0) {
...@@ -41,8 +41,8 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) { ...@@ -41,8 +41,8 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) {
// Make sure the entire frame (including header) fits into max_frame_size. // Make sure the entire frame (including header) fits into max_frame_size.
if (hdr.payload_len >= (max_frame_size - static_cast<size_t>(hdr_bytes))) { if (hdr.payload_len >= (max_frame_size - static_cast<size_t>(hdr_bytes))) {
CAF_LOG_DEBUG("WebSocket frame too large"); CAF_LOG_DEBUG("WebSocket frame too large");
abort_and_shutdown( abort_and_close_connection(sec::protocol_error,
make_error(sec::protocol_error, "WebSocket frame too large")); "WebSocket frame too large");
return -1; return -1;
} }
// Wait for more data if necessary. // Wait for more data if necessary.
...@@ -61,25 +61,27 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) { ...@@ -61,25 +61,27 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) {
if (opcode_ == nil_code) { if (opcode_ == nil_code) {
// Call upper layer. // Call upper layer.
if (hdr.opcode == detail::rfc6455::connection_close) { if (hdr.opcode == detail::rfc6455::connection_close) {
abort_and_shutdown(make_error(sec::connection_closed)); // TODO
abort_and_close_connection(sec::connection_closed, "");
return -1; return -1;
} else if (!handle(hdr.opcode, payload)) { } else if (!handle(hdr.opcode, payload)) {
return -1; return -1;
} }
} else if (hdr.opcode != detail::rfc6455::continuation_frame) { } else if (hdr.opcode != detail::rfc6455::continuation_frame) {
CAF_LOG_DEBUG("expected a WebSocket continuation_frame"); CAF_LOG_DEBUG("expected a WebSocket continuation_frame");
abort_and_shutdown(make_error(sec::protocol_error, abort_and_close_connection(sec::protocol_error,
"expected a WebSocket continuation_frame")); "expected a WebSocket continuation_frame");
return -1; return -1;
} else if (payload_buf_.size() + payload_len > max_frame_size) { } else if (payload_buf_.size() + payload_len > max_frame_size) {
CAF_LOG_DEBUG("fragmented WebSocket payload exceeds maximum size"); CAF_LOG_DEBUG("fragmented WebSocket payload exceeds maximum size");
abort_and_shutdown(make_error(sec::protocol_error, abort_and_close_connection(sec::protocol_error,
"fragmented WebSocket payload " "fragmented WebSocket payload "
"exceeds maximum size")); "exceeds maximum size");
return -1; return -1;
} else { } else {
if (hdr.opcode == detail::rfc6455::connection_close) { if (hdr.opcode == detail::rfc6455::connection_close) {
abort_and_shutdown(make_error(sec::connection_closed)); // TODO
abort_and_close_connection(sec::connection_closed, "");
return -1; return -1;
} else { } else {
// End of fragmented input. // End of fragmented input.
...@@ -98,23 +100,23 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) { ...@@ -98,23 +100,23 @@ ptrdiff_t framing::consume(byte_span buffer, byte_span) {
if (hdr.opcode == detail::rfc6455::continuation_frame) { if (hdr.opcode == detail::rfc6455::continuation_frame) {
CAF_LOG_DEBUG("received WebSocket continuation " CAF_LOG_DEBUG("received WebSocket continuation "
"frame without prior opcode"); "frame without prior opcode");
abort_and_shutdown(make_error(sec::protocol_error, abort_and_close_connection(sec::protocol_error,
"received WebSocket continuation " "received WebSocket continuation "
"frame without prior opcode")); "frame without prior opcode");
return -1; return -1;
} }
opcode_ = hdr.opcode; opcode_ = hdr.opcode;
} else if (hdr.opcode != detail::rfc6455::continuation_frame) { } else if (hdr.opcode != detail::rfc6455::continuation_frame) {
CAF_LOG_DEBUG("expected a continuation frame"); CAF_LOG_DEBUG("expected a continuation frame");
abort_and_shutdown(make_error(sec::protocol_error, // abort_and_close_connection(sec::protocol_error,
"expected a continuation frame")); "expected a continuation frame");
return -1; return -1;
} else if (payload_buf_.size() + payload_len > max_frame_size) { } else if (payload_buf_.size() + payload_len > max_frame_size) {
// Reject assembled payloads that exceed max_frame_size. // Reject assembled payloads that exceed max_frame_size.
CAF_LOG_DEBUG("fragmented WebSocket payload exceeds maximum size"); CAF_LOG_DEBUG("fragmented WebSocket payload exceeds maximum size");
abort_and_shutdown(make_error(sec::protocol_error, abort_and_close_connection(sec::protocol_error,
"fragmented WebSocket payload " "fragmented WebSocket payload "
"exceeds maximum size")); "exceeds maximum size");
return -1; return -1;
} }
payload_buf_.insert(payload_buf_.end(), payload.begin(), payload.end()); payload_buf_.insert(payload_buf_.end(), payload.begin(), payload.end());
...@@ -153,22 +155,7 @@ void framing::write_later() { ...@@ -153,22 +155,7 @@ void framing::write_later() {
} }
void framing::shutdown(status code, std::string_view msg) { void framing::shutdown(status code, std::string_view msg) {
auto code_val = static_cast<uint16_t>(code); ship_closing_message(code, msg);
uint32_t mask_key = 0;
byte_buffer payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<std::byte>((code_val & 0xFF00) >> 8));
payload.push_back(static_cast<std::byte>(code_val & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<std::byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down_->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down_->output_buffer());
down_->end_output();
down_->shutdown(); down_->shutdown();
} }
...@@ -253,4 +240,37 @@ void framing::ship_frame(std::vector<T>& buf) { ...@@ -253,4 +240,37 @@ void framing::ship_frame(std::vector<T>& buf) {
buf.clear(); buf.clear();
} }
void framing::ship_closing_message(status code, std::string_view msg) {
auto code_val = static_cast<uint16_t>(code);
uint32_t mask_key = 0;
byte_buffer payload;
payload.reserve(msg.size() + 2);
payload.push_back(static_cast<std::byte>((code_val & 0xFF00) >> 8));
payload.push_back(static_cast<std::byte>(code_val & 0x00FF));
for (auto c : msg)
payload.push_back(static_cast<std::byte>(c));
if (mask_outgoing_frames) {
mask_key = static_cast<uint32_t>(rng_());
detail::rfc6455::mask_data(mask_key, payload);
}
down_->begin_output();
detail::rfc6455::assemble_frame(detail::rfc6455::connection_close, mask_key,
payload, down_->output_buffer());
down_->end_output();
}
void framing::abort_and_close_connection(sec reason, std::string_view msg) {
if (msg.empty())
up_->abort(make_error(reason));
else
up_->abort(make_error(reason, std::string(msg)));
status code = status::unexpected_condition;
if (reason == sec::connection_closed)
code = status::normal_close;
else if (reason == sec::protocol_error)
code = status::protocol_error;
ship_closing_message(code, to_string(reason));
down_->shutdown();
}
} // namespace caf::net::web_socket } // namespace caf::net::web_socket
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment