More comments
This commit is contained in:
Родитель
f26b31e9a6
Коммит
d891643238
|
@ -12,6 +12,12 @@ using UnityEngine;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// The Academy is a singleton that orchestrates the decision making of the
|
||||
/// decision making of the Agents.
|
||||
/// It is used to register WorldProcessors to Worlds and to keep track of the
|
||||
/// reset logic of the simulation.
|
||||
/// </summary>
|
||||
public class Academy : IDisposable
|
||||
{
|
||||
#region Singleton
|
||||
|
@ -64,7 +70,7 @@ namespace Unity.AI.MLAgents
|
|||
/// <param name="policyId"> The string identifier of the MLAgentsWorld. There can only be one world per unique id.</param>
|
||||
/// <param name="world"> The MLAgentsWorld that is being subscribed.</param>
|
||||
/// <param name="worldProcessor"> If the remote process is not available, the MLAgentsWorld will use this World processor for decision making.</param>
|
||||
/// <param name="defaultRemote"> If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallbackWorldProcessor otherwise.</param>
|
||||
/// <param name="defaultRemote"> If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallback worldProcessor otherwise.</param>
|
||||
public void RegisterWorld(string policyId, MLAgentsWorld world, IWorldProcessor worldProcessor = null, bool defaultRemote = true)
|
||||
{
|
||||
// Need to find a way to deregister ?
|
||||
|
|
|
@ -9,13 +9,15 @@ namespace Unity.AI.MLAgents
|
|||
internal static class ActionHashMapUtils
|
||||
{
|
||||
/// <summary>
|
||||
/// Retrieves the action data for a world in puts it into a HashMap
|
||||
/// Retrieves the action data for a world in puts it into a HashMap.
|
||||
/// This action deletes the action data from the world.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld the data will be retrieved from.</param>
|
||||
/// <param name="allocator"> The memory allocator of the create NativeHashMap.</param>
|
||||
/// <typeparam name="T"> The type of the Action struct. It must match the Action Size and Action Type of the world.</typeparam>
|
||||
/// <returns> A NativeHashMap from Entities to Action.</returns>
|
||||
public static NativeHashMap<Entity, T> GetActionHashMap<T>(this MLAgentsWorld world, Allocator allocator) where T : struct
|
||||
/// <typeparam name="T"> The type of the Action struct. It must match the Action Size
|
||||
/// and Action Type of the world.</typeparam>
|
||||
/// <returns> A NativeHashMap from Entities to Actions with type T.</returns>
|
||||
public static NativeHashMap<Entity, T> GenerateActionHashMap<T>(this MLAgentsWorld world, Allocator allocator) where T : struct
|
||||
{
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
if (world.ActionSize != UnsafeUtility.SizeOf<T>() / 4)
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// Indicates the action space type of an Agent.
|
||||
/// </summary>
|
||||
public enum ActionType : int
|
||||
{
|
||||
|
||||
/// <summary>
|
||||
/// The Action is expected to be a struct of int or int enums.
|
||||
/// More precisely, struct must be convertible to an array of ints by
|
||||
/// memory copy.
|
||||
/// The number of ints in the struct corresponds to the number of actions,
|
||||
/// but the range of values each action can take is given by the "Branch" size.
|
||||
/// </summary>
|
||||
DISCRETE = 0,
|
||||
|
||||
/// <summary>
|
||||
/// The Action is expected to be a struct with only float members.
|
||||
/// More precisely, struct must be convertible to an array of float by
|
||||
/// memory copy.
|
||||
/// </summary>
|
||||
CONTINUOUS = 1,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
fileFormatVersion: 2
|
||||
guid: a82fe272987e4492cb74d2d80dfb1b4a
|
||||
MonoImporter:
|
||||
externalObjects: {}
|
||||
serializedVersion: 2
|
||||
defaultReferences: []
|
||||
executionOrder: 0
|
||||
icon: {instanceID: 0}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -93,7 +93,6 @@ namespace Unity.AI.MLAgents
|
|||
internal static unsafe JobHandle ScheduleImpl<T>(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps)
|
||||
where T : struct, IActuatorJob
|
||||
{
|
||||
// inputDeps.Complete();
|
||||
// Creating a data struct that contains the data the user passed into the job (This is what T is here)
|
||||
var data = new ActionEventJobData<T>
|
||||
{
|
||||
|
@ -124,7 +123,6 @@ namespace Unity.AI.MLAgents
|
|||
jobReflectionData = JobsUtility.CreateJobReflectionData(typeof(ActionEventJobData<T>), typeof(T), JobType.Single, (ExecuteJobFunction)Execute);
|
||||
return jobReflectionData;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
|
||||
|
@ -135,6 +133,7 @@ namespace Unity.AI.MLAgents
|
|||
{
|
||||
int size = jobData.world.ActionSize;
|
||||
int actionCount = jobData.world.ActionCounter.Count;
|
||||
|
||||
// Continuous case
|
||||
if (jobData.world.ActionType == ActionType.CONTINUOUS)
|
||||
{
|
||||
|
@ -149,6 +148,7 @@ namespace Unity.AI.MLAgents
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Discrete Case
|
||||
else
|
||||
{
|
||||
|
|
|
@ -56,53 +56,3 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
|
||||
/// <summary>
|
||||
/// Computes the total number of bytes necessary in the shared memory file to
|
||||
/// send and receive data for a specific MLAgentsWorld.
|
||||
/// </summary>
|
||||
public static int GetRequiredCapacity(MLAgentsWorld world)
|
||||
{
|
||||
// # int : 4 bytes : maximum number of Agents
|
||||
// # bool : 1 byte : is action discrete (False) or continuous (True)
|
||||
// # int : 4 bytes : action space size (continuous) / number of branches (discrete)
|
||||
// # -- If discrete only : array of action sizes for each branch
|
||||
// # int : 4 bytes : number of observations
|
||||
// # For each observation :
|
||||
// # 3 int : shape (the shape of the tensor observation for one agent
|
||||
// # start of the section that will change every step
|
||||
// # 4 bytes : n_agents at current step
|
||||
// # ? Bytes : the data : obs,reward,done,max_step,agent_id,masks,action
|
||||
int capacity = 64; // Name
|
||||
capacity += 4; // N Max Agents
|
||||
capacity += 1; // discrete or continuous
|
||||
capacity += 4; // action Size
|
||||
if (world.ActionType == ActionType.DISCRETE)
|
||||
{
|
||||
capacity += 4 * world.ActionSize; // The action branches
|
||||
}
|
||||
capacity += 4; // number of observations
|
||||
capacity += 3 * 4 * world.SensorShapes.Length; // The observation shapes
|
||||
capacity += 4; // Number of agents for the current step
|
||||
|
||||
var nAgents = world.Rewards.Length;
|
||||
capacity += 4 * world.Sensors.Length;
|
||||
capacity += 4 * nAgents;
|
||||
capacity += nAgents;
|
||||
capacity += nAgents;
|
||||
capacity += 4 * nAgents;
|
||||
if (world.ActionType == ActionType.DISCRETE)
|
||||
{
|
||||
foreach (int branch_size in world.DiscreteActionBranches)
|
||||
{
|
||||
capacity += branch_size * nAgents;
|
||||
}
|
||||
}
|
||||
capacity += 4 * world.ActionSize * nAgents;
|
||||
|
||||
return capacity;
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -2,11 +2,16 @@ using System;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// Contains exceptions specific to ML-Agents.
|
||||
internal class MLAgentsException : Exception
|
||||
/// </summary>
|
||||
public class MLAgentsException : Exception
|
||||
{
|
||||
public MLAgentsException(string message) : base(message)
|
||||
{
|
||||
}
|
||||
/// <summary>
|
||||
/// This exception indicates an error occurred in the ML-Agents package
|
||||
/// </summary>
|
||||
/// <param name="message">Text message for the error</param>
|
||||
/// <returns></returns>
|
||||
public MLAgentsException(string message) : base(message){}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,16 +6,6 @@ using System;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// Indicates the action space type of an Agent action.
|
||||
/// If CONTINUOUS, the Action is expected to be a struct containing only floats
|
||||
/// If DISCRETE, the Action is expected to be a struct of int or int enums.
|
||||
/// </summary>
|
||||
public enum ActionType : int
|
||||
{
|
||||
DISCRETE = 0,
|
||||
CONTINUOUS = 1,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// MLAgentsWorld is a data container on which the user requests decisions.
|
||||
|
|
|
@ -8,7 +8,12 @@ namespace Unity.AI.MLAgents
|
|||
private const string k_MemoryFileArgument = "--memory-path";
|
||||
private const string k_MemoryFilesDirectory = "ml-agents";
|
||||
private const string k_DefaultFileName = "default";
|
||||
// Used to read Python-provided environment parameters
|
||||
|
||||
/// <summary>
|
||||
/// Used to read Python-provided environment parameters. Will return the
|
||||
/// string corresponding to the path to the shared memory file.
|
||||
/// </summary>
|
||||
/// <returns> Path to the shared memroy file or null if no file is available</returns>
|
||||
public static string ReadSharedMemoryPathFromArgs()
|
||||
{
|
||||
var args = System.Environment.GetCommandLineArgs();
|
||||
|
@ -20,7 +25,7 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
#if UNITY_EDITOR
|
||||
// Try connecting on the default editor port
|
||||
// Try connecting on the default shared memory file
|
||||
var path = Path.Combine(
|
||||
Path.GetTempPath(),
|
||||
k_MemoryFilesDirectory,
|
||||
|
|
|
@ -14,6 +14,16 @@ namespace Unity.AI.MLAgents
|
|||
private IntPtr accessorPointer;
|
||||
private string filePath;
|
||||
|
||||
/// <summary>
|
||||
/// A bare bone shared memory wrapper that opens or create a new shared
|
||||
/// memory file and connects to it.
|
||||
/// </summary>
|
||||
/// <param name="fileName"> The name of the file. The file must be present
|
||||
/// in the Temporary folder (depends on OS) in the directory <see cref="k_Directory"/></param>
|
||||
/// <param name="createFile"> If true, the file will be created by the constructor.
|
||||
/// And error will be thrown if the file already exists.</param>
|
||||
/// <param name="size"> The size of the file to be created (will be ignored if
|
||||
/// <see cref="createFile"/> is true.</param>
|
||||
public BaseSharedMem(string fileName, bool createFile, int size = 0)
|
||||
{
|
||||
var directoryPath = Path.Combine(Path.GetTempPath(), k_Directory);
|
||||
|
@ -50,6 +60,12 @@ namespace Unity.AI.MLAgents
|
|||
accessorPointer = new IntPtr(ptr);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the integer present at the specified offset. The <see cref="offset"/> must be
|
||||
/// passed by reference and will be incremented to the next value.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <returns></returns>
|
||||
public int GetInt(ref int offset)
|
||||
{
|
||||
var result = accessor.ReadInt32(offset);
|
||||
|
@ -57,6 +73,12 @@ namespace Unity.AI.MLAgents
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the float present at the specified offset. The <see cref="offset"/> must be
|
||||
/// passed by reference and will be incremented to the next value.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <returns></returns>
|
||||
public float GetFloat(ref int offset)
|
||||
{
|
||||
var result = accessor.ReadSingle(offset);
|
||||
|
@ -64,6 +86,12 @@ namespace Unity.AI.MLAgents
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the boolean present at the specified offset. The <see cref="offset"/> must be
|
||||
/// passed by reference and will be incremented to the next value.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <returns></returns>
|
||||
public bool GetBool(ref int offset)
|
||||
{
|
||||
var result = accessor.ReadBoolean(offset);
|
||||
|
@ -71,6 +99,15 @@ namespace Unity.AI.MLAgents
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the string present at the specified offset. The <see cref="offset"/> must be
|
||||
/// passed by reference and will be incremented to the next value.
|
||||
/// Note, the strings must be represented with : a sbyte (8-bit unsigned)
|
||||
/// indicating the size of the string in bytes followed by the bytes of the
|
||||
/// string (in ASCII format).
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <returns></returns>
|
||||
public string GetString(ref int offset)
|
||||
{
|
||||
var length = accessor.ReadSByte(offset);
|
||||
|
@ -81,6 +118,15 @@ namespace Unity.AI.MLAgents
|
|||
return Encoding.ASCII.GetString(bytes);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Copies the values in the memory file at offset to the NativeArray <see cref="array"/>
|
||||
/// The <see cref="offset"/> must be passed by reference and will be incremented to the next value.
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <param name="array"> The destination array</param>
|
||||
/// <param name="length"> The number of bytes that must be copied, must be less or equal
|
||||
/// to the capacity of the array.</param>
|
||||
/// <typeparam name="T"> The type of the NativeArray. Must be a struct.</typeparam>
|
||||
/// </summary>
|
||||
public void GetArray<T>(ref int offset, NativeArray<T> array, int length) where T : struct
|
||||
{
|
||||
IntPtr src = IntPtr.Add(accessorPointer, offset);
|
||||
|
@ -89,6 +135,13 @@ namespace Unity.AI.MLAgents
|
|||
offset += length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the byte array present at the specified offset. The <see cref="offset"/> must be
|
||||
/// passed by reference and will be incremented to the next value.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <param name="length"> The size of the array in bytes</param>
|
||||
/// <returns></returns>
|
||||
public byte[] GetBytes(ref int offset, int length)
|
||||
{
|
||||
var result = new byte[length];
|
||||
|
@ -97,21 +150,44 @@ namespace Unity.AI.MLAgents
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the integer present at the specified offset.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value.</param>
|
||||
/// <returns></returns>
|
||||
public int GetInt(int offset)
|
||||
{
|
||||
return accessor.ReadInt32(offset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the float present at the specified offset.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value.</param>
|
||||
/// <returns></returns>
|
||||
public float GetFloat(int offset)
|
||||
{
|
||||
return accessor.ReadSingle(offset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the boolean present at the specified offset.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value.</param>
|
||||
/// <returns></returns>
|
||||
public bool GetBool(int offset)
|
||||
{
|
||||
return accessor.ReadBoolean(offset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the string present at the specified offset.
|
||||
/// Note, the strings must be represented with : a sbyte (8-bit unsigned)
|
||||
/// indicating the size of the string in bytes followed by the bytes of the
|
||||
/// string (in ASCII format).
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <returns></returns>
|
||||
public string GetString(int offset)
|
||||
{
|
||||
var length = accessor.ReadSByte(offset);
|
||||
|
@ -121,6 +197,14 @@ namespace Unity.AI.MLAgents
|
|||
return Encoding.ASCII.GetString(bytes);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Copies the values in the memory file at offset to the NativeArray <see cref="array"/>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <param name="array"> The destination array</param>
|
||||
/// <param name="length"> The number of bytes that must be copied, must be less or equal
|
||||
/// to the capacity of the array.</param>
|
||||
/// <typeparam name="T"> The type of the NativeArray. Must be a struct.</typeparam>
|
||||
/// </summary>
|
||||
public void GetArray<T>(int offset, NativeArray<T> array, int length) where T : struct
|
||||
{
|
||||
IntPtr src = IntPtr.Add(accessorPointer, offset);
|
||||
|
@ -128,6 +212,12 @@ namespace Unity.AI.MLAgents
|
|||
Buffer.MemoryCopy(src.ToPointer(), dst.ToPointer(), length, length);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the byte array present at the specified offset.
|
||||
/// </summary>
|
||||
/// <param name="offset"> Where to read the value</param>
|
||||
/// <param name="length"> The size of the array in bytes</param>
|
||||
/// <returns></returns>
|
||||
public byte[] GetBytes(int offset, int length)
|
||||
{
|
||||
var result = new byte[length];
|
||||
|
@ -138,24 +228,50 @@ namespace Unity.AI.MLAgents
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the integer at the specified offset in the shared memory.
|
||||
/// </summary>
|
||||
/// <param name="offset"> The position at which to write the value</param>
|
||||
/// <param name="value"> The value to be written</param>
|
||||
/// <returns> The offset right after the written value.</returns>
|
||||
public int SetInt(int offset, int value)
|
||||
{
|
||||
accessor.Write(offset, value);
|
||||
return offset + 4;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the float at the specified offset in the shared memory.
|
||||
/// </summary>
|
||||
/// <param name="offset"> The position at which to write the value</param>
|
||||
/// <param name="value"> The value to be written</param>
|
||||
/// <returns> The offset right after the written value.</returns>
|
||||
public int SetFloat(int offset, float value)
|
||||
{
|
||||
accessor.Write(offset, value);
|
||||
return offset + 4;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the boolean at the specified offset in the shared memory.
|
||||
/// </summary>
|
||||
/// <param name="offset"> The position at which to write the value</param>
|
||||
/// <param name="value"> The value to be written</param>
|
||||
/// <returns> The offset right after the written value.</returns>
|
||||
public int SetBool(int offset, bool value)
|
||||
{
|
||||
accessor.Write(offset, value);
|
||||
return offset + 1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the string at the specified offset in the shared memory.
|
||||
/// Note that the string is represented with a sbyte (8-bit unsigned) number
|
||||
/// encoding the length of the string and the string bytes in ASCII.
|
||||
/// </summary>
|
||||
/// <param name="offset"> The position at which to write the value</param>
|
||||
/// <param name="value"> The value to be written</param>
|
||||
/// <returns> The offset right after the written value.</returns>
|
||||
public int SetString(int offset, string value)
|
||||
{
|
||||
var bytes = Encoding.ASCII.GetBytes(value);
|
||||
|
@ -167,20 +283,38 @@ namespace Unity.AI.MLAgents
|
|||
return offset;
|
||||
}
|
||||
|
||||
public int SetArray<T>(int offset, NativeArray<T> arr, int length) where T : struct
|
||||
/// <summary>
|
||||
/// Copies the values present in a NativeArray to the shared memory file.
|
||||
/// </summary>
|
||||
/// <param name="offset"> The position at which to write the value</param>
|
||||
/// <param name="array"> The NativeArray containing the data to write</param>
|
||||
/// <param name="length"> The size of the array in bytes</param>
|
||||
/// <typeparam name="T"> The type of the NativeArray. Must be a struct.</typeparam>
|
||||
/// <returns> The offset right after the written value.</returns>
|
||||
public int SetArray<T>(int offset, NativeArray<T> array, int length) where T : struct
|
||||
{
|
||||
IntPtr dst = IntPtr.Add(accessorPointer, offset);
|
||||
IntPtr src = new IntPtr(arr.GetUnsafePtr());
|
||||
IntPtr src = new IntPtr(array.GetUnsafePtr());
|
||||
Buffer.MemoryCopy(src.ToPointer(), dst.ToPointer(), length, length);
|
||||
return offset + length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the byte array at the specified offset in the shared memory.
|
||||
/// </summary>
|
||||
/// <param name="offset"> The position at which to write the value</param>
|
||||
/// <param name="value"> The value to be written</param>
|
||||
/// <returns> The offset right after the written value.</returns>
|
||||
public int SetBytes(int offset, byte[] value)
|
||||
{
|
||||
accessor.WriteArray(offset, value, 0, value.Length);
|
||||
return offset + value.Length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Closes the shared memory reader and writter. This will not delete the file but
|
||||
/// remove access to it.
|
||||
/// </summary>
|
||||
public void Close()
|
||||
{
|
||||
if (accessor.CanWrite)
|
||||
|
@ -190,6 +324,11 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Indicates wether the <see cref="BaseSharedMem"/> has write access to the file.
|
||||
/// If it does not, it means the accessor was closed or the file was deleted.
|
||||
/// </summary>
|
||||
/// <value></value>
|
||||
protected bool CanEdit
|
||||
{
|
||||
get
|
||||
|
@ -198,6 +337,9 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Closes and deletes the shared memory file
|
||||
/// </summary>
|
||||
public void Delete()
|
||||
{
|
||||
Close();
|
||||
|
|
|
@ -13,6 +13,11 @@ namespace Unity.AI.MLAgents
|
|||
None
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// An editor friendly constructor for a MLAgentsWorld.
|
||||
/// Keeps track of the behavior specs of a world, its name,
|
||||
/// its processor type and Neural Network Model.
|
||||
/// </summary>
|
||||
[Serializable]
|
||||
public struct MLAgentsWorldSpecs
|
||||
{
|
||||
|
@ -31,6 +36,12 @@ namespace Unity.AI.MLAgents
|
|||
|
||||
private MLAgentsWorld m_World;
|
||||
|
||||
/// <summary>
|
||||
/// Generates the world using the specified specs and registers its
|
||||
/// processor to the Academy. The world is only created and registed once,
|
||||
/// if subsequent calls are made, the created world will be returned.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public MLAgentsWorld GetWorld()
|
||||
{
|
||||
if (m_World.IsCreated)
|
||||
|
|
|
@ -5,14 +5,37 @@ using Barracuda;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// Where to perform inference.
|
||||
/// </summary>
|
||||
public enum InferenceDevice
|
||||
{
|
||||
CPU,
|
||||
GPU
|
||||
/// <summary>
|
||||
/// CPU inference
|
||||
/// </summary>
|
||||
CPU = 0,
|
||||
|
||||
/// <summary>
|
||||
/// GPU inference
|
||||
/// </summary>
|
||||
GPU = 1
|
||||
}
|
||||
|
||||
public static class BarracudaWorldProcessorRegistringExtension
|
||||
{
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Neural
|
||||
/// Network Model. If the input model is null, a default inactive
|
||||
/// processor will be registered instead. Note that if the simulation
|
||||
/// connects to Python, the Neural Network will be ignored and the world
|
||||
/// will exchange data with Python instead.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="model"> The Neural Network model used by the processor</param>
|
||||
/// <param name="inferenceDevice"> The inference device specifying where to run inference
|
||||
/// (CPU or GPU)</param>
|
||||
public static void RegisterWorldWithBarracudaModel(
|
||||
this MLAgentsWorld world,
|
||||
string policyId,
|
||||
|
@ -31,6 +54,19 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Neural
|
||||
/// Network Model. If the input model is null, a default inactive
|
||||
/// processor will be registered instead. Note that if the simulation
|
||||
/// connects to Python, the world will not connect to Python and run the
|
||||
/// given Neural Network regardless.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="model"> The Neural Network model used by the processor</param>
|
||||
/// <param name="inferenceDevice"> The inference device specifying where to run inference
|
||||
/// (CPU or GPU)</param>
|
||||
public static void RegisterWorldWithBarracudaModelForceNoCommunication(
|
||||
this MLAgentsWorld world,
|
||||
string policyId,
|
||||
|
@ -49,6 +85,7 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal unsafe class BarracudaWorldProcessor : IWorldProcessor
|
||||
{
|
||||
MLAgentsWorld world;
|
||||
|
@ -78,6 +115,7 @@ namespace Unity.AI.MLAgents
|
|||
|
||||
public void ProcessWorld()
|
||||
{
|
||||
// TODO : Cover all cases
|
||||
// FOR VECTOR OBS ONLY
|
||||
// For Continuous control only
|
||||
// No LSTM
|
||||
|
|
|
@ -7,6 +7,19 @@ namespace Unity.AI.MLAgents
|
|||
{
|
||||
public static class HeristicWorldProcessorRegistringExtension
|
||||
{
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
|
||||
/// Note that if the simulation connects to Python, the Heuristic will
|
||||
/// be ignored and the world will exchange data with Python instead.
|
||||
/// The Heuristic is a Function that returns an action struct.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="heuristic"> The Heuristic used to generate the actions.
|
||||
/// Note that all agents in the world will receive the same action.</param>
|
||||
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
|
||||
/// Size and Action Type of the world.</typeparam>
|
||||
public static void RegisterWorldWithHeuristic<TH>(
|
||||
this MLAgentsWorld world,
|
||||
string policyId,
|
||||
|
@ -17,6 +30,19 @@ namespace Unity.AI.MLAgents
|
|||
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
|
||||
/// Note that if the simulation connects to Python, the world will not
|
||||
/// exchange data with Python and use the Heuristic regardless.
|
||||
/// The Heuristic is a Function that returns an action struct.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="heuristic"> The Heuristic used to generate the actions.
|
||||
/// Note that all agents in the world will receive the same action.</param>
|
||||
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
|
||||
/// Size and Action Type of the world.</typeparam>
|
||||
public static void RegisterWorldWithHeuristicForceNoCommunication<TH>(
|
||||
this MLAgentsWorld world,
|
||||
string policyId,
|
||||
|
|
|
@ -5,9 +5,22 @@ using Unity.Collections.LowLevel.Unsafe;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// The interface for a world processor. A world processor updates the
|
||||
/// action data of a world using the observation data present in it.
|
||||
/// </summary>
|
||||
public interface IWorldProcessor : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// True if the World Processor is connected to the Python process
|
||||
/// </summary>
|
||||
bool IsConnected {get;}
|
||||
|
||||
/// <summary>
|
||||
/// This method is called once everytime the world needs to update its action
|
||||
/// data. The implementation of this mehtod must give new action data to all the
|
||||
/// agents that requested a decision.
|
||||
/// </summary>
|
||||
void ProcessWorld();
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче