diff --git a/src/signalrclient/hub_connection_impl.cpp b/src/signalrclient/hub_connection_impl.cpp index 32b38d1..935eba6 100644 --- a/src/signalrclient/hub_connection_impl.cpp +++ b/src/signalrclient/hub_connection_impl.cpp @@ -364,6 +364,12 @@ namespace signalr for (const auto& val : messages) { + // Protocol received an unknown message type and gave us a null object, close the connection like we do in other client implementations + if (val == nullptr) + { + throw std::runtime_error("null message received"); + } + switch (val->message_type) { case message_type::invocation: @@ -405,6 +411,9 @@ namespace signalr case message_type::close: // TODO break; + default: + throw std::runtime_error("unknown message type '" + std::to_string(static_cast(val->message_type)) + "' received"); + break; } } } @@ -412,7 +421,7 @@ namespace signalr { if (m_logger.is_enabled(trace_level::error)) { - m_logger.log(trace_level::error, std::string("error occured when parsing response: ") + m_logger.log(trace_level::error, std::string("error occurred when parsing response: ") .append(e.what()) .append(". response: ") .append(response)); diff --git a/src/signalrclient/json_helpers.cpp b/src/signalrclient/json_helpers.cpp index a266dd7..ee3cfad 100644 --- a/src/signalrclient/json_helpers.cpp +++ b/src/signalrclient/json_helpers.cpp @@ -9,6 +9,8 @@ namespace signalr { + char record_separator = '\x1e'; + signalr::value createValue(const Json::Value& v) { switch (v.type()) diff --git a/src/signalrclient/json_helpers.h b/src/signalrclient/json_helpers.h index 2432e56..10a78fc 100644 --- a/src/signalrclient/json_helpers.h +++ b/src/signalrclient/json_helpers.h @@ -10,7 +10,7 @@ namespace signalr { - static constexpr char record_separator = '\x1e'; + extern char record_separator; signalr::value createValue(const Json::Value& v); diff --git a/src/signalrclient/json_hub_protocol.cpp b/src/signalrclient/json_hub_protocol.cpp index 18d46c9..52280eb 100644 --- a/src/signalrclient/json_hub_protocol.cpp +++ b/src/signalrclient/json_hub_protocol.cpp @@ -70,7 +70,10 @@ namespace signalr while (pos != std::string::npos) { auto hub_message = parse_message(message.c_str() + offset, pos - offset); - vec.emplace_back(std::move(hub_message)); + if (hub_message != nullptr) + { + vec.push_back(std::move(hub_message)); + } offset = pos + 1; pos = message.find(record_separator, offset); diff --git a/test/signalrclienttests/hub_connection_tests.cpp b/test/signalrclienttests/hub_connection_tests.cpp index 046a6ac..8188ffa 100644 --- a/test/signalrclienttests/hub_connection_tests.cpp +++ b/test/signalrclienttests/hub_connection_tests.cpp @@ -1112,7 +1112,7 @@ TEST(hub_invocation, hub_connection_closes_when_invocation_response_missing_argu auto log_entries = std::dynamic_pointer_cast(writer)->get_log_entries(); ASSERT_EQ(2, log_entries.size()) << dump_vector(log_entries); - ASSERT_TRUE(has_log_entry("[error ] error occured when parsing response: Field 'arguments' not found for 'invocation' message. response: { \"type\": 1, \"target\": \"broadcast\" }\x1e\n", log_entries)) << dump_vector(log_entries); + ASSERT_TRUE(has_log_entry("[error ] error occurred when parsing response: Field 'arguments' not found for 'invocation' message. response: { \"type\": 1, \"target\": \"broadcast\" }\x1e\n", log_entries)) << dump_vector(log_entries); ASSERT_TRUE(has_log_entry("[error ] connection closed with error: Field 'arguments' not found for 'invocation' message\n", log_entries)) << dump_vector(log_entries); } @@ -1156,7 +1156,7 @@ TEST(hub_invocation, hub_connection_closes_when_invocation_response_missing_targ auto log_entries = std::dynamic_pointer_cast(writer)->get_log_entries(); ASSERT_EQ(2, log_entries.size()) << dump_vector(log_entries); - ASSERT_TRUE(has_log_entry("[error ] error occured when parsing response: Field 'target' not found for 'invocation' message. response: { \"type\": 1, \"arguments\": [] }\x1e\n", log_entries)) << dump_vector(log_entries); + ASSERT_TRUE(has_log_entry("[error ] error occurred when parsing response: Field 'target' not found for 'invocation' message. response: { \"type\": 1, \"arguments\": [] }\x1e\n", log_entries)) << dump_vector(log_entries); ASSERT_TRUE(has_log_entry("[error ] connection closed with error: Field 'target' not found for 'invocation' message\n", log_entries)) << dump_vector(log_entries); } @@ -1891,6 +1891,87 @@ TEST(send, throws_if_protocol_fails) ASSERT_EQ(connection_state::connected, hub_connection->get_connection_state()); } +class empty_hub_protocol : public hub_protocol +{ + virtual std::string write_message(const hub_message*) const override + { + return std::string { }; + } + virtual std::vector> parse_messages(const std::string& str) const override + { + auto vec = std::vector>(); + if (str.find("\"target\"") != std::string::npos) + { + vec.push_back(std::unique_ptr(new invocation_message("1", "target", std::vector()))); + } + else + { + vec.push_back(std::unique_ptr()); + } + return vec; + } + + virtual const std::string& name() const override + { + return m_protocol_name; + } + + virtual int version() const override + { + return 1; + } + + virtual signalr::transfer_format transfer_format() const override + { + return signalr::transfer_format::text; + } + +private: + std::string m_protocol_name = "json"; +}; + +TEST(receive, close_connection_on_null_hub_message) +{ + auto websocket_client = create_test_websocket_client(); + + std::shared_ptr writer(std::make_shared()); + auto hub_connection = hub_connection_impl::create("", std::move(std::unique_ptr(new empty_hub_protocol())), signalr::trace_level::info, writer, nullptr, [websocket_client](const signalr_client_config& config) + { + websocket_client->set_config(config); + return websocket_client; + }, true); + + auto close_mre = manual_reset_event(); + hub_connection->set_disconnected([&close_mre](std::exception_ptr exception) + { + close_mre.set(exception); + }); + + auto mre = manual_reset_event(); + hub_connection->start([&mre](std::exception_ptr exception) + { + mre.set(exception); + }); + + ASSERT_FALSE(websocket_client->receive_loop_started.wait(5000)); + ASSERT_FALSE(websocket_client->handshake_sent.wait(5000)); + websocket_client->receive_message("{ }\x1e"); + + mre.get(); + + websocket_client->receive_message("{ \"type\": 134 }\x1e"); + + try + { + close_mre.get(); + ASSERT_TRUE(false); + } + catch (const std::exception& ex) + { + ASSERT_STREQ("null message received", ex.what()); + } +} + TEST(keepalive, sends_ping_messages) { signalr_client_config config; @@ -2022,3 +2103,83 @@ TEST(keepalive, resets_server_timeout_timer_on_any_message_from_server) } ASSERT_EQ(connection_state::disconnected, hub_connection.get_connection_state()); } + +class unknown_message_type_hub_protocol : public hub_protocol +{ + class custom_hub_message : public hub_message + { + public: + custom_hub_message() : hub_message(static_cast(100)) {} + }; + virtual std::string write_message(const hub_message*) const override + { + return std::string{ }; + } + virtual std::vector> parse_messages(const std::string& str) const override + { + auto vec = std::vector>(); + vec.push_back(std::unique_ptr(new custom_hub_message())); + return vec; + } + + virtual const std::string& name() const override + { + return m_protocol_name; + } + + virtual int version() const override + { + return 1; + } + + virtual signalr::transfer_format transfer_format() const override + { + return signalr::transfer_format::text; + } + +private: + std::string m_protocol_name = "json"; +}; + +TEST(receive, unknown_message_type_closes_connection) +{ + auto websocket_client = create_test_websocket_client(); + + std::shared_ptr writer(std::make_shared()); + auto hub_connection = hub_connection_impl::create("", std::move(std::unique_ptr(new unknown_message_type_hub_protocol())), signalr::trace_level::info, writer, + nullptr, [websocket_client](const signalr_client_config& config) + { + websocket_client->set_config(config); + return websocket_client; + }, true); + + auto disconnect_mre = manual_reset_event(); + hub_connection->set_disconnected([&disconnect_mre](std::exception_ptr ex) + { + disconnect_mre.set(ex); + }); + + auto mre = manual_reset_event(); + hub_connection->start([&mre](std::exception_ptr exception) + { + mre.set(exception); + }); + + ASSERT_FALSE(websocket_client->receive_loop_started.wait(5000)); + ASSERT_FALSE(websocket_client->handshake_sent.wait(5000)); + websocket_client->receive_message("{ }\x1e"); + + mre.get(); + + websocket_client->receive_message("{ \"type\": 101 }\x1e"); + + try + { + disconnect_mre.get(); + ASSERT_TRUE(false); + } + catch (const std::exception& ex) + { + ASSERT_STREQ("unknown message type '100' received", ex.what()); + } +} \ No newline at end of file diff --git a/test/signalrclienttests/json_hub_protocol_tests.cpp b/test/signalrclienttests/json_hub_protocol_tests.cpp index 4d81156..a0192a0 100644 --- a/test/signalrclienttests/json_hub_protocol_tests.cpp +++ b/test/signalrclienttests/json_hub_protocol_tests.cpp @@ -128,6 +128,15 @@ TEST(json_hub_protocol, extra_items_ignored_when_parsing) assert_hub_message_equality(&message, output[0].get()); } +TEST(json_hub_protocol, unknown_message_type_returns_null) +{ + ping_message message = ping_message(); + // adding ping message, just make sure other messages are still being parsed + auto output = json_hub_protocol().parse_messages("{\"type\":142}\x1e{\"type\":6}\x1e"); + ASSERT_EQ(1, output.size()); + assert_hub_message_equality(&message, output[0].get()); +} + std::vector> invalid_messages { { "\x1e", "* Line 1, Column 1\n Syntax error: value, object or array expected.\n* Line 1, Column 1\n A valid JSON document must be either an array or an object value.\n" }, diff --git a/test/signalrclienttests/messagepack_hub_protocol_tests.cpp b/test/signalrclienttests/messagepack_hub_protocol_tests.cpp index d3a6ee7..127dd86 100644 --- a/test/signalrclienttests/messagepack_hub_protocol_tests.cpp +++ b/test/signalrclienttests/messagepack_hub_protocol_tests.cpp @@ -119,6 +119,17 @@ TEST(messagepack_hub_protocol, extra_items_ignored_when_parsing) assert_hub_message_equality(&message, output[0].get()); } +TEST(messagepack_hub_protocol, unknown_message_type_returns_null) +{ + ping_message message = ping_message(); + auto payload = string_from_bytes({ 0x04, 0x93, 0x6E, 0x80, 0xC0, + // adding ping message, just make sure other messages are still being parsed + 0x02, 0x91, 0x06 }); + auto output = messagepack_hub_protocol().parse_messages(payload); + ASSERT_EQ(1, output.size()); + assert_hub_message_equality(&message, output[0].get()); +} + namespace { std::vector> invalid_messages