Modified some formatting and added comments

This commit is contained in:
vincentpierre 2020-04-15 14:38:11 -07:00
Родитель d891643238
Коммит c19ffc2a89
17 изменённых файлов: 662 добавлений и 623 удалений

8
.gitignore поставляемый
Просмотреть файл

@ -10,3 +10,11 @@
*obj
.pytest_cache
*__pycache__
# precommit
.mypy_cache
.pre-commit-config.yaml
.pylintrc
.vscode
setup.cfg
setup.cfg.meta

Просмотреть файл

@ -36,7 +36,7 @@ This package is available as a preview, so it is not ready for production use. T
```python
from mlagents_envs.environment import UnityEnvironment
```
with
with
```python
from mlagents_dots_envs.unity_environment import UnityEnvironment
```
@ -45,7 +45,7 @@ This package is available as a preview, so it is not ready for production use. T
```python
from mlagents_envs.environment import UnityEnvironment
```
with
with
```python
from mlagents_dots_envs.unity_environment import UnityEnvironment
```
@ -62,7 +62,7 @@ This package is available as a preview, so it is not ready for production use. T
## API
One approach to designing ml-agents to be compativle with DOTS would be to use typical API used for example in [Unity.Physics](https://github.com/Unity-Technologies/Unity.Physics) where a "MLAgents World" holds data, processes it and the data can then be retrieved.
One approach to designing ml-agents to be compativle with DOTS would be to use typical API used for example in [Unity.Physics](https://github.com/Unity-Technologies/Unity.Physics) where a "MLAgents World" holds data, processes it and the data can then be retrieved.
The user would access the `MLAgentsWorld` in the main thread :
```csharp
@ -71,9 +71,9 @@ var world = new MLAgentsWorld(
new int3[] { new int3(3, 0, 0) }, // The observation shapes (here, one observation of shape (3,0,0))
ActionType.CONTINUOUS, // Continuous = float, Discrete = int
3); // The number of actions
world.SubscribeWorldWithBarracudaModel(Name, Model, InferenceDevice);
```
```
The user could then in his own jobs add and retrieve data from the world. Here is an example of a job in which the user populates the sensor data :
@ -117,7 +117,7 @@ world.RequestDecision(entities[i])
.SetObservationFromSlice(1, visObs.Slice());
```
In order to retrieve actions, we use a custom job :
In order to retrieve actions, we use a custom job :
```csharp
public struct UserCreatedActionEventJob : IActuatorJob
@ -162,7 +162,7 @@ __Note__ : The python code for communication is located in [ml-agents-envs~](./m
#### RL Data section
- int : 4 bytes : The number of Agent groups in the simulation (starts at 0)
- For each group :
- For each group :
- string : 64 byte : group name
- int : 4 bytes : maximum number of Agents

Просмотреть файл

@ -1,5 +1,5 @@
Unity.AI.MLAgents copyright © 2020 Unity Technologies ApS
Licensed under the Unity Companion License for Unity-dependent projects--see [Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
Licensed under the Unity Companion License for Unity-dependent projects--see [Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
Unless expressly provided otherwise, the Software under this license is made available strictly on an “AS IS” BASIS WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. Please review the license for details on these and other terms and conditions.

Просмотреть файл

@ -1,7 +1,7 @@
# ML-Agents DOTS
[ML-Agents on DOTS Proposal Google Doc](https://docs.google.com/document/d/1QnGSjOfLpwaRopbMf9ZDC89oZJuG0Ii6ORA22a5TWzE/edit#heading=h.py1zfmz3396x)
[Documentation](./Documentation~/README.md)
# ML-Agents DOTS
[ML-Agents on DOTS Proposal Google Doc](https://docs.google.com/document/d/1QnGSjOfLpwaRopbMf9ZDC89oZJuG0Ii6ORA22a5TWzE/edit#heading=h.py1zfmz3396x)
[Documentation](./Documentation~/README.md)

Просмотреть файл

@ -1,80 +1,80 @@
using Unity.Entities;
using Unity.Core;
namespace Unity.AI.MLAgents
{
public static class TimeUtils
{
/// <summary>
/// Configure the given ComponentSystemGroup to update at a fixed timestep, given by timeStep.
/// If the interval between the current time and the last update is bigger than the timestep
/// multiplied by the time scale,
/// the group's systems will be updated more than once.
/// </summary>
/// <param name="group">The group whose UpdateCallback will be configured with a fixed time step update call</param>
/// <param name="timeStep">The fixed time step (in seconds)</param>
/// <param name="timeScale">How much time passes in the group compared to other systems</param>
public static void EnableFixedRateWithCatchUp(ComponentSystemGroup group, float timeStep, float timeScale)
{
var manager = new FixedRateTimeScaleCatchUpManager(timeStep, timeScale);
group.UpdateCallback = manager.UpdateCallback;
}
/// <summary>
/// Disable fixed rate updates on the given group, by setting the UpdateCallback to null.
/// </summary>
/// <param name="group">The group whose UpdateCallback to set to null.</param>
public static void DisableFixedRate(ComponentSystemGroup group)
{
group.UpdateCallback = null;
}
}
internal class FixedRateTimeScaleCatchUpManager
{
protected float m_TimeScale;
protected float m_FixedTimeStep;
protected double m_LastFixedUpdateTime;
protected int m_FixedUpdateCount;
protected bool m_DidPushTime;
internal FixedRateTimeScaleCatchUpManager(float fixedStep, float timeScale)
{
m_FixedTimeStep = fixedStep;
m_TimeScale = timeScale;
}
internal bool UpdateCallback(ComponentSystemGroup group)
{
// if this is true, means we're being called a second or later time in a loop
if (m_DidPushTime)
{
group.World.PopTime();
}
var elapsedTime = group.World.Time.ElapsedTime * m_TimeScale;
if (m_LastFixedUpdateTime == 0.0)
m_LastFixedUpdateTime = elapsedTime - m_FixedTimeStep;
if (elapsedTime - m_LastFixedUpdateTime >= m_FixedTimeStep)
{
// Note that m_FixedTimeStep of 0.0f will never update
m_LastFixedUpdateTime += m_FixedTimeStep;
m_FixedUpdateCount++;
}
else
{
m_DidPushTime = false;
return false;
}
group.World.PushTime(new TimeData(
elapsedTime: m_LastFixedUpdateTime,
deltaTime: m_FixedTimeStep));
m_DidPushTime = true;
return true;
}
}
}
using Unity.Entities;
using Unity.Core;
namespace Unity.AI.MLAgents
{
public static class TimeUtils
{
/// <summary>
/// Configure the given ComponentSystemGroup to update at a fixed timestep, given by timeStep.
/// If the interval between the current time and the last update is bigger than the timestep
/// multiplied by the time scale,
/// the group's systems will be updated more than once.
/// </summary>
/// <param name="group">The group whose UpdateCallback will be configured with a fixed time step update call</param>
/// <param name="timeStep">The fixed time step (in seconds)</param>
/// <param name="timeScale">How much time passes in the group compared to other systems</param>
public static void EnableFixedRateWithCatchUp(ComponentSystemGroup group, float timeStep, float timeScale)
{
var manager = new FixedRateTimeScaleCatchUpManager(timeStep, timeScale);
group.UpdateCallback = manager.UpdateCallback;
}
/// <summary>
/// Disable fixed rate updates on the given group, by setting the UpdateCallback to null.
/// </summary>
/// <param name="group">The group whose UpdateCallback to set to null.</param>
public static void DisableFixedRate(ComponentSystemGroup group)
{
group.UpdateCallback = null;
}
}
internal class FixedRateTimeScaleCatchUpManager
{
protected float m_TimeScale;
protected float m_FixedTimeStep;
protected double m_LastFixedUpdateTime;
protected int m_FixedUpdateCount;
protected bool m_DidPushTime;
internal FixedRateTimeScaleCatchUpManager(float fixedStep, float timeScale)
{
m_FixedTimeStep = fixedStep;
m_TimeScale = timeScale;
}
internal bool UpdateCallback(ComponentSystemGroup group)
{
// if this is true, means we're being called a second or later time in a loop
if (m_DidPushTime)
{
group.World.PopTime();
}
var elapsedTime = group.World.Time.ElapsedTime * m_TimeScale;
if (m_LastFixedUpdateTime == 0.0)
m_LastFixedUpdateTime = elapsedTime - m_FixedTimeStep;
if (elapsedTime - m_LastFixedUpdateTime >= m_FixedTimeStep)
{
// Note that m_FixedTimeStep of 0.0f will never update
m_LastFixedUpdateTime += m_FixedTimeStep;
m_FixedUpdateCount++;
}
else
{
m_DidPushTime = false;
return false;
}
group.World.PushTime(new TimeData(
elapsedTime: m_LastFixedUpdateTime,
deltaTime: m_FixedTimeStep));
m_DidPushTime = true;
return true;
}
}
}

Просмотреть файл

@ -1,78 +1,78 @@
using Unity.Entities;
using Unity.Collections;
using Unity.Mathematics;
using Unity.Transforms;
using UnityEngine;
using Random = UnityEngine.Random;
using Unity.AI.MLAgents;
public class BalanceBallManager : MonoBehaviour
{
public MLAgentsWorldSpecs MyWorldSpecs;
private EntityManager manager;
public GameObject prefabPlatform;
public GameObject prefabBall;
private Entity _prefabEntityPlatform;
private Entity _prefabEntityBall;
int currentIndex;
void Awake()
{
var world = MyWorldSpecs.GetWorld();
var ballSystem = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem<BallSystem>();
ballSystem.Enabled = true;
ballSystem.BallWorld = world;
manager = World.DefaultGameObjectInjectionWorld.EntityManager;
BlobAssetStore blob = new BlobAssetStore();
GameObjectConversionSettings settings = GameObjectConversionSettings.FromWorld(World.DefaultGameObjectInjectionWorld, blob);
_prefabEntityPlatform = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabPlatform, settings);
_prefabEntityBall = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabBall, settings);
Spawn(1000);
blob.Dispose();
}
void Spawn(int amount)
{
NativeArray<Entity> entitiesP = new NativeArray<Entity>(amount, Allocator.Temp);
NativeArray<Entity> entitiesB = new NativeArray<Entity>(amount, Allocator.Temp);
manager.Instantiate(_prefabEntityPlatform, entitiesP);
manager.Instantiate(_prefabEntityBall, entitiesB);
for (int i = 0; i < amount; i++)
{
float3 position = new float3((currentIndex % 10) - 5, (currentIndex / 10 % 10) - 5, currentIndex / 100) * 5f;
float valX = Random.Range(-0.1f, 0.1f);
float valZ = Random.Range(-0.1f, 0.1f);
manager.SetComponentData(entitiesP[i],
new Translation
{
Value = position
});
manager.SetComponentData(entitiesB[i],
new Translation
{
Value = position + new float3(0, 0.2f, 0)
});
manager.SetComponentData(entitiesP[i],
new Rotation
{
Value = quaternion.EulerXYZ(valX, 0, valZ)
});
manager.AddComponent<AgentData>(entitiesP[i]);
manager.SetComponentData(entitiesP[i], new AgentData {
BallResetPosition = position + new float3(0, 0.2f, 0),
BallRef = entitiesB[i],
StepCount = 0
});
manager.AddComponent<Actuator>(entitiesP[i]);
currentIndex++;
}
entitiesP.Dispose();
entitiesB.Dispose();
}
}
using Unity.Entities;
using Unity.Collections;
using Unity.Mathematics;
using Unity.Transforms;
using UnityEngine;
using Random = UnityEngine.Random;
using Unity.AI.MLAgents;
public class BalanceBallManager : MonoBehaviour
{
public MLAgentsWorldSpecs MyWorldSpecs;
private EntityManager manager;
public GameObject prefabPlatform;
public GameObject prefabBall;
private Entity _prefabEntityPlatform;
private Entity _prefabEntityBall;
int currentIndex;
void Awake()
{
var world = MyWorldSpecs.GetWorld();
var ballSystem = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem<BallSystem>();
ballSystem.Enabled = true;
ballSystem.BallWorld = world;
manager = World.DefaultGameObjectInjectionWorld.EntityManager;
BlobAssetStore blob = new BlobAssetStore();
GameObjectConversionSettings settings = GameObjectConversionSettings.FromWorld(World.DefaultGameObjectInjectionWorld, blob);
_prefabEntityPlatform = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabPlatform, settings);
_prefabEntityBall = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabBall, settings);
Spawn(1000);
blob.Dispose();
}
void Spawn(int amount)
{
NativeArray<Entity> entitiesP = new NativeArray<Entity>(amount, Allocator.Temp);
NativeArray<Entity> entitiesB = new NativeArray<Entity>(amount, Allocator.Temp);
manager.Instantiate(_prefabEntityPlatform, entitiesP);
manager.Instantiate(_prefabEntityBall, entitiesB);
for (int i = 0; i < amount; i++)
{
float3 position = new float3((currentIndex % 10) - 5, (currentIndex / 10 % 10) - 5, currentIndex / 100) * 5f;
float valX = Random.Range(-0.1f, 0.1f);
float valZ = Random.Range(-0.1f, 0.1f);
manager.SetComponentData(entitiesP[i],
new Translation
{
Value = position
});
manager.SetComponentData(entitiesB[i],
new Translation
{
Value = position + new float3(0, 0.2f, 0)
});
manager.SetComponentData(entitiesP[i],
new Rotation
{
Value = quaternion.EulerXYZ(valX, 0, valZ)
});
manager.AddComponent<AgentData>(entitiesP[i]);
manager.SetComponentData(entitiesP[i], new AgentData {
BallResetPosition = position + new float3(0, 0.2f, 0),
BallRef = entitiesB[i],
StepCount = 0
});
manager.AddComponent<Actuator>(entitiesP[i]);
currentIndex++;
}
entitiesP.Dispose();
entitiesB.Dispose();
}
}

Просмотреть файл

@ -1,127 +1,127 @@
using Unity.Entities;
using Unity.Jobs;
using Unity.Mathematics;
using Unity.Transforms;
using Unity.AI.MLAgents;
using Unity.Collections;
using Unity.Physics;
public struct AgentData : IComponentData
{
public float3 BallResetPosition;
public Entity BallRef;
public int StepCount;
}
public struct Actuator : IComponentData
{
public float2 Value;
}
public class BallSystem : JobComponentSystem
{
private const int maxStep = 5000;
private struct RotateJob : IActuatorJob
{
public ComponentDataFromEntity<Actuator> ComponentDataFromEntity;
public void Execute(ActuatorEvent ev)
{
var a = ev.GetAction<Actuator>();
ComponentDataFromEntity[ev.Entity] = a;
}
}
public MLAgentsWorld BallWorld;
// Update is called once per frame
protected override JobHandle OnUpdate(JobHandle inputDeps)
{
if (!BallWorld.IsCreated){
return inputDeps;
}
var world = BallWorld;
ComponentDataFromEntity<Translation> TranslationFromEntity = GetComponentDataFromEntity<Translation>(isReadOnly: false);
ComponentDataFromEntity<PhysicsVelocity> VelFromEntity = GetComponentDataFromEntity<PhysicsVelocity>(isReadOnly: false);
inputDeps = Entities
.WithNativeDisableParallelForRestriction(TranslationFromEntity)
.WithNativeDisableParallelForRestriction(VelFromEntity)
.ForEach((Entity entity, ref Rotation rot, ref AgentData agentData) =>
{
var ballPos = TranslationFromEntity[agentData.BallRef].Value;
var ballVel = VelFromEntity[agentData.BallRef].Linear;
var platformVel = VelFromEntity[entity];
bool taskFailed = false;
bool interruption = false;
if (ballPos.y - agentData.BallResetPosition.y < -0.7f)
{
taskFailed = true;
agentData.StepCount = 0;
}
if (agentData.StepCount > maxStep)
{
interruption = true;
agentData.StepCount = 0;
}
if (!interruption && !taskFailed)
{
world.RequestDecision(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
.SetObservation(3, platformVel.Angular)
.SetReward((0.1f));
}
if (taskFailed)
{
world.EndEpisode(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
.SetObservation(3, platformVel.Angular)
.SetReward(-1f);
}
else if (interruption)
{
world.InterruptEpisode(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
.SetObservation(3, platformVel.Angular)
.SetReward((0.1f));
}
if (interruption || taskFailed)
{
VelFromEntity[agentData.BallRef] = new PhysicsVelocity();
TranslationFromEntity[agentData.BallRef] = new Translation { Value = agentData.BallResetPosition };
rot.Value = quaternion.identity;
}
agentData.StepCount++;
}).Schedule(inputDeps);
var reactiveJob = new RotateJob
{
ComponentDataFromEntity = GetComponentDataFromEntity<Actuator>(isReadOnly: false)
};
inputDeps = reactiveJob.Schedule(world, inputDeps);
inputDeps = Entities.ForEach((Actuator act, ref Rotation rotation) =>
{
var rot = math.mul(rotation.Value, quaternion.Euler(0.05f * new float3(act.Value.x, 0, act.Value.y)));
rotation.Value = rot;
}).Schedule(inputDeps);
return inputDeps;
}
protected override void OnDestroy()
{
BallWorld.Dispose();
}
}
using Unity.Entities;
using Unity.Jobs;
using Unity.Mathematics;
using Unity.Transforms;
using Unity.AI.MLAgents;
using Unity.Collections;
using Unity.Physics;
public struct AgentData : IComponentData
{
public float3 BallResetPosition;
public Entity BallRef;
public int StepCount;
}
public struct Actuator : IComponentData
{
public float2 Value;
}
public class BallSystem : JobComponentSystem
{
private const int maxStep = 5000;
private struct RotateJob : IActuatorJob
{
public ComponentDataFromEntity<Actuator> ComponentDataFromEntity;
public void Execute(ActuatorEvent ev)
{
var a = ev.GetAction<Actuator>();
ComponentDataFromEntity[ev.Entity] = a;
}
}
public MLAgentsWorld BallWorld;
// Update is called once per frame
protected override JobHandle OnUpdate(JobHandle inputDeps)
{
if (!BallWorld.IsCreated){
return inputDeps;
}
var world = BallWorld;
ComponentDataFromEntity<Translation> TranslationFromEntity = GetComponentDataFromEntity<Translation>(isReadOnly: false);
ComponentDataFromEntity<PhysicsVelocity> VelFromEntity = GetComponentDataFromEntity<PhysicsVelocity>(isReadOnly: false);
inputDeps = Entities
.WithNativeDisableParallelForRestriction(TranslationFromEntity)
.WithNativeDisableParallelForRestriction(VelFromEntity)
.ForEach((Entity entity, ref Rotation rot, ref AgentData agentData) =>
{
var ballPos = TranslationFromEntity[agentData.BallRef].Value;
var ballVel = VelFromEntity[agentData.BallRef].Linear;
var platformVel = VelFromEntity[entity];
bool taskFailed = false;
bool interruption = false;
if (ballPos.y - agentData.BallResetPosition.y < -0.7f)
{
taskFailed = true;
agentData.StepCount = 0;
}
if (agentData.StepCount > maxStep)
{
interruption = true;
agentData.StepCount = 0;
}
if (!interruption && !taskFailed)
{
world.RequestDecision(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
.SetObservation(3, platformVel.Angular)
.SetReward((0.1f));
}
if (taskFailed)
{
world.EndEpisode(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
.SetObservation(3, platformVel.Angular)
.SetReward(-1f);
}
else if (interruption)
{
world.InterruptEpisode(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
.SetObservation(3, platformVel.Angular)
.SetReward((0.1f));
}
if (interruption || taskFailed)
{
VelFromEntity[agentData.BallRef] = new PhysicsVelocity();
TranslationFromEntity[agentData.BallRef] = new Translation { Value = agentData.BallResetPosition };
rot.Value = quaternion.identity;
}
agentData.StepCount++;
}).Schedule(inputDeps);
var reactiveJob = new RotateJob
{
ComponentDataFromEntity = GetComponentDataFromEntity<Actuator>(isReadOnly: false)
};
inputDeps = reactiveJob.Schedule(world, inputDeps);
inputDeps = Entities.ForEach((Actuator act, ref Rotation rotation) =>
{
var rot = math.mul(rotation.Value, quaternion.Euler(0.05f * new float3(act.Value.x, 0, act.Value.y)));
rotation.Value = rot;
}).Schedule(inputDeps);
return inputDeps;
}
protected override void OnDestroy()
{
BallWorld.Dispose();
}
}

Просмотреть файл

@ -1,106 +1,106 @@
using Unity.Collections;
using Unity.Entities;
using Unity.Mathematics;
using Unity.Jobs;
using UnityEngine;
using Unity.AI.MLAgents;
[DisableAutoCreation]
public class SimpleSystem : JobComponentSystem
{
private MLAgentsWorld world;
private NativeArray<Entity> entities;
private Camera camera;
public const int N_Agents = 5;
int counter;
// Start is called before the first frame update
protected override void OnCreate()
{
Application.targetFrameRate = -1;
world = new MLAgentsWorld(N_Agents, new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) }, ActionType.DISCRETE, 2, new int[] { 2, 3 });
world.RegisterWorldWithHeuristic("test", () => new int2(1, 1));
entities = new NativeArray<Entity>(N_Agents, Allocator.Persistent);
for (int i = 0; i < N_Agents; i++)
{
entities[i] = World.DefaultGameObjectInjectionWorld.EntityManager.CreateEntity();
}
}
protected override void OnDestroy()
{
world.Dispose();
entities.Dispose();
}
// Update is called once per frame
protected override JobHandle OnUpdate(JobHandle inputDeps)
{
if (camera == null)
{
camera = Camera.main;
camera = GameObject.FindObjectOfType<Camera>();
}
inputDeps.Complete();
var reactiveJob = new UserCreatedActionEventJob
{
myNumber = 666
};
inputDeps = reactiveJob.Schedule(world, inputDeps);
if (counter % 5 == 0)
{
var visObs = VisualObservationUtility.GetVisObs(camera, 84, 84, Allocator.TempJob);
var senseJob = new UserCreateSensingJob
{
cameraObservation = visObs,
entities = entities,
world = world
};
inputDeps = senseJob.Schedule(N_Agents, 64, inputDeps);
inputDeps.Complete();
visObs.Dispose();
}
counter++;
return inputDeps;
}
public struct UserCreateSensingJob : IJobParallelFor
{
[ReadOnly] public NativeArray<float> cameraObservation;
public NativeArray<Entity> entities;
public MLAgentsWorld world;
public void Execute(int i)
{
world.RequestDecision(entities[i])
.SetReward(1.0f)
.SetObservation(0, new float3(entities[i].Index, 0, 0))
.SetObservationFromSlice(1, cameraObservation.Slice());
}
}
public struct UserCreatedActionEventJob : IActuatorJob
{
public int myNumber;
public void Execute(ActuatorEvent data)
{
var tmp =data.GetAction<testAction>();
Debug.Log(data.Entity.Index + " " + tmp.e1);
}
}
public enum testEnum
{
A, B, C
}
public struct testAction
{
public testEnum e1;
public testEnum e2;
}
}
using Unity.Collections;
using Unity.Entities;
using Unity.Mathematics;
using Unity.Jobs;
using UnityEngine;
using Unity.AI.MLAgents;
[DisableAutoCreation]
public class SimpleSystem : JobComponentSystem
{
private MLAgentsWorld world;
private NativeArray<Entity> entities;
private Camera camera;
public const int N_Agents = 5;
int counter;
// Start is called before the first frame update
protected override void OnCreate()
{
Application.targetFrameRate = -1;
world = new MLAgentsWorld(N_Agents, new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) }, ActionType.DISCRETE, 2, new int[] { 2, 3 });
world.RegisterWorldWithHeuristic("test", () => new int2(1, 1));
entities = new NativeArray<Entity>(N_Agents, Allocator.Persistent);
for (int i = 0; i < N_Agents; i++)
{
entities[i] = World.DefaultGameObjectInjectionWorld.EntityManager.CreateEntity();
}
}
protected override void OnDestroy()
{
world.Dispose();
entities.Dispose();
}
// Update is called once per frame
protected override JobHandle OnUpdate(JobHandle inputDeps)
{
if (camera == null)
{
camera = Camera.main;
camera = GameObject.FindObjectOfType<Camera>();
}
inputDeps.Complete();
var reactiveJob = new UserCreatedActionEventJob
{
myNumber = 666
};
inputDeps = reactiveJob.Schedule(world, inputDeps);
if (counter % 5 == 0)
{
var visObs = VisualObservationUtility.GetVisObs(camera, 84, 84, Allocator.TempJob);
var senseJob = new UserCreateSensingJob
{
cameraObservation = visObs,
entities = entities,
world = world
};
inputDeps = senseJob.Schedule(N_Agents, 64, inputDeps);
inputDeps.Complete();
visObs.Dispose();
}
counter++;
return inputDeps;
}
public struct UserCreateSensingJob : IJobParallelFor
{
[ReadOnly] public NativeArray<float> cameraObservation;
public NativeArray<Entity> entities;
public MLAgentsWorld world;
public void Execute(int i)
{
world.RequestDecision(entities[i])
.SetReward(1.0f)
.SetObservation(0, new float3(entities[i].Index, 0, 0))
.SetObservationFromSlice(1, cameraObservation.Slice());
}
}
public struct UserCreatedActionEventJob : IActuatorJob
{
public int myNumber;
public void Execute(ActuatorEvent data)
{
var tmp =data.GetAction<testAction>();
Debug.Log(data.Entity.Index + " " + tmp.e1);
}
}
public enum testEnum
{
A, B, C
}
public struct testAction
{
public testEnum e1;
public testEnum e2;
}
}

Просмотреть файл

@ -7,66 +7,67 @@ import subprocess
from sys import platform
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.side_channel.side_channel import SideChannel
from typing import Dict, List, Optional, Any, Tuple
from typing import Dict, List, Optional
from mlagents_envs.logging_util import get_logger
logger = get_logger(__name__)
def validate_environment_path(env_path: str) -> Optional[str]:
# Strip out executable extensions if passed
env_path = (
env_path.strip()
.replace(".app", "")
.replace(".exe", "")
.replace(".x86_64", "")
.replace(".x86", "")
# Strip out executable extensions if passed
env_path = (
env_path.strip()
.replace(".app", "")
.replace(".exe", "")
.replace(".x86_64", "")
.replace(".x86", "")
)
true_filename = os.path.basename(os.path.normpath(env_path))
logger.debug("The true file name is {}".format(true_filename))
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None
cwd = os.getcwd()
launch_string = None
true_filename = os.path.basename(os.path.normpath(env_path))
if platform == "linux" or platform == "linux2":
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86")
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "darwin":
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
)
true_filename = os.path.basename(os.path.normpath(env_path))
logger.debug("The true file name is {}".format(true_filename))
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None
cwd = os.getcwd()
launch_string = None
true_filename = os.path.basename(os.path.normpath(env_path))
if platform == "linux" or platform == "linux2":
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86")
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "darwin":
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "win32":
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
if len(candidates) == 0:
candidates = glob.glob(env_path + ".exe")
if len(candidates) > 0:
launch_string = candidates[0]
return launch_string
# TODO: END REMOVE
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "win32":
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
if len(candidates) == 0:
candidates = glob.glob(env_path + ".exe")
if len(candidates) > 0:
launch_string = candidates[0]
return launch_string
# TODO: END REMOVE
def get_side_channels(
@ -77,9 +78,8 @@ def get_side_channels(
for _sc in side_c:
if _sc.channel_id in side_channels_dict:
raise UnityEnvironmentException(
"There cannot be two side channels with the same channel type {0}.".format(
_sc.channel_id
)
f"There cannot be two side channels with "
f"the same channel id {_sc.channel_id}."
)
side_channels_dict[_sc.channel_id] = _sc
return side_channels_dict
@ -110,9 +110,7 @@ def executable_launcher(exec_name, memory_path, args):
elif platform == "darwin":
candidates = glob.glob(
os.path.join(
cwd, exec_name + ".app", "Contents", "MacOS", true_filename
)
os.path.join(cwd, exec_name + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
@ -137,9 +135,7 @@ def executable_launcher(exec_name, memory_path, args):
if launch_string is None:
raise UnityEnvironmentException(
"Couldn't launch the {0} environment. "
"Provided filename does not match any environments.".format(
true_filename
)
"Provided filename does not match any environments.".format(true_filename)
)
else:
logger.debug("This is the launch string {}".format(launch_string))
@ -151,9 +147,10 @@ def executable_launcher(exec_name, memory_path, args):
return subprocess.Popen(
subprocess_args,
# start_new_session=True means that signals to the parent python process
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms.
# This is generally good since we want the environment to have a chance to shutdown,
# but may be undesirable in come cases; if so, we'll add a command-line toggle.
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new
# process on POSIX platforms. This is generally good since we want the
# environment to have a chance to shutdown, but may be undesirable in
# some cases; if so, we'll add a command-line toggle.
# Note that on Windows, the CTRL_C signal will still be sent.
start_new_session=True,
)
@ -198,6 +195,7 @@ def parse_side_channel_message(
": {0}.".format(channel_id)
)
def generate_side_channel_data(
side_channels: Dict[uuid.UUID, SideChannel]
) -> bytearray:
@ -217,7 +215,8 @@ def returncode_to_signal_name(returncode: int) -> Optional[str]:
E.g. returncode_to_signal_name(-2) -> "SIGINT"
"""
try:
# A negative value -N indicates that the child was terminated by signal N (POSIX only).
# A negative value -N indicates that the child was terminated by
# signal N (POSIX only).
s = signal.Signals(-returncode) # pylint: disable=no-member
return s.name
except Exception:

Просмотреть файл

@ -5,9 +5,7 @@ from mlagents_dots_envs.unity_environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
)
from mlagents_envs.side_channel.float_properties_channel import (
FloatPropertiesChannel,
)
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
sc = EngineConfigurationChannel()
sc2 = FloatPropertiesChannel()
@ -15,31 +13,42 @@ env = UnityEnvironment(side_channels=[sc, sc2])
sc.set_configuration_parameters()
for i in range(10):
s = ""
for name in env.get_behavior_names():
s += name + " : " + str(len(env.get_steps(name)[0])) + " : " + str(len(env.get_steps(name)[1])) + "|"
env.set_actions(name, np.ones((len(env.get_steps(name)[0]), 2)))
print(s)
env.step()
for _ in range(10):
s = ""
for name in env.get_behavior_names():
s += (
name
+ " : "
+ str(len(env.get_steps(name)[0]))
+ " : "
+ str(len(env.get_steps(name)[1]))
+ "|"
)
env.set_actions(name, np.ones((len(env.get_steps(name)[0]), 2)))
print(s)
env.step()
print(env.get_steps("Ball_DOTS")[0].obs[0])
print("RESET")
env.reset()
# import time
# time.sleep(50)
sc2.set_property("test", 2)
sc2.set_property("test2", 2)
sc2.set_property("test3", 2)
for i in range(100):
s = ""
for name in env.get_behavior_names():
s += name +" : " + str(len(env.get_steps(name)[0])) +" : " + str(len(env.get_steps(name)[1]))+ "|"
print(s)
env.step()
for _ in range(100):
s = ""
for name in env.get_behavior_names():
s += (
name
+ " : "
+ str(len(env.get_steps(name)[0]))
+ " : "
+ str(len(env.get_steps(name)[1]))
+ "|"
)
print(s)
env.step()
env.close()

Просмотреть файл

@ -1,16 +1,17 @@
from abc import ABC, abstractmethod
from abc import ABC
import os
import tempfile
import mmap
import numpy as np
import struct
import uuid
from typing import Tuple, Optional, NamedTuple, List, Dict
from typing import Tuple
class BasedSharedMem(ABC):
DIRECTORY = "ml-agents"
def __init__(self, file_name: str, create_file: bool = False, size:int = 0):
def __init__(self, file_name: str, create_file: bool = False, size: int = 0):
directory = os.path.join(tempfile.gettempdir(), self.DIRECTORY)
if not os.path.exists(directory):
os.makedirs(directory)
@ -26,16 +27,16 @@ class BasedSharedMem(ABC):
self.accessor = mmap.mmap(f.fileno(), 0)
self._file_path = file_path
def get_int(self, offset:int) -> Tuple[int, int]:
def get_int(self, offset: int) -> Tuple[int, int]:
return struct.unpack_from("<i", self.accessor, offset)[0], offset + 4
def get_float(self, offset:int) -> Tuple[float, int]:
def get_float(self, offset: int) -> Tuple[float, int]:
return struct.unpack_from("<f", self.accessor, offset)[0], offset + 4
def get_bool(self, offset:int) -> Tuple[bool, int]:
def get_bool(self, offset: int) -> Tuple[bool, int]:
return struct.unpack_from("<?", self.accessor, offset)[0], offset + 1
def get_string(self, offset:int) -> Tuple[str, int]:
def get_string(self, offset: int) -> Tuple[str, int]:
string_len = struct.unpack_from("<B", self.accessor, offset)[0]
byte_array = bytes(self.accessor[offset + 1 : offset + string_len + 1])
result = byte_array.decode("ascii")
@ -50,44 +51,43 @@ class BasedSharedMem(ABC):
).reshape(shape)
def get_uuid(self, offset: int) -> Tuple[uuid.UUID, int]:
return uuid.UUID(bytes_le= self.accessor[offset:offset+16]), offset+16
return uuid.UUID(bytes_le=self.accessor[offset : offset + 16]), offset + 16
def set_int(self, offset:int, value:int) -> int:
def set_int(self, offset: int, value: int) -> int:
struct.pack_into("<i", self.accessor, offset, value)
return offset + 4
def set_float(self, offset:int, value:float) -> int:
def set_float(self, offset: int, value: float) -> int:
struct.pack_into("<f", self.accessor, offset, value)
return offset + 4
def set_bool(self, offset:int, value:bool) -> int:
def set_bool(self, offset: int, value: bool) -> int:
struct.pack_into("<?", self.accessor, offset, value)
return offset + 1
def set_string(self, offset:int, value:str) -> int:
def set_string(self, offset: int, value: str) -> int:
string_len = len(value)
struct.pack_into("<B", self.accessor, offset, string_len)
offset += 1
self.accessor[offset: offset + string_len] = value.encode("ascii")
self.accessor[offset : offset + string_len] = value.encode("ascii")
return offset + string_len
def set_uuid(self, offset: int, value: uuid.UUID) -> int:
self.accessor[offset: offset + 16] = value.bytes_le
self.accessor[offset : offset + 16] = value.bytes_le
return offset + 16
def set_ndarray(self, offset: int, data: np.ndarray) -> None:
bytes_data = data.tobytes()
self.accessor[offset: offset + len(bytes_data)] = bytes_data
self.accessor[offset : offset + len(bytes_data)] = bytes_data
def close(self) -> None:
if self.accessor is not None:
self.accessor.close()
self.accessor = None
self.accessor = None # type: ignore
def delete(self) -> None:
self.close()
try:
os.remove(self._file_path)
except:
except BaseException:
pass

Просмотреть файл

@ -1,20 +1,28 @@
import numpy as np
from mlagents_dots_envs.shared_memory.base_shared_mem import BasedSharedMem
from mlagents_dots_envs.shared_memory.rl_data_offset import RLDataOffsets
from typing import Dict, List, Optional, Any, Tuple
from mlagents_envs.base_env import DecisionSteps, TerminalSteps, BehaviorSpec, ActionType
from typing import Dict, List
from mlagents_envs.base_env import (
DecisionSteps,
TerminalSteps,
BehaviorSpec,
ActionType,
)
class DataSharedMem(BasedSharedMem):
def __init__(self,
file_name: str,
create_file: bool = False,
copy_from: "DataSharedMem" = None,
side_channel_buffer_size: int = 0,
rl_data_buffer_size: int = 0):
def __init__(
self,
file_name: str,
create_file: bool = False,
copy_from: "DataSharedMem" = None,
side_channel_buffer_size: int = 0,
rl_data_buffer_size: int = 0,
):
self._offset_dict: Dict[str, RLDataOffsets] = {}
if create_file and copy_from is None:
size = side_channel_buffer_size + rl_data_buffer_size
super(DataSharedMem, self).__init__( file_name, True, size)
super(DataSharedMem, self).__init__(file_name, True, size)
self._side_channel_buffer_size = side_channel_buffer_size
self._rl_data_buffer_size = rl_data_buffer_size
return
@ -46,7 +54,7 @@ class DataSharedMem(BasedSharedMem):
def side_channel_data(self) -> bytearray:
offset = 0
len_data, offset = self.get_int(offset)
return self.accessor[offset: offset + len_data]
return self.accessor[offset : offset + len_data]
@side_channel_data.setter
def side_channel_data(self, data: bytearray) -> None:
@ -64,7 +72,7 @@ class DataSharedMem(BasedSharedMem):
def rl_data(self) -> bytearray:
offset = self.rl_data_offset
size = self._rl_data_buffer_size
return self.accessor[offset: offset + size]
return self.accessor[offset : offset + size]
@rl_data.setter
def rl_data(self, data: bytearray) -> None:
@ -72,7 +80,7 @@ class DataSharedMem(BasedSharedMem):
raise Exception("TODO")
offset = self.rl_data_offset
size = self._rl_data_buffer_size
self.accessor[offset: offset + size] = data
self.accessor[offset : offset + size] = data
self._refresh_offsets()
def get_decision_steps(self, key: str) -> DecisionSteps:
@ -80,34 +88,48 @@ class DataSharedMem(BasedSharedMem):
offsets = self._offset_dict[key]
n_agents, _ = self.get_int(offsets.decision_n_agents_offset)
obs: List[np.ndarray] = []
for obs_offset, obs_shape in zip(offsets.decision_obs_offset, offsets.obs_shapes):
for obs_offset, obs_shape in zip(
offsets.decision_obs_offset, offsets.obs_shapes
):
obs_shape = (n_agents,) + obs_shape
arr = self.get_ndarray(obs_offset, obs_shape, np.float32)
obs.append(arr)
return DecisionSteps(
obs=obs,
reward=self.get_ndarray(offsets.decision_rewards_offset, (n_agents), np.float32),
agent_id=self.get_ndarray(offsets.decision_agent_id_offset, (n_agents), np.int32),
action_mask=None #TODO
reward=self.get_ndarray(
offsets.decision_rewards_offset, (n_agents), np.float32
),
agent_id=self.get_ndarray(
offsets.decision_agent_id_offset, (n_agents), np.int32
),
action_mask=None, # TODO
)
def get_terminal_steps(self, key) -> TerminalSteps:
def get_terminal_steps(self, key: str) -> TerminalSteps:
assert key in self._offset_dict
offsets = self._offset_dict[key]
n_agents, _ = self.get_int(offsets.termination_n_agents_offset)
obs: List[np.ndarray] = []
for obs_offset, obs_shape in zip(offsets.termination_obs_offset, offsets.obs_shapes):
for obs_offset, obs_shape in zip(
offsets.termination_obs_offset, offsets.obs_shapes
):
obs_shape = (n_agents,) + obs_shape
arr= self.get_ndarray(obs_offset, obs_shape, np.float32)
arr = self.get_ndarray(obs_offset, obs_shape, np.float32)
obs.append(arr)
return TerminalSteps(
obs=obs,
reward=self.get_ndarray(offsets.termination_reward_offset, (n_agents), np.float32),
agent_id=self.get_ndarray(offsets.termination_agent_id_offset, (n_agents), np.int32),
max_step=self.get_ndarray(offsets.termination_status_offset, (n_agents), np.bool),
reward=self.get_ndarray(
offsets.termination_reward_offset, (n_agents), np.float32
),
agent_id=self.get_ndarray(
offsets.termination_agent_id_offset, (n_agents), np.int32
),
max_step=self.get_ndarray(
offsets.termination_status_offset, (n_agents), np.bool
),
)
def set_actions(self, key: str, data:np.ndarray) -> None:
def set_actions(self, key: str, data: np.ndarray) -> None:
assert key in self._offset_dict
offsets = self._offset_dict[key]
self.set_ndarray(offsets.action_offset, data)

Просмотреть файл

@ -1,6 +1,8 @@
import os, glob
import os
import glob
from mlagents_dots_envs.shared_memory.base_shared_mem import BasedSharedMem
class MasterSharedMem(BasedSharedMem):
"""
Always created by Python
@ -14,11 +16,17 @@ class MasterSharedMem(BasedSharedMem):
- int : Communication file "side channel" size in bytes
- int : Communication file "RL section" size in bytes
"""
SIZE = 28
VERSION = (0,3,0)
def __init__(self, file_name: str, side_channel_size=0, rl_data_size=0):
super(MasterSharedMem, self).__init__(file_name, create_file=True, size=self.SIZE)
SIZE = 28
VERSION = (0, 3, 0)
def __init__(
self, file_name: str, side_channel_size: int = 0, rl_data_size: int = 0
):
super(MasterSharedMem, self).__init__(
file_name, create_file=True, size=self.SIZE
)
for f in glob.glob(file_name + "_"):
# Removing all the future files in case they were not correctly created
os.remove(f)
@ -76,24 +84,24 @@ class MasterSharedMem(BasedSharedMem):
self.set_bool(offset, True)
@property
def side_channel_size(self):
def side_channel_size(self) -> int:
offset = 20
result, _ = self.get_int(offset)
return result
@side_channel_size.setter
def side_channel_size(self, value:int):
def side_channel_size(self, value: int) -> None:
offset = 20
self.set_int(offset, value)
@property
def rl_data_size(self):
def rl_data_size(self) -> int:
offset = 24
result, _ = self.get_int(offset)
return result
@rl_data_size.setter
def rl_data_size(self, value:int):
def rl_data_size(self, value: int) -> None:
offset = 24
self.set_int(offset, value)
@ -103,8 +111,8 @@ class MasterSharedMem(BasedSharedMem):
minor, offset = self.get_int(offset)
bug, offset = self.get_int(offset)
if (major, minor, bug) != self.VERSION:
raise Exception(f"Incompatible versions of communicator between " +
f"Unity {major}.{minor}.{bug} and Python "
f"{self.VERSION[0]}.{self.VERSION[1]}.{self.VERSION[2]}")
raise Exception(
f"Incompatible versions of communicator between "
+ f"Unity {major}.{minor}.{bug} and Python "
f"{self.VERSION[0]}.{self.VERSION[1]}.{self.VERSION[2]}"
)

Просмотреть файл

@ -1,14 +1,16 @@
import numpy as np
from typing import Tuple, Optional, NamedTuple, List, Dict
from typing import Tuple, Optional, NamedTuple, List
from mlagents_dots_envs.shared_memory.base_shared_mem import BasedSharedMem
class RLDataOffsets(NamedTuple):
"""
Contains the offsets to the data for a section of the RL data
"""
# data
name:str
name: str
max_n_agents: int
is_action_continuous: bool
action_size: int
@ -47,28 +49,28 @@ class RLDataOffsets(NamedTuple):
# 4 bytes : n_agents at current step
# ? Bytes : the data : obs,reward,done,max_step,agent_id,masks,action
## Get the specs of the group
# Get the specs of the group
name, offset = mem.get_string(offset)
max_n_agents, offset = mem.get_int(offset)
is_continuous, offset = mem.get_bool(offset)
action_size, offset = mem.get_int(offset)
discrete_branches = None
if not is_continuous:
discrete_branches = ()
discrete_branches = () # type: ignore
for _ in range(action_size):
branch_size, offset = mem.get_int(offset)
discrete_branches += (branch_size,)
discrete_branches += (branch_size,) # type: ignore
n_obs, offset = mem.get_int(offset)
obs_shapes: List[Tuple[int, ...]] = []
for _ in range(n_obs):
shape = ()
shape = () # type: ignore
for _ in range(3):
s, offset = mem.get_int(offset)
if s != 0:
shape += (s,)
shape += (s,) # type: ignore
obs_shapes += [shape]
## Compute the offsets for decision steps
# Compute the offsets for decision steps
# n_agents
decision_n_agents_offset = offset
_, offset = mem.get_int(offset)
@ -90,7 +92,7 @@ class RLDataOffsets(NamedTuple):
mask_offset = offset
offset += max_n_agents * int(np.sum(discrete_branches))
## Compute the offsets for termination steps
# Compute the offsets for termination steps
# n_agents
termination_n_agents_offset = offset
_, offset = mem.get_int(offset)
@ -102,14 +104,14 @@ class RLDataOffsets(NamedTuple):
# rewards
termination_reward_offset = offset
offset += 4 * max_n_agents
#status
# status
termination_status_offset = offset
offset += max_n_agents
# agent id
termination_agent_id_offset = offset
offset += 4 * max_n_agents
## Compute the offsets for actions
# Compute the offsets for actions
act_offset = offset
offset += 4 * max_n_agents * action_size

Просмотреть файл

@ -1,15 +1,8 @@
import mmap
import struct
import numpy as np
import tempfile
import os
import time
import uuid
from datetime import datetime
import string
from enum import IntEnum
import random
from typing import Tuple, Optional, NamedTuple, List, Dict
from typing import Tuple, Dict
from mlagents_dots_envs.shared_memory.master_shared_mem import MasterSharedMem
from mlagents_dots_envs.shared_memory.data_shared_mem import DataSharedMem
@ -17,6 +10,7 @@ from mlagents_dots_envs.shared_memory.data_shared_mem import DataSharedMem
from mlagents_envs.exception import UnityCommunicationException
from mlagents_envs.base_env import DecisionSteps, TerminalSteps, BehaviorSpec
class SharedMemCom:
FILE_DEFAULT = "default"
MAX_TIMEOUT_IN_SECONDS = 30
@ -34,9 +28,10 @@ class SharedMemCom:
self._data_mem = DataSharedMem(
file_name + "_" * self._current_file_number,
create_file=True,
copy_from = None,
copy_from=None,
side_channel_buffer_size=4,
rl_data_buffer_size=0)
rl_data_buffer_size=0,
)
@property
def communicator_id(self):
@ -53,15 +48,16 @@ class SharedMemCom:
def write_side_channel_data(self, data: bytearray) -> None:
capacity = self._master_mem.side_channel_size
if len(data) >= capacity - 4: # need 4 bytes for an integer size
new_capacity = 2*len(data)+20
new_capacity = 2 * len(data) + 20
self._current_file_number += 1
self._master_mem.file_number = self._current_file_number
tmp = self._data_mem
self._data_mem = DataSharedMem(self._base_file_name + "_" * self._current_file_number,
self._data_mem = DataSharedMem(
self._base_file_name + "_" * self._current_file_number,
create_file=True,
copy_from=tmp,
side_channel_buffer_size=new_capacity,
rl_data_buffer_size=self._master_mem.rl_data_size
rl_data_buffer_size=self._master_mem.rl_data_size,
)
tmp.close()
# Unity is responsible for destroying the old file
@ -73,7 +69,7 @@ class SharedMemCom:
self._data_mem.side_channel_data = bytearray()
return result
def give_unity_control(self, reset:bool = False):
def give_unity_control(self, reset: bool = False) -> None:
self._master_mem.mark_python_blocked()
if reset:
self._master_mem.mark_reset()
@ -104,15 +100,19 @@ class SharedMemCom:
self._data_mem = DataSharedMem(
self._base_file_name + "_" * self._current_file_number,
side_channel_buffer_size=self._master_mem.side_channel_size,
rl_data_buffer_size=self._master_mem.rl_data_size)
rl_data_buffer_size=self._master_mem.rl_data_size,
)
def get_steps(self, key: str) -> Tuple[DecisionSteps, TerminalSteps]:
return self._data_mem.get_decision_steps(key), self._data_mem.get_terminal_steps(key)
return (
self._data_mem.get_decision_steps(key),
self._data_mem.get_terminal_steps(key),
)
def get_n_decisions_requested(self, key: str) -> int:
return self._data_mem.get_n_decisions_requested(key)
def set_actions(self, key: str, data:np.ndarray) -> None:
def set_actions(self, key: str, data: np.ndarray) -> None:
self._data_mem.set_actions(key, data)
@property

Просмотреть файл

@ -1,11 +1,7 @@
import atexit
import logging
import numpy as np
import uuid
import signal
import struct
import subprocess
from typing import Dict, List, Optional, Any, Tuple
from typing import List, Optional, Tuple
from mlagents_envs.side_channel.side_channel import SideChannel
@ -15,15 +11,9 @@ from mlagents_envs.base_env import (
TerminalSteps,
BehaviorSpec,
BehaviorName,
ActionType,
)
from mlagents_envs.timers import timed, hierarchical_timer
from mlagents_envs.exception import (
UnityEnvironmentException,
UnityCommunicationException,
UnityActionException,
UnityTimeOutException,
)
from mlagents_envs.timers import timed
from mlagents_envs.exception import UnityCommunicationException, UnityActionException
from mlagents_dots_envs.shared_memory.shared_mem_com import SharedMemCom
@ -33,40 +23,41 @@ from mlagents_dots_envs.env_utils import (
generate_side_channel_data,
parse_side_channel_message,
returncode_to_signal_name,
validate_environment_path as ttt
validate_environment_path as ttt,
)
from mlagents_envs.logging_util import get_logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mlagents.envs")
logger = get_logger(__name__)
class UnityEnvironment(BaseEnv):
API_VERSION = "API-14" # TODO : REMOVE
DEFAULT_EDITOR_PORT = 5004 # TODO : REMOVE
BASE_ENVIRONMENT_PORT = 5005 # TODO : REMOVE
PORT_COMMAND_LINE_ARG = "--mlagents-port" # TODO : REMOVE
API_VERSION = "API-14" # TODO : REMOVE
DEFAULT_EDITOR_PORT = 5004 # TODO : REMOVE
BASE_ENVIRONMENT_PORT = 5005 # TODO : REMOVE
PORT_COMMAND_LINE_ARG = "--mlagents-port" # TODO : REMOVE
@staticmethod
def validate_environment_path( path__):
def validate_environment_path(path__):
return ttt(path__)
def __init__(
self,
worker_id=None, # TODO : REMOVE
seed=None, # TODO : REMOVE
docker_training=None, # TODO : REMOVE
no_graphics=None, # TODO : REMOVE
base_port=None, # TODO : REMOVE
worker_id: Optional[int] = None, # TODO : REMOVE
seed: Optional[int] = None, # TODO : REMOVE
docker_trainin: Optional[bool] = None, # TODO : REMOVE
no_graphics: Optional[bool] = None, # TODO : REMOVE
base_port: Optional[int] = None, # TODO : REMOVE
file_name: Optional[str] = None,
args: Optional[List[str]] = None,
side_channels: Optional[List[SideChannel]] = None,
):
"""
Starts a new unity environment and establishes a connection with the environment.
Starts a new unity environment and establishes a connection with it.
:string file_name: Name of Unity environment binary. If None, will try to connect to the Editor.
:string file_name: Name of Unity environment binary. If None, will try to
connect to the Editor.
:list args: Addition Unity command line arguments
:list side_channels: Additional side channel for not-rl communication with Unity
"""
@ -93,15 +84,14 @@ class UnityEnvironment(BaseEnv):
self._communicator.give_unity_control()
self._communicator.wait_for_unity()
def reset(self) -> None:
self._step(reset = True)
self._step(reset=True)
@timed
def step(self) -> None:
self._step(reset = False)
self._step(reset=False)
def _step(self, reset:bool= False) -> None:
def _step(self, reset: bool = False) -> None:
if not self._communicator.active:
raise UnityCommunicationException("Communicator has stopped.")
channel_data = generate_side_channel_data(self._side_channels)
@ -122,8 +112,8 @@ class UnityEnvironment(BaseEnv):
def _assert_behavior_exists(self, behavior_name: BehaviorName) -> None:
if behavior_name not in self._env_specs:
raise UnityActionException(
"The behavior {0} does not correspond to one existing "
"in the environment".format(behavior_name)
f"The behavior {behavior_name} does not correspond to one existing "
f"in the environment"
)
def set_actions(self, behavior_name: BehaviorName, action: np.array) -> None:
@ -131,16 +121,15 @@ class UnityEnvironment(BaseEnv):
expected_n_agents = self._communicator.get_n_decisions_requested(behavior_name)
if expected_n_agents == 0 and len(action) != 0:
raise UnityActionException(
"The behavior {0} does not need an input this step".format(behavior_name)
f"The behavior {behavior_name} does not need an input this step"
)
spec = self._env_specs[behavior_name]
expected_type = np.float32 if spec.is_action_continuous() else np.int32
expected_shape = (expected_n_agents, spec.action_size)
if action.shape != expected_shape:
raise UnityActionException(
"The behavior {0} needs an input of dimension {1} but received input of dimension {2}".format(
behavior_name, expected_shape, action.shape
)
f"The behavior {behavior_name} needs an input of dimension"
f"{expected_shape} but received input of dimension {action.shape}"
)
if action.dtype != expected_type:
action = action.astype(expected_type)
@ -151,17 +140,19 @@ class UnityEnvironment(BaseEnv):
) -> None:
raise NotImplementedError("Method not implemented.")
def get_steps(self, behavior_name: BehaviorName) -> Tuple[DecisionSteps, TerminalSteps]:
def get_steps(
self, behavior_name: BehaviorName
) -> Tuple[DecisionSteps, TerminalSteps]:
self._assert_behavior_exists(behavior_name)
return self._communicator.get_steps(behavior_name)
def get_behavior_spec(self, behavior_name:BehaviorName) -> BehaviorSpec:
def get_behavior_spec(self, behavior_name: BehaviorName) -> BehaviorSpec:
self._assert_behavior_exists(behavior_name)
return self._env_specs[behavior_name]
def close(self):
"""
Sends a shutdown signal to the unity environment, and closes the socket connection.
Sends a shutdown signal to the unity environment, and closes the communication.
"""
self._communicator.close()
if self.proc1 is not None:
@ -170,11 +161,13 @@ class UnityEnvironment(BaseEnv):
self.proc1.wait(timeout=5)
signal_name = returncode_to_signal_name(self.proc1.returncode)
signal_name = f" ({signal_name})" if signal_name else ""
return_info = f"Environment shut down with return code {self.proc1.returncode}{signal_name}."
return_info = (
f"Environment shut down with return"
f"code{self.proc1.returncode}{signal_name}."
)
logger.info(return_info)
except subprocess.TimeoutExpired:
logger.info("Environment timed out shutting down. Killing...")
self.proc1.kill()
# Set to None so we don't try to close multiple times.
self.proc1 = None

Просмотреть файл

@ -1,8 +1,4 @@
import os
import sys
from setuptools import setup, find_packages
from setuptools.command.install import install
import mlagents_dots_envs
setup(
name="mlagents_dots_envs",
@ -17,8 +13,10 @@ setup(
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
],
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests", "examples"]),
packages=find_packages(
exclude=["*.tests", "*.tests.*", "tests.*", "tests", "examples"]
),
zip_safe=False,
install_requires=[ "mlagents_envs>=0.14.1"],
install_requires=["mlagents_envs>=0.14.1"],
python_requires=">=3.6",
)