diff --git a/.gitattributes b/.gitattributes index b2f6100a2..f699908b5 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2,4 +2,6 @@ *.unity binary *.prefab binary *.meta binary +*/CommunicatorObjects/* binary +*/communicator_objects/* binary *.md text diff --git a/Dockerfile b/Dockerfile index f5f5992ec..0b2d72277 100644 --- a/Dockerfile +++ b/Dockerfile @@ -129,4 +129,7 @@ RUN pip install --trusted-host pypi.python.org -r requirements.txt WORKDIR /execute COPY python /execute/python +# port 5005 is the port used in in Editor training. +EXPOSE 5005 + ENTRYPOINT ["python", "python/learn.py"] diff --git a/python/communicator_objects/__init__.py b/python/communicator_objects/__init__.py new file mode 100644 index 000000000..210506ec5 --- /dev/null +++ b/python/communicator_objects/__init__.py @@ -0,0 +1,19 @@ +from .agent_action_proto_pb2 import * +from .agent_info_proto_pb2 import * +from .brain_parameters_proto_pb2 import * +from .brain_type_proto_pb2 import * +from .command_proto_pb2 import * +from .engine_configuration_proto_pb2 import * +from .environment_parameters_proto_pb2 import * +from .header_pb2 import * +from .resolution_proto_pb2 import * +from .space_type_proto_pb2 import * +from .unity_input_pb2 import * +from .unity_message_pb2 import * +from .unity_output_pb2 import * +from .unity_rl_initialization_input_pb2 import * +from .unity_rl_initialization_output_pb2 import * +from .unity_rl_input_pb2 import * +from .unity_rl_output_pb2 import * +from .unity_to_external_pb2 import * +from .unity_to_external_pb2_grpc import * diff --git a/python/communicator_objects/agent_action_proto_pb2.py b/python/communicator_objects/agent_action_proto_pb2.py new file mode 100644 index 000000000..2bb928b6a --- /dev/null +++ b/python/communicator_objects/agent_action_proto_pb2.py @@ -0,0 +1,85 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/agent_action_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/agent_action_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"R\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_AGENTACTIONPROTO = _descriptor.Descriptor( + name='AgentActionProto', + full_name='communicator_objects.AgentActionProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0, + number=1, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2, + number=3, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=71, + serialized_end=153, +) + +DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict( + DESCRIPTOR = _AGENTACTIONPROTO, + __module__ = 'communicator_objects.agent_action_proto_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) + )) +_sym_db.RegisterMessage(AgentActionProto) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/agent_info_proto_pb2.py b/python/communicator_objects/agent_info_proto_pb2.py new file mode 100644 index 000000000..944f5b381 --- /dev/null +++ b/python/communicator_objects/agent_info_proto_pb2.py @@ -0,0 +1,134 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/agent_info_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/agent_info_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n+communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\"\xfd\x01\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_AGENTINFOPROTO = _descriptor.Descriptor( + name='AgentInfoProto', + full_name='communicator_objects.AgentInfoProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0, + number=1, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='visual_observations', full_name='communicator_objects.AgentInfoProto.visual_observations', index=1, + number=2, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=3, + number=4, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=5, + number=6, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=6, + number=7, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='done', full_name='communicator_objects.AgentInfoProto.done', index=7, + number=8, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=8, + number=9, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='id', full_name='communicator_objects.AgentInfoProto.id', index=9, + number=10, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=70, + serialized_end=323, +) + +DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +AgentInfoProto = _reflection.GeneratedProtocolMessageType('AgentInfoProto', (_message.Message,), dict( + DESCRIPTOR = _AGENTINFOPROTO, + __module__ = 'communicator_objects.agent_info_proto_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto) + )) +_sym_db.RegisterMessage(AgentInfoProto) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/brain_parameters_proto_pb2.py b/python/communicator_objects/brain_parameters_proto_pb2.py new file mode 100644 index 000000000..8fee109a2 --- /dev/null +++ b/python/communicator_objects/brain_parameters_proto_pb2.py @@ -0,0 +1,135 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/brain_parameters_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 +from communicator_objects import brain_type_proto_pb2 as communicator__objects_dot_brain__type__proto__pb2 +from communicator_objects import space_type_proto_pb2 as communicator__objects_dot_space__type__proto__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/brain_parameters_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n1communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto\x1a+communicator_objects/brain_type_proto.proto\x1a+communicator_objects/space_type_proto.proto\"\xc6\x03\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x01(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12K\n\x1dvector_observation_space_type\x18\x07 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x08 \x01(\t\x12\x38\n\nbrain_type\x18\t \x01(\x0e\x32$.communicator_objects.BrainTypeProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,communicator__objects_dot_brain__type__proto__pb2.DESCRIPTOR,communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR,]) + + + + +_BRAINPARAMETERSPROTO = _descriptor.Descriptor( + name='BrainParametersProto', + full_name='communicator_objects.BrainParametersProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='camera_resolutions', full_name='communicator_objects.BrainParametersProto.camera_resolutions', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=4, + number=5, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=5, + number=6, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='vector_observation_space_type', full_name='communicator_objects.BrainParametersProto.vector_observation_space_type', index=6, + number=7, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=7, + number=8, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='brain_type', full_name='communicator_objects.BrainParametersProto.brain_type', index=8, + number=9, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=211, + serialized_end=665, +) + +_BRAINPARAMETERSPROTO.fields_by_name['camera_resolutions'].message_type = communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO +_BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO +_BRAINPARAMETERSPROTO.fields_by_name['vector_observation_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO +_BRAINPARAMETERSPROTO.fields_by_name['brain_type'].enum_type = communicator__objects_dot_brain__type__proto__pb2._BRAINTYPEPROTO +DESCRIPTOR.message_types_by_name['BrainParametersProto'] = _BRAINPARAMETERSPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +BrainParametersProto = _reflection.GeneratedProtocolMessageType('BrainParametersProto', (_message.Message,), dict( + DESCRIPTOR = _BRAINPARAMETERSPROTO, + __module__ = 'communicator_objects.brain_parameters_proto_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto) + )) +_sym_db.RegisterMessage(BrainParametersProto) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/brain_type_proto_pb2.py b/python/communicator_objects/brain_type_proto_pb2.py new file mode 100644 index 000000000..8ac6022aa --- /dev/null +++ b/python/communicator_objects/brain_type_proto_pb2.py @@ -0,0 +1,71 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/brain_type_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/brain_type_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n+communicator_objects/brain_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*G\n\x0e\x42rainTypeProto\x12\n\n\x06Player\x10\x00\x12\r\n\tHeuristic\x10\x01\x12\x0c\n\x08\x45xternal\x10\x02\x12\x0c\n\x08Internal\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) + +_BRAINTYPEPROTO = _descriptor.EnumDescriptor( + name='BrainTypeProto', + full_name='communicator_objects.BrainTypeProto', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='Player', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='Heuristic', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='External', index=2, number=2, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='Internal', index=3, number=3, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=114, + serialized_end=185, +) +_sym_db.RegisterEnumDescriptor(_BRAINTYPEPROTO) + +BrainTypeProto = enum_type_wrapper.EnumTypeWrapper(_BRAINTYPEPROTO) +Player = 0 +Heuristic = 1 +External = 2 +Internal = 3 + + +DESCRIPTOR.enum_types_by_name['BrainTypeProto'] = _BRAINTYPEPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/command_proto_pb2.py b/python/communicator_objects/command_proto_pb2.py new file mode 100644 index 000000000..bdc97c1d2 --- /dev/null +++ b/python/communicator_objects/command_proto_pb2.py @@ -0,0 +1,64 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/command_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/command_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n(communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + +_COMMANDPROTO = _descriptor.EnumDescriptor( + name='CommandProto', + full_name='communicator_objects.CommandProto', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='STEP', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='RESET', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='QUIT', index=2, number=2, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=66, + serialized_end=111, +) +_sym_db.RegisterEnumDescriptor(_COMMANDPROTO) + +CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) +STEP = 0 +RESET = 1 +QUIT = 2 + + +DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/engine_configuration_proto_pb2.py b/python/communicator_objects/engine_configuration_proto_pb2.py new file mode 100644 index 000000000..cf5ee4687 --- /dev/null +++ b/python/communicator_objects/engine_configuration_proto_pb2.py @@ -0,0 +1,106 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/engine_configuration_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/engine_configuration_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n5communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( + name='EngineConfigurationProto', + full_name='communicator_objects.EngineConfigurationProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3, + number=4, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5, + number=6, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=80, + serialized_end=229, +) + +DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict( + DESCRIPTOR = _ENGINECONFIGURATIONPROTO, + __module__ = 'communicator_objects.engine_configuration_proto_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) + )) +_sym_db.RegisterMessage(EngineConfigurationProto) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/environment_parameters_proto_pb2.py b/python/communicator_objects/environment_parameters_proto_pb2.py new file mode 100644 index 000000000..1732614a6 --- /dev/null +++ b/python/communicator_objects/environment_parameters_proto_pb2.py @@ -0,0 +1,120 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/environment_parameters_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/environment_parameters_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n7communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\"\xb5\x01\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( + name='FloatParametersEntry', + full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=209, + serialized_end=263, +) + +_ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( + name='EnvironmentParametersProto', + full_name='communicator_objects.EnvironmentParametersProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=82, + serialized_end=263, +) + +_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO +_ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY +DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict( + + FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict( + DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, + __module__ = 'communicator_objects.environment_parameters_proto_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) + )) + , + DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO, + __module__ = 'communicator_objects.environment_parameters_proto_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) + )) +_sym_db.RegisterMessage(EnvironmentParametersProto) +_sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.has_options = True +_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/header_pb2.py b/python/communicator_objects/header_pb2.py new file mode 100644 index 000000000..1a8566dd2 --- /dev/null +++ b/python/communicator_objects/header_pb2.py @@ -0,0 +1,78 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/header.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/header.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n!communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_HEADER = _descriptor.Descriptor( + name='Header', + full_name='communicator_objects.Header', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='communicator_objects.Header.status', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='message', full_name='communicator_objects.Header.message', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=59, + serialized_end=100, +) + +DESCRIPTOR.message_types_by_name['Header'] = _HEADER +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( + DESCRIPTOR = _HEADER, + __module__ = 'communicator_objects.header_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.Header) + )) +_sym_db.RegisterMessage(Header) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/resolution_proto_pb2.py b/python/communicator_objects/resolution_proto_pb2.py new file mode 100644 index 000000000..e46853048 --- /dev/null +++ b/python/communicator_objects/resolution_proto_pb2.py @@ -0,0 +1,85 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/resolution_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/resolution_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n+communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_RESOLUTIONPROTO = _descriptor.Descriptor( + name='ResolutionProto', + full_name='communicator_objects.ResolutionProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='width', full_name='communicator_objects.ResolutionProto.width', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='height', full_name='communicator_objects.ResolutionProto.height', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=69, + serialized_end=137, +) + +DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict( + DESCRIPTOR = _RESOLUTIONPROTO, + __module__ = 'communicator_objects.resolution_proto_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) + )) +_sym_db.RegisterMessage(ResolutionProto) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/space_type_proto_pb2.py b/python/communicator_objects/space_type_proto_pb2.py new file mode 100644 index 000000000..ab58acbf1 --- /dev/null +++ b/python/communicator_objects/space_type_proto_pb2.py @@ -0,0 +1,61 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/space_type_proto.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/space_type_proto.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n+communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) + +_SPACETYPEPROTO = _descriptor.EnumDescriptor( + name='SpaceTypeProto', + full_name='communicator_objects.SpaceTypeProto', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='discrete', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='continuous', index=1, number=1, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=114, + serialized_end=160, +) +_sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) + +SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) +discrete = 0 +continuous = 1 + + +DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_input_pb2.py b/python/communicator_objects/unity_input_pb2.py new file mode 100644 index 000000000..d22921c75 --- /dev/null +++ b/python/communicator_objects/unity_input_pb2.py @@ -0,0 +1,90 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_input.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import unity_rl_input_pb2 as communicator__objects_dot_unity__rl__input__pb2 +from communicator_objects import unity_rl_initialization_input_pb2 as communicator__objects_dot_unity__rl__initialization__input__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_input.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n&communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a)communicator_objects/unity_rl_input.proto\x1a\x38\x63ommunicator_objects/unity_rl_initialization_input.proto\"\xb0\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInput\x12\x19\n\x11\x63ustom_data_input\x18\x03 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,]) + + + + +_UNITYINPUT = _descriptor.Descriptor( + name='UnityInput', + full_name='communicator_objects.UnityInput', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='custom_data_input', full_name='communicator_objects.UnityInput.custom_data_input', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=166, + serialized_end=342, +) + +_UNITYINPUT.fields_by_name['rl_input'].message_type = communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT +_UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT +DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict( + DESCRIPTOR = _UNITYINPUT, + __module__ = 'communicator_objects.unity_input_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) + )) +_sym_db.RegisterMessage(UnityInput) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_message_pb2.py b/python/communicator_objects/unity_message_pb2.py new file mode 100644 index 000000000..5288a492e --- /dev/null +++ b/python/communicator_objects/unity_message_pb2.py @@ -0,0 +1,92 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_message.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import unity_output_pb2 as communicator__objects_dot_unity__output__pb2 +from communicator_objects import unity_input_pb2 as communicator__objects_dot_unity__input__pb2 +from communicator_objects import header_pb2 as communicator__objects_dot_header__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_message.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n(communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\'communicator_objects/unity_output.proto\x1a&communicator_objects/unity_input.proto\x1a!communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_unity__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__input__pb2.DESCRIPTOR,communicator__objects_dot_header__pb2.DESCRIPTOR,]) + + + + +_UNITYMESSAGE = _descriptor.Descriptor( + name='UnityMessage', + full_name='communicator_objects.UnityMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='header', full_name='communicator_objects.UnityMessage.header', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=183, + serialized_end=355, +) + +_UNITYMESSAGE.fields_by_name['header'].message_type = communicator__objects_dot_header__pb2._HEADER +_UNITYMESSAGE.fields_by_name['unity_output'].message_type = communicator__objects_dot_unity__output__pb2._UNITYOUTPUT +_UNITYMESSAGE.fields_by_name['unity_input'].message_type = communicator__objects_dot_unity__input__pb2._UNITYINPUT +DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict( + DESCRIPTOR = _UNITYMESSAGE, + __module__ = 'communicator_objects.unity_message_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) + )) +_sym_db.RegisterMessage(UnityMessage) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_output_pb2.py b/python/communicator_objects/unity_output_pb2.py new file mode 100644 index 000000000..8700e72ad --- /dev/null +++ b/python/communicator_objects/unity_output_pb2.py @@ -0,0 +1,90 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_output.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import unity_rl_output_pb2 as communicator__objects_dot_unity__rl__output__pb2 +from communicator_objects import unity_rl_initialization_output_pb2 as communicator__objects_dot_unity__rl__initialization__output__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_output.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n\'communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a*communicator_objects/unity_rl_output.proto\x1a\x39\x63ommunicator_objects/unity_rl_initialization_output.proto\"\xb6\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutput\x12\x1a\n\x12\x63ustom_data_output\x18\x03 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,]) + + + + +_UNITYOUTPUT = _descriptor.Descriptor( + name='UnityOutput', + full_name='communicator_objects.UnityOutput', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='custom_data_output', full_name='communicator_objects.UnityOutput.custom_data_output', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=169, + serialized_end=351, +) + +_UNITYOUTPUT.fields_by_name['rl_output'].message_type = communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT +_UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT +DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict( + DESCRIPTOR = _UNITYOUTPUT, + __module__ = 'communicator_objects.unity_output_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) + )) +_sym_db.RegisterMessage(UnityOutput) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_rl_initialization_input_pb2.py b/python/communicator_objects/unity_rl_initialization_input_pb2.py new file mode 100644 index 000000000..9110cd187 --- /dev/null +++ b/python/communicator_objects/unity_rl_initialization_input_pb2.py @@ -0,0 +1,71 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_rl_initialization_input.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_rl_initialization_input.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n8communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( + name='UnityRLInitializationInput', + full_name='communicator_objects.UnityRLInitializationInput', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=82, + serialized_end=124, +) + +DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict( + DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT, + __module__ = 'communicator_objects.unity_rl_initialization_input_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) + )) +_sym_db.RegisterMessage(UnityRLInitializationInput) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_rl_initialization_output_pb2.py b/python/communicator_objects/unity_rl_initialization_output_pb2.py new file mode 100644 index 000000000..835c15242 --- /dev/null +++ b/python/communicator_objects/unity_rl_initialization_output_pb2.py @@ -0,0 +1,104 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_rl_initialization_output.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import brain_parameters_proto_pb2 as communicator__objects_dot_brain__parameters__proto__pb2 +from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_rl_initialization_output.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n9communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x31\x63ommunicator_objects/brain_parameters_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,]) + + + + +_UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( + name='UnityRLInitializationOutput', + full_name='communicator_objects.UnityRLInitializationOutput', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4, + number=6, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=192, + serialized_end=422, +) + +_UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO +_UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO +DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict( + DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT, + __module__ = 'communicator_objects.unity_rl_initialization_output_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) + )) +_sym_db.RegisterMessage(UnityRLInitializationOutput) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_rl_input_pb2.py b/python/communicator_objects/unity_rl_input_pb2.py new file mode 100644 index 000000000..819046d89 --- /dev/null +++ b/python/communicator_objects/unity_rl_input_pb2.py @@ -0,0 +1,188 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_rl_input.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import agent_action_proto_pb2 as communicator__objects_dot_agent__action__proto__pb2 +from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 +from communicator_objects import command_proto_pb2 as communicator__objects_dot_command__proto__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_rl_input.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n)communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a-communicator_objects/agent_action_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\x1a(communicator_objects/command_proto.proto\"\xb4\x03\n\x0cUnityRLInput\x12K\n\ragent_actions\x18\x01 \x03(\x0b\x32\x34.communicator_objects.UnityRLInput.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1al\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x46\n\x05value\x18\x02 \x01(\x0b\x32\x37.communicator_objects.UnityRLInput.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_agent__action__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_command__proto__pb2.DESCRIPTOR,]) + + + + +_UNITYRLINPUT_LISTAGENTACTIONPROTO = _descriptor.Descriptor( + name='ListAgentActionProto', + full_name='communicator_objects.UnityRLInput.ListAgentActionProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='value', full_name='communicator_objects.UnityRLInput.ListAgentActionProto.value', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=463, + serialized_end=540, +) + +_UNITYRLINPUT_AGENTACTIONSENTRY = _descriptor.Descriptor( + name='AgentActionsEntry', + full_name='communicator_objects.UnityRLInput.AgentActionsEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=542, + serialized_end=650, +) + +_UNITYRLINPUT = _descriptor.Descriptor( + name='UnityRLInput', + full_name='communicator_objects.UnityRLInput', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='agent_actions', full_name='communicator_objects.UnityRLInput.agent_actions', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='environment_parameters', full_name='communicator_objects.UnityRLInput.environment_parameters', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_training', full_name='communicator_objects.UnityRLInput.is_training', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='command', full_name='communicator_objects.UnityRLInput.command', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_UNITYRLINPUT_LISTAGENTACTIONPROTO, _UNITYRLINPUT_AGENTACTIONSENTRY, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=214, + serialized_end=650, +) + +_UNITYRLINPUT_LISTAGENTACTIONPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__action__proto__pb2._AGENTACTIONPROTO +_UNITYRLINPUT_LISTAGENTACTIONPROTO.containing_type = _UNITYRLINPUT +_UNITYRLINPUT_AGENTACTIONSENTRY.fields_by_name['value'].message_type = _UNITYRLINPUT_LISTAGENTACTIONPROTO +_UNITYRLINPUT_AGENTACTIONSENTRY.containing_type = _UNITYRLINPUT +_UNITYRLINPUT.fields_by_name['agent_actions'].message_type = _UNITYRLINPUT_AGENTACTIONSENTRY +_UNITYRLINPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO +_UNITYRLINPUT.fields_by_name['command'].enum_type = communicator__objects_dot_command__proto__pb2._COMMANDPROTO +DESCRIPTOR.message_types_by_name['UnityRLInput'] = _UNITYRLINPUT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +UnityRLInput = _reflection.GeneratedProtocolMessageType('UnityRLInput', (_message.Message,), dict( + + ListAgentActionProto = _reflection.GeneratedProtocolMessageType('ListAgentActionProto', (_message.Message,), dict( + DESCRIPTOR = _UNITYRLINPUT_LISTAGENTACTIONPROTO, + __module__ = 'communicator_objects.unity_rl_input_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.ListAgentActionProto) + )) + , + + AgentActionsEntry = _reflection.GeneratedProtocolMessageType('AgentActionsEntry', (_message.Message,), dict( + DESCRIPTOR = _UNITYRLINPUT_AGENTACTIONSENTRY, + __module__ = 'communicator_objects.unity_rl_input_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.AgentActionsEntry) + )) + , + DESCRIPTOR = _UNITYRLINPUT, + __module__ = 'communicator_objects.unity_rl_input_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput) + )) +_sym_db.RegisterMessage(UnityRLInput) +_sym_db.RegisterMessage(UnityRLInput.ListAgentActionProto) +_sym_db.RegisterMessage(UnityRLInput.AgentActionsEntry) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +_UNITYRLINPUT_AGENTACTIONSENTRY.has_options = True +_UNITYRLINPUT_AGENTACTIONSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_rl_output_pb2.py b/python/communicator_objects/unity_rl_output_pb2.py new file mode 100644 index 000000000..4f1763d55 --- /dev/null +++ b/python/communicator_objects/unity_rl_output_pb2.py @@ -0,0 +1,170 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_rl_output.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import agent_info_proto_pb2 as communicator__objects_dot_agent__info__proto__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_rl_output.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n*communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/agent_info_proto.proto\"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR,]) + + + + +_UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor( + name='ListAgentInfoProto', + full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='value', full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto.value', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=225, + serialized_end=298, +) + +_UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor( + name='AgentInfosEntry', + full_name='communicator_objects.UnityRLOutput.AgentInfosEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=300, + serialized_end=405, +) + +_UNITYRLOUTPUT = _descriptor.Descriptor( + name='UnityRLOutput', + full_name='communicator_objects.UnityRLOutput', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='global_done', full_name='communicator_objects.UnityRLOutput.global_done', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='agentInfos', full_name='communicator_objects.UnityRLOutput.agentInfos', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=114, + serialized_end=405, +) + +_UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO +_UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT +_UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name['value'].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO +_UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT +_UNITYRLOUTPUT.fields_by_name['agentInfos'].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY +DESCRIPTOR.message_types_by_name['UnityRLOutput'] = _UNITYRLOUTPUT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +UnityRLOutput = _reflection.GeneratedProtocolMessageType('UnityRLOutput', (_message.Message,), dict( + + ListAgentInfoProto = _reflection.GeneratedProtocolMessageType('ListAgentInfoProto', (_message.Message,), dict( + DESCRIPTOR = _UNITYRLOUTPUT_LISTAGENTINFOPROTO, + __module__ = 'communicator_objects.unity_rl_output_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto) + )) + , + + AgentInfosEntry = _reflection.GeneratedProtocolMessageType('AgentInfosEntry', (_message.Message,), dict( + DESCRIPTOR = _UNITYRLOUTPUT_AGENTINFOSENTRY, + __module__ = 'communicator_objects.unity_rl_output_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry) + )) + , + DESCRIPTOR = _UNITYRLOUTPUT, + __module__ = 'communicator_objects.unity_rl_output_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput) + )) +_sym_db.RegisterMessage(UnityRLOutput) +_sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto) +_sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) +_UNITYRLOUTPUT_AGENTINFOSENTRY.has_options = True +_UNITYRLOUTPUT_AGENTINFOSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_to_external_pb2.py b/python/communicator_objects/unity_to_external_pb2.py new file mode 100644 index 000000000..f9e61b344 --- /dev/null +++ b/python/communicator_objects/unity_to_external_pb2.py @@ -0,0 +1,58 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: communicator_objects/unity_to_external.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='communicator_objects/unity_to_external.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n,communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a(communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + , + dependencies=[communicator__objects_dot_unity__message__pb2.DESCRIPTOR,]) + + + +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) + +_UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( + name='UnityToExternal', + full_name='communicator_objects.UnityToExternal', + file=DESCRIPTOR, + index=0, + options=None, + serialized_start=112, + serialized_end=215, + methods=[ + _descriptor.MethodDescriptor( + name='Exchange', + full_name='communicator_objects.UnityToExternal.Exchange', + index=0, + containing_service=None, + input_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, + output_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, + options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) + +DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL + +# @@protoc_insertion_point(module_scope) diff --git a/python/communicator_objects/unity_to_external_pb2_grpc.py b/python/communicator_objects/unity_to_external_pb2_grpc.py new file mode 100644 index 000000000..52c2d67bd --- /dev/null +++ b/python/communicator_objects/unity_to_external_pb2_grpc.py @@ -0,0 +1,46 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 + + +class UnityToExternalStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Exchange = channel.unary_unary( + '/communicator_objects.UnityToExternal/Exchange', + request_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, + response_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, + ) + + +class UnityToExternalServicer(object): + # missing associated documentation comment in .proto file + pass + + def Exchange(self, request, context): + """Sends the academy parameters + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_UnityToExternalServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Exchange': grpc.unary_unary_rpc_method_handler( + servicer.Exchange, + request_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, + response_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'communicator_objects.UnityToExternal', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/python/learn.py b/python/learn.py index 25049533e..1b8f3591c 100755 --- a/python/learn.py +++ b/python/learn.py @@ -31,6 +31,7 @@ if __name__ == '__main__': _USAGE = ''' Usage: learn () [options] + learn [options] learn --help Options: diff --git a/python/requirements.txt b/python/requirements.txt index 584553894..ed3ebf84d 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -6,3 +6,5 @@ jupyter pytest>=3.2.2 docopt pyyaml +protobuf==3.5.2 +grpcio==1.11.0 diff --git a/python/tests/mock_communicator.py b/python/tests/mock_communicator.py new file mode 100755 index 000000000..21129b192 --- /dev/null +++ b/python/tests/mock_communicator.py @@ -0,0 +1,89 @@ + +from unityagents.communicator import Communicator +from communicator_objects import UnityMessage, UnityOutput, UnityInput,\ + ResolutionProto, BrainParametersProto, UnityRLInitializationOutput,\ + AgentInfoProto, UnityRLOutput + + +class MockCommunicator(Communicator): + def __init__(self, discrete=False, visual_input=False): + """ + Python side of the grpc communication. Python is the client and Unity the server + + :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. + :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. + """ + self.is_discrete = discrete + self.steps = 0 + self.visual_input = visual_input + self.has_been_closed = False + + def initialize(self, inputs: UnityInput) -> UnityOutput: + if self.visual_input: + resolutions = [ResolutionProto( + width=30, + height=40, + gray_scale=False)] + else: + resolutions = [] + bp = BrainParametersProto( + vector_observation_size=3, + num_stacked_vector_observations=2, + vector_action_size=2, + camera_resolutions=resolutions, + vector_action_descriptions=["", ""], + vector_action_space_type=int(not self.is_discrete), + vector_observation_space_type=1, + brain_name="RealFakeBrain", + brain_type=2 + ) + rl_init = UnityRLInitializationOutput( + name="RealFakeAcademy", + version="API-4", + log_path="", + brain_parameters=[bp] + ) + return UnityOutput( + rl_initialization_output=rl_init + ) + + def exchange(self, inputs: UnityInput) -> UnityOutput: + dict_agent_info = {} + if self.is_discrete: + vector_action = [1] + else: + vector_action = [1, 2] + list_agent_info = [] + for i in range(3): + list_agent_info.append( + AgentInfoProto( + stacked_vector_observation=[1, 2, 3, 1, 2, 3], + reward=1, + stored_vector_actions=vector_action, + stored_text_actions="", + text_observation="", + memories=[], + done=(i == 2), + max_step_reached=False, + id=i + )) + dict_agent_info["RealFakeBrain"] = \ + UnityRLOutput.ListAgentInfoProto(value=list_agent_info) + global_done = False + try: + global_done = (inputs.rl_input.agent_actions["RealFakeBrain"].value[0].vector_actions[0] == -1) + except: + pass + result = UnityRLOutput( + global_done=global_done, + agentInfos=dict_agent_info + ) + return UnityOutput( + rl_output=result + ) + + def close(self): + """ + Sends a shutdown signal to the unity environment, and closes the grpc connection. + """ + self.has_been_closed = True diff --git a/python/tests/test_bc.py b/python/tests/test_bc.py index 02f58c038..0db01ac14 100644 --- a/python/tests/test_bc.py +++ b/python/tests/test_bc.py @@ -6,97 +6,53 @@ import tensorflow as tf from unitytrainers.bc.models import BehavioralCloningModel from unityagents import UnityEnvironment +from .mock_communicator import MockCommunicator -def test_cc_bc_model(): - c_action_c_state_start = '''{ - "AcademyName": "RealFakeAcademy", - "resetParameters": {}, - "brainNames": ["RealFakeBrain"], - "externalBrainNames": ["RealFakeBrain"], - "logPath":"RealFakePath", - "apiNumber":"API-3", - "brainParameters": [{ - "vectorObservationSize": 3, - "numStackedVectorObservations": 2, - "vectorActionSize": 2, - "memorySize": 0, - "cameraResolutions": [], - "vectorActionDescriptions": ["",""], - "vectorActionSpaceType": 1, - "vectorObservationSpaceType": 1 - }] - }'''.encode() - +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_cc_bc_model(mock_communicator, mock_launcher): tf.reset_default_graph() - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - # End of mock - with tf.Session() as sess: - with tf.variable_scope("FakeGraphScope"): - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = c_action_c_state_start - env = UnityEnvironment(' ') + with tf.Session() as sess: + with tf.variable_scope("FakeGraphScope"): + mock_communicator.return_value = MockCommunicator( + discrete=False, visual_input=False) + env = UnityEnvironment(' ') + model = BehavioralCloningModel(env.brains["RealFakeBrain"]) + init = tf.global_variables_initializer() + sess.run(init) - model = BehavioralCloningModel(env.brains["RealFakeBrain"]) - init = tf.global_variables_initializer() - sess.run(init) - - run_list = [model.sample_action, model.policy] - feed_dict = {model.batch_size: 2, - model.sequence_length: 1, - model.vector_in: np.array([[1, 2, 3, 1, 2, 3], - [3, 4, 5, 3, 4, 5]])} - sess.run(run_list, feed_dict=feed_dict) - env.close() + run_list = [model.sample_action, model.policy] + feed_dict = {model.batch_size: 2, + model.sequence_length: 1, + model.vector_in: np.array([[1, 2, 3, 1, 2, 3], + [3, 4, 5, 3, 4, 5]])} + sess.run(run_list, feed_dict=feed_dict) + env.close() -def test_dc_bc_model(): - d_action_c_state_start = '''{ - "AcademyName": "RealFakeAcademy", - "resetParameters": {}, - "brainNames": ["RealFakeBrain"], - "externalBrainNames": ["RealFakeBrain"], - "logPath":"RealFakePath", - "apiNumber":"API-3", - "brainParameters": [{ - "vectorObservationSize": 3, - "numStackedVectorObservations": 2, - "vectorActionSize": 2, - "memorySize": 0, - "cameraResolutions": [{"width":30,"height":40,"blackAndWhite":false}], - "vectorActionDescriptions": ["",""], - "vectorActionSpaceType": 0, - "vectorObservationSpaceType": 1 - }] - }'''.encode() - +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_dc_bc_model(mock_communicator, mock_launcher): tf.reset_default_graph() - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - with tf.Session() as sess: - with tf.variable_scope("FakeGraphScope"): - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = d_action_c_state_start - env = UnityEnvironment(' ') + with tf.Session() as sess: + with tf.variable_scope("FakeGraphScope"): + mock_communicator.return_value = MockCommunicator( + discrete=True, visual_input=True) + env = UnityEnvironment(' ') + model = BehavioralCloningModel(env.brains["RealFakeBrain"]) + init = tf.global_variables_initializer() + sess.run(init) - model = BehavioralCloningModel(env.brains["RealFakeBrain"]) - init = tf.global_variables_initializer() - sess.run(init) - - run_list = [model.sample_action, model.policy] - feed_dict = {model.batch_size: 2, - model.dropout_rate: 1.0, - model.sequence_length: 1, - model.vector_in: np.array([[1, 2, 3, 1, 2, 3], - [3, 4, 5, 3, 4, 5]]), - model.visual_in[0]: np.ones([2, 40, 30, 3])} - sess.run(run_list, feed_dict=feed_dict) - env.close() + run_list = [model.sample_action, model.policy] + feed_dict = {model.batch_size: 2, + model.dropout_rate: 1.0, + model.sequence_length: 1, + model.vector_in: np.array([[1, 2, 3, 1, 2, 3], + [3, 4, 5, 3, 4, 5]]), + model.visual_in[0]: np.ones([2, 40, 30, 3])} + sess.run(run_list, feed_dict=feed_dict) + env.close() if __name__ == '__main__': diff --git a/python/tests/test_ppo.py b/python/tests/test_ppo.py index c8a23317f..c45972c98 100644 --- a/python/tests/test_ppo.py +++ b/python/tests/test_ppo.py @@ -6,99 +6,56 @@ import tensorflow as tf from unitytrainers.ppo.models import PPOModel from unityagents import UnityEnvironment +from .mock_communicator import MockCommunicator -def test_ppo_model_continuous(): - c_action_c_state_start = '''{ - "AcademyName": "RealFakeAcademy", - "resetParameters": {}, - "brainNames": ["RealFakeBrain"], - "externalBrainNames": ["RealFakeBrain"], - "logPath":"RealFakePath", - "apiNumber":"API-3", - "brainParameters": [{ - "vectorObservationSize": 3, - "numStackedVectorObservations": 2, - "vectorActionSize": 2, - "memorySize": 0, - "cameraResolutions": [], - "vectorActionDescriptions": ["",""], - "vectorActionSpaceType": 1, - "vectorObservationSpaceType": 1 - }] - }'''.encode() - +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_ppo_model_continuous(mock_communicator, mock_launcher): tf.reset_default_graph() - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - # End of mock - with tf.Session() as sess: - with tf.variable_scope("FakeGraphScope"): - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = c_action_c_state_start - env = UnityEnvironment(' ') + with tf.Session() as sess: + with tf.variable_scope("FakeGraphScope"): + mock_communicator.return_value = MockCommunicator( + discrete=False, visual_input=False) + env = UnityEnvironment(' ') - model = PPOModel(env.brains["RealFakeBrain"]) - init = tf.global_variables_initializer() - sess.run(init) + model = PPOModel(env.brains["RealFakeBrain"]) + init = tf.global_variables_initializer() + sess.run(init) - run_list = [model.output, model.probs, model.value, model.entropy, - model.learning_rate] - feed_dict = {model.batch_size: 2, - model.sequence_length: 1, - model.vector_in: np.array([[1, 2, 3, 1, 2, 3], - [3, 4, 5, 3, 4, 5]])} - sess.run(run_list, feed_dict=feed_dict) - env.close() + run_list = [model.output, model.probs, model.value, model.entropy, + model.learning_rate] + feed_dict = {model.batch_size: 2, + model.sequence_length: 1, + model.vector_in: np.array([[1, 2, 3, 1, 2, 3], + [3, 4, 5, 3, 4, 5]])} + sess.run(run_list, feed_dict=feed_dict) + env.close() -def test_ppo_model_discrete(): - d_action_c_state_start = '''{ - "AcademyName": "RealFakeAcademy", - "resetParameters": {}, - "brainNames": ["RealFakeBrain"], - "externalBrainNames": ["RealFakeBrain"], - "logPath":"RealFakePath", - "apiNumber":"API-3", - "brainParameters": [{ - "vectorObservationSize": 3, - "numStackedVectorObservations": 2, - "vectorActionSize": 2, - "memorySize": 0, - "cameraResolutions": [{"width":30,"height":40,"blackAndWhite":false}], - "vectorActionDescriptions": ["",""], - "vectorActionSpaceType": 0, - "vectorObservationSpaceType": 1 - }] - }'''.encode() - +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_ppo_model_discrete(mock_communicator, mock_launcher): tf.reset_default_graph() - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - # End of mock - with tf.Session() as sess: - with tf.variable_scope("FakeGraphScope"): - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = d_action_c_state_start - env = UnityEnvironment(' ') - model = PPOModel(env.brains["RealFakeBrain"]) - init = tf.global_variables_initializer() - sess.run(init) + with tf.Session() as sess: + with tf.variable_scope("FakeGraphScope"): + mock_communicator.return_value = MockCommunicator( + discrete=True, visual_input=True) + env = UnityEnvironment(' ') + model = PPOModel(env.brains["RealFakeBrain"]) + init = tf.global_variables_initializer() + sess.run(init) - run_list = [model.output, model.all_probs, model.value, model.entropy, - model.learning_rate] - feed_dict = {model.batch_size: 2, - model.sequence_length: 1, - model.vector_in: np.array([[1, 2, 3, 1, 2, 3], - [3, 4, 5, 3, 4, 5]]), - model.visual_in[0]: np.ones([2, 40, 30, 3]) - } - sess.run(run_list, feed_dict=feed_dict) - env.close() + run_list = [model.output, model.all_probs, model.value, model.entropy, + model.learning_rate] + feed_dict = {model.batch_size: 2, + model.sequence_length: 1, + model.vector_in: np.array([[1, 2, 3, 1, 2, 3], + [3, 4, 5, 3, 4, 5]]), + model.visual_in[0]: np.ones([2, 40, 30, 3]) + } + sess.run(run_list, feed_dict=feed_dict) + env.close() if __name__ == '__main__': diff --git a/python/tests/test_unityagents.py b/python/tests/test_unityagents.py index 55e1aaf87..3d438afcb 100755 --- a/python/tests/test_unityagents.py +++ b/python/tests/test_unityagents.py @@ -7,80 +7,9 @@ import numpy as np from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, \ BrainInfo, Curriculum +from .mock_communicator import MockCommunicator -def append_length(partial_string): - return struct.pack("I", len(partial_string.encode())) + partial_string.encode() - - -dummy_start = '''{ - "AcademyName": "RealFakeAcademy", - "resetParameters": {}, - "brainNames": ["RealFakeBrain"], - "externalBrainNames": ["RealFakeBrain"], - "logPath":"RealFakePath", - "apiNumber":"API-3", - "brainParameters": [{ - "vectorObservationSize": 3, - "numStackedVectorObservations": 2, - "vectorActionSize": 2, - "memorySize": 0, - "cameraResolutions": [], - "vectorActionDescriptions": ["",""], - "vectorActionSpaceType": 1, - "vectorObservationSpaceType": 1 - }] -}'''.encode() - -dummy_reset = [ - 'CONFIG_REQUEST'.encode(), - append_length( - ''' - { - "brain_name": "RealFakeBrain", - "agents": [1,2], - "vectorObservations": [1,2,3,4,5,6,1,2,3,4,5,6], - "rewards": [1,2], - "previousVectorActions": [1,2,3,4], - "previousTextActions":["",""], - "memories": [], - "dones": [false, false], - "maxes": [false, false], - "textObservations" :[" "," "] - }'''), - append_length('END_OF_MESSAGE:False')] - -dummy_step = ['actions'.encode(), - append_length(''' -{ - "brain_name": "RealFakeBrain", - "agents": [1,2,3], - "vectorObservations": [1,2,3,4,5,6,7,8,9,1,2,3,4,5,6,7,8,9], - "rewards": [1,2,3], - "previousVectorActions": [1,2,3,4,5,6], - "previousTextActions":["","",""], - "memories": [], - "dones": [false, false, false], - "maxes": [false, false, false], - "textObservations" :[" "," ", " "] -}'''), - append_length('END_OF_MESSAGE:False'), - 'actions'.encode(), - append_length(''' -{ - "brain_name": "RealFakeBrain", - "agents": [1,2,3], - "vectorObservations": [1,2,3,4,5,6,7,8,9,1,2,3,4,5,6,7,8,9], - "rewards": [1,2,3], - "previousVectorActions": [1,2,3,4,5,6], - "previousTextActions":["","",""], - "memories": [], - "dones": [false, false, true], - "maxes": [false, false, false], - "textObservations" :[" "," ", " "] -}'''), - append_length('END_OF_MESSAGE:True')] - dummy_curriculum = json.loads('''{ "measure" : "reward", "thresholds" : [10, 20, 50], @@ -112,90 +41,81 @@ def test_handles_bad_filename(): UnityEnvironment(' ') -def test_initialization(): - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = dummy_start - env = UnityEnvironment(' ') - with pytest.raises(UnityActionException): - env.step([0]) - assert env.brain_names[0] == 'RealFakeBrain' - env.close() +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_initialization(mock_communicator, mock_launcher): + mock_communicator.return_value = MockCommunicator( + discrete=False, visual_input=False) + env = UnityEnvironment(' ') + with pytest.raises(UnityActionException): + env.step([0]) + assert env.brain_names[0] == 'RealFakeBrain' + env.close() -def test_reset(): - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = dummy_start - env = UnityEnvironment(' ') - brain = env.brains['RealFakeBrain'] - mock_socket.recv.side_effect = dummy_reset - brain_info = env.reset() - env.close() - assert not env.global_done - assert isinstance(brain_info, dict) - assert isinstance(brain_info['RealFakeBrain'], BrainInfo) - assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) - assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) - assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations - assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ - len(brain_info['RealFakeBrain'].agents) - assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ - brain.vector_observation_space_size * brain.num_stacked_vector_observations +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_reset(mock_communicator, mock_launcher): + mock_communicator.return_value = MockCommunicator( + discrete=False, visual_input=False) + env = UnityEnvironment(' ') + brain = env.brains['RealFakeBrain'] + brain_info = env.reset() + env.close() + assert not env.global_done + assert isinstance(brain_info, dict) + assert isinstance(brain_info['RealFakeBrain'], BrainInfo) + assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) + assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) + assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations + assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ + len(brain_info['RealFakeBrain'].agents) + assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ + brain.vector_observation_space_size * brain.num_stacked_vector_observations +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_step(mock_communicator, mock_launcher): + mock_communicator.return_value = MockCommunicator( + discrete=False, visual_input=False) + env = UnityEnvironment(' ') + brain = env.brains['RealFakeBrain'] + brain_info = env.reset() + brain_info = env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) + with pytest.raises(UnityActionException): + env.step([0]) + brain_info = env.step([-1] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) + with pytest.raises(UnityActionException): + env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) + env.close() + assert env.global_done + assert isinstance(brain_info, dict) + assert isinstance(brain_info['RealFakeBrain'], BrainInfo) + assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) + assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) + assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations + assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ + len(brain_info['RealFakeBrain'].agents) + assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ + brain.vector_observation_space_size * brain.num_stacked_vector_observations -def test_step(): - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = dummy_start - env = UnityEnvironment(' ') - brain = env.brains['RealFakeBrain'] - mock_socket.recv.side_effect = dummy_reset - brain_info = env.reset() - mock_socket.recv.side_effect = dummy_step - brain_info = env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) - with pytest.raises(UnityActionException): - env.step([0]) - brain_info = env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) - with pytest.raises(UnityActionException): - env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) - env.close() - assert env.global_done - assert isinstance(brain_info, dict) - assert isinstance(brain_info['RealFakeBrain'], BrainInfo) - assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) - assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) - assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations - assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ - len(brain_info['RealFakeBrain'].agents) - assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ - brain.vector_observation_space_size * brain.num_stacked_vector_observations - assert not brain_info['RealFakeBrain'].local_done[0] - assert brain_info['RealFakeBrain'].local_done[2] + print("\n\n\n\n\n\n\n" + str(brain_info['RealFakeBrain'].local_done)) + assert not brain_info['RealFakeBrain'].local_done[0] + assert brain_info['RealFakeBrain'].local_done[2] -def test_close(): - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = dummy_start - env = UnityEnvironment(' ') - assert env._loaded - env.close() - assert not env._loaded - mock_socket.close.assert_called_once() +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_close(mock_communicator, mock_launcher): + comm = MockCommunicator( + discrete=False, visual_input=False) + mock_communicator.return_value = comm + env = UnityEnvironment(' ') + assert env._loaded + env.close() + assert not env._loaded + assert comm.has_been_closed def test_curriculum(): diff --git a/python/tests/test_unitytrainers.py b/python/tests/test_unitytrainers.py index 977cc946e..edf3ddebb 100644 --- a/python/tests/test_unitytrainers.py +++ b/python/tests/test_unitytrainers.py @@ -8,6 +8,7 @@ from unitytrainers.models import * from unitytrainers.ppo.trainer import PPOTrainer from unitytrainers.bc.trainer import BehavioralCloningTrainer from unityagents import UnityEnvironmentException +from .mock_communicator import MockCommunicator dummy_start = '''{ "AcademyName": "RealFakeAcademy", @@ -100,74 +101,68 @@ default: ''') -def test_initialization(): - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = dummy_start - tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, - 1, 1, 1, '', "tests/test_unitytrainers.py") - assert(tc.env.brain_names[0] == 'RealFakeBrain') +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_initialization(mock_communicator, mock_launcher): + mock_communicator.return_value = MockCommunicator( + discrete=True, visual_input=True) + tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, + 1, 1, 1, '', "tests/test_unitytrainers.py") + assert(tc.env.brain_names[0] == 'RealFakeBrain') -def test_load_config(): +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_load_config(mock_communicator, mock_launcher): open_name = 'unitytrainers.trainer_controller' + '.open' with mock.patch('yaml.load') as mock_load: with mock.patch(open_name, create=True) as _: - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - mock_load.return_value = dummy_config - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = dummy_start - mock_load.return_value = dummy_config - tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, - 1, 1, 1, '','') - config = tc._load_config() - assert(len(config) == 1) - assert(config['default']['trainer'] == "ppo") + mock_load.return_value = dummy_config + mock_communicator.return_value = MockCommunicator( + discrete=True, visual_input=True) + mock_load.return_value = dummy_config + tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, + 1, 1, 1, '','') + config = tc._load_config() + assert(len(config) == 1) + assert(config['default']['trainer'] == "ppo") -def test_initialize_trainers(): +@mock.patch('unityagents.UnityEnvironment.executable_launcher') +@mock.patch('unityagents.UnityEnvironment.get_communicator') +def test_initialize_trainers(mock_communicator, mock_launcher): open_name = 'unitytrainers.trainer_controller' + '.open' with mock.patch('yaml.load') as mock_load: with mock.patch(open_name, create=True) as _: - with mock.patch('subprocess.Popen'): - with mock.patch('socket.socket') as mock_socket: - with mock.patch('glob.glob') as mock_glob: - mock_glob.return_value = ['FakeLaunchPath'] - mock_socket.return_value.accept.return_value = (mock_socket, 0) - mock_socket.recv.return_value.decode.return_value = dummy_start - tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, - 1, 1, 1, '', "tests/test_unitytrainers.py") + mock_communicator.return_value = MockCommunicator( + discrete=True, visual_input=True) + tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, + 1, 1, 1, '', "tests/test_unitytrainers.py") - # Test for PPO trainer - mock_load.return_value = dummy_config - config = tc._load_config() - tf.reset_default_graph() - with tf.Session() as sess: - tc._initialize_trainers(config, sess) - assert(len(tc.trainers) == 1) - assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer)) + # Test for PPO trainer + mock_load.return_value = dummy_config + config = tc._load_config() + tf.reset_default_graph() + with tf.Session() as sess: + tc._initialize_trainers(config, sess) + assert(len(tc.trainers) == 1) + assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer)) - # Test for Behavior Cloning Trainer - mock_load.return_value = dummy_bc_config - config = tc._load_config() - tf.reset_default_graph() - with tf.Session() as sess: - tc._initialize_trainers(config, sess) - assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer)) + # Test for Behavior Cloning Trainer + mock_load.return_value = dummy_bc_config + config = tc._load_config() + tf.reset_default_graph() + with tf.Session() as sess: + tc._initialize_trainers(config, sess) + assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer)) - # Test for proper exception when trainer name is incorrect - mock_load.return_value = dummy_bad_config - config = tc._load_config() - tf.reset_default_graph() - with tf.Session() as sess: - with pytest.raises(UnityEnvironmentException): - tc._initialize_trainers(config, sess) + # Test for proper exception when trainer name is incorrect + mock_load.return_value = dummy_bad_config + config = tc._load_config() + tf.reset_default_graph() + with tf.Session() as sess: + with pytest.raises(UnityEnvironmentException): + tc._initialize_trainers(config, sess) def assert_array(a, b): diff --git a/python/unityagents/communicator.py b/python/unityagents/communicator.py new file mode 100755 index 000000000..c1c566354 --- /dev/null +++ b/python/unityagents/communicator.py @@ -0,0 +1,37 @@ +import logging + +from communicator_objects import UnityOutput, UnityInput + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("unityagents") + + +class Communicator(object): + def __init__(self, worker_id=0, + base_port=5005): + """ + Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. + + :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. + :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. + """ + + def initialize(self, inputs: UnityInput) -> UnityOutput: + """ + Used to exchange initialization parameters between Python and the Environment + :param inputs: The initialization input that will be sent to the environment. + :return: UnityOutput: The initialization output sent by Unity + """ + + def exchange(self, inputs: UnityInput) -> UnityOutput: + """ + Used to send an input and receive an output from the Environment + :param inputs: The UnityInput that needs to be sent the Environment + :return: The UnityOutputs generated by the Environment + """ + + def close(self): + """ + Sends a shutdown signal to the unity environment, and closes the connection. + """ + diff --git a/python/unityagents/environment.py b/python/unityagents/environment.py index 9a1dc2e3b..a01ce4e09 100755 --- a/python/unityagents/environment.py +++ b/python/unityagents/environment.py @@ -1,27 +1,32 @@ import atexit -import io import glob -import json +import io import logging import numpy as np import os -import socket import subprocess -import struct from .brain import BrainInfo, BrainParameters, AllBrainInfo from .exception import UnityEnvironmentException, UnityActionException, UnityTimeOutException from .curriculum import Curriculum -from PIL import Image +from communicator_objects import UnityRLInput, UnityRLOutput, AgentActionProto,\ + EnvironmentParametersProto, UnityRLInitializationInput, UnityRLInitializationOutput,\ + UnityInput, UnityOutput + +from .rpc_communicator import RpcCommunicator +from .socket_communicator import SocketCommunicator + + from sys import platform +from PIL import Image logging.basicConfig(level=logging.INFO) logger = logging.getLogger("unityagents") class UnityEnvironment(object): - def __init__(self, file_name, worker_id=0, + def __init__(self, file_name=None, worker_id=0, base_port=5005, curriculum=None, seed=0, docker_training=False): """ @@ -35,146 +40,72 @@ class UnityEnvironment(object): :param docker_training: Informs this class whether the process is being run within a container. """ - atexit.register(self.close) + atexit.register(self._close) self.port = base_port + worker_id self._buffer_size = 12000 - self._version_ = "API-3" - self._loaded = False - self._open_socket = False + self._version_ = "API-4" + self._loaded = False # If true, this means the environment was successfully loaded + self.proc1 = None # The process that is started. If None, no process was started - try: - # Establish communication socket - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self._socket.bind(("localhost", self.port)) - self._open_socket = True - except socket.error: - self._open_socket = True - self.close() - raise socket.error("Couldn't launch new environment because worker number {} is still in use. " - "You may need to manually close a previously opened environment " - "or use a different worker number.".format(str(worker_id))) + self.communicator = self.get_communicator(worker_id, base_port) - cwd = os.getcwd() - file_name = (file_name.strip() - .replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', '')) - true_filename = os.path.basename(os.path.normpath(file_name)) - logger.debug('The true file name is {}'.format(true_filename)) - launch_string = None - if platform == "linux" or platform == "linux2": - candidates = glob.glob(os.path.join(cwd, file_name) + '.x86_64') - if len(candidates) == 0: - candidates = glob.glob(os.path.join(cwd, file_name) + '.x86') - if len(candidates) == 0: - candidates = glob.glob(file_name + '.x86_64') - if len(candidates) == 0: - candidates = glob.glob(file_name + '.x86') - if len(candidates) > 0: - launch_string = candidates[0] - - elif platform == 'darwin': - candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', true_filename)) - if len(candidates) == 0: - candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', true_filename)) - if len(candidates) == 0: - candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', '*')) - if len(candidates) == 0: - candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', '*')) - if len(candidates) > 0: - launch_string = candidates[0] - elif platform == 'win32': - candidates = glob.glob(os.path.join(cwd, file_name + '.exe')) - if len(candidates) == 0: - candidates = glob.glob(file_name + '.exe') - if len(candidates) > 0: - launch_string = candidates[0] - if launch_string is None: - self.close() - raise UnityEnvironmentException("Couldn't launch the {0} environment. " - "Provided filename does not match any environments." - .format(true_filename)) + # If the environment name is 'editor', a new environment will not be launched + # and the communicator will directly try to connect to an existing unity environment. + if file_name is not None: + self.executable_launcher(file_name, docker_training) else: - logger.debug("This is the launch string {}".format(launch_string)) - # Launch Unity environment - if docker_training == False: - proc1 = subprocess.Popen( - [launch_string, - '--port', str(self.port), - '--seed', str(seed)]) - else: - """ - Comments for future maintenance: - xvfb-run is a wrapper around Xvfb, a virtual xserver where all - rendering is done to virtual memory. It automatically creates a - new virtual server automatically picking a server number `auto-servernum`. - The server is passed the arguments using `server-args`, we are telling - Xvfb to create Screen number 0 with width 640, height 480 and depth 24 bits. - Note that 640 X 480 are the default width and height. The main reason for - us to add this is because we'd like to change the depth from the default - of 8 bits to 24. - Unfortunately, this means that we will need to pass the arguments through - a shell which is why we set `shell=True`. Now, this adds its own - complications. E.g SIGINT can bounce off the shell and not get propagated - to the child processes. This is why we add `exec`, so that the shell gets - launched, the arguments are passed to `xvfb-run`. `exec` replaces the shell - we created with `xvfb`. - """ - docker_ls = ("exec xvfb-run --auto-servernum" - " --server-args='-screen 0 640x480x24'" - " {0} --port {1} --seed {2}").format(launch_string, - str(self.port), - str(seed)) - proc1 = subprocess.Popen(docker_ls, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True) - self._socket.settimeout(30) - try: - try: - self._socket.listen(1) - self._conn, _ = self._socket.accept() - self._conn.settimeout(30) - p = self._conn.recv(self._buffer_size).decode('utf-8') - p = json.loads(p) - except socket.timeout as e: - raise UnityTimeOutException( - "The Unity environment took too long to respond. Make sure {} does not need user interaction to " - "launch and that the Academy and the external Brain(s) are attached to objects in the Scene." - .format(str(file_name))) + logger.info("Ready to connect with the Editor.") + self._loaded = True - if "apiNumber" not in p: - self._unity_version = "API-1" - else: - self._unity_version = p["apiNumber"] - if self._unity_version != self._version_: - raise UnityEnvironmentException( - "The API number is not compatible between Unity and python. Python API : {0}, Unity API : " - "{1}.\nPlease go to https://github.com/Unity-Technologies/ml-agents to download the latest version " - "of ML-Agents.".format(self._version_, self._unity_version)) - self._data = {} - self._global_done = None - self._academy_name = p["AcademyName"] - self._log_path = p["logPath"] - # Need to instantiate new AllBrainInfo - self._brains = {} - self._brain_names = p["brainNames"] - self._external_brain_names = p["externalBrainNames"] - self._external_brain_names = [] if self._external_brain_names is None else self._external_brain_names - self._num_brains = len(self._brain_names) - self._num_external_brains = len(self._external_brain_names) - self._resetParameters = p["resetParameters"] - self._curriculum = Curriculum(curriculum, self._resetParameters) - for i in range(self._num_brains): - self._brains[self._brain_names[i]] = BrainParameters(self._brain_names[i], p["brainParameters"][i]) - self._loaded = True - logger.info("\n'{0}' started successfully!\n{1}".format(self._academy_name, str(self))) - if self._num_external_brains == 0: - logger.warning(" No External Brains found in the Unity Environment. " - "You will not be able to pass actions to your agent(s).") - except UnityEnvironmentException: - proc1.kill() - self.close() + rl_init_parameters_in = UnityRLInitializationInput( + seed=seed + ) + try: + aca_params = self.send_academy_parameters(rl_init_parameters_in) + except UnityTimeOutException: + self._close() raise + # TODO : think of a better way to expose the academyParameters + self._unity_version = aca_params.version + if self._unity_version != self._version_: + raise UnityEnvironmentException( + "The API number is not compatible between Unity and python. Python API : {0}, Unity API : " + "{1}.\nPlease go to https://github.com/Unity-Technologies/ml-agents to download the latest version " + "of ML-Agents.".format(self._version_, self._unity_version)) + self._n_agents = {} + self._global_done = None + self._academy_name = aca_params.name + self._log_path = aca_params.log_path + self._brains = {} + self._brain_names = [] + self._external_brain_names = [] + for brain_param in aca_params.brain_parameters: + self._brain_names += [brain_param.brain_name] + resolution = [{ + "height": x.height, + "width": x.width, + "blackAndWhite": x.gray_scale + } for x in brain_param.camera_resolutions] + self._brains[brain_param.brain_name] = \ + BrainParameters(brain_param.brain_name, { + "vectorObservationSize": brain_param.vector_observation_size, + "numStackedVectorObservations": brain_param.num_stacked_vector_observations, + "cameraResolutions": resolution, + "vectorActionSize": brain_param.vector_action_size, + "vectorActionDescriptions": brain_param.vector_action_descriptions, + "vectorActionSpaceType": brain_param.vector_action_space_type, + "vectorObservationSpaceType": brain_param.vector_observation_space_type + }) + if brain_param.brain_type == 2: + self._external_brain_names += [brain_param.brain_name] + self._num_brains = len(self._brain_names) + self._num_external_brains = len(self._external_brain_names) + self._resetParameters = dict(aca_params.environment_parameters.float_parameters) # TODO + self._curriculum = Curriculum(curriculum, self._resetParameters) + logger.info("\n'{0}' started successfully!\n{1}".format(self._academy_name, str(self))) + if self._num_external_brains == 0: + logger.warning(" No External Brains found in the Unity Environment. " + "You will not be able to pass actions to your agent(s).") @property def curriculum(self): @@ -212,20 +143,81 @@ class UnityEnvironment(object): def external_brain_names(self): return self._external_brain_names - @staticmethod - def _process_pixels(image_bytes=None, bw=False): - """ - Converts byte array observation image into numpy array, re-sizes it, and optionally converts it to grey scale - :param image_bytes: input byte array corresponding to image - :return: processed numpy array of observation from environment - """ - s = bytearray(image_bytes) - image = Image.open(io.BytesIO(s)) - s = np.array(image) / 255.0 - if bw: - s = np.mean(s, axis=2) - s = np.reshape(s, [s.shape[0], s.shape[1], 1]) - return s + def executable_launcher(self, file_name, docker_training): + cwd = os.getcwd() + file_name = (file_name.strip() + .replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', '')) + true_filename = os.path.basename(os.path.normpath(file_name)) + logger.debug('The true file name is {}'.format(true_filename)) + launch_string = None + if platform == "linux" or platform == "linux2": + candidates = glob.glob(os.path.join(cwd, file_name) + '.x86_64') + if len(candidates) == 0: + candidates = glob.glob(os.path.join(cwd, file_name) + '.x86') + if len(candidates) == 0: + candidates = glob.glob(file_name + '.x86_64') + if len(candidates) == 0: + candidates = glob.glob(file_name + '.x86') + if len(candidates) > 0: + launch_string = candidates[0] + + elif platform == 'darwin': + candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', true_filename)) + if len(candidates) == 0: + candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', true_filename)) + if len(candidates) == 0: + candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', '*')) + if len(candidates) == 0: + candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', '*')) + if len(candidates) > 0: + launch_string = candidates[0] + elif platform == 'win32': + candidates = glob.glob(os.path.join(cwd, file_name + '.exe')) + if len(candidates) == 0: + candidates = glob.glob(file_name + '.exe') + if len(candidates) > 0: + launch_string = candidates[0] + if launch_string is None: + self._close() + raise UnityEnvironmentException("Couldn't launch the {0} environment. " + "Provided filename does not match any environments." + .format(true_filename)) + else: + logger.debug("This is the launch string {}".format(launch_string)) + # Launch Unity environment + if not docker_training: + self.proc1 = subprocess.Popen( + [launch_string, + '--port', str(self.port)]) + else: + """ + Comments for future maintenance: + xvfb-run is a wrapper around Xvfb, a virtual xserver where all + rendering is done to virtual memory. It automatically creates a + new virtual server automatically picking a server number `auto-servernum`. + The server is passed the arguments using `server-args`, we are telling + Xvfb to create Screen number 0 with width 640, height 480 and depth 24 bits. + Note that 640 X 480 are the default width and height. The main reason for + us to add this is because we'd like to change the depth from the default + of 8 bits to 24. + Unfortunately, this means that we will need to pass the arguments through + a shell which is why we set `shell=True`. Now, this adds its own + complications. E.g SIGINT can bounce off the shell and not get propagated + to the child processes. This is why we add `exec`, so that the shell gets + launched, the arguments are passed to `xvfb-run`. `exec` replaces the shell + we created with `xvfb`. + """ + docker_ls = ("exec xvfb-run --auto-servernum" + " --server-args='-screen 0 640x480x24'" + " {0} --port {1}").format(launch_string, str(self.port)) + self.proc1 = subprocess.Popen(docker_ls, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True) + + def get_communicator(self, worker_id, base_port): + return RpcCommunicator(worker_id, base_port) + # return SocketCommunicator(worker_id, base_port) def __str__(self): _new_reset_param = self._curriculum.get_config() @@ -241,40 +233,6 @@ class UnityEnvironment(object): for k in self._resetParameters])) + '\n' + \ '\n'.join([str(self._brains[b]) for b in self._brains]) - def _recv_bytes(self): - try: - s = self._conn.recv(self._buffer_size) - message_length = struct.unpack("I", bytearray(s[:4]))[0] - s = s[4:] - while len(s) != message_length: - s += self._conn.recv(self._buffer_size) - except socket.timeout as e: - raise UnityTimeOutException("The environment took too long to respond.", self._log_path) - return s - - def _get_state_image(self, bw): - """ - Receives observation from socket, and confirms. - :param bw: - :return: - """ - s = self._recv_bytes() - s = self._process_pixels(image_bytes=s, bw=bw) - self._conn.send(b"RECEIVED") - return s - - def _get_state_dict(self): - """ - Receives dictionary of state information from socket, and confirms. - :return: - """ - state = self._recv_bytes().decode('utf-8') - if state[:14] == "END_OF_MESSAGE": - return {}, state[15:] == 'True' - self._conn.send(b"RECEIVED") - state_dict = json.loads(state) - return state_dict, None - def reset(self, train_mode=True, config=None, lesson=None) -> AllBrainInfo: """ Sends a signal to reset the unity environment. @@ -295,111 +253,20 @@ class UnityEnvironment(object): raise UnityEnvironmentException("The parameter '{0}' is not a valid parameter.".format(k)) if self._loaded: - self._conn.send(b"RESET") - try: - self._conn.recv(self._buffer_size) - except socket.timeout as e: - raise UnityTimeOutException("The environment took too long to respond.", self._log_path) - self._conn.send(json.dumps({"train_model": train_mode, "parameters": config}).encode('utf-8')) - return self._get_state() + outputs = self.communicator.exchange( + self._generate_reset_input(train_mode, config) + ) + if outputs is None: + raise KeyboardInterrupt + rl_output = outputs.rl_output + s = self._get_state(rl_output) + self._global_done = s[1] + for _b in self._external_brain_names: + self._n_agents[_b] = len(s[0][_b].agents) + return s[0] else: raise UnityEnvironmentException("No Unity environment is loaded.") - def _get_state(self) -> AllBrainInfo: - """ - Collects experience information from all external brains in environment at current step. - :return: a dictionary of BrainInfo objects. - """ - self._data = {} - while True: - state_dict, end_of_message = self._get_state_dict() - if end_of_message is not None: - self._global_done = end_of_message - for _b in self._brain_names: - if _b not in self._data: - self._data[_b] = BrainInfo([], np.array([]), [], np.array([]), - [], [], [], np.array([]), [], max_reached=[]) - return self._data - b = state_dict["brain_name"] - n_agent = len(state_dict["agents"]) - try: - if self._brains[b].vector_observation_space_type == "continuous": - vector_obs = np.array(state_dict["vectorObservations"]).reshape( - (n_agent, self._brains[b].vector_observation_space_size - * self._brains[b].num_stacked_vector_observations)) - else: - vector_obs = np.array(state_dict["vectorObservations"]).reshape( - (n_agent, self._brains[b].num_stacked_vector_observations)) - except UnityActionException: - raise UnityActionException("Brain {0} has an invalid vector observation. " - "Expecting {1} {2} vector observations but received {3}." - .format(b, n_agent if self._brains[b].vector_observation_space_type == "discrete" - else str(self._brains[b].vector_observation_space_size * n_agent - * self._brains[b].num_stacked_vector_observations), - self._brains[b].vector_observation_space_type, - len(state_dict["vectorObservations"]))) - - memories = np.array(state_dict["memories"]).reshape((n_agent, -1)) - text_obs = state_dict["textObservations"] - rewards = state_dict["rewards"] - dones = state_dict["dones"] - agents = state_dict["agents"] - maxes = state_dict["maxes"] - - if n_agent > 0: - vector_actions = np.array(state_dict["previousVectorActions"]).reshape((n_agent, -1)) - text_actions = state_dict["previousTextActions"] - else: - vector_actions = np.array([]) - text_actions = [] - observations = [] - for o in range(self._brains[b].number_visual_observations): - obs_n = [] - for a in range(n_agent): - obs_n.append(self._get_state_image(self._brains[b].camera_resolutions[o]['blackAndWhite'])) - - observations.append(np.array(obs_n)) - self._data[b] = BrainInfo(observations, vector_obs, text_obs, memories, rewards, - agents, dones, vector_actions, text_actions, max_reached=maxes) - - def _send_action(self, vector_action ,memory, text_action): - """ - Send dictionary of actions, memories, and value estimates over socket. - :param vector_action: a dictionary of lists of vector actions. - :param memory: a dictionary of lists of of memories. - :param text_action: a dictionary of lists of text actions. - """ - try: - self._conn.recv(self._buffer_size) - except socket.timeout as e: - raise UnityTimeOutException("The environment took too long to respond.", self._log_path) - action_message = {"vector_action": vector_action, "memory": memory, "text_action": text_action} - self._conn.send(self._append_length(json.dumps(action_message).encode('utf-8'))) - - @staticmethod - def _append_length(message): - return struct.pack("I", len(message)) + message - - @staticmethod - def _flatten(arr): - """ - Converts dictionary of arrays to list for transmission over socket. - :param arr: numpy vector. - :return: flattened list. - """ - if isinstance(arr, (int, np.int_, float, np.float_)): - arr = [float(arr)] - if isinstance(arr, np.ndarray): - arr = arr.tolist() - if len(arr) == 0: - return arr - if isinstance(arr[0], np.ndarray): - arr = [item for sublist in arr for item in sublist.tolist()] - if isinstance(arr[0], list): - arr = [item for sublist in arr for item in sublist] - arr = [float(x) for x in arr] - return arr - def step(self, vector_action=None, memory=None, text_action=None) -> AllBrainInfo: """ Provides the environment with an action, moves the environment dynamics forward accordingly, and returns @@ -455,7 +322,7 @@ class UnityEnvironment(object): "in the environment".format(brain_name)) for b in self._external_brain_names: - n_agent = len(self._data[b].agents) + n_agent = self._n_agents[b] if b not in vector_action: # raise UnityActionException("You need to input an action for the brain {0}".format(b)) if self._brains[b].vector_action_space_type == "discrete": @@ -475,7 +342,7 @@ class UnityEnvironment(object): text_action[b] = [""] * n_agent else: if text_action[b] is None: - text_action[b] = [] + text_action[b] = [""] * n_agent if isinstance(text_action[b], str): text_action[b] = [text_action[b]] * n_agent if not ((len(text_action[b]) == n_agent) or len(text_action[b]) == 0): @@ -494,9 +361,17 @@ class UnityEnvironment(object): self._brains[b].vector_action_space_type, str(vector_action[b]))) - self._conn.send(b"STEP") - self._send_action(vector_action, memory, text_action) - return self._get_state() + outputs = self.communicator.exchange( + self._generate_step_input(vector_action, memory, text_action) + ) + if outputs is None: + raise KeyboardInterrupt + rl_output = outputs.rl_output + s = self._get_state(rl_output) + self._global_done = s[1] + for _b in self._external_brain_names: + self._n_agents[_b] = len(s[0][_b].agents) + return s[0] elif not self._loaded: raise UnityEnvironmentException("No Unity environment is loaded.") elif self._global_done: @@ -509,11 +384,120 @@ class UnityEnvironment(object): """ Sends a shutdown signal to the unity environment, and closes the socket connection. """ - if self._loaded & self._open_socket: - self._conn.send(b"EXIT") - self._conn.close() - if self._open_socket: - self._socket.close() - self._loaded = False + if self._loaded: + self._close() else: raise UnityEnvironmentException("No Unity environment is loaded.") + + def _close(self): + self._loaded = False + self.communicator.close() + if self.proc1 is not None: + self.proc1.kill() + + @staticmethod + def _flatten(arr): + """ + Converts arrays to list. + :param arr: numpy vector. + :return: flattened list. + """ + if isinstance(arr, (int, np.int_, float, np.float_)): + arr = [float(arr)] + if isinstance(arr, np.ndarray): + arr = arr.tolist() + if len(arr) == 0: + return arr + if isinstance(arr[0], np.ndarray): + arr = [item for sublist in arr for item in sublist.tolist()] + if isinstance(arr[0], list): + arr = [item for sublist in arr for item in sublist] + arr = [float(x) for x in arr] + return arr + + @staticmethod + def _process_pixels(image_bytes, gray_scale): + """ + Converts byte array observation image into numpy array, re-sizes it, and optionally converts it to grey scale + :param image_bytes: input byte array corresponding to image + :return: processed numpy array of observation from environment + """ + s = bytearray(image_bytes) + image = Image.open(io.BytesIO(s)) + s = np.array(image) / 255.0 + if gray_scale: + s = np.mean(s, axis=2) + s = np.reshape(s, [s.shape[0], s.shape[1], 1]) + return s + + def _get_state(self, output: UnityRLOutput) -> (AllBrainInfo, bool): + """ + Collects experience information from all external brains in environment at current step. + :return: a dictionary of BrainInfo objects. + """ + _data = {} + global_done = output.global_done + for b in output.agentInfos: + agent_info_list = output.agentInfos[b].value + vis_obs = [] + for i in range(self.brains[b].number_visual_observations): + obs = [ + self._process_pixels(x.visual_observations[i], self.brains[b].camera_resolutions[i]['blackAndWhite']) + for x in agent_info_list] + vis_obs += [np.array(obs)] + memory_size = max([len(x.memories) for x in agent_info_list]) + if memory_size == 0: + memory = np.zeros((0,0)) + else: + [x.memories.extend([0] * (memory_size - len(x.memories))) for x in agent_info_list] + memory = np.array([x.memories for x in agent_info_list]) + _data[b] = BrainInfo( + visual_observation=vis_obs, + vector_observation=np.array([x.stacked_vector_observation for x in agent_info_list]), + text_observations=[x.text_observation for x in agent_info_list], + memory=memory, + reward=[x.reward for x in agent_info_list], + agents=[x.id for x in agent_info_list], + local_done=[x.done for x in agent_info_list], + vector_action=np.array([x.stored_vector_actions for x in agent_info_list]), + text_action=[x.stored_text_actions for x in agent_info_list], + max_reached=[x.max_step_reached for x in agent_info_list] + ) + return _data, global_done + + def _generate_step_input(self, vector_action, memory, text_action) -> UnityRLInput: + rl_in = UnityRLInput() + for b in vector_action: + n_agents = self._n_agents[b] + if n_agents == 0: + continue + _a_s = len(vector_action[b]) // n_agents + _m_s = len(memory[b]) // n_agents + for i in range(n_agents): + action = AgentActionProto( + vector_actions=vector_action[b][i*_a_s: (i+1)*_a_s], + memories=memory[b][i*_m_s: (i+1)*_m_s], + text_actions=text_action[b][i] + ) + rl_in.agent_actions[b].value.extend([action]) + rl_in.command = 0 + return self.wrap_unity_input(rl_in) + + def _generate_reset_input(self, training, config) -> UnityRLInput: + rl_in = UnityRLInput() + rl_in.is_training = training + rl_in.environment_parameters.CopyFrom(EnvironmentParametersProto()) + for key in config: + rl_in.environment_parameters.float_parameters[key] = config[key] + rl_in.command = 1 + return self.wrap_unity_input(rl_in) + + def send_academy_parameters(self, init_parameters: UnityRLInitializationInput) -> UnityRLInitializationOutput: + inputs = UnityInput() + inputs.rl_initialization_input.CopyFrom(init_parameters) + return self.communicator.initialize(inputs).rl_initialization_output + + def wrap_unity_input(self, rl_input: UnityRLInput) -> UnityOutput: + result = UnityInput() + result.rl_input.CopyFrom(rl_input) + return result diff --git a/python/unityagents/rpc_communicator.py b/python/unityagents/rpc_communicator.py new file mode 100755 index 000000000..95303bffe --- /dev/null +++ b/python/unityagents/rpc_communicator.py @@ -0,0 +1,97 @@ +import logging +import grpc + +from multiprocessing import Pipe +from concurrent.futures import ThreadPoolExecutor + +from .communicator import Communicator +from communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server +from communicator_objects import UnityMessage, UnityInput, UnityOutput +from .exception import UnityTimeOutException + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("unityagents") + + +class UnityToExternalServicerImplementation(UnityToExternalServicer): + parent_conn, child_conn = Pipe() + + def Initialize(self, request, context): + self.child_conn.send(request) + return self.child_conn.recv() + + def Exchange(self, request, context): + self.child_conn.send(request) + return self.child_conn.recv() + + +class RpcCommunicator(Communicator): + def __init__(self, worker_id=0, + base_port=5005): + """ + Python side of the grpc communication. Python is the server and Unity the client + + + :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. + :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. + """ + self.port = base_port + worker_id + self.worker_id = worker_id + self.server = None + self.unity_to_external = None + self.is_open = False + + def initialize(self, inputs: UnityInput) -> UnityOutput: + try: + # Establish communication grpc + self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) + self.unity_to_external = UnityToExternalServicerImplementation() + add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) + self.server.add_insecure_port('[::]:'+str(self.port)) + self.server.start() + except : + raise UnityTimeOutException( + "Couldn't start socket communication because worker number {} is still in use. " + "You may need to manually close a previously opened environment " + "or use a different worker number.".format(str(self.worker_id))) + if not self.unity_to_external.parent_conn.poll(30): + raise UnityTimeOutException( + "The Unity environment took too long to respond. Make sure that :\n" + "\t The environment does not need user interaction to launch\n" + "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" + "\t The environment and the Python interface have compatible versions.") + aca_param = self.unity_to_external.parent_conn.recv().unity_output + self.is_open = True + message = UnityMessage() + message.header.status = 200 + message.unity_input.CopyFrom(inputs) + self.unity_to_external.parent_conn.send(message) + self.unity_to_external.parent_conn.recv() + return aca_param + + def exchange(self, inputs: UnityInput) -> UnityOutput: + message = UnityMessage() + message.header.status = 200 + message.unity_input.CopyFrom(inputs) + self.unity_to_external.parent_conn.send(message) + output = self.unity_to_external.parent_conn.recv() + if output.header.status != 200: + return None + return output.unity_output + + def close(self): + """ + Sends a shutdown signal to the unity environment, and closes the grpc connection. + """ + if self.is_open: + message_input = UnityMessage() + message_input.header.status = 400 + self.unity_to_external.parent_conn.send(message_input) + self.unity_to_external.parent_conn.close() + self.server.stop(False) + self.is_open = False + + + + diff --git a/python/unityagents/socket_communicator.py b/python/unityagents/socket_communicator.py new file mode 100755 index 000000000..f55550767 --- /dev/null +++ b/python/unityagents/socket_communicator.py @@ -0,0 +1,98 @@ +import logging +import socket +import struct + +from .communicator import Communicator +from communicator_objects import UnityMessage, UnityOutput, UnityInput +from .exception import UnityTimeOutException + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("unityagents") + + +class SocketCommunicator(Communicator): + def __init__(self, worker_id=0, + base_port=5005): + """ + Python side of the socket communication + + :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. + :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. + """ + + self.port = base_port + worker_id + self._buffer_size = 12000 + self.worker_id = worker_id + self._socket = None + self._conn = None + + def initialize(self, inputs: UnityInput) -> UnityOutput: + try: + # Establish communication socket + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._socket.bind(("localhost", self.port)) + except: + raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. " + "You may need to manually close a previously opened environment " + "or use a different worker number.".format(str(self.worker_id))) + try: + self._socket.settimeout(30) + self._socket.listen(1) + self._conn, _ = self._socket.accept() + self._conn.settimeout(30) + except : + raise UnityTimeOutException( + "The Unity environment took too long to respond. Make sure that :\n" + "\t The environment does not need user interaction to launch\n" + "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" + "\t The environment and the Python interface have compatible versions.") + message = UnityMessage() + message.header.status = 200 + message.unity_input.CopyFrom(inputs) + self._communicator_send(message.SerializeToString()) + initialization_output = UnityMessage() + initialization_output.ParseFromString(self._communicator_receive()) + return initialization_output.unity_output + + def _communicator_receive(self): + try: + s = self._conn.recv(self._buffer_size) + message_length = struct.unpack("I", bytearray(s[:4]))[0] + s = s[4:] + while len(s) != message_length: + s += self._conn.recv(self._buffer_size) + except socket.timeout as e: + raise UnityTimeOutException("The environment took too long to respond.") + return s + + def _communicator_send(self, message): + self._conn.send(struct.pack("I", len(message)) + message) + + def exchange(self, inputs: UnityInput) -> UnityOutput: + message = UnityMessage() + message.header.status = 200 + message.unity_input.CopyFrom(inputs) + self._communicator_send(message.SerializeToString()) + outputs = UnityMessage() + outputs.ParseFromString(self._communicator_receive()) + if outputs.header.status != 200: + return None + return outputs.unity_output + + def close(self): + """ + Sends a shutdown signal to the unity environment, and closes the socket connection. + """ + if self._socket is not None and self._conn is not None: + message_input = UnityMessage() + message_input.header.status = 400 + self._communicator_send(message_input.SerializeToString()) + if self._socket is not None: + self._socket.close() + self._socket = None + if self._socket is not None: + self._conn.close() + self._conn = None + diff --git a/python/unitytrainers/trainer_controller.py b/python/unitytrainers/trainer_controller.py index e3458e5d8..890da8a2d 100644 --- a/python/unitytrainers/trainer_controller.py +++ b/python/unitytrainers/trainer_controller.py @@ -35,11 +35,12 @@ class TrainerController(object): :param trainer_config_path: Fully qualified path to location of trainer configuration file """ self.trainer_config_path = trainer_config_path - env_path = (env_path.strip() - .replace('.app', '') - .replace('.exe', '') - .replace('.x86_64', '') - .replace('.x86', '')) # Strip out executable extensions if passed + if env_path is not None: + env_path = (env_path.strip() + .replace('.app', '') + .replace('.exe', '') + .replace('.x86_64', '') + .replace('.x86', '')) # Strip out executable extensions if passed # Recognize and use docker volume if one is passed as an argument if docker_target_name == '': self.docker_training = False @@ -51,8 +52,9 @@ class TrainerController(object): self.model_path = '/{docker_target_name}/models/{run_id}'.format( docker_target_name=docker_target_name, run_id=run_id) - env_path = '/{docker_target_name}/{env_name}'.format(docker_target_name=docker_target_name, - env_name=env_path) + if env_path is not None : + env_path = '/{docker_target_name}/{env_name}'.format(docker_target_name=docker_target_name, + env_name=env_path) if curriculum_file is None: self.curriculum_file = None else: @@ -78,7 +80,10 @@ class TrainerController(object): self.env = UnityEnvironment(file_name=env_path, worker_id=self.worker_id, curriculum=self.curriculum_file, seed=self.seed, docker_training=self.docker_training) - self.env_name = os.path.basename(os.path.normpath(env_path)) # Extract out name of environment + if env_path is None: + self.env_name = 'editor_'+self.env.academy_name + else: + self.env_name = os.path.basename(os.path.normpath(env_path)) # Extract out name of environment def _get_progress(self): if self.curriculum_file is not None: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer.meta new file mode 100644 index 000000000..af0fdcb10 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: e44343d7e31b04d47bd5f7329c918ffe +folderAsset: yes +timeCreated: 1521839636 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll new file mode 100755 index 000000000..6ea720de8 Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll.meta new file mode 100644 index 000000000..e08504227 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll.meta @@ -0,0 +1,30 @@ +fileFormatVersion: 2 +guid: 0836ffd04a4924861a2d58aa4b111937 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + Any: + second: + enabled: 1 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + DefaultValueInitialized: true + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Grpc.Core.dll b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Grpc.Core.dll new file mode 100644 index 000000000..601f87c27 Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Grpc.Core.dll differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Grpc.Core.dll.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Grpc.Core.dll.meta new file mode 100644 index 000000000..2a461c726 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/Grpc.Core.dll.meta @@ -0,0 +1,30 @@ +fileFormatVersion: 2 +guid: cbf24ddeec4054edc9ad4c8295556878 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + Any: + second: + enabled: 1 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + DefaultValueInitialized: true + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/System.Interactive.Async.dll b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/System.Interactive.Async.dll new file mode 100755 index 000000000..364a99c32 Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/System.Interactive.Async.dll differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta new file mode 100644 index 000000000..1ee8b2e13 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta @@ -0,0 +1,32 @@ +fileFormatVersion: 2 +guid: 9502ce7e38c5947dba996570732b6e9f +timeCreated: 1521661784 +licenseType: Free +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + Any: + second: + enabled: 1 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + DefaultValueInitialized: true + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes.meta new file mode 100644 index 000000000..6995400ae --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: b8022add2e5264884a117894eeaf9809 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux.meta new file mode 100644 index 000000000..97848b129 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: 50c3602c6f6244621861928757e31463 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native.meta new file mode 100644 index 000000000..a8b33def0 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: ba192b1e561564e1583e0a87334f8682 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so new file mode 100755 index 000000000..9bf86dc2d Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta new file mode 100644 index 000000000..62496d62f --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta @@ -0,0 +1,102 @@ +fileFormatVersion: 2 +guid: c9d901caf522f4dc5815786fa764a5da +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + '': Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude Editor: 0 + Exclude Linux: 1 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 1 + Exclude Win: 0 + Exclude Win64: 0 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86_64 + DefaultValueInitialized: true + OS: Linux + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: x86_64 + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: x86_64 + - first: + Standalone: OSXUniversal + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so new file mode 100755 index 000000000..fce304168 Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta new file mode 100644 index 000000000..f612ded0b --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta @@ -0,0 +1,102 @@ +fileFormatVersion: 2 +guid: 7dfb52431a6d941c89758cf0a217e3ab +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + '': Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude Editor: 0 + Exclude Linux: 0 + Exclude Linux64: 1 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 1 + Exclude Win: 0 + Exclude Win64: 0 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86 + DefaultValueInitialized: true + OS: Linux + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: OSXUniversal + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx.meta new file mode 100644 index 000000000..69cbe8ef6 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: f43fa6e62fb4c4105b270be1ae7bbbfd +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native.meta new file mode 100644 index 000000000..24fab959d --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: 55aee008fb6a3411aa96f2f9911f9207 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle new file mode 100755 index 000000000..58390e6cb Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta new file mode 100644 index 000000000..61d5f9812 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta @@ -0,0 +1,106 @@ +fileFormatVersion: 2 +guid: 7eeb863bd08ba4388829c23da03a714f +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + '': Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude Editor: 0 + Exclude Linux: 1 + Exclude Linux64: 1 + Exclude LinuxUniversal: 1 + Exclude OSXUniversal: 0 + Exclude Win: 1 + Exclude Win64: 1 + Exclude iOS: 1 + - first: + '': OSXIntel + second: + enabled: 1 + settings: {} + - first: + '': OSXIntel64 + second: + enabled: 1 + settings: {} + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86_64 + DefaultValueInitialized: true + OS: OSX + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 0 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 0 + settings: + CPU: x86_64 + - first: + Standalone: OSXUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win.meta new file mode 100644 index 000000000..b1e54c9a4 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: a961485c3484a4002ac4961a8481f6cc +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native.meta new file mode 100644 index 000000000..42e4968ae --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native.meta @@ -0,0 +1,10 @@ +fileFormatVersion: 2 +guid: af9f9f367bbc543b8ba41e58dcdd6e66 +folderAsset: yes +timeCreated: 1521595360 +licenseType: Free +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll new file mode 100755 index 000000000..b2e48711b Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta new file mode 100644 index 000000000..888979c7c --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta @@ -0,0 +1,102 @@ +fileFormatVersion: 2 +guid: f4d9429fe43154fbd9d158c129e0ff33 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + '': Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude Editor: 0 + Exclude Linux: 0 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 0 + Exclude Win: 1 + Exclude Win64: 0 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86_64 + DefaultValueInitialized: true + OS: Windows + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: None + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Standalone: Linux + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: x86_64 + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: OSXUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Win64 + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll new file mode 100755 index 000000000..45d5c324a Binary files /dev/null and b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll differ diff --git a/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta new file mode 100644 index 000000000..9c7036f37 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta @@ -0,0 +1,102 @@ +fileFormatVersion: 2 +guid: d74134114def74fb4ae781c015deaa95 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + isPreloaded: 0 + isOverridable: 0 + platformData: + - first: + '': Any + second: + enabled: 0 + settings: + Exclude Android: 1 + Exclude Editor: 0 + Exclude Linux: 0 + Exclude Linux64: 0 + Exclude LinuxUniversal: 0 + Exclude OSXUniversal: 0 + Exclude Win: 0 + Exclude Win64: 1 + Exclude iOS: 1 + - first: + Android: Android + second: + enabled: 0 + settings: + CPU: ARMv7 + - first: + Any: + second: + enabled: 0 + settings: {} + - first: + Editor: Editor + second: + enabled: 1 + settings: + CPU: x86 + DefaultValueInitialized: true + OS: Windows + - first: + Facebook: Win + second: + enabled: 0 + settings: + CPU: AnyCPU + - first: + Facebook: Win64 + second: + enabled: 0 + settings: + CPU: None + - first: + Standalone: Linux + second: + enabled: 1 + settings: + CPU: x86 + - first: + Standalone: Linux64 + second: + enabled: 1 + settings: + CPU: x86_64 + - first: + Standalone: LinuxUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: OSXUniversal + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win + second: + enabled: 1 + settings: + CPU: AnyCPU + - first: + Standalone: Win64 + second: + enabled: 0 + settings: + CPU: None + - first: + iPhone: iOS + second: + enabled: 0 + settings: + AddToEmbeddedBinaries: false + CompileFlags: + FrameworkDependencies: + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/Academy.cs b/unity-environment/Assets/ML-Agents/Scripts/Academy.cs index 6e58022ab..3ba887ac3 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/Academy.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/Academy.cs @@ -1,5 +1,10 @@ using System.Collections.Generic; using UnityEngine; +using System.IO; +using System.Linq; +#if UNITY_EDITOR +using UnityEditor; +#endif /** * Welcome to Unity Machine Learning Agents (ML-Agents). @@ -87,6 +92,8 @@ public class EnvironmentConfiguration "docs/Learning-Environment-Design-Academy.md")] public abstract class Academy : MonoBehaviour { + private const string kApiVersion = "API-4"; + // Fields provided in the Inspector [SerializeField] @@ -130,6 +137,11 @@ public abstract class Academy : MonoBehaviour /// Training or Inference mode. bool isCommunicatorOn; + /// Keeps track of the id of the last communicator message received. + /// Remains 0 if there are no communicators. Is used to ensure that + /// the same message is not used multiple times. + private ulong lastCommunicatorMessageNumber; + /// If true, the Academy will use inference settings. This field is /// initialized in depending on the presence /// or absence of a communicator. Furthermore, it can be modified by an @@ -159,8 +171,15 @@ public abstract class Academy : MonoBehaviour /// engine settings at the next environment step. bool modeSwitched; - /// Pointer to the communicator currently in use by the Academy. - Communicator communicator; + /// Pointer to the batcher currently in use by the Academy. + MLAgents.Batcher brainBatcher; + + /// Used to write error messages. + StreamWriter logWriter; + + /// The path to where the log should be written. + string logPath; + // Flag used to keep track of the first time the Academy is reset. bool firstAcademyReset; @@ -208,32 +227,100 @@ public abstract class Academy : MonoBehaviour InitializeEnvironment(); } + // Used to read Python-provided environment parameters + private int ReadArgs() + { + var args = System.Environment.GetCommandLineArgs(); + var inputPort = ""; + for (var i = 0; i < args.Length; i++) + { + if (args[i] == "--port") + { + inputPort = args[i + 1]; + } + } + return int.Parse(inputPort); + } + /// /// Initializes the environment, configures it and initialized the Academy. /// - void InitializeEnvironment() + private void InitializeEnvironment() { // Retrieve Brain and initialize Academy - List brains = GetBrains(gameObject); + var brains = GetBrains(gameObject); InitializeAcademy(); + MLAgents.Communicator communicator= null; - // Check for existence of communicator - communicator = new ExternalCommunicator(this); - if (!communicator.CommunicatorHandShake()) + // Try to launch the communicator by usig the arguments passed at launch + try + { + communicator = new MLAgents.RPCCommunicator( + new MLAgents.CommunicatorParameters + { + port = ReadArgs() + }); + } + // If it fails, we check if there are any external brains in the scene + // If there are : Launch the communicator on the default port + // If there arn't, there is no need for a communicator and it is set + // to null + catch { communicator = null; + var externalBrain = brains.FirstOrDefault(b => b.brainType == BrainType.External); + if (externalBrain != null) + { + communicator = new MLAgents.RPCCommunicator( + new MLAgents.CommunicatorParameters + { + port = 5005 + }); + } + } + brainBatcher = new MLAgents.Batcher(communicator); + // Initialize Brains and communicator (if present) - foreach (Brain brain in brains) + foreach (var brain in brains) { - brain.InitializeBrain(this, communicator); + brain.InitializeBrain(this, brainBatcher); } if (communicator != null) { isCommunicatorOn = true; - communicator.InitializeCommunicator(); - communicator.UpdateCommand(); + + var academyParameters = new MLAgents.CommunicatorObjects.UnityRLInitializationOutput(); + academyParameters.Name = gameObject.name; + academyParameters.Version = kApiVersion; + foreach (var brain in brains) + { + var bp = brain.brainParameters; + academyParameters.BrainParameters.Add( + MLAgents.Batcher.BrainParametersConvertor( + bp, + brain.gameObject.name, + (MLAgents.CommunicatorObjects.BrainTypeProto) + brain.brainType)); + + } + academyParameters.EnvironmentParameters = + new MLAgents.CommunicatorObjects.EnvironmentParametersProto(); + foreach (var key in resetParameters.Keys) + { + academyParameters.EnvironmentParameters.FloatParameters.Add( + key, resetParameters[key] + ); + } + var pythonParameters = brainBatcher.SendAcademyParameters(academyParameters); + Random.InitState(pythonParameters.Seed); + Application.logMessageReceived += HandleLog; + logPath = Path.GetFullPath(".") + "/unity-environment.log"; + logWriter = new StreamWriter(logPath, false); + logWriter.WriteLine(System.DateTime.Now.ToString()); + logWriter.WriteLine(" "); + logWriter.Close(); } // If a communicator is enabled/provided, then we assume we are in @@ -253,6 +340,15 @@ public abstract class Academy : MonoBehaviour ConfigureEnvironment(); } + void HandleLog(string logString, string stackTrace, LogType type) + { + logWriter = new StreamWriter(logPath, true); + logWriter.WriteLine(type.ToString()); + logWriter.WriteLine(logString); + logWriter.WriteLine(stackTrace); + logWriter.Close(); + } + /// /// Configures the environment settings depending on the training/inference /// mode and the corresponding parameters passed in the Editor. @@ -398,15 +494,6 @@ public abstract class Academy : MonoBehaviour return isCommunicatorOn; } - /// - /// Returns the Communicator currently used by the Academy. - /// - /// The commincator currently in use (may be null). - public Communicator GetCommunicator() - { - return communicator; - } - /// /// Forces the full reset. The done flags are not affected. Is either /// called the first reset at inference and every external reset @@ -430,24 +517,32 @@ public abstract class Academy : MonoBehaviour ConfigureEnvironment(); modeSwitched = false; } - - if (isCommunicatorOn) + if ((isCommunicatorOn) && + (lastCommunicatorMessageNumber != brainBatcher.GetNumberMessageReceived())) { - if (communicator.GetCommand() == ExternalCommand.RESET) + lastCommunicatorMessageNumber = brainBatcher.GetNumberMessageReceived(); + if (brainBatcher.GetCommand() == + MLAgents.CommunicatorObjects.CommandProto.Reset) { // Update reset parameters. - Dictionary NewResetParameters = - communicator.GetResetParameters(); - foreach (KeyValuePair kv in NewResetParameters) + var newResetParameters = brainBatcher.GetEnvironmentParameters(); + if (newResetParameters != null) { - resetParameters[kv.Key] = kv.Value; + foreach (var kv in newResetParameters.FloatParameters) + { + resetParameters[kv.Key] = kv.Value; + } } + SetIsInference(!brainBatcher.GetIsTraining()); ForcedFullReset(); - communicator.SetCommand(ExternalCommand.STEP); } - if (communicator.GetCommand() == ExternalCommand.QUIT) + if (brainBatcher.GetCommand() == + MLAgents.CommunicatorObjects.CommandProto.Quit) { +#if UNITY_EDITOR + EditorApplication.isPlaying = false; +#endif Application.Quit(); return; } @@ -465,6 +560,8 @@ public abstract class Academy : MonoBehaviour AgentSetStatus(maxStepReached, done, stepCount); + brainBatcher.RegisterAcademyDoneFlag(done); + if (done) { EnvironmentReset(); @@ -480,12 +577,6 @@ public abstract class Academy : MonoBehaviour AgentAct(); - if (done) - { - done = false; - maxStepReached = false; - } - stepCount += 1; } @@ -496,6 +587,8 @@ public abstract class Academy : MonoBehaviour { stepCount = 0; episodeCount++; + done = false; + maxStepReached = false; AcademyReset(); } diff --git a/unity-environment/Assets/ML-Agents/Scripts/Agent.cs b/unity-environment/Assets/ML-Agents/Scripts/Agent.cs index 2a1c6158e..e70d1da80 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/Agent.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/Agent.cs @@ -456,6 +456,8 @@ public abstract class Agent : MonoBehaviour action.vectorActions = new float[1]; info.storedVectorActions = new float[1]; } + if (info.textObservation==null) + info.textObservation = ""; action.textActions = ""; info.memories = new List(); action.memories = new List(); diff --git a/unity-environment/Assets/ML-Agents/Scripts/Batcher.cs b/unity-environment/Assets/ML-Agents/Scripts/Batcher.cs new file mode 100644 index 000000000..a994974d6 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/Batcher.cs @@ -0,0 +1,352 @@ +using System.Collections.Generic; +using System.Linq; +using UnityEngine; +using Google.Protobuf; + +namespace MLAgents +{ + /// + /// The batcher is an RL specific class that makes sure that the information each object in + /// Unity (Academy and Brains) wants to send to External is appropriately batched together + /// and sent only when necessary. + /// + /// The Batcher will only send a Message to the Communicator when either : + /// 1 - The academy is done + /// 2 - At least one brain has data to send + /// + /// At each step, the batcher will keep track of the brains that queried the batcher for that + /// step. The batcher can only send the batched data when all the Brains have queried the + /// Batcher. + /// + public class Batcher + { + /// The default number of agents in the scene + private const int NumAgents = 32; + + /// Keeps track of which brains have data to send on the current step + Dictionary m_hasData = + new Dictionary(); + /// Keeps track of which brains queried the batcher on the current step + Dictionary m_hasQueried = + new Dictionary(); + /// Keeps track of the agents of each brain on the current step + Dictionary> m_currentAgents = + new Dictionary>(); + /// The Communicator of the batcher, sends a message at most once per step + Communicator m_communicator; + /// The current UnityRLOutput to be sent when all the brains queried the batcher + CommunicatorObjects.UnityRLOutput m_currentUnityRLOutput = + new CommunicatorObjects.UnityRLOutput(); + /// Keeps track of the done flag of the Academy + bool m_academyDone; + /// Keeps track of last CommandProto sent by External + CommunicatorObjects.CommandProto m_command; + /// Keeps track of last EnvironmentParametersProto sent by External + CommunicatorObjects.EnvironmentParametersProto m_environmentParameters; + /// Keeps track of last training mode sent by External + bool m_isTraining; + + /// Keeps track of the number of messages received + private ulong m_messagesReceived; + + /// + /// Initializes a new instance of the Batcher class. + /// + /// The communicator to be used by the batcher. + public Batcher(Communicator communicator) + { + this.m_communicator = communicator; + } + + /// + /// Sends the academy parameters through the Communicator. + /// Is used by the academy to send the AcademyParameters to the communicator. + /// + /// The External Initialization Parameters received. + /// The Unity Initialization Paramters to be sent. + public CommunicatorObjects.UnityRLInitializationInput SendAcademyParameters( + CommunicatorObjects.UnityRLInitializationOutput academyParameters) + { + CommunicatorObjects.UnityInput input; + var initializationInput = new CommunicatorObjects.UnityInput(); + try + { + initializationInput = m_communicator.Initialize( + new CommunicatorObjects.UnityOutput + { + RlInitializationOutput = academyParameters + }, + out input); + } + catch + { + throw new UnityAgentsException( + "The Communicator was unable to connect. Please make sure the External " + + "process is ready to accept communication with Unity."); + } + + var firstRlInput = input.RlInput; + m_command = firstRlInput.Command; + m_environmentParameters = firstRlInput.EnvironmentParameters; + m_isTraining = firstRlInput.IsTraining; + return initializationInput.RlInitializationInput; + } + + /// + /// Registers the done flag of the academy to the next output to be sent + /// to the communicator. + /// + /// If set to true + /// The academy done state will be sent to External at the next Exchange. + public void RegisterAcademyDoneFlag(bool done) + { + m_academyDone = done; + } + + /// + /// Gets the command. Is used by the academy to get reset or quit signals. + /// + /// The current command. + public CommunicatorObjects.CommandProto GetCommand() + { + return m_command; + } + + /// + /// Gets the number of messages received so far. Can be used to check for new messages. + /// + /// The number of messages received since start of the simulation + public ulong GetNumberMessageReceived() + { + return m_messagesReceived; + } + + /// + /// Gets the environment parameters. Is used by the academy to update + /// the environment parameters. + /// + /// The environment parameters. + public CommunicatorObjects.EnvironmentParametersProto GetEnvironmentParameters() + { + return m_environmentParameters; + } + + /// + /// Gets the last training_mode flag External sent + /// + /// true, if training mode is requested, false otherwise. + public bool GetIsTraining() + { + return m_isTraining; + } + + /// + /// Adds the brain to the list of brains which will be sending information to External. + /// + /// Brain key. + public void SubscribeBrain(string brainKey) + { + m_hasQueried[brainKey] = false; + m_hasData[brainKey] = false; + m_currentAgents[brainKey] = new List(NumAgents); + m_currentUnityRLOutput.AgentInfos.Add( + brainKey, + new CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto()); + } + + /// + /// Converts a AgentInfo to a protobuffer generated AgentInfoProto + /// + /// The protobuf verison of the AgentInfo. + /// The AgentInfo to convert. + public static CommunicatorObjects.AgentInfoProto + AgentInfoConvertor(AgentInfo info) + { + + var agentInfoProto = new CommunicatorObjects.AgentInfoProto + { + StackedVectorObservation = { info.stackedVectorObservation }, + StoredVectorActions = { info.storedVectorActions }, + Memories = { info.memories }, + StoredTextActions = info.storedTextActions, + TextObservation = info.textObservation, + Reward = info.reward, + MaxStepReached = info.maxStepReached, + Done = info.done, + Id = info.id, + }; + foreach (Texture2D obs in info.visualObservations) + { + agentInfoProto.VisualObservations.Add( + ByteString.CopyFrom(obs.EncodeToJPG()) + ); + } + return agentInfoProto; + } + + /// + /// Converts a Brain into to a Protobuff BrainInfoProto so it can be sent + /// + /// The BrainInfoProto generated. + /// The BrainParameters. + /// The name of the brain. + /// The type of brain. + public static CommunicatorObjects.BrainParametersProto BrainParametersConvertor( + BrainParameters brainParameters, string name, CommunicatorObjects.BrainTypeProto type) + { + + var brainParametersProto = new CommunicatorObjects.BrainParametersProto + { + VectorObservationSize = brainParameters.vectorObservationSize, + NumStackedVectorObservations = brainParameters.numStackedVectorObservations, + VectorActionSize = brainParameters.vectorActionSize, + VectorActionSpaceType = + (CommunicatorObjects.SpaceTypeProto)brainParameters.vectorActionSpaceType, + VectorObservationSpaceType = + (CommunicatorObjects.SpaceTypeProto)brainParameters.vectorObservationSpaceType, + BrainName = name, + BrainType = type + }; + brainParametersProto.VectorActionDescriptions.AddRange( + brainParameters.vectorActionDescriptions); + foreach (resolution res in brainParameters.cameraResolutions) + { + brainParametersProto.CameraResolutions.Add( + new CommunicatorObjects.ResolutionProto + { + Width = res.width, + Height = res.height, + GrayScale = res.blackAndWhite + }); + } + return brainParametersProto; + } + + /// + /// Sends the brain info. If at least one brain has an agent in need of + /// a decision or if the academy is done, the data is sent via + /// Communicator. Else, a new step is realized. The data can only be + /// sent once all the brains that subscribed to the batcher have tried + /// to send information. + /// + /// Brain key. + /// Agent info. + public void SendBrainInfo( + string brainKey, Dictionary agentInfo) + { + // If no communicator is initialized, the Batcher will not transmit + // BrainInfo + if (m_communicator == null) + { + return; + } + + // The brain tried called GiveBrainInfo, update m_hasQueried + m_hasQueried[brainKey] = true; + // Populate the currentAgents dictionary + m_currentAgents[brainKey].Clear(); + foreach (Agent agent in agentInfo.Keys) + { + m_currentAgents[brainKey].Add(agent); + } + // If at least one agent has data to send, then append data to + // the message and update hasSentState + if (m_currentAgents[brainKey].Count > 0) + { + foreach (Agent agent in m_currentAgents[brainKey]) + { + CommunicatorObjects.AgentInfoProto agentInfoProto = + AgentInfoConvertor(agentInfo[agent]); + m_currentUnityRLOutput.AgentInfos[brainKey].Value.Add(agentInfoProto); + } + m_hasData[brainKey] = true; + } + + // If any agent needs to send data, then the whole message + // must be sent + if (m_hasQueried.Values.All(x => x)) + { + if (m_hasData.Values.Any(x => x) || m_academyDone) + { + m_currentUnityRLOutput.GlobalDone = m_academyDone; + SendBatchedMessageHelper(); + } + // The message was just sent so we must reset hasSentState and + // triedSendState + foreach (string k in m_currentAgents.Keys) + { + m_hasData[k] = false; + m_hasQueried[k] = false; + } + } + } + + /// + /// Helper method that sends the curent UnityRLOutput, receives the next UnityInput and + /// Applies the appropriate AgentAction to the agents. + /// + void SendBatchedMessageHelper() + { + var input = m_communicator.Exchange( + new CommunicatorObjects.UnityOutput{ + RlOutput = m_currentUnityRLOutput + }); + m_messagesReceived += 1; + + foreach (string k in m_currentUnityRLOutput.AgentInfos.Keys) + { + m_currentUnityRLOutput.AgentInfos[k].Value.Clear(); + } + if (input == null) + { + m_command = CommunicatorObjects.CommandProto.Quit; + return; + } + + CommunicatorObjects.UnityRLInput rlInput = input.RlInput; + + if (rlInput == null) + { + m_command = CommunicatorObjects.CommandProto.Quit; + return; + } + + m_command = rlInput.Command; + m_environmentParameters = rlInput.EnvironmentParameters; + m_isTraining = rlInput.IsTraining; + + if (rlInput.AgentActions == null) + { + return; + } + + foreach (var brainName in rlInput.AgentActions.Keys) + { + if (!m_currentAgents[brainName].Any()) + { + continue; + } + if (!rlInput.AgentActions[brainName].Value.Any()) + { + continue; + } + for (var i = 0; i < m_currentAgents[brainName].Count(); i++) + { + var agent = m_currentAgents[brainName][i]; + var action = rlInput.AgentActions[brainName].Value[i]; + agent.UpdateVectorAction( + action.VectorActions.ToArray()); + agent.UpdateMemoriesAction( + action.Memories.ToList()); + agent.UpdateTextAction( + action.TextActions); + } + } + + } + + } +} + + + diff --git a/unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/Batcher.cs.meta old mode 100755 new mode 100644 similarity index 70% rename from unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs.meta rename to unity-environment/Assets/ML-Agents/Scripts/Batcher.cs.meta index 27fe3cd73..e7a87c640 --- a/unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs.meta +++ b/unity-environment/Assets/ML-Agents/Scripts/Batcher.cs.meta @@ -1,8 +1,9 @@ fileFormatVersion: 2 -guid: 9685de855ca1541409f4187c5ab7601d -timeCreated: 1504820023 +guid: 4243d5dc0ad5746cba578575182f8c17 +timeCreated: 1523045876 licenseType: Free MonoImporter: + externalObjects: {} serializedVersion: 2 defaultReferences: [] executionOrder: 0 diff --git a/unity-environment/Assets/ML-Agents/Scripts/Brain.cs b/unity-environment/Assets/ML-Agents/Scripts/Brain.cs index 33ecbf489..be279fd8a 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/Brain.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/Brain.cs @@ -1,4 +1,4 @@ -using System.Collections; +using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.UI; @@ -203,10 +203,10 @@ public class Brain : MonoBehaviour } /// This is called by the Academy at the start of the environemnt. - public void InitializeBrain(Academy aca, Communicator communicator) + public void InitializeBrain(Academy aca, MLAgents.Batcher brainBatcher) { UpdateCoreBrains(); - coreBrain.InitializeCoreBrain(communicator); + coreBrain.InitializeCoreBrain(brainBatcher); aca.BrainDecideAction += DecideAction; isInitialized = true; } diff --git a/unity-environment/Assets/ML-Agents/Scripts/Communicator.cs b/unity-environment/Assets/ML-Agents/Scripts/Communicator.cs index 0e3eb4015..61f2f50ce 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/Communicator.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/Communicator.cs @@ -1,80 +1,74 @@ using System.Collections; using System.Collections.Generic; using UnityEngine; +using MLAgents.CommunicatorObjects; -/** \brief AcademyParameters is a structure containing basic information about the - * training environment. */ -/** The AcademyParameters will be sent via socket at the start of the Environment. - * This structure does not need to be modified. - */ -public struct AcademyParameters +namespace MLAgents { - /**< \brief The name of the Academy. If the communicator is External, - * it will be the name of the Academy GameObject */ - public string AcademyName; + public struct CommunicatorParameters + { + public int port; + } - /**< \brief The API number for the communicator. */ - public string apiNumber; + /** + This is the interface of the Communicators. + This does not need to be modified nor implemented to create a Unity environment. - /**< \brief The location of the logfile*/ - public string logPath; + When the Unity Communicator is initialized, it will wait for the External Communicator + to be initialized as well. The two communicators will then exchange their first messages + that will usually contain information for initialization (information that does not need + to be resent at each new exchange). - /**< \brief The default reset parameters are sent via socket*/ - public Dictionary resetParameters; + By convention a Unity input is from External to Unity and a Unity output is from Unity to + External. Inputs and outputs are relative to Unity. - /**< \brief A list of the all the brains names sent via socket*/ - public List brainNames; + By convention, when the Unity Communicator and External Communicator call exchange, the + exchange is NOT simultaneous but sequential. This means that when a side of the + communication calls exchange, the other will receive the result of its previous + exchange call. + This is what happens when A calls exchange a single time: + A sends data_1 to B -> B receives data_1 -> B generates and sends data_2 -> A receives data_2 + When A calls exchange, it sends data_1 and receives data_2 - /**< \brief A list of the External brains parameters sent via socket*/ - public List brainParameters; - - /**< \brief A list of the External brains names sent via socket*/ - public List externalBrainNames; -} - -public enum ExternalCommand -{ - STEP, - RESET, - QUIT -} - -/** - * This is the interface used to generate coordinators. - * This does not need to be modified nor implemented to create a - * Unity environment. - */ -public interface Communicator -{ - - /// Implement this method to allow brains to subscribe to the - /// decisions made outside of Unity - void SubscribeBrain(Brain brain); - - /// First contact between Communicator and external process - bool CommunicatorHandShake(); - - /// Implement this method to initialize the communicator - void InitializeCommunicator(); - - /// Implement this method to receive actions from outside of Unity and - /// update the actions of the brains that subscribe - void UpdateActions(); - - /// Implement this method to return the ExternalCommand that - /// was given outside of Unity - ExternalCommand GetCommand(); - - void UpdateCommand(); - void SetCommand(ExternalCommand c); - - /// Implement this method to return the new dictionary of resetParameters - /// that was given outside of Unity - Dictionary GetResetParameters(); - - - - Dictionary GetHasTried(); - Dictionary GetSent(); + Since the messages are sent back and forth with exchange and simultaneously when calling + initialize, External sends two messages at initialization. + The structure of the messages is as follows: + UnityMessage + ...Header + ...UnityOutput + ......UnityRLOutput + ......UnityRLInitializationOutput + ...UnityInput + ......UnityRLIntput + ......UnityRLInitializationIntput + + UnityOutput and UnityInput can be extended to provide functionalities beyond RL + UnityRLOutput and UnityRLInput can be extended to provide new RL functionalities + */ + public interface Communicator + { + /// + /// Initialize the communicator by sending the first UnityOutput and receiving the + /// first UnityInput. The second UnityInput is stored in the unityInput argument. + /// + /// The first Unity Input. + /// The first Unity Output. + /// The second Unity input. + UnityInput Initialize(UnityOutput unityOutput, + out UnityInput unityInput); + + /// + /// Send a UnityOutput and receives a UnityInput. + /// + /// The next UnityInput. + /// The UnityOutput to be sent. + UnityInput Exchange(UnityOutput unityOutput); + + /// + /// Close the communicator gracefully on both sides of the communication. + /// + void Close(); + + } } diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects.meta new file mode 100644 index 000000000..cef92044c --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 7ebeef5df83b74a048b7f99681672f3b +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs new file mode 100644 index 000000000..3096b50fd --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs @@ -0,0 +1,217 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/agent_action_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/agent_action_proto.proto + public static partial class AgentActionProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/agent_action_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AgentActionProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci1jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9hY3Rpb25fcHJvdG8ucHJv", + "dG8SFGNvbW11bmljYXRvcl9vYmplY3RzIlIKEEFnZW50QWN0aW9uUHJvdG8S", + "FgoOdmVjdG9yX2FjdGlvbnMYASADKAISFAoMdGV4dF9hY3Rpb25zGAIgASgJ", + "EhAKCG1lbW9yaWVzGAMgAygCQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JP", + "YmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class AgentActionProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentActionProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.AgentActionProtoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentActionProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentActionProto(AgentActionProto other) : this() { + vectorActions_ = other.vectorActions_.Clone(); + textActions_ = other.textActions_; + memories_ = other.memories_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentActionProto Clone() { + return new AgentActionProto(this); + } + + /// Field number for the "vector_actions" field. + public const int VectorActionsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_vectorActions_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField vectorActions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VectorActions { + get { return vectorActions_; } + } + + /// Field number for the "text_actions" field. + public const int TextActionsFieldNumber = 2; + private string textActions_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TextActions { + get { return textActions_; } + set { + textActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "memories" field. + public const int MemoriesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_memories_codec + = pb::FieldCodec.ForFloat(26); + private readonly pbc::RepeatedField memories_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Memories { + get { return memories_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AgentActionProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AgentActionProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!vectorActions_.Equals(other.vectorActions_)) return false; + if (TextActions != other.TextActions) return false; + if(!memories_.Equals(other.memories_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= vectorActions_.GetHashCode(); + if (TextActions.Length != 0) hash ^= TextActions.GetHashCode(); + hash ^= memories_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + vectorActions_.WriteTo(output, _repeated_vectorActions_codec); + if (TextActions.Length != 0) { + output.WriteRawTag(18); + output.WriteString(TextActions); + } + memories_.WriteTo(output, _repeated_memories_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += vectorActions_.CalculateSize(_repeated_vectorActions_codec); + if (TextActions.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions); + } + size += memories_.CalculateSize(_repeated_memories_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AgentActionProto other) { + if (other == null) { + return; + } + vectorActions_.Add(other.vectorActions_); + if (other.TextActions.Length != 0) { + TextActions = other.TextActions; + } + memories_.Add(other.memories_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + vectorActions_.AddEntriesFrom(input, _repeated_vectorActions_codec); + break; + } + case 18: { + TextActions = input.ReadString(); + break; + } + case 26: + case 29: { + memories_.AddEntriesFrom(input, _repeated_memories_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs.meta new file mode 100644 index 000000000..1a1eda655 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 93eec67e32dc3484ca9b8e3ea98909c7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs new file mode 100644 index 000000000..d85a8a7a0 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs @@ -0,0 +1,402 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/agent_info_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/agent_info_proto.proto + public static partial class AgentInfoProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/agent_info_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AgentInfoProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Citjb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9pbmZvX3Byb3RvLnByb3Rv", + "EhRjb21tdW5pY2F0b3Jfb2JqZWN0cyL9AQoOQWdlbnRJbmZvUHJvdG8SIgoa", + "c3RhY2tlZF92ZWN0b3Jfb2JzZXJ2YXRpb24YASADKAISGwoTdmlzdWFsX29i", + "c2VydmF0aW9ucxgCIAMoDBIYChB0ZXh0X29ic2VydmF0aW9uGAMgASgJEh0K", + "FXN0b3JlZF92ZWN0b3JfYWN0aW9ucxgEIAMoAhIbChNzdG9yZWRfdGV4dF9h", + "Y3Rpb25zGAUgASgJEhAKCG1lbW9yaWVzGAYgAygCEg4KBnJld2FyZBgHIAEo", + "AhIMCgRkb25lGAggASgIEhgKEG1heF9zdGVwX3JlYWNoZWQYCSABKAgSCgoC", + "aWQYCiABKAVCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy", + "b3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "VisualObservations", "TextObservation", "StoredVectorActions", "StoredTextActions", "Memories", "Reward", "Done", "MaxStepReached", "Id" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class AgentInfoProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentInfoProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.AgentInfoProtoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoProto(AgentInfoProto other) : this() { + stackedVectorObservation_ = other.stackedVectorObservation_.Clone(); + visualObservations_ = other.visualObservations_.Clone(); + textObservation_ = other.textObservation_; + storedVectorActions_ = other.storedVectorActions_.Clone(); + storedTextActions_ = other.storedTextActions_; + memories_ = other.memories_.Clone(); + reward_ = other.reward_; + done_ = other.done_; + maxStepReached_ = other.maxStepReached_; + id_ = other.id_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AgentInfoProto Clone() { + return new AgentInfoProto(this); + } + + /// Field number for the "stacked_vector_observation" field. + public const int StackedVectorObservationFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_stackedVectorObservation_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField stackedVectorObservation_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField StackedVectorObservation { + get { return stackedVectorObservation_; } + } + + /// Field number for the "visual_observations" field. + public const int VisualObservationsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_visualObservations_codec + = pb::FieldCodec.ForBytes(18); + private readonly pbc::RepeatedField visualObservations_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VisualObservations { + get { return visualObservations_; } + } + + /// Field number for the "text_observation" field. + public const int TextObservationFieldNumber = 3; + private string textObservation_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TextObservation { + get { return textObservation_; } + set { + textObservation_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "stored_vector_actions" field. + public const int StoredVectorActionsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_storedVectorActions_codec + = pb::FieldCodec.ForFloat(34); + private readonly pbc::RepeatedField storedVectorActions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField StoredVectorActions { + get { return storedVectorActions_; } + } + + /// Field number for the "stored_text_actions" field. + public const int StoredTextActionsFieldNumber = 5; + private string storedTextActions_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string StoredTextActions { + get { return storedTextActions_; } + set { + storedTextActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "memories" field. + public const int MemoriesFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_memories_codec + = pb::FieldCodec.ForFloat(50); + private readonly pbc::RepeatedField memories_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Memories { + get { return memories_; } + } + + /// Field number for the "reward" field. + public const int RewardFieldNumber = 7; + private float reward_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Reward { + get { return reward_; } + set { + reward_ = value; + } + } + + /// Field number for the "done" field. + public const int DoneFieldNumber = 8; + private bool done_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Done { + get { return done_; } + set { + done_ = value; + } + } + + /// Field number for the "max_step_reached" field. + public const int MaxStepReachedFieldNumber = 9; + private bool maxStepReached_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool MaxStepReached { + get { return maxStepReached_; } + set { + maxStepReached_ = value; + } + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 10; + private int id_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Id { + get { return id_; } + set { + id_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AgentInfoProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AgentInfoProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!stackedVectorObservation_.Equals(other.stackedVectorObservation_)) return false; + if(!visualObservations_.Equals(other.visualObservations_)) return false; + if (TextObservation != other.TextObservation) return false; + if(!storedVectorActions_.Equals(other.storedVectorActions_)) return false; + if (StoredTextActions != other.StoredTextActions) return false; + if(!memories_.Equals(other.memories_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false; + if (Done != other.Done) return false; + if (MaxStepReached != other.MaxStepReached) return false; + if (Id != other.Id) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= stackedVectorObservation_.GetHashCode(); + hash ^= visualObservations_.GetHashCode(); + if (TextObservation.Length != 0) hash ^= TextObservation.GetHashCode(); + hash ^= storedVectorActions_.GetHashCode(); + if (StoredTextActions.Length != 0) hash ^= StoredTextActions.GetHashCode(); + hash ^= memories_.GetHashCode(); + if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward); + if (Done != false) hash ^= Done.GetHashCode(); + if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode(); + if (Id != 0) hash ^= Id.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + stackedVectorObservation_.WriteTo(output, _repeated_stackedVectorObservation_codec); + visualObservations_.WriteTo(output, _repeated_visualObservations_codec); + if (TextObservation.Length != 0) { + output.WriteRawTag(26); + output.WriteString(TextObservation); + } + storedVectorActions_.WriteTo(output, _repeated_storedVectorActions_codec); + if (StoredTextActions.Length != 0) { + output.WriteRawTag(42); + output.WriteString(StoredTextActions); + } + memories_.WriteTo(output, _repeated_memories_codec); + if (Reward != 0F) { + output.WriteRawTag(61); + output.WriteFloat(Reward); + } + if (Done != false) { + output.WriteRawTag(64); + output.WriteBool(Done); + } + if (MaxStepReached != false) { + output.WriteRawTag(72); + output.WriteBool(MaxStepReached); + } + if (Id != 0) { + output.WriteRawTag(80); + output.WriteInt32(Id); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += stackedVectorObservation_.CalculateSize(_repeated_stackedVectorObservation_codec); + size += visualObservations_.CalculateSize(_repeated_visualObservations_codec); + if (TextObservation.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TextObservation); + } + size += storedVectorActions_.CalculateSize(_repeated_storedVectorActions_codec); + if (StoredTextActions.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(StoredTextActions); + } + size += memories_.CalculateSize(_repeated_memories_codec); + if (Reward != 0F) { + size += 1 + 4; + } + if (Done != false) { + size += 1 + 1; + } + if (MaxStepReached != false) { + size += 1 + 1; + } + if (Id != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AgentInfoProto other) { + if (other == null) { + return; + } + stackedVectorObservation_.Add(other.stackedVectorObservation_); + visualObservations_.Add(other.visualObservations_); + if (other.TextObservation.Length != 0) { + TextObservation = other.TextObservation; + } + storedVectorActions_.Add(other.storedVectorActions_); + if (other.StoredTextActions.Length != 0) { + StoredTextActions = other.StoredTextActions; + } + memories_.Add(other.memories_); + if (other.Reward != 0F) { + Reward = other.Reward; + } + if (other.Done != false) { + Done = other.Done; + } + if (other.MaxStepReached != false) { + MaxStepReached = other.MaxStepReached; + } + if (other.Id != 0) { + Id = other.Id; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + stackedVectorObservation_.AddEntriesFrom(input, _repeated_stackedVectorObservation_codec); + break; + } + case 18: { + visualObservations_.AddEntriesFrom(input, _repeated_visualObservations_codec); + break; + } + case 26: { + TextObservation = input.ReadString(); + break; + } + case 34: + case 37: { + storedVectorActions_.AddEntriesFrom(input, _repeated_storedVectorActions_codec); + break; + } + case 42: { + StoredTextActions = input.ReadString(); + break; + } + case 50: + case 53: { + memories_.AddEntriesFrom(input, _repeated_memories_codec); + break; + } + case 61: { + Reward = input.ReadFloat(); + break; + } + case 64: { + Done = input.ReadBool(); + break; + } + case 72: { + MaxStepReached = input.ReadBool(); + break; + } + case 80: { + Id = input.ReadInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs.meta new file mode 100644 index 000000000..2888079b0 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 9a2cd47d5b7a84d45b66748c405edf5a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainParametersProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainParametersProto.cs new file mode 100644 index 000000000..31755e34f --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainParametersProto.cs @@ -0,0 +1,394 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/brain_parameters_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/brain_parameters_proto.proto + public static partial class BrainParametersProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/brain_parameters_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static BrainParametersProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjFjb21tdW5pY2F0b3Jfb2JqZWN0cy9icmFpbl9wYXJhbWV0ZXJzX3Byb3Rv", + "LnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cxorY29tbXVuaWNhdG9yX29i", + "amVjdHMvcmVzb2x1dGlvbl9wcm90by5wcm90bxorY29tbXVuaWNhdG9yX29i", + "amVjdHMvYnJhaW5fdHlwZV9wcm90by5wcm90bxorY29tbXVuaWNhdG9yX29i", + "amVjdHMvc3BhY2VfdHlwZV9wcm90by5wcm90byLGAwoUQnJhaW5QYXJhbWV0", + "ZXJzUHJvdG8SHwoXdmVjdG9yX29ic2VydmF0aW9uX3NpemUYASABKAUSJwof", + "bnVtX3N0YWNrZWRfdmVjdG9yX29ic2VydmF0aW9ucxgCIAEoBRIaChJ2ZWN0", + "b3JfYWN0aW9uX3NpemUYAyABKAUSQQoSY2FtZXJhX3Jlc29sdXRpb25zGAQg", + "AygLMiUuY29tbXVuaWNhdG9yX29iamVjdHMuUmVzb2x1dGlvblByb3RvEiIK", + "GnZlY3Rvcl9hY3Rpb25fZGVzY3JpcHRpb25zGAUgAygJEkYKGHZlY3Rvcl9h", + "Y3Rpb25fc3BhY2VfdHlwZRgGIAEoDjIkLmNvbW11bmljYXRvcl9vYmplY3Rz", + "LlNwYWNlVHlwZVByb3RvEksKHXZlY3Rvcl9vYnNlcnZhdGlvbl9zcGFjZV90", + "eXBlGAcgASgOMiQuY29tbXVuaWNhdG9yX29iamVjdHMuU3BhY2VUeXBlUHJv", + "dG8SEgoKYnJhaW5fbmFtZRgIIAEoCRI4CgpicmFpbl90eXBlGAkgASgOMiQu", + "Y29tbXVuaWNhdG9yX29iamVjdHMuQnJhaW5UeXBlUHJvdG9CH6oCHE1MQWdl", + "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ResolutionProtoReflection.Descriptor, global::MLAgents.CommunicatorObjects.BrainTypeProtoReflection.Descriptor, global::MLAgents.CommunicatorObjects.SpaceTypeProtoReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.BrainParametersProto), global::MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorObservationSize", "NumStackedVectorObservations", "VectorActionSize", "CameraResolutions", "VectorActionDescriptions", "VectorActionSpaceType", "VectorObservationSpaceType", "BrainName", "BrainType" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class BrainParametersProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BrainParametersProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.BrainParametersProtoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BrainParametersProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BrainParametersProto(BrainParametersProto other) : this() { + vectorObservationSize_ = other.vectorObservationSize_; + numStackedVectorObservations_ = other.numStackedVectorObservations_; + vectorActionSize_ = other.vectorActionSize_; + cameraResolutions_ = other.cameraResolutions_.Clone(); + vectorActionDescriptions_ = other.vectorActionDescriptions_.Clone(); + vectorActionSpaceType_ = other.vectorActionSpaceType_; + vectorObservationSpaceType_ = other.vectorObservationSpaceType_; + brainName_ = other.brainName_; + brainType_ = other.brainType_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BrainParametersProto Clone() { + return new BrainParametersProto(this); + } + + /// Field number for the "vector_observation_size" field. + public const int VectorObservationSizeFieldNumber = 1; + private int vectorObservationSize_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int VectorObservationSize { + get { return vectorObservationSize_; } + set { + vectorObservationSize_ = value; + } + } + + /// Field number for the "num_stacked_vector_observations" field. + public const int NumStackedVectorObservationsFieldNumber = 2; + private int numStackedVectorObservations_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumStackedVectorObservations { + get { return numStackedVectorObservations_; } + set { + numStackedVectorObservations_ = value; + } + } + + /// Field number for the "vector_action_size" field. + public const int VectorActionSizeFieldNumber = 3; + private int vectorActionSize_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int VectorActionSize { + get { return vectorActionSize_; } + set { + vectorActionSize_ = value; + } + } + + /// Field number for the "camera_resolutions" field. + public const int CameraResolutionsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_cameraResolutions_codec + = pb::FieldCodec.ForMessage(34, global::MLAgents.CommunicatorObjects.ResolutionProto.Parser); + private readonly pbc::RepeatedField cameraResolutions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField CameraResolutions { + get { return cameraResolutions_; } + } + + /// Field number for the "vector_action_descriptions" field. + public const int VectorActionDescriptionsFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_vectorActionDescriptions_codec + = pb::FieldCodec.ForString(42); + private readonly pbc::RepeatedField vectorActionDescriptions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VectorActionDescriptions { + get { return vectorActionDescriptions_; } + } + + /// Field number for the "vector_action_space_type" field. + public const int VectorActionSpaceTypeFieldNumber = 6; + private global::MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceType_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceType { + get { return vectorActionSpaceType_; } + set { + vectorActionSpaceType_ = value; + } + } + + /// Field number for the "vector_observation_space_type" field. + public const int VectorObservationSpaceTypeFieldNumber = 7; + private global::MLAgents.CommunicatorObjects.SpaceTypeProto vectorObservationSpaceType_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.SpaceTypeProto VectorObservationSpaceType { + get { return vectorObservationSpaceType_; } + set { + vectorObservationSpaceType_ = value; + } + } + + /// Field number for the "brain_name" field. + public const int BrainNameFieldNumber = 8; + private string brainName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string BrainName { + get { return brainName_; } + set { + brainName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "brain_type" field. + public const int BrainTypeFieldNumber = 9; + private global::MLAgents.CommunicatorObjects.BrainTypeProto brainType_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.BrainTypeProto BrainType { + get { return brainType_; } + set { + brainType_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BrainParametersProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BrainParametersProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (VectorObservationSize != other.VectorObservationSize) return false; + if (NumStackedVectorObservations != other.NumStackedVectorObservations) return false; + if (VectorActionSize != other.VectorActionSize) return false; + if(!cameraResolutions_.Equals(other.cameraResolutions_)) return false; + if(!vectorActionDescriptions_.Equals(other.vectorActionDescriptions_)) return false; + if (VectorActionSpaceType != other.VectorActionSpaceType) return false; + if (VectorObservationSpaceType != other.VectorObservationSpaceType) return false; + if (BrainName != other.BrainName) return false; + if (BrainType != other.BrainType) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (VectorObservationSize != 0) hash ^= VectorObservationSize.GetHashCode(); + if (NumStackedVectorObservations != 0) hash ^= NumStackedVectorObservations.GetHashCode(); + if (VectorActionSize != 0) hash ^= VectorActionSize.GetHashCode(); + hash ^= cameraResolutions_.GetHashCode(); + hash ^= vectorActionDescriptions_.GetHashCode(); + if (VectorActionSpaceType != 0) hash ^= VectorActionSpaceType.GetHashCode(); + if (VectorObservationSpaceType != 0) hash ^= VectorObservationSpaceType.GetHashCode(); + if (BrainName.Length != 0) hash ^= BrainName.GetHashCode(); + if (BrainType != 0) hash ^= BrainType.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (VectorObservationSize != 0) { + output.WriteRawTag(8); + output.WriteInt32(VectorObservationSize); + } + if (NumStackedVectorObservations != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumStackedVectorObservations); + } + if (VectorActionSize != 0) { + output.WriteRawTag(24); + output.WriteInt32(VectorActionSize); + } + cameraResolutions_.WriteTo(output, _repeated_cameraResolutions_codec); + vectorActionDescriptions_.WriteTo(output, _repeated_vectorActionDescriptions_codec); + if (VectorActionSpaceType != 0) { + output.WriteRawTag(48); + output.WriteEnum((int) VectorActionSpaceType); + } + if (VectorObservationSpaceType != 0) { + output.WriteRawTag(56); + output.WriteEnum((int) VectorObservationSpaceType); + } + if (BrainName.Length != 0) { + output.WriteRawTag(66); + output.WriteString(BrainName); + } + if (BrainType != 0) { + output.WriteRawTag(72); + output.WriteEnum((int) BrainType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (VectorObservationSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(VectorObservationSize); + } + if (NumStackedVectorObservations != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumStackedVectorObservations); + } + if (VectorActionSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(VectorActionSize); + } + size += cameraResolutions_.CalculateSize(_repeated_cameraResolutions_codec); + size += vectorActionDescriptions_.CalculateSize(_repeated_vectorActionDescriptions_codec); + if (VectorActionSpaceType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceType); + } + if (VectorObservationSpaceType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorObservationSpaceType); + } + if (BrainName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(BrainName); + } + if (BrainType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) BrainType); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BrainParametersProto other) { + if (other == null) { + return; + } + if (other.VectorObservationSize != 0) { + VectorObservationSize = other.VectorObservationSize; + } + if (other.NumStackedVectorObservations != 0) { + NumStackedVectorObservations = other.NumStackedVectorObservations; + } + if (other.VectorActionSize != 0) { + VectorActionSize = other.VectorActionSize; + } + cameraResolutions_.Add(other.cameraResolutions_); + vectorActionDescriptions_.Add(other.vectorActionDescriptions_); + if (other.VectorActionSpaceType != 0) { + VectorActionSpaceType = other.VectorActionSpaceType; + } + if (other.VectorObservationSpaceType != 0) { + VectorObservationSpaceType = other.VectorObservationSpaceType; + } + if (other.BrainName.Length != 0) { + BrainName = other.BrainName; + } + if (other.BrainType != 0) { + BrainType = other.BrainType; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + VectorObservationSize = input.ReadInt32(); + break; + } + case 16: { + NumStackedVectorObservations = input.ReadInt32(); + break; + } + case 24: { + VectorActionSize = input.ReadInt32(); + break; + } + case 34: { + cameraResolutions_.AddEntriesFrom(input, _repeated_cameraResolutions_codec); + break; + } + case 42: { + vectorActionDescriptions_.AddEntriesFrom(input, _repeated_vectorActionDescriptions_codec); + break; + } + case 48: { + vectorActionSpaceType_ = (global::MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum(); + break; + } + case 56: { + vectorObservationSpaceType_ = (global::MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum(); + break; + } + case 66: { + BrainName = input.ReadString(); + break; + } + case 72: { + brainType_ = (global::MLAgents.CommunicatorObjects.BrainTypeProto) input.ReadEnum(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainParametersProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainParametersProto.cs.meta new file mode 100644 index 000000000..3a620addc --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainParametersProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 91e3353985a4c4c08a8004648a81de4f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainTypeProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainTypeProto.cs new file mode 100644 index 000000000..4b36687c3 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainTypeProto.cs @@ -0,0 +1,52 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/brain_type_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/brain_type_proto.proto + public static partial class BrainTypeProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/brain_type_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static BrainTypeProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Citjb21tdW5pY2F0b3Jfb2JqZWN0cy9icmFpbl90eXBlX3Byb3RvLnByb3Rv", + "EhRjb21tdW5pY2F0b3Jfb2JqZWN0cxorY29tbXVuaWNhdG9yX29iamVjdHMv", + "cmVzb2x1dGlvbl9wcm90by5wcm90bypHCg5CcmFpblR5cGVQcm90bxIKCgZQ", + "bGF5ZXIQABINCglIZXVyaXN0aWMQARIMCghFeHRlcm5hbBACEgwKCEludGVy", + "bmFsEANCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3Rv", + "Mw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ResolutionProtoReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::MLAgents.CommunicatorObjects.BrainTypeProto), }, null)); + } + #endregion + + } + #region Enums + public enum BrainTypeProto { + [pbr::OriginalName("Player")] Player = 0, + [pbr::OriginalName("Heuristic")] Heuristic = 1, + [pbr::OriginalName("External")] External = 2, + [pbr::OriginalName("Internal")] Internal = 3, + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainTypeProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainTypeProto.cs.meta new file mode 100644 index 000000000..5a1ee111f --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainTypeProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d2e4f3cea300049b7a4cd65fbee2ee95 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/CommandProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/CommandProto.cs new file mode 100644 index 000000000..3f9a5c4ee --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/CommandProto.cs @@ -0,0 +1,49 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/command_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/command_proto.proto + public static partial class CommandProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/command_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CommandProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cihjb21tdW5pY2F0b3Jfb2JqZWN0cy9jb21tYW5kX3Byb3RvLnByb3RvEhRj", + "b21tdW5pY2F0b3Jfb2JqZWN0cyotCgxDb21tYW5kUHJvdG8SCAoEU1RFUBAA", + "EgkKBVJFU0VUEAESCAoEUVVJVBACQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0", + "b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::MLAgents.CommunicatorObjects.CommandProto), }, null)); + } + #endregion + + } + #region Enums + public enum CommandProto { + [pbr::OriginalName("STEP")] Step = 0, + [pbr::OriginalName("RESET")] Reset = 1, + [pbr::OriginalName("QUIT")] Quit = 2, + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/CommandProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/CommandProto.cs.meta new file mode 100644 index 000000000..2098f9b95 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/CommandProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 19e8be280f78249c188fde36f0855094 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EngineConfigurationProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EngineConfigurationProto.cs new file mode 100644 index 000000000..489064a24 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EngineConfigurationProto.cs @@ -0,0 +1,316 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/engine_configuration_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/engine_configuration_proto.proto + public static partial class EngineConfigurationProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/engine_configuration_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static EngineConfigurationProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjVjb21tdW5pY2F0b3Jfb2JqZWN0cy9lbmdpbmVfY29uZmlndXJhdGlvbl9w", + "cm90by5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilQEKGEVuZ2luZUNv", + "bmZpZ3VyYXRpb25Qcm90bxINCgV3aWR0aBgBIAEoBRIOCgZoZWlnaHQYAiAB", + "KAUSFQoNcXVhbGl0eV9sZXZlbBgDIAEoBRISCgp0aW1lX3NjYWxlGAQgASgC", + "EhkKEXRhcmdldF9mcmFtZV9yYXRlGAUgASgFEhQKDHNob3dfbW9uaXRvchgG", + "IAEoCEIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.EngineConfigurationProto), global::MLAgents.CommunicatorObjects.EngineConfigurationProto.Parser, new[]{ "Width", "Height", "QualityLevel", "TimeScale", "TargetFrameRate", "ShowMonitor" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class EngineConfigurationProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EngineConfigurationProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.EngineConfigurationProtoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EngineConfigurationProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EngineConfigurationProto(EngineConfigurationProto other) : this() { + width_ = other.width_; + height_ = other.height_; + qualityLevel_ = other.qualityLevel_; + timeScale_ = other.timeScale_; + targetFrameRate_ = other.targetFrameRate_; + showMonitor_ = other.showMonitor_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EngineConfigurationProto Clone() { + return new EngineConfigurationProto(this); + } + + /// Field number for the "width" field. + public const int WidthFieldNumber = 1; + private int width_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Width { + get { return width_; } + set { + width_ = value; + } + } + + /// Field number for the "height" field. + public const int HeightFieldNumber = 2; + private int height_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Height { + get { return height_; } + set { + height_ = value; + } + } + + /// Field number for the "quality_level" field. + public const int QualityLevelFieldNumber = 3; + private int qualityLevel_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int QualityLevel { + get { return qualityLevel_; } + set { + qualityLevel_ = value; + } + } + + /// Field number for the "time_scale" field. + public const int TimeScaleFieldNumber = 4; + private float timeScale_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float TimeScale { + get { return timeScale_; } + set { + timeScale_ = value; + } + } + + /// Field number for the "target_frame_rate" field. + public const int TargetFrameRateFieldNumber = 5; + private int targetFrameRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int TargetFrameRate { + get { return targetFrameRate_; } + set { + targetFrameRate_ = value; + } + } + + /// Field number for the "show_monitor" field. + public const int ShowMonitorFieldNumber = 6; + private bool showMonitor_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ShowMonitor { + get { return showMonitor_; } + set { + showMonitor_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as EngineConfigurationProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(EngineConfigurationProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Width != other.Width) return false; + if (Height != other.Height) return false; + if (QualityLevel != other.QualityLevel) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TimeScale, other.TimeScale)) return false; + if (TargetFrameRate != other.TargetFrameRate) return false; + if (ShowMonitor != other.ShowMonitor) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Width != 0) hash ^= Width.GetHashCode(); + if (Height != 0) hash ^= Height.GetHashCode(); + if (QualityLevel != 0) hash ^= QualityLevel.GetHashCode(); + if (TimeScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TimeScale); + if (TargetFrameRate != 0) hash ^= TargetFrameRate.GetHashCode(); + if (ShowMonitor != false) hash ^= ShowMonitor.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Width != 0) { + output.WriteRawTag(8); + output.WriteInt32(Width); + } + if (Height != 0) { + output.WriteRawTag(16); + output.WriteInt32(Height); + } + if (QualityLevel != 0) { + output.WriteRawTag(24); + output.WriteInt32(QualityLevel); + } + if (TimeScale != 0F) { + output.WriteRawTag(37); + output.WriteFloat(TimeScale); + } + if (TargetFrameRate != 0) { + output.WriteRawTag(40); + output.WriteInt32(TargetFrameRate); + } + if (ShowMonitor != false) { + output.WriteRawTag(48); + output.WriteBool(ShowMonitor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Width != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width); + } + if (Height != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height); + } + if (QualityLevel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(QualityLevel); + } + if (TimeScale != 0F) { + size += 1 + 4; + } + if (TargetFrameRate != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TargetFrameRate); + } + if (ShowMonitor != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(EngineConfigurationProto other) { + if (other == null) { + return; + } + if (other.Width != 0) { + Width = other.Width; + } + if (other.Height != 0) { + Height = other.Height; + } + if (other.QualityLevel != 0) { + QualityLevel = other.QualityLevel; + } + if (other.TimeScale != 0F) { + TimeScale = other.TimeScale; + } + if (other.TargetFrameRate != 0) { + TargetFrameRate = other.TargetFrameRate; + } + if (other.ShowMonitor != false) { + ShowMonitor = other.ShowMonitor; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Width = input.ReadInt32(); + break; + } + case 16: { + Height = input.ReadInt32(); + break; + } + case 24: { + QualityLevel = input.ReadInt32(); + break; + } + case 37: { + TimeScale = input.ReadFloat(); + break; + } + case 40: { + TargetFrameRate = input.ReadInt32(); + break; + } + case 48: { + ShowMonitor = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EngineConfigurationProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EngineConfigurationProto.cs.meta new file mode 100644 index 000000000..fcd861ce9 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EngineConfigurationProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: fac934345fc664df8823b494ea9b1ca8 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EnvironmentParametersProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EnvironmentParametersProto.cs new file mode 100644 index 000000000..b340a79b9 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EnvironmentParametersProto.cs @@ -0,0 +1,169 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/environment_parameters_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/environment_parameters_proto.proto + public static partial class EnvironmentParametersProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/environment_parameters_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static EnvironmentParametersProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cjdjb21tdW5pY2F0b3Jfb2JqZWN0cy9lbnZpcm9ubWVudF9wYXJhbWV0ZXJz", + "X3Byb3RvLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyK1AQoaRW52aXJv", + "bm1lbnRQYXJhbWV0ZXJzUHJvdG8SXwoQZmxvYXRfcGFyYW1ldGVycxgBIAMo", + "CzJFLmNvbW11bmljYXRvcl9vYmplY3RzLkVudmlyb25tZW50UGFyYW1ldGVy", + "c1Byb3RvLkZsb2F0UGFyYW1ldGVyc0VudHJ5GjYKFEZsb2F0UGFyYW1ldGVy", + "c0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoAjoCOAFCH6oCHE1M", + "QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.EnvironmentParametersProto), global::MLAgents.CommunicatorObjects.EnvironmentParametersProto.Parser, new[]{ "FloatParameters" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }) + })); + } + #endregion + + } + #region Messages + public sealed partial class EnvironmentParametersProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EnvironmentParametersProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.EnvironmentParametersProtoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EnvironmentParametersProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EnvironmentParametersProto(EnvironmentParametersProto other) : this() { + floatParameters_ = other.floatParameters_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EnvironmentParametersProto Clone() { + return new EnvironmentParametersProto(this); + } + + /// Field number for the "float_parameters" field. + public const int FloatParametersFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_floatParameters_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForFloat(21), 10); + private readonly pbc::MapField floatParameters_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField FloatParameters { + get { return floatParameters_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as EnvironmentParametersProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(EnvironmentParametersProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!FloatParameters.Equals(other.FloatParameters)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= FloatParameters.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + floatParameters_.WriteTo(output, _map_floatParameters_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += floatParameters_.CalculateSize(_map_floatParameters_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(EnvironmentParametersProto other) { + if (other == null) { + return; + } + floatParameters_.Add(other.floatParameters_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + floatParameters_.AddEntriesFrom(input, _map_floatParameters_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EnvironmentParametersProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EnvironmentParametersProto.cs.meta new file mode 100644 index 000000000..a1f4ceb10 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/EnvironmentParametersProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 312dc062dfab44416a31b8b273cda29a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/Header.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/Header.cs new file mode 100644 index 000000000..c932d5e97 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/Header.cs @@ -0,0 +1,202 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/header.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/header.proto + public static partial class HeaderReflection { + + #region Descriptor + /// File descriptor for communicator_objects/header.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static HeaderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiFjb21tdW5pY2F0b3Jfb2JqZWN0cy9oZWFkZXIucHJvdG8SFGNvbW11bmlj", + "YXRvcl9vYmplY3RzIikKBkhlYWRlchIOCgZzdGF0dXMYASABKAUSDwoHbWVz", + "c2FnZRgCIAEoCUIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IG", + "cHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.Header), global::MLAgents.CommunicatorObjects.Header.Parser, new[]{ "Status", "Message" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class Header : pb::IMessage
{ + private static readonly pb::MessageParser
_parser = new pb::MessageParser
(() => new Header()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser
Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.HeaderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Header() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Header(Header other) : this() { + status_ = other.status_; + message_ = other.message_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Header Clone() { + return new Header(this); + } + + /// Field number for the "status" field. + public const int StatusFieldNumber = 1; + private int status_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Status { + get { return status_; } + set { + status_ = value; + } + } + + /// Field number for the "message" field. + public const int MessageFieldNumber = 2; + private string message_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Message { + get { return message_; } + set { + message_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Header); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Header other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Status != other.Status) return false; + if (Message != other.Message) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Status != 0) hash ^= Status.GetHashCode(); + if (Message.Length != 0) hash ^= Message.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Status != 0) { + output.WriteRawTag(8); + output.WriteInt32(Status); + } + if (Message.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Message); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Status != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Status); + } + if (Message.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Message); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Header other) { + if (other == null) { + return; + } + if (other.Status != 0) { + Status = other.Status; + } + if (other.Message.Length != 0) { + Message = other.Message; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Status = input.ReadInt32(); + break; + } + case 18: { + Message = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/Header.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/Header.cs.meta new file mode 100644 index 000000000..956bcafad --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/Header.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e582b089dfedc438d9cbce9d4017b807 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/ResolutionProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/ResolutionProto.cs new file mode 100644 index 000000000..1de1ecdcf --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/ResolutionProto.cs @@ -0,0 +1,230 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/resolution_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/resolution_proto.proto + public static partial class ResolutionProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/resolution_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ResolutionProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Citjb21tdW5pY2F0b3Jfb2JqZWN0cy9yZXNvbHV0aW9uX3Byb3RvLnByb3Rv", + "EhRjb21tdW5pY2F0b3Jfb2JqZWN0cyJECg9SZXNvbHV0aW9uUHJvdG8SDQoF", + "d2lkdGgYASABKAUSDgoGaGVpZ2h0GAIgASgFEhIKCmdyYXlfc2NhbGUYAyAB", + "KAhCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.ResolutionProto), global::MLAgents.CommunicatorObjects.ResolutionProto.Parser, new[]{ "Width", "Height", "GrayScale" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class ResolutionProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResolutionProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.ResolutionProtoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResolutionProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResolutionProto(ResolutionProto other) : this() { + width_ = other.width_; + height_ = other.height_; + grayScale_ = other.grayScale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResolutionProto Clone() { + return new ResolutionProto(this); + } + + /// Field number for the "width" field. + public const int WidthFieldNumber = 1; + private int width_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Width { + get { return width_; } + set { + width_ = value; + } + } + + /// Field number for the "height" field. + public const int HeightFieldNumber = 2; + private int height_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Height { + get { return height_; } + set { + height_ = value; + } + } + + /// Field number for the "gray_scale" field. + public const int GrayScaleFieldNumber = 3; + private bool grayScale_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool GrayScale { + get { return grayScale_; } + set { + grayScale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ResolutionProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ResolutionProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Width != other.Width) return false; + if (Height != other.Height) return false; + if (GrayScale != other.GrayScale) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Width != 0) hash ^= Width.GetHashCode(); + if (Height != 0) hash ^= Height.GetHashCode(); + if (GrayScale != false) hash ^= GrayScale.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Width != 0) { + output.WriteRawTag(8); + output.WriteInt32(Width); + } + if (Height != 0) { + output.WriteRawTag(16); + output.WriteInt32(Height); + } + if (GrayScale != false) { + output.WriteRawTag(24); + output.WriteBool(GrayScale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Width != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width); + } + if (Height != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height); + } + if (GrayScale != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ResolutionProto other) { + if (other == null) { + return; + } + if (other.Width != 0) { + Width = other.Width; + } + if (other.Height != 0) { + Height = other.Height; + } + if (other.GrayScale != false) { + GrayScale = other.GrayScale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Width = input.ReadInt32(); + break; + } + case 16: { + Height = input.ReadInt32(); + break; + } + case 24: { + GrayScale = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/ResolutionProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/ResolutionProto.cs.meta new file mode 100644 index 000000000..2e06dd6f2 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/ResolutionProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ca2454611610e4136a412b5cd6afee4d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/SpaceTypeProto.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/SpaceTypeProto.cs new file mode 100644 index 000000000..c46430f8c --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/SpaceTypeProto.cs @@ -0,0 +1,49 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/space_type_proto.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/space_type_proto.proto + public static partial class SpaceTypeProtoReflection { + + #region Descriptor + /// File descriptor for communicator_objects/space_type_proto.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SpaceTypeProtoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Citjb21tdW5pY2F0b3Jfb2JqZWN0cy9zcGFjZV90eXBlX3Byb3RvLnByb3Rv", + "EhRjb21tdW5pY2F0b3Jfb2JqZWN0cxorY29tbXVuaWNhdG9yX29iamVjdHMv", + "cmVzb2x1dGlvbl9wcm90by5wcm90byouCg5TcGFjZVR5cGVQcm90bxIMCghk", + "aXNjcmV0ZRAAEg4KCmNvbnRpbnVvdXMQAUIfqgIcTUxBZ2VudHMuQ29tbXVu", + "aWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ResolutionProtoReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::MLAgents.CommunicatorObjects.SpaceTypeProto), }, null)); + } + #endregion + + } + #region Enums + public enum SpaceTypeProto { + [pbr::OriginalName("discrete")] Discrete = 0, + [pbr::OriginalName("continuous")] Continuous = 1, + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/SpaceTypeProto.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/SpaceTypeProto.cs.meta new file mode 100644 index 000000000..2eefdf17e --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/SpaceTypeProto.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: bf7e44e20999448ef846526541819077 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityInput.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityInput.cs new file mode 100644 index 000000000..fdee29b8a --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityInput.cs @@ -0,0 +1,218 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_input.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_input.proto + public static partial class UnityInputReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_input.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityInputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiZjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9pbnB1dC5wcm90bxIUY29t", + "bXVuaWNhdG9yX29iamVjdHMaKWNvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5", + "X3JsX2lucHV0LnByb3RvGjhjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9y", + "bF9pbml0aWFsaXphdGlvbl9pbnB1dC5wcm90byKVAQoKVW5pdHlJbnB1dBI0", + "CghybF9pbnB1dBgBIAEoCzIiLmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5", + "UkxJbnB1dBJRChdybF9pbml0aWFsaXphdGlvbl9pbnB1dBgCIAEoCzIwLmNv", + "bW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxJbml0aWFsaXphdGlvbklucHV0", + "Qh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityRlInputReflection.Descriptor, global::MLAgents.CommunicatorObjects.UnityRlInitializationInputReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityInput), global::MLAgents.CommunicatorObjects.UnityInput.Parser, new[]{ "RlInput", "RlInitializationInput" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class UnityInput : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityInput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityInputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityInput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityInput(UnityInput other) : this() { + RlInput = other.rlInput_ != null ? other.RlInput.Clone() : null; + RlInitializationInput = other.rlInitializationInput_ != null ? other.RlInitializationInput.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityInput Clone() { + return new UnityInput(this); + } + + /// Field number for the "rl_input" field. + public const int RlInputFieldNumber = 1; + private global::MLAgents.CommunicatorObjects.UnityRLInput rlInput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.UnityRLInput RlInput { + get { return rlInput_; } + set { + rlInput_ = value; + } + } + + /// Field number for the "rl_initialization_input" field. + public const int RlInitializationInputFieldNumber = 2; + private global::MLAgents.CommunicatorObjects.UnityRLInitializationInput rlInitializationInput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.UnityRLInitializationInput RlInitializationInput { + get { return rlInitializationInput_; } + set { + rlInitializationInput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityInput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityInput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(RlInput, other.RlInput)) return false; + if (!object.Equals(RlInitializationInput, other.RlInitializationInput)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (rlInput_ != null) hash ^= RlInput.GetHashCode(); + if (rlInitializationInput_ != null) hash ^= RlInitializationInput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (rlInput_ != null) { + output.WriteRawTag(10); + output.WriteMessage(RlInput); + } + if (rlInitializationInput_ != null) { + output.WriteRawTag(18); + output.WriteMessage(RlInitializationInput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (rlInput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInput); + } + if (rlInitializationInput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInitializationInput); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityInput other) { + if (other == null) { + return; + } + if (other.rlInput_ != null) { + if (rlInput_ == null) { + rlInput_ = new global::MLAgents.CommunicatorObjects.UnityRLInput(); + } + RlInput.MergeFrom(other.RlInput); + } + if (other.rlInitializationInput_ != null) { + if (rlInitializationInput_ == null) { + rlInitializationInput_ = new global::MLAgents.CommunicatorObjects.UnityRLInitializationInput(); + } + RlInitializationInput.MergeFrom(other.RlInitializationInput); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (rlInput_ == null) { + rlInput_ = new global::MLAgents.CommunicatorObjects.UnityRLInput(); + } + input.ReadMessage(rlInput_); + break; + } + case 18: { + if (rlInitializationInput_ == null) { + rlInitializationInput_ = new global::MLAgents.CommunicatorObjects.UnityRLInitializationInput(); + } + input.ReadMessage(rlInitializationInput_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityInput.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityInput.cs.meta new file mode 100644 index 000000000..3f4456acb --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityInput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: c97e6e2cde58d404cba31008c0489454 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityMessage.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityMessage.cs new file mode 100644 index 000000000..1653ef76b --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityMessage.cs @@ -0,0 +1,253 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_message.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_message.proto + public static partial class UnityMessageReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_message.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityMessageReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cihjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9tZXNzYWdlLnByb3RvEhRj", + "b21tdW5pY2F0b3Jfb2JqZWN0cxonY29tbXVuaWNhdG9yX29iamVjdHMvdW5p", + "dHlfb3V0cHV0LnByb3RvGiZjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9p", + "bnB1dC5wcm90bxohY29tbXVuaWNhdG9yX29iamVjdHMvaGVhZGVyLnByb3Rv", + "IqwBCgxVbml0eU1lc3NhZ2USLAoGaGVhZGVyGAEgASgLMhwuY29tbXVuaWNh", + "dG9yX29iamVjdHMuSGVhZGVyEjcKDHVuaXR5X291dHB1dBgCIAEoCzIhLmNv", + "bW11bmljYXRvcl9vYmplY3RzLlVuaXR5T3V0cHV0EjUKC3VuaXR5X2lucHV0", + "GAMgASgLMiAuY29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlJbnB1dEIfqgIc", + "TUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityOutputReflection.Descriptor, global::MLAgents.CommunicatorObjects.UnityInputReflection.Descriptor, global::MLAgents.CommunicatorObjects.HeaderReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityMessage), global::MLAgents.CommunicatorObjects.UnityMessage.Parser, new[]{ "Header", "UnityOutput", "UnityInput" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class UnityMessage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityMessage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityMessageReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityMessage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityMessage(UnityMessage other) : this() { + Header = other.header_ != null ? other.Header.Clone() : null; + UnityOutput = other.unityOutput_ != null ? other.UnityOutput.Clone() : null; + UnityInput = other.unityInput_ != null ? other.UnityInput.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityMessage Clone() { + return new UnityMessage(this); + } + + /// Field number for the "header" field. + public const int HeaderFieldNumber = 1; + private global::MLAgents.CommunicatorObjects.Header header_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.Header Header { + get { return header_; } + set { + header_ = value; + } + } + + /// Field number for the "unity_output" field. + public const int UnityOutputFieldNumber = 2; + private global::MLAgents.CommunicatorObjects.UnityOutput unityOutput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.UnityOutput UnityOutput { + get { return unityOutput_; } + set { + unityOutput_ = value; + } + } + + /// Field number for the "unity_input" field. + public const int UnityInputFieldNumber = 3; + private global::MLAgents.CommunicatorObjects.UnityInput unityInput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.UnityInput UnityInput { + get { return unityInput_; } + set { + unityInput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityMessage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityMessage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Header, other.Header)) return false; + if (!object.Equals(UnityOutput, other.UnityOutput)) return false; + if (!object.Equals(UnityInput, other.UnityInput)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (header_ != null) hash ^= Header.GetHashCode(); + if (unityOutput_ != null) hash ^= UnityOutput.GetHashCode(); + if (unityInput_ != null) hash ^= UnityInput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (header_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Header); + } + if (unityOutput_ != null) { + output.WriteRawTag(18); + output.WriteMessage(UnityOutput); + } + if (unityInput_ != null) { + output.WriteRawTag(26); + output.WriteMessage(UnityInput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (header_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Header); + } + if (unityOutput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(UnityOutput); + } + if (unityInput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(UnityInput); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityMessage other) { + if (other == null) { + return; + } + if (other.header_ != null) { + if (header_ == null) { + header_ = new global::MLAgents.CommunicatorObjects.Header(); + } + Header.MergeFrom(other.Header); + } + if (other.unityOutput_ != null) { + if (unityOutput_ == null) { + unityOutput_ = new global::MLAgents.CommunicatorObjects.UnityOutput(); + } + UnityOutput.MergeFrom(other.UnityOutput); + } + if (other.unityInput_ != null) { + if (unityInput_ == null) { + unityInput_ = new global::MLAgents.CommunicatorObjects.UnityInput(); + } + UnityInput.MergeFrom(other.UnityInput); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (header_ == null) { + header_ = new global::MLAgents.CommunicatorObjects.Header(); + } + input.ReadMessage(header_); + break; + } + case 18: { + if (unityOutput_ == null) { + unityOutput_ = new global::MLAgents.CommunicatorObjects.UnityOutput(); + } + input.ReadMessage(unityOutput_); + break; + } + case 26: { + if (unityInput_ == null) { + unityInput_ = new global::MLAgents.CommunicatorObjects.UnityInput(); + } + input.ReadMessage(unityInput_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityMessage.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityMessage.cs.meta new file mode 100644 index 000000000..9fa8d23cf --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityMessage.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 10dca984632854b079476d5fb6df329c +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityOutput.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityOutput.cs new file mode 100644 index 000000000..2fc8a7814 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityOutput.cs @@ -0,0 +1,219 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_output.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_output.proto + public static partial class UnityOutputReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_output.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityOutputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cidjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9vdXRwdXQucHJvdG8SFGNv", + "bW11bmljYXRvcl9vYmplY3RzGipjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0", + "eV9ybF9vdXRwdXQucHJvdG8aOWNvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5", + "X3JsX2luaXRpYWxpemF0aW9uX291dHB1dC5wcm90byKaAQoLVW5pdHlPdXRw", + "dXQSNgoJcmxfb3V0cHV0GAEgASgLMiMuY29tbXVuaWNhdG9yX29iamVjdHMu", + "VW5pdHlSTE91dHB1dBJTChhybF9pbml0aWFsaXphdGlvbl9vdXRwdXQYAiAB", + "KAsyMS5jb21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMSW5pdGlhbGl6YXRp", + "b25PdXRwdXRCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy", + "b3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityRlOutputReflection.Descriptor, global::MLAgents.CommunicatorObjects.UnityRlInitializationOutputReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityOutput), global::MLAgents.CommunicatorObjects.UnityOutput.Parser, new[]{ "RlOutput", "RlInitializationOutput" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class UnityOutput : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityOutput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityOutputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityOutput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityOutput(UnityOutput other) : this() { + RlOutput = other.rlOutput_ != null ? other.RlOutput.Clone() : null; + RlInitializationOutput = other.rlInitializationOutput_ != null ? other.RlInitializationOutput.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityOutput Clone() { + return new UnityOutput(this); + } + + /// Field number for the "rl_output" field. + public const int RlOutputFieldNumber = 1; + private global::MLAgents.CommunicatorObjects.UnityRLOutput rlOutput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.UnityRLOutput RlOutput { + get { return rlOutput_; } + set { + rlOutput_ = value; + } + } + + /// Field number for the "rl_initialization_output" field. + public const int RlInitializationOutputFieldNumber = 2; + private global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput rlInitializationOutput_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput RlInitializationOutput { + get { return rlInitializationOutput_; } + set { + rlInitializationOutput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityOutput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityOutput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(RlOutput, other.RlOutput)) return false; + if (!object.Equals(RlInitializationOutput, other.RlInitializationOutput)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (rlOutput_ != null) hash ^= RlOutput.GetHashCode(); + if (rlInitializationOutput_ != null) hash ^= RlInitializationOutput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (rlOutput_ != null) { + output.WriteRawTag(10); + output.WriteMessage(RlOutput); + } + if (rlInitializationOutput_ != null) { + output.WriteRawTag(18); + output.WriteMessage(RlInitializationOutput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (rlOutput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlOutput); + } + if (rlInitializationOutput_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInitializationOutput); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityOutput other) { + if (other == null) { + return; + } + if (other.rlOutput_ != null) { + if (rlOutput_ == null) { + rlOutput_ = new global::MLAgents.CommunicatorObjects.UnityRLOutput(); + } + RlOutput.MergeFrom(other.RlOutput); + } + if (other.rlInitializationOutput_ != null) { + if (rlInitializationOutput_ == null) { + rlInitializationOutput_ = new global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput(); + } + RlInitializationOutput.MergeFrom(other.RlInitializationOutput); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (rlOutput_ == null) { + rlOutput_ = new global::MLAgents.CommunicatorObjects.UnityRLOutput(); + } + input.ReadMessage(rlOutput_); + break; + } + case 18: { + if (rlInitializationOutput_ == null) { + rlInitializationOutput_ = new global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput(); + } + input.ReadMessage(rlInitializationOutput_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityOutput.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityOutput.cs.meta new file mode 100644 index 000000000..1516896ad --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityOutput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 546f38fe479d240eabdf11ac55ecf7d4 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs new file mode 100644 index 000000000..a8fda8ec9 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs @@ -0,0 +1,174 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_rl_initialization_input.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_rl_initialization_input.proto + public static partial class UnityRlInitializationInputReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_rl_initialization_input.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlInitializationInputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cjhjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbml0aWFsaXphdGlv", + "bl9pbnB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiKgoaVW5pdHlS", + "TEluaXRpYWxpemF0aW9uSW5wdXQSDAoEc2VlZBgBIAEoBUIfqgIcTUxBZ2Vu", + "dHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationInput), global::MLAgents.CommunicatorObjects.UnityRLInitializationInput.Parser, new[]{ "Seed" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class UnityRLInitializationInput : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInitializationInput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityRlInitializationInputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationInput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationInput(UnityRLInitializationInput other) : this() { + seed_ = other.seed_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationInput Clone() { + return new UnityRLInitializationInput(this); + } + + /// Field number for the "seed" field. + public const int SeedFieldNumber = 1; + private int seed_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Seed { + get { return seed_; } + set { + seed_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLInitializationInput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLInitializationInput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Seed != other.Seed) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Seed != 0) hash ^= Seed.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Seed != 0) { + output.WriteRawTag(8); + output.WriteInt32(Seed); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Seed != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Seed); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLInitializationInput other) { + if (other == null) { + return; + } + if (other.Seed != 0) { + Seed = other.Seed; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Seed = input.ReadInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs.meta new file mode 100644 index 000000000..4325bd79f --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d9c1712ba119a47458082c7190c838b0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs new file mode 100644 index 000000000..2b1cf26b7 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs @@ -0,0 +1,294 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_rl_initialization_output.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_rl_initialization_output.proto + public static partial class UnityRlInitializationOutputReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_rl_initialization_output.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlInitializationOutputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cjljb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbml0aWFsaXphdGlv", + "bl9vdXRwdXQucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjFjb21tdW5p", + "Y2F0b3Jfb2JqZWN0cy9icmFpbl9wYXJhbWV0ZXJzX3Byb3RvLnByb3RvGjdj", + "b21tdW5pY2F0b3Jfb2JqZWN0cy9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzX3By", + "b3RvLnByb3RvIuYBChtVbml0eVJMSW5pdGlhbGl6YXRpb25PdXRwdXQSDAoE", + "bmFtZRgBIAEoCRIPCgd2ZXJzaW9uGAIgASgJEhAKCGxvZ19wYXRoGAMgASgJ", + "EkQKEGJyYWluX3BhcmFtZXRlcnMYBSADKAsyKi5jb21tdW5pY2F0b3Jfb2Jq", + "ZWN0cy5CcmFpblBhcmFtZXRlcnNQcm90bxJQChZlbnZpcm9ubWVudF9wYXJh", + "bWV0ZXJzGAYgASgLMjAuY29tbXVuaWNhdG9yX29iamVjdHMuRW52aXJvbm1l", + "bnRQYXJhbWV0ZXJzUHJvdG9CH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9i", + "amVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.BrainParametersProtoReflection.Descriptor, global::MLAgents.CommunicatorObjects.EnvironmentParametersProtoReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput), global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput.Parser, new[]{ "Name", "Version", "LogPath", "BrainParameters", "EnvironmentParameters" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// The request message containing the academy's parameters. + /// + public sealed partial class UnityRLInitializationOutput : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInitializationOutput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityRlInitializationOutputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationOutput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationOutput(UnityRLInitializationOutput other) : this() { + name_ = other.name_; + version_ = other.version_; + logPath_ = other.logPath_; + brainParameters_ = other.brainParameters_.Clone(); + EnvironmentParameters = other.environmentParameters_ != null ? other.EnvironmentParameters.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInitializationOutput Clone() { + return new UnityRLInitializationOutput(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 2; + private string version_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Version { + get { return version_; } + set { + version_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "log_path" field. + public const int LogPathFieldNumber = 3; + private string logPath_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string LogPath { + get { return logPath_; } + set { + logPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "brain_parameters" field. + public const int BrainParametersFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_brainParameters_codec + = pb::FieldCodec.ForMessage(42, global::MLAgents.CommunicatorObjects.BrainParametersProto.Parser); + private readonly pbc::RepeatedField brainParameters_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField BrainParameters { + get { return brainParameters_; } + } + + /// Field number for the "environment_parameters" field. + public const int EnvironmentParametersFieldNumber = 6; + private global::MLAgents.CommunicatorObjects.EnvironmentParametersProto environmentParameters_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.EnvironmentParametersProto EnvironmentParameters { + get { return environmentParameters_; } + set { + environmentParameters_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLInitializationOutput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLInitializationOutput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Version != other.Version) return false; + if (LogPath != other.LogPath) return false; + if(!brainParameters_.Equals(other.brainParameters_)) return false; + if (!object.Equals(EnvironmentParameters, other.EnvironmentParameters)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Version.Length != 0) hash ^= Version.GetHashCode(); + if (LogPath.Length != 0) hash ^= LogPath.GetHashCode(); + hash ^= brainParameters_.GetHashCode(); + if (environmentParameters_ != null) hash ^= EnvironmentParameters.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Version.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Version); + } + if (LogPath.Length != 0) { + output.WriteRawTag(26); + output.WriteString(LogPath); + } + brainParameters_.WriteTo(output, _repeated_brainParameters_codec); + if (environmentParameters_ != null) { + output.WriteRawTag(50); + output.WriteMessage(EnvironmentParameters); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Version.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Version); + } + if (LogPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(LogPath); + } + size += brainParameters_.CalculateSize(_repeated_brainParameters_codec); + if (environmentParameters_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(EnvironmentParameters); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLInitializationOutput other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Version.Length != 0) { + Version = other.Version; + } + if (other.LogPath.Length != 0) { + LogPath = other.LogPath; + } + brainParameters_.Add(other.brainParameters_); + if (other.environmentParameters_ != null) { + if (environmentParameters_ == null) { + environmentParameters_ = new global::MLAgents.CommunicatorObjects.EnvironmentParametersProto(); + } + EnvironmentParameters.MergeFrom(other.EnvironmentParameters); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Version = input.ReadString(); + break; + } + case 26: { + LogPath = input.ReadString(); + break; + } + case 42: { + brainParameters_.AddEntriesFrom(input, _repeated_brainParameters_codec); + break; + } + case 50: { + if (environmentParameters_ == null) { + environmentParameters_ = new global::MLAgents.CommunicatorObjects.EnvironmentParametersProto(); + } + input.ReadMessage(environmentParameters_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs.meta new file mode 100644 index 000000000..dbe55dd87 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: cfac266f05f674dbd8dc50e8e9b29753 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInput.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInput.cs new file mode 100644 index 000000000..7e8569942 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInput.cs @@ -0,0 +1,397 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_rl_input.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_rl_input.proto + public static partial class UnityRlInputReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_rl_input.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlInputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ciljb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbnB1dC5wcm90bxIU", + "Y29tbXVuaWNhdG9yX29iamVjdHMaLWNvbW11bmljYXRvcl9vYmplY3RzL2Fn", + "ZW50X2FjdGlvbl9wcm90by5wcm90bxo3Y29tbXVuaWNhdG9yX29iamVjdHMv", + "ZW52aXJvbm1lbnRfcGFyYW1ldGVyc19wcm90by5wcm90bxooY29tbXVuaWNh", + "dG9yX29iamVjdHMvY29tbWFuZF9wcm90by5wcm90byK0AwoMVW5pdHlSTElu", + "cHV0EksKDWFnZW50X2FjdGlvbnMYASADKAsyNC5jb21tdW5pY2F0b3Jfb2Jq", + "ZWN0cy5Vbml0eVJMSW5wdXQuQWdlbnRBY3Rpb25zRW50cnkSUAoWZW52aXJv", + "bm1lbnRfcGFyYW1ldGVycxgCIAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3Rz", + "LkVudmlyb25tZW50UGFyYW1ldGVyc1Byb3RvEhMKC2lzX3RyYWluaW5nGAMg", + "ASgIEjMKB2NvbW1hbmQYBCABKA4yIi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5D", + "b21tYW5kUHJvdG8aTQoUTGlzdEFnZW50QWN0aW9uUHJvdG8SNQoFdmFsdWUY", + "ASADKAsyJi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5BZ2VudEFjdGlvblByb3Rv", + "GmwKEUFnZW50QWN0aW9uc0VudHJ5EgsKA2tleRgBIAEoCRJGCgV2YWx1ZRgC", + "IAEoCzI3LmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxJbnB1dC5MaXN0", + "QWdlbnRBY3Rpb25Qcm90bzoCOAFCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRv", + "ck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentActionProtoReflection.Descriptor, global::MLAgents.CommunicatorObjects.EnvironmentParametersProtoReflection.Descriptor, global::MLAgents.CommunicatorObjects.CommandProtoReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInput), global::MLAgents.CommunicatorObjects.UnityRLInput.Parser, new[]{ "AgentActions", "EnvironmentParameters", "IsTraining", "Command" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInput.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInput.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null), + null, }) + })); + } + #endregion + + } + #region Messages + public sealed partial class UnityRLInput : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityRlInputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInput(UnityRLInput other) : this() { + agentActions_ = other.agentActions_.Clone(); + EnvironmentParameters = other.environmentParameters_ != null ? other.EnvironmentParameters.Clone() : null; + isTraining_ = other.isTraining_; + command_ = other.command_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLInput Clone() { + return new UnityRLInput(this); + } + + /// Field number for the "agent_actions" field. + public const int AgentActionsFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_agentActions_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::MLAgents.CommunicatorObjects.UnityRLInput.Types.ListAgentActionProto.Parser), 10); + private readonly pbc::MapField agentActions_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField AgentActions { + get { return agentActions_; } + } + + /// Field number for the "environment_parameters" field. + public const int EnvironmentParametersFieldNumber = 2; + private global::MLAgents.CommunicatorObjects.EnvironmentParametersProto environmentParameters_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.EnvironmentParametersProto EnvironmentParameters { + get { return environmentParameters_; } + set { + environmentParameters_ = value; + } + } + + /// Field number for the "is_training" field. + public const int IsTrainingFieldNumber = 3; + private bool isTraining_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsTraining { + get { return isTraining_; } + set { + isTraining_ = value; + } + } + + /// Field number for the "command" field. + public const int CommandFieldNumber = 4; + private global::MLAgents.CommunicatorObjects.CommandProto command_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::MLAgents.CommunicatorObjects.CommandProto Command { + get { return command_; } + set { + command_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLInput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLInput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!AgentActions.Equals(other.AgentActions)) return false; + if (!object.Equals(EnvironmentParameters, other.EnvironmentParameters)) return false; + if (IsTraining != other.IsTraining) return false; + if (Command != other.Command) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= AgentActions.GetHashCode(); + if (environmentParameters_ != null) hash ^= EnvironmentParameters.GetHashCode(); + if (IsTraining != false) hash ^= IsTraining.GetHashCode(); + if (Command != 0) hash ^= Command.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + agentActions_.WriteTo(output, _map_agentActions_codec); + if (environmentParameters_ != null) { + output.WriteRawTag(18); + output.WriteMessage(EnvironmentParameters); + } + if (IsTraining != false) { + output.WriteRawTag(24); + output.WriteBool(IsTraining); + } + if (Command != 0) { + output.WriteRawTag(32); + output.WriteEnum((int) Command); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += agentActions_.CalculateSize(_map_agentActions_codec); + if (environmentParameters_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(EnvironmentParameters); + } + if (IsTraining != false) { + size += 1 + 1; + } + if (Command != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Command); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLInput other) { + if (other == null) { + return; + } + agentActions_.Add(other.agentActions_); + if (other.environmentParameters_ != null) { + if (environmentParameters_ == null) { + environmentParameters_ = new global::MLAgents.CommunicatorObjects.EnvironmentParametersProto(); + } + EnvironmentParameters.MergeFrom(other.EnvironmentParameters); + } + if (other.IsTraining != false) { + IsTraining = other.IsTraining; + } + if (other.Command != 0) { + Command = other.Command; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + agentActions_.AddEntriesFrom(input, _map_agentActions_codec); + break; + } + case 18: { + if (environmentParameters_ == null) { + environmentParameters_ = new global::MLAgents.CommunicatorObjects.EnvironmentParametersProto(); + } + input.ReadMessage(environmentParameters_); + break; + } + case 24: { + IsTraining = input.ReadBool(); + break; + } + case 32: { + command_ = (global::MLAgents.CommunicatorObjects.CommandProto) input.ReadEnum(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the UnityRLInput message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public sealed partial class ListAgentActionProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListAgentActionProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityRLInput.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentActionProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentActionProto(ListAgentActionProto other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentActionProto Clone() { + return new ListAgentActionProto(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForMessage(10, global::MLAgents.CommunicatorObjects.AgentActionProto.Parser); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ListAgentActionProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ListAgentActionProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ListAgentActionProto other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInput.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInput.cs.meta new file mode 100644 index 000000000..4cdc003df --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 0283aaaebbbaf4c438db36396a5e3885 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlOutput.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlOutput.cs new file mode 100644 index 000000000..cc63ef9e3 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlOutput.cs @@ -0,0 +1,329 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_rl_output.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_rl_output.proto + public static partial class UnityRlOutputReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_rl_output.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityRlOutputReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cipjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9vdXRwdXQucHJvdG8S", + "FGNvbW11bmljYXRvcl9vYmplY3RzGitjb21tdW5pY2F0b3Jfb2JqZWN0cy9h", + "Z2VudF9pbmZvX3Byb3RvLnByb3RvIqMCCg1Vbml0eVJMT3V0cHV0EhMKC2ds", + "b2JhbF9kb25lGAEgASgIEkcKCmFnZW50SW5mb3MYAiADKAsyMy5jb21tdW5p", + "Y2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0LkFnZW50SW5mb3NFbnRyeRpJ", + "ChJMaXN0QWdlbnRJbmZvUHJvdG8SMwoFdmFsdWUYASADKAsyJC5jb21tdW5p", + "Y2F0b3Jfb2JqZWN0cy5BZ2VudEluZm9Qcm90bxppCg9BZ2VudEluZm9zRW50", + "cnkSCwoDa2V5GAEgASgJEkUKBXZhbHVlGAIgASgLMjYuY29tbXVuaWNhdG9y", + "X29iamVjdHMuVW5pdHlSTE91dHB1dC5MaXN0QWdlbnRJbmZvUHJvdG86AjgB", + "Qh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentInfoProtoReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutput), global::MLAgents.CommunicatorObjects.UnityRLOutput.Parser, new[]{ "GlobalDone", "AgentInfos" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null), + null, }) + })); + } + #endregion + + } + #region Messages + public sealed partial class UnityRLOutput : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLOutput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityRlOutputReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLOutput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLOutput(UnityRLOutput other) : this() { + globalDone_ = other.globalDone_; + agentInfos_ = other.agentInfos_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public UnityRLOutput Clone() { + return new UnityRLOutput(this); + } + + /// Field number for the "global_done" field. + public const int GlobalDoneFieldNumber = 1; + private bool globalDone_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool GlobalDone { + get { return globalDone_; } + set { + globalDone_ = value; + } + } + + /// Field number for the "agentInfos" field. + public const int AgentInfosFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_agentInfos_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::MLAgents.CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto.Parser), 18); + private readonly pbc::MapField agentInfos_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField AgentInfos { + get { return agentInfos_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as UnityRLOutput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(UnityRLOutput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (GlobalDone != other.GlobalDone) return false; + if (!AgentInfos.Equals(other.AgentInfos)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (GlobalDone != false) hash ^= GlobalDone.GetHashCode(); + hash ^= AgentInfos.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (GlobalDone != false) { + output.WriteRawTag(8); + output.WriteBool(GlobalDone); + } + agentInfos_.WriteTo(output, _map_agentInfos_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (GlobalDone != false) { + size += 1 + 1; + } + size += agentInfos_.CalculateSize(_map_agentInfos_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(UnityRLOutput other) { + if (other == null) { + return; + } + if (other.GlobalDone != false) { + GlobalDone = other.GlobalDone; + } + agentInfos_.Add(other.agentInfos_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + GlobalDone = input.ReadBool(); + break; + } + case 18: { + agentInfos_.AddEntriesFrom(input, _map_agentInfos_codec); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the UnityRLOutput message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public sealed partial class ListAgentInfoProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListAgentInfoProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::MLAgents.CommunicatorObjects.UnityRLOutput.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentInfoProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentInfoProto(ListAgentInfoProto other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListAgentInfoProto Clone() { + return new ListAgentInfoProto(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForMessage(10, global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ListAgentInfoProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ListAgentInfoProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ListAgentInfoProto other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlOutput.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlOutput.cs.meta new file mode 100644 index 000000000..8a7368711 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlOutput.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a6665911e84e24b7e970f63662f55713 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternal.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternal.cs new file mode 100644 index 000000000..0e7e85041 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternal.cs @@ -0,0 +1,40 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_to_external.proto +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from communicator_objects/unity_to_external.proto + public static partial class UnityToExternalReflection { + + #region Descriptor + /// File descriptor for communicator_objects/unity_to_external.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static UnityToExternalReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cixjb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV90b19leHRlcm5hbC5wcm90", + "bxIUY29tbXVuaWNhdG9yX29iamVjdHMaKGNvbW11bmljYXRvcl9vYmplY3Rz", + "L3VuaXR5X21lc3NhZ2UucHJvdG8yZwoPVW5pdHlUb0V4dGVybmFsElQKCEV4", + "Y2hhbmdlEiIuY29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlNZXNzYWdlGiIu", + "Y29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlNZXNzYWdlIgBCH6oCHE1MQWdl", + "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityMessageReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null)); + } + #endregion + + } +} + +#endregion Designer generated code diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternal.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternal.cs.meta new file mode 100644 index 000000000..970211f9f --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternal.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 553c6b5d2feba4ef69206f0e0a2a92a3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs new file mode 100644 index 000000000..8be5278b2 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs @@ -0,0 +1,133 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: communicator_objects/unity_to_external.proto +// +#pragma warning disable 1591 +#region Designer generated code + +using System; +using System.Threading; +using System.Threading.Tasks; +using grpc = global::Grpc.Core; + +namespace MLAgents.CommunicatorObjects { + public static partial class UnityToExternal + { + static readonly string __ServiceName = "communicator_objects.UnityToExternal"; + + static readonly grpc::Marshaller __Marshaller_UnityMessage = grpc::Marshallers.Create((arg) => global::Google.Protobuf.MessageExtensions.ToByteArray(arg), global::MLAgents.CommunicatorObjects.UnityMessage.Parser.ParseFrom); + + static readonly grpc::Method __Method_Exchange = new grpc::Method( + grpc::MethodType.Unary, + __ServiceName, + "Exchange", + __Marshaller_UnityMessage, + __Marshaller_UnityMessage); + + /// Service descriptor + public static global::Google.Protobuf.Reflection.ServiceDescriptor Descriptor + { + get { return global::MLAgents.CommunicatorObjects.UnityToExternalReflection.Descriptor.Services[0]; } + } + + /// Base class for server-side implementations of UnityToExternal + public abstract partial class UnityToExternalBase + { + /// + /// Sends the academy parameters + /// + /// The request received from the client. + /// The context of the server-side call handler being invoked. + /// The response to send back to the client (wrapped by a task). + public virtual global::System.Threading.Tasks.Task Exchange(global::MLAgents.CommunicatorObjects.UnityMessage request, grpc::ServerCallContext context) + { + throw new grpc::RpcException(new grpc::Status(grpc::StatusCode.Unimplemented, "")); + } + + } + + /// Client for UnityToExternal + public partial class UnityToExternalClient : grpc::ClientBase + { + /// Creates a new client for UnityToExternal + /// The channel to use to make remote calls. + public UnityToExternalClient(grpc::Channel channel) : base(channel) + { + } + /// Creates a new client for UnityToExternal that uses a custom CallInvoker. + /// The callInvoker to use to make remote calls. + public UnityToExternalClient(grpc::CallInvoker callInvoker) : base(callInvoker) + { + } + /// Protected parameterless constructor to allow creation of test doubles. + protected UnityToExternalClient() : base() + { + } + /// Protected constructor to allow creation of configured clients. + /// The client configuration. + protected UnityToExternalClient(ClientBaseConfiguration configuration) : base(configuration) + { + } + + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The initial metadata to send with the call. This parameter is optional. + /// An optional deadline for the call. The call will be cancelled if deadline is hit. + /// An optional token for canceling the call. + /// The response received from the server. + public virtual global::MLAgents.CommunicatorObjects.UnityMessage Exchange(global::MLAgents.CommunicatorObjects.UnityMessage request, grpc::Metadata headers = null, DateTime? deadline = null, CancellationToken cancellationToken = default(CancellationToken)) + { + return Exchange(request, new grpc::CallOptions(headers, deadline, cancellationToken)); + } + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The options for the call. + /// The response received from the server. + public virtual global::MLAgents.CommunicatorObjects.UnityMessage Exchange(global::MLAgents.CommunicatorObjects.UnityMessage request, grpc::CallOptions options) + { + return CallInvoker.BlockingUnaryCall(__Method_Exchange, null, options, request); + } + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The initial metadata to send with the call. This parameter is optional. + /// An optional deadline for the call. The call will be cancelled if deadline is hit. + /// An optional token for canceling the call. + /// The call object. + public virtual grpc::AsyncUnaryCall ExchangeAsync(global::MLAgents.CommunicatorObjects.UnityMessage request, grpc::Metadata headers = null, DateTime? deadline = null, CancellationToken cancellationToken = default(CancellationToken)) + { + return ExchangeAsync(request, new grpc::CallOptions(headers, deadline, cancellationToken)); + } + /// + /// Sends the academy parameters + /// + /// The request to send to the server. + /// The options for the call. + /// The call object. + public virtual grpc::AsyncUnaryCall ExchangeAsync(global::MLAgents.CommunicatorObjects.UnityMessage request, grpc::CallOptions options) + { + return CallInvoker.AsyncUnaryCall(__Method_Exchange, null, options, request); + } + /// Creates a new instance of client from given ClientBaseConfiguration. + protected override UnityToExternalClient NewInstance(ClientBaseConfiguration configuration) + { + return new UnityToExternalClient(configuration); + } + } + + /// Creates service definition that can be registered with a server + /// An object implementing the server-side handling logic. + public static grpc::ServerServiceDefinition BindService(UnityToExternalBase serviceImpl) + { + return grpc::ServerServiceDefinition.CreateBuilder() + .AddMethod(__Method_Exchange, serviceImpl.Exchange).Build(); + } + + } +} +#endregion diff --git a/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs.meta new file mode 100644 index 000000000..4b7f96a0d --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d3ea7da815b0b4d938c13e621f57db04 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrain.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrain.cs index 6c018a8ed..4fff4fce5 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrain.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrain.cs @@ -11,7 +11,7 @@ public interface CoreBrain /// Implement setBrain so let the coreBrain know what brain is using it void SetBrain(Brain b); /// Implement this method to initialize CoreBrain - void InitializeCoreBrain(Communicator communicator); + void InitializeCoreBrain(MLAgents.Batcher brainBatcher); /// Implement this method to define the logic for deciding actions void DecideAction(Dictionary agentInfo); /// Implement this method to define what should be displayed in the brain Inspector diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainExternal.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainExternal.cs index 353939e07..5006e6c4b 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainExternal.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainExternal.cs @@ -8,7 +8,7 @@ public class CoreBrainExternal : ScriptableObject, CoreBrain /**< Reference to the brain that uses this CoreBrainExternal */ public Brain brain; - ExternalCommunicator coord; + MLAgents.Batcher brainBatcher; /// Creates the reference to the brain public void SetBrain(Brain b) @@ -18,20 +18,20 @@ public class CoreBrainExternal : ScriptableObject, CoreBrain /// Generates the communicator for the Academy if none was present and /// subscribe to ExternalCommunicator if it was present. - public void InitializeCoreBrain(Communicator communicator) + public void InitializeCoreBrain(MLAgents.Batcher brainBatcher) { - if (communicator == null) + if (brainBatcher == null) { - coord = null; + brainBatcher = null; throw new UnityAgentsException(string.Format("The brain {0} was set to" + " External mode" + " but Unity was unable to read the" + " arguments passed at launch.", brain.gameObject.name)); } - else if (communicator is ExternalCommunicator) + else { - coord = (ExternalCommunicator)communicator; - coord.SubscribeBrain(brain); + this.brainBatcher = brainBatcher; + this.brainBatcher.SubscribeBrain(brain.gameObject.name); } } @@ -40,11 +40,11 @@ public class CoreBrainExternal : ScriptableObject, CoreBrain /// sends them to the agents public void DecideAction(Dictionary agentInfo) { - if (coord != null) + if (brainBatcher != null) { - coord.GiveBrainInfo(brain, agentInfo); + brainBatcher.SendBrainInfo(brain.gameObject.name, agentInfo); } - return ; + return; } /// Nothing needs to appear in the inspector diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainHeuristic.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainHeuristic.cs index 5630d66d4..689bbf506 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainHeuristic.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainHeuristic.cs @@ -15,7 +15,7 @@ public class CoreBrainHeuristic : ScriptableObject, CoreBrain /**< Reference to the brain that uses this CoreBrainHeuristic */ public Brain brain; - ExternalCommunicator coord; + MLAgents.Batcher brainBatcher; /**< Reference to the Decision component used to decide the actions */ public Decision decision; @@ -27,28 +27,28 @@ public class CoreBrainHeuristic : ScriptableObject, CoreBrain } /// Create the reference to decision - public void InitializeCoreBrain(Communicator communicator) + public void InitializeCoreBrain(MLAgents.Batcher brainBatcher) { decision = brain.gameObject.GetComponent(); - if ((communicator == null) + if ((brainBatcher == null) || (!broadcast)) { - coord = null; + this.brainBatcher = null; } - else if (communicator is ExternalCommunicator) + else { - coord = (ExternalCommunicator)communicator; - coord.SubscribeBrain(brain); + this.brainBatcher = brainBatcher; ; + this.brainBatcher.SubscribeBrain(brain.gameObject.name); } } /// Uses the Decision Component to decide that action to take public void DecideAction(Dictionary agentInfo) { - if (coord!=null) + if (brainBatcher != null) { - coord.GiveBrainInfo(brain, agentInfo); + brainBatcher.SendBrainInfo(brain.gameObject.name, agentInfo); } if (decision == null) diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs index 405370706..3018ec2ea 100644 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs @@ -40,7 +40,7 @@ public class CoreBrainInternal : ScriptableObject, CoreBrain } - ExternalCommunicator coord; + MLAgents.Batcher brainBatcher; [Tooltip("This must be the bytes file corresponding to the pretrained TensorFlow graph.")] /// Modify only in inspector : Reference to the Graph asset @@ -91,7 +91,7 @@ public class CoreBrainInternal : ScriptableObject, CoreBrain } /// Loads the tensorflow graph model to generate a TFGraph object - public void InitializeCoreBrain(Communicator communicator) + public void InitializeCoreBrain(MLAgents.Batcher brainBatcher) { #if ENABLE_TENSORFLOW #if UNITY_ANDROID @@ -104,15 +104,15 @@ public class CoreBrainInternal : ScriptableObject, CoreBrain } #endif - if ((communicator == null) - || (!broadcast)) + if ((brainBatcher == null) + || (!broadcast)) { - coord = null; + this.brainBatcher = null; } - else if (communicator is ExternalCommunicator) + else { - coord = (ExternalCommunicator)communicator; - coord.SubscribeBrain(brain); + this.brainBatcher = brainBatcher; + this.brainBatcher.SubscribeBrain(brain.gameObject.name); } if (graphModel != null) @@ -164,9 +164,9 @@ public class CoreBrainInternal : ScriptableObject, CoreBrain public void DecideAction(Dictionary agentInfo) { #if ENABLE_TENSORFLOW - if (coord != null) + if (brainBatcher != null) { - coord.GiveBrainInfo(brain, agentInfo); + brainBatcher.SendBrainInfo(brain.gameObject.name, agentInfo); } int currentBatchSize = agentInfo.Count(); List agentList = agentInfo.Keys.ToList(); diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainPlayer.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainPlayer.cs index 8e8ad29ea..f260a4d0e 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainPlayer.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainPlayer.cs @@ -28,7 +28,7 @@ public class CoreBrainPlayer : ScriptableObject, CoreBrain public float value; } - ExternalCommunicator coord; + MLAgents.Batcher brainBatcher; [SerializeField] [Tooltip("The list of keys and the value they correspond to for continuous control.")] @@ -51,17 +51,18 @@ public class CoreBrainPlayer : ScriptableObject, CoreBrain } /// Nothing to implement - public void InitializeCoreBrain(Communicator communicator) + /// Nothing to implement + public void InitializeCoreBrain(MLAgents.Batcher brainBatcher) { - if ((communicator == null) + if ((brainBatcher == null) || (!broadcast)) { - coord = null; + this.brainBatcher = null; } - else if (communicator is ExternalCommunicator) + else { - coord = (ExternalCommunicator)communicator; - coord.SubscribeBrain(brain); + this.brainBatcher = brainBatcher; + this.brainBatcher.SubscribeBrain(brain.gameObject.name); } } @@ -69,10 +70,10 @@ public class CoreBrainPlayer : ScriptableObject, CoreBrain /// decide action public void DecideAction(Dictionary agentInfo) { - if (coord != null) - { - coord.GiveBrainInfo(brain, agentInfo); - } + if (brainBatcher != null) + { + brainBatcher.SendBrainInfo(brain.gameObject.name, agentInfo); + } if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous) { foreach (Agent agent in agentInfo.Keys) diff --git a/unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs b/unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs deleted file mode 100644 index d85f5acdd..000000000 --- a/unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs +++ /dev/null @@ -1,452 +0,0 @@ -using System.Collections; -using System.Collections.Generic; -using UnityEngine; - -using Newtonsoft.Json; -using System.Linq; -using System.Net.Sockets; -using System.Text; -using System.IO; - - -/// Responsible for communication with Python API. -public class ExternalCommunicator : Communicator -{ - - ExternalCommand command = ExternalCommand.QUIT; - Academy academy; - - Dictionary> current_agents; - - List brains; - Dictionary hasSentState; - Dictionary triedSendState; - - const int messageLength = 12000; - const int defaultNumAgents = 32; - const int defaultNumObservations = 32; - - - int comPort; - int randomSeed; - Socket sender; - byte[] messageHolder; - byte[] lengthHolder; - - StreamWriter logWriter; - string logPath; - - const string _version_ = "API-3"; - - /// Placeholder for state information to send. - [System.Serializable] - [HideInInspector] - public struct StepMessage - { - public string brain_name; - public List agents; - public List vectorObservations; - public List rewards; - public List previousVectorActions; - public List previousTextActions; - public List memories; - public List textObservations; - public List dones; - public List maxes; - } - - StepMessage sMessage; - string sMessageString; - - AgentMessage rMessage; - StringBuilder rMessageString = new StringBuilder(messageLength); - - /// Placeholder for returned message. - struct AgentMessage - { - public Dictionary> vector_action { get; set; } - public Dictionary> memory { get; set; } - public Dictionary> text_action { get; set; } - } - - /// Placeholder for reset parameter message - struct ResetParametersMessage - { - public Dictionary parameters { get; set; } - public bool train_model { get; set; } - } - - /// Consrtuctor for the External Communicator - public ExternalCommunicator(Academy aca) - { - academy = aca; - brains = new List(); - current_agents = new Dictionary>(); - - hasSentState = new Dictionary(); - triedSendState = new Dictionary(); - - } - - /// Adds the brain to the list of brains which have already decided their - /// actions. - public void SubscribeBrain(Brain brain) - { - brains.Add(brain); - triedSendState[brain.gameObject.name] = false; - hasSentState[brain.gameObject.name] = false; - } - - /// Attempts to make handshake with external API. - public bool CommunicatorHandShake() - { - try - { - ReadArgs(); - } - catch - { - return false; - } - return true; - } - - /// Contains the logic for the initializtation of the socket. - public void InitializeCommunicator() - { - Application.logMessageReceived += HandleLog; - logPath = Path.GetFullPath(".") + "/unity-environment.log"; - logWriter = new StreamWriter(logPath, false); - logWriter.WriteLine(System.DateTime.Now.ToString()); - logWriter.WriteLine(" "); - logWriter.Close(); - messageHolder = new byte[messageLength]; - lengthHolder = new byte[4]; - - // Create a TCP/IP socket. - sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - sender.Connect("localhost", comPort); - - var accParamerters = new AcademyParameters(); - - accParamerters.brainParameters = new List(); - accParamerters.brainNames = new List(); - accParamerters.externalBrainNames = new List(); - accParamerters.apiNumber = _version_; - accParamerters.logPath = logPath; - foreach (Brain b in brains) - { - accParamerters.brainParameters.Add(b.brainParameters); - accParamerters.brainNames.Add(b.gameObject.name); - if (b.brainType == BrainType.External) - { - accParamerters.externalBrainNames.Add(b.gameObject.name); - } - } - accParamerters.AcademyName = academy.gameObject.name; - accParamerters.resetParameters = academy.resetParameters; - - SendParameters(accParamerters); - - sMessage = new StepMessage(); - sMessage.agents = new List(defaultNumAgents); - sMessage.vectorObservations = new List(defaultNumAgents * defaultNumObservations); - sMessage.rewards = new List(defaultNumAgents); - sMessage.memories = new List(defaultNumAgents * defaultNumObservations); - sMessage.dones = new List(defaultNumAgents); - sMessage.previousVectorActions = new List(defaultNumAgents * defaultNumObservations); - sMessage.previousTextActions = new List(defaultNumAgents); - sMessage.maxes = new List(defaultNumAgents); - sMessage.textObservations = new List(defaultNumAgents); - - // Initialize the list of brains the Communicator must listen to - // Issue : This assumes all brains are broadcasting. - foreach (string k in accParamerters.brainNames) - { - current_agents[k] = new List(defaultNumAgents); - hasSentState[k] = false; - triedSendState[k] = false; - } - - } - - void HandleLog(string logString, string stackTrace, LogType type) - { - logWriter = new StreamWriter(logPath, true); - logWriter.WriteLine(type.ToString()); - logWriter.WriteLine(logString); - logWriter.WriteLine(stackTrace); - logWriter.Close(); - } - - /// Listens to the socket for a command and returns the corresponding - /// External Command. - public void UpdateCommand() - { - int location = sender.Receive(messageHolder); - string message = Encoding.ASCII.GetString(messageHolder, 0, location); - switch (message) - { - case "STEP": - command = ExternalCommand.STEP; - break; - case "RESET": - command = ExternalCommand.RESET; - break; - case "QUIT": - command = ExternalCommand.QUIT; - break; - default: - command = ExternalCommand.QUIT; - break; - } - } - - public ExternalCommand GetCommand() - { - return command; - } - - public void SetCommand(ExternalCommand c) - { - command = c; - } - - /// Listens to the socket for the new resetParameters - public Dictionary GetResetParameters() - { - sender.Send(Encoding.ASCII.GetBytes("CONFIG_REQUEST")); - Receive(); - var resetParams = JsonConvert.DeserializeObject(rMessageString.ToString()); - academy.SetIsInference(!resetParams.train_model); - return resetParams.parameters; - } - - - /// Used to read Python-provided environment parameters - private void ReadArgs() - { - string[] args = System.Environment.GetCommandLineArgs(); - var inputPort = ""; - var inputSeed = ""; - for (int i = 0; i < args.Length; i++) - { - if (args[i] == "--port") - { - inputPort = args[i + 1]; - } - if (args[i] == "--seed") - { - inputSeed = args[i + 1]; - } - } - comPort = int.Parse(inputPort); - randomSeed = int.Parse(inputSeed); - Random.InitState(randomSeed); - } - - /// Sends Academy parameters to external agent - private void SendParameters(AcademyParameters envParams) - { - string envMessage = JsonConvert.SerializeObject(envParams, Formatting.Indented); - sender.Send(Encoding.ASCII.GetBytes(envMessage)); - } - - /// Receives messages from external agent - private void Receive() - { - int location = sender.Receive(messageHolder); - rMessageString.Clear(); - rMessageString.Append(Encoding.ASCII.GetString(messageHolder, 0, location)); - } - - /// Receives a message and can reconstruct a message if was too long - private void ReceiveAll() - { - sender.Receive(lengthHolder); - int totalLength = System.BitConverter.ToInt32(lengthHolder, 0); - int location = 0; - rMessageString.Clear(); - while (location != totalLength) - { - int fragment = sender.Receive(messageHolder); - location += fragment; - rMessageString.Append(Encoding.ASCII.GetString(messageHolder, 0, fragment)); - } - } - - /// Ends connection and closes environment - private void OnApplicationQuit() - { - sender.Close(); - sender.Shutdown(SocketShutdown.Both); - } - - /// Contains logic for coverting texture into bytearray to send to - /// external agent. - private byte[] TexToByteArray(Texture2D tex) - { - byte[] bytes = tex.EncodeToPNG(); - Object.DestroyImmediate(tex); - Resources.UnloadUnusedAssets(); - return bytes; - } - - private byte[] AppendLength(byte[] input) - { - byte[] newArray = new byte[input.Length + 4]; - input.CopyTo(newArray, 4); - System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0); - return newArray; - } - - /// Collects the information from the brains and sends it accross the socket - public void GiveBrainInfo(Brain brain, Dictionary agentInfo) - { - var brainName = brain.gameObject.name; - triedSendState[brainName] = true; - - - current_agents[brainName].Clear(); - foreach (Agent agent in agentInfo.Keys) - { - current_agents[brainName].Add(agent); - } - if (current_agents[brainName].Count() > 0) - { - hasSentState[brainName] = true; - sMessage.brain_name = brainName; - sMessage.agents.Clear(); - sMessage.vectorObservations.Clear(); - sMessage.rewards.Clear(); - sMessage.memories.Clear(); - sMessage.dones.Clear(); - sMessage.previousVectorActions.Clear(); - sMessage.previousTextActions.Clear(); - sMessage.maxes.Clear(); - sMessage.textObservations.Clear(); - - int memorySize = 0; - foreach (Agent agent in current_agents[brainName]) - { - memorySize = Mathf.Max(agentInfo[agent].memories.Count, memorySize); - } - - foreach (Agent agent in current_agents[brainName]) - { - sMessage.agents.Add(agentInfo[agent].id); - sMessage.vectorObservations.AddRange(agentInfo[agent].stackedVectorObservation); - sMessage.rewards.Add(agentInfo[agent].reward); - sMessage.memories.AddRange(agentInfo[agent].memories); - for (int j = 0; j < memorySize - agentInfo[agent].memories.Count; j++) - { - sMessage.memories.Add(0f); - } - sMessage.dones.Add(agentInfo[agent].done); - sMessage.previousVectorActions.AddRange(agentInfo[agent].storedVectorActions.ToList()); - sMessage.previousTextActions.Add(agentInfo[agent].storedTextActions); - sMessage.maxes.Add(agentInfo[agent].maxStepReached); - sMessage.textObservations.Add(agentInfo[agent].textObservation); - - } - - - - sMessageString = JsonUtility.ToJson(sMessage); - sender.Send(AppendLength(Encoding.ASCII.GetBytes(sMessageString))); - Receive(); - int i = 0; - foreach (resolution res in brain.brainParameters.cameraResolutions) - { - foreach (Agent agent in current_agents[brainName]) - { - sender.Send(AppendLength(TexToByteArray(agentInfo[agent].visualObservations[i]))); - Receive(); - } - i++; - } - - - } - if (triedSendState.Values.All(x => x)) - { - if (hasSentState.Values.Any(x => x) || academy.IsDone()) - { - // if all the brains listed have sent their state - sender.Send(AppendLength(Encoding.ASCII.GetBytes("END_OF_MESSAGE:" + (academy.IsDone() ? "True" : "False")))); - - - UpdateCommand(); - if (GetCommand() == ExternalCommand.STEP) - { - UpdateActions(); - } - } - - foreach (string k in current_agents.Keys) - { - hasSentState[k] = false; - triedSendState[k] = false; - } - } - - } - - public Dictionary GetHasTried() - { - return triedSendState; - } - - public Dictionary GetSent() - { - return hasSentState; - } - - /// Listens for actions, memories, and values and sends them - /// to the corrensponding brains. - public void UpdateActions() - { - sender.Send(Encoding.ASCII.GetBytes("STEPPING")); - ReceiveAll(); - rMessage = JsonConvert.DeserializeObject(rMessageString.ToString()); - - foreach (Brain brain in brains) - { - if (brain.brainType == BrainType.External) - { - var brainName = brain.gameObject.name; - - if (current_agents[brainName].Count() == 0) - { - continue; - } - var memorySize = rMessage.memory[brainName].Count() / current_agents[brainName].Count(); - - for (int i = 0; i < current_agents[brainName].Count(); i++) - { - if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous) - { - current_agents[brainName][i].UpdateVectorAction(rMessage.vector_action[brainName].GetRange( - i * brain.brainParameters.vectorActionSize, brain.brainParameters.vectorActionSize).ToArray()); - } - else - { - current_agents[brainName][i].UpdateVectorAction(rMessage.vector_action[brainName].GetRange(i, 1).ToArray()); - - } - - current_agents[brainName][i].UpdateMemoriesAction( - rMessage.memory[brainName].GetRange(i * memorySize, memorySize)); - - if (rMessage.text_action[brainName].Count > 0) - current_agents[brainName][i].UpdateTextAction(rMessage.text_action[brainName][i]); - - } - - } - } - } - - - -} \ No newline at end of file diff --git a/unity-environment/Assets/ML-Agents/Scripts/RpcCommunicator.cs b/unity-environment/Assets/ML-Agents/Scripts/RpcCommunicator.cs new file mode 100644 index 000000000..a34408b20 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/RpcCommunicator.cs @@ -0,0 +1,149 @@ +using Grpc.Core; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +#if UNITY_EDITOR +using UnityEditor; +#endif +using UnityEngine; +using MLAgents.CommunicatorObjects; + +namespace MLAgents +{ + /// Responsible for communication with External using gRPC. + public class RPCCommunicator : Communicator + { + /// If true, the communication is active. + bool m_isOpen; + + /// The Unity to External client. + UnityToExternal.UnityToExternalClient m_client; + + /// The communicator parameters sent at construction + CommunicatorParameters m_communicatorParameters; + + /// + /// Initializes a new instance of the RPCCommunicator class. + /// + /// Communicator parameters. + public RPCCommunicator(CommunicatorParameters communicatorParameters) + { + this.m_communicatorParameters = communicatorParameters; + } + + /// + /// Initialize the communicator by sending the first UnityOutput and receiving the + /// first UnityInput. The second UnityInput is stored in the unityInput argument. + /// + /// The first Unity Input. + /// The first Unity Output. + /// The second Unity input. + public UnityInput Initialize(UnityOutput unityOutput, + out UnityInput unityInput) + { + m_isOpen = true; + var channel = new Channel( + "localhost:"+m_communicatorParameters.port, + ChannelCredentials.Insecure); + + m_client = new UnityToExternal.UnityToExternalClient(channel); + var result = m_client.Exchange(WrapMessage(unityOutput, 200)); + unityInput = m_client.Exchange(WrapMessage(null, 200)).UnityInput; +#if UNITY_EDITOR + EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; +#endif + return result.UnityInput; + } + + /// + /// Close the communicator gracefully on both sides of the communication. + /// + public void Close() + { + if (!m_isOpen) + { + return; + } + + try + { + m_client.Exchange(WrapMessage(null, 400)); + m_isOpen = false; + } + catch + { + return; + } + } + + /// + /// Send a UnityOutput and receives a UnityInput. + /// + /// The next UnityInput. + /// The UnityOutput to be sent. + public UnityInput Exchange(UnityOutput unityOutput) + { + if (!m_isOpen) + { + return null; + } + try + { + var message = m_client.Exchange(WrapMessage(unityOutput, 200)); + if (message.Header.Status == 200) + { + return message.UnityInput; + } + else + { + m_isOpen = false; + return null; + } + } + catch + { + m_isOpen = false; + return null; + } + } + + /// + /// Wraps the UnityOuptut into a message with the appropriate status. + /// + /// The UnityMessage corresponding. + /// The UnityOutput to be wrapped. + /// The status of the message. + private static UnityMessage WrapMessage(UnityOutput content, int status) + { + return new UnityMessage + { + Header = new Header { Status = status }, + UnityOutput = content + }; + } + + /// + /// When the Unity application quits, the communicator must be closed + /// + private void OnApplicationQuit() + { + Close(); + } + +#if UNITY_EDITOR + /// + /// When the editor exits, the communicator must be closed + /// + /// State. + private void HandleOnPlayModeChanged(PlayModeStateChange state) + { + // This method is run whenever the playmode state is changed. + if (state==PlayModeStateChange.ExitingPlayMode) + { + Close(); + } + } +#endif + + } +} diff --git a/unity-environment/Assets/ML-Agents/Scripts/RpcCommunicator.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/RpcCommunicator.cs.meta new file mode 100644 index 000000000..d1903d74c --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/RpcCommunicator.cs.meta @@ -0,0 +1,13 @@ +fileFormatVersion: 2 +guid: 57a3dc12d3b88408688bb490b65a838e +timeCreated: 1523046536 +licenseType: Free +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/SocketCommunicator.cs b/unity-environment/Assets/ML-Agents/Scripts/SocketCommunicator.cs new file mode 100644 index 000000000..5edf61dbc --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/SocketCommunicator.cs @@ -0,0 +1,166 @@ +using Google.Protobuf; +using Grpc.Core; +using System.Net.Sockets; +using UnityEngine; +using MLAgents.CommunicatorObjects; +using System.Threading.Tasks; +#if UNITY_EDITOR +using UnityEditor; +#endif + +namespace MLAgents +{ + + public class SocketCommunicator : Communicator + { + private const float TimeOut = 10f; + private const int MessageLength = 12000; + byte[] m_messageHolder = new byte[MessageLength]; + int m_comPort; + Socket m_sender; + byte[] m_lengthHolder = new byte[4]; + CommunicatorParameters communicatorParameters; + + + public SocketCommunicator(CommunicatorParameters communicatorParameters) + { + this.communicatorParameters = communicatorParameters; + } + + /// + /// Initialize the communicator by sending the first UnityOutput and receiving the + /// first UnityInput. The second UnityInput is stored in the unityInput argument. + /// + /// The first Unity Input. + /// The first Unity Output. + /// The second Unity input. + public UnityInput Initialize(UnityOutput unityOutput, + out UnityInput unityInput) + { + + m_sender = new Socket( + AddressFamily.InterNetwork, + SocketType.Stream, + ProtocolType.Tcp); + m_sender.Connect("localhost", communicatorParameters.port); + + UnityMessage initializationInput = + UnityMessage.Parser.ParseFrom(Receive()); + + Send(WrapMessage(unityOutput, 200).ToByteArray()); + + unityInput = UnityMessage.Parser.ParseFrom(Receive()).UnityInput; +#if UNITY_EDITOR + EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; +#endif + return initializationInput.UnityInput; + + } + + /// + /// Uses the socke to receive a byte[] from External. Reassemble a message that was split + /// by External if it was too long. + /// + /// The byte[] sent by External. + byte[] Receive() + { + m_sender.Receive(m_lengthHolder); + int totalLength = System.BitConverter.ToInt32(m_lengthHolder, 0); + int location = 0; + byte[] result = new byte[totalLength]; + while (location != totalLength) + { + int fragment = m_sender.Receive(m_messageHolder); + System.Buffer.BlockCopy( + m_messageHolder, 0, result, location, fragment); + location += fragment; + } + return result; + } + + /// + /// Send the specified input via socket to External. Split the message into smaller + /// parts if it is too long. + /// + /// The byte[] to be sent. + void Send(byte[] input) + { + byte[] newArray = new byte[input.Length + 4]; + input.CopyTo(newArray, 4); + System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0); + m_sender.Send(newArray); + } + + /// + /// Close the communicator gracefully on both sides of the communication. + /// + public void Close() + { + Send(WrapMessage(null, 400).ToByteArray()); + } + + /// + /// Send a UnityOutput and receives a UnityInput. + /// + /// The next UnityInput. + /// The UnityOutput to be sent. + public UnityInput Exchange(UnityOutput unityOutput) + { + Send(WrapMessage(unityOutput, 200).ToByteArray()); + byte[] received = null; + var task = Task.Run(() => received = Receive()); + if (!task.Wait(System.TimeSpan.FromSeconds(TimeOut))) + { + throw new UnityAgentsException( + "The communicator took too long to respond."); + } + + var message = UnityMessage.Parser.ParseFrom(received); + + if (message.Header.Status != 200) + { + return null; + } + return message.UnityInput; + } + + /// + /// Wraps the UnityOuptut into a message with the appropriate status. + /// + /// The UnityMessage corresponding. + /// The UnityOutput to be wrapped. + /// The status of the message. + private static UnityMessage WrapMessage(UnityOutput content, int status) + { + return new UnityMessage + { + Header = new Header { Status = status }, + UnityOutput = content + }; + } + + /// + /// When the Unity application quits, the communicator must be closed + /// + private void OnApplicationQuit() + { + Close(); + } + +#if UNITY_EDITOR + /// + /// When the editor exits, the communicator must be closed + /// + /// State. + void HandleOnPlayModeChanged(PlayModeStateChange state) + { + // This method is run whenever the playmode state is changed. + if (state == PlayModeStateChange.ExitingPlayMode) + { + Close(); + } + } +#endif + + } +} diff --git a/unity-environment/Assets/ML-Agents/Scripts/SocketCommunicator.cs.meta b/unity-environment/Assets/ML-Agents/Scripts/SocketCommunicator.cs.meta new file mode 100644 index 000000000..23ac272f1 --- /dev/null +++ b/unity-environment/Assets/ML-Agents/Scripts/SocketCommunicator.cs.meta @@ -0,0 +1,13 @@ +fileFormatVersion: 2 +guid: f0901c57c84a54f25aa5955165072493 +timeCreated: 1523046536 +licenseType: Free +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/unity-environment/Assets/ML-Agents/Scripts/UnityAgentsException.cs b/unity-environment/Assets/ML-Agents/Scripts/UnityAgentsException.cs index a623b051f..dd6ddde7d 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/UnityAgentsException.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/UnityAgentsException.cs @@ -11,7 +11,7 @@ public class UnityAgentsException : System.Exception /// The simulation will end since no steps will be taken. public UnityAgentsException(string message) : base(message) { - Time.timeScale = 0f; + } /// A constructor is needed for serialization when an exception propagates