diff --git a/samples/apps/txregulator/clients/loader.py b/samples/apps/txregulator/clients/loader.py index 010ea7ef9..c3ce7d9f1 100644 --- a/samples/apps/txregulator/clients/loader.py +++ b/samples/apps/txregulator/clients/loader.py @@ -23,7 +23,7 @@ class AppUser: network.consortium.add_users(primary, [self.name]) with primary.user_client(user_id=self.name) as client: - self.ccf_id = client.rpc("whoAmI", {}).result["caller_id"] + self.ccf_id = client.get("whoAmI").result["caller_id"] def __str__(self): return f"{self.ccf_id} ({self.name})" diff --git a/samples/apps/txregulator/clients/poll.py b/samples/apps/txregulator/clients/poll.py index 32184670e..6beb8eb2f 100644 --- a/samples/apps/txregulator/clients/poll.py +++ b/samples/apps/txregulator/clients/poll.py @@ -44,7 +44,7 @@ def run(args): ) as c: while True: time.sleep(1) - resp = reg_c.rpc("REG_poll_flagged", {}).to_dict() + resp = reg_c.rpc("REG_poll_flagged").to_dict() if "result" in resp: flagged_txs = resp["result"] diff --git a/samples/apps/txregulator/tests/txregulatorclient.py b/samples/apps/txregulator/tests/txregulatorclient.py index 5bbb58ffc..0d92c8014 100644 --- a/samples/apps/txregulator/tests/txregulatorclient.py +++ b/samples/apps/txregulator/tests/txregulatorclient.py @@ -23,7 +23,7 @@ class AppUser: network.consortium.add_users(primary, [self.name]) with primary.user_client(user_id=self.name) as client: - self.ccf_id = client.rpc("whoAmI", {}).result["caller_id"] + self.ccf_id = client.get("whoAmI").result["caller_id"] def __str__(self): return f"{self.ccf_id} ({self.name})" @@ -99,20 +99,20 @@ def run(args): # Check permissions are enforced with primary.user_client(user_id=regulator.name) as c: check( - c.rpc("REG_register", {}), + c.rpc("REG_register"), error=check_status(http.HTTPStatus.FORBIDDEN), ) check( - c.rpc("BK_register", {}), error=check_status(http.HTTPStatus.FORBIDDEN), + c.rpc("BK_register"), error=check_status(http.HTTPStatus.FORBIDDEN), ) with primary.user_client(user_id=banks[0].name) as c: check( - c.rpc("REG_register", {}), + c.rpc("REG_register"), error=check_status(http.HTTPStatus.FORBIDDEN), ) check( - c.rpc("BK_register", {}), error=check_status(http.HTTPStatus.FORBIDDEN), + c.rpc("BK_register"), error=check_status(http.HTTPStatus.FORBIDDEN), ) # As permissioned manager, register regulator and banks @@ -227,7 +227,7 @@ def run(args): with primary.user_client(user_id=bank.name) as c: # try to poll flagged but fail as you are not a regulator check( - c.rpc("REG_poll_flagged", {}), + c.rpc("REG_poll_flagged"), error=check_status(http.HTTPStatus.FORBIDDEN), ) @@ -248,7 +248,7 @@ def run(args): with primary.node_client() as mc: with primary.user_client(user_id=regulator.name) as c: # assert that the flagged txs that we poll for are correct - resp = c.rpc("REG_poll_flagged", {}) + resp = c.rpc("REG_poll_flagged") poll_flagged_ids = [] for poll_flagged in resp.result: # poll flagged is a list [tx_id, regulator_id] diff --git a/src/http/http_consts.h b/src/http/http_consts.h index 18563f1ab..afa96a3bd 100644 --- a/src/http/http_consts.h +++ b/src/http/http_consts.h @@ -7,17 +7,18 @@ namespace http namespace headers { // All HTTP headers are expected to be lowercase + static constexpr auto ALLOW = "allow"; static constexpr auto AUTHORIZATION = "authorization"; - static constexpr auto DIGEST = "digest"; - static constexpr auto CONTENT_TYPE = "content-type"; static constexpr auto CONTENT_LENGTH = "content-length"; + static constexpr auto CONTENT_TYPE = "content-type"; + static constexpr auto DIGEST = "digest"; static constexpr auto LOCATION = "location"; static constexpr auto WWW_AUTHENTICATE = "www-authenticate"; static constexpr auto CCF_COMMIT = "x-ccf-commit"; - static constexpr auto CCF_TERM = "x-ccf-term"; static constexpr auto CCF_GLOBAL_COMMIT = "x-ccf-global-commit"; static constexpr auto CCF_READ_ONLY = "x-ccf-read-only"; + static constexpr auto CCF_TERM = "x-ccf-term"; } namespace headervalues diff --git a/src/http/http_endpoint.h b/src/http/http_endpoint.h index 8de3b5ab5..e06a0afed 100644 --- a/src/http/http_endpoint.h +++ b/src/http/http_endpoint.h @@ -179,7 +179,7 @@ namespace http void handle_request( http_method verb, const std::string_view& path, - const std::string_view& query, + const std::string& query, http::HeaderMap&& headers, std::vector&& body) override { diff --git a/src/http/http_parser.h b/src/http/http_parser.h index f64597407..03aee1364 100644 --- a/src/http/http_parser.h +++ b/src/http/http_parser.h @@ -20,7 +20,7 @@ namespace http virtual void handle_request( http_method method, const std::string_view& path, - const std::string_view& query, + const std::string& query, HeaderMap&& headers, std::vector&& body) = 0; }; @@ -32,6 +32,51 @@ namespace http http_status status, HeaderMap&& headers, std::vector&& body) = 0; }; + static uint8_t hex_char_to_int(char c) + { + if (c <= '9') + { + return c - '0'; + } + else if (c <= 'F') + { + return c - 'A' + 10; + } + else if (c <= 'f') + { + return c - 'a' + 10; + } + return c; + } + + static void url_unescape(std::string& s) + { + char const* src = s.c_str(); + char const* end = s.c_str() + s.size(); + char* dst = s.data(); + + while (src < end) + { + char const c = *src++; + if (c == '%' && (src + 1) < end && isxdigit(src[0]) && isxdigit(src[1])) + { + const auto a = hex_char_to_int(*src++); + const auto b = hex_char_to_int(*src++); + *dst++ = (a << 4) | b; + } + else if (c == '+') + { + *dst++ = ' '; + } + else + { + *dst++ = c; + } + } + + s.resize(dst - s.data()); + } + struct SimpleRequestProcessor : public http::RequestProcessor { public: @@ -49,7 +94,7 @@ namespace http virtual void handle_request( http_method method, const std::string_view& path, - const std::string_view& query, + const std::string& query, http::HeaderMap&& headers, std::vector&& body) override { @@ -342,10 +387,12 @@ namespace http else { const auto [path, query] = parse_url(url); + std::string unescaped_query(query); + url_unescape(unescaped_query); proc.handle_request( http_method(parser.method), path, - query, + unescaped_query, std::move(headers), std::move(body_buf)); } diff --git a/src/http/http_rpc_context.h b/src/http/http_rpc_context.h index 03244b1cd..f223457ea 100644 --- a/src/http/http_rpc_context.h +++ b/src/http/http_rpc_context.h @@ -115,11 +115,12 @@ namespace http } const auto canonical_request_header = fmt::format( - "{} {} HTTP/1.1\r\n" + "{} {}{} HTTP/1.1\r\n" "{}" "\r\n", http_method_str(verb), - fmt::format("{}{}", whole_path, query), + whole_path, + query.empty() ? "" : fmt::format("?{}", query), http::get_header_string(request_headers)); serialised_request.resize( diff --git a/src/http/test/http_test.cpp b/src/http/test/http_test.cpp index 4333889b0..3fd74554a 100644 --- a/src/http/test/http_test.cpp +++ b/src/http/test/http_test.cpp @@ -274,4 +274,38 @@ DOCTEST_TEST_CASE("Pessimal transport") sp.received.pop(); } +} + +DOCTEST_TEST_CASE("Escaping") +{ + { + const std::string unescaped = + "This has many@many+many \\% \" AWKWARD :;-=?!& ++ characters %20%20"; + const std::string escaped = + "This+has+many%40many%2Bmany+%5C%25+%22+AWKWARD+%3A%3B-%3D%3F%21%26+%2B%" + "2b+" + "characters+%2520%2520"; + + std::string s = escaped; + http::url_unescape(s); + DOCTEST_REQUIRE(s == unescaped); + } + + { + const std::string request = + "GET /foo/bar?this=that&awkward=escaped+string+%3A%3B-%3D%3F%21%22 " + "HTTP/1.1\r\n\r\n"; + + http::SimpleRequestProcessor sp; + http::RequestParser p(sp); + + const std::vector req(request.begin(), request.end()); + auto parsed = p.execute(req.data(), req.size()); + + DOCTEST_CHECK(!sp.received.empty()); + const auto& m = sp.received.front(); + DOCTEST_CHECK(m.method == HTTP_GET); + DOCTEST_CHECK(m.path == "/foo/bar"); + DOCTEST_CHECK(m.query == "this=that&awkward=escaped string :;-=?!\""); + } } \ No newline at end of file diff --git a/src/node/rpc/commonhandlerregistry.h b/src/node/rpc/commonhandlerregistry.h index b508d6b84..a5f1ca60b 100644 --- a/src/node/rpc/commonhandlerregistry.h +++ b/src/node/rpc/commonhandlerregistry.h @@ -83,8 +83,7 @@ namespace ccf }; auto who_am_i = - [this]( - Store::Tx& tx, CallerId caller_id, const nlohmann::json& params) { + [this](Store::Tx& tx, CallerId caller_id, nlohmann::json&& params) { if (certs == nullptr) { return make_error( @@ -241,25 +240,32 @@ namespace ccf .set_auto_schema(); install(GeneralProcs::GET_METRICS, json_adapter(get_metrics), Read) .set_auto_schema() - .set_execute_locally(true); + .set_execute_locally(true) + .set_http_get_only(); install(GeneralProcs::MK_SIGN, json_adapter(make_signature), Write) .set_auto_schema(); install(GeneralProcs::WHO_AM_I, json_adapter(who_am_i), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install(GeneralProcs::WHO_IS, json_adapter(who_is), Read) .set_auto_schema(); install( GeneralProcs::GET_PRIMARY_INFO, json_adapter(get_primary_info), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install( GeneralProcs::GET_NETWORK_INFO, json_adapter(get_network_info), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install(GeneralProcs::LIST_METHODS, json_adapter(list_methods_fn), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install(GeneralProcs::GET_SCHEMA, json_adapter(get_schema), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install(GeneralProcs::GET_RECEIPT, json_adapter(get_receipt), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install(GeneralProcs::VERIFY_RECEIPT, json_adapter(verify_receipt), Read) .set_auto_schema(); } diff --git a/src/node/rpc/frontend.h b/src/node/rpc/frontend.h index 125198f0b..781601d45 100644 --- a/src/node/rpc/frontend.h +++ b/src/node/rpc/frontend.h @@ -403,6 +403,28 @@ namespace ccf if (handler == nullptr) { ctx->set_response_status(HTTP_STATUS_NOT_FOUND); + ctx->set_response_header( + http::headers::CONTENT_TYPE, http::headervalues::contenttype::TEXT); + ctx->set_response_body(fmt::format("Unknown RPC: {}", method)); + return ctx->serialise_response(); + } + + if (!(handler->allowed_verbs_mask & + verb_to_mask(ctx->get_request_verb()))) + { + ctx->set_response_status(HTTP_STATUS_METHOD_NOT_ALLOWED); + std::string allow_header_value; + bool first = true; + for (size_t verb = 0; verb <= HTTP_SOURCE; ++verb) + { + if (handler->allowed_verbs_mask & verb_to_mask(verb)) + { + allow_header_value += fmt::format( + "{}{}", (first ? "" : ", "), http_method_str((http_method)verb)); + first = false; + } + } + ctx->set_response_header(http::headers::ALLOW, allow_header_value); return ctx->serialise_response(); } diff --git a/src/node/rpc/handlerregistry.h b/src/node/rpc/handlerregistry.h index 1464c5254..9717bdd7e 100644 --- a/src/node/rpc/handlerregistry.h +++ b/src/node/rpc/handlerregistry.h @@ -4,11 +4,14 @@ #include "ds/json_schema.h" #include "enclave/rpccontext.h" +#include "http/http_consts.h" #include "node/certs.h" #include "serialization.h" #include +#include #include +#include namespace ccf { @@ -19,6 +22,11 @@ namespace ccf CallerId caller_id; }; + static uint64_t verb_to_mask(size_t verb) + { + return 1ul << verb; + } + using HandleFunction = std::function; class HandlerRegistry @@ -119,6 +127,34 @@ namespace ccf execute_locally = v; return *this; } + + // Bit mask. Bit i is 1 iff the http_method with value i is allowed. + // Default is that all verbs are allowed + uint64_t allowed_verbs_mask = ~0; + + Handler& set_allowed_verbs(std::set&& allowed_verbs) + { + // Reset mask to disallow everything + allowed_verbs_mask = 0; + + // Set bit for each allowed verb + for (const auto& verb : allowed_verbs) + { + allowed_verbs_mask |= verb_to_mask(verb); + } + + return *this; + } + + Handler& set_http_get_only() + { + return set_allowed_verbs({HTTP_GET}); + } + + Handler& set_http_post_only() + { + return set_allowed_verbs({HTTP_POST}); + } }; protected: diff --git a/src/node/rpc/jsonhandler.h b/src/node/rpc/jsonhandler.h index b32357e9f..b923899aa 100644 --- a/src/node/rpc/jsonhandler.h +++ b/src/node/rpc/jsonhandler.h @@ -182,7 +182,10 @@ namespace ccf const auto pack = detect_json_pack(ctx); nlohmann::json params = nullptr; - if (!ctx->get_request_body().empty()) + if ( + !ctx->get_request_body().empty() + // Body of GET is ignored + && ctx->get_request_verb() != HTTP_GET) { params = get_params_from_body(ctx, pack); } diff --git a/src/node/rpc/memberfrontend.h b/src/node/rpc/memberfrontend.h index cbfdb366b..ef5805ed7 100644 --- a/src/node/rpc/memberfrontend.h +++ b/src/node/rpc/memberfrontend.h @@ -465,7 +465,7 @@ namespace ccf auto read = [this]( Store::Tx& tx, CallerId caller_id, - const nlohmann::json& params) { + nlohmann::json&& params) { if (!check_member_status( tx, caller_id, {MemberStatus::ACTIVE, MemberStatus::ACCEPTED})) { @@ -496,8 +496,7 @@ namespace ccf .set_auto_schema(); auto query = - [this]( - Store::Tx& tx, CallerId caller_id, const nlohmann::json& params) { + [this](Store::Tx& tx, CallerId caller_id, nlohmann::json&& params) { if (!check_member_accepted(tx, caller_id)) { return make_error(HTTP_STATUS_FORBIDDEN, "Member is not accepted"); @@ -510,7 +509,7 @@ namespace ccf install(MemberProcs::QUERY, json_adapter(query), Read) .set_auto_schema(); - auto propose = [this](RequestArgs& args, const nlohmann::json& params) { + auto propose = [this](RequestArgs& args, nlohmann::json&& params) { if (!check_member_active(args.tx, args.caller_id)) { return make_error(HTTP_STATUS_FORBIDDEN, "Member is not active"); @@ -534,7 +533,7 @@ namespace ccf install(MemberProcs::PROPOSE, json_adapter(propose), Write) .set_auto_schema(); - auto withdraw = [this](RequestArgs& args, const nlohmann::json& params) { + auto withdraw = [this](RequestArgs& args, nlohmann::json&& params) { if (!check_member_active(args.tx, args.caller_id)) { return make_error(HTTP_STATUS_FORBIDDEN, "Member is not active"); @@ -586,7 +585,7 @@ namespace ccf .set_auto_schema() .set_require_client_signature(true); - auto vote = [this](RequestArgs& args, const nlohmann::json& params) { + auto vote = [this](RequestArgs& args, nlohmann::json&& params) { if (!check_member_active(args.tx, args.caller_id)) { return make_error(HTTP_STATUS_FORBIDDEN, "Member is not active"); @@ -634,8 +633,7 @@ namespace ccf .set_require_client_signature(true); auto complete = - [this]( - Store::Tx& tx, CallerId caller_id, const nlohmann::json& params) { + [this](Store::Tx& tx, CallerId caller_id, nlohmann::json&& params) { if (!check_member_active(tx, caller_id)) { return make_error(HTTP_STATUS_FORBIDDEN, "Member is not active"); @@ -661,7 +659,7 @@ namespace ccf .set_require_client_signature(true); //! A member acknowledges state - auto ack = [this](RequestArgs& args, const nlohmann::json& params) { + auto ack = [this](RequestArgs& args, nlohmann::json&& params) { const auto signed_request = args.rpc_ctx->get_signed_request(); auto [ma_view, sig_view] = @@ -707,8 +705,7 @@ namespace ccf //! A member asks for a fresher state digest auto update_state_digest = - [this]( - Store::Tx& tx, CallerId caller_id, const nlohmann::json& params) { + [this](Store::Tx& tx, CallerId caller_id, nlohmann::json&& params) { auto [ma_view, sig_view] = tx.get_view(this->network.member_acks, this->network.signatures); auto ma = ma_view->get(caller_id); @@ -737,7 +734,7 @@ namespace ccf .set_auto_schema(); auto get_encrypted_recovery_share = - [this](RequestArgs& args, const nlohmann::json& params) { + [this](RequestArgs& args, nlohmann::json&& params) { // This check should depend on whether new shares are emitted when a // new member is added (status = Accepted) or when the new member acks // (status = Active). @@ -781,7 +778,7 @@ namespace ccf auto submit_recovery_share = [this]( RequestArgs& args, - const nlohmann::json& params) { + nlohmann::json&& params) { // Only active members can submit their shares for recovery if (!check_member_active(args.tx, args.caller_id)) { @@ -831,7 +828,7 @@ namespace ccf Write) .set_auto_schema(); - auto create = [this](Store::Tx& tx, const nlohmann::json& params) { + auto create = [this](Store::Tx& tx, nlohmann::json&& params) { LOG_DEBUG_FMT("Processing create RPC"); const auto in = params.get(); diff --git a/src/node/rpc/nodefrontend.h b/src/node/rpc/nodefrontend.h index 4504d83f7..a920f504e 100644 --- a/src/node/rpc/nodefrontend.h +++ b/src/node/rpc/nodefrontend.h @@ -300,11 +300,14 @@ namespace ccf install(NodeProcs::JOIN, json_adapter(accept), Write); install(NodeProcs::GET_SIGNED_INDEX, json_adapter(get_signed_index), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install(NodeProcs::GET_NODE_QUOTE, json_adapter(get_quote), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); install(NodeProcs::GET_QUOTES, json_adapter(get_quotes), Read) - .set_auto_schema(); + .set_auto_schema() + .set_http_get_only(); } }; diff --git a/src/node/rpc/test/frontend_test.cpp b/src/node/rpc/test/frontend_test.cpp index 86c2fc670..21bb5f88e 100644 --- a/src/node/rpc/test/frontend_test.cpp +++ b/src/node/rpc/test/frontend_test.cpp @@ -97,7 +97,7 @@ public: "get_caller", json_adapter(get_caller_function), HandlerRegistry::Read); auto failable_function = - [this](Store::Tx& tx, CallerId caller_id, const nlohmann::json& params) { + [this](Store::Tx& tx, CallerId caller_id, nlohmann::json&& params) { const auto it = params.find("error"); if (it != params.end()) { @@ -113,6 +113,31 @@ public: } }; +class TestRestrictedVerbsFrontend : public SimpleUserRpcFrontend +{ +public: + TestRestrictedVerbsFrontend(Store& tables) : SimpleUserRpcFrontend(tables) + { + open(); + + auto get_only = [this](RequestArgs& args) { + args.rpc_ctx->set_response_status(HTTP_STATUS_OK); + }; + install("get_only", get_only, HandlerRegistry::Read).set_http_get_only(); + + auto post_only = [this](RequestArgs& args) { + args.rpc_ctx->set_response_status(HTTP_STATUS_OK); + }; + install("post_only", post_only, HandlerRegistry::Read).set_http_post_only(); + + auto put_or_delete = [this](RequestArgs& args) { + args.rpc_ctx->set_response_status(HTTP_STATUS_OK); + }; + install("put_or_delete", put_or_delete, HandlerRegistry::Read) + .set_allowed_verbs({HTTP_PUT, HTTP_DELETE}); + } +}; + class TestMemberFrontend : public MemberRpcFrontend { public: @@ -807,6 +832,81 @@ TEST_CASE("MinimalHandleFunction") } } +TEST_CASE("Restricted verbs") +{ + prepare_callers(); + TestRestrictedVerbsFrontend frontend(*network.tables); + + for (auto verb = HTTP_DELETE; verb <= HTTP_SOURCE; + verb = (http_method)(size_t(verb) + 1)) + { + INFO(http_method_str(verb)); + + { + http::Request get("get_only", verb); + const auto serialized_get = get.build_request(); + auto rpc_ctx = enclave::make_rpc_context(user_session, serialized_get); + const auto serialized_response = frontend.process(rpc_ctx).value(); + const auto response = parse_response(serialized_response); + if (verb == HTTP_GET) + { + CHECK(response.status == HTTP_STATUS_OK); + } + else + { + CHECK(response.status == HTTP_STATUS_METHOD_NOT_ALLOWED); + const auto it = response.headers.find(http::headers::ALLOW); + REQUIRE(it != response.headers.end()); + const auto v = it->second; + CHECK(v.find(http_method_str(HTTP_GET)) != std::string::npos); + } + } + + { + http::Request get("post_only", verb); + const auto serialized_post = get.build_request(); + auto rpc_ctx = enclave::make_rpc_context(user_session, serialized_post); + const auto serialized_response = frontend.process(rpc_ctx).value(); + const auto response = parse_response(serialized_response); + if (verb == HTTP_POST) + { + CHECK(response.status == HTTP_STATUS_OK); + } + else + { + CHECK(response.status == HTTP_STATUS_METHOD_NOT_ALLOWED); + const auto it = response.headers.find(http::headers::ALLOW); + REQUIRE(it != response.headers.end()); + const auto v = it->second; + CHECK(v.find(http_method_str(HTTP_POST)) != std::string::npos); + } + } + + { + http::Request get("put_or_delete", verb); + const auto serialized_put_or_delete = get.build_request(); + auto rpc_ctx = + enclave::make_rpc_context(user_session, serialized_put_or_delete); + const auto serialized_response = frontend.process(rpc_ctx).value(); + const auto response = parse_response(serialized_response); + if (verb == HTTP_PUT || verb == HTTP_DELETE) + { + CHECK(response.status == HTTP_STATUS_OK); + } + else + { + CHECK(response.status == HTTP_STATUS_METHOD_NOT_ALLOWED); + const auto it = response.headers.find(http::headers::ALLOW); + REQUIRE(it != response.headers.end()); + const auto v = it->second; + CHECK(v.find(http_method_str(HTTP_PUT)) != std::string::npos); + CHECK(v.find(http_method_str(HTTP_DELETE)) != std::string::npos); + CHECK(v.find(http_method_str(verb)) == std::string::npos); + } + } + } +} + TEST_CASE("Signed read requests can be executed on backup") { prepare_callers(); diff --git a/tests/e2e_logging.py b/tests/e2e_logging.py index b3ce72a11..abcea94b8 100644 --- a/tests/e2e_logging.py +++ b/tests/e2e_logging.py @@ -110,9 +110,9 @@ def test_forwarding_frontends(network, args): with primary.node_client() as nc: check_commit = infra.checker.Checker(nc) with backup.node_client() as c: - check_commit(c.do("mkSign", params={}), result=True) + check_commit(c.rpc("mkSign"), result=True) with backup.member_client() as c: - check_commit(c.do("mkSign", params={}), result=True) + check_commit(c.rpc("mkSign"), result=True) return network @@ -143,7 +143,7 @@ def test_update_lua(network, args): member_id=1, remote_node=primary, app_script=new_app_file ) with primary.user_client() as c: - check(c.rpc("ping", params={}), result="pong") + check(c.rpc("ping"), result="pong") LOG.debug("Check that former endpoints no longer exists") for endpoint in [ @@ -153,7 +153,7 @@ def test_update_lua(network, args): "LOG_get_pub", ]: check( - c.rpc(endpoint, params={}), + c.rpc(endpoint), error=lambda status, msg: status == http.HTTPStatus.NOT_FOUND.value, ) else: diff --git a/tests/e2e_scenarios.py b/tests/e2e_scenarios.py index e4f076043..d0a75b0b6 100644 --- a/tests/e2e_scenarios.py +++ b/tests/e2e_scenarios.py @@ -42,7 +42,7 @@ def run(args): check = infra.checker.Checker() check_commit = infra.checker.Checker(mc) with primary.user_client() as uc: - check_commit(uc.do("mkSign", params={}), result=True) + check_commit(uc.rpc("mkSign"), result=True) for connection in scenario["connections"]: with ( diff --git a/tests/election.py b/tests/election.py index 7c97adee9..fa07d256e 100644 --- a/tests/election.py +++ b/tests/election.py @@ -26,7 +26,7 @@ def wait_for_index_globally_committed(index, term, nodes): up_to_date_f = [] for f in nodes: with f.node_client() as c: - res = c.request("getCommit", {"commit": index}) + res = c.get("getCommit", {"commit": index}) if res.result["term"] == term and (res.global_commit >= index): up_to_date_f.append(f.node_id) if len(up_to_date_f) == len(nodes): @@ -45,6 +45,7 @@ def run(args): with infra.ccf.network( hosts, args.binary_dir, args.debug_nodes, args.perf_nodes, pdb=args.pdb ) as network: + check = infra.checker.Checker() network.start_and_join(args) current_term = None @@ -76,15 +77,15 @@ def run(args): ) commit_index = None with primary.user_client() as c: - res = c.do( + res = c.rpc( "LOG_record", { "id": current_term, "msg": "This log is committed in term {}".format(current_term), }, readonly_hint=None, - expected_result=True, ) + check(res, result=True) commit_index = res.commit LOG.debug("Waiting for transaction to be committed by all nodes") diff --git a/tests/governance.py b/tests/governance.py index ce13f3f1d..1983bb8c4 100644 --- a/tests/governance.py +++ b/tests/governance.py @@ -31,7 +31,7 @@ def run(args): with primary.node_client() as mc: check_commit = infra.checker.Checker(mc) check = infra.checker.Checker() - r = mc.rpc("getQuotes", {}) + r = mc.get("getQuotes") quotes = r.result["quotes"] assert len(quotes) == len(hosts) primary_quote = quotes[0] diff --git a/tests/infra/ccf.py b/tests/infra/ccf.py index f60817d3a..3c5213914 100644 --- a/tests/infra/ccf.py +++ b/tests/infra/ccf.py @@ -376,7 +376,7 @@ class Network: for _ in range(timeout): try: with node.node_client() as c: - r = c.request("getSignedIndex", {}) + r = c.get("getSignedIndex") if r.result["state"] == state: break except ConnectionRefusedError: @@ -411,7 +411,7 @@ class Network: for node in self.get_joined_nodes(): with node.node_client(request_timeout=request_timeout) as c: try: - res = c.do("getPrimaryInfo", {}) + res = c.get("getPrimaryInfo") if res.error is None: primary_id = res.result["primary_id"] term = res.term @@ -453,7 +453,7 @@ class Network: which added the nodes). """ with primary.node_client() as c: - res = c.do("getCommit", {}) + res = c.get("getCommit") local_commit_leader = res.commit term_leader = res.term @@ -461,7 +461,7 @@ class Network: caught_up_nodes = [] for node in self.get_joined_nodes(): with node.node_client() as c: - resp = c.request("getCommit", {}) + resp = c.get("getCommit") if resp.error is not None: # Node may not have joined the network yet, try again break @@ -486,7 +486,7 @@ class Network: commits = [] for node in self.get_joined_nodes(): with node.node_client() as c: - r = c.request("getCommit", {}) + r = c.get("getCommit") commits.append(r.commit) if [commits[0]] * len(commits) == commits: break diff --git a/tests/infra/checker.py b/tests/infra/checker.py index 07a7aac08..bbf01fb59 100644 --- a/tests/infra/checker.py +++ b/tests/infra/checker.py @@ -21,12 +21,12 @@ def wait_for_global_commit(node_client, commit_index, term, mksign=False, timeou # Forcing a signature accelerates this process for common operations # (e.g. governance proposals) if mksign: - r = node_client.rpc("mkSign", params={}) + r = node_client.rpc("mkSign") if r.error is not None: raise RuntimeError(f"mkSign returned an error: {r.error}") for i in range(timeout * 10): - r = node_client.rpc("getCommit", {"commit": commit_index}) + r = node_client.get("getCommit", {"commit": commit_index}) if r.global_commit >= commit_index and r.result["term"] == term: return time.sleep(0.1) diff --git a/tests/infra/clients.py b/tests/infra/clients.py index c4de3bbe0..00ae2e251 100644 --- a/tests/infra/clients.py +++ b/tests/infra/clients.py @@ -12,6 +12,7 @@ import subprocess import tempfile import base64 import requests +import urllib.parse from requests_http_signature import HTTPSignatureAuth from http.client import HTTPResponse from io import BytesIO @@ -36,10 +37,11 @@ CCF_READ_ONLY_HEADER = "x-ccf-read-only" class Request: - def __init__(self, method, params, readonly_hint=None): + def __init__(self, method, params=None, readonly_hint=None, http_verb="POST"): self.method = method self.params = params self.readonly_hint = readonly_hint + self.http_verb = http_verb def int_or_none(v): @@ -113,8 +115,8 @@ def human_readable_size(n): class RPCLogger: def log_request(self, request, name, description): LOG.info( - f"{name} {request.method} " - + truncate(f"{request.params}") + f"{name} {request.http_verb} /{request.method}" + + (truncate(f" {request.params}") if request.params is not None else "") + ( f" (RO hint: {request.readonly_hint})" if request.readonly_hint is not None @@ -145,7 +147,7 @@ class RPCFileLogger(RPCLogger): def log_request(self, request, name, description): with open(self.path, "a") as f: - f.write(f">> Request: {request.method}" + os.linesep) + f.write(f">> Request: {request.http_verb} /{request.method}" + os.linesep) json.dump(request.params, f, indent=2) f.write(os.linesep) @@ -160,6 +162,13 @@ class CCFConnectionException(Exception): pass +def build_query_string(params): + return "&".join( + f"{urllib.parse.quote_plus(k)}={urllib.parse.quote_plus(json.dumps(v))}" + for k, v in params.items() + ) + + class CurlClient: """ We keep this around in a limited fashion still, because @@ -190,25 +199,39 @@ class CurlClient: def _just_request(self, request, is_signed=False): with tempfile.NamedTemporaryFile() as nf: - msg = json.dumps(request.params).encode() - LOG.debug(f"Going to call {request.method} with {msg}") - nf.write(msg) - nf.flush() if is_signed: cmd = [os.path.join(self.binary_dir, "scurl.sh")] else: cmd = ["curl"] + url = f"https://{self.host}:{self.port}/{request.method}" + + is_get = request.http_verb == "GET" + if is_get: + if request.params is not None: + url += f"?{build_query_string(request.params)}" + cmd += [ - f"https://{self.host}:{self.port}/{request.method}", + url, + "-X", + request.http_verb, "-H", "Content-Type: application/json", - "--data-binary", - f"@{nf.name}", "-i", f"-m {self.request_timeout}", ] + if not is_get: + msg = ( + json.dumps(request.params).encode() + if request.params is not None + else bytes() + ) + LOG.debug(f"Writing request body: {msg}") + nf.write(msg) + nf.flush() + cmd.extend(["--data-binary", f"@{nf.name}"]) + if request.readonly_hint: cmd.extend(["-H", f"{CCF_READ_ONLY_HEADER}: true"]) @@ -296,13 +319,21 @@ class RequestClient: if request.readonly_hint: extra_headers[CCF_READ_ONLY_HEADER] = "true" - response = self.session.post( - f"https://{self.host}:{self.port}/{request.method}", - json=request.params, - timeout=self.request_timeout, - auth=auth_value, - headers=extra_headers, - ) + request_args = { + "method": request.http_verb, + "url": f"https://{self.host}:{self.port}/{request.method}", + "auth": auth_value, + "headers": extra_headers, + } + + is_get = request.http_verb == "GET" + if request.params is not None: + if is_get: + request_args["params"] = build_query_string(request.params) + else: + request_args["json"] = request.params + + response = self.session.request(timeout=self.request_timeout, **request_args) return Response.from_requests_response(response) def _request(self, request, is_signed=False): @@ -382,8 +413,8 @@ class CCFClient: logger.log_response(response) return response - def request(self, method, params, *args, **kwargs): - r = Request(f"{self.prefix}/{method}", params, *args, **kwargs) + def request(self, method, *args, **kwargs): + r = Request(f"{self.prefix}/{method}", *args, **kwargs) description = "" if self.description: description = f" ({self.description})" @@ -392,8 +423,8 @@ class CCFClient: return self._response(self.client_impl.request(r)) - def signed_request(self, method, params, *args, **kwargs): - r = Request(f"{self.prefix}/{method}", params, *args, **kwargs) + def signed_request(self, method, *args, **kwargs): + r = Request(f"{self.prefix}/{method}", *args, **kwargs) description = "" if self.description: @@ -403,29 +434,15 @@ class CCFClient: return self._response(self.client_impl.signed_request(r)) - def do(self, *args, **kwargs): - expected_result = None - expected_error_code = None - if "expected_result" in kwargs: - expected_result = kwargs.pop("expected_result") - if "expected_error_code" in kwargs: - expected_error_code = kwargs.pop("expected_error_code") - - r = self.rpc(*args, **kwargs) - - if expected_result is not None: - assert expected_result == r.result - - if expected_error_code is not None: - assert expected_error_code == r.error["code"] - return r - def rpc(self, *args, **kwargs): if "signed" in kwargs and kwargs.pop("signed"): return self.signed_request(*args, **kwargs) else: return self.request(*args, **kwargs) + def get(self, *args, **kwargs): + return self.rpc(*args, http_verb="GET", **kwargs) + @contextlib.contextmanager def client( diff --git a/tests/infra/consortium.py b/tests/infra/consortium.py index 1d4daf0d2..0ae7e7b88 100644 --- a/tests/infra/consortium.py +++ b/tests/infra/consortium.py @@ -136,7 +136,7 @@ class Consortium: def update_ack_state_digest(self, member_id, remote_node): with remote_node.member_client(member_id=member_id) as mc: - res = mc.rpc("updateAckStateDigest", params={}) + res = mc.rpc("updateAckStateDigest") return bytearray(res.result["state_digest"]) def ack(self, member_id, remote_node): @@ -156,7 +156,7 @@ class Consortium: """ with remote_node.member_client(member_id=member_id) as c: - rep = c.do("query", {"text": script}) + rep = c.rpc("query", {"text": script}) return rep.result def propose_retire_node(self, member_id, remote_node, node_id): @@ -312,7 +312,7 @@ class Consortium: def get_decrypt_and_submit_shares(self, remote_node): for m in self.members: with remote_node.member_client(member_id=m) as mc: - r = mc.rpc("getEncryptedRecoveryShare", params={}) + r = mc.rpc("getEncryptedRecoveryShare") # For now, members rely on a copy of the original network encryption public key ctx = infra.crypto.CryptoBoxCtx( @@ -356,7 +356,7 @@ class Consortium: # When opening the service in PBFT, the first transaction to be # completed when f = 1 takes a significant amount of time with remote_node.member_client(request_timeout=(30 if pbft_open else 3)) as c: - rep = c.do( + rep = c.rpc( "query", { "text": """tables = ... @@ -378,7 +378,7 @@ class Consortium: def _check_node_exists(self, remote_node, node_id, node_status=None): with remote_node.member_client() as c: - rep = c.do("read", {"table": "ccf.nodes", "key": node_id}) + rep = c.rpc("read", {"table": "ccf.nodes", "key": node_id}) if rep.error is not None or ( node_status and rep.result["status"] != node_status.name diff --git a/tests/infra/node.py b/tests/infra/node.py index bf0579249..2e3c57eb3 100644 --- a/tests/infra/node.py +++ b/tests/infra/node.py @@ -204,7 +204,7 @@ class Node: # is not yet endorsed by the network certificate try: with self.node_client(connection_timeout=timeout) as nc: - rep = nc.do("getCommit", {}) + rep = nc.get("getCommit") assert ( rep.error is None and rep.result is not None ), f"An error occured after node {self.node_id} joined the network" diff --git a/tests/infra/rates.py b/tests/infra/rates.py index 3c7ca9315..5af84b217 100644 --- a/tests/infra/rates.py +++ b/tests/infra/rates.py @@ -55,7 +55,7 @@ class TxRates: def process_next(self): with self.primary.user_client() as client: - rv = client.rpc("getCommit", {}) + rv = client.get("getCommit") result = rv.to_dict() next_commit = result["result"]["commit"] more_to_process = self.commit != next_commit @@ -65,7 +65,7 @@ class TxRates: def get_metrics(self): with self.primary.user_client() as client: - rv = client.rpc("getMetrics", {}) + rv = client.get("getMetrics") result = rv.to_dict() result = result["result"] self.all_metrics = result diff --git a/tests/receipts.py b/tests/receipts.py index 779bbd0d1..75d85d365 100644 --- a/tests/receipts.py +++ b/tests/receipts.py @@ -34,7 +34,7 @@ def test(network, args, notifications_queue=None): check_commit(c.rpc("LOG_record", {"id": 42, "msg": msg}), result=True) r = c.rpc("LOG_get", {"id": 42}) check(r, result={"msg": msg}) - r = c.rpc("getReceipt", {"commit": r.commit}) + r = c.get("getReceipt", {"commit": r.commit}) check( c.rpc("verifyReceipt", {"receipt": r.result["receipt"]}), result={"valid": True}, diff --git a/tests/reconfiguration.py b/tests/reconfiguration.py index 3416d1ede..b6e12e227 100644 --- a/tests/reconfiguration.py +++ b/tests/reconfiguration.py @@ -17,7 +17,7 @@ def check_can_progress(node): with node.node_client() as mc: check_commit = infra.checker.Checker(mc) with node.node_client() as c: - check_commit(c.rpc("mkSign", params={}), result=True) + check_commit(c.rpc("mkSign"), result=True) @reqs.description("Adding a valid node from primary") diff --git a/tests/rekey.py b/tests/rekey.py index 1beaf20aa..a8c27e73a 100644 --- a/tests/rekey.py +++ b/tests/rekey.py @@ -18,7 +18,7 @@ def test(network, args): # Retrieve current index version to check for sealed secrets later with primary.node_client() as nc: check_commit = infra.checker.Checker(nc) - res = nc.rpc("mkSign", params={}) + res = nc.rpc("mkSign") check_commit(res, result=True) version_before_rekey = res.commit diff --git a/tests/schema.py b/tests/schema.py index 3a0b78066..26e8e4556 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -4,6 +4,7 @@ import os import sys import getpass import json +import http import time import logging import multiprocessing @@ -25,14 +26,19 @@ def run(args): methods_without_schema = set() def fetch_schema(client): - list_response = client.rpc("listMethods", {}) - check(list_response) + list_response = client.get("listMethods") + check( + list_response, error=lambda status, msg: status == http.HTTPStatus.OK.value + ) methods = list_response.result["methods"] for method in methods: schema_found = False - schema_response = client.rpc("getSchema", {"method": method}) - check(schema_response) + schema_response = client.get(f"getSchema", params={"method": method}) + check( + schema_response, + error=lambda status, msg: status == http.HTTPStatus.OK.value, + ) if schema_response.result is not None: for schema_type in ["params", "result"]: diff --git a/tests/suite/test_requirements.py b/tests/suite/test_requirements.py index dc2a6ba2b..17da2e85f 100644 --- a/tests/suite/test_requirements.py +++ b/tests/suite/test_requirements.py @@ -50,7 +50,7 @@ def supports_methods(*methods): def check(network, args, *nargs, **kwargs): primary, term = network.find_primary() with primary.user_client() as c: - response = c.rpc("listMethods", {}) + response = c.get("listMethods") supported_methods = response.result["methods"] missing = {*methods}.difference(supported_methods) if missing: