Support for simple (synchronously produced) streaming gRPC messages (#4258)

This commit is contained in:
Eddy Ashton 2022-09-27 17:49:03 +01:00 коммит произвёл GitHub
Родитель efb9e85891
Коммит aacb484aab
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 305 добавлений и 34 удалений

Просмотреть файл

@ -17,12 +17,10 @@ add_ccf_app(
external_executor
SRCS external_executor.cpp
INCLUDE_DIRS "${CMAKE_CURRENT_BINARY_DIR}/protobuf"
LINK_LIBS_ENCLAVE executor_registration.enclave status.enclave
protobuf.enclave
LINK_LIBS_VIRTUAL executor_registration.virtual status.virtual
protobuf.virtual
LINK_LIBS_ENCLAVE kv.enclave status.enclave protobuf.enclave
LINK_LIBS_VIRTUAL kv.virtual status.virtual protobuf.virtual
LINK_LIBS_ENCLAVE executor_registration.enclave kv.enclave status.enclave
protobuf.enclave stringops.enclave
LINK_LIBS_VIRTUAL executor_registration.virtual kv.virtual status.virtual
protobuf.virtual stringops.virtual
)
# Generate an ephemeral signing key

Просмотреть файл

@ -14,6 +14,7 @@
#include "executor_registration.pb.h"
#include "kv.pb.h"
#include "node/endpoint_context_impl.h"
#include "stringops.pb.h"
#define FMT_HEADER_ONLY
#include <fmt/format.h>
@ -246,6 +247,72 @@ namespace externalexecutor
install_registry_service();
install_kv_service();
auto run_string_ops = [this](
ccf::endpoints::CommandEndpointContext& ctx,
std::vector<temp::OpIn>&& payload)
-> ccf::grpc::GrpcAdapterResponse<std::vector<temp::OpOut>> {
std::vector<temp::OpOut> results;
for (temp::OpIn& op : payload)
{
temp::OpOut& result = results.emplace_back();
switch (op.op_case())
{
case (temp::OpIn::OpCase::kEcho):
{
LOG_INFO_FMT("Got kEcho");
auto* echo_op = op.mutable_echo();
auto* echoed = result.mutable_echoed();
echoed->set_allocated_body(echo_op->release_body());
break;
}
case (temp::OpIn::OpCase::kReverse):
{
LOG_INFO_FMT("Got kReverse");
auto* reverse_op = op.mutable_reverse();
std::string* s = reverse_op->release_body();
std::reverse(s->begin(), s->end());
auto* reversed = result.mutable_reversed();
reversed->set_allocated_body(s);
break;
}
case (temp::OpIn::OpCase::kTruncate):
{
LOG_INFO_FMT("Got kTruncate");
auto* truncate_op = op.mutable_truncate();
std::string* s = truncate_op->release_body();
*s = s->substr(
truncate_op->start(),
truncate_op->end() - truncate_op->start());
auto* truncated = result.mutable_truncated();
truncated->set_allocated_body(s);
break;
}
case (temp::OpIn::OpCase::OP_NOT_SET):
{
LOG_INFO_FMT("Got OP_NOT_SET");
// oneof may always be null. If the input OpIn was null, then the
// resulting OpOut is also null
break;
}
}
}
return ccf::grpc::make_success(results);
};
make_command_endpoint(
"/temp.Test/RunOps",
HTTP_POST,
ccf::grpc_command_adapter<
std::vector<temp::OpIn>,
std::vector<temp::OpOut>>(run_string_ops),
ccf::no_auth_required)
.install();
}
};
} // namespace externalexecutor

Просмотреть файл

@ -0,0 +1,43 @@
syntax = "proto3";
package temp;
option optimize_for = LITE_RUNTIME;
service Test
{
rpc RunOps(stream OpIn) returns (stream OpOut) {}
}
message OpIn
{
oneof op
{
EchoOp echo = 1;
ReverseOp reverse = 2;
TruncateOp truncate = 3;
}
}
message EchoOp { string body = 1; }
message ReverseOp { string body = 1; }
message TruncateOp
{
string body = 1;
uint32 start = 2;
uint32 end = 3;
}
message OpOut
{
oneof result
{
EchoResult echoed = 1;
ReverseResult reversed = 2;
TruncateResult truncated = 3;
}
}
message EchoResult { string body = 1; }
message ReverseResult { string body = 1; }
message TruncateResult { string body = 1; }

