зеркало из https://github.com/microsoft/CCF.git
Support for simple (synchronously produced) streaming gRPC messages (#4258)
This commit is contained in:
Родитель
efb9e85891
Коммит
aacb484aab
|
@ -17,12 +17,10 @@ add_ccf_app(
|
||||||
external_executor
|
external_executor
|
||||||
SRCS external_executor.cpp
|
SRCS external_executor.cpp
|
||||||
INCLUDE_DIRS "${CMAKE_CURRENT_BINARY_DIR}/protobuf"
|
INCLUDE_DIRS "${CMAKE_CURRENT_BINARY_DIR}/protobuf"
|
||||||
LINK_LIBS_ENCLAVE executor_registration.enclave status.enclave
|
LINK_LIBS_ENCLAVE executor_registration.enclave kv.enclave status.enclave
|
||||||
protobuf.enclave
|
protobuf.enclave stringops.enclave
|
||||||
LINK_LIBS_VIRTUAL executor_registration.virtual status.virtual
|
LINK_LIBS_VIRTUAL executor_registration.virtual kv.virtual status.virtual
|
||||||
protobuf.virtual
|
protobuf.virtual stringops.virtual
|
||||||
LINK_LIBS_ENCLAVE kv.enclave status.enclave protobuf.enclave
|
|
||||||
LINK_LIBS_VIRTUAL kv.virtual status.virtual protobuf.virtual
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate an ephemeral signing key
|
# Generate an ephemeral signing key
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "executor_registration.pb.h"
|
#include "executor_registration.pb.h"
|
||||||
#include "kv.pb.h"
|
#include "kv.pb.h"
|
||||||
#include "node/endpoint_context_impl.h"
|
#include "node/endpoint_context_impl.h"
|
||||||
|
#include "stringops.pb.h"
|
||||||
|
|
||||||
#define FMT_HEADER_ONLY
|
#define FMT_HEADER_ONLY
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
@ -246,6 +247,72 @@ namespace externalexecutor
|
||||||
install_registry_service();
|
install_registry_service();
|
||||||
|
|
||||||
install_kv_service();
|
install_kv_service();
|
||||||
|
|
||||||
|
auto run_string_ops = [this](
|
||||||
|
ccf::endpoints::CommandEndpointContext& ctx,
|
||||||
|
std::vector<temp::OpIn>&& payload)
|
||||||
|
-> ccf::grpc::GrpcAdapterResponse<std::vector<temp::OpOut>> {
|
||||||
|
std::vector<temp::OpOut> results;
|
||||||
|
|
||||||
|
for (temp::OpIn& op : payload)
|
||||||
|
{
|
||||||
|
temp::OpOut& result = results.emplace_back();
|
||||||
|
switch (op.op_case())
|
||||||
|
{
|
||||||
|
case (temp::OpIn::OpCase::kEcho):
|
||||||
|
{
|
||||||
|
LOG_INFO_FMT("Got kEcho");
|
||||||
|
auto* echo_op = op.mutable_echo();
|
||||||
|
auto* echoed = result.mutable_echoed();
|
||||||
|
echoed->set_allocated_body(echo_op->release_body());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case (temp::OpIn::OpCase::kReverse):
|
||||||
|
{
|
||||||
|
LOG_INFO_FMT("Got kReverse");
|
||||||
|
auto* reverse_op = op.mutable_reverse();
|
||||||
|
std::string* s = reverse_op->release_body();
|
||||||
|
std::reverse(s->begin(), s->end());
|
||||||
|
auto* reversed = result.mutable_reversed();
|
||||||
|
reversed->set_allocated_body(s);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case (temp::OpIn::OpCase::kTruncate):
|
||||||
|
{
|
||||||
|
LOG_INFO_FMT("Got kTruncate");
|
||||||
|
auto* truncate_op = op.mutable_truncate();
|
||||||
|
std::string* s = truncate_op->release_body();
|
||||||
|
*s = s->substr(
|
||||||
|
truncate_op->start(),
|
||||||
|
truncate_op->end() - truncate_op->start());
|
||||||
|
auto* truncated = result.mutable_truncated();
|
||||||
|
truncated->set_allocated_body(s);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case (temp::OpIn::OpCase::OP_NOT_SET):
|
||||||
|
{
|
||||||
|
LOG_INFO_FMT("Got OP_NOT_SET");
|
||||||
|
// oneof may always be null. If the input OpIn was null, then the
|
||||||
|
// resulting OpOut is also null
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ccf::grpc::make_success(results);
|
||||||
|
};
|
||||||
|
|
||||||
|
make_command_endpoint(
|
||||||
|
"/temp.Test/RunOps",
|
||||||
|
HTTP_POST,
|
||||||
|
ccf::grpc_command_adapter<
|
||||||
|
std::vector<temp::OpIn>,
|
||||||
|
std::vector<temp::OpOut>>(run_string_ops),
|
||||||
|
ccf::no_auth_required)
|
||||||
|
.install();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace externalexecutor
|
} // namespace externalexecutor
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package temp;
|
||||||
|
|
||||||
|
option optimize_for = LITE_RUNTIME;
|
||||||
|
|
||||||
|
service Test
|
||||||
|
{
|
||||||
|
rpc RunOps(stream OpIn) returns (stream OpOut) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message OpIn
|
||||||
|
{
|
||||||
|
oneof op
|
||||||
|
{
|
||||||
|
EchoOp echo = 1;
|
||||||
|
ReverseOp reverse = 2;
|
||||||
|
TruncateOp truncate = 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message EchoOp { string body = 1; }
|
||||||
|
message ReverseOp { string body = 1; }
|
||||||
|
message TruncateOp
|
||||||
|
{
|
||||||
|
string body = 1;
|
||||||
|
uint32 start = 2;
|
||||||
|
uint32 end = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message OpOut
|
||||||
|
{
|
||||||
|
oneof result
|
||||||
|
{
|
||||||
|
EchoResult echoed = 1;
|
||||||
|
ReverseResult reversed = 2;
|
||||||
|
TruncateResult truncated = 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message EchoResult { string body = 1; }
|
||||||
|
message ReverseResult { string body = 1; }
|
||||||
|
message TruncateResult { string body = 1; }
|
|
@ -113,7 +113,41 @@ namespace ccf::grpc
|
||||||
http::headervalues::contenttype::GRPC));
|
http::headervalues::contenttype::GRPC));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto message_length = impl::read_message_frame(data, size);
|
if constexpr (nonstd::is_std_vector<In>::value)
|
||||||
|
{
|
||||||
|
using Message = typename In::value_type;
|
||||||
|
In messages;
|
||||||
|
while (size != 0)
|
||||||
|
{
|
||||||
|
const auto message_length = impl::read_message_frame(data, size);
|
||||||
|
if (message_length > size)
|
||||||
|
{
|
||||||
|
throw std::logic_error(fmt::format(
|
||||||
|
"Error in gRPC frame: only {} bytes remaining but message header "
|
||||||
|
"says messages is {} bytes",
|
||||||
|
size,
|
||||||
|
message_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
Message& msg = messages.emplace_back();
|
||||||
|
if (!msg.ParseFromArray(data, message_length))
|
||||||
|
{
|
||||||
|
throw std::logic_error(fmt::format(
|
||||||
|
"Error deserialising protobuf payload of type {}, size {} (message "
|
||||||
|
"{} in "
|
||||||
|
"stream)",
|
||||||
|
msg.GetTypeName(),
|
||||||
|
size,
|
||||||
|
messages.size()));
|
||||||
|
}
|
||||||
|
data += message_length;
|
||||||
|
size -= message_length;
|
||||||
|
}
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto message_length = impl::read_message_frame(data, size);
|
||||||
if (size != message_length)
|
if (size != message_length)
|
||||||
{
|
{
|
||||||
throw std::logic_error(fmt::format(
|
throw std::logic_error(fmt::format(
|
||||||
|
@ -121,17 +155,18 @@ namespace ccf::grpc
|
||||||
size,
|
size,
|
||||||
message_length));
|
message_length));
|
||||||
}
|
}
|
||||||
ctx->set_response_header(
|
|
||||||
http::headers::CONTENT_TYPE, http::headervalues::contenttype::GRPC);
|
|
||||||
|
|
||||||
In in;
|
In in;
|
||||||
if (!in.ParseFromArray(data, size))
|
if (!in.ParseFromArray(data, message_length))
|
||||||
{
|
{
|
||||||
throw std::logic_error(
|
throw std::logic_error(fmt::format(
|
||||||
fmt::format("Error deserialising protobuf payload of size {}", size));
|
"Error deserialising protobuf payload of type {}, size {}",
|
||||||
|
in.GetTypeName(),
|
||||||
|
size));
|
||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Out>
|
template <typename Out>
|
||||||
void set_grpc_response(
|
void set_grpc_response(
|
||||||
|
@ -141,22 +176,62 @@ namespace ccf::grpc
|
||||||
auto success_response = std::get_if<SuccessResponse<Out>>(&r);
|
auto success_response = std::get_if<SuccessResponse<Out>>(&r);
|
||||||
if (success_response != nullptr)
|
if (success_response != nullptr)
|
||||||
{
|
{
|
||||||
const auto& resp = success_response->body;
|
std::vector<uint8_t> r;
|
||||||
size_t r_size = impl::message_frame_length + resp.ByteSizeLong();
|
|
||||||
std::vector<uint8_t> r(r_size);
|
if constexpr (nonstd::is_std_vector<Out>::value)
|
||||||
|
{
|
||||||
|
using Message = typename Out::value_type;
|
||||||
|
const Out& messages = success_response->body;
|
||||||
|
size_t r_size = std::accumulate(
|
||||||
|
messages.begin(),
|
||||||
|
messages.end(),
|
||||||
|
0,
|
||||||
|
[](size_t current, const Message& msg) {
|
||||||
|
return current + impl::message_frame_length + msg.ByteSizeLong();
|
||||||
|
});
|
||||||
|
r.resize(r_size);
|
||||||
auto r_data = r.data();
|
auto r_data = r.data();
|
||||||
|
|
||||||
impl::write_message_frame(r_data, r_size, resp.ByteSizeLong());
|
for (const Message& msg : messages)
|
||||||
ctx->set_response_header(
|
{
|
||||||
http::headers::CONTENT_TYPE, http::headervalues::contenttype::GRPC);
|
const auto message_length = msg.ByteSizeLong();
|
||||||
|
impl::write_message_frame(r_data, r_size, message_length);
|
||||||
|
|
||||||
|
if (!msg.SerializeToArray(r_data, r_size))
|
||||||
|
{
|
||||||
|
throw std::logic_error(fmt::format(
|
||||||
|
"Error serialising protobuf response of type {}, size {}",
|
||||||
|
msg.GetTypeName(),
|
||||||
|
message_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
r_data += message_length;
|
||||||
|
r_size += message_length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const Out& resp = success_response->body;
|
||||||
|
const auto message_length = resp.ByteSizeLong();
|
||||||
|
size_t r_size = impl::message_frame_length + message_length;
|
||||||
|
r.resize(r_size);
|
||||||
|
auto r_data = r.data();
|
||||||
|
|
||||||
|
impl::write_message_frame(r_data, r_size, message_length);
|
||||||
|
|
||||||
if (!resp.SerializeToArray(r_data, r_size))
|
if (!resp.SerializeToArray(r_data, r_size))
|
||||||
{
|
{
|
||||||
throw std::logic_error(fmt::format(
|
throw std::logic_error(fmt::format(
|
||||||
"Error serialising protobuf response of size {}",
|
"Error serialising protobuf response of type {}, size {}",
|
||||||
resp.ByteSizeLong()));
|
resp.GetTypeName(),
|
||||||
|
message_length));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ctx->set_response_body(r);
|
ctx->set_response_body(r);
|
||||||
|
ctx->set_response_header(
|
||||||
|
http::headers::CONTENT_TYPE, http::headervalues::contenttype::GRPC);
|
||||||
|
|
||||||
ctx->set_response_trailer("grpc-status", success_response->status.code());
|
ctx->set_response_trailer("grpc-status", success_response->status.code());
|
||||||
ctx->set_response_trailer(
|
ctx->set_response_trailer(
|
||||||
"grpc-message", success_response->status.message());
|
"grpc-message", success_response->status.message());
|
||||||
|
@ -181,6 +256,10 @@ namespace ccf
|
||||||
using GrpcReadOnlyEndpoint = std::function<grpc::GrpcAdapterResponse<Out>(
|
using GrpcReadOnlyEndpoint = std::function<grpc::GrpcAdapterResponse<Out>(
|
||||||
endpoints::ReadOnlyEndpointContext& ctx, In&& payload)>;
|
endpoints::ReadOnlyEndpointContext& ctx, In&& payload)>;
|
||||||
|
|
||||||
|
template <typename In, typename Out>
|
||||||
|
using GrpcCommandEndpoint = std::function<grpc::GrpcAdapterResponse<Out>(
|
||||||
|
endpoints::CommandEndpointContext& ctx, In&& payload)>;
|
||||||
|
|
||||||
template <typename In, typename Out>
|
template <typename In, typename Out>
|
||||||
endpoints::EndpointFunction grpc_adapter(const GrpcEndpoint<In, Out>& f)
|
endpoints::EndpointFunction grpc_adapter(const GrpcEndpoint<In, Out>& f)
|
||||||
{
|
{
|
||||||
|
@ -199,4 +278,14 @@ namespace ccf
|
||||||
f(ctx, grpc::get_grpc_payload<In>(ctx.rpc_ctx)), ctx.rpc_ctx);
|
f(ctx, grpc::get_grpc_payload<In>(ctx.rpc_ctx)), ctx.rpc_ctx);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out>
|
||||||
|
endpoints::CommandEndpointFunction grpc_command_adapter(
|
||||||
|
const GrpcCommandEndpoint<In, Out>& f)
|
||||||
|
{
|
||||||
|
return [f](endpoints::CommandEndpointContext& ctx) {
|
||||||
|
grpc::set_grpc_response<Out>(
|
||||||
|
f(ctx, grpc::get_grpc_payload<In>(ctx.rpc_ctx)), ctx.rpc_ctx);
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,12 @@ import kv_pb2 as KV
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import kv_pb2_grpc as Service
|
import kv_pb2_grpc as Service
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import stringops_pb2 as StringOps
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import stringops_pb2_grpc as StringOpsService
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module
|
# pylint: disable=no-name-in-module
|
||||||
from google.protobuf.empty_pb2 import Empty as Empty
|
from google.protobuf.empty_pb2 import Empty as Empty
|
||||||
|
|
||||||
|
@ -123,6 +129,74 @@ def test_put_get(network, args):
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
@reqs.description("Test gRPC streaming APIs")
|
||||||
|
def test_streaming(network, args):
|
||||||
|
primary, _ = network.find_primary()
|
||||||
|
|
||||||
|
credentials = grpc.ssl_channel_credentials(
|
||||||
|
open(os.path.join(network.common_dir, "service_cert.pem"), "rb").read()
|
||||||
|
)
|
||||||
|
|
||||||
|
def echo_op(s):
|
||||||
|
return (StringOps.OpIn(echo=StringOps.EchoOp(body=s)), ("echoed", s))
|
||||||
|
|
||||||
|
def reverse_op(s):
|
||||||
|
return (
|
||||||
|
StringOps.OpIn(reverse=StringOps.ReverseOp(body=s)),
|
||||||
|
("reversed", s[::-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def truncate_op(s):
|
||||||
|
start = random.randint(0, len(s))
|
||||||
|
end = random.randint(start, len(s))
|
||||||
|
return (
|
||||||
|
StringOps.OpIn(truncate=StringOps.TruncateOp(body=s, start=start, end=end)),
|
||||||
|
("truncated", s[start:end]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def empty_op(s):
|
||||||
|
# oneof may always be null - generate some like this to make sure they're handled "correctly"
|
||||||
|
return (StringOps.OpIn(), None)
|
||||||
|
|
||||||
|
def generate_ops(n):
|
||||||
|
for _ in range(n):
|
||||||
|
s = f"I'm random string {n}: {random.random()}"
|
||||||
|
yield random.choice((echo_op, reverse_op, truncate_op, empty_op))(s)
|
||||||
|
|
||||||
|
def compare_op_results(stub, n_ops):
|
||||||
|
LOG.info(f"Sending streaming request containing {n_ops} operations")
|
||||||
|
ops = []
|
||||||
|
expected_results = []
|
||||||
|
for op, expected_result in generate_ops(n_ops):
|
||||||
|
ops.append(op)
|
||||||
|
expected_results.append(expected_result)
|
||||||
|
|
||||||
|
for actual_result in stub.RunOps(op for op in ops):
|
||||||
|
assert len(expected_results) > 0, "More responses than requests"
|
||||||
|
expected_result = expected_results.pop(0)
|
||||||
|
if expected_result is None:
|
||||||
|
assert not actual_result.HasField("result"), actual_result
|
||||||
|
else:
|
||||||
|
field_name, expected = expected_result
|
||||||
|
actual = getattr(actual_result, field_name).body
|
||||||
|
assert (
|
||||||
|
actual == expected
|
||||||
|
), f"Wrong {field_name} op: {actual} != {expected}"
|
||||||
|
|
||||||
|
assert len(expected_results) == 0, "Fewer responses than requests"
|
||||||
|
|
||||||
|
with grpc.secure_channel(
|
||||||
|
target=f"{primary.get_public_rpc_host()}:{primary.get_public_rpc_port()}",
|
||||||
|
credentials=credentials,
|
||||||
|
) as channel:
|
||||||
|
stub = StringOpsService.TestStub(channel)
|
||||||
|
|
||||||
|
compare_op_results(stub, 0)
|
||||||
|
compare_op_results(stub, 1)
|
||||||
|
compare_op_results(stub, 30)
|
||||||
|
compare_op_results(stub, 1000)
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
with infra.network.network(
|
with infra.network.network(
|
||||||
args.nodes,
|
args.nodes,
|
||||||
|
@ -132,12 +206,12 @@ def run(args):
|
||||||
) as network:
|
) as network:
|
||||||
network.start_and_open(args)
|
network.start_and_open(args)
|
||||||
test_put_get(network, args)
|
test_put_get(network, args)
|
||||||
|
test_streaming(network, args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = infra.e2e_args.cli_args()
|
args = infra.e2e_args.cli_args()
|
||||||
|
|
||||||
args.host_log_level = "trace"
|
|
||||||
args.package = "src/apps/external_executor/libexternal_executor"
|
args.package = "src/apps/external_executor/libexternal_executor"
|
||||||
args.http2 = True # gRPC interface
|
args.http2 = True # gRPC interface
|
||||||
args.nodes = infra.e2e_args.min_nodes(args, f=0)
|
args.nodes = infra.e2e_args.min_nodes(args, f=0)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче