From 0b677295d3965eef29ee916a7c80104f31291bb5 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Mon, 8 Jun 2020 09:35:22 -0700 Subject: [PATCH] Modified the *CODE* to rename world to policy --- ...rldSpecsDrawer.cs => PolicySpecsDrawer.cs} | 24 ++-- ...awer.cs.meta => PolicySpecsDrawer.cs.meta} | 2 +- Runtime/Academy.cs | 77 +++++++------ Runtime/ActionHashMapUtils.cs | 28 ++--- Runtime/ActuatorJob.cs | 40 +++---- Runtime/{World.meta => Policy.meta} | 0 Runtime/{World => Policy}/Counter.cs | 0 Runtime/{World => Policy}/Counter.cs.meta | 0 Runtime/{World => Policy}/DecisionRequest.cs | 50 ++++----- .../{World => Policy}/DecisionRequest.cs.meta | 0 .../{World => Policy}/EpisodeTermination.cs | 38 +++---- .../EpisodeTermination.cs.meta | 0 .../MLAgentsWorld.cs => Policy/Policy.cs} | 24 ++-- .../Policy.cs.meta} | 0 ...rldProcessor.meta => PolicyProcessor.meta} | 0 .../BarracudaPolicyProcessor.cs} | 76 ++++++------- .../BarracudaPolicyProcessor.cs.meta} | 0 .../HeuristicPolicyProcessor.cs | 104 ++++++++++++++++++ .../HeuristicPolicyProcessor.cs.meta} | 0 .../IPolicyProcessor.cs} | 12 +- .../IPolicyProcessor.cs.meta} | 0 .../NullPolicyProcessor.cs} | 10 +- .../NullPolicyProcessor.cs.meta} | 0 .../RemotePolicyProcessor.cs} | 14 +-- .../RemotePolicyProcessor.cs.meta} | 0 Runtime/Remote/RLDataOffsets.cs | 14 +-- Runtime/Remote/SharedMemoryBody.cs | 56 +++++----- Runtime/Remote/SharedMemoryCommunicator.cs | 26 ++--- .../{MLAgentsWorldSpecs.cs => PolicySpecs.cs} | 44 ++++---- ...WorldSpecs.cs.meta => PolicySpecs.cs.meta} | 0 .../WorldProcessor/HeuristicWorldProcessor.cs | 104 ------------------ Samples~/3DBall/Prefab/AgentBlue.mat | 5 +- Samples~/3DBall/Scene/3DBall.unity | 8 +- Samples~/3DBall/Script/BalanceBallManager.cs | 7 +- Samples~/3DBall/Script/BallSystem.cs | 18 +-- Samples~/Basic/Script/BasicAgent.cs | 18 +-- ...MLAgentsWorld.cs => TestMLAgentsPolicy.cs} | 54 ++++----- ...rld.cs.meta => TestMLAgentsPolicy.cs.meta} | 0 38 files changed, 426 insertions(+), 427 deletions(-) rename Editor/{MLAgentsWorldSpecsDrawer.cs => PolicySpecsDrawer.cs} (92%) rename Editor/{MLAgentsWorldSpecsDrawer.cs.meta => PolicySpecsDrawer.cs.meta} (83%) rename Runtime/{World.meta => Policy.meta} (100%) rename Runtime/{World => Policy}/Counter.cs (100%) rename Runtime/{World => Policy}/Counter.cs.meta (100%) rename Runtime/{World => Policy}/DecisionRequest.cs (72%) rename Runtime/{World => Policy}/DecisionRequest.cs.meta (100%) rename Runtime/{World => Policy}/EpisodeTermination.cs (73%) rename Runtime/{World => Policy}/EpisodeTermination.cs.meta (100%) rename Runtime/{World/MLAgentsWorld.cs => Policy/Policy.cs} (94%) rename Runtime/{World/MLAgentsWorld.cs.meta => Policy/Policy.cs.meta} (100%) rename Runtime/{WorldProcessor.meta => PolicyProcessor.meta} (100%) rename Runtime/{WorldProcessor/BarracudaWorldProcessor.cs => PolicyProcessor/BarracudaPolicyProcessor.cs} (64%) rename Runtime/{WorldProcessor/BarracudaWorldProcessor.cs.meta => PolicyProcessor/BarracudaPolicyProcessor.cs.meta} (100%) create mode 100644 Runtime/PolicyProcessor/HeuristicPolicyProcessor.cs rename Runtime/{WorldProcessor/HeuristicWorldProcessor.cs.meta => PolicyProcessor/HeuristicPolicyProcessor.cs.meta} (100%) rename Runtime/{WorldProcessor/IWorldProcessor.cs => PolicyProcessor/IPolicyProcessor.cs} (51%) rename Runtime/{WorldProcessor/IWorldProcessor.cs.meta => PolicyProcessor/IPolicyProcessor.cs.meta} (100%) rename Runtime/{WorldProcessor/NullWorldProcessor.cs => PolicyProcessor/NullPolicyProcessor.cs} (55%) rename Runtime/{WorldProcessor/NullWorldProcessor.cs.meta => PolicyProcessor/NullPolicyProcessor.cs.meta} (100%) rename Runtime/{WorldProcessor/RemoteWorldProcessor.cs => PolicyProcessor/RemotePolicyProcessor.cs} (57%) rename Runtime/{WorldProcessor/RemoteWorldProcessor.cs.meta => PolicyProcessor/RemotePolicyProcessor.cs.meta} (100%) rename Runtime/UI/{MLAgentsWorldSpecs.cs => PolicySpecs.cs} (51%) rename Runtime/UI/{MLAgentsWorldSpecs.cs.meta => PolicySpecs.cs.meta} (100%) delete mode 100644 Runtime/WorldProcessor/HeuristicWorldProcessor.cs rename Tests/Editor/{TestMLAgentsWorld.cs => TestMLAgentsPolicy.cs} (79%) rename Tests/Editor/{TestMLAgentsWorld.cs.meta => TestMLAgentsPolicy.cs.meta} (100%) diff --git a/Editor/MLAgentsWorldSpecsDrawer.cs b/Editor/PolicySpecsDrawer.cs similarity index 92% rename from Editor/MLAgentsWorldSpecsDrawer.cs rename to Editor/PolicySpecsDrawer.cs index 18546c4..41bc90c 100644 --- a/Editor/MLAgentsWorldSpecsDrawer.cs +++ b/Editor/PolicySpecsDrawer.cs @@ -12,7 +12,7 @@ namespace Unity.AI.MLAgents.Editor internal static class SpecsPropertyNames { public const string k_Name = "Name"; - public const string k_WorldProcessorType = "WorldProcessorType"; + public const string k_PolicyProcessorType = "PolicyProcessorType"; public const string k_NumberAgents = "NumberAgents"; public const string k_ActionType = "ActionType"; public const string k_ObservationShapes = "ObservationShapes"; @@ -27,8 +27,8 @@ namespace Unity.AI.MLAgents.Editor /// PropertyDrawer for BrainParameters. Defines how BrainParameters are displayed in the /// Inspector. /// - [CustomPropertyDrawer(typeof(MLAgentsWorldSpecs))] - internal class MLAgentsWorldSpecsDrawer : PropertyDrawer + [CustomPropertyDrawer(typeof(PolicySpecs))] + internal class PolicySpecsDrawer : PropertyDrawer { // The height of a line in the Unity Inspectors const float k_LineHeight = 21f; @@ -66,7 +66,7 @@ namespace Unity.AI.MLAgents.Editor new Rect(position.x - 3f, position.y, position.width + 6f, m_TotalHeight), new Color(0f, 0f, 0f, 0.1f)); - EditorGUI.LabelField(position, "ML-Agents World Specs : " + label.text); + EditorGUI.LabelField(position, "Policy Specs : " + label.text); position.y += k_LineHeight; EditorGUI.indentLevel++; @@ -74,13 +74,13 @@ namespace Unity.AI.MLAgents.Editor // Name EditorGUI.PropertyField(position, property.FindPropertyRelative(SpecsPropertyNames.k_Name), - new GUIContent("World Name", "The name of the World")); + new GUIContent("Name", "The name of the Policy")); position.y += k_LineHeight; - // WorldProcessorType + // PolicyProcessorType EditorGUI.PropertyField(position, - property.FindPropertyRelative(SpecsPropertyNames.k_WorldProcessorType), - new GUIContent("Processor Type", "The Policy for the World")); + property.FindPropertyRelative(SpecsPropertyNames.k_PolicyProcessorType), + new GUIContent("Policy Type", "The type of Policy")); position.y += k_LineHeight; // Number of Agents @@ -224,7 +224,7 @@ namespace Unity.AI.MLAgents.Editor var name = nameProperty.stringValue; if (name == "") { - m_Warnings.Add("Your World must have a non-empty name"); + m_Warnings.Add("Your Policy must have a non-empty name"); } // Max number of agents is not zero @@ -232,14 +232,14 @@ namespace Unity.AI.MLAgents.Editor var nAgents = nAgentsProperty.intValue; if (nAgents == 0) { - m_Warnings.Add("Your World must have a non-zero maximum number of Agents"); + m_Warnings.Add("Your Policy must have a non-zero maximum number of Agents"); } // At least one observation var observationShapes = property.FindPropertyRelative(SpecsPropertyNames.k_ObservationShapes); if (observationShapes.arraySize == 0) { - m_Warnings.Add("Your World must have at least one observation"); + m_Warnings.Add("Your Policy must have at least one observation"); } // Action Size is not zero @@ -247,7 +247,7 @@ namespace Unity.AI.MLAgents.Editor var actionSize = actionSizeProperty.intValue; if (actionSize == 0) { - m_Warnings.Add("Your World must have non-zero action size"); + m_Warnings.Add("Your Policy must have non-zero action size"); } //Model is not empty diff --git a/Editor/MLAgentsWorldSpecsDrawer.cs.meta b/Editor/PolicySpecsDrawer.cs.meta similarity index 83% rename from Editor/MLAgentsWorldSpecsDrawer.cs.meta rename to Editor/PolicySpecsDrawer.cs.meta index 37549a3..e7725e8 100644 --- a/Editor/MLAgentsWorldSpecsDrawer.cs.meta +++ b/Editor/PolicySpecsDrawer.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 671ee066644364cc191cb6c00ceaf1b4 +guid: 94a88b2001ca84a0aa8eb12a93dffc6e MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/Runtime/Academy.cs b/Runtime/Academy.cs index b6b127b..08c250b 100644 --- a/Runtime/Academy.cs +++ b/Runtime/Academy.cs @@ -11,7 +11,7 @@ namespace Unity.AI.MLAgents /// /// The Academy is a singleton that orchestrates the decision making of the /// decision making of the Agents. - /// It is used to register WorldProcessors to Worlds and to keep track of the + /// It is used to register PolicyProcessors to Policy and to keep track of the /// reset logic of the simulation. /// public class Academy : IDisposable @@ -50,7 +50,8 @@ namespace Unity.AI.MLAgents private bool m_FirstMessageReceived; private SharedMemoryCommunicator m_Communicator; - internal Dictionary m_WorldToProcessor; // Maybe we can put the processor in the world with an unsafe unmanaged memory pointer ? + internal Dictionary m_PolicyToProcessor; + // TODO : Maybe we can put the processor in the policy with an unsafe unmanaged memory pointer ? private EnvironmentParameters m_EnvironmentParameters; private StatsRecorder m_StatsRecorder; @@ -63,33 +64,29 @@ namespace Unity.AI.MLAgents public Action OnEnvironmentReset; /// - /// Registers a MLAgentsWorld to a decision making mechanism. - /// By default, the MLAgentsWorld will use a remote process for decision making when available. + /// Registers a Policy to a decision making mechanism. + /// By default, the Policy will use a remote process for decision making when available. /// - /// The string identifier of the MLAgentsWorld. There can only be one world per unique id. - /// The MLAgentsWorld that is being subscribed. - /// If the remote process is not available, the MLAgentsWorld will use this World processor for decision making. - /// If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallback worldProcessor otherwise. - public void RegisterWorld(string policyId, MLAgentsWorld world, IWorldProcessor worldProcessor = null, bool defaultRemote = true) + /// The string identifier of the Policy. There can only be one Policy per unique id. + /// The Policy that is being subscribed. + /// If the remote process is not available, the Policy will use this IPolicyProcessor for decision making. + /// If true, the Policy will default to using the remote process for communication making and use the fallback IPolicyProcessor otherwise. + public void RegisterPolicy(string policyId, Policy policy, IPolicyProcessor policyProcessor = null, bool defaultRemote = true) { - // Need to find a way to deregister ? - // Need a way to modify the World processor on the fly - // Automagically register world on creation ? - - IWorldProcessor processor = null; + IPolicyProcessor processor = null; if (m_Communicator != null && defaultRemote) { - processor = new RemoteWorldProcessor(world, policyId, m_Communicator); + processor = new RemotePolicyProcessor(policy, policyId, m_Communicator); } - else if (worldProcessor != null) + else if (policyProcessor != null) { - processor = worldProcessor; + processor = policyProcessor; } else { - processor = new NullWorldProcessor(world); + processor = new NullPolicyProcessor(policy); } - m_WorldToProcessor[world] = processor; + m_PolicyToProcessor[policy] = processor; } /// @@ -133,7 +130,7 @@ namespace Unity.AI.MLAgents Application.quitting += Dispose; OnEnvironmentReset = () => {}; - m_WorldToProcessor = new Dictionary(); + m_PolicyToProcessor = new Dictionary(); TryInitializeCommunicator(); SideChannelsManager.RegisterSideChannel(new EngineConfigurationChannel()); @@ -162,8 +159,8 @@ namespace Unity.AI.MLAgents } } - // We will make the assumption that a world can only be updated one at a time - internal void UpdateWorld(MLAgentsWorld world) + // We will make the assumption that a policy can only be updated one at a time + internal void UpdatePolicy(Policy policy) { if (!m_Initialized) { @@ -171,23 +168,23 @@ namespace Unity.AI.MLAgents } // If no agents requested a decision return - if (world.DecisionCounter.Count == 0 && world.TerminationCounter.Count == 0) + if (policy.DecisionCounter.Count == 0 && policy.TerminationCounter.Count == 0) { return; } - // Ensure the world does not have lingering actions: - if (world.ActionCounter.Count != 0) + // Ensure the policy does not have lingering actions: + if (policy.ActionCounter.Count != 0) { // This means something in the execution went wrong, this error should never appear throw new MLAgentsException("TODO : ActionCount is not 0"); } - var processor = m_WorldToProcessor[world]; + var processor = m_PolicyToProcessor[policy]; if (processor == null) { // Raise error - throw new MLAgentsException($"A world has not been correctly registered."); + throw new MLAgentsException($"A Policy has not been correctly registered."); } @@ -208,10 +205,10 @@ namespace Unity.AI.MLAgents if (!reset) // TODO : Comment out if we do not want to reset on first env.reset() { m_Communicator.WriteSideChannelData(SideChannelsManager.GetSideChannelMessage()); - processor.ProcessWorld(); + processor.Process(); reset = m_Communicator.ReadAndClearResetCommand(); - world.SetActionReady(); - world.ResetDecisionsAndTerminationCounters(); + policy.SetActionReady(); + policy.ResetDecisionsAndTerminationCounters(); SideChannelsManager.ProcessSideChannelData(m_Communicator.ReadAndClearSideChannelData()); } if (reset) @@ -222,16 +219,16 @@ namespace Unity.AI.MLAgents } else if (!processor.IsConnected) { - processor.ProcessWorld(); - world.SetActionReady(); - world.ResetDecisionsAndTerminationCounters(); + processor.Process(); + policy.SetActionReady(); + policy.ResetDecisionsAndTerminationCounters(); // TODO com.ReadAndClearSideChannelData(); // Remove side channel data } else { // The processor wants to communicate but the communicator is either null or inactive - world.ResetActionsCounter(); - world.ResetDecisionsAndTerminationCounters(); + policy.ResetActionsCounter(); + policy.ResetDecisionsAndTerminationCounters(); } if (m_Communicator == null) { @@ -246,16 +243,16 @@ namespace Unity.AI.MLAgents // Need to complete all of the jobs at this point. ECSWorld.EntityManager.CompleteAllJobs(); } - ResetAllWorlds(); + ResetAllPolicies(); OnEnvironmentReset?.Invoke(); } - private void ResetAllWorlds() // This is problematic because it affects all worlds and is not thread safe... + private void ResetAllPolicies() // This is problematic because it affects all policies and is not thread safe... { - foreach (var w in m_WorldToProcessor.Keys) + foreach (var pol in m_PolicyToProcessor.Keys) { - w.ResetActionsCounter(); - w.ResetDecisionsAndTerminationCounters(); + pol.ResetActionsCounter(); + pol.ResetDecisionsAndTerminationCounters(); } } diff --git a/Runtime/ActionHashMapUtils.cs b/Runtime/ActionHashMapUtils.cs index e9b48c5..d0e2597 100644 --- a/Runtime/ActionHashMapUtils.cs +++ b/Runtime/ActionHashMapUtils.cs @@ -9,39 +9,39 @@ namespace Unity.AI.MLAgents public static class ActionHashMapUtils { /// - /// Retrieves the action data for a world in puts it into a HashMap. - /// This action deletes the action data from the world. + /// Retrieves the action data for a Policy in puts it into a HashMap. + /// This action deletes the action data from the Policy. /// - /// The MLAgentsWorld the data will be retrieved from. + /// The Policy the data will be retrieved from. /// The memory allocator of the create NativeHashMap. /// The type of the Action struct. It must match the Action Size - /// and Action Type of the world. + /// and Action Type of the Policy. /// A NativeHashMap from Entities to Actions with type T. - public static NativeHashMap GenerateActionHashMap(this MLAgentsWorld world, Allocator allocator) where T : struct + public static NativeHashMap GenerateActionHashMap(this Policy policy, Allocator allocator) where T : struct { #if ENABLE_UNITY_COLLECTIONS_CHECKS - if (world.ActionSize != UnsafeUtility.SizeOf() / 4) + if (policy.ActionSize != UnsafeUtility.SizeOf() / 4) { var receivedSize = UnsafeUtility.SizeOf() / 4; - throw new MLAgentsException($"Action space size does not match for action. Expected {world.ActionSize} but received {receivedSize}"); + throw new MLAgentsException($"Action space size does not match for action. Expected {policy.ActionSize} but received {receivedSize}"); } #endif - Academy.Instance.UpdateWorld(world); - int actionCount = world.ActionCounter.Count; + Academy.Instance.UpdatePolicy(policy); + int actionCount = policy.ActionCounter.Count; var result = new NativeHashMap(actionCount, allocator); - int size = world.ActionSize; + int size = policy.ActionSize; for (int i = 0; i < actionCount; i++) { - if (world.ActionType == ActionType.DISCRETE) + if (policy.ActionType == ActionType.DISCRETE) { - result.TryAdd(world.ActionAgentEntityIds[i], world.DiscreteActuators.Slice(i * size, size).SliceConvert()[0]); + result.TryAdd(policy.ActionAgentEntityIds[i], policy.DiscreteActuators.Slice(i * size, size).SliceConvert()[0]); } else { - result.TryAdd(world.ActionAgentEntityIds[i], world.ContinuousActuators.Slice(i * size, size).SliceConvert()[0]); + result.TryAdd(policy.ActionAgentEntityIds[i], policy.ContinuousActuators.Slice(i * size, size).SliceConvert()[0]); } } - world.ResetActionsCounter(); + policy.ResetActionsCounter(); return result; } } diff --git a/Runtime/ActuatorJob.cs b/Runtime/ActuatorJob.cs index 28f87d5..3e81ab1 100644 --- a/Runtime/ActuatorJob.cs +++ b/Runtime/ActuatorJob.cs @@ -34,7 +34,7 @@ namespace Unity.AI.MLAgents /// Retrieve the action that was decided for the Entity. /// This method uses a generic type, as such you must provide a /// type that is compatible with the Action Type and Action Size - /// for this MLAgentsWorld. + /// for this Policy. /// /// The type of action struct. /// The action struct for the Entity. @@ -59,7 +59,7 @@ namespace Unity.AI.MLAgents } /// - /// The signature of the Job used to retrieve actuators values from a world + /// The signature of the Job used to retrieve actuators values from a Policy /// [JobProducerType(typeof(IActuatorJobExtensions.ActuatorDataJobProcess<>))] public interface IActuatorJob @@ -73,31 +73,31 @@ namespace Unity.AI.MLAgents /// Schedule the Job that will generate the action data for the Entities that requested a decision. /// /// The IActuatorJob struct. - /// The MLAgentsWorld containing the data needed for decision making. + /// The Policy containing the data needed for decision making. /// The jobHandle for the job. /// The type of the IActuatorData struct. /// The updated jobHandle for the job. - public static unsafe JobHandle Schedule(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps) + public static unsafe JobHandle Schedule(this T jobData, Policy policy, JobHandle inputDeps) where T : struct, IActuatorJob { - inputDeps.Complete(); // TODO : FIND A BETTER WAY TO MAKE SURE ALL THE DATA IS IN THE WORLD - Academy.Instance.UpdateWorld(mlagentsWorld); - if (mlagentsWorld.ActionCounter.Count == 0) + inputDeps.Complete(); // TODO : FIND A BETTER WAY TO MAKE SURE ALL THE DATA IS IN THE POLICY + Academy.Instance.UpdatePolicy(policy); + if (policy.ActionCounter.Count == 0) { return inputDeps; } - return ScheduleImpl(jobData, mlagentsWorld, inputDeps); + return ScheduleImpl(jobData, policy, inputDeps); } // Passing this along - internal static unsafe JobHandle ScheduleImpl(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps) + internal static unsafe JobHandle ScheduleImpl(this T jobData, Policy policy, JobHandle inputDeps) where T : struct, IActuatorJob { // Creating a data struct that contains the data the user passed into the job (This is what T is here) var data = new ActionEventJobData { UserJobData = jobData, - world = mlagentsWorld // Need to create this before hand with the actuator data + Policy = policy // Need to create this before hand with the actuator data }; // Scheduling a Job out of thin air by using a pointer called jobReflectionData in the ActuatorSystemJobStruct @@ -105,11 +105,11 @@ namespace Unity.AI.MLAgents return JobsUtility.Schedule(ref parameters); } - // This is the struct containing all the data needed from both the user and the MLAgents world + // This is the struct containing all the data needed from both the user and the Policy internal unsafe struct ActionEventJobData where T : struct { public T UserJobData; - [NativeDisableContainerSafetyRestriction] public MLAgentsWorld world; + [NativeDisableContainerSafetyRestriction] public Policy Policy; } internal struct ActuatorDataJobProcess where T : struct, IActuatorJob @@ -132,11 +132,11 @@ namespace Unity.AI.MLAgents /// Calls the user implemented Execute method with ActuatorEvent struct public static unsafe void Execute(ref ActionEventJobData jobData, IntPtr listDataPtr, IntPtr unusedPtr, ref JobRanges ranges, int jobIndex) { - int size = jobData.world.ActionSize; - int actionCount = jobData.world.ActionCounter.Count; + int size = jobData.Policy.ActionSize; + int actionCount = jobData.Policy.ActionCounter.Count; // Continuous case - if (jobData.world.ActionType == ActionType.CONTINUOUS) + if (jobData.Policy.ActionType == ActionType.CONTINUOUS) { for (int i = 0; i < actionCount; i++) { @@ -144,8 +144,8 @@ namespace Unity.AI.MLAgents { ActionSize = size, ActionType = ActionType.CONTINUOUS, - Entity = jobData.world.ActionAgentEntityIds[i], - ContinuousActionSlice = jobData.world.ContinuousActuators.Slice(i * size, size) + Entity = jobData.Policy.ActionAgentEntityIds[i], + ContinuousActionSlice = jobData.Policy.ContinuousActuators.Slice(i * size, size) }); } } @@ -158,12 +158,12 @@ namespace Unity.AI.MLAgents { ActionSize = size, ActionType = ActionType.DISCRETE, - Entity = jobData.world.ActionAgentEntityIds[i], - DiscreteActionSlice = jobData.world.DiscreteActuators.Slice(i * size, size) + Entity = jobData.Policy.ActionAgentEntityIds[i], + DiscreteActionSlice = jobData.Policy.DiscreteActuators.Slice(i * size, size) }); } } - jobData.world.ResetActionsCounter(); + jobData.Policy.ResetActionsCounter(); } } } diff --git a/Runtime/World.meta b/Runtime/Policy.meta similarity index 100% rename from Runtime/World.meta rename to Runtime/Policy.meta diff --git a/Runtime/World/Counter.cs b/Runtime/Policy/Counter.cs similarity index 100% rename from Runtime/World/Counter.cs rename to Runtime/Policy/Counter.cs diff --git a/Runtime/World/Counter.cs.meta b/Runtime/Policy/Counter.cs.meta similarity index 100% rename from Runtime/World/Counter.cs.meta rename to Runtime/Policy/Counter.cs.meta diff --git a/Runtime/World/DecisionRequest.cs b/Runtime/Policy/DecisionRequest.cs similarity index 72% rename from Runtime/World/DecisionRequest.cs rename to Runtime/Policy/DecisionRequest.cs index 9b522af..42a077b 100644 --- a/Runtime/World/DecisionRequest.cs +++ b/Runtime/Policy/DecisionRequest.cs @@ -7,19 +7,19 @@ using System; namespace Unity.AI.MLAgents { /// - /// A DecisionRequest is a struct used to provide data about an Agent to a MLAgentsWorld. - /// This data will be used to generate a decision after the world is processed. + /// A DecisionRequest is a struct used to provide data about an Agent to a Policy. + /// This data will be used to generate a decision after the Policy is processed. /// Adding data is done through a builder pattern. /// public struct DecisionRequest { private int m_Index; - private MLAgentsWorld m_World; + private Policy m_Policy; - internal DecisionRequest(int index, MLAgentsWorld world) + internal DecisionRequest(int index, Policy policy) { this.m_Index = index; - this.m_World = world; + this.m_Policy = policy; } /// @@ -29,7 +29,7 @@ namespace Unity.AI.MLAgents /// The DecisionRequest struct public DecisionRequest SetReward(float r) { - m_World.DecisionRewards[m_Index] = r; + m_Policy.DecisionRewards[m_Index] = r; return this; } @@ -43,45 +43,45 @@ namespace Unity.AI.MLAgents public DecisionRequest SetDiscreteActionMask(int branch, int actionIndex) { #if ENABLE_UNITY_COLLECTIONS_CHECKS - if (m_World.ActionType == ActionType.CONTINUOUS) + if (m_Policy.ActionType == ActionType.CONTINUOUS) { throw new MLAgentsException("SetDiscreteActionMask can only be used with discrete acton space."); } - if (branch > m_World.DiscreteActionBranches.Length) + if (branch > m_Policy.DiscreteActionBranches.Length) { throw new MLAgentsException("Unknown action branch used when setting mask."); } - if (actionIndex > m_World.DiscreteActionBranches[branch]) + if (actionIndex > m_Policy.DiscreteActionBranches[branch]) { throw new MLAgentsException("Index is out of bounds for requested action mask."); } #endif - var trueMaskIndex = m_World.DiscreteActionBranches.CumSumAt(branch) + actionIndex; - m_World.DecisionActionMasks[trueMaskIndex + m_World.DiscreteActionBranches.Sum() * m_Index] = true; + var trueMaskIndex = m_Policy.DiscreteActionBranches.CumSumAt(branch) + actionIndex; + m_Policy.DecisionActionMasks[trueMaskIndex + m_Policy.DiscreteActionBranches.Sum() * m_Index] = true; return this; } /// /// Sets the observation for a decision request. /// - /// The index of the observation as provided when creating the associated MLAgentsWorld + /// The index of the observation as provided when creating the associated Policy /// A struct strictly containing floats used as observation data /// The DecisionRequest struct public DecisionRequest SetObservation(int sensorNumber, T sensor) where T : struct { int inputSize = UnsafeUtility.SizeOf() / sizeof(float); #if ENABLE_UNITY_COLLECTIONS_CHECKS - int3 s = m_World.SensorShapes[sensorNumber]; + int3 s = m_Policy.SensorShapes[sensorNumber]; int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z); if (inputSize != expectedInputSize) { throw new MLAgentsException( - $"Cannot set observation due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}"); + $"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}"); } #endif - int start = m_World.ObservationOffsets[sensorNumber]; + int start = m_Policy.ObservationOffsets[sensorNumber]; start += inputSize * m_Index; - var tmp = m_World.DecisionObs.Slice(start, inputSize).SliceConvert(); + var tmp = m_Policy.DecisionObs.Slice(start, inputSize).SliceConvert(); tmp[0] = sensor; return this; } @@ -89,12 +89,12 @@ namespace Unity.AI.MLAgents /// /// Sets the observation for a decision request using a categorical value. /// - /// The index of the observation as provided when creating the associated MLAgentsWorld + /// The index of the observation as provided when creating the associated Policy /// An integer containing the index of the categorical observation /// The DecisionRequest struct public DecisionRequest SetObservation(int sensorNumber, int sensor) { - int3 s = m_World.SensorShapes[sensorNumber]; + int3 s = m_Policy.SensorShapes[sensorNumber]; int maxValue = s.x; #if ENABLE_UNITY_COLLECTIONS_CHECKS @@ -109,33 +109,33 @@ namespace Unity.AI.MLAgents $"Categorical observation is out of bound for observation {sensorNumber} with maximum {maxValue} (received {sensor}."); } #endif - int start = m_World.ObservationOffsets[sensorNumber]; + int start = m_Policy.ObservationOffsets[sensorNumber]; start += maxValue * m_Index; - m_World.DecisionObs[start + sensor] = 1.0f; + m_Policy.DecisionObs[start + sensor] = 1.0f; return this; } /// /// Sets the observation for a decision request. /// - /// The index of the observation as provided when creating the associated MLAgentsWorld + /// The index of the observation as provided when creating the associated Policy /// A NativeSlice of floats containing the observation data /// The DecisionRequest struct public DecisionRequest SetObservationFromSlice(int sensorNumber, [ReadOnly] NativeSlice obs) { int inputSize = obs.Length; #if ENABLE_UNITY_COLLECTIONS_CHECKS - int3 s = m_World.SensorShapes[sensorNumber]; + int3 s = m_Policy.SensorShapes[sensorNumber]; int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z); if (inputSize != expectedInputSize) { throw new MLAgentsException( - $"Cannot set observation due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}"); + $"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}"); } #endif - int start = m_World.ObservationOffsets[sensorNumber]; + int start = m_Policy.ObservationOffsets[sensorNumber]; start += inputSize * m_Index; - m_World.DecisionObs.Slice(start, inputSize).CopyFrom(obs); + m_Policy.DecisionObs.Slice(start, inputSize).CopyFrom(obs); return this; } } diff --git a/Runtime/World/DecisionRequest.cs.meta b/Runtime/Policy/DecisionRequest.cs.meta similarity index 100% rename from Runtime/World/DecisionRequest.cs.meta rename to Runtime/Policy/DecisionRequest.cs.meta diff --git a/Runtime/World/EpisodeTermination.cs b/Runtime/Policy/EpisodeTermination.cs similarity index 73% rename from Runtime/World/EpisodeTermination.cs rename to Runtime/Policy/EpisodeTermination.cs index f62d4f6..db2a861 100644 --- a/Runtime/World/EpisodeTermination.cs +++ b/Runtime/Policy/EpisodeTermination.cs @@ -7,19 +7,19 @@ using System; namespace Unity.AI.MLAgents { /// - /// A EpisodeTermination is a struct used to provide data about an Agent to a MLAgentsWorld. + /// A EpisodeTermination is a struct used to provide data about an Agent to a Policy. /// This data will be used to notify of the end of the episode of an Agent. /// Adding data is done through a builder pattern. /// public struct EpisodeTermination { private int m_Index; - private MLAgentsWorld m_World; + private Policy m_Policy; - internal EpisodeTermination(int index, MLAgentsWorld world) + internal EpisodeTermination(int index, Policy policy) { this.m_Index = index; - this.m_World = world; + this.m_Policy = policy; } /// @@ -30,31 +30,31 @@ namespace Unity.AI.MLAgents /// The EpisodeTermination struct public EpisodeTermination SetReward(float r) { - m_World.TerminationRewards[m_Index] = r; + m_Policy.TerminationRewards[m_Index] = r; return this; } /// /// Sets the observation for of the end of the Episode. /// - /// The index of the observation as provided when creating the associated MLAgentsWorld + /// The index of the observation as provided when creating the associated Policy /// A struct strictly containing floats used as observation data /// The EpisodeTermination struct public EpisodeTermination SetObservation(int sensorNumber, T sensor) where T : struct { int inputSize = UnsafeUtility.SizeOf() / sizeof(float); #if ENABLE_UNITY_COLLECTIONS_CHECKS - int3 s = m_World.SensorShapes[sensorNumber]; + int3 s = m_Policy.SensorShapes[sensorNumber]; int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z); if (inputSize != expectedInputSize) { throw new MLAgentsException( - $"Cannot set observation due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}"); + $"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}"); } #endif - int start = m_World.ObservationOffsets[sensorNumber]; + int start = m_Policy.ObservationOffsets[sensorNumber]; start += inputSize * m_Index; - var tmp = m_World.TerminationObs.Slice(start, inputSize).SliceConvert(); + var tmp = m_Policy.TerminationObs.Slice(start, inputSize).SliceConvert(); tmp[0] = sensor; return this; } @@ -62,12 +62,12 @@ namespace Unity.AI.MLAgents /// /// Sets the observation for a termination request using a categorical value. /// - /// The index of the observation as provided when creating the associated MLAgentsWorld + /// The index of the observation as provided when creating the associated Policy /// An integer containing the index of the categorical observation /// The EpisodeTermination struct public EpisodeTermination SetObservation(int sensorNumber, int sensor) { - int3 s = m_World.SensorShapes[sensorNumber]; + int3 s = m_Policy.SensorShapes[sensorNumber]; int maxValue = s.x; #if ENABLE_UNITY_COLLECTIONS_CHECKS @@ -82,33 +82,33 @@ namespace Unity.AI.MLAgents $"Categorical observation is out of bound for observation {sensorNumber} with maximum {maxValue} (received {sensor}."); } #endif - int start = m_World.ObservationOffsets[sensorNumber]; + int start = m_Policy.ObservationOffsets[sensorNumber]; start += maxValue * m_Index; - m_World.TerminationObs[start + sensor] = 1.0f; + m_Policy.TerminationObs[start + sensor] = 1.0f; return this; } /// /// Sets the last observation the Agent perceives before ending the episode. /// - /// The index of the observation as provided when creating the associated MLAgentsWorld + /// The index of the observation as provided when creating the associated Policy /// A NativeSlice of floats containing the observation data /// The EpisodeTermination struct public EpisodeTermination SetObservationFromSlice(int sensorNumber, [ReadOnly] NativeSlice obs) { int inputSize = obs.Length; #if ENABLE_UNITY_COLLECTIONS_CHECKS - int3 s = m_World.SensorShapes[sensorNumber]; + int3 s = m_Policy.SensorShapes[sensorNumber]; int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z); if (inputSize != expectedInputSize) { throw new MLAgentsException( - $"Cannot set observation due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}"); + $"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}"); } #endif - int start = m_World.ObservationOffsets[sensorNumber]; + int start = m_Policy.ObservationOffsets[sensorNumber]; start += inputSize * m_Index; - m_World.TerminationObs.Slice(start, inputSize).CopyFrom(obs); + m_Policy.TerminationObs.Slice(start, inputSize).CopyFrom(obs); return this; } } diff --git a/Runtime/World/EpisodeTermination.cs.meta b/Runtime/Policy/EpisodeTermination.cs.meta similarity index 100% rename from Runtime/World/EpisodeTermination.cs.meta rename to Runtime/Policy/EpisodeTermination.cs.meta diff --git a/Runtime/World/MLAgentsWorld.cs b/Runtime/Policy/Policy.cs similarity index 94% rename from Runtime/World/MLAgentsWorld.cs rename to Runtime/Policy/Policy.cs index 081a457..642d529 100644 --- a/Runtime/World/MLAgentsWorld.cs +++ b/Runtime/Policy/Policy.cs @@ -9,14 +9,14 @@ using Unity.Collections.LowLevel.Unsafe; namespace Unity.AI.MLAgents { /// - /// MLAgentsWorld is a data container on which the user requests decisions. + /// Policy is a data container on which the user requests decisions. /// - public struct MLAgentsWorld : IDisposable + public struct Policy : IDisposable { /// - /// Indicates if the MLAgentsWorld has been instantiated + /// Indicates if the Policy has been instantiated /// - /// True if MLAgentsWorld was instantiated, False otherwise + /// True if the Policy was instantiated, False otherwise public bool IsCreated { get { return DecisionAgentIds.IsCreated;} @@ -60,13 +60,13 @@ namespace Unity.AI.MLAgents /// /// The maximum number of decisions that can be requested between each MLAgentsSystem update /// An array of int3 corresponding to the shape of the expected observations (one int3 per observation) - /// An ActionType enum (DISCRETE / CONTINUOUS) specifying the type of actions the MLAgentsWorld will produce - /// The number of actions the MLAgentsWorld is expected to generate for each decision. + /// An ActionType enum (DISCRETE / CONTINUOUS) specifying the type of actions the Policy will produce + /// The number of actions the Policy is expected to generate for each decision. /// - If CONTINUOUS ActionType : The number of floats the action contains /// - If DISCRETE ActionType : The number of branches (integer actions) the action contains /// For DISCRETE ActionType only : an array of int specifying the number of possible int values each /// action branch has. (Must be of the same length as actionSize - public MLAgentsWorld( + public Policy( int maximumNumberAgents, int3[] obsShapes, ActionType actionType, @@ -164,7 +164,7 @@ namespace Unity.AI.MLAgents } /// - /// Dispose of the MLAgentsWorld. + /// Dispose of the Policy. /// public void Dispose() { @@ -199,7 +199,7 @@ namespace Unity.AI.MLAgents } /// - /// Requests a decision for a specific Entity to the MLAgentsWorld. The DecisionRequest + /// Requests a decision for a specific Entity to the Policy. The DecisionRequest /// struct this method returns can be used to provide data necessary for the Agent to /// take a decision for the Entity. /// @@ -211,7 +211,7 @@ namespace Unity.AI.MLAgents #if ENABLE_UNITY_COLLECTIONS_CHECKS if (!IsCreated) { - throw new MLAgentsException($"Invalid operation, cannot request a decision on a non-initialized MLAgentsWorld"); + throw new MLAgentsException($"Invalid operation, cannot request a decision on a non-initialized Policy"); } #endif var index = DecisionCounter.Increment() - 1; @@ -239,7 +239,7 @@ namespace Unity.AI.MLAgents #if ENABLE_UNITY_COLLECTIONS_CHECKS if (!IsCreated) { - throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized MLAgentsWorld"); + throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized Policy"); } #endif var index = TerminationCounter.Increment() - 1; @@ -267,7 +267,7 @@ namespace Unity.AI.MLAgents #if ENABLE_UNITY_COLLECTIONS_CHECKS if (!IsCreated) { - throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized MLAgentsWorld"); + throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized Policy"); } #endif var index = TerminationCounter.Increment() - 1; diff --git a/Runtime/World/MLAgentsWorld.cs.meta b/Runtime/Policy/Policy.cs.meta similarity index 100% rename from Runtime/World/MLAgentsWorld.cs.meta rename to Runtime/Policy/Policy.cs.meta diff --git a/Runtime/WorldProcessor.meta b/Runtime/PolicyProcessor.meta similarity index 100% rename from Runtime/WorldProcessor.meta rename to Runtime/PolicyProcessor.meta diff --git a/Runtime/WorldProcessor/BarracudaWorldProcessor.cs b/Runtime/PolicyProcessor/BarracudaPolicyProcessor.cs similarity index 64% rename from Runtime/WorldProcessor/BarracudaWorldProcessor.cs rename to Runtime/PolicyProcessor/BarracudaPolicyProcessor.cs index a834b4b..0fa08e0 100644 --- a/Runtime/WorldProcessor/BarracudaWorldProcessor.cs +++ b/Runtime/PolicyProcessor/BarracudaPolicyProcessor.cs @@ -21,23 +21,23 @@ namespace Unity.AI.MLAgents GPU = 1 } - public static class BarracudaWorldProcessorRegistringExtension + public static class BarracudaPolicyProcessorRegistringExtension { /// - /// Registers the given MLAgentsWorld to the Academy with a Neural + /// Registers the given Policy to the Academy with a Neural /// Network Model. If the input model is null, a default inactive /// processor will be registered instead. Note that if the simulation - /// connects to Python, the Neural Network will be ignored and the world + /// connects to Python, the Neural Network will be ignored and the Policy /// will exchange data with Python instead. /// - /// The MLAgentsWorld to register - /// The name of the world. This is useful for identification + /// The Policy to register + /// The name of the Policy. This is useful for identification /// and for training. /// The Neural Network model used by the processor /// The inference device specifying where to run inference /// (CPU or GPU) - public static void RegisterWorldWithBarracudaModel( - this MLAgentsWorld world, + public static void RegisterPolicyWithBarracudaModel( + this Policy policy, string policyId, NNModel model, InferenceDevice inferenceDevice = InferenceDevice.CPU @@ -45,30 +45,30 @@ namespace Unity.AI.MLAgents { if (model != null) { - var worldProcessor = new BarracudaWorldProcessor(world, model, inferenceDevice); - Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true); + var policyProcessor = new BarracudaPolicyProcessor(policy, model, inferenceDevice); + Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, true); } else { - Academy.Instance.RegisterWorld(policyId, world, null, true); + Academy.Instance.RegisterPolicy(policyId, policy, null, true); } } /// - /// Registers the given MLAgentsWorld to the Academy with a Neural + /// Registers the given Policy to the Academy with a Neural /// Network Model. If the input model is null, a default inactive /// processor will be registered instead. Note that if the simulation - /// connects to Python, the world will not connect to Python and run the + /// connects to Python, the Policy will not connect to Python and run the /// given Neural Network regardless. /// - /// The MLAgentsWorld to register - /// The name of the world. This is useful for identification + /// The Policy to register + /// The name of the Policy. This is useful for identification /// and for training. /// The Neural Network model used by the processor /// The inference device specifying where to run inference /// (CPU or GPU) - public static void RegisterWorldWithBarracudaModelForceNoCommunication( - this MLAgentsWorld world, + public static void RegisterPolicyWithBarracudaModelForceNoCommunication( + this Policy policy, string policyId, NNModel model, InferenceDevice inferenceDevice = InferenceDevice.CPU @@ -76,19 +76,19 @@ namespace Unity.AI.MLAgents { if (model != null) { - var worldProcessor = new BarracudaWorldProcessor(world, model, inferenceDevice); - Academy.Instance.RegisterWorld(policyId, world, worldProcessor, false); + var policyProcessor = new BarracudaPolicyProcessor(policy, model, inferenceDevice); + Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, false); } else { - Academy.Instance.RegisterWorld(policyId, world, null, false); + Academy.Instance.RegisterPolicy(policyId, policy, null, false); } } } - internal unsafe class BarracudaWorldProcessor : IWorldProcessor + internal unsafe class BarracudaPolicyProcessor : IPolicyProcessor { - private MLAgentsWorld m_World; + private Policy m_Policy; private Model m_BarracudaModel; private IWorker m_Engine; private const bool k_Verbose = false; @@ -98,9 +98,9 @@ namespace Unity.AI.MLAgents public bool IsConnected {get {return false;}} - internal BarracudaWorldProcessor(MLAgentsWorld world, NNModel model, InferenceDevice inferenceDevice) + internal BarracudaPolicyProcessor(Policy policy, NNModel model, InferenceDevice inferenceDevice) { - this.m_World = world; + this.m_Policy = policy; D.logEnabled = k_Verbose; m_Engine?.Dispose(); @@ -111,15 +111,15 @@ namespace Unity.AI.MLAgents m_Engine = WorkerFactory.CreateWorker( executionDevice, m_BarracudaModel, k_Verbose); - for (int i = 0; i < m_World.SensorShapes.Length; i++) + for (int i = 0; i < m_Policy.SensorShapes.Length; i++) { - if (m_World.SensorShapes[i].GetDimensions() == 1) - obsSize += m_World.SensorShapes[i].GetTotalTensorSize(); + if (m_Policy.SensorShapes[i].GetDimensions() == 1) + obsSize += m_Policy.SensorShapes[i].GetTotalTensorSize(); } - vectorObsArr = new float[m_World.DecisionAgentIds.Length * obsSize]; + vectorObsArr = new float[m_Policy.DecisionAgentIds.Length * obsSize]; } - public void ProcessWorld() + public void Process() { // TODO : Cover all cases // FOR VECTOR OBS ONLY @@ -128,27 +128,27 @@ namespace Unity.AI.MLAgents var input = new System.Collections.Generic.Dictionary(); - // var sensorData = m_World.DecisionObs.ToArray(); + // var sensorData = m_Policy.DecisionObs.ToArray(); int sensorOffset = 0; int vecObsOffset = 0; - foreach (var shape in m_World.SensorShapes) + foreach (var shape in m_Policy.SensorShapes) { if (shape.GetDimensions() == 1) { - for (int i = 0; i < m_World.DecisionCounter.Count; i++) + for (int i = 0; i < m_Policy.DecisionCounter.Count; i++) { fixed(void* arrPtr = vectorObsArr) { UnsafeUtility.MemCpy( (byte*)arrPtr + 4 * i * obsSize + 4 * vecObsOffset, - (byte*)m_World.DecisionObs.GetUnsafePtr() + 4 * sensorOffset + 4 * i * shape.GetTotalTensorSize(), + (byte*)m_Policy.DecisionObs.GetUnsafePtr() + 4 * sensorOffset + 4 * i * shape.GetTotalTensorSize(), shape.GetTotalTensorSize() * 4 ); } // Array.Copy(sensorData, sensorOffset + i * shape.GetTotalTensorSize(), vectorObsArr, i * obsSize + vecObsOffset, shape.GetTotalTensorSize()); } - sensorOffset += m_World.DecisionAgentIds.Length * shape.GetTotalTensorSize(); + sensorOffset += m_Policy.DecisionAgentIds.Length * shape.GetTotalTensorSize(); vecObsOffset += shape.GetTotalTensorSize(); } else @@ -158,7 +158,7 @@ namespace Unity.AI.MLAgents } input["vector_observation"] = new Tensor( - new TensorShape(m_World.DecisionCounter.Count, obsSize), + new TensorShape(m_Policy.DecisionCounter.Count, obsSize), vectorObsArr, "vector_observation"); @@ -166,18 +166,18 @@ namespace Unity.AI.MLAgents var actuatorT = m_Engine.CopyOutput("action"); - switch (m_World.ActionType) + switch (m_Policy.ActionType) { case ActionType.CONTINUOUS: - int count = m_World.DecisionCounter.Count * m_World.ActionSize; + int count = m_Policy.DecisionCounter.Count * m_Policy.ActionSize; var wholeData = actuatorT.data.Download(count); // var dest = new float[count]; // Array.Copy(wholeData, dest, count); - // m_World.ContinuousActuators.Slice(0, count).CopyFrom(dest); + // m_Policy.ContinuousActuators.Slice(0, count).CopyFrom(dest); fixed(void* arrPtr = wholeData) { UnsafeUtility.MemCpy( - m_World.ContinuousActuators.GetUnsafePtr(), + m_Policy.ContinuousActuators.GetUnsafePtr(), arrPtr, count * 4 ); diff --git a/Runtime/WorldProcessor/BarracudaWorldProcessor.cs.meta b/Runtime/PolicyProcessor/BarracudaPolicyProcessor.cs.meta similarity index 100% rename from Runtime/WorldProcessor/BarracudaWorldProcessor.cs.meta rename to Runtime/PolicyProcessor/BarracudaPolicyProcessor.cs.meta diff --git a/Runtime/PolicyProcessor/HeuristicPolicyProcessor.cs b/Runtime/PolicyProcessor/HeuristicPolicyProcessor.cs new file mode 100644 index 0000000..c87c28e --- /dev/null +++ b/Runtime/PolicyProcessor/HeuristicPolicyProcessor.cs @@ -0,0 +1,104 @@ +using System; +using Unity.Collections; +using Unity.Collections.LowLevel.Unsafe; + + +namespace Unity.AI.MLAgents +{ + public static class HeristicPolicyProcessorRegistringExtension + { + /// + /// Registers the given Policy to the Academy with a Heuristic. + /// Note that if the simulation connects to Python, the Heuristic will + /// be ignored and the Policy will exchange data with Python instead. + /// The Heuristic is a Function that returns an action struct. + /// + /// The Policy to register + /// The name of the Policy. This is useful for identification + /// and for training. + /// The Heuristic used to generate the actions. + /// Note that all agents in the Policy will receive the same action. + /// The type of the Action struct. It must match the Action + /// Size and Action Type of the Policy. + public static void RegisterPolicyWithHeuristic( + this Policy policy, + string policyId, + Func heuristic + ) where TH : struct + { + var policyProcessor = new HeuristicPolicyProcessor(policy, heuristic); + Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, true); + } + + /// + /// Registers the given Policy to the Academy with a Heuristic. + /// Note that if the simulation connects to Python, the Policy will not + /// exchange data with Python and use the Heuristic regardless. + /// The Heuristic is a Function that returns an action struct. + /// + /// The Policy to register + /// The name of the Policy. This is useful for identification + /// and for training. + /// The Heuristic used to generate the actions. + /// Note that all agents in the Policy will receive the same action. + /// The type of the Action struct. It must match the Action + /// Size and Action Type of the Policy. + public static void RegisterPolicyWithHeuristicForceNoCommunication( + this Policy policy, + string policyId, + Func heuristic + ) where TH : struct + { + var policyProcessor = new HeuristicPolicyProcessor(policy, heuristic); + Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, false); + } + } + + internal class HeuristicPolicyProcessor : IPolicyProcessor where T : struct + { + private Func m_Heuristic; + private Policy m_Policy; + + public bool IsConnected {get {return false;}} + + internal HeuristicPolicyProcessor(Policy policy, Func heuristic) + { + this.m_Policy = policy; + this.m_Heuristic = heuristic; + var structSize = UnsafeUtility.SizeOf() / sizeof(float); + if (structSize != policy.ActionSize) + { + throw new MLAgentsException( + $"The heuristic provided does not match the action size. Expected {policy.ActionSize} but received {structSize} from heuristic"); + } + } + + public void Process() + { + T action = m_Heuristic.Invoke(); + var totalCount = m_Policy.DecisionCounter.Count; + + // TODO : This can be parallelized + if (m_Policy.ActionType == ActionType.CONTINUOUS) + { + var s = m_Policy.ContinuousActuators.Slice(0, totalCount * m_Policy.ActionSize).SliceConvert(); + for (int i = 0; i < totalCount; i++) + { + s[i] = action; + } + } + else + { + var s = m_Policy.DiscreteActuators.Slice(0, totalCount * m_Policy.ActionSize).SliceConvert(); + for (int i = 0; i < totalCount; i++) + { + s[i] = action; + } + } + } + + public void Dispose() + { + } + } +} diff --git a/Runtime/WorldProcessor/HeuristicWorldProcessor.cs.meta b/Runtime/PolicyProcessor/HeuristicPolicyProcessor.cs.meta similarity index 100% rename from Runtime/WorldProcessor/HeuristicWorldProcessor.cs.meta rename to Runtime/PolicyProcessor/HeuristicPolicyProcessor.cs.meta diff --git a/Runtime/WorldProcessor/IWorldProcessor.cs b/Runtime/PolicyProcessor/IPolicyProcessor.cs similarity index 51% rename from Runtime/WorldProcessor/IWorldProcessor.cs rename to Runtime/PolicyProcessor/IPolicyProcessor.cs index 53f3fa0..2d05042 100644 --- a/Runtime/WorldProcessor/IWorldProcessor.cs +++ b/Runtime/PolicyProcessor/IPolicyProcessor.cs @@ -6,21 +6,21 @@ using Unity.Collections.LowLevel.Unsafe; namespace Unity.AI.MLAgents { /// - /// The interface for a world processor. A world processor updates the - /// action data of a world using the observation data present in it. + /// The interface for a Policy processor. A Policy processor updates the + /// action data of a Policy using the observation data present in it. /// - public interface IWorldProcessor : IDisposable + public interface IPolicyProcessor : IDisposable { /// - /// True if the World Processor is connected to the Python process + /// True if the Policy Processor is connected to the Python process /// bool IsConnected {get;} /// - /// This method is called once everytime the world needs to update its action + /// This method is called once everytime the policy needs to update its action /// data. The implementation of this mehtod must give new action data to all the /// agents that requested a decision. /// - void ProcessWorld(); + void Process(); } } diff --git a/Runtime/WorldProcessor/IWorldProcessor.cs.meta b/Runtime/PolicyProcessor/IPolicyProcessor.cs.meta similarity index 100% rename from Runtime/WorldProcessor/IWorldProcessor.cs.meta rename to Runtime/PolicyProcessor/IPolicyProcessor.cs.meta diff --git a/Runtime/WorldProcessor/NullWorldProcessor.cs b/Runtime/PolicyProcessor/NullPolicyProcessor.cs similarity index 55% rename from Runtime/WorldProcessor/NullWorldProcessor.cs rename to Runtime/PolicyProcessor/NullPolicyProcessor.cs index 506c78c..7335c23 100644 --- a/Runtime/WorldProcessor/NullWorldProcessor.cs +++ b/Runtime/PolicyProcessor/NullPolicyProcessor.cs @@ -5,18 +5,18 @@ using Unity.Collections.LowLevel.Unsafe; namespace Unity.AI.MLAgents { - internal class NullWorldProcessor : IWorldProcessor + internal class NullPolicyProcessor : IPolicyProcessor { - private MLAgentsWorld m_World; + private Policy m_Policy; public bool IsConnected {get {return false;}} - internal NullWorldProcessor(MLAgentsWorld world) + internal NullPolicyProcessor(Policy policy) { - this.m_World = world; + this.m_Policy = policy; } - public void ProcessWorld() + public void Process() { } diff --git a/Runtime/WorldProcessor/NullWorldProcessor.cs.meta b/Runtime/PolicyProcessor/NullPolicyProcessor.cs.meta similarity index 100% rename from Runtime/WorldProcessor/NullWorldProcessor.cs.meta rename to Runtime/PolicyProcessor/NullPolicyProcessor.cs.meta diff --git a/Runtime/WorldProcessor/RemoteWorldProcessor.cs b/Runtime/PolicyProcessor/RemotePolicyProcessor.cs similarity index 57% rename from Runtime/WorldProcessor/RemoteWorldProcessor.cs rename to Runtime/PolicyProcessor/RemotePolicyProcessor.cs index ed69242..61643b5 100644 --- a/Runtime/WorldProcessor/RemoteWorldProcessor.cs +++ b/Runtime/PolicyProcessor/RemotePolicyProcessor.cs @@ -5,27 +5,27 @@ using Unity.Collections.LowLevel.Unsafe; namespace Unity.AI.MLAgents { - internal class RemoteWorldProcessor : IWorldProcessor + internal class RemotePolicyProcessor : IPolicyProcessor { - private MLAgentsWorld m_World; + private Policy m_Policy; private SharedMemoryCommunicator m_Communicator; private string m_PolicyId; public bool IsConnected {get {return true;}} - internal RemoteWorldProcessor(MLAgentsWorld world, string policyId, SharedMemoryCommunicator com) + internal RemotePolicyProcessor(Policy policy, string policyId, SharedMemoryCommunicator com) { - this.m_World = world; + this.m_Policy = policy; this.m_Communicator = com; this.m_PolicyId = policyId; } - public void ProcessWorld() + public void Process() { - m_Communicator.WriteWorld(m_PolicyId, m_World); + m_Communicator.WritePolicy(m_PolicyId, m_Policy); m_Communicator.SetUnityReady(); m_Communicator.WaitForPython(); - m_Communicator.LoadWorld(m_PolicyId, m_World); + m_Communicator.LoadPolicy(m_PolicyId, m_Policy); } public void Dispose() diff --git a/Runtime/WorldProcessor/RemoteWorldProcessor.cs.meta b/Runtime/PolicyProcessor/RemotePolicyProcessor.cs.meta similarity index 100% rename from Runtime/WorldProcessor/RemoteWorldProcessor.cs.meta rename to Runtime/PolicyProcessor/RemotePolicyProcessor.cs.meta diff --git a/Runtime/Remote/RLDataOffsets.cs b/Runtime/Remote/RLDataOffsets.cs index 78498aa..994b084 100644 --- a/Runtime/Remote/RLDataOffsets.cs +++ b/Runtime/Remote/RLDataOffsets.cs @@ -68,26 +68,26 @@ namespace Unity.AI.MLAgents startOffset); } - public static RLDataOffsets FromWorld(MLAgentsWorld world, string name, int offset) + public static RLDataOffsets FromPolicy(Policy policy, string name, int offset) { - bool isContinuous = world.ActionType == ActionType.CONTINUOUS; + bool isContinuous = policy.ActionType == ActionType.CONTINUOUS; int totalFloatObsPerAgent = 0; - foreach (int3 shape in world.SensorShapes) + foreach (int3 shape in policy.SensorShapes) { totalFloatObsPerAgent += shape.GetTotalTensorSize(); } int totalNumberOfMasks = 0; if (!isContinuous) { - totalNumberOfMasks = world.DiscreteActionBranches.Sum(); + totalNumberOfMasks = policy.DiscreteActionBranches.Sum(); } return ComputeOffsets( name, - world.DecisionAgentIds.Length, + policy.DecisionAgentIds.Length, isContinuous, - world.ActionSize, - world.SensorShapes.Length, + policy.ActionSize, + policy.SensorShapes.Length, totalFloatObsPerAgent, totalNumberOfMasks, offset diff --git a/Runtime/Remote/SharedMemoryBody.cs b/Runtime/Remote/SharedMemoryBody.cs index d5f1c1b..5052042 100644 --- a/Runtime/Remote/SharedMemoryBody.cs +++ b/Runtime/Remote/SharedMemoryBody.cs @@ -52,12 +52,12 @@ namespace Unity.AI.MLAgents } } - public bool ContainsWorld(string name) + public bool ContainsPolicy(string name) { return m_OffsetDict.ContainsKey(name); } - public void WriteWorld(string name, MLAgentsWorld world) + public void WritePolicy(string name, Policy policy) { if (!CanEdit) { @@ -69,48 +69,48 @@ namespace Unity.AI.MLAgents } var dataOffsets = m_OffsetDict[name]; int totalFloatObsPerAgent = 0; - foreach (int3 shape in world.SensorShapes) + foreach (int3 shape in policy.SensorShapes) { totalFloatObsPerAgent += shape.GetTotalTensorSize(); } // Decision data - var decisionCount = world.DecisionCounter.Count; + var decisionCount = policy.DecisionCounter.Count; SetInt(dataOffsets.DecisionNumberAgentsOffset, decisionCount); - SetArray(dataOffsets.DecisionObsOffset, world.DecisionObs, 4 * decisionCount * totalFloatObsPerAgent); - SetArray(dataOffsets.DecisionRewardsOffset, world.DecisionRewards, 4 * decisionCount); - SetArray(dataOffsets.DecisionAgentIdOffset, world.DecisionAgentIds, 4 * decisionCount); - if (world.ActionType == ActionType.DISCRETE) + SetArray(dataOffsets.DecisionObsOffset, policy.DecisionObs, 4 * decisionCount * totalFloatObsPerAgent); + SetArray(dataOffsets.DecisionRewardsOffset, policy.DecisionRewards, 4 * decisionCount); + SetArray(dataOffsets.DecisionAgentIdOffset, policy.DecisionAgentIds, 4 * decisionCount); + if (policy.ActionType == ActionType.DISCRETE) { - SetArray(dataOffsets.DecisionActionMasksOffset, world.DecisionActionMasks, decisionCount * world.DiscreteActionBranches.Sum()); + SetArray(dataOffsets.DecisionActionMasksOffset, policy.DecisionActionMasks, decisionCount * policy.DiscreteActionBranches.Sum()); } //Termination data - var terminationCount = world.TerminationCounter.Count; + var terminationCount = policy.TerminationCounter.Count; SetInt(dataOffsets.TerminationNumberAgentsOffset, terminationCount); - SetArray(dataOffsets.TerminationObsOffset, world.TerminationObs, 4 * terminationCount * totalFloatObsPerAgent); - SetArray(dataOffsets.TerminationRewardsOffset, world.TerminationRewards, 4 * terminationCount); - SetArray(dataOffsets.TerminationAgentIdOffset, world.TerminationAgentIds, 4 * terminationCount); - SetArray(dataOffsets.TerminationStatusOffset, world.TerminationStatus, terminationCount); + SetArray(dataOffsets.TerminationObsOffset, policy.TerminationObs, 4 * terminationCount * totalFloatObsPerAgent); + SetArray(dataOffsets.TerminationRewardsOffset, policy.TerminationRewards, 4 * terminationCount); + SetArray(dataOffsets.TerminationAgentIdOffset, policy.TerminationAgentIds, 4 * terminationCount); + SetArray(dataOffsets.TerminationStatusOffset, policy.TerminationStatus, terminationCount); } - public void WriteWorldSpecs(string name, MLAgentsWorld world) + public void WritePolicySpecs(string name, Policy policy) { - m_OffsetDict[name] = RLDataOffsets.FromWorld(world, name, m_CurrentEndOffset); + m_OffsetDict[name] = RLDataOffsets.FromPolicy(policy, name, m_CurrentEndOffset); var offset = m_CurrentEndOffset; offset = SetString(offset, name); // Name - offset = SetInt(offset, world.DecisionAgentIds.Length); // Max Agents - offset = SetBool(offset, world.ActionType == ActionType.CONTINUOUS); - offset = SetInt(offset, world.ActionSize); - if (world.ActionType == ActionType.DISCRETE) + offset = SetInt(offset, policy.DecisionAgentIds.Length); // Max Agents + offset = SetBool(offset, policy.ActionType == ActionType.CONTINUOUS); + offset = SetInt(offset, policy.ActionSize); + if (policy.ActionType == ActionType.DISCRETE) { - foreach (int branchSize in world.DiscreteActionBranches) + foreach (int branchSize in policy.DiscreteActionBranches) { offset = SetInt(offset, branchSize); } } - offset = SetInt(offset, world.SensorShapes.Length); - foreach (int3 shape in world.SensorShapes) + offset = SetInt(offset, policy.SensorShapes.Length); + foreach (int3 shape in policy.SensorShapes) { offset = SetInt(offset, shape.x); offset = SetInt(offset, shape.y); @@ -119,7 +119,7 @@ namespace Unity.AI.MLAgents m_CurrentEndOffset = offset; } - public void ReadWorld(string name, MLAgentsWorld world) + public void ReadPolicy(string name, Policy policy) { if (!CanEdit) { @@ -127,18 +127,18 @@ namespace Unity.AI.MLAgents } if (!m_OffsetDict.ContainsKey(name)) { - throw new MLAgentsException("World not registered"); + throw new MLAgentsException("Policy not registered"); } var dataOffsets = m_OffsetDict[name]; SetInt(dataOffsets.DecisionNumberAgentsOffset, 0); SetInt(dataOffsets.TerminationNumberAgentsOffset, 0); - if (world.ActionType == ActionType.DISCRETE) + if (policy.ActionType == ActionType.DISCRETE) { - GetArray(dataOffsets.ActionOffset, world.DiscreteActuators, 4 * world.DecisionCounter.Count * world.ActionSize); + GetArray(dataOffsets.ActionOffset, policy.DiscreteActuators, 4 * policy.DecisionCounter.Count * policy.ActionSize); } else { - GetArray(dataOffsets.ActionOffset, world.ContinuousActuators, 4 * world.DecisionCounter.Count * world.ActionSize); + GetArray(dataOffsets.ActionOffset, policy.ContinuousActuators, 4 * policy.DecisionCounter.Count * policy.ActionSize); } } diff --git a/Runtime/Remote/SharedMemoryCommunicator.cs b/Runtime/Remote/SharedMemoryCommunicator.cs index 4f19920..b8ac7ef 100644 --- a/Runtime/Remote/SharedMemoryCommunicator.cs +++ b/Runtime/Remote/SharedMemoryCommunicator.cs @@ -75,23 +75,23 @@ namespace Unity.AI.MLAgents } /// - /// Writes the data of a world into the shared memory file. + /// Writes the data of a worPolicyld into the shared memory file. /// - public void WriteWorld(string worldName, MLAgentsWorld world) + public void WritePolicy(string policyName, Policy policy) { if (!m_SharedMemoryHeader.Active) { return; } - if (m_ShareMemoryBody.ContainsWorld(worldName)) + if (m_ShareMemoryBody.ContainsPolicy(policyName)) { - m_ShareMemoryBody.WriteWorld(worldName, world); + m_ShareMemoryBody.WritePolicy(policyName, policy); } else { - // The world needs to register + // The policy needs to register int oldTotalCapacity = m_SharedMemoryHeader.RLDataBufferSize; - int worldMemorySize = RLDataOffsets.FromWorld(world, worldName, 0).EndOfDataOffset; + int policyMemorySize = RLDataOffsets.FromPolicy(policy, policyName, 0).EndOfDataOffset; m_CurrentFileNumber += 1; m_SharedMemoryHeader.FileNumber = m_CurrentFileNumber; byte[] channelData = m_ShareMemoryBody.SideChannelData; @@ -102,9 +102,9 @@ namespace Unity.AI.MLAgents true, null, m_SharedMemoryHeader.SideChannelBufferSize, - oldTotalCapacity + worldMemorySize + oldTotalCapacity + policyMemorySize ); - m_SharedMemoryHeader.RLDataBufferSize = oldTotalCapacity + worldMemorySize; + m_SharedMemoryHeader.RLDataBufferSize = oldTotalCapacity + policyMemorySize; if (channelData != null) { m_ShareMemoryBody.SideChannelData = channelData; @@ -114,8 +114,8 @@ namespace Unity.AI.MLAgents m_ShareMemoryBody.RlData = rlData; } // TODO Need to write the offsets - m_ShareMemoryBody.WriteWorldSpecs(worldName, world); - m_ShareMemoryBody.WriteWorld(worldName, world); + m_ShareMemoryBody.WritePolicySpecs(policyName, policy); + m_ShareMemoryBody.WritePolicy(policyName, policy); } } @@ -180,11 +180,11 @@ namespace Unity.AI.MLAgents } /// - /// Loads the action data form the shared memory file to the world + /// Loads the action data form the shared memory file to the policy /// - public void LoadWorld(string worldName, MLAgentsWorld world) + public void LoadPolicy(string policyName, Policy policy) { - m_ShareMemoryBody.ReadWorld(worldName, world); + m_ShareMemoryBody.ReadPolicy(policyName, policy); } public void Dispose() diff --git a/Runtime/UI/MLAgentsWorldSpecs.cs b/Runtime/UI/PolicySpecs.cs similarity index 51% rename from Runtime/UI/MLAgentsWorldSpecs.cs rename to Runtime/UI/PolicySpecs.cs index e2ae138..e2a0df8 100644 --- a/Runtime/UI/MLAgentsWorldSpecs.cs +++ b/Runtime/UI/PolicySpecs.cs @@ -6,7 +6,7 @@ using UnityEngine; namespace Unity.AI.MLAgents { - internal enum WorldProcessorType + internal enum PolicyProcessorType { Default, InferenceOnly, @@ -14,16 +14,16 @@ namespace Unity.AI.MLAgents } /// - /// An editor friendly constructor for a MLAgentsWorld. - /// Keeps track of the behavior specs of a world, its name, + /// An editor friendly constructor for a Policy. + /// Keeps track of the behavior specs of a Policy, its name, /// its processor type and Neural Network Model. /// [Serializable] - public struct MLAgentsWorldSpecs + public struct PolicySpecs { [SerializeField] internal string Name; - [SerializeField] internal WorldProcessorType WorldProcessorType; + [SerializeField] internal PolicyProcessorType PolicyProcessorType; [SerializeField] internal int NumberAgents; [SerializeField] internal ActionType ActionType; @@ -34,47 +34,47 @@ namespace Unity.AI.MLAgents [SerializeField] internal NNModel Model; [SerializeField] internal InferenceDevice InferenceDevice; - private MLAgentsWorld m_World; + private Policy m_Policy; /// - /// Generates the world using the specified specs and registers its - /// processor to the Academy. The world is only created and registed once, - /// if subsequent calls are made, the created world will be returned. + /// Generates a Policy using the specified specs and registers its + /// processor to the Academy. The policy is only created and registed once, + /// if subsequent calls are made, the created policy will be returned. /// /// - public MLAgentsWorld GetWorld() + public Policy GetPolicy() { - if (m_World.IsCreated) + if (m_Policy.IsCreated) { - return m_World; + return m_Policy; } - m_World = new MLAgentsWorld( + m_Policy = new Policy( NumberAgents, ObservationShapes, ActionType, ActionSize, DiscreteActionBranches ); - switch (WorldProcessorType) + switch (PolicyProcessorType) { - case WorldProcessorType.Default: - m_World.RegisterWorldWithBarracudaModel(Name, Model, InferenceDevice); + case PolicyProcessorType.Default: + m_Policy.RegisterPolicyWithBarracudaModel(Name, Model, InferenceDevice); break; - case WorldProcessorType.InferenceOnly: + case PolicyProcessorType.InferenceOnly: if (Model == null) { throw new MLAgentsException($"No model specified for {Name}"); } - m_World.RegisterWorldWithBarracudaModelForceNoCommunication(Name, Model, InferenceDevice); + m_Policy.RegisterPolicyWithBarracudaModelForceNoCommunication(Name, Model, InferenceDevice); break; - case WorldProcessorType.None: - Academy.Instance.RegisterWorld(Name, m_World, new NullWorldProcessor(m_World), false); + case PolicyProcessorType.None: + Academy.Instance.RegisterPolicy(Name, m_Policy, new NullPolicyProcessor(m_Policy), false); break; default: - throw new MLAgentsException($"Unknown WorldProcessor Type"); + throw new MLAgentsException($"Unknown IPolicyProcessor Type"); } - return m_World; + return m_Policy; } } } diff --git a/Runtime/UI/MLAgentsWorldSpecs.cs.meta b/Runtime/UI/PolicySpecs.cs.meta similarity index 100% rename from Runtime/UI/MLAgentsWorldSpecs.cs.meta rename to Runtime/UI/PolicySpecs.cs.meta diff --git a/Runtime/WorldProcessor/HeuristicWorldProcessor.cs b/Runtime/WorldProcessor/HeuristicWorldProcessor.cs deleted file mode 100644 index a9278fa..0000000 --- a/Runtime/WorldProcessor/HeuristicWorldProcessor.cs +++ /dev/null @@ -1,104 +0,0 @@ -using System; -using Unity.Collections; -using Unity.Collections.LowLevel.Unsafe; - - -namespace Unity.AI.MLAgents -{ - public static class HeristicWorldProcessorRegistringExtension - { - /// - /// Registers the given MLAgentsWorld to the Academy with a Heuristic. - /// Note that if the simulation connects to Python, the Heuristic will - /// be ignored and the world will exchange data with Python instead. - /// The Heuristic is a Function that returns an action struct. - /// - /// The MLAgentsWorld to register - /// The name of the world. This is useful for identification - /// and for training. - /// The Heuristic used to generate the actions. - /// Note that all agents in the world will receive the same action. - /// The type of the Action struct. It must match the Action - /// Size and Action Type of the world. - public static void RegisterWorldWithHeuristic( - this MLAgentsWorld world, - string policyId, - Func heuristic - ) where TH : struct - { - var worldProcessor = new HeuristicWorldProcessor(world, heuristic); - Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true); - } - - /// - /// Registers the given MLAgentsWorld to the Academy with a Heuristic. - /// Note that if the simulation connects to Python, the world will not - /// exchange data with Python and use the Heuristic regardless. - /// The Heuristic is a Function that returns an action struct. - /// - /// The MLAgentsWorld to register - /// The name of the world. This is useful for identification - /// and for training. - /// The Heuristic used to generate the actions. - /// Note that all agents in the world will receive the same action. - /// The type of the Action struct. It must match the Action - /// Size and Action Type of the world. - public static void RegisterWorldWithHeuristicForceNoCommunication( - this MLAgentsWorld world, - string policyId, - Func heuristic - ) where TH : struct - { - var worldProcessor = new HeuristicWorldProcessor(world, heuristic); - Academy.Instance.RegisterWorld(policyId, world, worldProcessor, false); - } - } - - internal class HeuristicWorldProcessor : IWorldProcessor where T : struct - { - private Func m_Heuristic; - private MLAgentsWorld m_World; - - public bool IsConnected {get {return false;}} - - internal HeuristicWorldProcessor(MLAgentsWorld world, Func heuristic) - { - this.m_World = world; - this.m_Heuristic = heuristic; - var structSize = UnsafeUtility.SizeOf() / sizeof(float); - if (structSize != world.ActionSize) - { - throw new MLAgentsException( - $"The heuristic provided does not match the action size. Expected {world.ActionSize} but received {structSize} from heuristic"); - } - } - - public void ProcessWorld() - { - T action = m_Heuristic.Invoke(); - var totalCount = m_World.DecisionCounter.Count; - - // TODO : This can be parallelized - if (m_World.ActionType == ActionType.CONTINUOUS) - { - var s = m_World.ContinuousActuators.Slice(0, totalCount * m_World.ActionSize).SliceConvert(); - for (int i = 0; i < totalCount; i++) - { - s[i] = action; - } - } - else - { - var s = m_World.DiscreteActuators.Slice(0, totalCount * m_World.ActionSize).SliceConvert(); - for (int i = 0; i < totalCount; i++) - { - s[i] = action; - } - } - } - - public void Dispose() - { - } - } -} diff --git a/Samples~/3DBall/Prefab/AgentBlue.mat b/Samples~/3DBall/Prefab/AgentBlue.mat index 98afe27..098e81a 100644 --- a/Samples~/3DBall/Prefab/AgentBlue.mat +++ b/Samples~/3DBall/Prefab/AgentBlue.mat @@ -4,8 +4,9 @@ Material: serializedVersion: 6 m_ObjectHideFlags: 0 - m_PrefabParentObject: {fileID: 0} - m_PrefabInternal: {fileID: 0} + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} m_Name: AgentBlue m_Shader: {fileID: 47, guid: 0000000000000000f000000000000000, type: 0} m_ShaderKeywords: _GLOSSYREFLECTIONS_OFF _SPECULARHIGHLIGHTS_OFF diff --git a/Samples~/3DBall/Scene/3DBall.unity b/Samples~/3DBall/Scene/3DBall.unity index 7d34bf8..3afbe92 100644 --- a/Samples~/3DBall/Scene/3DBall.unity +++ b/Samples~/3DBall/Scene/3DBall.unity @@ -249,9 +249,9 @@ MonoBehaviour: m_Script: {fileID: 11500000, guid: 99c8505fda4844b5b8a9505d069fdd8c, type: 3} m_Name: m_EditorClassIdentifier: - MyWorldSpecs: - Name: Ball_DOTS - WorldProcessorType: 0 + MyPolicySpecs: + Name: 3DBall + PolicyProcessorType: 0 NumberAgents: 1000 ActionType: 1 ObservationShapes: @@ -268,7 +268,7 @@ MonoBehaviour: y: 0 z: 0 ActionSize: 2 - DiscreteActionBranches: + DiscreteActionBranches: 0000000000000000 Model: {fileID: 5022602860645237092, guid: 10312addcce7c4e508c30d95eaf1a2ff, type: 3} InferenceDevice: 0 NumberBalls: 1000 diff --git a/Samples~/3DBall/Script/BalanceBallManager.cs b/Samples~/3DBall/Script/BalanceBallManager.cs index 7a1346e..c31e001 100644 --- a/Samples~/3DBall/Script/BalanceBallManager.cs +++ b/Samples~/3DBall/Script/BalanceBallManager.cs @@ -8,7 +8,7 @@ using Unity.AI.MLAgents; public class BalanceBallManager : MonoBehaviour { - public MLAgentsWorldSpecs MyWorldSpecs; + public PolicySpecs MyPolicySpecs; public int NumberBalls = 1000; @@ -23,12 +23,13 @@ public class BalanceBallManager : MonoBehaviour NativeArray entitiesB; BlobAssetStore blob; + void Awake() { - var world = MyWorldSpecs.GetWorld(); + var policy = MyPolicySpecs.GetPolicy(); var ballSystem = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem(); ballSystem.Enabled = true; - ballSystem.BallWorld = world; + ballSystem.BallPolicy = policy; manager = World.DefaultGameObjectInjectionWorld.EntityManager; diff --git a/Samples~/3DBall/Script/BallSystem.cs b/Samples~/3DBall/Script/BallSystem.cs index 4c9d4d6..a5c5c6e 100644 --- a/Samples~/3DBall/Script/BallSystem.cs +++ b/Samples~/3DBall/Script/BallSystem.cs @@ -33,17 +33,17 @@ public class BallSystem : JobComponentSystem } } - public MLAgentsWorld BallWorld; + public Policy BallPolicy; // Update is called once per frame protected override JobHandle OnUpdate(JobHandle inputDeps) { - if (!BallWorld.IsCreated){ + if (!BallPolicy.IsCreated){ return inputDeps; } - var world = BallWorld; + var policy = BallPolicy; ComponentDataFromEntity TranslationFromEntity = GetComponentDataFromEntity(isReadOnly: false); ComponentDataFromEntity VelFromEntity = GetComponentDataFromEntity(isReadOnly: false); @@ -52,7 +52,7 @@ public class BallSystem : JobComponentSystem .WithNativeDisableParallelForRestriction(VelFromEntity) .ForEach((Entity entity, ref Rotation rot, ref AgentData agentData) => { - + var ballPos = TranslationFromEntity[agentData.BallRef].Value; var ballVel = VelFromEntity[agentData.BallRef].Linear; var platformVel = VelFromEntity[entity]; @@ -70,7 +70,7 @@ public class BallSystem : JobComponentSystem } if (!interruption && !taskFailed) { - world.RequestDecision(entity) + policy.RequestDecision(entity) .SetObservation(0, rot.Value) .SetObservation(1, ballPos - agentData.BallResetPosition) .SetObservation(2, ballVel) @@ -79,7 +79,7 @@ public class BallSystem : JobComponentSystem } if (taskFailed) { - world.EndEpisode(entity) + policy.EndEpisode(entity) .SetObservation(0, rot.Value) .SetObservation(1, ballPos - agentData.BallResetPosition) .SetObservation(2, ballVel) @@ -88,7 +88,7 @@ public class BallSystem : JobComponentSystem } else if (interruption) { - world.InterruptEpisode(entity) + policy.InterruptEpisode(entity) .SetObservation(0, rot.Value) .SetObservation(1, ballPos - agentData.BallResetPosition) .SetObservation(2, ballVel) @@ -109,7 +109,7 @@ public class BallSystem : JobComponentSystem { ComponentDataFromEntity = GetComponentDataFromEntity(isReadOnly: false) }; - inputDeps = reactiveJob.Schedule(world, inputDeps); + inputDeps = reactiveJob.Schedule(policy, inputDeps); inputDeps = Entities.ForEach((Actuator act, ref Rotation rotation) => { @@ -122,6 +122,6 @@ public class BallSystem : JobComponentSystem protected override void OnDestroy() { - BallWorld.Dispose(); + BallPolicy.Dispose(); } } diff --git a/Samples~/Basic/Script/BasicAgent.cs b/Samples~/Basic/Script/BasicAgent.cs index 16ecae6..e924c6c 100644 --- a/Samples~/Basic/Script/BasicAgent.cs +++ b/Samples~/Basic/Script/BasicAgent.cs @@ -5,8 +5,8 @@ using UnityEngine; public class BasicAgent : MonoBehaviour { - public MLAgentsWorldSpecs BasicSpecs; - private MLAgentsWorld m_World; + public PolicySpecs BasicSpecs; + private Policy m_Policy; private Entity m_Entity; public float timeBetweenDecisionsAtInference; @@ -33,8 +33,8 @@ public class BasicAgent : MonoBehaviour void Start() { m_Entity = World.DefaultGameObjectInjectionWorld.EntityManager.CreateEntity(); - m_World = BasicSpecs.GetWorld(); - m_World.RegisterWorldWithHeuristic("BASIC", () => { return 2; }); + m_Policy = BasicSpecs.GetPolicy(); + m_Policy.RegisterPolicyWithHeuristic("BASIC", () => { return 2; }); Academy.Instance.OnEnvironmentReset += BeginEpisode; BeginEpisode(); } @@ -62,12 +62,12 @@ public class BasicAgent : MonoBehaviour void StepAgent() { // Request a Decision for all agents - m_World.RequestDecision(m_Entity) + m_Policy.RequestDecision(m_Entity) .SetObservation(0, m_Position) .SetReward(-0.01f); // Get the action - NativeHashMap actions = m_World.GenerateActionHashMap(Allocator.Temp); + NativeHashMap actions = m_Policy.GenerateActionHashMap(Allocator.Temp); int action = 0; actions.TryGetValue(m_Entity, out action); @@ -87,7 +87,7 @@ public class BasicAgent : MonoBehaviour // See if the Agent terminated if (m_Position == k_SmallGoalPosition) { - m_World.EndEpisode(m_Entity) + m_Policy.EndEpisode(m_Entity) .SetObservation(0, m_Position) .SetReward(0.1f); BeginEpisode(); @@ -95,7 +95,7 @@ public class BasicAgent : MonoBehaviour if (m_Position == k_LargeGoalPosition) { - m_World.EndEpisode(m_Entity) + m_Policy.EndEpisode(m_Entity) .SetObservation(0, m_Position) .SetReward(1f); BeginEpisode(); @@ -104,6 +104,6 @@ public class BasicAgent : MonoBehaviour void OnDestroy() { - m_World.Dispose(); + m_Policy.Dispose(); } } diff --git a/Tests/Editor/TestMLAgentsWorld.cs b/Tests/Editor/TestMLAgentsPolicy.cs similarity index 79% rename from Tests/Editor/TestMLAgentsWorld.cs rename to Tests/Editor/TestMLAgentsPolicy.cs index 2c898ef..f5deeb0 100644 --- a/Tests/Editor/TestMLAgentsWorld.cs +++ b/Tests/Editor/TestMLAgentsPolicy.cs @@ -6,7 +6,7 @@ using Unity.Jobs; namespace Unity.AI.MLAgents.Tests.Editor { - public class TestMLAgentsWorld + public class TestMLAgentsPolicy { private World ECSWorld; private EntityManager entityManager; @@ -51,15 +51,15 @@ namespace Unity.AI.MLAgents.Tests.Editor } [Test] - public void TestWorldCreation() + public void TestPolicyCreation() { - var world = new MLAgentsWorld( + var policy = new Policy( 20, new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) }, ActionType.DISCRETE, 2, new int[] { 2, 3 }); - world.Dispose(); + policy.Dispose(); } private struct SingleActionEnumUpdate : IActuatorJob @@ -77,18 +77,18 @@ namespace Unity.AI.MLAgents.Tests.Editor [Test] public void TestManualDecisionSteppingWithHeuristic() { - var world = new MLAgentsWorld( + var policy = new Policy( 20, new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) }, ActionType.DISCRETE, 2, new int[] { 2, 3 }); - world.RegisterWorldWithHeuristic("test", () => new int2(0, 1)); + policy.RegisterPolicyWithHeuristic("test", () => new int2(0, 1)); var entity = entityManager.CreateEntity(); - world.RequestDecision(entity) + policy.RequestDecision(entity) .SetObservation(0, new float3(1, 2, 3)) .SetReward(1f); @@ -100,7 +100,7 @@ namespace Unity.AI.MLAgents.Tests.Editor ent = entities, action = actions }; - actionJob.Schedule(world, new JobHandle()).Complete(); + actionJob.Schedule(policy, new JobHandle()).Complete(); Assert.AreEqual(entity, entities[0]); Assert.AreEqual(new DiscreteAction_TWO_THREE @@ -109,7 +109,7 @@ namespace Unity.AI.MLAgents.Tests.Editor action_TWO = ThreeOptionEnum.Option_TWO }, actions[0]); - world.Dispose(); + policy.Dispose(); entities.Dispose(); actions.Dispose(); Academy.Instance.Dispose(); @@ -118,56 +118,56 @@ namespace Unity.AI.MLAgents.Tests.Editor [Test] public void TestTerminateEpisode() { - var world = new MLAgentsWorld( + var policy = new Policy( 20, new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) }, ActionType.DISCRETE, 2, new int[] { 2, 3 }); - world.RegisterWorldWithHeuristic("test", () => new int2(0, 1)); + policy.RegisterPolicyWithHeuristic("test", () => new int2(0, 1)); var entity = entityManager.CreateEntity(); - world.RequestDecision(entity) + policy.RequestDecision(entity) .SetObservation(0, new float3(1, 2, 3)) .SetReward(1f); - var hashMap = world.GenerateActionHashMap(Allocator.Temp); + var hashMap = policy.GenerateActionHashMap(Allocator.Temp); Assert.True(hashMap.TryGetValue(entity, out _)); hashMap.Dispose(); - world.EndEpisode(entity); - hashMap = world.GenerateActionHashMap(Allocator.Temp); + policy.EndEpisode(entity); + hashMap = policy.GenerateActionHashMap(Allocator.Temp); Assert.False(hashMap.TryGetValue(entity, out _)); hashMap.Dispose(); - world.Dispose(); + policy.Dispose(); Academy.Instance.Dispose(); } [Test] - public void TestMultiWorld() + public void TestMultiPolicy() { - var world1 = new MLAgentsWorld( + var policy1 = new Policy( 20, new int3[] { new int3(3, 0, 0) }, ActionType.DISCRETE, 2, new int[] { 2, 3 }); - var world2 = new MLAgentsWorld( + var policy2 = new Policy( 20, new int3[] { new int3(3, 0, 0) }, ActionType.DISCRETE, 2, new int[] { 2, 3 }); - world1.RegisterWorldWithHeuristic("test1", () => new DiscreteAction_TWO_THREE + policy1.RegisterPolicyWithHeuristic("test1", () => new DiscreteAction_TWO_THREE { action_ONE = TwoOptionEnum.Option_TWO, action_TWO = ThreeOptionEnum.Option_ONE }); - world2.RegisterWorldWithHeuristic("test2", () => new DiscreteAction_TWO_THREE + policy2.RegisterPolicyWithHeuristic("test2", () => new DiscreteAction_TWO_THREE { action_ONE = TwoOptionEnum.Option_ONE, action_TWO = ThreeOptionEnum.Option_TWO @@ -175,8 +175,8 @@ namespace Unity.AI.MLAgents.Tests.Editor var entity = entityManager.CreateEntity(); - world1.RequestDecision(entity); - world2.RequestDecision(entity); + policy1.RequestDecision(entity); + policy2.RequestDecision(entity); var entities = new NativeArray(1, Allocator.Persistent); var actions = new NativeArray(1, Allocator.Persistent); @@ -185,7 +185,7 @@ namespace Unity.AI.MLAgents.Tests.Editor ent = entities, action = actions }; - actionJob1.Schedule(world1, new JobHandle()).Complete(); + actionJob1.Schedule(policy1, new JobHandle()).Complete(); Assert.AreEqual(entity, entities[0]); Assert.AreEqual(new DiscreteAction_TWO_THREE @@ -203,7 +203,7 @@ namespace Unity.AI.MLAgents.Tests.Editor ent = entities, action = actions }; - actionJob2.Schedule(world2, new JobHandle()).Complete(); + actionJob2.Schedule(policy2, new JobHandle()).Complete(); Assert.AreEqual(entity, entities[0]); Assert.AreEqual(new DiscreteAction_TWO_THREE @@ -214,8 +214,8 @@ namespace Unity.AI.MLAgents.Tests.Editor entities.Dispose(); actions.Dispose(); - world1.Dispose(); - world2.Dispose(); + policy1.Dispose(); + policy2.Dispose(); Academy.Instance.Dispose(); } } diff --git a/Tests/Editor/TestMLAgentsWorld.cs.meta b/Tests/Editor/TestMLAgentsPolicy.cs.meta similarity index 100% rename from Tests/Editor/TestMLAgentsWorld.cs.meta rename to Tests/Editor/TestMLAgentsPolicy.cs.meta