This commit is contained in:
vincentpierre 2020-06-08 09:48:17 -07:00
Родитель 427aaecf9e
Коммит 99c163dcd0
4 изменённых файлов: 116 добавлений и 122 удалений

5
.gitignore поставляемый
Просмотреть файл

@ -18,3 +18,8 @@
.vscode
setup.cfg
setup.cfg.meta
# API files
*.api
*.api.meta

Просмотреть файл

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Updated the shared memory communication to be more modular
- Added materials and textures to the 3DBall Sample
- Modified the API to reflect changes made to mlagents_envs version 0.16.0
- Renamed MLAgentsWorld to Policy
## [0.2.0-preview] - 2019-02-24

Просмотреть файл

@ -70,30 +70,30 @@ Please note that this package is available as a preview, so it is not ready for
## API
One approach to designing ml-agents to be compativle with DOTS would be to use typical API used for example in [Unity.Physics](https://github.com/Unity-Technologies/Unity.Physics) where a "MLAgents World" holds data, processes it and the data can then be retrieved.
The user would access the `MLAgentsWorld` in the main thread :
One approach to designing ml-agents to be compatible with DOTS would be to use typical API used for example in [Unity.Physics](https://github.com/Unity-Technologies/Unity.Physics) where a `Policy` holds data, processes it and the data can then be retrieved.
The user would access the `Policy` in the main thread :
```csharp
var world = new MLAgentsWorld(
var policy = new Policy(
100, // The maximum number of agents that can request a decision per step
new int3[] { new int3(3, 0, 0) }, // The observation shapes (here, one observation of shape (3,0,0))
ActionType.CONTINUOUS, // Continuous = float, Discrete = int
3); // The number of actions
world.SubscribeWorldWithBarracudaModel(Name, Model, InferenceDevice);
policy.SubscribePolicyWithBarracudaModel(Name, Model, InferenceDevice);
```
The user could then in his own jobs add and retrieve data from the world. Here is an example of a job in which the user populates the sensor data :
The user could then in his own jobs add and retrieve data from the `Policy`. Here is an example of a job in which the user populates the sensor data :
```csharp
public struct UserCreateSensingJob : IJobParallelFor
{
public NativeArray<Entity> entities;
public MLAgentsWorld world;
public Policy policy;
public void Execute(int i)
{
world.RequestDecision(entities[i])
policy.RequestDecision(entities[i])
.SetReward(1.0f)
.SetObservation(0, new float3(3.0f, 0, 0)); // observation index and then observation struct
@ -107,7 +107,7 @@ The job would be called this way or use the `Entities.ForEach` API :
protected override JobHandle OnUpdate(JobHandle inputDeps)
{
var job = new MyPopulationJob{
world = myWorld,
policy = myPolicy,
entities = ...,
sensors = ...,
reward = ...,
@ -120,7 +120,7 @@ Note that this API can also be called outside of a job and used in the main thre
```csharp
var visObs = VisualObservationUtility.GetVisObs(camera, 84, 84, Allocator.TempJob);
world.RequestDecision(entities[i])
policy.RequestDecision(entities[i])
.SetReward(1.0f)
.SetObservationFromSlice(1, visObs.Slice());
```
@ -139,46 +139,12 @@ public struct UserCreatedActionEventJob : IActuatorJob
```
The ActuatorEvent data contains a key (here an entity) to identify the Agent and a `GetAction` method to retrieve the data in the event. This is very similar to how collisions are currently handled in the Physics package.
## UI to create MLAgentsWorld
## UI to create a Policy
We currently offer a `MLAgentsWorldSpecs` struct that has a custom inspector drawer (you can add it to a MonoBehaviour to edit the properties of your MLAgentsWorld and even add a neural network for its behavior).
To generate the MLAgentsWorld with the given settings call `MLAgentsWorldSpecs.GetWorld()`.
We currently offer a `PolicySpecs` struct that has a custom inspector drawer (you can add it to a MonoBehaviour to edit the properties of your Policy and even add a neural network for its behavior).
To generate the Policy with the given settings call `PolicySpecs.GetPolicy()`.
## Communication Between C# and Python
In order to exchange data with Python, we use shared memory. Python will create a small file that contains information required for starting the communication. The path to the file will be randomly generated and passed by Python to the Unity Executable as command line argument. For in editor training, a default file will be used. Using shared memory allows for faster data exchange and will remove the need to serialize the data to an intermediate format.
__Note__ : The python code for communication is located in [ml-agents-envs~](./ml-agents-envs~).
### Shared memory layout
#### Header
- int : 4 bytes : File Length : Size of the file (will change as the file grows) (start at 22)
- int : 4 bytes : Version number : Unity and Python expecting the same memory layout
- bool : 1 byte : mutex : Is it Python or Unitys turn to edit the file (Unity blocked = True, Python Blocked = False) (start at `False`)
- ushort : 1 byte : Command : [step, reset, change file, close] (starts at `step`)
- step : DEFAULT : Nothing special
- reset : RESET : Only from Python to Unity to signal a reset
- change file : CHANGE_FILE : Can be sent by both C# and Python : Means the file is too short and needs to be changed. Both processes will switch to a new file (append `_` at the end of the old path) and delete the old one after reading the message. Note that to change file, the process must : Create the new file (with more capacity), copy the content of the file at appropriate location, add contexts to the file recompute the offsets to specific locations in the file, set the change file command, flip the mutex on the old file, use the new file and only flip the mutex when ready
- int : 4 bytes : The total amount of data in the side channel (starts at 4 for the next message length int)
- int : 4 bytes : The length of the side channel data in bytes for the current step
#### Side channel data
- int : 4 bytes : The length of the side channel data for the current step (starts at 0)
- ??? : Side channel data (Size = total side channel capacity - 4 bytes )
#### RL Data section
- int : 4 bytes : The number of Agent groups in the simulation (starts at 0)
- For each group :
- string : 64 byte : group name
- int : 4 bytes : maximum number of Agents
- bool : 1 byte : is action discrete (False) or continuous (True)
- int : 4 bytes : action space size (continuous) / number of branches (discrete)
- If discrete only : array of action sizes for each branch (size = n_branches x 4)
- int : 4 bytes : number of observations
- For each observation :
- 3 int : shape (the shape of the tensor observation for one agent)
- 4 bytes : n_agents at current step
- ??? bytes : the data : obs,reward,done,max_step,agent_id,masks,action

Просмотреть файл

@ -1,15 +1,23 @@
// This file is generated. Do not modify by hand.
// XML documentation file not found. To check if public methods have XML comments,
// make sure the XML doc file is present and located next to the scraped dll
namespace Unity.AI.MLAgents
{
public class Academy : System.IDisposable
{
public Unity.AI.MLAgents.SideChannels.IFloatProperties FloatProperties;
public event System.Action OnEnvironmentReset;
public System.Action OnEnvironmentReset;
public Unity.AI.MLAgents.EnvironmentParameters EnvironmentParameters { get; }
public static Unity.AI.MLAgents.Academy Instance { get; }
public bool IsCommunicatorOn { get; }
public static bool IsInitialized { get; }
public Unity.AI.MLAgents.StatsRecorder StatsRecorder { get; }
public virtual void Dispose();
public void SubscribeSideChannel(Unity.AI.MLAgents.SideChannels.SideChannel channel);
public void SubscribeWorld(string policyId, Unity.AI.MLAgents.MLAgentsWorld world, Unity.AI.MLAgents.IWorldProcessor fallbackWorldProcessor = default(Unity.AI.MLAgents.IWorldProcessor), bool communicate = True);
public void RegisterPolicy(string policyId, Unity.AI.MLAgents.Policy policy, Unity.AI.MLAgents.IPolicyProcessor policyProcessor = default(Unity.AI.MLAgents.IPolicyProcessor), bool defaultRemote = True);
}
[System.Runtime.CompilerServices.Extension] public static class ActionHashMapUtils
{
[System.Runtime.CompilerServices.Extension] public static Unity.Collections.NativeHashMap<Unity.Entities.Entity, T> GenerateActionHashMap<T>(Unity.AI.MLAgents.Policy policy, Unity.Collections.Allocator allocator) where T : System.ValueType, new();
}
public enum ActionType
@ -22,32 +30,45 @@ namespace Unity.AI.MLAgents
public struct ActuatorEvent
{
[Unity.Collections.ReadOnly] public int ActionSize;
[Unity.Collections.ReadOnly] public Unity.Collections.NativeSlice<float> ContinuousActionSlice;
[Unity.Collections.ReadOnly] public Unity.Collections.NativeSlice<int> DiscreteActionSlice;
[Unity.Collections.ReadOnly] public Unity.AI.MLAgents.ActionType ActionType;
[Unity.Collections.ReadOnly] public Unity.Entities.Entity Entity;
public void GetContinuousAction<T>(out T action) where T : System.ValueType, new();
public void GetDiscreteAction<T>(out T action) where T : System.ValueType, new();
public T GetAction<T>() where T : System.ValueType, new();
}
[System.Runtime.CompilerServices.Extension] public static class BarracudaWorldProcessorRegistringExtension
[System.Runtime.CompilerServices.Extension] public static class BarracudaPolicyProcessorRegistringExtension
{
[System.Runtime.CompilerServices.Extension] public static void SubscribeWorldWithBarracudaModel(Unity.AI.MLAgents.MLAgentsWorld world, string policyId, Barracuda.NNModel model, Unity.AI.MLAgents.InferenceDevice inferenceDevice = 0);
[System.Runtime.CompilerServices.Extension] public static void SubscribeWorldWithBarracudaModelForceNoCommunication<TH>(Unity.AI.MLAgents.MLAgentsWorld world, string policyId, Barracuda.NNModel model, Unity.AI.MLAgents.InferenceDevice inferenceDevice = 0);
[System.Runtime.CompilerServices.Extension] public static void RegisterPolicyWithBarracudaModel(Unity.AI.MLAgents.Policy policy, string policyId, Barracuda.NNModel model, Unity.AI.MLAgents.InferenceDevice inferenceDevice = 0);
[System.Runtime.CompilerServices.Extension] public static void RegisterPolicyWithBarracudaModelForceNoCommunication(Unity.AI.MLAgents.Policy policy, string policyId, Barracuda.NNModel model, Unity.AI.MLAgents.InferenceDevice inferenceDevice = 0);
}
public struct DecisionRequest
{
public Unity.AI.MLAgents.DecisionRequest HasTerminated(bool doneStatus, bool timedOut);
public Unity.AI.MLAgents.DecisionRequest SetDiscreteActionMask(int branch, int actionIndex);
public Unity.AI.MLAgents.DecisionRequest SetObservation(int sensorNumber, int sensor);
public Unity.AI.MLAgents.DecisionRequest SetObservation<T>(int sensorNumber, T sensor) where T : System.ValueType, new();
public Unity.AI.MLAgents.DecisionRequest SetObservationFromSlice(int sensorNumber, [Unity.Collections.ReadOnly] Unity.Collections.NativeSlice<float> obs);
public Unity.AI.MLAgents.DecisionRequest SetReward(float r);
}
[System.Runtime.CompilerServices.Extension] public static class HeristicWorldProcessorRegistringExtension
public sealed class EnvironmentParameters
{
[System.Runtime.CompilerServices.Extension] public static void SubscribeWorldWithHeuristic<TH>(Unity.AI.MLAgents.MLAgentsWorld world, string policyId, System.Func<TH> heuristic) where TH : System.ValueType, new();
[System.Runtime.CompilerServices.Extension] public static void SubscribeWorldWithHeuristicForceNoCommunication<TH>(Unity.AI.MLAgents.MLAgentsWorld world, string policyId, System.Func<TH> heuristic) where TH : System.ValueType, new();
public float GetWithDefault(string key, float defaultValue);
public System.Collections.Generic.IList<string> Keys();
public void RegisterCallback(string key, System.Action<float> action);
}
public struct EpisodeTermination
{
public Unity.AI.MLAgents.EpisodeTermination SetObservation(int sensorNumber, int sensor);
public Unity.AI.MLAgents.EpisodeTermination SetObservation<T>(int sensorNumber, T sensor) where T : System.ValueType, new();
public Unity.AI.MLAgents.EpisodeTermination SetObservationFromSlice(int sensorNumber, [Unity.Collections.ReadOnly] Unity.Collections.NativeSlice<float> obs);
public Unity.AI.MLAgents.EpisodeTermination SetReward(float r);
}
[System.Runtime.CompilerServices.Extension] public static class HeristicPolicyProcessorRegistringExtension
{
[System.Runtime.CompilerServices.Extension] public static void RegisterPolicyWithHeuristic<TH>(Unity.AI.MLAgents.Policy policy, string policyId, System.Func<TH> heuristic) where TH : System.ValueType, new();
[System.Runtime.CompilerServices.Extension] public static void RegisterPolicyWithHeuristicForceNoCommunication<TH>(Unity.AI.MLAgents.Policy policy, string policyId, System.Func<TH> heuristic) where TH : System.ValueType, new();
}
[Unity.Jobs.LowLevel.Unsafe.JobProducerType(typeof(Unity.AI.MLAgents.IActuatorJobExtensions.ActuatorDataJobProcess<>))] public interface IActuatorJob
@ -57,7 +78,7 @@ namespace Unity.AI.MLAgents
[System.Runtime.CompilerServices.Extension] public static class IActuatorJobExtensions
{
[System.Runtime.CompilerServices.Extension] public static Unity.Jobs.JobHandle Schedule<T>(T jobData, Unity.AI.MLAgents.MLAgentsWorld mlagentsWorld, Unity.Jobs.JobHandle inputDeps) where T : Unity.AI.MLAgents.IActuatorJob, System.ValueType, new();
[System.Runtime.CompilerServices.Extension] public static Unity.Jobs.JobHandle Schedule<T>(T jobData, Unity.AI.MLAgents.Policy policy, Unity.Jobs.JobHandle inputDeps) where T : System.ValueType, Unity.AI.MLAgents.IActuatorJob, new();
}
public enum InferenceDevice
@ -67,41 +88,48 @@ namespace Unity.AI.MLAgents
public int value__;
}
public interface IWorldProcessor : System.IDisposable
public interface IPolicyProcessor : System.IDisposable
{
public abstract bool IsConnected { get; }
public abstract Unity.AI.MLAgents.RemoteCommand ProcessWorld();
public abstract void Process();
}
public struct MLAgentsWorld : System.IDisposable
public class MLAgentsException : System.Exception
{
public MLAgentsException(string message) {}
}
public struct Policy : System.IDisposable
{
public bool IsCreated { get; }
public MLAgentsWorld(int maximumNumberAgents, Unity.AI.MLAgents.ActionType actionType, Unity.Mathematics.int3[] obsShapes, int actionSize, int[] discreteActionBranches = default(int[])) {}
public Policy(int maximumNumberAgents, Unity.Mathematics.int3[] obsShapes, Unity.AI.MLAgents.ActionType actionType, int actionSize, int[] discreteActionBranches = default(int[])) {}
public virtual void Dispose();
public Unity.AI.MLAgents.EpisodeTermination EndEpisode(Unity.Entities.Entity entity);
public Unity.AI.MLAgents.EpisodeTermination InterruptEpisode(Unity.Entities.Entity entity);
public Unity.AI.MLAgents.DecisionRequest RequestDecision(Unity.Entities.Entity entity);
}
public struct MLAgentsWorldSpecs
public struct PolicySpecs
{
public int ActionSize;
public Unity.AI.MLAgents.ActionType ActionType;
public int[] DiscreteActionBranches;
public Unity.AI.MLAgents.InferenceDevice InferenceDevice;
public Barracuda.NNModel Model;
public string Name;
public int NumberAgents;
public Unity.Mathematics.int3[] ObservationShapes;
public Unity.AI.MLAgents.MLAgentsWorld GenerateAndRegisterWorld();
public Unity.AI.MLAgents.MLAgentsWorld GenerateWorld();
public Unity.AI.MLAgents.Policy GetPolicy();
}
public enum RemoteCommand
public enum StatAggregationMethod
{
public const Unity.AI.MLAgents.RemoteCommand CHANGE_FILE = 2;
public const Unity.AI.MLAgents.RemoteCommand CLOSE = 3;
public const Unity.AI.MLAgents.RemoteCommand DEFAULT = 0;
public const Unity.AI.MLAgents.RemoteCommand RESET = 1;
public System.SByte value__;
public const Unity.AI.MLAgents.StatAggregationMethod Average = 0;
public const Unity.AI.MLAgents.StatAggregationMethod MostRecent = 1;
public int value__;
}
public sealed class StatsRecorder
{
public void Add(string key, float value, Unity.AI.MLAgents.StatAggregationMethod aggregationMethod = 0);
}
public static class TimeUtils
{
public static void DisableFixedRate(Unity.Entities.ComponentSystemGroup group);
public static void EnableFixedRateWithRepeat(Unity.Entities.ComponentSystemGroup group, float timeStep, int numberOfRepeat);
}
public static class VisualObservationUtility
@ -112,66 +140,60 @@ namespace Unity.AI.MLAgents
namespace Unity.AI.MLAgents.SideChannels
{
public class EngineConfigurationChannel : Unity.AI.MLAgents.SideChannels.SideChannel
public class FloatPropertiesChannel : Unity.AI.MLAgents.SideChannels.SideChannel
{
public EngineConfigurationChannel() {}
public virtual int ChannelType();
public virtual void OnMessageReceived(byte[] data);
public FloatPropertiesChannel(System.Guid channelId = default(System.Guid)) {}
public float GetWithDefault(string key, float defaultValue);
public System.Collections.Generic.IList<string> Keys();
protected virtual void OnMessageReceived(Unity.AI.MLAgents.SideChannels.IncomingMessage msg);
public void RegisterCallback(string key, System.Action<float> action);
public void Set(string key, float value);
}
public class FloatPropertiesChannel : Unity.AI.MLAgents.SideChannels.SideChannel, Unity.AI.MLAgents.SideChannels.IFloatProperties
public class IncomingMessage : System.IDisposable
{
public FloatPropertiesChannel() {}
public virtual int ChannelType();
public virtual float GetPropertyWithDefault(string key, float defaultValue);
public virtual System.Collections.Generic.IList<string> ListProperties();
public virtual void OnMessageReceived(byte[] data);
public virtual void RegisterCallback(string key, System.Action<float> action);
public virtual void SetProperty(string key, float value);
public IncomingMessage(byte[] data) {}
public virtual void Dispose();
public byte[] GetRawBytes();
public bool ReadBoolean(bool defaultValue = False);
public float ReadFloat32(float defaultValue = 0);
public System.Collections.Generic.IList<float> ReadFloatList(System.Collections.Generic.IList<float> defaultValue = default(System.Collections.Generic.IList<float>));
public int ReadInt32(int defaultValue = 0);
public string ReadString(string defaultValue = default(string));
}
public interface IFloatProperties
public class OutgoingMessage : System.IDisposable
{
public abstract float GetPropertyWithDefault(string key, float defaultValue);
public abstract System.Collections.Generic.IList<string> ListProperties();
public abstract void RegisterCallback(string key, System.Action<float> action);
public abstract void SetProperty(string key, float value);
public OutgoingMessage() {}
public virtual void Dispose();
public void SetRawBytes(byte[] data);
public void WriteBoolean(bool b);
public void WriteFloat32(float f);
public void WriteFloatList(System.Collections.Generic.IList<float> floatList);
public void WriteInt32(int i);
public void WriteString(string s);
}
public class RawBytesChannel : Unity.AI.MLAgents.SideChannels.SideChannel
{
public RawBytesChannel(int channelId = 0) {}
public virtual int ChannelType();
public RawBytesChannel(System.Guid channelId) {}
public System.Collections.Generic.IList<byte[]> GetAndClearReceivedMessages();
public System.Collections.Generic.IList<byte[]> GetReceivedMessages();
public virtual void OnMessageReceived(byte[] data);
protected virtual void OnMessageReceived(Unity.AI.MLAgents.SideChannels.IncomingMessage msg);
public void SendRawBytes(byte[] data);
}
public abstract class SideChannel
{
public System.Collections.Generic.List<byte[]> MessageQueue;
public System.Guid ChannelId { get; protected set; }
protected SideChannel() {}
public abstract int ChannelType();
public abstract void OnMessageReceived(byte[] data);
protected void QueueMessageToSend(byte[] data);
protected abstract void OnMessageReceived(Unity.AI.MLAgents.SideChannels.IncomingMessage msg);
protected void QueueMessageToSend(Unity.AI.MLAgents.SideChannels.OutgoingMessage msg);
}
public enum SideChannelType
public static class SideChannelsManager
{
public const Unity.AI.MLAgents.SideChannels.SideChannelType EngineSettings = 2;
public const Unity.AI.MLAgents.SideChannels.SideChannelType FloatProperties = 1;
public const Unity.AI.MLAgents.SideChannels.SideChannelType Invalid = 0;
public const Unity.AI.MLAgents.SideChannels.SideChannelType RawBytesChannelStart = 1000;
public const Unity.AI.MLAgents.SideChannels.SideChannelType UserSideChannelStart = 2000;
public int value__;
}
public class StringLogSideChannel : Unity.AI.MLAgents.SideChannels.SideChannel
{
public StringLogSideChannel() {}
public virtual int ChannelType();
public virtual void OnMessageReceived(byte[] data);
public void SendDebugStatementToPython(string logString, string stackTrace, UnityEngine.LogType type);
public static void RegisterSideChannel(Unity.AI.MLAgents.SideChannels.SideChannel sideChannel);
public static void UnregisterSideChannel(Unity.AI.MLAgents.SideChannels.SideChannel sideChannel);
}
}