Просмотреть файл

@ -113,24 +113,59 @@ namespace ccf::grpc
http::headervalues::contenttype::GRPC));
}
auto message_length = impl::read_message_frame(data, size);
if (size != message_length)
if constexpr (nonstd::is_std_vector<In>::value)
{
throw std::logic_error(fmt::format(
"Error in gRPC frame: frame size is {} but messages is {} bytes",
size,
message_length));
}
ctx->set_response_header(
http::headers::CONTENT_TYPE, http::headervalues::contenttype::GRPC);
using Message = typename In::value_type;
In messages;
while (size != 0)
{
const auto message_length = impl::read_message_frame(data, size);
if (message_length > size)
{
throw std::logic_error(fmt::format(
"Error in gRPC frame: only {} bytes remaining but message header "
"says messages is {} bytes",
size,
message_length));
}
In in;
if (!in.ParseFromArray(data, size))
{
throw std::logic_error(
fmt::format("Error deserialising protobuf payload of size {}", size));
Message& msg = messages.emplace_back();
if (!msg.ParseFromArray(data, message_length))
{
throw std::logic_error(fmt::format(
"Error deserialising protobuf payload of type {}, size {} (message "
"{} in "
"stream)",
msg.GetTypeName(),
size,
messages.size()));
}
data += message_length;
size -= message_length;
}
return messages;
}
else
{
const auto message_length = impl::read_message_frame(data, size);
if (size != message_length)
{
throw std::logic_error(fmt::format(
"Error in gRPC frame: frame size is {} but messages is {} bytes",
size,
message_length));
}
In in;
if (!in.ParseFromArray(data, message_length))
{
throw std::logic_error(fmt::format(
"Error deserialising protobuf payload of type {}, size {}",
in.GetTypeName(),
size));
}
return in;
}
return in;
}
template <typename Out>
@ -141,22 +176,62 @@ namespace ccf::grpc
auto success_response = std::get_if<SuccessResponse<Out>>(&r);
if (success_response != nullptr)
{
const auto& resp = success_response->body;
size_t r_size = impl::message_frame_length + resp.ByteSizeLong();
std::vector<uint8_t> r(r_size);
auto r_data = r.data();
std::vector<uint8_t> r;
impl::write_message_frame(r_data, r_size, resp.ByteSizeLong());
if constexpr (nonstd::is_std_vector<Out>::value)
{
using Message = typename Out::value_type;
const Out& messages = success_response->body;
size_t r_size = std::accumulate(
messages.begin(),
messages.end(),
0,
[](size_t current, const Message& msg) {
return current + impl::message_frame_length + msg.ByteSizeLong();
});
r.resize(r_size);
auto r_data = r.data();
for (const Message& msg : messages)
{
const auto message_length = msg.ByteSizeLong();
impl::write_message_frame(r_data, r_size, message_length);
if (!msg.SerializeToArray(r_data, r_size))
{
throw std::logic_error(fmt::format(
"Error serialising protobuf response of type {}, size {}",
msg.GetTypeName(),
message_length));
}
r_data += message_length;
r_size += message_length;
}
}
else
{
const Out& resp = success_response->body;
const auto message_length = resp.ByteSizeLong();
size_t r_size = impl::message_frame_length + message_length;
r.resize(r_size);
auto r_data = r.data();
impl::write_message_frame(r_data, r_size, message_length);
if (!resp.SerializeToArray(r_data, r_size))
{
throw std::logic_error(fmt::format(
"Error serialising protobuf response of type {}, size {}",
resp.GetTypeName(),
message_length));
}
}
ctx->set_response_body(r);
ctx->set_response_header(
http::headers::CONTENT_TYPE, http::headervalues::contenttype::GRPC);
if (!resp.SerializeToArray(r_data, r_size))
{
throw std::logic_error(fmt::format(
"Error serialising protobuf response of size {}",
resp.ByteSizeLong()));
}
ctx->set_response_body(r);
ctx->set_response_trailer("grpc-status", success_response->status.code());
ctx->set_response_trailer(
"grpc-message", success_response->status.message());
@ -181,6 +256,10 @@ namespace ccf
using GrpcReadOnlyEndpoint = std::function<grpc::GrpcAdapterResponse<Out>(
endpoints::ReadOnlyEndpointContext& ctx, In&& payload)>;
template <typename In, typename Out>
using GrpcCommandEndpoint = std::function<grpc::GrpcAdapterResponse<Out>(
endpoints::CommandEndpointContext& ctx, In&& payload)>;
template <typename In, typename Out>
endpoints::EndpointFunction grpc_adapter(const GrpcEndpoint<In, Out>& f)
{
@ -199,4 +278,14 @@ namespace ccf
f(ctx, grpc::get_grpc_payload<In>(ctx.rpc_ctx)), ctx.rpc_ctx);
};
}
template <typename In, typename Out>
endpoints::CommandEndpointFunction grpc_command_adapter(
const GrpcCommandEndpoint<In, Out>& f)
{
return [f](endpoints::CommandEndpointContext& ctx) {
grpc::set_grpc_response<Out>(
f(ctx, grpc::get_grpc_payload<In>(ctx.rpc_ctx)), ctx.rpc_ctx);
};
}
}

Просмотреть файл

@ -11,6 +11,12 @@ import kv_pb2 as KV
# pylint: disable=import-error
import kv_pb2_grpc as Service
# pylint: disable=import-error
import stringops_pb2 as StringOps
# pylint: disable=import-error
import stringops_pb2_grpc as StringOpsService
# pylint: disable=no-name-in-module
from google.protobuf.empty_pb2 import Empty as Empty
@ -123,6 +129,74 @@ def test_put_get(network, args):
return network
@reqs.description("Test gRPC streaming APIs")
def test_streaming(network, args):
primary, _ = network.find_primary()
credentials = grpc.ssl_channel_credentials(
open(os.path.join(network.common_dir, "service_cert.pem"), "rb").read()
)
def echo_op(s):
return (StringOps.OpIn(echo=StringOps.EchoOp(body=s)), ("echoed", s))
def reverse_op(s):
return (
StringOps.OpIn(reverse=StringOps.ReverseOp(body=s)),
("reversed", s[::-1]),
)
def truncate_op(s):
start = random.randint(0, len(s))
end = random.randint(start, len(s))
return (
StringOps.OpIn(truncate=StringOps.TruncateOp(body=s, start=start, end=end)),
("truncated", s[start:end]),
)
def empty_op(s):
# oneof may always be null - generate some like this to make sure they're handled "correctly"
return (StringOps.OpIn(), None)
def generate_ops(n):
for _ in range(n):
s = f"I'm random string {n}: {random.random()}"
yield random.choice((echo_op, reverse_op, truncate_op, empty_op))(s)
def compare_op_results(stub, n_ops):
LOG.info(f"Sending streaming request containing {n_ops} operations")
ops = []
expected_results = []
for op, expected_result in generate_ops(n_ops):
ops.append(op)
expected_results.append(expected_result)
for actual_result in stub.RunOps(op for op in ops):
assert len(expected_results) > 0, "More responses than requests"
expected_result = expected_results.pop(0)
if expected_result is None:
assert not actual_result.HasField("result"), actual_result
else:
field_name, expected = expected_result
actual = getattr(actual_result, field_name).body
assert (
actual == expected
), f"Wrong {field_name} op: {actual} != {expected}"
assert len(expected_results) == 0, "Fewer responses than requests"
with grpc.secure_channel(
target=f"{primary.get_public_rpc_host()}:{primary.get_public_rpc_port()}",
credentials=credentials,
) as channel:
stub = StringOpsService.TestStub(channel)
compare_op_results(stub, 0)
compare_op_results(stub, 1)
compare_op_results(stub, 30)
compare_op_results(stub, 1000)
def run(args):
with infra.network.network(
args.nodes,
@ -132,12 +206,12 @@ def run(args):
) as network:
network.start_and_open(args)
test_put_get(network, args)
test_streaming(network, args)
if __name__ == "__main__":
args = infra.e2e_args.cli_args()
args.host_log_level = "trace"
args.package = "src/apps/external_executor/libexternal_executor"
args.http2 = True # gRPC interface
args.nodes = infra.e2e_args.min_nodes(args, f=0)