From d8916432384f3d744c023590f900ac8d267f3ffa Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Wed, 15 Apr 2020 13:15:09 -0700 Subject: [PATCH] More comments --- Runtime/Academy.cs | 8 +- Runtime/ActionHashMapUtils.cs | 10 +- Runtime/ActionType.cs | 25 +++ Runtime/ActionType.cs.meta | 11 ++ Runtime/ActuatorJob.cs | 4 +- Runtime/ArrayUtils.cs | 50 ------ Runtime/MLAgentsException.cs | 13 +- Runtime/MLAgentsWorld.cs | 10 -- Runtime/Remote/ArgParser.cs | 9 +- Runtime/Remote/BaseSharedMem.cs | 146 +++++++++++++++++- Runtime/UI/MLAgentsWorldSpecs.cs | 11 ++ .../WorldProcessor/BarracudaWorldProcessor.cs | 42 ++++- .../WorldProcessor/HeuristicWorldProcessor.cs | 26 ++++ Runtime/WorldProcessor/IWorldProcessor.cs | 13 ++ 14 files changed, 301 insertions(+), 77 deletions(-) create mode 100644 Runtime/ActionType.cs create mode 100644 Runtime/ActionType.cs.meta diff --git a/Runtime/Academy.cs b/Runtime/Academy.cs index aa2edee..128f6cb 100644 --- a/Runtime/Academy.cs +++ b/Runtime/Academy.cs @@ -12,6 +12,12 @@ using UnityEngine; namespace Unity.AI.MLAgents { + /// + /// The Academy is a singleton that orchestrates the decision making of the + /// decision making of the Agents. + /// It is used to register WorldProcessors to Worlds and to keep track of the + /// reset logic of the simulation. + /// public class Academy : IDisposable { #region Singleton @@ -64,7 +70,7 @@ namespace Unity.AI.MLAgents /// The string identifier of the MLAgentsWorld. There can only be one world per unique id. /// The MLAgentsWorld that is being subscribed. /// If the remote process is not available, the MLAgentsWorld will use this World processor for decision making. - /// If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallbackWorldProcessor otherwise. + /// If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallback worldProcessor otherwise. public void RegisterWorld(string policyId, MLAgentsWorld world, IWorldProcessor worldProcessor = null, bool defaultRemote = true) { // Need to find a way to deregister ? diff --git a/Runtime/ActionHashMapUtils.cs b/Runtime/ActionHashMapUtils.cs index 3d65a63..d4fba0a 100644 --- a/Runtime/ActionHashMapUtils.cs +++ b/Runtime/ActionHashMapUtils.cs @@ -9,13 +9,15 @@ namespace Unity.AI.MLAgents internal static class ActionHashMapUtils { /// - /// Retrieves the action data for a world in puts it into a HashMap + /// Retrieves the action data for a world in puts it into a HashMap. + /// This action deletes the action data from the world. /// /// The MLAgentsWorld the data will be retrieved from. /// The memory allocator of the create NativeHashMap. - /// The type of the Action struct. It must match the Action Size and Action Type of the world. - /// A NativeHashMap from Entities to Action. - public static NativeHashMap GetActionHashMap(this MLAgentsWorld world, Allocator allocator) where T : struct + /// The type of the Action struct. It must match the Action Size + /// and Action Type of the world. + /// A NativeHashMap from Entities to Actions with type T. + public static NativeHashMap GenerateActionHashMap(this MLAgentsWorld world, Allocator allocator) where T : struct { #if ENABLE_UNITY_COLLECTIONS_CHECKS if (world.ActionSize != UnsafeUtility.SizeOf() / 4) diff --git a/Runtime/ActionType.cs b/Runtime/ActionType.cs new file mode 100644 index 0000000..de7515b --- /dev/null +++ b/Runtime/ActionType.cs @@ -0,0 +1,25 @@ +namespace Unity.AI.MLAgents +{ + /// + /// Indicates the action space type of an Agent. + /// + public enum ActionType : int + { + + /// + /// The Action is expected to be a struct of int or int enums. + /// More precisely, struct must be convertible to an array of ints by + /// memory copy. + /// The number of ints in the struct corresponds to the number of actions, + /// but the range of values each action can take is given by the "Branch" size. + /// + DISCRETE = 0, + + /// + /// The Action is expected to be a struct with only float members. + /// More precisely, struct must be convertible to an array of float by + /// memory copy. + /// + CONTINUOUS = 1, + } +} diff --git a/Runtime/ActionType.cs.meta b/Runtime/ActionType.cs.meta new file mode 100644 index 0000000..c29eeac --- /dev/null +++ b/Runtime/ActionType.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a82fe272987e4492cb74d2d80dfb1b4a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/ActuatorJob.cs b/Runtime/ActuatorJob.cs index 01d16bb..4884c04 100644 --- a/Runtime/ActuatorJob.cs +++ b/Runtime/ActuatorJob.cs @@ -93,7 +93,6 @@ namespace Unity.AI.MLAgents internal static unsafe JobHandle ScheduleImpl(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps) where T : struct, IActuatorJob { - // inputDeps.Complete(); // Creating a data struct that contains the data the user passed into the job (This is what T is here) var data = new ActionEventJobData { @@ -124,7 +123,6 @@ namespace Unity.AI.MLAgents jobReflectionData = JobsUtility.CreateJobReflectionData(typeof(ActionEventJobData), typeof(T), JobType.Single, (ExecuteJobFunction)Execute); return jobReflectionData; } - #endregion @@ -135,6 +133,7 @@ namespace Unity.AI.MLAgents { int size = jobData.world.ActionSize; int actionCount = jobData.world.ActionCounter.Count; + // Continuous case if (jobData.world.ActionType == ActionType.CONTINUOUS) { @@ -149,6 +148,7 @@ namespace Unity.AI.MLAgents }); } } + // Discrete Case else { diff --git a/Runtime/ArrayUtils.cs b/Runtime/ArrayUtils.cs index 9a9debe..42c5690 100644 --- a/Runtime/ArrayUtils.cs +++ b/Runtime/ArrayUtils.cs @@ -56,53 +56,3 @@ namespace Unity.AI.MLAgents } } } - - -/* - -/// - /// Computes the total number of bytes necessary in the shared memory file to - /// send and receive data for a specific MLAgentsWorld. - /// - public static int GetRequiredCapacity(MLAgentsWorld world) - { - // # int : 4 bytes : maximum number of Agents - // # bool : 1 byte : is action discrete (False) or continuous (True) - // # int : 4 bytes : action space size (continuous) / number of branches (discrete) - // # -- If discrete only : array of action sizes for each branch - // # int : 4 bytes : number of observations - // # For each observation : - // # 3 int : shape (the shape of the tensor observation for one agent - // # start of the section that will change every step - // # 4 bytes : n_agents at current step - // # ? Bytes : the data : obs,reward,done,max_step,agent_id,masks,action - int capacity = 64; // Name - capacity += 4; // N Max Agents - capacity += 1; // discrete or continuous - capacity += 4; // action Size - if (world.ActionType == ActionType.DISCRETE) - { - capacity += 4 * world.ActionSize; // The action branches - } - capacity += 4; // number of observations - capacity += 3 * 4 * world.SensorShapes.Length; // The observation shapes - capacity += 4; // Number of agents for the current step - - var nAgents = world.Rewards.Length; - capacity += 4 * world.Sensors.Length; - capacity += 4 * nAgents; - capacity += nAgents; - capacity += nAgents; - capacity += 4 * nAgents; - if (world.ActionType == ActionType.DISCRETE) - { - foreach (int branch_size in world.DiscreteActionBranches) - { - capacity += branch_size * nAgents; - } - } - capacity += 4 * world.ActionSize * nAgents; - - return capacity; - } - */ diff --git a/Runtime/MLAgentsException.cs b/Runtime/MLAgentsException.cs index 57fd196..359fe02 100644 --- a/Runtime/MLAgentsException.cs +++ b/Runtime/MLAgentsException.cs @@ -2,11 +2,16 @@ using System; namespace Unity.AI.MLAgents { + /// /// Contains exceptions specific to ML-Agents. - internal class MLAgentsException : Exception + /// + public class MLAgentsException : Exception { - public MLAgentsException(string message) : base(message) - { - } + /// + /// This exception indicates an error occurred in the ML-Agents package + /// + /// Text message for the error + /// + public MLAgentsException(string message) : base(message){} } } diff --git a/Runtime/MLAgentsWorld.cs b/Runtime/MLAgentsWorld.cs index bc5ca50..eeacab0 100644 --- a/Runtime/MLAgentsWorld.cs +++ b/Runtime/MLAgentsWorld.cs @@ -6,16 +6,6 @@ using System; namespace Unity.AI.MLAgents { - /// - /// Indicates the action space type of an Agent action. - /// If CONTINUOUS, the Action is expected to be a struct containing only floats - /// If DISCRETE, the Action is expected to be a struct of int or int enums. - /// - public enum ActionType : int - { - DISCRETE = 0, - CONTINUOUS = 1, - } /// /// MLAgentsWorld is a data container on which the user requests decisions. diff --git a/Runtime/Remote/ArgParser.cs b/Runtime/Remote/ArgParser.cs index fdaab8c..170f6b6 100644 --- a/Runtime/Remote/ArgParser.cs +++ b/Runtime/Remote/ArgParser.cs @@ -8,7 +8,12 @@ namespace Unity.AI.MLAgents private const string k_MemoryFileArgument = "--memory-path"; private const string k_MemoryFilesDirectory = "ml-agents"; private const string k_DefaultFileName = "default"; - // Used to read Python-provided environment parameters + + /// + /// Used to read Python-provided environment parameters. Will return the + /// string corresponding to the path to the shared memory file. + /// + /// Path to the shared memroy file or null if no file is available public static string ReadSharedMemoryPathFromArgs() { var args = System.Environment.GetCommandLineArgs(); @@ -20,7 +25,7 @@ namespace Unity.AI.MLAgents } } #if UNITY_EDITOR - // Try connecting on the default editor port + // Try connecting on the default shared memory file var path = Path.Combine( Path.GetTempPath(), k_MemoryFilesDirectory, diff --git a/Runtime/Remote/BaseSharedMem.cs b/Runtime/Remote/BaseSharedMem.cs index cba6b6c..ea230eb 100644 --- a/Runtime/Remote/BaseSharedMem.cs +++ b/Runtime/Remote/BaseSharedMem.cs @@ -14,6 +14,16 @@ namespace Unity.AI.MLAgents private IntPtr accessorPointer; private string filePath; + /// + /// A bare bone shared memory wrapper that opens or create a new shared + /// memory file and connects to it. + /// + /// The name of the file. The file must be present + /// in the Temporary folder (depends on OS) in the directory + /// If true, the file will be created by the constructor. + /// And error will be thrown if the file already exists. + /// The size of the file to be created (will be ignored if + /// is true. public BaseSharedMem(string fileName, bool createFile, int size = 0) { var directoryPath = Path.Combine(Path.GetTempPath(), k_Directory); @@ -50,6 +60,12 @@ namespace Unity.AI.MLAgents accessorPointer = new IntPtr(ptr); } + /// + /// Returns the integer present at the specified offset. The must be + /// passed by reference and will be incremented to the next value. + /// + /// Where to read the value + /// public int GetInt(ref int offset) { var result = accessor.ReadInt32(offset); @@ -57,6 +73,12 @@ namespace Unity.AI.MLAgents return result; } + /// + /// Returns the float present at the specified offset. The must be + /// passed by reference and will be incremented to the next value. + /// + /// Where to read the value + /// public float GetFloat(ref int offset) { var result = accessor.ReadSingle(offset); @@ -64,6 +86,12 @@ namespace Unity.AI.MLAgents return result; } + /// + /// Returns the boolean present at the specified offset. The must be + /// passed by reference and will be incremented to the next value. + /// + /// Where to read the value + /// public bool GetBool(ref int offset) { var result = accessor.ReadBoolean(offset); @@ -71,6 +99,15 @@ namespace Unity.AI.MLAgents return result; } + /// + /// Returns the string present at the specified offset. The must be + /// passed by reference and will be incremented to the next value. + /// Note, the strings must be represented with : a sbyte (8-bit unsigned) + /// indicating the size of the string in bytes followed by the bytes of the + /// string (in ASCII format). + /// + /// Where to read the value + /// public string GetString(ref int offset) { var length = accessor.ReadSByte(offset); @@ -81,6 +118,15 @@ namespace Unity.AI.MLAgents return Encoding.ASCII.GetString(bytes); } + /// + /// Copies the values in the memory file at offset to the NativeArray + /// The must be passed by reference and will be incremented to the next value. + /// Where to read the value + /// The destination array + /// The number of bytes that must be copied, must be less or equal + /// to the capacity of the array. + /// The type of the NativeArray. Must be a struct. + /// public void GetArray(ref int offset, NativeArray array, int length) where T : struct { IntPtr src = IntPtr.Add(accessorPointer, offset); @@ -89,6 +135,13 @@ namespace Unity.AI.MLAgents offset += length; } + /// + /// Returns the byte array present at the specified offset. The must be + /// passed by reference and will be incremented to the next value. + /// + /// Where to read the value + /// The size of the array in bytes + /// public byte[] GetBytes(ref int offset, int length) { var result = new byte[length]; @@ -97,21 +150,44 @@ namespace Unity.AI.MLAgents return result; } + /// + /// Returns the integer present at the specified offset. + /// + /// Where to read the value. + /// public int GetInt(int offset) { return accessor.ReadInt32(offset); } + /// + /// Returns the float present at the specified offset. + /// + /// Where to read the value. + /// public float GetFloat(int offset) { return accessor.ReadSingle(offset); } + /// + /// Returns the boolean present at the specified offset. + /// + /// Where to read the value. + /// public bool GetBool(int offset) { return accessor.ReadBoolean(offset); } + /// + /// Returns the string present at the specified offset. + /// Note, the strings must be represented with : a sbyte (8-bit unsigned) + /// indicating the size of the string in bytes followed by the bytes of the + /// string (in ASCII format). + /// + /// Where to read the value + /// public string GetString(int offset) { var length = accessor.ReadSByte(offset); @@ -121,6 +197,14 @@ namespace Unity.AI.MLAgents return Encoding.ASCII.GetString(bytes); } + /// + /// Copies the values in the memory file at offset to the NativeArray + /// Where to read the value + /// The destination array + /// The number of bytes that must be copied, must be less or equal + /// to the capacity of the array. + /// The type of the NativeArray. Must be a struct. + /// public void GetArray(int offset, NativeArray array, int length) where T : struct { IntPtr src = IntPtr.Add(accessorPointer, offset); @@ -128,6 +212,12 @@ namespace Unity.AI.MLAgents Buffer.MemoryCopy(src.ToPointer(), dst.ToPointer(), length, length); } + /// + /// Returns the byte array present at the specified offset. + /// + /// Where to read the value + /// The size of the array in bytes + /// public byte[] GetBytes(int offset, int length) { var result = new byte[length]; @@ -138,24 +228,50 @@ namespace Unity.AI.MLAgents return result; } + /// + /// Sets the integer at the specified offset in the shared memory. + /// + /// The position at which to write the value + /// The value to be written + /// The offset right after the written value. public int SetInt(int offset, int value) { accessor.Write(offset, value); return offset + 4; } + /// + /// Sets the float at the specified offset in the shared memory. + /// + /// The position at which to write the value + /// The value to be written + /// The offset right after the written value. public int SetFloat(int offset, float value) { accessor.Write(offset, value); return offset + 4; } + /// + /// Sets the boolean at the specified offset in the shared memory. + /// + /// The position at which to write the value + /// The value to be written + /// The offset right after the written value. public int SetBool(int offset, bool value) { accessor.Write(offset, value); return offset + 1; } + /// + /// Sets the string at the specified offset in the shared memory. + /// Note that the string is represented with a sbyte (8-bit unsigned) number + /// encoding the length of the string and the string bytes in ASCII. + /// + /// The position at which to write the value + /// The value to be written + /// The offset right after the written value. public int SetString(int offset, string value) { var bytes = Encoding.ASCII.GetBytes(value); @@ -167,20 +283,38 @@ namespace Unity.AI.MLAgents return offset; } - public int SetArray(int offset, NativeArray arr, int length) where T : struct + /// + /// Copies the values present in a NativeArray to the shared memory file. + /// + /// The position at which to write the value + /// The NativeArray containing the data to write + /// The size of the array in bytes + /// The type of the NativeArray. Must be a struct. + /// The offset right after the written value. + public int SetArray(int offset, NativeArray array, int length) where T : struct { IntPtr dst = IntPtr.Add(accessorPointer, offset); - IntPtr src = new IntPtr(arr.GetUnsafePtr()); + IntPtr src = new IntPtr(array.GetUnsafePtr()); Buffer.MemoryCopy(src.ToPointer(), dst.ToPointer(), length, length); return offset + length; } + /// + /// Sets the byte array at the specified offset in the shared memory. + /// + /// The position at which to write the value + /// The value to be written + /// The offset right after the written value. public int SetBytes(int offset, byte[] value) { accessor.WriteArray(offset, value, 0, value.Length); return offset + value.Length; } + /// + /// Closes the shared memory reader and writter. This will not delete the file but + /// remove access to it. + /// public void Close() { if (accessor.CanWrite) @@ -190,6 +324,11 @@ namespace Unity.AI.MLAgents } } + /// + /// Indicates wether the has write access to the file. + /// If it does not, it means the accessor was closed or the file was deleted. + /// + /// protected bool CanEdit { get @@ -198,6 +337,9 @@ namespace Unity.AI.MLAgents } } + /// + /// Closes and deletes the shared memory file + /// public void Delete() { Close(); diff --git a/Runtime/UI/MLAgentsWorldSpecs.cs b/Runtime/UI/MLAgentsWorldSpecs.cs index 979ae5a..e2ae138 100644 --- a/Runtime/UI/MLAgentsWorldSpecs.cs +++ b/Runtime/UI/MLAgentsWorldSpecs.cs @@ -13,6 +13,11 @@ namespace Unity.AI.MLAgents None } + /// + /// An editor friendly constructor for a MLAgentsWorld. + /// Keeps track of the behavior specs of a world, its name, + /// its processor type and Neural Network Model. + /// [Serializable] public struct MLAgentsWorldSpecs { @@ -31,6 +36,12 @@ namespace Unity.AI.MLAgents private MLAgentsWorld m_World; + /// + /// Generates the world using the specified specs and registers its + /// processor to the Academy. The world is only created and registed once, + /// if subsequent calls are made, the created world will be returned. + /// + /// public MLAgentsWorld GetWorld() { if (m_World.IsCreated) diff --git a/Runtime/WorldProcessor/BarracudaWorldProcessor.cs b/Runtime/WorldProcessor/BarracudaWorldProcessor.cs index ebc3239..11b9a20 100644 --- a/Runtime/WorldProcessor/BarracudaWorldProcessor.cs +++ b/Runtime/WorldProcessor/BarracudaWorldProcessor.cs @@ -5,14 +5,37 @@ using Barracuda; namespace Unity.AI.MLAgents { + /// + /// Where to perform inference. + /// public enum InferenceDevice { - CPU, - GPU + /// + /// CPU inference + /// + CPU = 0, + + /// + /// GPU inference + /// + GPU = 1 } public static class BarracudaWorldProcessorRegistringExtension { + /// + /// Registers the given MLAgentsWorld to the Academy with a Neural + /// Network Model. If the input model is null, a default inactive + /// processor will be registered instead. Note that if the simulation + /// connects to Python, the Neural Network will be ignored and the world + /// will exchange data with Python instead. + /// + /// The MLAgentsWorld to register + /// The name of the world. This is useful for identification + /// and for training. + /// The Neural Network model used by the processor + /// The inference device specifying where to run inference + /// (CPU or GPU) public static void RegisterWorldWithBarracudaModel( this MLAgentsWorld world, string policyId, @@ -31,6 +54,19 @@ namespace Unity.AI.MLAgents } } + /// + /// Registers the given MLAgentsWorld to the Academy with a Neural + /// Network Model. If the input model is null, a default inactive + /// processor will be registered instead. Note that if the simulation + /// connects to Python, the world will not connect to Python and run the + /// given Neural Network regardless. + /// + /// The MLAgentsWorld to register + /// The name of the world. This is useful for identification + /// and for training. + /// The Neural Network model used by the processor + /// The inference device specifying where to run inference + /// (CPU or GPU) public static void RegisterWorldWithBarracudaModelForceNoCommunication( this MLAgentsWorld world, string policyId, @@ -49,6 +85,7 @@ namespace Unity.AI.MLAgents } } } + internal unsafe class BarracudaWorldProcessor : IWorldProcessor { MLAgentsWorld world; @@ -78,6 +115,7 @@ namespace Unity.AI.MLAgents public void ProcessWorld() { + // TODO : Cover all cases // FOR VECTOR OBS ONLY // For Continuous control only // No LSTM diff --git a/Runtime/WorldProcessor/HeuristicWorldProcessor.cs b/Runtime/WorldProcessor/HeuristicWorldProcessor.cs index bcc7a79..babbeb3 100644 --- a/Runtime/WorldProcessor/HeuristicWorldProcessor.cs +++ b/Runtime/WorldProcessor/HeuristicWorldProcessor.cs @@ -7,6 +7,19 @@ namespace Unity.AI.MLAgents { public static class HeristicWorldProcessorRegistringExtension { + /// + /// Registers the given MLAgentsWorld to the Academy with a Heuristic. + /// Note that if the simulation connects to Python, the Heuristic will + /// be ignored and the world will exchange data with Python instead. + /// The Heuristic is a Function that returns an action struct. + /// + /// The MLAgentsWorld to register + /// The name of the world. This is useful for identification + /// and for training. + /// The Heuristic used to generate the actions. + /// Note that all agents in the world will receive the same action. + /// The type of the Action struct. It must match the Action + /// Size and Action Type of the world. public static void RegisterWorldWithHeuristic( this MLAgentsWorld world, string policyId, @@ -17,6 +30,19 @@ namespace Unity.AI.MLAgents Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true); } + /// + /// Registers the given MLAgentsWorld to the Academy with a Heuristic. + /// Note that if the simulation connects to Python, the world will not + /// exchange data with Python and use the Heuristic regardless. + /// The Heuristic is a Function that returns an action struct. + /// + /// The MLAgentsWorld to register + /// The name of the world. This is useful for identification + /// and for training. + /// The Heuristic used to generate the actions. + /// Note that all agents in the world will receive the same action. + /// The type of the Action struct. It must match the Action + /// Size and Action Type of the world. public static void RegisterWorldWithHeuristicForceNoCommunication( this MLAgentsWorld world, string policyId, diff --git a/Runtime/WorldProcessor/IWorldProcessor.cs b/Runtime/WorldProcessor/IWorldProcessor.cs index ae4cf57..53f3fa0 100644 --- a/Runtime/WorldProcessor/IWorldProcessor.cs +++ b/Runtime/WorldProcessor/IWorldProcessor.cs @@ -5,9 +5,22 @@ using Unity.Collections.LowLevel.Unsafe; namespace Unity.AI.MLAgents { + /// + /// The interface for a world processor. A world processor updates the + /// action data of a world using the observation data present in it. + /// public interface IWorldProcessor : IDisposable { + /// + /// True if the World Processor is connected to the Python process + /// bool IsConnected {get;} + + /// + /// This method is called once everytime the world needs to update its action + /// data. The implementation of this mehtod must give new action data to all the + /// agents that requested a decision. + /// void ProcessWorld(); } }