Modified some formatting and added comments
This commit is contained in:
Родитель
d891643238
Коммит
c19ffc2a89
|
@ -10,3 +10,11 @@
|
|||
*obj
|
||||
.pytest_cache
|
||||
*__pycache__
|
||||
|
||||
# precommit
|
||||
.mypy_cache
|
||||
.pre-commit-config.yaml
|
||||
.pylintrc
|
||||
.vscode
|
||||
setup.cfg
|
||||
setup.cfg.meta
|
||||
|
|
|
@ -36,7 +36,7 @@ This package is available as a preview, so it is not ready for production use. T
|
|||
```python
|
||||
from mlagents_envs.environment import UnityEnvironment
|
||||
```
|
||||
with
|
||||
with
|
||||
```python
|
||||
from mlagents_dots_envs.unity_environment import UnityEnvironment
|
||||
```
|
||||
|
@ -45,7 +45,7 @@ This package is available as a preview, so it is not ready for production use. T
|
|||
```python
|
||||
from mlagents_envs.environment import UnityEnvironment
|
||||
```
|
||||
with
|
||||
with
|
||||
```python
|
||||
from mlagents_dots_envs.unity_environment import UnityEnvironment
|
||||
```
|
||||
|
@ -62,7 +62,7 @@ This package is available as a preview, so it is not ready for production use. T
|
|||
|
||||
|
||||
## API
|
||||
One approach to designing ml-agents to be compativle with DOTS would be to use typical API used for example in [Unity.Physics](https://github.com/Unity-Technologies/Unity.Physics) where a "MLAgents World" holds data, processes it and the data can then be retrieved.
|
||||
One approach to designing ml-agents to be compativle with DOTS would be to use typical API used for example in [Unity.Physics](https://github.com/Unity-Technologies/Unity.Physics) where a "MLAgents World" holds data, processes it and the data can then be retrieved.
|
||||
The user would access the `MLAgentsWorld` in the main thread :
|
||||
|
||||
```csharp
|
||||
|
@ -71,9 +71,9 @@ var world = new MLAgentsWorld(
|
|||
new int3[] { new int3(3, 0, 0) }, // The observation shapes (here, one observation of shape (3,0,0))
|
||||
ActionType.CONTINUOUS, // Continuous = float, Discrete = int
|
||||
3); // The number of actions
|
||||
|
||||
|
||||
world.SubscribeWorldWithBarracudaModel(Name, Model, InferenceDevice);
|
||||
```
|
||||
```
|
||||
|
||||
The user could then in his own jobs add and retrieve data from the world. Here is an example of a job in which the user populates the sensor data :
|
||||
|
||||
|
@ -117,7 +117,7 @@ world.RequestDecision(entities[i])
|
|||
.SetObservationFromSlice(1, visObs.Slice());
|
||||
```
|
||||
|
||||
In order to retrieve actions, we use a custom job :
|
||||
In order to retrieve actions, we use a custom job :
|
||||
|
||||
```csharp
|
||||
public struct UserCreatedActionEventJob : IActuatorJob
|
||||
|
@ -162,7 +162,7 @@ __Note__ : The python code for communication is located in [ml-agents-envs~](./m
|
|||
#### RL Data section
|
||||
|
||||
- int : 4 bytes : The number of Agent groups in the simulation (starts at 0)
|
||||
- For each group :
|
||||
- For each group :
|
||||
|
||||
- string : 64 byte : group name
|
||||
- int : 4 bytes : maximum number of Agents
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
Unity.AI.MLAgents copyright © 2020 Unity Technologies ApS
|
||||
|
||||
Licensed under the Unity Companion License for Unity-dependent projects--see [Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
|
||||
Licensed under the Unity Companion License for Unity-dependent projects--see [Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
|
||||
|
||||
Unless expressly provided otherwise, the Software under this license is made available strictly on an “AS IS” BASIS WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. Please review the license for details on these and other terms and conditions.
|
||||
|
|
14
README.md
14
README.md
|
@ -1,7 +1,7 @@
|
|||
|
||||
# ML-Agents DOTS
|
||||
|
||||
[ML-Agents on DOTS Proposal Google Doc](https://docs.google.com/document/d/1QnGSjOfLpwaRopbMf9ZDC89oZJuG0Ii6ORA22a5TWzE/edit#heading=h.py1zfmz3396x)
|
||||
|
||||
[Documentation](./Documentation~/README.md)
|
||||
|
||||
|
||||
# ML-Agents DOTS
|
||||
|
||||
[ML-Agents on DOTS Proposal Google Doc](https://docs.google.com/document/d/1QnGSjOfLpwaRopbMf9ZDC89oZJuG0Ii6ORA22a5TWzE/edit#heading=h.py1zfmz3396x)
|
||||
|
||||
[Documentation](./Documentation~/README.md)
|
||||
|
||||
|
|
|
@ -1,80 +1,80 @@
|
|||
using Unity.Entities;
|
||||
using Unity.Core;
|
||||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
public static class TimeUtils
|
||||
{
|
||||
/// <summary>
|
||||
/// Configure the given ComponentSystemGroup to update at a fixed timestep, given by timeStep.
|
||||
/// If the interval between the current time and the last update is bigger than the timestep
|
||||
/// multiplied by the time scale,
|
||||
/// the group's systems will be updated more than once.
|
||||
/// </summary>
|
||||
/// <param name="group">The group whose UpdateCallback will be configured with a fixed time step update call</param>
|
||||
/// <param name="timeStep">The fixed time step (in seconds)</param>
|
||||
/// <param name="timeScale">How much time passes in the group compared to other systems</param>
|
||||
public static void EnableFixedRateWithCatchUp(ComponentSystemGroup group, float timeStep, float timeScale)
|
||||
{
|
||||
var manager = new FixedRateTimeScaleCatchUpManager(timeStep, timeScale);
|
||||
group.UpdateCallback = manager.UpdateCallback;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disable fixed rate updates on the given group, by setting the UpdateCallback to null.
|
||||
/// </summary>
|
||||
/// <param name="group">The group whose UpdateCallback to set to null.</param>
|
||||
public static void DisableFixedRate(ComponentSystemGroup group)
|
||||
{
|
||||
group.UpdateCallback = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
internal class FixedRateTimeScaleCatchUpManager
|
||||
{
|
||||
protected float m_TimeScale;
|
||||
protected float m_FixedTimeStep;
|
||||
protected double m_LastFixedUpdateTime;
|
||||
protected int m_FixedUpdateCount;
|
||||
protected bool m_DidPushTime;
|
||||
|
||||
internal FixedRateTimeScaleCatchUpManager(float fixedStep, float timeScale)
|
||||
{
|
||||
m_FixedTimeStep = fixedStep;
|
||||
m_TimeScale = timeScale;
|
||||
}
|
||||
|
||||
internal bool UpdateCallback(ComponentSystemGroup group)
|
||||
{
|
||||
// if this is true, means we're being called a second or later time in a loop
|
||||
if (m_DidPushTime)
|
||||
{
|
||||
group.World.PopTime();
|
||||
}
|
||||
|
||||
var elapsedTime = group.World.Time.ElapsedTime * m_TimeScale;
|
||||
if (m_LastFixedUpdateTime == 0.0)
|
||||
m_LastFixedUpdateTime = elapsedTime - m_FixedTimeStep;
|
||||
|
||||
if (elapsedTime - m_LastFixedUpdateTime >= m_FixedTimeStep)
|
||||
{
|
||||
// Note that m_FixedTimeStep of 0.0f will never update
|
||||
m_LastFixedUpdateTime += m_FixedTimeStep;
|
||||
m_FixedUpdateCount++;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_DidPushTime = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
group.World.PushTime(new TimeData(
|
||||
elapsedTime: m_LastFixedUpdateTime,
|
||||
deltaTime: m_FixedTimeStep));
|
||||
|
||||
m_DidPushTime = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
using Unity.Entities;
|
||||
using Unity.Core;
|
||||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
public static class TimeUtils
|
||||
{
|
||||
/// <summary>
|
||||
/// Configure the given ComponentSystemGroup to update at a fixed timestep, given by timeStep.
|
||||
/// If the interval between the current time and the last update is bigger than the timestep
|
||||
/// multiplied by the time scale,
|
||||
/// the group's systems will be updated more than once.
|
||||
/// </summary>
|
||||
/// <param name="group">The group whose UpdateCallback will be configured with a fixed time step update call</param>
|
||||
/// <param name="timeStep">The fixed time step (in seconds)</param>
|
||||
/// <param name="timeScale">How much time passes in the group compared to other systems</param>
|
||||
public static void EnableFixedRateWithCatchUp(ComponentSystemGroup group, float timeStep, float timeScale)
|
||||
{
|
||||
var manager = new FixedRateTimeScaleCatchUpManager(timeStep, timeScale);
|
||||
group.UpdateCallback = manager.UpdateCallback;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disable fixed rate updates on the given group, by setting the UpdateCallback to null.
|
||||
/// </summary>
|
||||
/// <param name="group">The group whose UpdateCallback to set to null.</param>
|
||||
public static void DisableFixedRate(ComponentSystemGroup group)
|
||||
{
|
||||
group.UpdateCallback = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
internal class FixedRateTimeScaleCatchUpManager
|
||||
{
|
||||
protected float m_TimeScale;
|
||||
protected float m_FixedTimeStep;
|
||||
protected double m_LastFixedUpdateTime;
|
||||
protected int m_FixedUpdateCount;
|
||||
protected bool m_DidPushTime;
|
||||
|
||||
internal FixedRateTimeScaleCatchUpManager(float fixedStep, float timeScale)
|
||||
{
|
||||
m_FixedTimeStep = fixedStep;
|
||||
m_TimeScale = timeScale;
|
||||
}
|
||||
|
||||
internal bool UpdateCallback(ComponentSystemGroup group)
|
||||
{
|
||||
// if this is true, means we're being called a second or later time in a loop
|
||||
if (m_DidPushTime)
|
||||
{
|
||||
group.World.PopTime();
|
||||
}
|
||||
|
||||
var elapsedTime = group.World.Time.ElapsedTime * m_TimeScale;
|
||||
if (m_LastFixedUpdateTime == 0.0)
|
||||
m_LastFixedUpdateTime = elapsedTime - m_FixedTimeStep;
|
||||
|
||||
if (elapsedTime - m_LastFixedUpdateTime >= m_FixedTimeStep)
|
||||
{
|
||||
// Note that m_FixedTimeStep of 0.0f will never update
|
||||
m_LastFixedUpdateTime += m_FixedTimeStep;
|
||||
m_FixedUpdateCount++;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_DidPushTime = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
group.World.PushTime(new TimeData(
|
||||
elapsedTime: m_LastFixedUpdateTime,
|
||||
deltaTime: m_FixedTimeStep));
|
||||
|
||||
m_DidPushTime = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,78 +1,78 @@
|
|||
using Unity.Entities;
|
||||
using Unity.Collections;
|
||||
using Unity.Mathematics;
|
||||
using Unity.Transforms;
|
||||
using UnityEngine;
|
||||
using Random = UnityEngine.Random;
|
||||
using Unity.AI.MLAgents;
|
||||
|
||||
public class BalanceBallManager : MonoBehaviour
|
||||
{
|
||||
public MLAgentsWorldSpecs MyWorldSpecs;
|
||||
|
||||
private EntityManager manager;
|
||||
public GameObject prefabPlatform;
|
||||
public GameObject prefabBall;
|
||||
private Entity _prefabEntityPlatform;
|
||||
private Entity _prefabEntityBall;
|
||||
int currentIndex;
|
||||
|
||||
void Awake()
|
||||
{
|
||||
var world = MyWorldSpecs.GetWorld();
|
||||
var ballSystem = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem<BallSystem>();
|
||||
ballSystem.Enabled = true;
|
||||
ballSystem.BallWorld = world;
|
||||
|
||||
manager = World.DefaultGameObjectInjectionWorld.EntityManager;
|
||||
|
||||
BlobAssetStore blob = new BlobAssetStore();
|
||||
GameObjectConversionSettings settings = GameObjectConversionSettings.FromWorld(World.DefaultGameObjectInjectionWorld, blob);
|
||||
_prefabEntityPlatform = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabPlatform, settings);
|
||||
_prefabEntityBall = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabBall, settings);
|
||||
|
||||
Spawn(1000);
|
||||
blob.Dispose();
|
||||
}
|
||||
|
||||
void Spawn(int amount)
|
||||
{
|
||||
NativeArray<Entity> entitiesP = new NativeArray<Entity>(amount, Allocator.Temp);
|
||||
NativeArray<Entity> entitiesB = new NativeArray<Entity>(amount, Allocator.Temp);
|
||||
manager.Instantiate(_prefabEntityPlatform, entitiesP);
|
||||
manager.Instantiate(_prefabEntityBall, entitiesB);
|
||||
for (int i = 0; i < amount; i++)
|
||||
{
|
||||
float3 position = new float3((currentIndex % 10) - 5, (currentIndex / 10 % 10) - 5, currentIndex / 100) * 5f;
|
||||
float valX = Random.Range(-0.1f, 0.1f);
|
||||
float valZ = Random.Range(-0.1f, 0.1f);
|
||||
manager.SetComponentData(entitiesP[i],
|
||||
new Translation
|
||||
{
|
||||
Value = position
|
||||
});
|
||||
manager.SetComponentData(entitiesB[i],
|
||||
new Translation
|
||||
{
|
||||
Value = position + new float3(0, 0.2f, 0)
|
||||
});
|
||||
|
||||
manager.SetComponentData(entitiesP[i],
|
||||
new Rotation
|
||||
{
|
||||
Value = quaternion.EulerXYZ(valX, 0, valZ)
|
||||
});
|
||||
manager.AddComponent<AgentData>(entitiesP[i]);
|
||||
manager.SetComponentData(entitiesP[i], new AgentData {
|
||||
BallResetPosition = position + new float3(0, 0.2f, 0),
|
||||
BallRef = entitiesB[i],
|
||||
StepCount = 0
|
||||
});
|
||||
manager.AddComponent<Actuator>(entitiesP[i]);
|
||||
currentIndex++;
|
||||
}
|
||||
|
||||
entitiesP.Dispose();
|
||||
entitiesB.Dispose();
|
||||
}
|
||||
}
|
||||
using Unity.Entities;
|
||||
using Unity.Collections;
|
||||
using Unity.Mathematics;
|
||||
using Unity.Transforms;
|
||||
using UnityEngine;
|
||||
using Random = UnityEngine.Random;
|
||||
using Unity.AI.MLAgents;
|
||||
|
||||
public class BalanceBallManager : MonoBehaviour
|
||||
{
|
||||
public MLAgentsWorldSpecs MyWorldSpecs;
|
||||
|
||||
private EntityManager manager;
|
||||
public GameObject prefabPlatform;
|
||||
public GameObject prefabBall;
|
||||
private Entity _prefabEntityPlatform;
|
||||
private Entity _prefabEntityBall;
|
||||
int currentIndex;
|
||||
|
||||
void Awake()
|
||||
{
|
||||
var world = MyWorldSpecs.GetWorld();
|
||||
var ballSystem = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem<BallSystem>();
|
||||
ballSystem.Enabled = true;
|
||||
ballSystem.BallWorld = world;
|
||||
|
||||
manager = World.DefaultGameObjectInjectionWorld.EntityManager;
|
||||
|
||||
BlobAssetStore blob = new BlobAssetStore();
|
||||
GameObjectConversionSettings settings = GameObjectConversionSettings.FromWorld(World.DefaultGameObjectInjectionWorld, blob);
|
||||
_prefabEntityPlatform = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabPlatform, settings);
|
||||
_prefabEntityBall = GameObjectConversionUtility.ConvertGameObjectHierarchy(prefabBall, settings);
|
||||
|
||||
Spawn(1000);
|
||||
blob.Dispose();
|
||||
}
|
||||
|
||||
void Spawn(int amount)
|
||||
{
|
||||
NativeArray<Entity> entitiesP = new NativeArray<Entity>(amount, Allocator.Temp);
|
||||
NativeArray<Entity> entitiesB = new NativeArray<Entity>(amount, Allocator.Temp);
|
||||
manager.Instantiate(_prefabEntityPlatform, entitiesP);
|
||||
manager.Instantiate(_prefabEntityBall, entitiesB);
|
||||
for (int i = 0; i < amount; i++)
|
||||
{
|
||||
float3 position = new float3((currentIndex % 10) - 5, (currentIndex / 10 % 10) - 5, currentIndex / 100) * 5f;
|
||||
float valX = Random.Range(-0.1f, 0.1f);
|
||||
float valZ = Random.Range(-0.1f, 0.1f);
|
||||
manager.SetComponentData(entitiesP[i],
|
||||
new Translation
|
||||
{
|
||||
Value = position
|
||||
});
|
||||
manager.SetComponentData(entitiesB[i],
|
||||
new Translation
|
||||
{
|
||||
Value = position + new float3(0, 0.2f, 0)
|
||||
});
|
||||
|
||||
manager.SetComponentData(entitiesP[i],
|
||||
new Rotation
|
||||
{
|
||||
Value = quaternion.EulerXYZ(valX, 0, valZ)
|
||||
});
|
||||
manager.AddComponent<AgentData>(entitiesP[i]);
|
||||
manager.SetComponentData(entitiesP[i], new AgentData {
|
||||
BallResetPosition = position + new float3(0, 0.2f, 0),
|
||||
BallRef = entitiesB[i],
|
||||
StepCount = 0
|
||||
});
|
||||
manager.AddComponent<Actuator>(entitiesP[i]);
|
||||
currentIndex++;
|
||||
}
|
||||
|
||||
entitiesP.Dispose();
|
||||
entitiesB.Dispose();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,127 +1,127 @@
|
|||
using Unity.Entities;
|
||||
using Unity.Jobs;
|
||||
using Unity.Mathematics;
|
||||
using Unity.Transforms;
|
||||
using Unity.AI.MLAgents;
|
||||
using Unity.Collections;
|
||||
using Unity.Physics;
|
||||
|
||||
|
||||
public struct AgentData : IComponentData
|
||||
{
|
||||
public float3 BallResetPosition;
|
||||
public Entity BallRef;
|
||||
public int StepCount;
|
||||
}
|
||||
|
||||
public struct Actuator : IComponentData
|
||||
{
|
||||
public float2 Value;
|
||||
}
|
||||
|
||||
public class BallSystem : JobComponentSystem
|
||||
{
|
||||
private const int maxStep = 5000;
|
||||
|
||||
private struct RotateJob : IActuatorJob
|
||||
{
|
||||
public ComponentDataFromEntity<Actuator> ComponentDataFromEntity;
|
||||
public void Execute(ActuatorEvent ev)
|
||||
{
|
||||
var a = ev.GetAction<Actuator>();
|
||||
ComponentDataFromEntity[ev.Entity] = a;
|
||||
}
|
||||
}
|
||||
|
||||
public MLAgentsWorld BallWorld;
|
||||
|
||||
|
||||
// Update is called once per frame
|
||||
protected override JobHandle OnUpdate(JobHandle inputDeps)
|
||||
{
|
||||
|
||||
if (!BallWorld.IsCreated){
|
||||
return inputDeps;
|
||||
}
|
||||
var world = BallWorld;
|
||||
|
||||
ComponentDataFromEntity<Translation> TranslationFromEntity = GetComponentDataFromEntity<Translation>(isReadOnly: false);
|
||||
ComponentDataFromEntity<PhysicsVelocity> VelFromEntity = GetComponentDataFromEntity<PhysicsVelocity>(isReadOnly: false);
|
||||
inputDeps = Entities
|
||||
.WithNativeDisableParallelForRestriction(TranslationFromEntity)
|
||||
.WithNativeDisableParallelForRestriction(VelFromEntity)
|
||||
.ForEach((Entity entity, ref Rotation rot, ref AgentData agentData) =>
|
||||
{
|
||||
|
||||
var ballPos = TranslationFromEntity[agentData.BallRef].Value;
|
||||
var ballVel = VelFromEntity[agentData.BallRef].Linear;
|
||||
var platformVel = VelFromEntity[entity];
|
||||
bool taskFailed = false;
|
||||
bool interruption = false;
|
||||
if (ballPos.y - agentData.BallResetPosition.y < -0.7f)
|
||||
{
|
||||
taskFailed = true;
|
||||
agentData.StepCount = 0;
|
||||
}
|
||||
if (agentData.StepCount > maxStep)
|
||||
{
|
||||
interruption = true;
|
||||
agentData.StepCount = 0;
|
||||
}
|
||||
if (!interruption && !taskFailed)
|
||||
{
|
||||
world.RequestDecision(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
.SetObservation(3, platformVel.Angular)
|
||||
.SetReward((0.1f));
|
||||
}
|
||||
if (taskFailed)
|
||||
{
|
||||
world.EndEpisode(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
.SetObservation(3, platformVel.Angular)
|
||||
.SetReward(-1f);
|
||||
}
|
||||
else if (interruption)
|
||||
{
|
||||
world.InterruptEpisode(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
.SetObservation(3, platformVel.Angular)
|
||||
.SetReward((0.1f));
|
||||
}
|
||||
if (interruption || taskFailed)
|
||||
{
|
||||
VelFromEntity[agentData.BallRef] = new PhysicsVelocity();
|
||||
TranslationFromEntity[agentData.BallRef] = new Translation { Value = agentData.BallResetPosition };
|
||||
rot.Value = quaternion.identity;
|
||||
}
|
||||
agentData.StepCount++;
|
||||
|
||||
}).Schedule(inputDeps);
|
||||
|
||||
var reactiveJob = new RotateJob
|
||||
{
|
||||
ComponentDataFromEntity = GetComponentDataFromEntity<Actuator>(isReadOnly: false)
|
||||
};
|
||||
inputDeps = reactiveJob.Schedule(world, inputDeps);
|
||||
|
||||
inputDeps = Entities.ForEach((Actuator act, ref Rotation rotation) =>
|
||||
{
|
||||
var rot = math.mul(rotation.Value, quaternion.Euler(0.05f * new float3(act.Value.x, 0, act.Value.y)));
|
||||
rotation.Value = rot;
|
||||
}).Schedule(inputDeps);
|
||||
|
||||
return inputDeps;
|
||||
}
|
||||
|
||||
protected override void OnDestroy()
|
||||
{
|
||||
BallWorld.Dispose();
|
||||
}
|
||||
}
|
||||
using Unity.Entities;
|
||||
using Unity.Jobs;
|
||||
using Unity.Mathematics;
|
||||
using Unity.Transforms;
|
||||
using Unity.AI.MLAgents;
|
||||
using Unity.Collections;
|
||||
using Unity.Physics;
|
||||
|
||||
|
||||
public struct AgentData : IComponentData
|
||||
{
|
||||
public float3 BallResetPosition;
|
||||
public Entity BallRef;
|
||||
public int StepCount;
|
||||
}
|
||||
|
||||
public struct Actuator : IComponentData
|
||||
{
|
||||
public float2 Value;
|
||||
}
|
||||
|
||||
public class BallSystem : JobComponentSystem
|
||||
{
|
||||
private const int maxStep = 5000;
|
||||
|
||||
private struct RotateJob : IActuatorJob
|
||||
{
|
||||
public ComponentDataFromEntity<Actuator> ComponentDataFromEntity;
|
||||
public void Execute(ActuatorEvent ev)
|
||||
{
|
||||
var a = ev.GetAction<Actuator>();
|
||||
ComponentDataFromEntity[ev.Entity] = a;
|
||||
}
|
||||
}
|
||||
|
||||
public MLAgentsWorld BallWorld;
|
||||
|
||||
|
||||
// Update is called once per frame
|
||||
protected override JobHandle OnUpdate(JobHandle inputDeps)
|
||||
{
|
||||
|
||||
if (!BallWorld.IsCreated){
|
||||
return inputDeps;
|
||||
}
|
||||
var world = BallWorld;
|
||||
|
||||
ComponentDataFromEntity<Translation> TranslationFromEntity = GetComponentDataFromEntity<Translation>(isReadOnly: false);
|
||||
ComponentDataFromEntity<PhysicsVelocity> VelFromEntity = GetComponentDataFromEntity<PhysicsVelocity>(isReadOnly: false);
|
||||
inputDeps = Entities
|
||||
.WithNativeDisableParallelForRestriction(TranslationFromEntity)
|
||||
.WithNativeDisableParallelForRestriction(VelFromEntity)
|
||||
.ForEach((Entity entity, ref Rotation rot, ref AgentData agentData) =>
|
||||
{
|
||||
|
||||
var ballPos = TranslationFromEntity[agentData.BallRef].Value;
|
||||
var ballVel = VelFromEntity[agentData.BallRef].Linear;
|
||||
var platformVel = VelFromEntity[entity];
|
||||
bool taskFailed = false;
|
||||
bool interruption = false;
|
||||
if (ballPos.y - agentData.BallResetPosition.y < -0.7f)
|
||||
{
|
||||
taskFailed = true;
|
||||
agentData.StepCount = 0;
|
||||
}
|
||||
if (agentData.StepCount > maxStep)
|
||||
{
|
||||
interruption = true;
|
||||
agentData.StepCount = 0;
|
||||
}
|
||||
if (!interruption && !taskFailed)
|
||||
{
|
||||
world.RequestDecision(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
.SetObservation(3, platformVel.Angular)
|
||||
.SetReward((0.1f));
|
||||
}
|
||||
if (taskFailed)
|
||||
{
|
||||
world.EndEpisode(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
.SetObservation(3, platformVel.Angular)
|
||||
.SetReward(-1f);
|
||||
}
|
||||
else if (interruption)
|
||||
{
|
||||
world.InterruptEpisode(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
.SetObservation(3, platformVel.Angular)
|
||||
.SetReward((0.1f));
|
||||
}
|
||||
if (interruption || taskFailed)
|
||||
{
|
||||
VelFromEntity[agentData.BallRef] = new PhysicsVelocity();
|
||||
TranslationFromEntity[agentData.BallRef] = new Translation { Value = agentData.BallResetPosition };
|
||||
rot.Value = quaternion.identity;
|
||||
}
|
||||
agentData.StepCount++;
|
||||
|
||||
}).Schedule(inputDeps);
|
||||
|
||||
var reactiveJob = new RotateJob
|
||||
{
|
||||
ComponentDataFromEntity = GetComponentDataFromEntity<Actuator>(isReadOnly: false)
|
||||
};
|
||||
inputDeps = reactiveJob.Schedule(world, inputDeps);
|
||||
|
||||
inputDeps = Entities.ForEach((Actuator act, ref Rotation rotation) =>
|
||||
{
|
||||
var rot = math.mul(rotation.Value, quaternion.Euler(0.05f * new float3(act.Value.x, 0, act.Value.y)));
|
||||
rotation.Value = rot;
|
||||
}).Schedule(inputDeps);
|
||||
|
||||
return inputDeps;
|
||||
}
|
||||
|
||||
protected override void OnDestroy()
|
||||
{
|
||||
BallWorld.Dispose();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,106 +1,106 @@
|
|||
using Unity.Collections;
|
||||
using Unity.Entities;
|
||||
using Unity.Mathematics;
|
||||
using Unity.Jobs;
|
||||
using UnityEngine;
|
||||
using Unity.AI.MLAgents;
|
||||
|
||||
[DisableAutoCreation]
|
||||
public class SimpleSystem : JobComponentSystem
|
||||
{
|
||||
private MLAgentsWorld world;
|
||||
private NativeArray<Entity> entities;
|
||||
private Camera camera;
|
||||
|
||||
public const int N_Agents = 5;
|
||||
int counter;
|
||||
|
||||
// Start is called before the first frame update
|
||||
protected override void OnCreate()
|
||||
{
|
||||
Application.targetFrameRate = -1;
|
||||
|
||||
world = new MLAgentsWorld(N_Agents, new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) }, ActionType.DISCRETE, 2, new int[] { 2, 3 });
|
||||
world.RegisterWorldWithHeuristic("test", () => new int2(1, 1));
|
||||
entities = new NativeArray<Entity>(N_Agents, Allocator.Persistent);
|
||||
|
||||
for (int i = 0; i < N_Agents; i++)
|
||||
{
|
||||
entities[i] = World.DefaultGameObjectInjectionWorld.EntityManager.CreateEntity();
|
||||
}
|
||||
}
|
||||
|
||||
protected override void OnDestroy()
|
||||
{
|
||||
world.Dispose();
|
||||
entities.Dispose();
|
||||
}
|
||||
|
||||
// Update is called once per frame
|
||||
protected override JobHandle OnUpdate(JobHandle inputDeps)
|
||||
{
|
||||
if (camera == null)
|
||||
{
|
||||
camera = Camera.main;
|
||||
camera = GameObject.FindObjectOfType<Camera>();
|
||||
}
|
||||
inputDeps.Complete();
|
||||
var reactiveJob = new UserCreatedActionEventJob
|
||||
{
|
||||
myNumber = 666
|
||||
};
|
||||
inputDeps = reactiveJob.Schedule(world, inputDeps);
|
||||
if (counter % 5 == 0)
|
||||
{
|
||||
var visObs = VisualObservationUtility.GetVisObs(camera, 84, 84, Allocator.TempJob);
|
||||
var senseJob = new UserCreateSensingJob
|
||||
{
|
||||
cameraObservation = visObs,
|
||||
entities = entities,
|
||||
world = world
|
||||
};
|
||||
inputDeps = senseJob.Schedule(N_Agents, 64, inputDeps);
|
||||
|
||||
inputDeps.Complete();
|
||||
visObs.Dispose();
|
||||
}
|
||||
counter++;
|
||||
|
||||
return inputDeps;
|
||||
}
|
||||
|
||||
public struct UserCreateSensingJob : IJobParallelFor
|
||||
{
|
||||
[ReadOnly] public NativeArray<float> cameraObservation;
|
||||
public NativeArray<Entity> entities;
|
||||
public MLAgentsWorld world;
|
||||
|
||||
public void Execute(int i)
|
||||
{
|
||||
world.RequestDecision(entities[i])
|
||||
.SetReward(1.0f)
|
||||
.SetObservation(0, new float3(entities[i].Index, 0, 0))
|
||||
.SetObservationFromSlice(1, cameraObservation.Slice());
|
||||
}
|
||||
}
|
||||
|
||||
public struct UserCreatedActionEventJob : IActuatorJob
|
||||
{
|
||||
public int myNumber;
|
||||
public void Execute(ActuatorEvent data)
|
||||
{
|
||||
var tmp =data.GetAction<testAction>();
|
||||
Debug.Log(data.Entity.Index + " " + tmp.e1);
|
||||
}
|
||||
}
|
||||
|
||||
public enum testEnum
|
||||
{
|
||||
A, B, C
|
||||
}
|
||||
public struct testAction
|
||||
{
|
||||
public testEnum e1;
|
||||
public testEnum e2;
|
||||
}
|
||||
}
|
||||
using Unity.Collections;
|
||||
using Unity.Entities;
|
||||
using Unity.Mathematics;
|
||||
using Unity.Jobs;
|
||||
using UnityEngine;
|
||||
using Unity.AI.MLAgents;
|
||||
|
||||
[DisableAutoCreation]
|
||||
public class SimpleSystem : JobComponentSystem
|
||||
{
|
||||
private MLAgentsWorld world;
|
||||
private NativeArray<Entity> entities;
|
||||
private Camera camera;
|
||||
|
||||
public const int N_Agents = 5;
|
||||
int counter;
|
||||
|
||||
// Start is called before the first frame update
|
||||
protected override void OnCreate()
|
||||
{
|
||||
Application.targetFrameRate = -1;
|
||||
|
||||
world = new MLAgentsWorld(N_Agents, new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) }, ActionType.DISCRETE, 2, new int[] { 2, 3 });
|
||||
world.RegisterWorldWithHeuristic("test", () => new int2(1, 1));
|
||||
entities = new NativeArray<Entity>(N_Agents, Allocator.Persistent);
|
||||
|
||||
for (int i = 0; i < N_Agents; i++)
|
||||
{
|
||||
entities[i] = World.DefaultGameObjectInjectionWorld.EntityManager.CreateEntity();
|
||||
}
|
||||
}
|
||||
|
||||
protected override void OnDestroy()
|
||||
{
|
||||
world.Dispose();
|
||||
entities.Dispose();
|
||||
}
|
||||
|
||||
// Update is called once per frame
|
||||
protected override JobHandle OnUpdate(JobHandle inputDeps)
|
||||
{
|
||||
if (camera == null)
|
||||
{
|
||||
camera = Camera.main;
|
||||
camera = GameObject.FindObjectOfType<Camera>();
|
||||
}
|
||||
inputDeps.Complete();
|
||||
var reactiveJob = new UserCreatedActionEventJob
|
||||
{
|
||||
myNumber = 666
|
||||
};
|
||||
inputDeps = reactiveJob.Schedule(world, inputDeps);
|
||||
if (counter % 5 == 0)
|
||||
{
|
||||
var visObs = VisualObservationUtility.GetVisObs(camera, 84, 84, Allocator.TempJob);
|
||||
var senseJob = new UserCreateSensingJob
|
||||
{
|
||||
cameraObservation = visObs,
|
||||
entities = entities,
|
||||
world = world
|
||||
};
|
||||
inputDeps = senseJob.Schedule(N_Agents, 64, inputDeps);
|
||||
|
||||
inputDeps.Complete();
|
||||
visObs.Dispose();
|
||||
}
|
||||
counter++;
|
||||
|
||||
return inputDeps;
|
||||
}
|
||||
|
||||
public struct UserCreateSensingJob : IJobParallelFor
|
||||
{
|
||||
[ReadOnly] public NativeArray<float> cameraObservation;
|
||||
public NativeArray<Entity> entities;
|
||||
public MLAgentsWorld world;
|
||||
|
||||
public void Execute(int i)
|
||||
{
|
||||
world.RequestDecision(entities[i])
|
||||
.SetReward(1.0f)
|
||||
.SetObservation(0, new float3(entities[i].Index, 0, 0))
|
||||
.SetObservationFromSlice(1, cameraObservation.Slice());
|
||||
}
|
||||
}
|
||||
|
||||
public struct UserCreatedActionEventJob : IActuatorJob
|
||||
{
|
||||
public int myNumber;
|
||||
public void Execute(ActuatorEvent data)
|
||||
{
|
||||
var tmp =data.GetAction<testAction>();
|
||||
Debug.Log(data.Entity.Index + " " + tmp.e1);
|
||||
}
|
||||
}
|
||||
|
||||
public enum testEnum
|
||||
{
|
||||
A, B, C
|
||||
}
|
||||
public struct testAction
|
||||
{
|
||||
public testEnum e1;
|
||||
public testEnum e2;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,66 +7,67 @@ import subprocess
|
|||
from sys import platform
|
||||
from mlagents_envs.exception import UnityEnvironmentException
|
||||
from mlagents_envs.side_channel.side_channel import SideChannel
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
from mlagents_envs.logging_util import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def validate_environment_path(env_path: str) -> Optional[str]:
|
||||
# Strip out executable extensions if passed
|
||||
env_path = (
|
||||
env_path.strip()
|
||||
.replace(".app", "")
|
||||
.replace(".exe", "")
|
||||
.replace(".x86_64", "")
|
||||
.replace(".x86", "")
|
||||
# Strip out executable extensions if passed
|
||||
env_path = (
|
||||
env_path.strip()
|
||||
.replace(".app", "")
|
||||
.replace(".exe", "")
|
||||
.replace(".x86_64", "")
|
||||
.replace(".x86", "")
|
||||
)
|
||||
true_filename = os.path.basename(os.path.normpath(env_path))
|
||||
logger.debug("The true file name is {}".format(true_filename))
|
||||
|
||||
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
|
||||
return None
|
||||
|
||||
cwd = os.getcwd()
|
||||
launch_string = None
|
||||
true_filename = os.path.basename(os.path.normpath(env_path))
|
||||
if platform == "linux" or platform == "linux2":
|
||||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(env_path + ".x86_64")
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(env_path + ".x86")
|
||||
if len(candidates) > 0:
|
||||
launch_string = candidates[0]
|
||||
|
||||
elif platform == "darwin":
|
||||
candidates = glob.glob(
|
||||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
|
||||
)
|
||||
true_filename = os.path.basename(os.path.normpath(env_path))
|
||||
logger.debug("The true file name is {}".format(true_filename))
|
||||
|
||||
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
|
||||
return None
|
||||
|
||||
cwd = os.getcwd()
|
||||
launch_string = None
|
||||
true_filename = os.path.basename(os.path.normpath(env_path))
|
||||
if platform == "linux" or platform == "linux2":
|
||||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(env_path + ".x86_64")
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(env_path + ".x86")
|
||||
if len(candidates) > 0:
|
||||
launch_string = candidates[0]
|
||||
|
||||
elif platform == "darwin":
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(
|
||||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
|
||||
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
|
||||
)
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(
|
||||
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
|
||||
)
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(
|
||||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
|
||||
)
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(
|
||||
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
|
||||
)
|
||||
if len(candidates) > 0:
|
||||
launch_string = candidates[0]
|
||||
elif platform == "win32":
|
||||
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(env_path + ".exe")
|
||||
if len(candidates) > 0:
|
||||
launch_string = candidates[0]
|
||||
return launch_string
|
||||
# TODO: END REMOVE
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(
|
||||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
|
||||
)
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(
|
||||
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
|
||||
)
|
||||
if len(candidates) > 0:
|
||||
launch_string = candidates[0]
|
||||
elif platform == "win32":
|
||||
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(env_path + ".exe")
|
||||
if len(candidates) > 0:
|
||||
launch_string = candidates[0]
|
||||
return launch_string
|
||||
# TODO: END REMOVE
|
||||
|
||||
|
||||
def get_side_channels(
|
||||
|
@ -77,9 +78,8 @@ def get_side_channels(
|
|||
for _sc in side_c:
|
||||
if _sc.channel_id in side_channels_dict:
|
||||
raise UnityEnvironmentException(
|
||||
"There cannot be two side channels with the same channel type {0}.".format(
|
||||
_sc.channel_id
|
||||
)
|
||||
f"There cannot be two side channels with "
|
||||
f"the same channel id {_sc.channel_id}."
|
||||
)
|
||||
side_channels_dict[_sc.channel_id] = _sc
|
||||
return side_channels_dict
|
||||
|
@ -110,9 +110,7 @@ def executable_launcher(exec_name, memory_path, args):
|
|||
|
||||
elif platform == "darwin":
|
||||
candidates = glob.glob(
|
||||
os.path.join(
|
||||
cwd, exec_name + ".app", "Contents", "MacOS", true_filename
|
||||
)
|
||||
os.path.join(cwd, exec_name + ".app", "Contents", "MacOS", true_filename)
|
||||
)
|
||||
if len(candidates) == 0:
|
||||
candidates = glob.glob(
|
||||
|
@ -137,9 +135,7 @@ def executable_launcher(exec_name, memory_path, args):
|
|||
if launch_string is None:
|
||||
raise UnityEnvironmentException(
|
||||
"Couldn't launch the {0} environment. "
|
||||
"Provided filename does not match any environments.".format(
|
||||
true_filename
|
||||
)
|
||||
"Provided filename does not match any environments.".format(true_filename)
|
||||
)
|
||||
else:
|
||||
logger.debug("This is the launch string {}".format(launch_string))
|
||||
|
@ -151,9 +147,10 @@ def executable_launcher(exec_name, memory_path, args):
|
|||
return subprocess.Popen(
|
||||
subprocess_args,
|
||||
# start_new_session=True means that signals to the parent python process
|
||||
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms.
|
||||
# This is generally good since we want the environment to have a chance to shutdown,
|
||||
# but may be undesirable in come cases; if so, we'll add a command-line toggle.
|
||||
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new
|
||||
# process on POSIX platforms. This is generally good since we want the
|
||||
# environment to have a chance to shutdown, but may be undesirable in
|
||||
# some cases; if so, we'll add a command-line toggle.
|
||||
# Note that on Windows, the CTRL_C signal will still be sent.
|
||||
start_new_session=True,
|
||||
)
|
||||
|
@ -198,6 +195,7 @@ def parse_side_channel_message(
|
|||
": {0}.".format(channel_id)
|
||||
)
|
||||
|
||||
|
||||
def generate_side_channel_data(
|
||||
side_channels: Dict[uuid.UUID, SideChannel]
|
||||
) -> bytearray:
|
||||
|
@ -217,7 +215,8 @@ def returncode_to_signal_name(returncode: int) -> Optional[str]:
|
|||
E.g. returncode_to_signal_name(-2) -> "SIGINT"
|
||||
"""
|
||||
try:
|
||||
# A negative value -N indicates that the child was terminated by signal N (POSIX only).
|
||||
# A negative value -N indicates that the child was terminated by
|
||||
# signal N (POSIX only).
|
||||
s = signal.Signals(-returncode) # pylint: disable=no-member
|
||||
return s.name
|
||||
except Exception:
|
||||
|
|
|
@ -5,9 +5,7 @@ from mlagents_dots_envs.unity_environment import UnityEnvironment
|
|||
from mlagents_envs.side_channel.engine_configuration_channel import (
|
||||
EngineConfigurationChannel,
|
||||
)
|
||||
from mlagents_envs.side_channel.float_properties_channel import (
|
||||
FloatPropertiesChannel,
|
||||
)
|
||||
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
|
||||
|
||||
sc = EngineConfigurationChannel()
|
||||
sc2 = FloatPropertiesChannel()
|
||||
|
@ -15,31 +13,42 @@ env = UnityEnvironment(side_channels=[sc, sc2])
|
|||
sc.set_configuration_parameters()
|
||||
|
||||
|
||||
for i in range(10):
|
||||
s = ""
|
||||
for name in env.get_behavior_names():
|
||||
s += name + " : " + str(len(env.get_steps(name)[0])) + " : " + str(len(env.get_steps(name)[1])) + "|"
|
||||
env.set_actions(name, np.ones((len(env.get_steps(name)[0]), 2)))
|
||||
print(s)
|
||||
env.step()
|
||||
for _ in range(10):
|
||||
s = ""
|
||||
for name in env.get_behavior_names():
|
||||
s += (
|
||||
name
|
||||
+ " : "
|
||||
+ str(len(env.get_steps(name)[0]))
|
||||
+ " : "
|
||||
+ str(len(env.get_steps(name)[1]))
|
||||
+ "|"
|
||||
)
|
||||
env.set_actions(name, np.ones((len(env.get_steps(name)[0]), 2)))
|
||||
print(s)
|
||||
env.step()
|
||||
|
||||
print(env.get_steps("Ball_DOTS")[0].obs[0])
|
||||
|
||||
print("RESET")
|
||||
env.reset()
|
||||
|
||||
# import time
|
||||
# time.sleep(50)
|
||||
|
||||
sc2.set_property("test", 2)
|
||||
sc2.set_property("test2", 2)
|
||||
sc2.set_property("test3", 2)
|
||||
|
||||
for i in range(100):
|
||||
s = ""
|
||||
for name in env.get_behavior_names():
|
||||
s += name +" : " + str(len(env.get_steps(name)[0])) +" : " + str(len(env.get_steps(name)[1]))+ "|"
|
||||
print(s)
|
||||
env.step()
|
||||
for _ in range(100):
|
||||
s = ""
|
||||
for name in env.get_behavior_names():
|
||||
s += (
|
||||
name
|
||||
+ " : "
|
||||
+ str(len(env.get_steps(name)[0]))
|
||||
+ " : "
|
||||
+ str(len(env.get_steps(name)[1]))
|
||||
+ "|"
|
||||
)
|
||||
print(s)
|
||||
env.step()
|
||||
env.close()
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
import os
|
||||
import tempfile
|
||||
import mmap
|
||||
import numpy as np
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Tuple, Optional, NamedTuple, List, Dict
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class BasedSharedMem(ABC):
|
||||
DIRECTORY = "ml-agents"
|
||||
|
||||
def __init__(self, file_name: str, create_file: bool = False, size:int = 0):
|
||||
def __init__(self, file_name: str, create_file: bool = False, size: int = 0):
|
||||
directory = os.path.join(tempfile.gettempdir(), self.DIRECTORY)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
@ -26,16 +27,16 @@ class BasedSharedMem(ABC):
|
|||
self.accessor = mmap.mmap(f.fileno(), 0)
|
||||
self._file_path = file_path
|
||||
|
||||
def get_int(self, offset:int) -> Tuple[int, int]:
|
||||
def get_int(self, offset: int) -> Tuple[int, int]:
|
||||
return struct.unpack_from("<i", self.accessor, offset)[0], offset + 4
|
||||
|
||||
def get_float(self, offset:int) -> Tuple[float, int]:
|
||||
def get_float(self, offset: int) -> Tuple[float, int]:
|
||||
return struct.unpack_from("<f", self.accessor, offset)[0], offset + 4
|
||||
|
||||
def get_bool(self, offset:int) -> Tuple[bool, int]:
|
||||
def get_bool(self, offset: int) -> Tuple[bool, int]:
|
||||
return struct.unpack_from("<?", self.accessor, offset)[0], offset + 1
|
||||
|
||||
def get_string(self, offset:int) -> Tuple[str, int]:
|
||||
def get_string(self, offset: int) -> Tuple[str, int]:
|
||||
string_len = struct.unpack_from("<B", self.accessor, offset)[0]
|
||||
byte_array = bytes(self.accessor[offset + 1 : offset + string_len + 1])
|
||||
result = byte_array.decode("ascii")
|
||||
|
@ -50,44 +51,43 @@ class BasedSharedMem(ABC):
|
|||
).reshape(shape)
|
||||
|
||||
def get_uuid(self, offset: int) -> Tuple[uuid.UUID, int]:
|
||||
return uuid.UUID(bytes_le= self.accessor[offset:offset+16]), offset+16
|
||||
return uuid.UUID(bytes_le=self.accessor[offset : offset + 16]), offset + 16
|
||||
|
||||
def set_int(self, offset:int, value:int) -> int:
|
||||
def set_int(self, offset: int, value: int) -> int:
|
||||
struct.pack_into("<i", self.accessor, offset, value)
|
||||
return offset + 4
|
||||
|
||||
def set_float(self, offset:int, value:float) -> int:
|
||||
def set_float(self, offset: int, value: float) -> int:
|
||||
struct.pack_into("<f", self.accessor, offset, value)
|
||||
return offset + 4
|
||||
|
||||
def set_bool(self, offset:int, value:bool) -> int:
|
||||
def set_bool(self, offset: int, value: bool) -> int:
|
||||
struct.pack_into("<?", self.accessor, offset, value)
|
||||
return offset + 1
|
||||
|
||||
def set_string(self, offset:int, value:str) -> int:
|
||||
def set_string(self, offset: int, value: str) -> int:
|
||||
string_len = len(value)
|
||||
struct.pack_into("<B", self.accessor, offset, string_len)
|
||||
offset += 1
|
||||
self.accessor[offset: offset + string_len] = value.encode("ascii")
|
||||
self.accessor[offset : offset + string_len] = value.encode("ascii")
|
||||
return offset + string_len
|
||||
|
||||
def set_uuid(self, offset: int, value: uuid.UUID) -> int:
|
||||
self.accessor[offset: offset + 16] = value.bytes_le
|
||||
self.accessor[offset : offset + 16] = value.bytes_le
|
||||
return offset + 16
|
||||
|
||||
def set_ndarray(self, offset: int, data: np.ndarray) -> None:
|
||||
bytes_data = data.tobytes()
|
||||
self.accessor[offset: offset + len(bytes_data)] = bytes_data
|
||||
self.accessor[offset : offset + len(bytes_data)] = bytes_data
|
||||
|
||||
def close(self) -> None:
|
||||
if self.accessor is not None:
|
||||
self.accessor.close()
|
||||
self.accessor = None
|
||||
self.accessor = None # type: ignore
|
||||
|
||||
def delete(self) -> None:
|
||||
self.close()
|
||||
try:
|
||||
os.remove(self._file_path)
|
||||
except:
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
|
|
|
@ -1,20 +1,28 @@
|
|||
import numpy as np
|
||||
from mlagents_dots_envs.shared_memory.base_shared_mem import BasedSharedMem
|
||||
from mlagents_dots_envs.shared_memory.rl_data_offset import RLDataOffsets
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from mlagents_envs.base_env import DecisionSteps, TerminalSteps, BehaviorSpec, ActionType
|
||||
from typing import Dict, List
|
||||
from mlagents_envs.base_env import (
|
||||
DecisionSteps,
|
||||
TerminalSteps,
|
||||
BehaviorSpec,
|
||||
ActionType,
|
||||
)
|
||||
|
||||
|
||||
class DataSharedMem(BasedSharedMem):
|
||||
def __init__(self,
|
||||
file_name: str,
|
||||
create_file: bool = False,
|
||||
copy_from: "DataSharedMem" = None,
|
||||
side_channel_buffer_size: int = 0,
|
||||
rl_data_buffer_size: int = 0):
|
||||
def __init__(
|
||||
self,
|
||||
file_name: str,
|
||||
create_file: bool = False,
|
||||
copy_from: "DataSharedMem" = None,
|
||||
side_channel_buffer_size: int = 0,
|
||||
rl_data_buffer_size: int = 0,
|
||||
):
|
||||
self._offset_dict: Dict[str, RLDataOffsets] = {}
|
||||
if create_file and copy_from is None:
|
||||
size = side_channel_buffer_size + rl_data_buffer_size
|
||||
super(DataSharedMem, self).__init__( file_name, True, size)
|
||||
super(DataSharedMem, self).__init__(file_name, True, size)
|
||||
self._side_channel_buffer_size = side_channel_buffer_size
|
||||
self._rl_data_buffer_size = rl_data_buffer_size
|
||||
return
|
||||
|
@ -46,7 +54,7 @@ class DataSharedMem(BasedSharedMem):
|
|||
def side_channel_data(self) -> bytearray:
|
||||
offset = 0
|
||||
len_data, offset = self.get_int(offset)
|
||||
return self.accessor[offset: offset + len_data]
|
||||
return self.accessor[offset : offset + len_data]
|
||||
|
||||
@side_channel_data.setter
|
||||
def side_channel_data(self, data: bytearray) -> None:
|
||||
|
@ -64,7 +72,7 @@ class DataSharedMem(BasedSharedMem):
|
|||
def rl_data(self) -> bytearray:
|
||||
offset = self.rl_data_offset
|
||||
size = self._rl_data_buffer_size
|
||||
return self.accessor[offset: offset + size]
|
||||
return self.accessor[offset : offset + size]
|
||||
|
||||
@rl_data.setter
|
||||
def rl_data(self, data: bytearray) -> None:
|
||||
|
@ -72,7 +80,7 @@ class DataSharedMem(BasedSharedMem):
|
|||
raise Exception("TODO")
|
||||
offset = self.rl_data_offset
|
||||
size = self._rl_data_buffer_size
|
||||
self.accessor[offset: offset + size] = data
|
||||
self.accessor[offset : offset + size] = data
|
||||
self._refresh_offsets()
|
||||
|
||||
def get_decision_steps(self, key: str) -> DecisionSteps:
|
||||
|
@ -80,34 +88,48 @@ class DataSharedMem(BasedSharedMem):
|
|||
offsets = self._offset_dict[key]
|
||||
n_agents, _ = self.get_int(offsets.decision_n_agents_offset)
|
||||
obs: List[np.ndarray] = []
|
||||
for obs_offset, obs_shape in zip(offsets.decision_obs_offset, offsets.obs_shapes):
|
||||
for obs_offset, obs_shape in zip(
|
||||
offsets.decision_obs_offset, offsets.obs_shapes
|
||||
):
|
||||
obs_shape = (n_agents,) + obs_shape
|
||||
arr = self.get_ndarray(obs_offset, obs_shape, np.float32)
|
||||
obs.append(arr)
|
||||
return DecisionSteps(
|
||||
obs=obs,
|
||||
reward=self.get_ndarray(offsets.decision_rewards_offset, (n_agents), np.float32),
|
||||
agent_id=self.get_ndarray(offsets.decision_agent_id_offset, (n_agents), np.int32),
|
||||
action_mask=None #TODO
|
||||
reward=self.get_ndarray(
|
||||
offsets.decision_rewards_offset, (n_agents), np.float32
|
||||
),
|
||||
agent_id=self.get_ndarray(
|
||||
offsets.decision_agent_id_offset, (n_agents), np.int32
|
||||
),
|
||||
action_mask=None, # TODO
|
||||
)
|
||||
|
||||
def get_terminal_steps(self, key) -> TerminalSteps:
|
||||
def get_terminal_steps(self, key: str) -> TerminalSteps:
|
||||
assert key in self._offset_dict
|
||||
offsets = self._offset_dict[key]
|
||||
n_agents, _ = self.get_int(offsets.termination_n_agents_offset)
|
||||
obs: List[np.ndarray] = []
|
||||
for obs_offset, obs_shape in zip(offsets.termination_obs_offset, offsets.obs_shapes):
|
||||
for obs_offset, obs_shape in zip(
|
||||
offsets.termination_obs_offset, offsets.obs_shapes
|
||||
):
|
||||
obs_shape = (n_agents,) + obs_shape
|
||||
arr= self.get_ndarray(obs_offset, obs_shape, np.float32)
|
||||
arr = self.get_ndarray(obs_offset, obs_shape, np.float32)
|
||||
obs.append(arr)
|
||||
return TerminalSteps(
|
||||
obs=obs,
|
||||
reward=self.get_ndarray(offsets.termination_reward_offset, (n_agents), np.float32),
|
||||
agent_id=self.get_ndarray(offsets.termination_agent_id_offset, (n_agents), np.int32),
|
||||
max_step=self.get_ndarray(offsets.termination_status_offset, (n_agents), np.bool),
|
||||
reward=self.get_ndarray(
|
||||
offsets.termination_reward_offset, (n_agents), np.float32
|
||||
),
|
||||
agent_id=self.get_ndarray(
|
||||
offsets.termination_agent_id_offset, (n_agents), np.int32
|
||||
),
|
||||
max_step=self.get_ndarray(
|
||||
offsets.termination_status_offset, (n_agents), np.bool
|
||||
),
|
||||
)
|
||||
|
||||
def set_actions(self, key: str, data:np.ndarray) -> None:
|
||||
def set_actions(self, key: str, data: np.ndarray) -> None:
|
||||
assert key in self._offset_dict
|
||||
offsets = self._offset_dict[key]
|
||||
self.set_ndarray(offsets.action_offset, data)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import os, glob
|
||||
import os
|
||||
import glob
|
||||
from mlagents_dots_envs.shared_memory.base_shared_mem import BasedSharedMem
|
||||
|
||||
|
||||
class MasterSharedMem(BasedSharedMem):
|
||||
"""
|
||||
Always created by Python
|
||||
|
@ -14,11 +16,17 @@ class MasterSharedMem(BasedSharedMem):
|
|||
- int : Communication file "side channel" size in bytes
|
||||
- int : Communication file "RL section" size in bytes
|
||||
"""
|
||||
SIZE = 28
|
||||
VERSION = (0,3,0)
|
||||
def __init__(self, file_name: str, side_channel_size=0, rl_data_size=0):
|
||||
|
||||
super(MasterSharedMem, self).__init__(file_name, create_file=True, size=self.SIZE)
|
||||
SIZE = 28
|
||||
VERSION = (0, 3, 0)
|
||||
|
||||
def __init__(
|
||||
self, file_name: str, side_channel_size: int = 0, rl_data_size: int = 0
|
||||
):
|
||||
|
||||
super(MasterSharedMem, self).__init__(
|
||||
file_name, create_file=True, size=self.SIZE
|
||||
)
|
||||
for f in glob.glob(file_name + "_"):
|
||||
# Removing all the future files in case they were not correctly created
|
||||
os.remove(f)
|
||||
|
@ -76,24 +84,24 @@ class MasterSharedMem(BasedSharedMem):
|
|||
self.set_bool(offset, True)
|
||||
|
||||
@property
|
||||
def side_channel_size(self):
|
||||
def side_channel_size(self) -> int:
|
||||
offset = 20
|
||||
result, _ = self.get_int(offset)
|
||||
return result
|
||||
|
||||
@side_channel_size.setter
|
||||
def side_channel_size(self, value:int):
|
||||
def side_channel_size(self, value: int) -> None:
|
||||
offset = 20
|
||||
self.set_int(offset, value)
|
||||
|
||||
@property
|
||||
def rl_data_size(self):
|
||||
def rl_data_size(self) -> int:
|
||||
offset = 24
|
||||
result, _ = self.get_int(offset)
|
||||
return result
|
||||
|
||||
@rl_data_size.setter
|
||||
def rl_data_size(self, value:int):
|
||||
def rl_data_size(self, value: int) -> None:
|
||||
offset = 24
|
||||
self.set_int(offset, value)
|
||||
|
||||
|
@ -103,8 +111,8 @@ class MasterSharedMem(BasedSharedMem):
|
|||
minor, offset = self.get_int(offset)
|
||||
bug, offset = self.get_int(offset)
|
||||
if (major, minor, bug) != self.VERSION:
|
||||
raise Exception(f"Incompatible versions of communicator between " +
|
||||
f"Unity {major}.{minor}.{bug} and Python "
|
||||
f"{self.VERSION[0]}.{self.VERSION[1]}.{self.VERSION[2]}")
|
||||
|
||||
|
||||
raise Exception(
|
||||
f"Incompatible versions of communicator between "
|
||||
+ f"Unity {major}.{minor}.{bug} and Python "
|
||||
f"{self.VERSION[0]}.{self.VERSION[1]}.{self.VERSION[2]}"
|
||||
)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
import numpy as np
|
||||
|
||||
from typing import Tuple, Optional, NamedTuple, List, Dict
|
||||
from typing import Tuple, Optional, NamedTuple, List
|
||||
from mlagents_dots_envs.shared_memory.base_shared_mem import BasedSharedMem
|
||||
|
||||
|
||||
class RLDataOffsets(NamedTuple):
|
||||
"""
|
||||
Contains the offsets to the data for a section of the RL data
|
||||
"""
|
||||
|
||||
# data
|
||||
name:str
|
||||
name: str
|
||||
max_n_agents: int
|
||||
is_action_continuous: bool
|
||||
action_size: int
|
||||
|
@ -47,28 +49,28 @@ class RLDataOffsets(NamedTuple):
|
|||
# 4 bytes : n_agents at current step
|
||||
# ? Bytes : the data : obs,reward,done,max_step,agent_id,masks,action
|
||||
|
||||
## Get the specs of the group
|
||||
# Get the specs of the group
|
||||
name, offset = mem.get_string(offset)
|
||||
max_n_agents, offset = mem.get_int(offset)
|
||||
is_continuous, offset = mem.get_bool(offset)
|
||||
action_size, offset = mem.get_int(offset)
|
||||
discrete_branches = None
|
||||
if not is_continuous:
|
||||
discrete_branches = ()
|
||||
discrete_branches = () # type: ignore
|
||||
for _ in range(action_size):
|
||||
branch_size, offset = mem.get_int(offset)
|
||||
discrete_branches += (branch_size,)
|
||||
discrete_branches += (branch_size,) # type: ignore
|
||||
n_obs, offset = mem.get_int(offset)
|
||||
obs_shapes: List[Tuple[int, ...]] = []
|
||||
for _ in range(n_obs):
|
||||
shape = ()
|
||||
shape = () # type: ignore
|
||||
for _ in range(3):
|
||||
s, offset = mem.get_int(offset)
|
||||
if s != 0:
|
||||
shape += (s,)
|
||||
shape += (s,) # type: ignore
|
||||
obs_shapes += [shape]
|
||||
|
||||
## Compute the offsets for decision steps
|
||||
# Compute the offsets for decision steps
|
||||
# n_agents
|
||||
decision_n_agents_offset = offset
|
||||
_, offset = mem.get_int(offset)
|
||||
|
@ -90,7 +92,7 @@ class RLDataOffsets(NamedTuple):
|
|||
mask_offset = offset
|
||||
offset += max_n_agents * int(np.sum(discrete_branches))
|
||||
|
||||
## Compute the offsets for termination steps
|
||||
# Compute the offsets for termination steps
|
||||
# n_agents
|
||||
termination_n_agents_offset = offset
|
||||
_, offset = mem.get_int(offset)
|
||||
|
@ -102,14 +104,14 @@ class RLDataOffsets(NamedTuple):
|
|||
# rewards
|
||||
termination_reward_offset = offset
|
||||
offset += 4 * max_n_agents
|
||||
#status
|
||||
# status
|
||||
termination_status_offset = offset
|
||||
offset += max_n_agents
|
||||
# agent id
|
||||
termination_agent_id_offset = offset
|
||||
offset += 4 * max_n_agents
|
||||
|
||||
## Compute the offsets for actions
|
||||
# Compute the offsets for actions
|
||||
act_offset = offset
|
||||
offset += 4 * max_n_agents * action_size
|
||||
|
||||
|
|
|
@ -1,15 +1,8 @@
|
|||
import mmap
|
||||
import struct
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import string
|
||||
from enum import IntEnum
|
||||
import random
|
||||
from typing import Tuple, Optional, NamedTuple, List, Dict
|
||||
from typing import Tuple, Dict
|
||||
|
||||
from mlagents_dots_envs.shared_memory.master_shared_mem import MasterSharedMem
|
||||
from mlagents_dots_envs.shared_memory.data_shared_mem import DataSharedMem
|
||||
|
@ -17,6 +10,7 @@ from mlagents_dots_envs.shared_memory.data_shared_mem import DataSharedMem
|
|||
from mlagents_envs.exception import UnityCommunicationException
|
||||
from mlagents_envs.base_env import DecisionSteps, TerminalSteps, BehaviorSpec
|
||||
|
||||
|
||||
class SharedMemCom:
|
||||
FILE_DEFAULT = "default"
|
||||
MAX_TIMEOUT_IN_SECONDS = 30
|
||||
|
@ -34,9 +28,10 @@ class SharedMemCom:
|
|||
self._data_mem = DataSharedMem(
|
||||
file_name + "_" * self._current_file_number,
|
||||
create_file=True,
|
||||
copy_from = None,
|
||||
copy_from=None,
|
||||
side_channel_buffer_size=4,
|
||||
rl_data_buffer_size=0)
|
||||
rl_data_buffer_size=0,
|
||||
)
|
||||
|
||||
@property
|
||||
def communicator_id(self):
|
||||
|
@ -53,15 +48,16 @@ class SharedMemCom:
|
|||
def write_side_channel_data(self, data: bytearray) -> None:
|
||||
capacity = self._master_mem.side_channel_size
|
||||
if len(data) >= capacity - 4: # need 4 bytes for an integer size
|
||||
new_capacity = 2*len(data)+20
|
||||
new_capacity = 2 * len(data) + 20
|
||||
self._current_file_number += 1
|
||||
self._master_mem.file_number = self._current_file_number
|
||||
tmp = self._data_mem
|
||||
self._data_mem = DataSharedMem(self._base_file_name + "_" * self._current_file_number,
|
||||
self._data_mem = DataSharedMem(
|
||||
self._base_file_name + "_" * self._current_file_number,
|
||||
create_file=True,
|
||||
copy_from=tmp,
|
||||
side_channel_buffer_size=new_capacity,
|
||||
rl_data_buffer_size=self._master_mem.rl_data_size
|
||||
rl_data_buffer_size=self._master_mem.rl_data_size,
|
||||
)
|
||||
tmp.close()
|
||||
# Unity is responsible for destroying the old file
|
||||
|
@ -73,7 +69,7 @@ class SharedMemCom:
|
|||
self._data_mem.side_channel_data = bytearray()
|
||||
return result
|
||||
|
||||
def give_unity_control(self, reset:bool = False):
|
||||
def give_unity_control(self, reset: bool = False) -> None:
|
||||
self._master_mem.mark_python_blocked()
|
||||
if reset:
|
||||
self._master_mem.mark_reset()
|
||||
|
@ -104,15 +100,19 @@ class SharedMemCom:
|
|||
self._data_mem = DataSharedMem(
|
||||
self._base_file_name + "_" * self._current_file_number,
|
||||
side_channel_buffer_size=self._master_mem.side_channel_size,
|
||||
rl_data_buffer_size=self._master_mem.rl_data_size)
|
||||
rl_data_buffer_size=self._master_mem.rl_data_size,
|
||||
)
|
||||
|
||||
def get_steps(self, key: str) -> Tuple[DecisionSteps, TerminalSteps]:
|
||||
return self._data_mem.get_decision_steps(key), self._data_mem.get_terminal_steps(key)
|
||||
return (
|
||||
self._data_mem.get_decision_steps(key),
|
||||
self._data_mem.get_terminal_steps(key),
|
||||
)
|
||||
|
||||
def get_n_decisions_requested(self, key: str) -> int:
|
||||
return self._data_mem.get_n_decisions_requested(key)
|
||||
|
||||
def set_actions(self, key: str, data:np.ndarray) -> None:
|
||||
def set_actions(self, key: str, data: np.ndarray) -> None:
|
||||
self._data_mem.set_actions(key, data)
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
import atexit
|
||||
import logging
|
||||
import numpy as np
|
||||
import uuid
|
||||
import signal
|
||||
import struct
|
||||
import subprocess
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mlagents_envs.side_channel.side_channel import SideChannel
|
||||
|
||||
|
@ -15,15 +11,9 @@ from mlagents_envs.base_env import (
|
|||
TerminalSteps,
|
||||
BehaviorSpec,
|
||||
BehaviorName,
|
||||
ActionType,
|
||||
)
|
||||
from mlagents_envs.timers import timed, hierarchical_timer
|
||||
from mlagents_envs.exception import (
|
||||
UnityEnvironmentException,
|
||||
UnityCommunicationException,
|
||||
UnityActionException,
|
||||
UnityTimeOutException,
|
||||
)
|
||||
from mlagents_envs.timers import timed
|
||||
from mlagents_envs.exception import UnityCommunicationException, UnityActionException
|
||||
|
||||
from mlagents_dots_envs.shared_memory.shared_mem_com import SharedMemCom
|
||||
|
||||
|
@ -33,40 +23,41 @@ from mlagents_dots_envs.env_utils import (
|
|||
generate_side_channel_data,
|
||||
parse_side_channel_message,
|
||||
returncode_to_signal_name,
|
||||
validate_environment_path as ttt
|
||||
validate_environment_path as ttt,
|
||||
)
|
||||
|
||||
from mlagents_envs.logging_util import get_logger
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("mlagents.envs")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UnityEnvironment(BaseEnv):
|
||||
API_VERSION = "API-14" # TODO : REMOVE
|
||||
DEFAULT_EDITOR_PORT = 5004 # TODO : REMOVE
|
||||
BASE_ENVIRONMENT_PORT = 5005 # TODO : REMOVE
|
||||
PORT_COMMAND_LINE_ARG = "--mlagents-port" # TODO : REMOVE
|
||||
API_VERSION = "API-14" # TODO : REMOVE
|
||||
DEFAULT_EDITOR_PORT = 5004 # TODO : REMOVE
|
||||
BASE_ENVIRONMENT_PORT = 5005 # TODO : REMOVE
|
||||
PORT_COMMAND_LINE_ARG = "--mlagents-port" # TODO : REMOVE
|
||||
|
||||
@staticmethod
|
||||
def validate_environment_path( path__):
|
||||
def validate_environment_path(path__):
|
||||
return ttt(path__)
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id=None, # TODO : REMOVE
|
||||
seed=None, # TODO : REMOVE
|
||||
docker_training=None, # TODO : REMOVE
|
||||
no_graphics=None, # TODO : REMOVE
|
||||
base_port=None, # TODO : REMOVE
|
||||
worker_id: Optional[int] = None, # TODO : REMOVE
|
||||
seed: Optional[int] = None, # TODO : REMOVE
|
||||
docker_trainin: Optional[bool] = None, # TODO : REMOVE
|
||||
no_graphics: Optional[bool] = None, # TODO : REMOVE
|
||||
base_port: Optional[int] = None, # TODO : REMOVE
|
||||
file_name: Optional[str] = None,
|
||||
args: Optional[List[str]] = None,
|
||||
side_channels: Optional[List[SideChannel]] = None,
|
||||
):
|
||||
"""
|
||||
Starts a new unity environment and establishes a connection with the environment.
|
||||
Starts a new unity environment and establishes a connection with it.
|
||||
|
||||
:string file_name: Name of Unity environment binary. If None, will try to connect to the Editor.
|
||||
:string file_name: Name of Unity environment binary. If None, will try to
|
||||
connect to the Editor.
|
||||
:list args: Addition Unity command line arguments
|
||||
:list side_channels: Additional side channel for not-rl communication with Unity
|
||||
"""
|
||||
|
@ -93,15 +84,14 @@ class UnityEnvironment(BaseEnv):
|
|||
self._communicator.give_unity_control()
|
||||
self._communicator.wait_for_unity()
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
self._step(reset = True)
|
||||
self._step(reset=True)
|
||||
|
||||
@timed
|
||||
def step(self) -> None:
|
||||
self._step(reset = False)
|
||||
self._step(reset=False)
|
||||
|
||||
def _step(self, reset:bool= False) -> None:
|
||||
def _step(self, reset: bool = False) -> None:
|
||||
if not self._communicator.active:
|
||||
raise UnityCommunicationException("Communicator has stopped.")
|
||||
channel_data = generate_side_channel_data(self._side_channels)
|
||||
|
@ -122,8 +112,8 @@ class UnityEnvironment(BaseEnv):
|
|||
def _assert_behavior_exists(self, behavior_name: BehaviorName) -> None:
|
||||
if behavior_name not in self._env_specs:
|
||||
raise UnityActionException(
|
||||
"The behavior {0} does not correspond to one existing "
|
||||
"in the environment".format(behavior_name)
|
||||
f"The behavior {behavior_name} does not correspond to one existing "
|
||||
f"in the environment"
|
||||
)
|
||||
|
||||
def set_actions(self, behavior_name: BehaviorName, action: np.array) -> None:
|
||||
|
@ -131,16 +121,15 @@ class UnityEnvironment(BaseEnv):
|
|||
expected_n_agents = self._communicator.get_n_decisions_requested(behavior_name)
|
||||
if expected_n_agents == 0 and len(action) != 0:
|
||||
raise UnityActionException(
|
||||
"The behavior {0} does not need an input this step".format(behavior_name)
|
||||
f"The behavior {behavior_name} does not need an input this step"
|
||||
)
|
||||
spec = self._env_specs[behavior_name]
|
||||
expected_type = np.float32 if spec.is_action_continuous() else np.int32
|
||||
expected_shape = (expected_n_agents, spec.action_size)
|
||||
if action.shape != expected_shape:
|
||||
raise UnityActionException(
|
||||
"The behavior {0} needs an input of dimension {1} but received input of dimension {2}".format(
|
||||
behavior_name, expected_shape, action.shape
|
||||
)
|
||||
f"The behavior {behavior_name} needs an input of dimension"
|
||||
f"{expected_shape} but received input of dimension {action.shape}"
|
||||
)
|
||||
if action.dtype != expected_type:
|
||||
action = action.astype(expected_type)
|
||||
|
@ -151,17 +140,19 @@ class UnityEnvironment(BaseEnv):
|
|||
) -> None:
|
||||
raise NotImplementedError("Method not implemented.")
|
||||
|
||||
def get_steps(self, behavior_name: BehaviorName) -> Tuple[DecisionSteps, TerminalSteps]:
|
||||
def get_steps(
|
||||
self, behavior_name: BehaviorName
|
||||
) -> Tuple[DecisionSteps, TerminalSteps]:
|
||||
self._assert_behavior_exists(behavior_name)
|
||||
return self._communicator.get_steps(behavior_name)
|
||||
|
||||
def get_behavior_spec(self, behavior_name:BehaviorName) -> BehaviorSpec:
|
||||
def get_behavior_spec(self, behavior_name: BehaviorName) -> BehaviorSpec:
|
||||
self._assert_behavior_exists(behavior_name)
|
||||
return self._env_specs[behavior_name]
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Sends a shutdown signal to the unity environment, and closes the socket connection.
|
||||
Sends a shutdown signal to the unity environment, and closes the communication.
|
||||
"""
|
||||
self._communicator.close()
|
||||
if self.proc1 is not None:
|
||||
|
@ -170,11 +161,13 @@ class UnityEnvironment(BaseEnv):
|
|||
self.proc1.wait(timeout=5)
|
||||
signal_name = returncode_to_signal_name(self.proc1.returncode)
|
||||
signal_name = f" ({signal_name})" if signal_name else ""
|
||||
return_info = f"Environment shut down with return code {self.proc1.returncode}{signal_name}."
|
||||
return_info = (
|
||||
f"Environment shut down with return"
|
||||
f"code{self.proc1.returncode}{signal_name}."
|
||||
)
|
||||
logger.info(return_info)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.info("Environment timed out shutting down. Killing...")
|
||||
self.proc1.kill()
|
||||
# Set to None so we don't try to close multiple times.
|
||||
self.proc1 = None
|
||||
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
import os
|
||||
import sys
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools.command.install import install
|
||||
import mlagents_dots_envs
|
||||
|
||||
setup(
|
||||
name="mlagents_dots_envs",
|
||||
|
@ -17,8 +13,10 @@ setup(
|
|||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
],
|
||||
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests", "examples"]),
|
||||
packages=find_packages(
|
||||
exclude=["*.tests", "*.tests.*", "tests.*", "tests", "examples"]
|
||||
),
|
||||
zip_safe=False,
|
||||
install_requires=[ "mlagents_envs>=0.14.1"],
|
||||
install_requires=["mlagents_envs>=0.14.1"],
|
||||
python_requires=">=3.6",
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче