Commit 7545d114 authored by Dominik Charousset's avatar Dominik Charousset

Implement heartbeats, drop magic number

parent 27301958
...@@ -26,10 +26,12 @@ ...@@ -26,10 +26,12 @@
#include "caf/byte.hpp" #include "caf/byte.hpp"
#include "caf/error.hpp" #include "caf/error.hpp"
#include "caf/net/basp/connection_state.hpp" #include "caf/net/basp/connection_state.hpp"
#include "caf/net/basp/constants.hpp"
#include "caf/net/basp/header.hpp" #include "caf/net/basp/header.hpp"
#include "caf/net/basp/message_type.hpp" #include "caf/net/basp/message_type.hpp"
#include "caf/net/endpoint_manager.hpp" #include "caf/net/endpoint_manager.hpp"
#include "caf/node_id.hpp" #include "caf/node_id.hpp"
#include "caf/serializer_impl.hpp"
#include "caf/span.hpp" #include "caf/span.hpp"
namespace caf { namespace caf {
...@@ -40,11 +42,20 @@ class application { ...@@ -40,11 +42,20 @@ class application {
public: public:
// -- member types ----------------------------------------------------------- // -- member types -----------------------------------------------------------
using buffer_type = std::vector<byte>;
// -- interface functions ---------------------------------------------------- // -- interface functions ----------------------------------------------------
template <class Parent> template <class Parent>
error init(Parent& parent) { error init(Parent& parent) {
// Initialize member variables.
system_ = &parent.system(); system_ = &parent.system();
// Write handshake.
if (auto err = generate_handshake())
return err;
auto hdr = to_bytes(header{message_type::handshake,
static_cast<uint32_t>(buf_.size()), version});
parent.write_packet(hdr, buf_);
return none; return none;
} }
...@@ -97,17 +108,23 @@ private: ...@@ -97,17 +108,23 @@ private:
error handle_handshake(header hdr, span<const byte> payload); error handle_handshake(header hdr, span<const byte> payload);
/// Writes the handshake payload to `buf_`.
error generate_handshake();
// -- member variables ------------------------------------------------------- // -- member variables -------------------------------------------------------
/// Stores a pointer to the parent actor system. /// Stores a pointer to the parent actor system.
actor_system* system_ = nullptr; actor_system* system_ = nullptr;
/// Stores what we are expecting to receive next. /// Stores what we are expecting to receive next.
connection_state state_ = connection_state::await_magic_number; connection_state state_ = connection_state::await_handshake_header;
/// Caches the last header;we need to store it when waiting for the payload. /// Caches the last header;we need to store it when waiting for the payload.
header hdr_; header hdr_;
/// Re-usable buffer for storing payloads.
buffer_type buf_;
/// Stores our own ID. /// Stores our own ID.
node_id id_; node_id id_;
......
...@@ -28,9 +28,6 @@ namespace basp { ...@@ -28,9 +28,6 @@ namespace basp {
/// Stores the state of a connection in a `basp::application`. /// Stores the state of a connection in a `basp::application`.
enum class connection_state { enum class connection_state {
/// Indicates that we have just accepted or opened a connection and await the
/// magic number.
await_magic_number,
/// Indicates that we successfully checked the magic number and now wait for /// Indicates that we successfully checked the magic number and now wait for
/// the handshake header. /// the handshake header.
await_handshake_header, await_handshake_header,
......
...@@ -30,9 +30,6 @@ namespace basp { ...@@ -30,9 +30,6 @@ namespace basp {
/// @note BASP is not backwards compatible. /// @note BASP is not backwards compatible.
constexpr uint64_t version = 1; constexpr uint64_t version = 1;
/// The very first thing clients send before the first header.
constexpr uint32_t magic_number = 0xCAFC0DE5;
/// @} /// @}
} // namespace basp } // namespace basp
......
...@@ -49,16 +49,6 @@ expected<std::vector<byte>> application::serialize(actor_system& sys, ...@@ -49,16 +49,6 @@ expected<std::vector<byte>> application::serialize(actor_system& sys,
error application::handle(span<const byte> bytes) { error application::handle(span<const byte> bytes) {
switch (state_) { switch (state_) {
case connection_state::await_magic_number: {
if (bytes.size() != sizeof(uint32_t))
return ec::unexpected_number_of_bytes;
auto xptr = reinterpret_cast<const uint32_t*>(bytes.data());
auto x = detail::from_network_order(*xptr);
if (x != magic_number)
return ec::invalid_magic_number;
state_ = connection_state::await_handshake_header;
return none;
}
case connection_state::await_handshake_header: { case connection_state::await_handshake_header: {
if (bytes.size() != header_size) if (bytes.size() != header_size)
return ec::unexpected_number_of_bytes; return ec::unexpected_number_of_bytes;
...@@ -87,6 +77,7 @@ error application::handle(span<const byte> bytes) { ...@@ -87,6 +77,7 @@ error application::handle(span<const byte> bytes) {
}; };
if (std::none_of(app_ids.begin(), app_ids.end(), predicate)) if (std::none_of(app_ids.begin(), app_ids.end(), predicate))
return ec::app_identifiers_mismatch; return ec::app_identifiers_mismatch;
state_ = connection_state::await_header;
return none; return none;
} }
case connection_state::await_header: { case connection_state::await_header: {
...@@ -109,8 +100,13 @@ error application::handle(span<const byte> bytes) { ...@@ -109,8 +100,13 @@ error application::handle(span<const byte> bytes) {
} }
} }
error application::handle(header, span<const byte>) { error application::handle(header hdr, span<const byte>) {
switch (hdr.type) {
case message_type::heartbeat:
return none;
default:
return ec::unimplemented; return ec::unimplemented;
}
} }
error application::handle_handshake(header hdr, span<const byte> payload) { error application::handle_handshake(header hdr, span<const byte> payload) {
...@@ -129,6 +125,13 @@ error application::handle_handshake(header hdr, span<const byte> payload) { ...@@ -129,6 +125,13 @@ error application::handle_handshake(header hdr, span<const byte> payload) {
return none; return none;
} }
error application::generate_handshake() {
serializer_impl<buffer_type> sink{system(), buf_};
return sink(system().node(),
get_or(system().config(), "middleman.app-identifiers",
defaults::middleman::app_identifiers));
}
} // namespace basp } // namespace basp
} // namespace net } // namespace net
} // namespace caf } // namespace caf
...@@ -33,7 +33,7 @@ int header::compare(header other) const noexcept { ...@@ -33,7 +33,7 @@ int header::compare(header other) const noexcept {
} }
header header::from_bytes(span<const byte> bytes) { header header::from_bytes(span<const byte> bytes) {
CAF_ASSERT(bytes.size() == header_size); CAF_ASSERT(bytes.size() >= header_size);
header result; header result;
auto ptr = bytes.data(); auto ptr = bytes.data();
result.type = *reinterpret_cast<const message_type*>(ptr); result.type = *reinterpret_cast<const message_type*>(ptr);
......
...@@ -63,10 +63,9 @@ struct fixture : test_coordinator_fixture<> { ...@@ -63,10 +63,9 @@ struct fixture : test_coordinator_fixture<> {
input = to_buf(xs...); input = to_buf(xs...);
} }
void handle_magic_number() { void write_packet(span<const byte> hdr, span<const byte> payload) {
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_magic_number); output.insert(output.end(), hdr.begin(), hdr.end());
set_input(basp::magic_number); output.insert(output.end(), payload.begin(), payload.end());
REQUIRE_OK(app.handle_data(*this, input));
} }
void handle_handshake() { void handle_handshake() {
...@@ -82,6 +81,24 @@ struct fixture : test_coordinator_fixture<> { ...@@ -82,6 +81,24 @@ struct fixture : test_coordinator_fixture<> {
REQUIRE_OK(app.handle_data(*this, payload)); REQUIRE_OK(app.handle_data(*this, payload));
} }
void consume_handshake() {
if (output.size() < basp::header_size)
CAF_FAIL("BASP application did not write a handshake header");
auto hdr = basp::header::from_bytes(output);
if (hdr.type != basp::message_type::handshake || hdr.payload_len == 0
|| hdr.operation_data != basp::version)
CAF_FAIL("invalid handshake header");
node_id nid;
std::vector<std::string> app_ids;
binary_deserializer source{sys, output};
source.skip(basp::header_size);
if (auto err = source(nid, app_ids))
CAF_FAIL("unable to deserialize payload: " << sys.render(err));
if (source.remaining() > 0)
CAF_FAIL("trailing bytes after reading payload");
output.clear();
}
actor_system& system() { actor_system& system() {
return sys; return sys;
} }
...@@ -99,36 +116,25 @@ struct fixture : test_coordinator_fixture<> { ...@@ -99,36 +116,25 @@ struct fixture : test_coordinator_fixture<> {
CAF_TEST_FIXTURE_SCOPE(application_tests, fixture) CAF_TEST_FIXTURE_SCOPE(application_tests, fixture)
CAF_TEST(invalid magic number) {
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_magic_number);
set_input(basp::magic_number + 1);
CAF_CHECK_EQUAL(app.handle_data(*this, input),
basp::ec::invalid_magic_number);
}
CAF_TEST(missing handshake) { CAF_TEST(missing handshake) {
handle_magic_number();
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header); CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header);
set_input(basp::header{basp::message_type::heartbeat, 0, 0}); set_input(basp::header{basp::message_type::heartbeat, 0, 0});
CAF_CHECK_EQUAL(app.handle_data(*this, input), basp::ec::missing_handshake); CAF_CHECK_EQUAL(app.handle_data(*this, input), basp::ec::missing_handshake);
} }
CAF_TEST(version mismatch) { CAF_TEST(version mismatch) {
handle_magic_number();
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header); CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header);
set_input(basp::header{basp::message_type::handshake, 0, 0}); set_input(basp::header{basp::message_type::handshake, 0, 0});
CAF_CHECK_EQUAL(app.handle_data(*this, input), basp::ec::version_mismatch); CAF_CHECK_EQUAL(app.handle_data(*this, input), basp::ec::version_mismatch);
} }
CAF_TEST(missing payload in handshake) { CAF_TEST(missing payload in handshake) {
handle_magic_number();
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header); CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header);
set_input(basp::header{basp::message_type::handshake, 0, basp::version}); set_input(basp::header{basp::message_type::handshake, 0, basp::version});
CAF_CHECK_EQUAL(app.handle_data(*this, input), basp::ec::missing_payload); CAF_CHECK_EQUAL(app.handle_data(*this, input), basp::ec::missing_payload);
} }
CAF_TEST(invalid handshake) { CAF_TEST(invalid handshake) {
handle_magic_number();
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header); CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header);
node_id no_nid; node_id no_nid;
std::vector<std::string> no_ids; std::vector<std::string> no_ids;
...@@ -141,7 +147,6 @@ CAF_TEST(invalid handshake) { ...@@ -141,7 +147,6 @@ CAF_TEST(invalid handshake) {
} }
CAF_TEST(app identifier mismatch) { CAF_TEST(app identifier mismatch) {
handle_magic_number();
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header); CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_handshake_header);
std::vector<std::string> wrong_ids{"YOLO!!!"}; std::vector<std::string> wrong_ids{"YOLO!!!"};
auto payload = to_buf(mars, wrong_ids); auto payload = to_buf(mars, wrong_ids);
...@@ -153,4 +158,14 @@ CAF_TEST(app identifier mismatch) { ...@@ -153,4 +158,14 @@ CAF_TEST(app identifier mismatch) {
basp::ec::app_identifiers_mismatch); basp::ec::app_identifiers_mismatch);
} }
CAF_TEST(heartbeat message) {
handle_handshake();
consume_handshake();
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_header);
auto bytes = to_bytes(basp::header{basp::message_type::heartbeat, 0, 0});
set_input(bytes);
REQUIRE_OK(app.handle_data(*this, input));
CAF_CHECK_EQUAL(app.state(), basp::connection_state::await_header);
}
CAF_TEST_FIXTURE_SCOPE_END() CAF_TEST_FIXTURE_SCOPE_END()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment