Commit 687502fc authored by Samir Halilcevic's avatar Samir Halilcevic

Validate UTF-8 on WebSocket text payloads

parent d0cdb60c
...@@ -32,6 +32,27 @@ constexpr std::byte head(std::byte value) noexcept { ...@@ -32,6 +32,27 @@ constexpr std::byte head(std::byte value) noexcept {
} }
} }
// Takes the last N bits of `value`.
template <size_t N>
constexpr std::byte tail(std::byte value) noexcept {
if constexpr (N == 1) {
return value & 0b0000'0001_b;
} else if constexpr (N == 2) {
return value & 0b0000'0011_b;
} else if constexpr (N == 3) {
return value & 0b0000'0111_b;
} else if constexpr (N == 4) {
return value & 0b0000'1111_b;
} else if constexpr (N == 5) {
return value & 0b0001'1111_b;
} else if constexpr (N == 6) {
return value & 0b0011'1111_b;
} else {
static_assert(N == 7);
return value & 0b0111'1111_b;
}
}
// Checks whether `value` is an UTF-8 continuation byte. // Checks whether `value` is an UTF-8 continuation byte.
constexpr bool is_continuation_byte(std::byte value) noexcept { constexpr bool is_continuation_byte(std::byte value) noexcept {
return head<2>(value) == 0b1000'0000_b; return head<2>(value) == 0b1000'0000_b;
...@@ -42,9 +63,6 @@ constexpr bool is_continuation_byte(std::byte value) noexcept { ...@@ -42,9 +63,6 @@ constexpr bool is_continuation_byte(std::byte value) noexcept {
bool validate_rfc3629(const std::byte* first, const std::byte* last) { bool validate_rfc3629(const std::byte* first, const std::byte* last) {
while (first != last) { while (first != last) {
auto x = *first++; auto x = *first++;
// Null byte (terminator) is not allowed.
if (x == 0_b)
return false;
// First bit is zero: ASCII character. // First bit is zero: ASCII character.
if (head<1>(x) == 0b0000'0000_b) if (head<1>(x) == 0b0000'0000_b)
continue; continue;
...@@ -72,18 +90,25 @@ bool validate_rfc3629(const std::byte* first, const std::byte* last) { ...@@ -72,18 +90,25 @@ bool validate_rfc3629(const std::byte* first, const std::byte* last) {
return false; return false;
continue; continue;
} }
// 1111'0bxx: 4-byte sequence. // 1111'0xxx: 4-byte sequence.
if (head<5>(x) == 0b1111'0000_b) { if (head<5>(x) == 0b1111'0000_b) {
uint64_t code_point = std::to_integer<uint64_t>(tail<3>(x)) << 18;
if (first == last || !is_continuation_byte(*first)) if (first == last || !is_continuation_byte(*first))
return false; return false;
// No non-shortest form. // No non-shortest form.
if (x == 0b1111'0000_b && head<4>(*first) == 0b1000'0000_b) if (x == 0b1111'0000_b && head<4>(*first) == 0b1000'0000_b)
return false; return false;
++first; code_point |= std::to_integer<uint64_t>(tail<6>(*first++)) << 12;
if (first == last || !is_continuation_byte(*first++)) if (first == last || !is_continuation_byte(*first))
return false; return false;
if (first == last || !is_continuation_byte(*first++)) code_point |= std::to_integer<uint64_t>(tail<6>(*first++)) << 6;
if (first == last || !is_continuation_byte(*first))
return false;
code_point |= std::to_integer<uint64_t>(tail<6>(*first++));
// Out of valid UTF range
if (code_point >= 0x110000) {
return false; return false;
}
continue; continue;
} }
return false; return false;
......
...@@ -96,11 +96,19 @@ constexpr std::byte valid_three_byte_1[] = {0xe0_b, 0xa0_b, 0x80_b}; ...@@ -96,11 +96,19 @@ constexpr std::byte valid_three_byte_1[] = {0xe0_b, 0xa0_b, 0x80_b};
// Largest valid 3-byte sequence. // Largest valid 3-byte sequence.
constexpr std::byte valid_three_byte_2[] = {0xef_b, 0xbf_b, 0xbf_b}; constexpr std::byte valid_three_byte_2[] = {0xef_b, 0xbf_b, 0xbf_b};
// UTF-8 standard covers code points in the sequences [0x0 - 0x110000)
// Theoretically, a larger value can be encoded in the 4-byte sequence.
// Smallest valid 4-byte sequence. // Smallest valid 4-byte sequence.
constexpr std::byte valid_four_byte_1[] = {0xf0_b, 0x90_b, 0x80_b, 0x80_b}; constexpr std::byte valid_four_byte_1[] = {0xf0_b, 0x90_b, 0x80_b, 0x80_b};
// Largest valid 4-byte sequence. // Largest valid 4-byte sequence - code point 0x10FFFF.
constexpr std::byte valid_four_byte_2[] = {0xf7_b, 0xbf_b, 0xbf_b, 0xbf_b}; constexpr std::byte valid_four_byte_2[] = {0xf4_b, 0x8f_b, 0xbf_b, 0xbf_b};
// Smallest invalid 4-byte sequence - code point 0x110000.
constexpr std::byte invalid_four_byte_9[] = {0xf4_b, 0x90_b, 0x80_b, 0x80_b};
// Largest invalid 4-byte sequence - invalid code point.
constexpr std::byte invalid_four_byte_10[] = {0xf7_b, 0xbf_b, 0xbf_b, 0xbf_b};
// Single line ASCII text. // Single line ASCII text.
constexpr std::string_view ascii_1 = "Hello World!"; constexpr std::string_view ascii_1 = "Hello World!";
...@@ -151,4 +159,6 @@ TEST_CASE("invalid UTF-8 input") { ...@@ -151,4 +159,6 @@ TEST_CASE("invalid UTF-8 input") {
CHECK(!valid_utf8(invalid_four_byte_6)); CHECK(!valid_utf8(invalid_four_byte_6));
CHECK(!valid_utf8(invalid_four_byte_7)); CHECK(!valid_utf8(invalid_four_byte_7));
CHECK(!valid_utf8(invalid_four_byte_8)); CHECK(!valid_utf8(invalid_four_byte_8));
CHECK(!valid_utf8(invalid_four_byte_9));
CHECK(!valid_utf8(invalid_four_byte_10));
} }
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "caf/net/web_socket/framing.hpp" #include "caf/net/web_socket/framing.hpp"
#include "caf/detail/rfc3629.hpp"
#include "caf/logger.hpp" #include "caf/logger.hpp"
#include "caf/net/http/v1.hpp" #include "caf/net/http/v1.hpp"
...@@ -202,6 +203,10 @@ ptrdiff_t framing::handle(uint8_t opcode, byte_span payload, ...@@ -202,6 +203,10 @@ ptrdiff_t framing::handle(uint8_t opcode, byte_span payload,
case detail::rfc6455::text_frame: { case detail::rfc6455::text_frame: {
std::string_view text{reinterpret_cast<const char*>(payload.data()), std::string_view text{reinterpret_cast<const char*>(payload.data()),
payload.size()}; payload.size()};
if (!detail::rfc3629::valid(text)) {
abort_and_shutdown(sec::runtime_error, "invalid UTF-8 sequence");
return -1;
}
if (up_->consume_text(text) < 0) if (up_->consume_text(text) < 0)
return -1; return -1;
break; break;
......
...@@ -24,6 +24,8 @@ void lower_layer::shutdown(const error& reason) { ...@@ -24,6 +24,8 @@ void lower_layer::shutdown(const error& reason) {
shutdown(status::normal_close, to_string(reason)); shutdown(status::normal_close, to_string(reason));
} else if (reason.code() == static_cast<uint8_t>(sec::protocol_error)) { } else if (reason.code() == static_cast<uint8_t>(sec::protocol_error)) {
shutdown(status::protocol_error, to_string(reason)); shutdown(status::protocol_error, to_string(reason));
} else if (reason.code() == static_cast<uint8_t>(sec::runtime_error)) {
shutdown(status::inconsistent_data, to_string(reason));
} else { } else {
shutdown(status::unexpected_condition, to_string(reason)); shutdown(status::unexpected_condition, to_string(reason));
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment