зеркало из https://github.com/microsoft/CCF.git
Support for simple (synchronously produced) streaming gRPC messages (#4258)
This commit is contained in:
Родитель
efb9e85891
Коммит
aacb484aab
|
@ -17,12 +17,10 @@ add_ccf_app(
|
|||
external_executor
|
||||
SRCS external_executor.cpp
|
||||
INCLUDE_DIRS "${CMAKE_CURRENT_BINARY_DIR}/protobuf"
|
||||
LINK_LIBS_ENCLAVE executor_registration.enclave status.enclave
|
||||
protobuf.enclave
|
||||
LINK_LIBS_VIRTUAL executor_registration.virtual status.virtual
|
||||
protobuf.virtual
|
||||
LINK_LIBS_ENCLAVE kv.enclave status.enclave protobuf.enclave
|
||||
LINK_LIBS_VIRTUAL kv.virtual status.virtual protobuf.virtual
|
||||
LINK_LIBS_ENCLAVE executor_registration.enclave kv.enclave status.enclave
|
||||
protobuf.enclave stringops.enclave
|
||||
LINK_LIBS_VIRTUAL executor_registration.virtual kv.virtual status.virtual
|
||||
protobuf.virtual stringops.virtual
|
||||
)
|
||||
|
||||
# Generate an ephemeral signing key
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "executor_registration.pb.h"
|
||||
#include "kv.pb.h"
|
||||
#include "node/endpoint_context_impl.h"
|
||||
#include "stringops.pb.h"
|
||||
|
||||
#define FMT_HEADER_ONLY
|
||||
#include <fmt/format.h>
|
||||
|
@ -246,6 +247,72 @@ namespace externalexecutor
|
|||
install_registry_service();
|
||||
|
||||
install_kv_service();
|
||||
|
||||
auto run_string_ops = [this](
|
||||
ccf::endpoints::CommandEndpointContext& ctx,
|
||||
std::vector<temp::OpIn>&& payload)
|
||||
-> ccf::grpc::GrpcAdapterResponse<std::vector<temp::OpOut>> {
|
||||
std::vector<temp::OpOut> results;
|
||||
|
||||
for (temp::OpIn& op : payload)
|
||||
{
|
||||
temp::OpOut& result = results.emplace_back();
|
||||
switch (op.op_case())
|
||||
{
|
||||
case (temp::OpIn::OpCase::kEcho):
|
||||
{
|
||||
LOG_INFO_FMT("Got kEcho");
|
||||
auto* echo_op = op.mutable_echo();
|
||||
auto* echoed = result.mutable_echoed();
|
||||
echoed->set_allocated_body(echo_op->release_body());
|
||||
break;
|
||||
}
|
||||
|
||||
case (temp::OpIn::OpCase::kReverse):
|
||||
{
|
||||
LOG_INFO_FMT("Got kReverse");
|
||||
auto* reverse_op = op.mutable_reverse();
|
||||
std::string* s = reverse_op->release_body();
|
||||
std::reverse(s->begin(), s->end());
|
||||
auto* reversed = result.mutable_reversed();
|
||||
reversed->set_allocated_body(s);
|
||||
break;
|
||||
}
|
||||
|
||||
case (temp::OpIn::OpCase::kTruncate):
|
||||
{
|
||||
LOG_INFO_FMT("Got kTruncate");
|
||||
auto* truncate_op = op.mutable_truncate();
|
||||
std::string* s = truncate_op->release_body();
|
||||
*s = s->substr(
|
||||
truncate_op->start(),
|
||||
truncate_op->end() - truncate_op->start());
|
||||
auto* truncated = result.mutable_truncated();
|
||||
truncated->set_allocated_body(s);
|
||||
break;
|
||||
}
|
||||
|
||||
case (temp::OpIn::OpCase::OP_NOT_SET):
|
||||
{
|
||||
LOG_INFO_FMT("Got OP_NOT_SET");
|
||||
// oneof may always be null. If the input OpIn was null, then the
|
||||
// resulting OpOut is also null
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ccf::grpc::make_success(results);
|
||||
};
|
||||
|
||||
make_command_endpoint(
|
||||
"/temp.Test/RunOps",
|
||||
HTTP_POST,
|
||||
ccf::grpc_command_adapter<
|
||||
std::vector<temp::OpIn>,
|
||||
std::vector<temp::OpOut>>(run_string_ops),
|
||||
ccf::no_auth_required)
|
||||
.install();
|
||||
}
|
||||
};
|
||||
} // namespace externalexecutor
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package temp;
|
||||
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
||||
service Test
|
||||
{
|
||||
rpc RunOps(stream OpIn) returns (stream OpOut) {}
|
||||
}
|
||||
|
||||
message OpIn
|
||||
{
|
||||
oneof op
|
||||
{
|
||||
EchoOp echo = 1;
|
||||
ReverseOp reverse = 2;
|
||||
TruncateOp truncate = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message EchoOp { string body = 1; }
|
||||
message ReverseOp { string body = 1; }
|
||||
message TruncateOp
|
||||
{
|
||||
string body = 1;
|
||||
uint32 start = 2;
|
||||
uint32 end = 3;
|
||||
}
|
||||
|
||||
message OpOut
|
||||
{
|
||||
oneof result
|
||||
{
|
||||
EchoResult echoed = 1;
|
||||
ReverseResult reversed = 2;
|
||||
TruncateResult truncated = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message EchoResult { string body = 1; }
|
||||
message ReverseResult { string body = 1; }
|
||||
message TruncateResult { string body = 1; }
|
|
@ -113,24 +113,59 @@ namespace ccf::grpc
|
|||
http::headervalues::contenttype::GRPC));
|
||||
}
|
||||
|
||||
auto message_length = impl::read_message_frame(data, size);
|
||||
if (size != message_length)
|
||||
if constexpr (nonstd::is_std_vector<In>::value)
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error in gRPC frame: frame size is {} but messages is {} bytes",
|
||||
size,
|
||||
message_length));
|
||||
}
|
||||
ctx->set_response_header(
|
||||
http::headers::CONTENT_TYPE, http::headervalues::contenttype::GRPC);
|
||||
using Message = typename In::value_type;
|
||||
In messages;
|
||||
while (size != 0)
|
||||
{
|
||||
const auto message_length = impl::read_message_frame(data, size);
|
||||
if (message_length > size)
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error in gRPC frame: only {} bytes remaining but message header "
|
||||
"says messages is {} bytes",
|
||||
size,
|
||||
message_length));
|
||||
}
|
||||
|
||||
In in;
|
||||
if (!in.ParseFromArray(data, size))
|
||||
{
|
||||
throw std::logic_error(
|
||||
fmt::format("Error deserialising protobuf payload of size {}", size));
|
||||
Message& msg = messages.emplace_back();
|
||||
if (!msg.ParseFromArray(data, message_length))
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error deserialising protobuf payload of type {}, size {} (message "
|
||||
"{} in "
|
||||
"stream)",
|
||||
msg.GetTypeName(),
|
||||
size,
|
||||
messages.size()));
|
||||
}
|
||||
data += message_length;
|
||||
size -= message_length;
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto message_length = impl::read_message_frame(data, size);
|
||||
if (size != message_length)
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error in gRPC frame: frame size is {} but messages is {} bytes",
|
||||
size,
|
||||
message_length));
|
||||
}
|
||||
|
||||
In in;
|
||||
if (!in.ParseFromArray(data, message_length))
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error deserialising protobuf payload of type {}, size {}",
|
||||
in.GetTypeName(),
|
||||
size));
|
||||
}
|
||||
return in;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
template <typename Out>
|
||||
|
@ -141,22 +176,62 @@ namespace ccf::grpc
|
|||
auto success_response = std::get_if<SuccessResponse<Out>>(&r);
|
||||
if (success_response != nullptr)
|
||||
{
|
||||
const auto& resp = success_response->body;
|
||||
size_t r_size = impl::message_frame_length + resp.ByteSizeLong();
|
||||
std::vector<uint8_t> r(r_size);
|
||||
auto r_data = r.data();
|
||||
std::vector<uint8_t> r;
|
||||
|
||||
impl::write_message_frame(r_data, r_size, resp.ByteSizeLong());
|
||||
if constexpr (nonstd::is_std_vector<Out>::value)
|
||||
{
|
||||
using Message = typename Out::value_type;
|
||||
const Out& messages = success_response->body;
|
||||
size_t r_size = std::accumulate(
|
||||
messages.begin(),
|
||||
messages.end(),
|
||||
0,
|
||||
[](size_t current, const Message& msg) {
|
||||
return current + impl::message_frame_length + msg.ByteSizeLong();
|
||||
});
|
||||
r.resize(r_size);
|
||||
auto r_data = r.data();
|
||||
|
||||
for (const Message& msg : messages)
|
||||
{
|
||||
const auto message_length = msg.ByteSizeLong();
|
||||
impl::write_message_frame(r_data, r_size, message_length);
|
||||
|
||||
if (!msg.SerializeToArray(r_data, r_size))
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error serialising protobuf response of type {}, size {}",
|
||||
msg.GetTypeName(),
|
||||
message_length));
|
||||
}
|
||||
|
||||
r_data += message_length;
|
||||
r_size += message_length;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const Out& resp = success_response->body;
|
||||
const auto message_length = resp.ByteSizeLong();
|
||||
size_t r_size = impl::message_frame_length + message_length;
|
||||
r.resize(r_size);
|
||||
auto r_data = r.data();
|
||||
|
||||
impl::write_message_frame(r_data, r_size, message_length);
|
||||
|
||||
if (!resp.SerializeToArray(r_data, r_size))
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error serialising protobuf response of type {}, size {}",
|
||||
resp.GetTypeName(),
|
||||
message_length));
|
||||
}
|
||||
}
|
||||
|
||||
ctx->set_response_body(r);
|
||||
ctx->set_response_header(
|
||||
http::headers::CONTENT_TYPE, http::headervalues::contenttype::GRPC);
|
||||
|
||||
if (!resp.SerializeToArray(r_data, r_size))
|
||||
{
|
||||
throw std::logic_error(fmt::format(
|
||||
"Error serialising protobuf response of size {}",
|
||||
resp.ByteSizeLong()));
|
||||
}
|
||||
ctx->set_response_body(r);
|
||||
ctx->set_response_trailer("grpc-status", success_response->status.code());
|
||||
ctx->set_response_trailer(
|
||||
"grpc-message", success_response->status.message());
|
||||
|
@ -181,6 +256,10 @@ namespace ccf
|
|||
using GrpcReadOnlyEndpoint = std::function<grpc::GrpcAdapterResponse<Out>(
|
||||
endpoints::ReadOnlyEndpointContext& ctx, In&& payload)>;
|
||||
|
||||
template <typename In, typename Out>
|
||||
using GrpcCommandEndpoint = std::function<grpc::GrpcAdapterResponse<Out>(
|
||||
endpoints::CommandEndpointContext& ctx, In&& payload)>;
|
||||
|
||||
template <typename In, typename Out>
|
||||
endpoints::EndpointFunction grpc_adapter(const GrpcEndpoint<In, Out>& f)
|
||||
{
|
||||
|
@ -199,4 +278,14 @@ namespace ccf
|
|||
f(ctx, grpc::get_grpc_payload<In>(ctx.rpc_ctx)), ctx.rpc_ctx);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename In, typename Out>
|
||||
endpoints::CommandEndpointFunction grpc_command_adapter(
|
||||
const GrpcCommandEndpoint<In, Out>& f)
|
||||
{
|
||||
return [f](endpoints::CommandEndpointContext& ctx) {
|
||||
grpc::set_grpc_response<Out>(
|
||||
f(ctx, grpc::get_grpc_payload<In>(ctx.rpc_ctx)), ctx.rpc_ctx);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,12 @@ import kv_pb2 as KV
|
|||
# pylint: disable=import-error
|
||||
import kv_pb2_grpc as Service
|
||||
|
||||
# pylint: disable=import-error
|
||||
import stringops_pb2 as StringOps
|
||||
|
||||
# pylint: disable=import-error
|
||||
import stringops_pb2_grpc as StringOpsService
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from google.protobuf.empty_pb2 import Empty as Empty
|
||||
|
||||
|
@ -123,6 +129,74 @@ def test_put_get(network, args):
|
|||
return network
|
||||
|
||||
|
||||
@reqs.description("Test gRPC streaming APIs")
|
||||
def test_streaming(network, args):
|
||||
primary, _ = network.find_primary()
|
||||
|
||||
credentials = grpc.ssl_channel_credentials(
|
||||
open(os.path.join(network.common_dir, "service_cert.pem"), "rb").read()
|
||||
)
|
||||
|
||||
def echo_op(s):
|
||||
return (StringOps.OpIn(echo=StringOps.EchoOp(body=s)), ("echoed", s))
|
||||
|
||||
def reverse_op(s):
|
||||
return (
|
||||
StringOps.OpIn(reverse=StringOps.ReverseOp(body=s)),
|
||||
("reversed", s[::-1]),
|
||||
)
|
||||
|
||||
def truncate_op(s):
|
||||
start = random.randint(0, len(s))
|
||||
end = random.randint(start, len(s))
|
||||
return (
|
||||
StringOps.OpIn(truncate=StringOps.TruncateOp(body=s, start=start, end=end)),
|
||||
("truncated", s[start:end]),
|
||||
)
|
||||
|
||||
def empty_op(s):
|
||||
# oneof may always be null - generate some like this to make sure they're handled "correctly"
|
||||
return (StringOps.OpIn(), None)
|
||||
|
||||
def generate_ops(n):
|
||||
for _ in range(n):
|
||||
s = f"I'm random string {n}: {random.random()}"
|
||||
yield random.choice((echo_op, reverse_op, truncate_op, empty_op))(s)
|
||||
|
||||
def compare_op_results(stub, n_ops):
|
||||
LOG.info(f"Sending streaming request containing {n_ops} operations")
|
||||
ops = []
|
||||
expected_results = []
|
||||
for op, expected_result in generate_ops(n_ops):
|
||||
ops.append(op)
|
||||
expected_results.append(expected_result)
|
||||
|
||||
for actual_result in stub.RunOps(op for op in ops):
|
||||
assert len(expected_results) > 0, "More responses than requests"
|
||||
expected_result = expected_results.pop(0)
|
||||
if expected_result is None:
|
||||
assert not actual_result.HasField("result"), actual_result
|
||||
else:
|
||||
field_name, expected = expected_result
|
||||
actual = getattr(actual_result, field_name).body
|
||||
assert (
|
||||
actual == expected
|
||||
), f"Wrong {field_name} op: {actual} != {expected}"
|
||||
|
||||
assert len(expected_results) == 0, "Fewer responses than requests"
|
||||
|
||||
with grpc.secure_channel(
|
||||
target=f"{primary.get_public_rpc_host()}:{primary.get_public_rpc_port()}",
|
||||
credentials=credentials,
|
||||
) as channel:
|
||||
stub = StringOpsService.TestStub(channel)
|
||||
|
||||
compare_op_results(stub, 0)
|
||||
compare_op_results(stub, 1)
|
||||
compare_op_results(stub, 30)
|
||||
compare_op_results(stub, 1000)
|
||||
|
||||
|
||||
def run(args):
|
||||
with infra.network.network(
|
||||
args.nodes,
|
||||
|
@ -132,12 +206,12 @@ def run(args):
|
|||
) as network:
|
||||
network.start_and_open(args)
|
||||
test_put_get(network, args)
|
||||
test_streaming(network, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = infra.e2e_args.cli_args()
|
||||
|
||||
args.host_log_level = "trace"
|
||||
args.package = "src/apps/external_executor/libexternal_executor"
|
||||
args.http2 = True # gRPC interface
|
||||
args.nodes = infra.e2e_args.min_nodes(args, f=0)
|
||||
|
|
Загрузка…
Ссылка в новой задаче