This commit is contained in:
vincentpierre 2020-04-15 13:15:09 -07:00
Родитель f26b31e9a6
Коммит d891643238
14 изменённых файлов: 301 добавлений и 77 удалений

Просмотреть файл

@ -12,6 +12,12 @@ using UnityEngine;
namespace Unity.AI.MLAgents
{
/// <summary>
/// The Academy is a singleton that orchestrates the decision making of the
/// decision making of the Agents.
/// It is used to register WorldProcessors to Worlds and to keep track of the
/// reset logic of the simulation.
/// </summary>
public class Academy : IDisposable
{
#region Singleton
@ -64,7 +70,7 @@ namespace Unity.AI.MLAgents
/// <param name="policyId"> The string identifier of the MLAgentsWorld. There can only be one world per unique id.</param>
/// <param name="world"> The MLAgentsWorld that is being subscribed.</param>
/// <param name="worldProcessor"> If the remote process is not available, the MLAgentsWorld will use this World processor for decision making.</param>
/// <param name="defaultRemote"> If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallbackWorldProcessor otherwise.</param>
/// <param name="defaultRemote"> If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallback worldProcessor otherwise.</param>
public void RegisterWorld(string policyId, MLAgentsWorld world, IWorldProcessor worldProcessor = null, bool defaultRemote = true)
{
// Need to find a way to deregister ?

Просмотреть файл

@ -9,13 +9,15 @@ namespace Unity.AI.MLAgents
internal static class ActionHashMapUtils
{
/// <summary>
/// Retrieves the action data for a world in puts it into a HashMap
/// Retrieves the action data for a world in puts it into a HashMap.
/// This action deletes the action data from the world.
/// </summary>
/// <param name="world"> The MLAgentsWorld the data will be retrieved from.</param>
/// <param name="allocator"> The memory allocator of the create NativeHashMap.</param>
/// <typeparam name="T"> The type of the Action struct. It must match the Action Size and Action Type of the world.</typeparam>
/// <returns> A NativeHashMap from Entities to Action.</returns>
public static NativeHashMap<Entity, T> GetActionHashMap<T>(this MLAgentsWorld world, Allocator allocator) where T : struct
/// <typeparam name="T"> The type of the Action struct. It must match the Action Size
/// and Action Type of the world.</typeparam>
/// <returns> A NativeHashMap from Entities to Actions with type T.</returns>
public static NativeHashMap<Entity, T> GenerateActionHashMap<T>(this MLAgentsWorld world, Allocator allocator) where T : struct
{
#if ENABLE_UNITY_COLLECTIONS_CHECKS
if (world.ActionSize != UnsafeUtility.SizeOf<T>() / 4)

25
Runtime/ActionType.cs Normal file
Просмотреть файл

@ -0,0 +1,25 @@
namespace Unity.AI.MLAgents
{
/// <summary>
/// Indicates the action space type of an Agent.
/// </summary>
public enum ActionType : int
{
/// <summary>
/// The Action is expected to be a struct of int or int enums.
/// More precisely, struct must be convertible to an array of ints by
/// memory copy.
/// The number of ints in the struct corresponds to the number of actions,
/// but the range of values each action can take is given by the "Branch" size.
/// </summary>
DISCRETE = 0,
/// <summary>
/// The Action is expected to be a struct with only float members.
/// More precisely, struct must be convertible to an array of float by
/// memory copy.
/// </summary>
CONTINUOUS = 1,
}
}

Просмотреть файл

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: a82fe272987e4492cb74d2d80dfb1b4a
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

Просмотреть файл

@ -93,7 +93,6 @@ namespace Unity.AI.MLAgents
internal static unsafe JobHandle ScheduleImpl<T>(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps)
where T : struct, IActuatorJob
{
// inputDeps.Complete();
// Creating a data struct that contains the data the user passed into the job (This is what T is here)
var data = new ActionEventJobData<T>
{
@ -124,7 +123,6 @@ namespace Unity.AI.MLAgents
jobReflectionData = JobsUtility.CreateJobReflectionData(typeof(ActionEventJobData<T>), typeof(T), JobType.Single, (ExecuteJobFunction)Execute);
return jobReflectionData;
}
#endregion
@ -135,6 +133,7 @@ namespace Unity.AI.MLAgents
{
int size = jobData.world.ActionSize;
int actionCount = jobData.world.ActionCounter.Count;
// Continuous case
if (jobData.world.ActionType == ActionType.CONTINUOUS)
{
@ -149,6 +148,7 @@ namespace Unity.AI.MLAgents
});
}
}
// Discrete Case
else
{

Просмотреть файл

@ -56,53 +56,3 @@ namespace Unity.AI.MLAgents
}
}
}
/*
/// <summary>
/// Computes the total number of bytes necessary in the shared memory file to
/// send and receive data for a specific MLAgentsWorld.
/// </summary>
public static int GetRequiredCapacity(MLAgentsWorld world)
{
// # int : 4 bytes : maximum number of Agents
// # bool : 1 byte : is action discrete (False) or continuous (True)
// # int : 4 bytes : action space size (continuous) / number of branches (discrete)
// # -- If discrete only : array of action sizes for each branch
// # int : 4 bytes : number of observations
// # For each observation :
// # 3 int : shape (the shape of the tensor observation for one agent
// # start of the section that will change every step
// # 4 bytes : n_agents at current step
// # ? Bytes : the data : obs,reward,done,max_step,agent_id,masks,action
int capacity = 64; // Name
capacity += 4; // N Max Agents
capacity += 1; // discrete or continuous
capacity += 4; // action Size
if (world.ActionType == ActionType.DISCRETE)
{
capacity += 4 * world.ActionSize; // The action branches
}
capacity += 4; // number of observations
capacity += 3 * 4 * world.SensorShapes.Length; // The observation shapes
capacity += 4; // Number of agents for the current step
var nAgents = world.Rewards.Length;
capacity += 4 * world.Sensors.Length;
capacity += 4 * nAgents;
capacity += nAgents;
capacity += nAgents;
capacity += 4 * nAgents;
if (world.ActionType == ActionType.DISCRETE)
{
foreach (int branch_size in world.DiscreteActionBranches)
{
capacity += branch_size * nAgents;
}
}
capacity += 4 * world.ActionSize * nAgents;
return capacity;
}
*/

Просмотреть файл

@ -2,11 +2,16 @@ using System;
namespace Unity.AI.MLAgents
{
/// <summary>
/// Contains exceptions specific to ML-Agents.
internal class MLAgentsException : Exception
/// </summary>
public class MLAgentsException : Exception
{
public MLAgentsException(string message) : base(message)
{
}
/// <summary>
/// This exception indicates an error occurred in the ML-Agents package
/// </summary>
/// <param name="message">Text message for the error</param>
/// <returns></returns>
public MLAgentsException(string message) : base(message){}
}
}

Просмотреть файл

@ -6,16 +6,6 @@ using System;
namespace Unity.AI.MLAgents
{
/// <summary>
/// Indicates the action space type of an Agent action.
/// If CONTINUOUS, the Action is expected to be a struct containing only floats
/// If DISCRETE, the Action is expected to be a struct of int or int enums.
/// </summary>
public enum ActionType : int
{
DISCRETE = 0,
CONTINUOUS = 1,
}
/// <summary>
/// MLAgentsWorld is a data container on which the user requests decisions.

Просмотреть файл

@ -8,7 +8,12 @@ namespace Unity.AI.MLAgents
private const string k_MemoryFileArgument = "--memory-path";
private const string k_MemoryFilesDirectory = "ml-agents";
private const string k_DefaultFileName = "default";
// Used to read Python-provided environment parameters
/// <summary>
/// Used to read Python-provided environment parameters. Will return the
/// string corresponding to the path to the shared memory file.
/// </summary>
/// <returns> Path to the shared memroy file or null if no file is available</returns>
public static string ReadSharedMemoryPathFromArgs()
{
var args = System.Environment.GetCommandLineArgs();
@ -20,7 +25,7 @@ namespace Unity.AI.MLAgents
}
}
#if UNITY_EDITOR
// Try connecting on the default editor port
// Try connecting on the default shared memory file
var path = Path.Combine(
Path.GetTempPath(),
k_MemoryFilesDirectory,

Просмотреть файл

@ -14,6 +14,16 @@ namespace Unity.AI.MLAgents
private IntPtr accessorPointer;
private string filePath;
/// <summary>
/// A bare bone shared memory wrapper that opens or create a new shared
/// memory file and connects to it.
/// </summary>
/// <param name="fileName"> The name of the file. The file must be present
/// in the Temporary folder (depends on OS) in the directory <see cref="k_Directory"/></param>
/// <param name="createFile"> If true, the file will be created by the constructor.
/// And error will be thrown if the file already exists.</param>
/// <param name="size"> The size of the file to be created (will be ignored if
/// <see cref="createFile"/> is true.</param>
public BaseSharedMem(string fileName, bool createFile, int size = 0)
{
var directoryPath = Path.Combine(Path.GetTempPath(), k_Directory);
@ -50,6 +60,12 @@ namespace Unity.AI.MLAgents
accessorPointer = new IntPtr(ptr);
}
/// <summary>
/// Returns the integer present at the specified offset. The <see cref="offset"/> must be
/// passed by reference and will be incremented to the next value.
/// </summary>
/// <param name="offset"> Where to read the value</param>
/// <returns></returns>
public int GetInt(ref int offset)
{
var result = accessor.ReadInt32(offset);
@ -57,6 +73,12 @@ namespace Unity.AI.MLAgents
return result;
}
/// <summary>
/// Returns the float present at the specified offset. The <see cref="offset"/> must be
/// passed by reference and will be incremented to the next value.
/// </summary>
/// <param name="offset"> Where to read the value</param>
/// <returns></returns>
public float GetFloat(ref int offset)
{
var result = accessor.ReadSingle(offset);
@ -64,6 +86,12 @@ namespace Unity.AI.MLAgents
return result;
}
/// <summary>
/// Returns the boolean present at the specified offset. The <see cref="offset"/> must be
/// passed by reference and will be incremented to the next value.
/// </summary>
/// <param name="offset"> Where to read the value</param>
/// <returns></returns>
public bool GetBool(ref int offset)
{
var result = accessor.ReadBoolean(offset);
@ -71,6 +99,15 @@ namespace Unity.AI.MLAgents
return result;
}
/// <summary>
/// Returns the string present at the specified offset. The <see cref="offset"/> must be
/// passed by reference and will be incremented to the next value.
/// Note, the strings must be represented with : a sbyte (8-bit unsigned)
/// indicating the size of the string in bytes followed by the bytes of the
/// string (in ASCII format).
/// </summary>
/// <param name="offset"> Where to read the value</param>
/// <returns></returns>
public string GetString(ref int offset)
{
var length = accessor.ReadSByte(offset);
@ -81,6 +118,15 @@ namespace Unity.AI.MLAgents
return Encoding.ASCII.GetString(bytes);
}
/// <summary>
/// Copies the values in the memory file at offset to the NativeArray <see cref="array"/>
/// The <see cref="offset"/> must be passed by reference and will be incremented to the next value.
/// <param name="offset"> Where to read the value</param>
/// <param name="array"> The destination array</param>
/// <param name="length"> The number of bytes that must be copied, must be less or equal
/// to the capacity of the array.</param>
/// <typeparam name="T"> The type of the NativeArray. Must be a struct.</typeparam>
/// </summary>
public void GetArray<T>(ref int offset, NativeArray<T> array, int length) where T : struct
{
IntPtr src = IntPtr.Add(accessorPointer, offset);
@ -89,6 +135,13 @@ namespace Unity.AI.MLAgents
offset += length;
}
/// <summary>
/// Returns the byte array present at the specified offset. The <see cref="offset"/> must be
/// passed by reference and will be incremented to the next value.
/// </summary>
/// <param name="offset"> Where to read the value</param>
/// <param name="length"> The size of the array in bytes</param>
/// <returns></returns>
public byte[] GetBytes(ref int offset, int length)
{
var result = new byte[length];
@ -97,21 +150,44 @@ namespace Unity.AI.MLAgents
return result;
}
/// <summary>
/// Returns the integer present at the specified offset.
/// </summary>
/// <param name="offset"> Where to read the value.</param>
/// <returns></returns>
public int GetInt(int offset)
{
return accessor.ReadInt32(offset);
}
/// <summary>
/// Returns the float present at the specified offset.
/// </summary>
/// <param name="offset"> Where to read the value.</param>
/// <returns></returns>
public float GetFloat(int offset)
{
return accessor.ReadSingle(offset);
}
/// <summary>
/// Returns the boolean present at the specified offset.
/// </summary>
/// <param name="offset"> Where to read the value.</param>
/// <returns></returns>
public bool GetBool(int offset)
{
return accessor.ReadBoolean(offset);
}
/// <summary>
/// Returns the string present at the specified offset.
/// Note, the strings must be represented with : a sbyte (8-bit unsigned)
/// indicating the size of the string in bytes followed by the bytes of the
/// string (in ASCII format).
/// </summary>
/// <param name="offset"> Where to read the value</param>
/// <returns></returns>
public string GetString(int offset)
{
var length = accessor.ReadSByte(offset);
@ -121,6 +197,14 @@ namespace Unity.AI.MLAgents
return Encoding.ASCII.GetString(bytes);
}
/// <summary>
/// Copies the values in the memory file at offset to the NativeArray <see cref="array"/>
/// <param name="offset"> Where to read the value</param>
/// <param name="array"> The destination array</param>
/// <param name="length"> The number of bytes that must be copied, must be less or equal
/// to the capacity of the array.</param>
/// <typeparam name="T"> The type of the NativeArray. Must be a struct.</typeparam>
/// </summary>
public void GetArray<T>(int offset, NativeArray<T> array, int length) where T : struct
{
IntPtr src = IntPtr.Add(accessorPointer, offset);
@ -128,6 +212,12 @@ namespace Unity.AI.MLAgents
Buffer.MemoryCopy(src.ToPointer(), dst.ToPointer(), length, length);
}
/// <summary>
/// Returns the byte array present at the specified offset.
/// </summary>
/// <param name="offset"> Where to read the value</param>
/// <param name="length"> The size of the array in bytes</param>
/// <returns></returns>
public byte[] GetBytes(int offset, int length)
{
var result = new byte[length];
@ -138,24 +228,50 @@ namespace Unity.AI.MLAgents
return result;
}
/// <summary>
/// Sets the integer at the specified offset in the shared memory.
/// </summary>
/// <param name="offset"> The position at which to write the value</param>
/// <param name="value"> The value to be written</param>
/// <returns> The offset right after the written value.</returns>
public int SetInt(int offset, int value)
{
accessor.Write(offset, value);
return offset + 4;
}
/// <summary>
/// Sets the float at the specified offset in the shared memory.
/// </summary>
/// <param name="offset"> The position at which to write the value</param>
/// <param name="value"> The value to be written</param>
/// <returns> The offset right after the written value.</returns>
public int SetFloat(int offset, float value)
{
accessor.Write(offset, value);
return offset + 4;
}
/// <summary>
/// Sets the boolean at the specified offset in the shared memory.
/// </summary>
/// <param name="offset"> The position at which to write the value</param>
/// <param name="value"> The value to be written</param>
/// <returns> The offset right after the written value.</returns>
public int SetBool(int offset, bool value)
{
accessor.Write(offset, value);
return offset + 1;
}
/// <summary>
/// Sets the string at the specified offset in the shared memory.
/// Note that the string is represented with a sbyte (8-bit unsigned) number
/// encoding the length of the string and the string bytes in ASCII.
/// </summary>
/// <param name="offset"> The position at which to write the value</param>
/// <param name="value"> The value to be written</param>
/// <returns> The offset right after the written value.</returns>
public int SetString(int offset, string value)
{
var bytes = Encoding.ASCII.GetBytes(value);
@ -167,20 +283,38 @@ namespace Unity.AI.MLAgents
return offset;
}
public int SetArray<T>(int offset, NativeArray<T> arr, int length) where T : struct
/// <summary>
/// Copies the values present in a NativeArray to the shared memory file.
/// </summary>
/// <param name="offset"> The position at which to write the value</param>
/// <param name="array"> The NativeArray containing the data to write</param>
/// <param name="length"> The size of the array in bytes</param>
/// <typeparam name="T"> The type of the NativeArray. Must be a struct.</typeparam>
/// <returns> The offset right after the written value.</returns>
public int SetArray<T>(int offset, NativeArray<T> array, int length) where T : struct
{
IntPtr dst = IntPtr.Add(accessorPointer, offset);
IntPtr src = new IntPtr(arr.GetUnsafePtr());
IntPtr src = new IntPtr(array.GetUnsafePtr());
Buffer.MemoryCopy(src.ToPointer(), dst.ToPointer(), length, length);
return offset + length;
}
/// <summary>
/// Sets the byte array at the specified offset in the shared memory.
/// </summary>
/// <param name="offset"> The position at which to write the value</param>
/// <param name="value"> The value to be written</param>
/// <returns> The offset right after the written value.</returns>
public int SetBytes(int offset, byte[] value)
{
accessor.WriteArray(offset, value, 0, value.Length);
return offset + value.Length;
}
/// <summary>
/// Closes the shared memory reader and writter. This will not delete the file but
/// remove access to it.
/// </summary>
public void Close()
{
if (accessor.CanWrite)
@ -190,6 +324,11 @@ namespace Unity.AI.MLAgents
}
}
/// <summary>
/// Indicates wether the <see cref="BaseSharedMem"/> has write access to the file.
/// If it does not, it means the accessor was closed or the file was deleted.
/// </summary>
/// <value></value>
protected bool CanEdit
{
get
@ -198,6 +337,9 @@ namespace Unity.AI.MLAgents
}
}
/// <summary>
/// Closes and deletes the shared memory file
/// </summary>
public void Delete()
{
Close();

Просмотреть файл

@ -13,6 +13,11 @@ namespace Unity.AI.MLAgents
None
}
/// <summary>
/// An editor friendly constructor for a MLAgentsWorld.
/// Keeps track of the behavior specs of a world, its name,
/// its processor type and Neural Network Model.
/// </summary>
[Serializable]
public struct MLAgentsWorldSpecs
{
@ -31,6 +36,12 @@ namespace Unity.AI.MLAgents
private MLAgentsWorld m_World;
/// <summary>
/// Generates the world using the specified specs and registers its
/// processor to the Academy. The world is only created and registed once,
/// if subsequent calls are made, the created world will be returned.
/// </summary>
/// <returns></returns>
public MLAgentsWorld GetWorld()
{
if (m_World.IsCreated)

Просмотреть файл

@ -5,14 +5,37 @@ using Barracuda;
namespace Unity.AI.MLAgents
{
/// <summary>
/// Where to perform inference.
/// </summary>
public enum InferenceDevice
{
CPU,
GPU
/// <summary>
/// CPU inference
/// </summary>
CPU = 0,
/// <summary>
/// GPU inference
/// </summary>
GPU = 1
}
public static class BarracudaWorldProcessorRegistringExtension
{
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Neural
/// Network Model. If the input model is null, a default inactive
/// processor will be registered instead. Note that if the simulation
/// connects to Python, the Neural Network will be ignored and the world
/// will exchange data with Python instead.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// and for training.</param>
/// <param name="model"> The Neural Network model used by the processor</param>
/// <param name="inferenceDevice"> The inference device specifying where to run inference
/// (CPU or GPU)</param>
public static void RegisterWorldWithBarracudaModel(
this MLAgentsWorld world,
string policyId,
@ -31,6 +54,19 @@ namespace Unity.AI.MLAgents
}
}
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Neural
/// Network Model. If the input model is null, a default inactive
/// processor will be registered instead. Note that if the simulation
/// connects to Python, the world will not connect to Python and run the
/// given Neural Network regardless.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// and for training.</param>
/// <param name="model"> The Neural Network model used by the processor</param>
/// <param name="inferenceDevice"> The inference device specifying where to run inference
/// (CPU or GPU)</param>
public static void RegisterWorldWithBarracudaModelForceNoCommunication(
this MLAgentsWorld world,
string policyId,
@ -49,6 +85,7 @@ namespace Unity.AI.MLAgents
}
}
}
internal unsafe class BarracudaWorldProcessor : IWorldProcessor
{
MLAgentsWorld world;
@ -78,6 +115,7 @@ namespace Unity.AI.MLAgents
public void ProcessWorld()
{
// TODO : Cover all cases
// FOR VECTOR OBS ONLY
// For Continuous control only
// No LSTM

Просмотреть файл

@ -7,6 +7,19 @@ namespace Unity.AI.MLAgents
{
public static class HeristicWorldProcessorRegistringExtension
{
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
/// Note that if the simulation connects to Python, the Heuristic will
/// be ignored and the world will exchange data with Python instead.
/// The Heuristic is a Function that returns an action struct.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// and for training.</param>
/// <param name="heuristic"> The Heuristic used to generate the actions.
/// Note that all agents in the world will receive the same action.</param>
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
/// Size and Action Type of the world.</typeparam>
public static void RegisterWorldWithHeuristic<TH>(
this MLAgentsWorld world,
string policyId,
@ -17,6 +30,19 @@ namespace Unity.AI.MLAgents
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true);
}
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
/// Note that if the simulation connects to Python, the world will not
/// exchange data with Python and use the Heuristic regardless.
/// The Heuristic is a Function that returns an action struct.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// and for training.</param>
/// <param name="heuristic"> The Heuristic used to generate the actions.
/// Note that all agents in the world will receive the same action.</param>
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
/// Size and Action Type of the world.</typeparam>
public static void RegisterWorldWithHeuristicForceNoCommunication<TH>(
this MLAgentsWorld world,
string policyId,

Просмотреть файл

@ -5,9 +5,22 @@ using Unity.Collections.LowLevel.Unsafe;
namespace Unity.AI.MLAgents
{
/// <summary>
/// The interface for a world processor. A world processor updates the
/// action data of a world using the observation data present in it.
/// </summary>
public interface IWorldProcessor : IDisposable
{
/// <summary>
/// True if the World Processor is connected to the Python process
/// </summary>
bool IsConnected {get;}
/// <summary>
/// This method is called once everytime the world needs to update its action
/// data. The implementation of this mehtod must give new action data to all the
/// agents that requested a decision.
/// </summary>
void ProcessWorld();
}
}