diff --git a/samples/apps/smallbank/tests/small_bank_client.py b/samples/apps/smallbank/tests/small_bank_client.py index 56df4ec88..0db1e1229 100644 --- a/samples/apps/smallbank/tests/small_bank_client.py +++ b/samples/apps/smallbank/tests/small_bank_client.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the Apache 2.0 License. -import client +import perfclient import sys import os @@ -11,7 +11,7 @@ if __name__ == "__main__": "-u", "--accounts", help="Number of accounts", default=10, type=int ) - args, unknown_args = client.cli_args(add=add, accept_unknown=True) + args, unknown_args = perfclient.cli_args(add=add, accept_unknown=True) unknown_args = [term for arg in unknown_args for term in arg.split(" ")] @@ -19,4 +19,4 @@ if __name__ == "__main__": return [*common_args, "--accounts", str(args.accounts)] + unknown_args args.package = "libsmallbankenc" - client.run(args.build_dir, get_command, args) + perfclient.run(args.build_dir, get_command, args) diff --git a/samples/apps/txregulator/clients/poll.py b/samples/apps/txregulator/clients/poll.py index baca71eae..d240d3abd 100644 --- a/samples/apps/txregulator/clients/poll.py +++ b/samples/apps/txregulator/clients/poll.py @@ -8,7 +8,6 @@ import time import csv from loguru import logger as LOG import argparse -from infra.jsonrpc import client import json @@ -33,7 +32,7 @@ def run(args): format="msgpack", cert="user{}_cert.pem".format(args.regulator_name), key="user{}_privk.pem".format(args.regulator_name), - cafile="networkcert.pem", + ca="networkcert.pem", ) as reg_c: with client( host=args.host, @@ -41,7 +40,7 @@ def run(args): format="msgpack", cert="user{}_cert.pem".format(args.bank_name), key="user{}_privk.pem".format(args.bank_name), - cafile="networkcert.pem", + ca="networkcert.pem", ) as c: while True: time.sleep(1) diff --git a/samples/apps/txregulator/tests/txregulatorclient.py b/samples/apps/txregulator/tests/txregulatorclient.py index 0263cf960..25668e4a3 100644 --- a/samples/apps/txregulator/tests/txregulatorclient.py +++ b/samples/apps/txregulator/tests/txregulatorclient.py @@ -2,6 +2,7 @@ # Licensed under the Apache 2.0 License. import e2e_args import infra.ccf +import infra.jsonrpc import logging from time import gmtime, strftime diff --git a/src/enclave/http.h b/src/enclave/http.h index 3c620a2ad..4e10e3b2e 100644 --- a/src/enclave/http.h +++ b/src/enclave/http.h @@ -73,7 +73,9 @@ namespace enclave { auto parsed = http_parser_execute(&parser, &settings, (const char*)data, size); + LOG_TRACE_FMT("Parsed {} bytes", parsed); + auto err = HTTP_PARSER_ERRNO(&parser); if (err) { @@ -83,9 +85,6 @@ namespace enclave http_errno_description(err))); } - LOG_TRACE_FMT( - "Parsed a {} request", http_method_str(http_method(parser.method))); - // TODO: check for http->upgrade to support websockets return parsed; } @@ -230,21 +229,25 @@ namespace enclave LOG_TRACE_FMT("recv called with {} bytes", size); - auto buf = read_all_available(); - - if (buf.size() == 0) - return; - - LOG_TRACE_FMT( - "Going to parse {} bytes: [{}]", - buf.size(), - std::string(buf.begin(), buf.end())); - - // TODO: This should return an error to the client if this fails - if (p.execute(buf.data(), buf.size()) == 0) + while (true) { - LOG_FAIL_FMT("Failed to parse request"); - return; + auto buf = read(4096, false); + if (buf.size() == 0) + { + return; + } + + LOG_TRACE_FMT( + "Going to parse {} bytes: \n[{}]", + buf.size(), + std::string(buf.begin(), buf.end())); + + // TODO: This should return an error to the client if this fails + if (p.execute(buf.data(), buf.size()) == 0) + { + LOG_FAIL_FMT("Failed to parse request"); + return; + } } } diff --git a/src/enclave/tlsendpoint.h b/src/enclave/tlsendpoint.h index 62f954323..1d7af13bc 100644 --- a/src/enclave/tlsendpoint.h +++ b/src/enclave/tlsendpoint.h @@ -193,30 +193,6 @@ namespace enclave return data; } - std::vector read_all_available() - { - constexpr auto read_size = 4096; - auto buf = read(read_size, false); - - if (buf.size() == read_size) - { - while (true) - { - const auto more = read(read_size, false); - - buf.insert(buf.end(), more.begin(), more.end()); - - if (more.size() != read_size) - { - break; - } - } - } - - LOG_TRACE_FMT("read_all_available returning {} bytes", buf.size()); - return buf; - } - void recv(const uint8_t* data, size_t size) { pending_read.insert(pending_read.end(), data, data + size); diff --git a/start_test_network.sh b/start_test_network.sh index dd56f0f9b..e21b00de1 100755 --- a/start_test_network.sh +++ b/start_test_network.sh @@ -18,4 +18,4 @@ source env/bin/activate pip install -q -U -r ../tests/requirements.txt echo "Python environment successfully setup" -python ../tests/start_network.py --package "$1" --label test_network +CURL_CLIENT=ON python ../tests/start_network.py --package "$1" --label test_network \ No newline at end of file diff --git a/tests/connections.py b/tests/connections.py index 63d32bc8a..9d663eedb 100644 --- a/tests/connections.py +++ b/tests/connections.py @@ -13,7 +13,6 @@ import multiprocessing from random import seed import infra.ccf import infra.proc -import infra.jsonrpc import json import contextlib import resource diff --git a/tests/e2e_batched.py b/tests/e2e_batched.py index 73b3d0be2..13f922a19 100644 --- a/tests/e2e_batched.py +++ b/tests/e2e_batched.py @@ -6,7 +6,6 @@ import time import infra.ccf import infra.proc -import infra.jsonrpc import infra.notification import infra.net import suite.test_requirements as reqs diff --git a/tests/e2e_logging.py b/tests/e2e_logging.py index 9a90c8d39..148d93e2a 100644 --- a/tests/e2e_logging.py +++ b/tests/e2e_logging.py @@ -1,17 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the Apache 2.0 License. -import os -import getpass -import time -import logging -import multiprocessing -import shutil -from random import seed import infra.ccf -import infra.proc import infra.jsonrpc import infra.notification -import infra.net import suite.test_requirements as reqs import e2e_args diff --git a/tests/e2e_logging_pbft.py b/tests/e2e_logging_pbft.py index ccb2027fb..75497921f 100644 --- a/tests/e2e_logging_pbft.py +++ b/tests/e2e_logging_pbft.py @@ -9,7 +9,6 @@ import shutil from random import seed import infra.ccf import infra.proc -import infra.jsonrpc import infra.notification import infra.net import e2e_args diff --git a/tests/election.py b/tests/election.py index 1873af785..ba03b256d 100644 --- a/tests/election.py +++ b/tests/election.py @@ -7,7 +7,6 @@ import time import math import infra.ccf import infra.proc -import infra.jsonrpc import e2e_args from loguru import logger as LOG diff --git a/tests/governance.py b/tests/governance.py index 31a194887..ccbad652b 100644 --- a/tests/governance.py +++ b/tests/governance.py @@ -11,7 +11,6 @@ import subprocess from random import seed import infra.ccf import infra.proc -import infra.jsonrpc import infra.notification import infra.net import e2e_args diff --git a/tests/infra/ccf.py b/tests/infra/ccf.py index ea282c286..4dc6983f7 100644 --- a/tests/infra/ccf.py +++ b/tests/infra/ccf.py @@ -7,11 +7,12 @@ import logging from contextlib import contextmanager from glob import glob from enum import Enum -import infra.jsonrpc +import infra.clients import infra.path import infra.proc import infra.node import infra.consortium +import infra.jsonrpc import ssl import random diff --git a/tests/infra/clients.py b/tests/infra/clients.py new file mode 100644 index 000000000..fb79bd659 --- /dev/null +++ b/tests/infra/clients.py @@ -0,0 +1,539 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the Apache 2.0 License. +import socket +import ssl +import msgpack +import struct +import select +import contextlib +import json +import time +import os +import subprocess +import tempfile +import base64 +import requests +from requests_http_signature import HTTPSignatureAuth +from enum import IntEnum +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import asymmetric + +from loguru import logger as LOG + + +def truncate(string, max_len=256): + if len(string) > max_len: + return string[: max_len - 3] + "..." + else: + return string + + +class Request: + def __init__(self, id, method, params, readonly_hint=None, jsonrpc="2.0"): + self.id = id + self.method = method + self.params = params + self.jsonrpc = jsonrpc + self.readonly_hint = readonly_hint + + def to_dict(self): + rpc = { + "id": self.id, + "method": self.method, + "jsonrpc": self.jsonrpc, + "params": self.params, + } + if self.readonly_hint is not None: + rpc["readonly"] = self.readonly_hint + return rpc + + def to_msgpack(self): + return msgpack.packb(self.to_dict(), use_bin_type=True) + + def to_json(self): + return json.dumps(self.to_dict()).encode() + + +class Response: + def __init__( + self, + id, + result=None, + error=None, + commit=None, + term=None, + global_commit=None, + jsonrpc="2.0", + ): + self.id = id + self.result = result + self.error = error + self.jsonrpc = jsonrpc + self.commit = commit + self.term = term + self.global_commit = global_commit + self._attrs = set(locals()) - {"self"} + + def to_dict(self): + d = { + "id": self.id, + "jsonrpc": self.jsonrpc, + "commit": self.commit, + "global_commit": self.global_commit, + "term": self.term, + } + if self.result is not None: + d["result"] = self.result + else: + d["error"] = self.error + return d + + def _from_parsed(self, parsed): + unexpected = parsed.keys() - self._attrs + if unexpected: + raise ValueError("Unexpected keys in response: {}".format(unexpected)) + for attr, value in parsed.items(): + setattr(self, attr, value) + + def from_msgpack(self, data): + parsed = msgpack.unpackb(data, raw=False) + self._from_parsed(parsed) + + def from_json(self, data): + parsed = json.loads(data.decode()) + self._from_parsed(parsed) + + +def human_readable_size(n): + suffixes = ("B", "KB", "MB", "GB") + i = 0 + while n >= 1024 and i < len(suffixes) - 1: + n /= 1024.0 + i += 1 + return f"{n:,.2f} {suffixes[i]}" + + +class FramedTLSClient: + def __init__(self, host, port, cert=None, key=None, ca=None): + self.host = host + self.port = port + self.cert = cert + self.key = key + self.ca = ca + self.context = None + self.sock = None + self.conn = None + + def connect(self): + if self.ca: + self.context = ssl.create_default_context(cafile=self.ca) + + # Auto detect EC curve to use based on server CA + ca_bytes = open(self.ca, "rb").read() + ca_curve = ( + x509.load_pem_x509_certificate(ca_bytes, default_backend()) + .public_key() + .curve + ) + if isinstance(ca_curve, asymmetric.ec.SECP256K1): + self.context.set_ecdh_curve("secp256k1") + else: + self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + if self.cert and self.key: + self.context.load_cert_chain(certfile=self.cert, keyfile=self.key) + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.conn = self.context.wrap_socket( + self.sock, server_side=False, server_hostname=self.host + ) + self.conn.connect((self.host, self.port)) + + def send(self, msg): + LOG.trace(f"Sending {human_readable_size(len(msg))} message") + frame = struct.pack("> Request:" + os.linesep) + json.dump(request.to_dict(), f, indent=2) + f.write(os.linesep) + + def log_response(self, response): + with open(self.path, "a") as f: + f.write("<< Response:" + os.linesep) + json.dump(response.to_dict(), f, indent=2) + f.write(os.linesep) + + +class FramedTLSJSONRPCClient: + def __init__( + self, + host, + port, + cert=None, + key=None, + ca=None, + version="2.0", + format="msgpack", + connection_timeout=3, + *args, + **kwargs, + ): + self.client = FramedTLSClient(host, int(port), cert, key, ca) + self.stream = Stream(version, format=format) + self.format = format + + while connection_timeout >= 0: + try: + self.connect() + break + except ssl.SSLError: + if connection_timeout < 0: + raise + connection_timeout -= 0.1 + time.sleep(0.1) + + def connect(self): + return self.client.connect() + + def disconnect(self): + return self.client.disconnect() + + def request(self, request): + self.client.send(getattr(request, "to_{}".format(self.format))()) + return request.id + + def tick(self): + msg = self.client.read() + self.stream.update(msg) + + def response(self, id): + self.tick() + return self.stream.response(id) + + +# We keep this around in a limited fashion still, because +# the resulting logs nicely illustrate manual usage in a way using requests doesn't +class CurlClient: + def __init__(self, host, port, cert, key, ca, version, format, *args, **kwargs): + self.host = host + self.port = port + self.cert = cert + self.key = key + self.ca = ca + self.format = format + self.stream = Stream(version, self.format) + + def request(self, request): + with tempfile.NamedTemporaryFile() as nf: + msg = getattr(request, "to_{}".format(self.format))() + LOG.debug("Going to send {}".format(msg)) + nf.write(msg) + nf.flush() + cmd = [ + "curl", + f"https://{self.host}:{self.port}/", + "-H", + "Content-Type: application/json", + "--data-binary", + f"@{nf.name}", + ] + if self.ca: + cmd.extend(["--cacert", self.ca]) + if self.key: + cmd.extend(["--key", self.key]) + if self.cert: + cmd.extend(["--cert", self.cert]) + LOG.debug(f"Running: {' '.join(cmd)}") + rc = subprocess.run(cmd, capture_output=True) + LOG.debug(f"Received {rc.stdout}") + if rc.returncode != 0: + LOG.error(rc.stderr) + raise RuntimeError("Curl failed") + self.stream.update(rc.stdout) + return request.id + + # TODO: Untested + def signed_request(self, request): + with tempfile.NamedTemporaryFile() as nf: + msg = getattr(request, "to_{}".format(self.format))() + LOG.debug("Going to send {}".format(msg)) + nf.write(msg) + nf.flush() + dgst = subprocess.run( + ["openssl", "dgst", "-sha256", "-sign", self.key, nf.name], + check=True, + capture_output=True, + ) + subprocess.run(["cat", nf.name], check=True) + cmd = [ + "curl", + "-v", + f"https://{self.host}:{self.port}/", + "-H", + "Content-Type: application/json", + "-H", + f"Authorize: {base64.b64encode(dgst.stdout).decode()}", + "--data-binary", + f"@{nf.name}", + ] + if self.ca: + cmd.extend(["--cacert", self.ca]) + if self.key: + cmd.extend(["--key", self.key]) + if self.cert: + cmd.extend(["--cert", self.cert]) + LOG.debug(f"Running: {' '.join(cmd)}") + rc = subprocess.run(cmd, capture_output=True) + LOG.debug(f"Received {rc.stdout.decode()}") + if rc.returncode != 0: + LOG.debug(f"ERR {rc.stderr.decode()}") + self.stream.update(rc.stdout) + return request.id + + def response(self, id): + return self.stream.response(id) + + def disconnect(self): + pass + + +class RequestClient: + def __init__( + self, + host, + port, + cert, + key, + ca, + version, + format, + connection_timeout, + request_timeout, + ): + self.host = host + self.port = port + self.cert = cert + self.key = key + self.ca = ca + self.stream = Stream(version, "json") + self.request_timeout = request_timeout + + def request(self, request): + rep = requests.post( + f"https://{self.host}:{self.port}/", + json=request.to_dict(), # TODO: For REST queries, use data= instead + cert=(self.cert, self.key), + verify=self.ca, + timeout=self.request_timeout, + ) + self.stream.update(rep.content) + return request.id + + def signed_request(self, request): + with open(self.key, "rb") as k: + rep = requests.post( + f"https://{self.host}:{self.port}/", + json=request.to_dict(), # TODO: For REST queries, use data= instead + cert=(self.cert, self.key), + verify=self.ca, + timeout=self.request_timeout, + # key_id needs to be specified but is unused + auth=HTTPSignatureAuth( + algorithm="ecdsa-sha256", key=k.read(), key_id="tls" + ), + ) + self.stream.update(rep.content) + return request.id + + def response(self, id): + return self.stream.response(id) + + def disconnect(self): + pass + + +class CCFClient: + def __init__(self, *args, **kwargs): + self.prefix = kwargs.pop("prefix") + self.description = kwargs.pop("description") + self.rpc_loggers = (RPCLogger(),) + self.name = "[{}:{}]".format(kwargs.get("host"), kwargs.get("port")) + + if os.getenv("HTTP"): + if os.getenv("CURL_CLIENT"): + self.client_impl = CurlClient(*args, **kwargs) + else: + self.client_impl = RequestClient(*args, **kwargs) + else: + self.client_impl = FramedTLSJSONRPCClient(*args, **kwargs) + + def disconnect(self): + self.client_impl.disconnect() + + def request(self, method, params, *args, **kwargs): + r = self.client_impl.stream.request( + f"{self.prefix}/{method}", params, *args, **kwargs + ) + if self.description: + description = " ({})".format(self.description) + for logger in self.rpc_loggers: + logger.log_request(r, self.name, description) + + self.client_impl.request(r) + return r.id + + def signed_request(self, method, params, *args, **kwargs): + r = self.client_impl.stream.request( + f"{self.prefix}/{method}", params, *args, **kwargs + ) + if self.description: + description = " ({}) [signed]".format(self.description) + for logger in self.rpc_loggers: + logger.log_request(r, self.name, description) + + return self.client_impl.signed_request(r) + + def response(self, id): + r = self.client_impl.response(id) + for logger in self.rpc_loggers: + logger.log_response(r) + return r + + def do(self, *args, **kwargs): + expected_result = None + expected_error_code = None + if "expected_result" in kwargs: + expected_result = kwargs.pop("expected_result") + if "expected_error_code" in kwargs: + expected_error_code = kwargs.pop("expected_error_code") + + id = self.request(*args, **kwargs) + r = self.response(id) + + if expected_result is not None: + assert expected_result == r.result + + if expected_error_code is not None: + assert expected_error_code == r.error["code"] + return r + + def rpc(self, *args, **kwargs): + if "signed" in kwargs and kwargs.pop("signed"): + id = self.signed_request(*args, **kwargs) + else: + id = self.request(*args, **kwargs) + return self.response(id) + + +@contextlib.contextmanager +def client( + host, + port, + cert=None, + key=None, + ca=None, + version="2.0", + format="json" if os.getenv("HTTP") else "msgpack", + description=None, + log_file=None, + prefix="users", + connection_timeout=3, + request_timeout=3, +): + c = CCFClient( + host=host, + port=port, + cert=cert, + key=key, + ca=ca, + version=version, + format=format, + description=description, + prefix=prefix, + connection_timeout=connection_timeout, + request_timeout=request_timeout, + ) + + if log_file is not None: + c.rpc_loggers += (RPCFileLogger(log_file),) + + try: + yield c + finally: + c.disconnect() diff --git a/tests/infra/jsonrpc.py b/tests/infra/jsonrpc.py index 672c8585a..e32aca230 100644 --- a/tests/infra/jsonrpc.py +++ b/tests/infra/jsonrpc.py @@ -1,25 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the Apache 2.0 License. -import socket -import ssl -import msgpack -import struct -import select -import contextlib -import json -import logging -import time -import os -import subprocess -import tempfile -import base64 -import requests -from enum import IntEnum -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import asymmetric -from loguru import logger as LOG +from enum import IntEnum # Values defined in node/rpc/jsonrpc.h class ErrorCode(IntEnum): @@ -45,452 +27,3 @@ class ErrorCode(IntEnum): RPC_NOT_FORWARDED = -32011 QUOTE_NOT_VERIFIED = -32012 SERVER_ERROR_END = -32099 - - -def truncate(string, max_len=256): - if len(string) > max_len: - return string[: max_len - 3] + "..." - else: - return string - - -class Request: - def __init__(self, id, method, params, readonly_hint=None, jsonrpc="2.0"): - self.id = id - self.method = method - self.params = params - self.jsonrpc = jsonrpc - self.readonly_hint = readonly_hint - - def to_dict(self): - rpc = { - "id": self.id, - "method": self.method, - "jsonrpc": self.jsonrpc, - "params": self.params, - } - if self.readonly_hint is not None: - rpc["readonly"] = self.readonly_hint - return rpc - - def to_msgpack(self): - return msgpack.packb(self.to_dict(), use_bin_type=True) - - def to_json(self): - return json.dumps(self.to_dict()).encode() - - -class Response: - def __init__( - self, - id, - result=None, - error=None, - commit=None, - term=None, - global_commit=None, - jsonrpc="2.0", - ): - self.id = id - self.result = result - self.error = error - self.jsonrpc = jsonrpc - self.commit = commit - self.term = term - self.global_commit = global_commit - self._attrs = set(locals()) - {"self"} - - def to_dict(self): - d = { - "id": self.id, - "jsonrpc": self.jsonrpc, - "commit": self.commit, - "global_commit": self.global_commit, - "term": self.term, - } - if self.result is not None: - d["result"] = self.result - else: - d["error"] = self.error - return d - - def _from_parsed(self, parsed): - unexpected = parsed.keys() - self._attrs - if unexpected: - raise ValueError("Unexpected keys in response: {}".format(unexpected)) - for attr, value in parsed.items(): - setattr(self, attr, value) - - def from_msgpack(self, data): - parsed = msgpack.unpackb(data, raw=False) - self._from_parsed(parsed) - - def from_json(self, data): - parsed = json.loads(data.decode()) - self._from_parsed(parsed) - - -def human_readable_size(n): - suffixes = ("B", "KB", "MB", "GB") - i = 0 - while n >= 1024 and i < len(suffixes) - 1: - n /= 1024.0 - i += 1 - return f"{n:,.2f} {suffixes[i]}" - - -class FramedTLSClient: - def __init__(self, host, port, cert=None, key=None, cafile=None): - self.host = host - self.port = port - self.cert = cert - self.key = key - self.cafile = cafile - self.context = None - self.sock = None - self.conn = None - - def connect(self): - if self.cafile: - self.context = ssl.create_default_context(cafile=self.cafile) - - # Auto detect EC curve to use based on server CA - ca_bytes = open(self.cafile, "rb").read() - ca_curve = ( - x509.load_pem_x509_certificate(ca_bytes, default_backend()) - .public_key() - .curve - ) - if isinstance(ca_curve, asymmetric.ec.SECP256K1): - self.context.set_ecdh_curve("secp256k1") - else: - self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - if self.cert and self.key: - self.context.load_cert_chain(certfile=self.cert, keyfile=self.key) - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.conn = self.context.wrap_socket( - self.sock, server_side=False, server_hostname=self.host - ) - self.conn.connect((self.host, self.port)) - - def send(self, msg): - LOG.trace(f"Sending {human_readable_size(len(msg))} message") - frame = struct.pack("> Request:" + os.linesep) - json.dump(request.to_dict(), f, indent=2) - f.write(os.linesep) - - def log_response(self, response): - with open(self.path, "a") as f: - f.write("<< Response:" + os.linesep) - json.dump(response.to_dict(), f, indent=2) - f.write(os.linesep) - - -class FramedTLSJSONRPCClient: - def __init__( - self, - host, - port, - cert=None, - key=None, - cafile=None, - version="2.0", - format="msgpack", - description=None, - prefix="users", - ): - self.client = FramedTLSClient(host, int(port), cert, key, cafile) - self.stream = Stream(version, format=format) - self.format = format - self.name = "[{}:{}]".format(host, port) - self.description = description - self.rpc_loggers = (RPCLogger(),) - self.prefix = prefix - - def connect(self): - return self.client.connect() - - def disconnect(self): - return self.client.disconnect() - - def request(self, method, params, *args, **kwargs): - r = self.stream.request(f"{self.prefix}/{method}", params, *args, **kwargs) - self.client.send(getattr(r, "to_{}".format(self.format))()) - description = "" - if self.description: - description = " ({})".format(self.description) - for logger in self.rpc_loggers: - logger.log_request(r, self.name, description) - return r.id - - def tick(self): - msg = self.client.read() - self.stream.update(msg) - - def response(self, id): - self.tick() - r = self.stream.response(id) - for logger in self.rpc_loggers: - logger.log_response(r) - return r - - def do(self, *args, **kwargs): - expected_result = None - expected_error_code = None - if "expected_result" in kwargs: - expected_result = kwargs.pop("expected_result") - if "expected_error_code" in kwargs: - expected_error_code = kwargs.pop("expected_error_code") - - id = self.request(*args, **kwargs) - r = self.response(id) - - if expected_result is not None: - assert expected_result == r.result - - if expected_error_code is not None: - assert expected_error_code == r.error["code"] - return r - - def rpc(self, *args, **kwargs): - id = self.request(*args, **kwargs) - return self.response(id) - - -# We use curl for now because we still use SNI to route to frontends -# and that's difficult to force in Python clients, whereas curl conveniently -# exposes --resolver -# We probably will keep this around in a limited fashion later still, because -# the resulting logs nicely illustrate manual usage in a way using requests doesn't -class CurlClient: - def __init__( - self, host, port, cert, key, cafile, version, format, description, prefix, - ): - self.host = host - self.port = port - self.cert = cert - self.key = key - self.cafile = cafile - self.version = version - self.format = format - self.stream = Stream(version, format=format) - self.pending = {} - self.prefix = prefix - - def signed_request(self, method, params): - r = self.stream.request(f"{self.prefix}/{method}", params) - with tempfile.NamedTemporaryFile() as nf: - msg = getattr(r, "to_{}".format(self.format))() - LOG.debug("Going to send {}".format(msg)) - nf.write(msg) - nf.flush() - dgst = subprocess.run( - ["openssl", "dgst", "-sha256", "-sign", self.key, nf.name], - check=True, - capture_output=True, - ) - subprocess.run(["cat", nf.name], check=True) - cmd = [ - "curl", - "-v", - f"https://{self.host}:{self.port}/", - "-H", - "Content-Type: application/json", - "-H", - f"Authorize: {base64.b64encode(dgst.stdout).decode()}", - "--data-binary", - f"@{nf.name}", - ] - if self.cafile: - cmd.extend(["--cacert", self.cafile]) - if self.key: - cmd.extend(["--key", self.key]) - if self.cert: - cmd.extend(["--cert", self.cert]) - LOG.debug(f"Running: {' '.join(cmd)}") - rc = subprocess.run(cmd, capture_output=True) - LOG.debug(f"Received {rc.stdout.decode()}") - if rc.returncode != 0: - LOG.debug(f"ERR {rc.stderr.decode()}") - self.stream.update(rc.stdout) - return r.id - - def request(self, method, params, *args, **kwargs): - r = self.stream.request(f"{self.prefix}/{method}", params, *args, **kwargs) - with tempfile.NamedTemporaryFile() as nf: - msg = getattr(r, "to_{}".format(self.format))() - LOG.debug("Going to send {}".format(msg)) - nf.write(msg) - nf.flush() - cmd = [ - "curl", - f"https://{self.host}:{self.port}/", - "-H", - "Content-Type: application/json", - "--data-binary", - f"@{nf.name}", - ] - if self.cafile: - cmd.extend(["--cacert", self.cafile]) - if self.key: - cmd.extend(["--key", self.key]) - if self.cert: - cmd.extend(["--cert", self.cert]) - LOG.debug(f"Running: {' '.join(cmd)}") - rc = subprocess.run(cmd, capture_output=True) - LOG.debug(f"Received {rc.stdout}") - if rc.returncode != 0: - LOG.error(rc.stderr) - raise RuntimeError("Curl failed") - self.stream.update(rc.stdout) - return r.id - - def response(self, id): - return self.stream.response(id) - - def do(self, *args, **kwargs): - expected_result = None - expected_error_code = None - if "expected_result" in kwargs: - expected_result = kwargs.pop("expected_result") - if "expected_error_code" in kwargs: - expected_error_code = kwargs.pop("expected_error_code") - - id = self.request(*args, **kwargs) - r = self.response(id) - - if expected_result is not None: - assert expected_result == r.result - - if expected_error_code is not None: - assert expected_error_code == r.error["code"] - return r - - def rpc(self, *args, **kwargs): - if "signed" in kwargs and kwargs.pop("signed"): - id = self.signed_request(*args, **kwargs) - return self.response(id) - else: - id = self.request(*args, **kwargs) - return self.response(id) - - -@contextlib.contextmanager -def client( - host, - port, - cert=None, - key=None, - cafile=None, - version="2.0", - format="json" if os.getenv("HTTP") else "msgpack", - description=None, - log_file=None, - connection_timeout=3, - prefix="users", -): - if os.getenv("HTTP"): - c = CurlClient( - host, port, cert, key, cafile, version, "json", description, prefix, - ) - yield c - else: - c = FramedTLSJSONRPCClient( - host, port, cert, key, cafile, version, format, description, prefix, - ) - - if log_file is not None: - c.rpc_loggers += (RPCFileLogger(log_file),) - - while connection_timeout >= 0: - try: - c.connect() - break - except ssl.SSLError: - if connection_timeout < 0: - raise - connection_timeout -= 0.1 - time.sleep(0.1) - try: - yield c - finally: - c.disconnect() diff --git a/tests/infra/node.py b/tests/infra/node.py index 782c5b22f..dd0458942 100644 --- a/tests/infra/node.py +++ b/tests/infra/node.py @@ -6,7 +6,7 @@ from enum import Enum import infra.remote import infra.net import infra.path -import infra.jsonrpc +import infra.clients import time from loguru import logger as LOG @@ -188,12 +188,12 @@ class Node: return self.remote.get_sealed_secrets() def user_client(self, format="msgpack", user_id=1, **kwargs): - return infra.jsonrpc.client( + return infra.clients.client( self.host, self.rpc_port, cert="user{}_cert.pem".format(user_id), key="user{}_privk.pem".format(user_id), - cafile="networkcert.pem", + ca="networkcert.pem", description="node {} (user)".format(self.node_id), format=format, prefix="users", @@ -201,12 +201,12 @@ class Node: ) def node_client(self, format="msgpack", timeout=3, **kwargs): - return infra.jsonrpc.client( + return infra.clients.client( self.host, self.rpc_port, cert=None, key=None, - cafile="networkcert.pem", + ca="networkcert.pem", description="node {} (node)".format(self.node_id), format=format, prefix="nodes", @@ -214,12 +214,12 @@ class Node: ) def member_client(self, format="msgpack", member_id=1, **kwargs): - return infra.jsonrpc.client( + return infra.clients.client( self.host, self.rpc_port, cert="member{}_cert.pem".format(member_id), key="member{}_privk.pem".format(member_id), - cafile="networkcert.pem", + ca="networkcert.pem", description="node {} (member)".format(self.node_id), format=format, prefix="members", diff --git a/tests/infra/runner.py b/tests/infra/runner.py index 9afe53281..cbe7353ea 100644 --- a/tests/infra/runner.py +++ b/tests/infra/runner.py @@ -9,7 +9,6 @@ from random import seed import infra.ccf import infra.proc import infra.remote_client -import infra.jsonrpc import infra.rates import os import re diff --git a/tests/node_suspension.py b/tests/node_suspension.py index 0769cce7d..9a4731d79 100644 --- a/tests/node_suspension.py +++ b/tests/node_suspension.py @@ -9,7 +9,6 @@ import shutil from random import seed import infra.ccf import infra.proc -import infra.jsonrpc import infra.notification import infra.net import e2e_args diff --git a/tests/client.py b/tests/perfclient.py similarity index 100% rename from tests/client.py rename to tests/perfclient.py diff --git a/tests/receipts.py b/tests/receipts.py index 5c1700ce1..23f20826d 100644 --- a/tests/receipts.py +++ b/tests/receipts.py @@ -9,7 +9,6 @@ import shutil from random import seed import infra.ccf import infra.proc -import infra.jsonrpc import infra.notification import infra.net import suite.test_requirements as reqs diff --git a/tests/recovery.py b/tests/recovery.py index 914ad4f54..c55babcf7 100644 --- a/tests/recovery.py +++ b/tests/recovery.py @@ -10,7 +10,6 @@ import e2e_args from random import seed import infra.ccf import infra.proc -import infra.jsonrpc import infra.remote import json import suite.test_requirements as reqs diff --git a/tests/requirements.txt b/tests/requirements.txt index e5e6f217f..54450ea69 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -5,4 +5,5 @@ msgpack loguru coincurve psutil -cimetrics>=0.2.1 \ No newline at end of file +cimetrics>=0.2.1 +requests-http-signature \ No newline at end of file diff --git a/tests/schema.py b/tests/schema.py index 707f3b827..077544de8 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -11,7 +11,6 @@ import shutil import random import infra.ccf import infra.proc -import infra.jsonrpc import e2e_args from loguru import logger as LOG