Modified the *CODE* to rename world to policy

This commit is contained in:
vincentpierre 2020-06-08 09:35:22 -07:00
Родитель 17a3f90d08
Коммит 0b677295d3
38 изменённых файлов: 426 добавлений и 427 удалений

Просмотреть файл

@ -12,7 +12,7 @@ namespace Unity.AI.MLAgents.Editor
internal static class SpecsPropertyNames
{
public const string k_Name = "Name";
public const string k_WorldProcessorType = "WorldProcessorType";
public const string k_PolicyProcessorType = "PolicyProcessorType";
public const string k_NumberAgents = "NumberAgents";
public const string k_ActionType = "ActionType";
public const string k_ObservationShapes = "ObservationShapes";
@ -27,8 +27,8 @@ namespace Unity.AI.MLAgents.Editor
/// PropertyDrawer for BrainParameters. Defines how BrainParameters are displayed in the
/// Inspector.
/// </summary>
[CustomPropertyDrawer(typeof(MLAgentsWorldSpecs))]
internal class MLAgentsWorldSpecsDrawer : PropertyDrawer
[CustomPropertyDrawer(typeof(PolicySpecs))]
internal class PolicySpecsDrawer : PropertyDrawer
{
// The height of a line in the Unity Inspectors
const float k_LineHeight = 21f;
@ -66,7 +66,7 @@ namespace Unity.AI.MLAgents.Editor
new Rect(position.x - 3f, position.y, position.width + 6f, m_TotalHeight),
new Color(0f, 0f, 0f, 0.1f));
EditorGUI.LabelField(position, "ML-Agents World Specs : " + label.text);
EditorGUI.LabelField(position, "Policy Specs : " + label.text);
position.y += k_LineHeight;
EditorGUI.indentLevel++;
@ -74,13 +74,13 @@ namespace Unity.AI.MLAgents.Editor
// Name
EditorGUI.PropertyField(position,
property.FindPropertyRelative(SpecsPropertyNames.k_Name),
new GUIContent("World Name", "The name of the World"));
new GUIContent("Name", "The name of the Policy"));
position.y += k_LineHeight;
// WorldProcessorType
// PolicyProcessorType
EditorGUI.PropertyField(position,
property.FindPropertyRelative(SpecsPropertyNames.k_WorldProcessorType),
new GUIContent("Processor Type", "The Policy for the World"));
property.FindPropertyRelative(SpecsPropertyNames.k_PolicyProcessorType),
new GUIContent("Policy Type", "The type of Policy"));
position.y += k_LineHeight;
// Number of Agents
@ -224,7 +224,7 @@ namespace Unity.AI.MLAgents.Editor
var name = nameProperty.stringValue;
if (name == "")
{
m_Warnings.Add("Your World must have a non-empty name");
m_Warnings.Add("Your Policy must have a non-empty name");
}
// Max number of agents is not zero
@ -232,14 +232,14 @@ namespace Unity.AI.MLAgents.Editor
var nAgents = nAgentsProperty.intValue;
if (nAgents == 0)
{
m_Warnings.Add("Your World must have a non-zero maximum number of Agents");
m_Warnings.Add("Your Policy must have a non-zero maximum number of Agents");
}
// At least one observation
var observationShapes = property.FindPropertyRelative(SpecsPropertyNames.k_ObservationShapes);
if (observationShapes.arraySize == 0)
{
m_Warnings.Add("Your World must have at least one observation");
m_Warnings.Add("Your Policy must have at least one observation");
}
// Action Size is not zero
@ -247,7 +247,7 @@ namespace Unity.AI.MLAgents.Editor
var actionSize = actionSizeProperty.intValue;
if (actionSize == 0)
{
m_Warnings.Add("Your World must have non-zero action size");
m_Warnings.Add("Your Policy must have non-zero action size");
}
//Model is not empty

Просмотреть файл

@ -1,5 +1,5 @@
fileFormatVersion: 2
guid: 671ee066644364cc191cb6c00ceaf1b4
guid: 94a88b2001ca84a0aa8eb12a93dffc6e
MonoImporter:
externalObjects: {}
serializedVersion: 2

Просмотреть файл

@ -11,7 +11,7 @@ namespace Unity.AI.MLAgents
/// <summary>
/// The Academy is a singleton that orchestrates the decision making of the
/// decision making of the Agents.
/// It is used to register WorldProcessors to Worlds and to keep track of the
/// It is used to register PolicyProcessors to Policy and to keep track of the
/// reset logic of the simulation.
/// </summary>
public class Academy : IDisposable
@ -50,7 +50,8 @@ namespace Unity.AI.MLAgents
private bool m_FirstMessageReceived;
private SharedMemoryCommunicator m_Communicator;
internal Dictionary<MLAgentsWorld, IWorldProcessor> m_WorldToProcessor; // Maybe we can put the processor in the world with an unsafe unmanaged memory pointer ?
internal Dictionary<Policy, IPolicyProcessor> m_PolicyToProcessor;
// TODO : Maybe we can put the processor in the policy with an unsafe unmanaged memory pointer ?
private EnvironmentParameters m_EnvironmentParameters;
private StatsRecorder m_StatsRecorder;
@ -63,33 +64,29 @@ namespace Unity.AI.MLAgents
public Action OnEnvironmentReset;
/// <summary>
/// Registers a MLAgentsWorld to a decision making mechanism.
/// By default, the MLAgentsWorld will use a remote process for decision making when available.
/// Registers a Policy to a decision making mechanism.
/// By default, the Policy will use a remote process for decision making when available.
/// </summary>
/// <param name="policyId"> The string identifier of the MLAgentsWorld. There can only be one world per unique id.</param>
/// <param name="world"> The MLAgentsWorld that is being subscribed.</param>
/// <param name="worldProcessor"> If the remote process is not available, the MLAgentsWorld will use this World processor for decision making.</param>
/// <param name="defaultRemote"> If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallback worldProcessor otherwise.</param>
public void RegisterWorld(string policyId, MLAgentsWorld world, IWorldProcessor worldProcessor = null, bool defaultRemote = true)
/// <param name="policyId"> The string identifier of the Policy. There can only be one Policy per unique id.</param>
/// <param name="policy"> The Policy that is being subscribed.</param>
/// <param name="policyProcessor"> If the remote process is not available, the Policy will use this IPolicyProcessor for decision making.</param>
/// <param name="defaultRemote"> If true, the Policy will default to using the remote process for communication making and use the fallback IPolicyProcessor otherwise.</param>
public void RegisterPolicy(string policyId, Policy policy, IPolicyProcessor policyProcessor = null, bool defaultRemote = true)
{
// Need to find a way to deregister ?
// Need a way to modify the World processor on the fly
// Automagically register world on creation ?
IWorldProcessor processor = null;
IPolicyProcessor processor = null;
if (m_Communicator != null && defaultRemote)
{
processor = new RemoteWorldProcessor(world, policyId, m_Communicator);
processor = new RemotePolicyProcessor(policy, policyId, m_Communicator);
}
else if (worldProcessor != null)
else if (policyProcessor != null)
{
processor = worldProcessor;
processor = policyProcessor;
}
else
{
processor = new NullWorldProcessor(world);
processor = new NullPolicyProcessor(policy);
}
m_WorldToProcessor[world] = processor;
m_PolicyToProcessor[policy] = processor;
}
/// <summary>
@ -133,7 +130,7 @@ namespace Unity.AI.MLAgents
Application.quitting += Dispose;
OnEnvironmentReset = () => {};
m_WorldToProcessor = new Dictionary<MLAgentsWorld, IWorldProcessor>();
m_PolicyToProcessor = new Dictionary<Policy, IPolicyProcessor>();
TryInitializeCommunicator();
SideChannelsManager.RegisterSideChannel(new EngineConfigurationChannel());
@ -162,8 +159,8 @@ namespace Unity.AI.MLAgents
}
}
// We will make the assumption that a world can only be updated one at a time
internal void UpdateWorld(MLAgentsWorld world)
// We will make the assumption that a policy can only be updated one at a time
internal void UpdatePolicy(Policy policy)
{
if (!m_Initialized)
{
@ -171,23 +168,23 @@ namespace Unity.AI.MLAgents
}
// If no agents requested a decision return
if (world.DecisionCounter.Count == 0 && world.TerminationCounter.Count == 0)
if (policy.DecisionCounter.Count == 0 && policy.TerminationCounter.Count == 0)
{
return;
}
// Ensure the world does not have lingering actions:
if (world.ActionCounter.Count != 0)
// Ensure the policy does not have lingering actions:
if (policy.ActionCounter.Count != 0)
{
// This means something in the execution went wrong, this error should never appear
throw new MLAgentsException("TODO : ActionCount is not 0");
}
var processor = m_WorldToProcessor[world];
var processor = m_PolicyToProcessor[policy];
if (processor == null)
{
// Raise error
throw new MLAgentsException($"A world has not been correctly registered.");
throw new MLAgentsException($"A Policy has not been correctly registered.");
}
@ -208,10 +205,10 @@ namespace Unity.AI.MLAgents
if (!reset) // TODO : Comment out if we do not want to reset on first env.reset()
{
m_Communicator.WriteSideChannelData(SideChannelsManager.GetSideChannelMessage());
processor.ProcessWorld();
processor.Process();
reset = m_Communicator.ReadAndClearResetCommand();
world.SetActionReady();
world.ResetDecisionsAndTerminationCounters();
policy.SetActionReady();
policy.ResetDecisionsAndTerminationCounters();
SideChannelsManager.ProcessSideChannelData(m_Communicator.ReadAndClearSideChannelData());
}
if (reset)
@ -222,16 +219,16 @@ namespace Unity.AI.MLAgents
}
else if (!processor.IsConnected)
{
processor.ProcessWorld();
world.SetActionReady();
world.ResetDecisionsAndTerminationCounters();
processor.Process();
policy.SetActionReady();
policy.ResetDecisionsAndTerminationCounters();
// TODO com.ReadAndClearSideChannelData(); // Remove side channel data
}
else
{
// The processor wants to communicate but the communicator is either null or inactive
world.ResetActionsCounter();
world.ResetDecisionsAndTerminationCounters();
policy.ResetActionsCounter();
policy.ResetDecisionsAndTerminationCounters();
}
if (m_Communicator == null)
{
@ -246,16 +243,16 @@ namespace Unity.AI.MLAgents
// Need to complete all of the jobs at this point.
ECSWorld.EntityManager.CompleteAllJobs();
}
ResetAllWorlds();
ResetAllPolicies();
OnEnvironmentReset?.Invoke();
}
private void ResetAllWorlds() // This is problematic because it affects all worlds and is not thread safe...
private void ResetAllPolicies() // This is problematic because it affects all policies and is not thread safe...
{
foreach (var w in m_WorldToProcessor.Keys)
foreach (var pol in m_PolicyToProcessor.Keys)
{
w.ResetActionsCounter();
w.ResetDecisionsAndTerminationCounters();
pol.ResetActionsCounter();
pol.ResetDecisionsAndTerminationCounters();
}
}

Просмотреть файл

@ -9,39 +9,39 @@ namespace Unity.AI.MLAgents
public static class ActionHashMapUtils
{
/// <summary>
/// Retrieves the action data for a world in puts it into a HashMap.
/// This action deletes the action data from the world.
/// Retrieves the action data for a Policy in puts it into a HashMap.
/// This action deletes the action data from the Policy.
/// </summary>
/// <param name="world"> The MLAgentsWorld the data will be retrieved from.</param>
/// <param name="policy"> The Policy the data will be retrieved from.</param>
/// <param name="allocator"> The memory allocator of the create NativeHashMap.</param>
/// <typeparam name="T"> The type of the Action struct. It must match the Action Size
/// and Action Type of the world.</typeparam>
/// and Action Type of the Policy.</typeparam>
/// <returns> A NativeHashMap from Entities to Actions with type T.</returns>
public static NativeHashMap<Entity, T> GenerateActionHashMap<T>(this MLAgentsWorld world, Allocator allocator) where T : struct
public static NativeHashMap<Entity, T> GenerateActionHashMap<T>(this Policy policy, Allocator allocator) where T : struct
{
#if ENABLE_UNITY_COLLECTIONS_CHECKS
if (world.ActionSize != UnsafeUtility.SizeOf<T>() / 4)
if (policy.ActionSize != UnsafeUtility.SizeOf<T>() / 4)
{
var receivedSize = UnsafeUtility.SizeOf<T>() / 4;
throw new MLAgentsException($"Action space size does not match for action. Expected {world.ActionSize} but received {receivedSize}");
throw new MLAgentsException($"Action space size does not match for action. Expected {policy.ActionSize} but received {receivedSize}");
}
#endif
Academy.Instance.UpdateWorld(world);
int actionCount = world.ActionCounter.Count;
Academy.Instance.UpdatePolicy(policy);
int actionCount = policy.ActionCounter.Count;
var result = new NativeHashMap<Entity, T>(actionCount, allocator);
int size = world.ActionSize;
int size = policy.ActionSize;
for (int i = 0; i < actionCount; i++)
{
if (world.ActionType == ActionType.DISCRETE)
if (policy.ActionType == ActionType.DISCRETE)
{
result.TryAdd(world.ActionAgentEntityIds[i], world.DiscreteActuators.Slice(i * size, size).SliceConvert<T>()[0]);
result.TryAdd(policy.ActionAgentEntityIds[i], policy.DiscreteActuators.Slice(i * size, size).SliceConvert<T>()[0]);
}
else
{
result.TryAdd(world.ActionAgentEntityIds[i], world.ContinuousActuators.Slice(i * size, size).SliceConvert<T>()[0]);
result.TryAdd(policy.ActionAgentEntityIds[i], policy.ContinuousActuators.Slice(i * size, size).SliceConvert<T>()[0]);
}
}
world.ResetActionsCounter();
policy.ResetActionsCounter();
return result;
}
}

Просмотреть файл

@ -34,7 +34,7 @@ namespace Unity.AI.MLAgents
/// Retrieve the action that was decided for the Entity.
/// This method uses a generic type, as such you must provide a
/// type that is compatible with the Action Type and Action Size
/// for this MLAgentsWorld.
/// for this Policy.
/// </summary>
/// <typeparam name="T"> The type of action struct.</typeparam>
/// <returns> The action struct for the Entity.</returns>
@ -59,7 +59,7 @@ namespace Unity.AI.MLAgents
}
/// <summary>
/// The signature of the Job used to retrieve actuators values from a world
/// The signature of the Job used to retrieve actuators values from a Policy
/// </summary>
[JobProducerType(typeof(IActuatorJobExtensions.ActuatorDataJobProcess<>))]
public interface IActuatorJob
@ -73,31 +73,31 @@ namespace Unity.AI.MLAgents
/// Schedule the Job that will generate the action data for the Entities that requested a decision.
/// </summary>
/// <param name="jobData"> The IActuatorJob struct.</param>
/// <param name="mlagentsWorld"> The MLAgentsWorld containing the data needed for decision making.</param>
/// <param name="policy"> The Policy containing the data needed for decision making.</param>
/// <param name="inputDeps"> The jobHandle for the job.</param>
/// <typeparam name="T"> The type of the IActuatorData struct.</typeparam>
/// <returns> The updated jobHandle for the job.</returns>
public static unsafe JobHandle Schedule<T>(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps)
public static unsafe JobHandle Schedule<T>(this T jobData, Policy policy, JobHandle inputDeps)
where T : struct, IActuatorJob
{
inputDeps.Complete(); // TODO : FIND A BETTER WAY TO MAKE SURE ALL THE DATA IS IN THE WORLD
Academy.Instance.UpdateWorld(mlagentsWorld);
if (mlagentsWorld.ActionCounter.Count == 0)
inputDeps.Complete(); // TODO : FIND A BETTER WAY TO MAKE SURE ALL THE DATA IS IN THE POLICY
Academy.Instance.UpdatePolicy(policy);
if (policy.ActionCounter.Count == 0)
{
return inputDeps;
}
return ScheduleImpl(jobData, mlagentsWorld, inputDeps);
return ScheduleImpl(jobData, policy, inputDeps);
}
// Passing this along
internal static unsafe JobHandle ScheduleImpl<T>(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps)
internal static unsafe JobHandle ScheduleImpl<T>(this T jobData, Policy policy, JobHandle inputDeps)
where T : struct, IActuatorJob
{
// Creating a data struct that contains the data the user passed into the job (This is what T is here)
var data = new ActionEventJobData<T>
{
UserJobData = jobData,
world = mlagentsWorld // Need to create this before hand with the actuator data
Policy = policy // Need to create this before hand with the actuator data
};
// Scheduling a Job out of thin air by using a pointer called jobReflectionData in the ActuatorSystemJobStruct
@ -105,11 +105,11 @@ namespace Unity.AI.MLAgents
return JobsUtility.Schedule(ref parameters);
}
// This is the struct containing all the data needed from both the user and the MLAgents world
// This is the struct containing all the data needed from both the user and the Policy
internal unsafe struct ActionEventJobData<T> where T : struct
{
public T UserJobData;
[NativeDisableContainerSafetyRestriction] public MLAgentsWorld world;
[NativeDisableContainerSafetyRestriction] public Policy Policy;
}
internal struct ActuatorDataJobProcess<T> where T : struct, IActuatorJob
@ -132,11 +132,11 @@ namespace Unity.AI.MLAgents
/// Calls the user implemented Execute method with ActuatorEvent struct
public static unsafe void Execute(ref ActionEventJobData<T> jobData, IntPtr listDataPtr, IntPtr unusedPtr, ref JobRanges ranges, int jobIndex)
{
int size = jobData.world.ActionSize;
int actionCount = jobData.world.ActionCounter.Count;
int size = jobData.Policy.ActionSize;
int actionCount = jobData.Policy.ActionCounter.Count;
// Continuous case
if (jobData.world.ActionType == ActionType.CONTINUOUS)
if (jobData.Policy.ActionType == ActionType.CONTINUOUS)
{
for (int i = 0; i < actionCount; i++)
{
@ -144,8 +144,8 @@ namespace Unity.AI.MLAgents
{
ActionSize = size,
ActionType = ActionType.CONTINUOUS,
Entity = jobData.world.ActionAgentEntityIds[i],
ContinuousActionSlice = jobData.world.ContinuousActuators.Slice(i * size, size)
Entity = jobData.Policy.ActionAgentEntityIds[i],
ContinuousActionSlice = jobData.Policy.ContinuousActuators.Slice(i * size, size)
});
}
}
@ -158,12 +158,12 @@ namespace Unity.AI.MLAgents
{
ActionSize = size,
ActionType = ActionType.DISCRETE,
Entity = jobData.world.ActionAgentEntityIds[i],
DiscreteActionSlice = jobData.world.DiscreteActuators.Slice(i * size, size)
Entity = jobData.Policy.ActionAgentEntityIds[i],
DiscreteActionSlice = jobData.Policy.DiscreteActuators.Slice(i * size, size)
});
}
}
jobData.world.ResetActionsCounter();
jobData.Policy.ResetActionsCounter();
}
}
}

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -7,19 +7,19 @@ using System;
namespace Unity.AI.MLAgents
{
/// <summary>
/// A DecisionRequest is a struct used to provide data about an Agent to a MLAgentsWorld.
/// This data will be used to generate a decision after the world is processed.
/// A DecisionRequest is a struct used to provide data about an Agent to a Policy.
/// This data will be used to generate a decision after the Policy is processed.
/// Adding data is done through a builder pattern.
/// </summary>
public struct DecisionRequest
{
private int m_Index;
private MLAgentsWorld m_World;
private Policy m_Policy;
internal DecisionRequest(int index, MLAgentsWorld world)
internal DecisionRequest(int index, Policy policy)
{
this.m_Index = index;
this.m_World = world;
this.m_Policy = policy;
}
/// <summary>
@ -29,7 +29,7 @@ namespace Unity.AI.MLAgents
/// <returns> The DecisionRequest struct </returns>
public DecisionRequest SetReward(float r)
{
m_World.DecisionRewards[m_Index] = r;
m_Policy.DecisionRewards[m_Index] = r;
return this;
}
@ -43,45 +43,45 @@ namespace Unity.AI.MLAgents
public DecisionRequest SetDiscreteActionMask(int branch, int actionIndex)
{
#if ENABLE_UNITY_COLLECTIONS_CHECKS
if (m_World.ActionType == ActionType.CONTINUOUS)
if (m_Policy.ActionType == ActionType.CONTINUOUS)
{
throw new MLAgentsException("SetDiscreteActionMask can only be used with discrete acton space.");
}
if (branch > m_World.DiscreteActionBranches.Length)
if (branch > m_Policy.DiscreteActionBranches.Length)
{
throw new MLAgentsException("Unknown action branch used when setting mask.");
}
if (actionIndex > m_World.DiscreteActionBranches[branch])
if (actionIndex > m_Policy.DiscreteActionBranches[branch])
{
throw new MLAgentsException("Index is out of bounds for requested action mask.");
}
#endif
var trueMaskIndex = m_World.DiscreteActionBranches.CumSumAt(branch) + actionIndex;
m_World.DecisionActionMasks[trueMaskIndex + m_World.DiscreteActionBranches.Sum() * m_Index] = true;
var trueMaskIndex = m_Policy.DiscreteActionBranches.CumSumAt(branch) + actionIndex;
m_Policy.DecisionActionMasks[trueMaskIndex + m_Policy.DiscreteActionBranches.Sum() * m_Index] = true;
return this;
}
/// <summary>
/// Sets the observation for a decision request.
/// </summary>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
/// <param name="sensor"> A struct strictly containing floats used as observation data </param>
/// <returns> The DecisionRequest struct </returns>
public DecisionRequest SetObservation<T>(int sensorNumber, T sensor) where T : struct
{
int inputSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
#if ENABLE_UNITY_COLLECTIONS_CHECKS
int3 s = m_World.SensorShapes[sensorNumber];
int3 s = m_Policy.SensorShapes[sensorNumber];
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
if (inputSize != expectedInputSize)
{
throw new MLAgentsException(
$"Cannot set observation due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
}
#endif
int start = m_World.ObservationOffsets[sensorNumber];
int start = m_Policy.ObservationOffsets[sensorNumber];
start += inputSize * m_Index;
var tmp = m_World.DecisionObs.Slice(start, inputSize).SliceConvert<T>();
var tmp = m_Policy.DecisionObs.Slice(start, inputSize).SliceConvert<T>();
tmp[0] = sensor;
return this;
}
@ -89,12 +89,12 @@ namespace Unity.AI.MLAgents
/// <summary>
/// Sets the observation for a decision request using a categorical value.
/// </summary>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
/// <param name="sensor"> An integer containing the index of the categorical observation </param>
/// <returns> The DecisionRequest struct </returns>
public DecisionRequest SetObservation(int sensorNumber, int sensor)
{
int3 s = m_World.SensorShapes[sensorNumber];
int3 s = m_Policy.SensorShapes[sensorNumber];
int maxValue = s.x;
#if ENABLE_UNITY_COLLECTIONS_CHECKS
@ -109,33 +109,33 @@ namespace Unity.AI.MLAgents
$"Categorical observation is out of bound for observation {sensorNumber} with maximum {maxValue} (received {sensor}.");
}
#endif
int start = m_World.ObservationOffsets[sensorNumber];
int start = m_Policy.ObservationOffsets[sensorNumber];
start += maxValue * m_Index;
m_World.DecisionObs[start + sensor] = 1.0f;
m_Policy.DecisionObs[start + sensor] = 1.0f;
return this;
}
/// <summary>
/// Sets the observation for a decision request.
/// </summary>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
/// <param name="obs"> A NativeSlice of floats containing the observation data </param>
/// <returns> The DecisionRequest struct </returns>
public DecisionRequest SetObservationFromSlice(int sensorNumber, [ReadOnly] NativeSlice<float> obs)
{
int inputSize = obs.Length;
#if ENABLE_UNITY_COLLECTIONS_CHECKS
int3 s = m_World.SensorShapes[sensorNumber];
int3 s = m_Policy.SensorShapes[sensorNumber];
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
if (inputSize != expectedInputSize)
{
throw new MLAgentsException(
$"Cannot set observation due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
}
#endif
int start = m_World.ObservationOffsets[sensorNumber];
int start = m_Policy.ObservationOffsets[sensorNumber];
start += inputSize * m_Index;
m_World.DecisionObs.Slice(start, inputSize).CopyFrom(obs);
m_Policy.DecisionObs.Slice(start, inputSize).CopyFrom(obs);
return this;
}
}

Просмотреть файл

@ -7,19 +7,19 @@ using System;
namespace Unity.AI.MLAgents
{
/// <summary>
/// A EpisodeTermination is a struct used to provide data about an Agent to a MLAgentsWorld.
/// A EpisodeTermination is a struct used to provide data about an Agent to a Policy.
/// This data will be used to notify of the end of the episode of an Agent.
/// Adding data is done through a builder pattern.
/// </summary>
public struct EpisodeTermination
{
private int m_Index;
private MLAgentsWorld m_World;
private Policy m_Policy;
internal EpisodeTermination(int index, MLAgentsWorld world)
internal EpisodeTermination(int index, Policy policy)
{
this.m_Index = index;
this.m_World = world;
this.m_Policy = policy;
}
/// <summary>
@ -30,31 +30,31 @@ namespace Unity.AI.MLAgents
/// <returns> The EpisodeTermination struct </returns>
public EpisodeTermination SetReward(float r)
{
m_World.TerminationRewards[m_Index] = r;
m_Policy.TerminationRewards[m_Index] = r;
return this;
}
/// <summary>
/// Sets the observation for of the end of the Episode.
/// </summary>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
/// <param name="sensor"> A struct strictly containing floats used as observation data </param>
/// <returns> The EpisodeTermination struct </returns>
public EpisodeTermination SetObservation<T>(int sensorNumber, T sensor) where T : struct
{
int inputSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
#if ENABLE_UNITY_COLLECTIONS_CHECKS
int3 s = m_World.SensorShapes[sensorNumber];
int3 s = m_Policy.SensorShapes[sensorNumber];
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
if (inputSize != expectedInputSize)
{
throw new MLAgentsException(
$"Cannot set observation due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
}
#endif
int start = m_World.ObservationOffsets[sensorNumber];
int start = m_Policy.ObservationOffsets[sensorNumber];
start += inputSize * m_Index;
var tmp = m_World.TerminationObs.Slice(start, inputSize).SliceConvert<T>();
var tmp = m_Policy.TerminationObs.Slice(start, inputSize).SliceConvert<T>();
tmp[0] = sensor;
return this;
}
@ -62,12 +62,12 @@ namespace Unity.AI.MLAgents
/// <summary>
/// Sets the observation for a termination request using a categorical value.
/// </summary>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
/// <param name="sensor"> An integer containing the index of the categorical observation </param>
/// <returns> The EpisodeTermination struct </returns>
public EpisodeTermination SetObservation(int sensorNumber, int sensor)
{
int3 s = m_World.SensorShapes[sensorNumber];
int3 s = m_Policy.SensorShapes[sensorNumber];
int maxValue = s.x;
#if ENABLE_UNITY_COLLECTIONS_CHECKS
@ -82,33 +82,33 @@ namespace Unity.AI.MLAgents
$"Categorical observation is out of bound for observation {sensorNumber} with maximum {maxValue} (received {sensor}.");
}
#endif
int start = m_World.ObservationOffsets[sensorNumber];
int start = m_Policy.ObservationOffsets[sensorNumber];
start += maxValue * m_Index;
m_World.TerminationObs[start + sensor] = 1.0f;
m_Policy.TerminationObs[start + sensor] = 1.0f;
return this;
}
/// <summary>
/// Sets the last observation the Agent perceives before ending the episode.
/// </summary>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
/// <param name="obs"> A NativeSlice of floats containing the observation data </param>
/// <returns> The EpisodeTermination struct </returns>
public EpisodeTermination SetObservationFromSlice(int sensorNumber, [ReadOnly] NativeSlice<float> obs)
{
int inputSize = obs.Length;
#if ENABLE_UNITY_COLLECTIONS_CHECKS
int3 s = m_World.SensorShapes[sensorNumber];
int3 s = m_Policy.SensorShapes[sensorNumber];
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
if (inputSize != expectedInputSize)
{
throw new MLAgentsException(
$"Cannot set observation due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
}
#endif
int start = m_World.ObservationOffsets[sensorNumber];
int start = m_Policy.ObservationOffsets[sensorNumber];
start += inputSize * m_Index;
m_World.TerminationObs.Slice(start, inputSize).CopyFrom(obs);
m_Policy.TerminationObs.Slice(start, inputSize).CopyFrom(obs);
return this;
}
}

Просмотреть файл

@ -9,14 +9,14 @@ using Unity.Collections.LowLevel.Unsafe;
namespace Unity.AI.MLAgents
{
/// <summary>
/// MLAgentsWorld is a data container on which the user requests decisions.
/// Policy is a data container on which the user requests decisions.
/// </summary>
public struct MLAgentsWorld : IDisposable
public struct Policy : IDisposable
{
/// <summary>
/// Indicates if the MLAgentsWorld has been instantiated
/// Indicates if the Policy has been instantiated
/// </summary>
/// <value> True if MLAgentsWorld was instantiated, False otherwise</value>
/// <value> True if the Policy was instantiated, False otherwise</value>
public bool IsCreated
{
get { return DecisionAgentIds.IsCreated;}
@ -60,13 +60,13 @@ namespace Unity.AI.MLAgents
/// </summary>
/// <param name="maximumNumberAgents"> The maximum number of decisions that can be requested between each MLAgentsSystem update </param>
/// <param name="obsShapes"> An array of int3 corresponding to the shape of the expected observations (one int3 per observation) </param>
/// <param name="actionType"> An ActionType enum (DISCRETE / CONTINUOUS) specifying the type of actions the MLAgentsWorld will produce </param>
/// <param name="actionSize"> The number of actions the MLAgentsWorld is expected to generate for each decision.
/// <param name="actionType"> An ActionType enum (DISCRETE / CONTINUOUS) specifying the type of actions the Policy will produce </param>
/// <param name="actionSize"> The number of actions the Policy is expected to generate for each decision.
/// - If CONTINUOUS ActionType : The number of floats the action contains
/// - If DISCRETE ActionType : The number of branches (integer actions) the action contains </param>
/// <param name="discreteActionBranches"> For DISCRETE ActionType only : an array of int specifying the number of possible int values each
/// action branch has. (Must be of the same length as actionSize </param>
public MLAgentsWorld(
public Policy(
int maximumNumberAgents,
int3[] obsShapes,
ActionType actionType,
@ -164,7 +164,7 @@ namespace Unity.AI.MLAgents
}
/// <summary>
/// Dispose of the MLAgentsWorld.
/// Dispose of the Policy.
/// </summary>
public void Dispose()
{
@ -199,7 +199,7 @@ namespace Unity.AI.MLAgents
}
/// <summary>
/// Requests a decision for a specific Entity to the MLAgentsWorld. The DecisionRequest
/// Requests a decision for a specific Entity to the Policy. The DecisionRequest
/// struct this method returns can be used to provide data necessary for the Agent to
/// take a decision for the Entity.
/// </summary>
@ -211,7 +211,7 @@ namespace Unity.AI.MLAgents
#if ENABLE_UNITY_COLLECTIONS_CHECKS
if (!IsCreated)
{
throw new MLAgentsException($"Invalid operation, cannot request a decision on a non-initialized MLAgentsWorld");
throw new MLAgentsException($"Invalid operation, cannot request a decision on a non-initialized Policy");
}
#endif
var index = DecisionCounter.Increment() - 1;
@ -239,7 +239,7 @@ namespace Unity.AI.MLAgents
#if ENABLE_UNITY_COLLECTIONS_CHECKS
if (!IsCreated)
{
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized MLAgentsWorld");
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized Policy");
}
#endif
var index = TerminationCounter.Increment() - 1;
@ -267,7 +267,7 @@ namespace Unity.AI.MLAgents
#if ENABLE_UNITY_COLLECTIONS_CHECKS
if (!IsCreated)
{
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized MLAgentsWorld");
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized Policy");
}
#endif
var index = TerminationCounter.Increment() - 1;

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -21,23 +21,23 @@ namespace Unity.AI.MLAgents
GPU = 1
}
public static class BarracudaWorldProcessorRegistringExtension
public static class BarracudaPolicyProcessorRegistringExtension
{
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Neural
/// Registers the given Policy to the Academy with a Neural
/// Network Model. If the input model is null, a default inactive
/// processor will be registered instead. Note that if the simulation
/// connects to Python, the Neural Network will be ignored and the world
/// connects to Python, the Neural Network will be ignored and the Policy
/// will exchange data with Python instead.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// <param name="policy"> The Policy to register</param>
/// <param name="policyId"> The name of the Policy. This is useful for identification
/// and for training.</param>
/// <param name="model"> The Neural Network model used by the processor</param>
/// <param name="inferenceDevice"> The inference device specifying where to run inference
/// (CPU or GPU)</param>
public static void RegisterWorldWithBarracudaModel(
this MLAgentsWorld world,
public static void RegisterPolicyWithBarracudaModel(
this Policy policy,
string policyId,
NNModel model,
InferenceDevice inferenceDevice = InferenceDevice.CPU
@ -45,30 +45,30 @@ namespace Unity.AI.MLAgents
{
if (model != null)
{
var worldProcessor = new BarracudaWorldProcessor(world, model, inferenceDevice);
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true);
var policyProcessor = new BarracudaPolicyProcessor(policy, model, inferenceDevice);
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, true);
}
else
{
Academy.Instance.RegisterWorld(policyId, world, null, true);
Academy.Instance.RegisterPolicy(policyId, policy, null, true);
}
}
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Neural
/// Registers the given Policy to the Academy with a Neural
/// Network Model. If the input model is null, a default inactive
/// processor will be registered instead. Note that if the simulation
/// connects to Python, the world will not connect to Python and run the
/// connects to Python, the Policy will not connect to Python and run the
/// given Neural Network regardless.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// <param name="policy"> The Policy to register</param>
/// <param name="policyId"> The name of the Policy. This is useful for identification
/// and for training.</param>
/// <param name="model"> The Neural Network model used by the processor</param>
/// <param name="inferenceDevice"> The inference device specifying where to run inference
/// (CPU or GPU)</param>
public static void RegisterWorldWithBarracudaModelForceNoCommunication(
this MLAgentsWorld world,
public static void RegisterPolicyWithBarracudaModelForceNoCommunication(
this Policy policy,
string policyId,
NNModel model,
InferenceDevice inferenceDevice = InferenceDevice.CPU
@ -76,19 +76,19 @@ namespace Unity.AI.MLAgents
{
if (model != null)
{
var worldProcessor = new BarracudaWorldProcessor(world, model, inferenceDevice);
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, false);
var policyProcessor = new BarracudaPolicyProcessor(policy, model, inferenceDevice);
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, false);
}
else
{
Academy.Instance.RegisterWorld(policyId, world, null, false);
Academy.Instance.RegisterPolicy(policyId, policy, null, false);
}
}
}
internal unsafe class BarracudaWorldProcessor : IWorldProcessor
internal unsafe class BarracudaPolicyProcessor : IPolicyProcessor
{
private MLAgentsWorld m_World;
private Policy m_Policy;
private Model m_BarracudaModel;
private IWorker m_Engine;
private const bool k_Verbose = false;
@ -98,9 +98,9 @@ namespace Unity.AI.MLAgents
public bool IsConnected {get {return false;}}
internal BarracudaWorldProcessor(MLAgentsWorld world, NNModel model, InferenceDevice inferenceDevice)
internal BarracudaPolicyProcessor(Policy policy, NNModel model, InferenceDevice inferenceDevice)
{
this.m_World = world;
this.m_Policy = policy;
D.logEnabled = k_Verbose;
m_Engine?.Dispose();
@ -111,15 +111,15 @@ namespace Unity.AI.MLAgents
m_Engine = WorkerFactory.CreateWorker(
executionDevice, m_BarracudaModel, k_Verbose);
for (int i = 0; i < m_World.SensorShapes.Length; i++)
for (int i = 0; i < m_Policy.SensorShapes.Length; i++)
{
if (m_World.SensorShapes[i].GetDimensions() == 1)
obsSize += m_World.SensorShapes[i].GetTotalTensorSize();
if (m_Policy.SensorShapes[i].GetDimensions() == 1)
obsSize += m_Policy.SensorShapes[i].GetTotalTensorSize();
}
vectorObsArr = new float[m_World.DecisionAgentIds.Length * obsSize];
vectorObsArr = new float[m_Policy.DecisionAgentIds.Length * obsSize];
}
public void ProcessWorld()
public void Process()
{
// TODO : Cover all cases
// FOR VECTOR OBS ONLY
@ -128,27 +128,27 @@ namespace Unity.AI.MLAgents
var input = new System.Collections.Generic.Dictionary<string, Tensor>();
// var sensorData = m_World.DecisionObs.ToArray();
// var sensorData = m_Policy.DecisionObs.ToArray();
int sensorOffset = 0;
int vecObsOffset = 0;
foreach (var shape in m_World.SensorShapes)
foreach (var shape in m_Policy.SensorShapes)
{
if (shape.GetDimensions() == 1)
{
for (int i = 0; i < m_World.DecisionCounter.Count; i++)
for (int i = 0; i < m_Policy.DecisionCounter.Count; i++)
{
fixed(void* arrPtr = vectorObsArr)
{
UnsafeUtility.MemCpy(
(byte*)arrPtr + 4 * i * obsSize + 4 * vecObsOffset,
(byte*)m_World.DecisionObs.GetUnsafePtr() + 4 * sensorOffset + 4 * i * shape.GetTotalTensorSize(),
(byte*)m_Policy.DecisionObs.GetUnsafePtr() + 4 * sensorOffset + 4 * i * shape.GetTotalTensorSize(),
shape.GetTotalTensorSize() * 4
);
}
// Array.Copy(sensorData, sensorOffset + i * shape.GetTotalTensorSize(), vectorObsArr, i * obsSize + vecObsOffset, shape.GetTotalTensorSize());
}
sensorOffset += m_World.DecisionAgentIds.Length * shape.GetTotalTensorSize();
sensorOffset += m_Policy.DecisionAgentIds.Length * shape.GetTotalTensorSize();
vecObsOffset += shape.GetTotalTensorSize();
}
else
@ -158,7 +158,7 @@ namespace Unity.AI.MLAgents
}
input["vector_observation"] = new Tensor(
new TensorShape(m_World.DecisionCounter.Count, obsSize),
new TensorShape(m_Policy.DecisionCounter.Count, obsSize),
vectorObsArr,
"vector_observation");
@ -166,18 +166,18 @@ namespace Unity.AI.MLAgents
var actuatorT = m_Engine.CopyOutput("action");
switch (m_World.ActionType)
switch (m_Policy.ActionType)
{
case ActionType.CONTINUOUS:
int count = m_World.DecisionCounter.Count * m_World.ActionSize;
int count = m_Policy.DecisionCounter.Count * m_Policy.ActionSize;
var wholeData = actuatorT.data.Download(count);
// var dest = new float[count];
// Array.Copy(wholeData, dest, count);
// m_World.ContinuousActuators.Slice(0, count).CopyFrom(dest);
// m_Policy.ContinuousActuators.Slice(0, count).CopyFrom(dest);
fixed(void* arrPtr = wholeData)
{
UnsafeUtility.MemCpy(
m_World.ContinuousActuators.GetUnsafePtr(),
m_Policy.ContinuousActuators.GetUnsafePtr(),
arrPtr,
count * 4
);

Просмотреть файл

@ -0,0 +1,104 @@
using System;
using Unity.Collections;
using Unity.Collections.LowLevel.Unsafe;
namespace Unity.AI.MLAgents
{
public static class HeristicPolicyProcessorRegistringExtension
{
/// <summary>
/// Registers the given Policy to the Academy with a Heuristic.
/// Note that if the simulation connects to Python, the Heuristic will
/// be ignored and the Policy will exchange data with Python instead.
/// The Heuristic is a Function that returns an action struct.
/// </summary>
/// <param name="policy"> The Policy to register</param>
/// <param name="policyId"> The name of the Policy. This is useful for identification
/// and for training.</param>
/// <param name="heuristic"> The Heuristic used to generate the actions.
/// Note that all agents in the Policy will receive the same action.</param>
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
/// Size and Action Type of the Policy.</typeparam>
public static void RegisterPolicyWithHeuristic<TH>(
this Policy policy,
string policyId,
Func<TH> heuristic
) where TH : struct
{
var policyProcessor = new HeuristicPolicyProcessor<TH>(policy, heuristic);
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, true);
}
/// <summary>
/// Registers the given Policy to the Academy with a Heuristic.
/// Note that if the simulation connects to Python, the Policy will not
/// exchange data with Python and use the Heuristic regardless.
/// The Heuristic is a Function that returns an action struct.
/// </summary>
/// <param name="policy"> The Policy to register</param>
/// <param name="policyId"> The name of the Policy. This is useful for identification
/// and for training.</param>
/// <param name="heuristic"> The Heuristic used to generate the actions.
/// Note that all agents in the Policy will receive the same action.</param>
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
/// Size and Action Type of the Policy.</typeparam>
public static void RegisterPolicyWithHeuristicForceNoCommunication<TH>(
this Policy policy,
string policyId,
Func<TH> heuristic
) where TH : struct
{
var policyProcessor = new HeuristicPolicyProcessor<TH>(policy, heuristic);
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, false);
}
}
internal class HeuristicPolicyProcessor<T> : IPolicyProcessor where T : struct
{
private Func<T> m_Heuristic;
private Policy m_Policy;
public bool IsConnected {get {return false;}}
internal HeuristicPolicyProcessor(Policy policy, Func<T> heuristic)
{
this.m_Policy = policy;
this.m_Heuristic = heuristic;
var structSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
if (structSize != policy.ActionSize)
{
throw new MLAgentsException(
$"The heuristic provided does not match the action size. Expected {policy.ActionSize} but received {structSize} from heuristic");
}
}
public void Process()
{
T action = m_Heuristic.Invoke();
var totalCount = m_Policy.DecisionCounter.Count;
// TODO : This can be parallelized
if (m_Policy.ActionType == ActionType.CONTINUOUS)
{
var s = m_Policy.ContinuousActuators.Slice(0, totalCount * m_Policy.ActionSize).SliceConvert<T>();
for (int i = 0; i < totalCount; i++)
{
s[i] = action;
}
}
else
{
var s = m_Policy.DiscreteActuators.Slice(0, totalCount * m_Policy.ActionSize).SliceConvert<T>();
for (int i = 0; i < totalCount; i++)
{
s[i] = action;
}
}
}
public void Dispose()
{
}
}
}

Просмотреть файл

@ -6,21 +6,21 @@ using Unity.Collections.LowLevel.Unsafe;
namespace Unity.AI.MLAgents
{
/// <summary>
/// The interface for a world processor. A world processor updates the
/// action data of a world using the observation data present in it.
/// The interface for a Policy processor. A Policy processor updates the
/// action data of a Policy using the observation data present in it.
/// </summary>
public interface IWorldProcessor : IDisposable
public interface IPolicyProcessor : IDisposable
{
/// <summary>
/// True if the World Processor is connected to the Python process
/// True if the Policy Processor is connected to the Python process
/// </summary>
bool IsConnected {get;}
/// <summary>
/// This method is called once everytime the world needs to update its action
/// This method is called once everytime the policy needs to update its action
/// data. The implementation of this mehtod must give new action data to all the
/// agents that requested a decision.
/// </summary>
void ProcessWorld();
void Process();
}
}

Просмотреть файл

@ -5,18 +5,18 @@ using Unity.Collections.LowLevel.Unsafe;
namespace Unity.AI.MLAgents
{
internal class NullWorldProcessor : IWorldProcessor
internal class NullPolicyProcessor : IPolicyProcessor
{
private MLAgentsWorld m_World;
private Policy m_Policy;
public bool IsConnected {get {return false;}}
internal NullWorldProcessor(MLAgentsWorld world)
internal NullPolicyProcessor(Policy policy)
{
this.m_World = world;
this.m_Policy = policy;
}
public void ProcessWorld()
public void Process()
{
}

Просмотреть файл

@ -5,27 +5,27 @@ using Unity.Collections.LowLevel.Unsafe;
namespace Unity.AI.MLAgents
{
internal class RemoteWorldProcessor : IWorldProcessor
internal class RemotePolicyProcessor : IPolicyProcessor
{
private MLAgentsWorld m_World;
private Policy m_Policy;
private SharedMemoryCommunicator m_Communicator;
private string m_PolicyId;
public bool IsConnected {get {return true;}}
internal RemoteWorldProcessor(MLAgentsWorld world, string policyId, SharedMemoryCommunicator com)
internal RemotePolicyProcessor(Policy policy, string policyId, SharedMemoryCommunicator com)
{
this.m_World = world;
this.m_Policy = policy;
this.m_Communicator = com;
this.m_PolicyId = policyId;
}
public void ProcessWorld()
public void Process()
{
m_Communicator.WriteWorld(m_PolicyId, m_World);
m_Communicator.WritePolicy(m_PolicyId, m_Policy);
m_Communicator.SetUnityReady();
m_Communicator.WaitForPython();
m_Communicator.LoadWorld(m_PolicyId, m_World);
m_Communicator.LoadPolicy(m_PolicyId, m_Policy);
}
public void Dispose()

Просмотреть файл

@ -68,26 +68,26 @@ namespace Unity.AI.MLAgents
startOffset);
}
public static RLDataOffsets FromWorld(MLAgentsWorld world, string name, int offset)
public static RLDataOffsets FromPolicy(Policy policy, string name, int offset)
{
bool isContinuous = world.ActionType == ActionType.CONTINUOUS;
bool isContinuous = policy.ActionType == ActionType.CONTINUOUS;
int totalFloatObsPerAgent = 0;
foreach (int3 shape in world.SensorShapes)
foreach (int3 shape in policy.SensorShapes)
{
totalFloatObsPerAgent += shape.GetTotalTensorSize();
}
int totalNumberOfMasks = 0;
if (!isContinuous)
{
totalNumberOfMasks = world.DiscreteActionBranches.Sum();
totalNumberOfMasks = policy.DiscreteActionBranches.Sum();
}
return ComputeOffsets(
name,
world.DecisionAgentIds.Length,
policy.DecisionAgentIds.Length,
isContinuous,
world.ActionSize,
world.SensorShapes.Length,
policy.ActionSize,
policy.SensorShapes.Length,
totalFloatObsPerAgent,
totalNumberOfMasks,
offset

Просмотреть файл

@ -52,12 +52,12 @@ namespace Unity.AI.MLAgents
}
}
public bool ContainsWorld(string name)
public bool ContainsPolicy(string name)
{
return m_OffsetDict.ContainsKey(name);
}
public void WriteWorld(string name, MLAgentsWorld world)
public void WritePolicy(string name, Policy policy)
{
if (!CanEdit)
{
@ -69,48 +69,48 @@ namespace Unity.AI.MLAgents
}
var dataOffsets = m_OffsetDict[name];
int totalFloatObsPerAgent = 0;
foreach (int3 shape in world.SensorShapes)
foreach (int3 shape in policy.SensorShapes)
{
totalFloatObsPerAgent += shape.GetTotalTensorSize();
}
// Decision data
var decisionCount = world.DecisionCounter.Count;
var decisionCount = policy.DecisionCounter.Count;
SetInt(dataOffsets.DecisionNumberAgentsOffset, decisionCount);
SetArray(dataOffsets.DecisionObsOffset, world.DecisionObs, 4 * decisionCount * totalFloatObsPerAgent);
SetArray(dataOffsets.DecisionRewardsOffset, world.DecisionRewards, 4 * decisionCount);
SetArray(dataOffsets.DecisionAgentIdOffset, world.DecisionAgentIds, 4 * decisionCount);
if (world.ActionType == ActionType.DISCRETE)
SetArray(dataOffsets.DecisionObsOffset, policy.DecisionObs, 4 * decisionCount * totalFloatObsPerAgent);
SetArray(dataOffsets.DecisionRewardsOffset, policy.DecisionRewards, 4 * decisionCount);
SetArray(dataOffsets.DecisionAgentIdOffset, policy.DecisionAgentIds, 4 * decisionCount);
if (policy.ActionType == ActionType.DISCRETE)
{
SetArray(dataOffsets.DecisionActionMasksOffset, world.DecisionActionMasks, decisionCount * world.DiscreteActionBranches.Sum());
SetArray(dataOffsets.DecisionActionMasksOffset, policy.DecisionActionMasks, decisionCount * policy.DiscreteActionBranches.Sum());
}
//Termination data
var terminationCount = world.TerminationCounter.Count;
var terminationCount = policy.TerminationCounter.Count;
SetInt(dataOffsets.TerminationNumberAgentsOffset, terminationCount);
SetArray(dataOffsets.TerminationObsOffset, world.TerminationObs, 4 * terminationCount * totalFloatObsPerAgent);
SetArray(dataOffsets.TerminationRewardsOffset, world.TerminationRewards, 4 * terminationCount);
SetArray(dataOffsets.TerminationAgentIdOffset, world.TerminationAgentIds, 4 * terminationCount);
SetArray(dataOffsets.TerminationStatusOffset, world.TerminationStatus, terminationCount);
SetArray(dataOffsets.TerminationObsOffset, policy.TerminationObs, 4 * terminationCount * totalFloatObsPerAgent);
SetArray(dataOffsets.TerminationRewardsOffset, policy.TerminationRewards, 4 * terminationCount);
SetArray(dataOffsets.TerminationAgentIdOffset, policy.TerminationAgentIds, 4 * terminationCount);
SetArray(dataOffsets.TerminationStatusOffset, policy.TerminationStatus, terminationCount);
}
public void WriteWorldSpecs(string name, MLAgentsWorld world)
public void WritePolicySpecs(string name, Policy policy)
{
m_OffsetDict[name] = RLDataOffsets.FromWorld(world, name, m_CurrentEndOffset);
m_OffsetDict[name] = RLDataOffsets.FromPolicy(policy, name, m_CurrentEndOffset);
var offset = m_CurrentEndOffset;
offset = SetString(offset, name); // Name
offset = SetInt(offset, world.DecisionAgentIds.Length); // Max Agents
offset = SetBool(offset, world.ActionType == ActionType.CONTINUOUS);
offset = SetInt(offset, world.ActionSize);
if (world.ActionType == ActionType.DISCRETE)
offset = SetInt(offset, policy.DecisionAgentIds.Length); // Max Agents
offset = SetBool(offset, policy.ActionType == ActionType.CONTINUOUS);
offset = SetInt(offset, policy.ActionSize);
if (policy.ActionType == ActionType.DISCRETE)
{
foreach (int branchSize in world.DiscreteActionBranches)
foreach (int branchSize in policy.DiscreteActionBranches)
{
offset = SetInt(offset, branchSize);
}
}
offset = SetInt(offset, world.SensorShapes.Length);
foreach (int3 shape in world.SensorShapes)
offset = SetInt(offset, policy.SensorShapes.Length);
foreach (int3 shape in policy.SensorShapes)
{
offset = SetInt(offset, shape.x);
offset = SetInt(offset, shape.y);
@ -119,7 +119,7 @@ namespace Unity.AI.MLAgents
m_CurrentEndOffset = offset;
}
public void ReadWorld(string name, MLAgentsWorld world)
public void ReadPolicy(string name, Policy policy)
{
if (!CanEdit)
{
@ -127,18 +127,18 @@ namespace Unity.AI.MLAgents
}
if (!m_OffsetDict.ContainsKey(name))
{
throw new MLAgentsException("World not registered");
throw new MLAgentsException("Policy not registered");
}
var dataOffsets = m_OffsetDict[name];
SetInt(dataOffsets.DecisionNumberAgentsOffset, 0);
SetInt(dataOffsets.TerminationNumberAgentsOffset, 0);
if (world.ActionType == ActionType.DISCRETE)
if (policy.ActionType == ActionType.DISCRETE)
{
GetArray(dataOffsets.ActionOffset, world.DiscreteActuators, 4 * world.DecisionCounter.Count * world.ActionSize);
GetArray(dataOffsets.ActionOffset, policy.DiscreteActuators, 4 * policy.DecisionCounter.Count * policy.ActionSize);
}
else
{
GetArray(dataOffsets.ActionOffset, world.ContinuousActuators, 4 * world.DecisionCounter.Count * world.ActionSize);
GetArray(dataOffsets.ActionOffset, policy.ContinuousActuators, 4 * policy.DecisionCounter.Count * policy.ActionSize);
}
}

Просмотреть файл

@ -75,23 +75,23 @@ namespace Unity.AI.MLAgents
}
/// <summary>
/// Writes the data of a world into the shared memory file.
/// Writes the data of a worPolicyld into the shared memory file.
/// </summary>
public void WriteWorld(string worldName, MLAgentsWorld world)
public void WritePolicy(string policyName, Policy policy)
{
if (!m_SharedMemoryHeader.Active)
{
return;
}
if (m_ShareMemoryBody.ContainsWorld(worldName))
if (m_ShareMemoryBody.ContainsPolicy(policyName))
{
m_ShareMemoryBody.WriteWorld(worldName, world);
m_ShareMemoryBody.WritePolicy(policyName, policy);
}
else
{
// The world needs to register
// The policy needs to register
int oldTotalCapacity = m_SharedMemoryHeader.RLDataBufferSize;
int worldMemorySize = RLDataOffsets.FromWorld(world, worldName, 0).EndOfDataOffset;
int policyMemorySize = RLDataOffsets.FromPolicy(policy, policyName, 0).EndOfDataOffset;
m_CurrentFileNumber += 1;
m_SharedMemoryHeader.FileNumber = m_CurrentFileNumber;
byte[] channelData = m_ShareMemoryBody.SideChannelData;
@ -102,9 +102,9 @@ namespace Unity.AI.MLAgents
true,
null,
m_SharedMemoryHeader.SideChannelBufferSize,
oldTotalCapacity + worldMemorySize
oldTotalCapacity + policyMemorySize
);
m_SharedMemoryHeader.RLDataBufferSize = oldTotalCapacity + worldMemorySize;
m_SharedMemoryHeader.RLDataBufferSize = oldTotalCapacity + policyMemorySize;
if (channelData != null)
{
m_ShareMemoryBody.SideChannelData = channelData;
@ -114,8 +114,8 @@ namespace Unity.AI.MLAgents
m_ShareMemoryBody.RlData = rlData;
}
// TODO Need to write the offsets
m_ShareMemoryBody.WriteWorldSpecs(worldName, world);
m_ShareMemoryBody.WriteWorld(worldName, world);
m_ShareMemoryBody.WritePolicySpecs(policyName, policy);
m_ShareMemoryBody.WritePolicy(policyName, policy);
}
}
@ -180,11 +180,11 @@ namespace Unity.AI.MLAgents
}
/// <summary>
/// Loads the action data form the shared memory file to the world
/// Loads the action data form the shared memory file to the policy
/// </summary>
public void LoadWorld(string worldName, MLAgentsWorld world)
public void LoadPolicy(string policyName, Policy policy)
{
m_ShareMemoryBody.ReadWorld(worldName, world);
m_ShareMemoryBody.ReadPolicy(policyName, policy);
}
public void Dispose()

Просмотреть файл

@ -6,7 +6,7 @@ using UnityEngine;
namespace Unity.AI.MLAgents
{
internal enum WorldProcessorType
internal enum PolicyProcessorType
{
Default,
InferenceOnly,
@ -14,16 +14,16 @@ namespace Unity.AI.MLAgents
}
/// <summary>
/// An editor friendly constructor for a MLAgentsWorld.
/// Keeps track of the behavior specs of a world, its name,
/// An editor friendly constructor for a Policy.
/// Keeps track of the behavior specs of a Policy, its name,
/// its processor type and Neural Network Model.
/// </summary>
[Serializable]
public struct MLAgentsWorldSpecs
public struct PolicySpecs
{
[SerializeField] internal string Name;
[SerializeField] internal WorldProcessorType WorldProcessorType;
[SerializeField] internal PolicyProcessorType PolicyProcessorType;
[SerializeField] internal int NumberAgents;
[SerializeField] internal ActionType ActionType;
@ -34,47 +34,47 @@ namespace Unity.AI.MLAgents
[SerializeField] internal NNModel Model;
[SerializeField] internal InferenceDevice InferenceDevice;
private MLAgentsWorld m_World;
private Policy m_Policy;
/// <summary>
/// Generates the world using the specified specs and registers its
/// processor to the Academy. The world is only created and registed once,
/// if subsequent calls are made, the created world will be returned.
/// Generates a Policy using the specified specs and registers its
/// processor to the Academy. The policy is only created and registed once,
/// if subsequent calls are made, the created policy will be returned.
/// </summary>
/// <returns></returns>
public MLAgentsWorld GetWorld()
public Policy GetPolicy()
{
if (m_World.IsCreated)
if (m_Policy.IsCreated)
{
return m_World;
return m_Policy;
}
m_World = new MLAgentsWorld(
m_Policy = new Policy(
NumberAgents,
ObservationShapes,
ActionType,
ActionSize,
DiscreteActionBranches
);
switch (WorldProcessorType)
switch (PolicyProcessorType)
{
case WorldProcessorType.Default:
m_World.RegisterWorldWithBarracudaModel(Name, Model, InferenceDevice);
case PolicyProcessorType.Default:
m_Policy.RegisterPolicyWithBarracudaModel(Name, Model, InferenceDevice);
break;
case WorldProcessorType.InferenceOnly:
case PolicyProcessorType.InferenceOnly:
if (Model == null)
{
throw new MLAgentsException($"No model specified for {Name}");
}
m_World.RegisterWorldWithBarracudaModelForceNoCommunication(Name, Model, InferenceDevice);
m_Policy.RegisterPolicyWithBarracudaModelForceNoCommunication(Name, Model, InferenceDevice);
break;
case WorldProcessorType.None:
Academy.Instance.RegisterWorld(Name, m_World, new NullWorldProcessor(m_World), false);
case PolicyProcessorType.None:
Academy.Instance.RegisterPolicy(Name, m_Policy, new NullPolicyProcessor(m_Policy), false);
break;
default:
throw new MLAgentsException($"Unknown WorldProcessor Type");
throw new MLAgentsException($"Unknown IPolicyProcessor Type");
}
return m_World;
return m_Policy;
}
}
}

Просмотреть файл

Просмотреть файл

@ -1,104 +0,0 @@
using System;
using Unity.Collections;
using Unity.Collections.LowLevel.Unsafe;
namespace Unity.AI.MLAgents
{
public static class HeristicWorldProcessorRegistringExtension
{
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
/// Note that if the simulation connects to Python, the Heuristic will
/// be ignored and the world will exchange data with Python instead.
/// The Heuristic is a Function that returns an action struct.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// and for training.</param>
/// <param name="heuristic"> The Heuristic used to generate the actions.
/// Note that all agents in the world will receive the same action.</param>
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
/// Size and Action Type of the world.</typeparam>
public static void RegisterWorldWithHeuristic<TH>(
this MLAgentsWorld world,
string policyId,
Func<TH> heuristic
) where TH : struct
{
var worldProcessor = new HeuristicWorldProcessor<TH>(world, heuristic);
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true);
}
/// <summary>
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
/// Note that if the simulation connects to Python, the world will not
/// exchange data with Python and use the Heuristic regardless.
/// The Heuristic is a Function that returns an action struct.
/// </summary>
/// <param name="world"> The MLAgentsWorld to register</param>
/// <param name="policyId"> The name of the world. This is useful for identification
/// and for training.</param>
/// <param name="heuristic"> The Heuristic used to generate the actions.
/// Note that all agents in the world will receive the same action.</param>
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
/// Size and Action Type of the world.</typeparam>
public static void RegisterWorldWithHeuristicForceNoCommunication<TH>(
this MLAgentsWorld world,
string policyId,
Func<TH> heuristic
) where TH : struct
{
var worldProcessor = new HeuristicWorldProcessor<TH>(world, heuristic);
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, false);
}
}
internal class HeuristicWorldProcessor<T> : IWorldProcessor where T : struct
{
private Func<T> m_Heuristic;
private MLAgentsWorld m_World;
public bool IsConnected {get {return false;}}
internal HeuristicWorldProcessor(MLAgentsWorld world, Func<T> heuristic)
{
this.m_World = world;
this.m_Heuristic = heuristic;
var structSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
if (structSize != world.ActionSize)
{
throw new MLAgentsException(
$"The heuristic provided does not match the action size. Expected {world.ActionSize} but received {structSize} from heuristic");
}
}
public void ProcessWorld()
{
T action = m_Heuristic.Invoke();
var totalCount = m_World.DecisionCounter.Count;
// TODO : This can be parallelized
if (m_World.ActionType == ActionType.CONTINUOUS)
{
var s = m_World.ContinuousActuators.Slice(0, totalCount * m_World.ActionSize).SliceConvert<T>();
for (int i = 0; i < totalCount; i++)
{
s[i] = action;
}
}
else
{
var s = m_World.DiscreteActuators.Slice(0, totalCount * m_World.ActionSize).SliceConvert<T>();
for (int i = 0; i < totalCount; i++)
{
s[i] = action;
}
}
}
public void Dispose()
{
}
}
}

Просмотреть файл

@ -4,8 +4,9 @@
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_Name: AgentBlue
m_Shader: {fileID: 47, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords: _GLOSSYREFLECTIONS_OFF _SPECULARHIGHLIGHTS_OFF

Просмотреть файл

@ -249,9 +249,9 @@ MonoBehaviour:
m_Script: {fileID: 11500000, guid: 99c8505fda4844b5b8a9505d069fdd8c, type: 3}
m_Name:
m_EditorClassIdentifier:
MyWorldSpecs:
Name: Ball_DOTS
WorldProcessorType: 0
MyPolicySpecs:
Name: 3DBall
PolicyProcessorType: 0
NumberAgents: 1000
ActionType: 1
ObservationShapes:
@ -268,7 +268,7 @@ MonoBehaviour:
y: 0
z: 0
ActionSize: 2
DiscreteActionBranches:
DiscreteActionBranches: 0000000000000000
Model: {fileID: 5022602860645237092, guid: 10312addcce7c4e508c30d95eaf1a2ff, type: 3}
InferenceDevice: 0
NumberBalls: 1000

Просмотреть файл

@ -8,7 +8,7 @@ using Unity.AI.MLAgents;
public class BalanceBallManager : MonoBehaviour
{
public MLAgentsWorldSpecs MyWorldSpecs;
public PolicySpecs MyPolicySpecs;
public int NumberBalls = 1000;
@ -23,12 +23,13 @@ public class BalanceBallManager : MonoBehaviour
NativeArray<Entity> entitiesB;
BlobAssetStore blob;
void Awake()
{
var world = MyWorldSpecs.GetWorld();
var policy = MyPolicySpecs.GetPolicy();
var ballSystem = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem<BallSystem>();
ballSystem.Enabled = true;
ballSystem.BallWorld = world;
ballSystem.BallPolicy = policy;
manager = World.DefaultGameObjectInjectionWorld.EntityManager;

Просмотреть файл

@ -33,17 +33,17 @@ public class BallSystem : JobComponentSystem
}
}
public MLAgentsWorld BallWorld;
public Policy BallPolicy;
// Update is called once per frame
protected override JobHandle OnUpdate(JobHandle inputDeps)
{
if (!BallWorld.IsCreated){
if (!BallPolicy.IsCreated){
return inputDeps;
}
var world = BallWorld;
var policy = BallPolicy;
ComponentDataFromEntity<Translation> TranslationFromEntity = GetComponentDataFromEntity<Translation>(isReadOnly: false);
ComponentDataFromEntity<PhysicsVelocity> VelFromEntity = GetComponentDataFromEntity<PhysicsVelocity>(isReadOnly: false);
@ -52,7 +52,7 @@ public class BallSystem : JobComponentSystem
.WithNativeDisableParallelForRestriction(VelFromEntity)
.ForEach((Entity entity, ref Rotation rot, ref AgentData agentData) =>
{
var ballPos = TranslationFromEntity[agentData.BallRef].Value;
var ballVel = VelFromEntity[agentData.BallRef].Linear;
var platformVel = VelFromEntity[entity];
@ -70,7 +70,7 @@ public class BallSystem : JobComponentSystem
}
if (!interruption && !taskFailed)
{
world.RequestDecision(entity)
policy.RequestDecision(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
@ -79,7 +79,7 @@ public class BallSystem : JobComponentSystem
}
if (taskFailed)
{
world.EndEpisode(entity)
policy.EndEpisode(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
@ -88,7 +88,7 @@ public class BallSystem : JobComponentSystem
}
else if (interruption)
{
world.InterruptEpisode(entity)
policy.InterruptEpisode(entity)
.SetObservation(0, rot.Value)
.SetObservation(1, ballPos - agentData.BallResetPosition)
.SetObservation(2, ballVel)
@ -109,7 +109,7 @@ public class BallSystem : JobComponentSystem
{
ComponentDataFromEntity = GetComponentDataFromEntity<Actuator>(isReadOnly: false)
};
inputDeps = reactiveJob.Schedule(world, inputDeps);
inputDeps = reactiveJob.Schedule(policy, inputDeps);
inputDeps = Entities.ForEach((Actuator act, ref Rotation rotation) =>
{
@ -122,6 +122,6 @@ public class BallSystem : JobComponentSystem
protected override void OnDestroy()
{
BallWorld.Dispose();
BallPolicy.Dispose();
}
}

Просмотреть файл

@ -5,8 +5,8 @@ using UnityEngine;
public class BasicAgent : MonoBehaviour
{
public MLAgentsWorldSpecs BasicSpecs;
private MLAgentsWorld m_World;
public PolicySpecs BasicSpecs;
private Policy m_Policy;
private Entity m_Entity;
public float timeBetweenDecisionsAtInference;
@ -33,8 +33,8 @@ public class BasicAgent : MonoBehaviour
void Start()
{
m_Entity = World.DefaultGameObjectInjectionWorld.EntityManager.CreateEntity();
m_World = BasicSpecs.GetWorld();
m_World.RegisterWorldWithHeuristic<int>("BASIC", () => { return 2; });
m_Policy = BasicSpecs.GetPolicy();
m_Policy.RegisterPolicyWithHeuristic<int>("BASIC", () => { return 2; });
Academy.Instance.OnEnvironmentReset += BeginEpisode;
BeginEpisode();
}
@ -62,12 +62,12 @@ public class BasicAgent : MonoBehaviour
void StepAgent()
{
// Request a Decision for all agents
m_World.RequestDecision(m_Entity)
m_Policy.RequestDecision(m_Entity)
.SetObservation(0, m_Position)
.SetReward(-0.01f);
// Get the action
NativeHashMap<Entity, int> actions = m_World.GenerateActionHashMap<int>(Allocator.Temp);
NativeHashMap<Entity, int> actions = m_Policy.GenerateActionHashMap<int>(Allocator.Temp);
int action = 0;
actions.TryGetValue(m_Entity, out action);
@ -87,7 +87,7 @@ public class BasicAgent : MonoBehaviour
// See if the Agent terminated
if (m_Position == k_SmallGoalPosition)
{
m_World.EndEpisode(m_Entity)
m_Policy.EndEpisode(m_Entity)
.SetObservation(0, m_Position)
.SetReward(0.1f);
BeginEpisode();
@ -95,7 +95,7 @@ public class BasicAgent : MonoBehaviour
if (m_Position == k_LargeGoalPosition)
{
m_World.EndEpisode(m_Entity)
m_Policy.EndEpisode(m_Entity)
.SetObservation(0, m_Position)
.SetReward(1f);
BeginEpisode();
@ -104,6 +104,6 @@ public class BasicAgent : MonoBehaviour
void OnDestroy()
{
m_World.Dispose();
m_Policy.Dispose();
}
}

Просмотреть файл

@ -6,7 +6,7 @@ using Unity.Jobs;
namespace Unity.AI.MLAgents.Tests.Editor
{
public class TestMLAgentsWorld
public class TestMLAgentsPolicy
{
private World ECSWorld;
private EntityManager entityManager;
@ -51,15 +51,15 @@ namespace Unity.AI.MLAgents.Tests.Editor
}
[Test]
public void TestWorldCreation()
public void TestPolicyCreation()
{
var world = new MLAgentsWorld(
var policy = new Policy(
20,
new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) },
ActionType.DISCRETE,
2,
new int[] { 2, 3 });
world.Dispose();
policy.Dispose();
}
private struct SingleActionEnumUpdate : IActuatorJob
@ -77,18 +77,18 @@ namespace Unity.AI.MLAgents.Tests.Editor
[Test]
public void TestManualDecisionSteppingWithHeuristic()
{
var world = new MLAgentsWorld(
var policy = new Policy(
20,
new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) },
ActionType.DISCRETE,
2,
new int[] { 2, 3 });
world.RegisterWorldWithHeuristic("test", () => new int2(0, 1));
policy.RegisterPolicyWithHeuristic("test", () => new int2(0, 1));
var entity = entityManager.CreateEntity();
world.RequestDecision(entity)
policy.RequestDecision(entity)
.SetObservation(0, new float3(1, 2, 3))
.SetReward(1f);
@ -100,7 +100,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
ent = entities,
action = actions
};
actionJob.Schedule(world, new JobHandle()).Complete();
actionJob.Schedule(policy, new JobHandle()).Complete();
Assert.AreEqual(entity, entities[0]);
Assert.AreEqual(new DiscreteAction_TWO_THREE
@ -109,7 +109,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
action_TWO = ThreeOptionEnum.Option_TWO
}, actions[0]);
world.Dispose();
policy.Dispose();
entities.Dispose();
actions.Dispose();
Academy.Instance.Dispose();
@ -118,56 +118,56 @@ namespace Unity.AI.MLAgents.Tests.Editor
[Test]
public void TestTerminateEpisode()
{
var world = new MLAgentsWorld(
var policy = new Policy(
20,
new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) },
ActionType.DISCRETE,
2,
new int[] { 2, 3 });
world.RegisterWorldWithHeuristic("test", () => new int2(0, 1));
policy.RegisterPolicyWithHeuristic("test", () => new int2(0, 1));
var entity = entityManager.CreateEntity();
world.RequestDecision(entity)
policy.RequestDecision(entity)
.SetObservation(0, new float3(1, 2, 3))
.SetReward(1f);
var hashMap = world.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
var hashMap = policy.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
Assert.True(hashMap.TryGetValue(entity, out _));
hashMap.Dispose();
world.EndEpisode(entity);
hashMap = world.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
policy.EndEpisode(entity);
hashMap = policy.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
Assert.False(hashMap.TryGetValue(entity, out _));
hashMap.Dispose();
world.Dispose();
policy.Dispose();
Academy.Instance.Dispose();
}
[Test]
public void TestMultiWorld()
public void TestMultiPolicy()
{
var world1 = new MLAgentsWorld(
var policy1 = new Policy(
20,
new int3[] { new int3(3, 0, 0) },
ActionType.DISCRETE,
2,
new int[] { 2, 3 });
var world2 = new MLAgentsWorld(
var policy2 = new Policy(
20,
new int3[] { new int3(3, 0, 0) },
ActionType.DISCRETE,
2,
new int[] { 2, 3 });
world1.RegisterWorldWithHeuristic("test1", () => new DiscreteAction_TWO_THREE
policy1.RegisterPolicyWithHeuristic("test1", () => new DiscreteAction_TWO_THREE
{
action_ONE = TwoOptionEnum.Option_TWO,
action_TWO = ThreeOptionEnum.Option_ONE
});
world2.RegisterWorldWithHeuristic("test2", () => new DiscreteAction_TWO_THREE
policy2.RegisterPolicyWithHeuristic("test2", () => new DiscreteAction_TWO_THREE
{
action_ONE = TwoOptionEnum.Option_ONE,
action_TWO = ThreeOptionEnum.Option_TWO
@ -175,8 +175,8 @@ namespace Unity.AI.MLAgents.Tests.Editor
var entity = entityManager.CreateEntity();
world1.RequestDecision(entity);
world2.RequestDecision(entity);
policy1.RequestDecision(entity);
policy2.RequestDecision(entity);
var entities = new NativeArray<Entity>(1, Allocator.Persistent);
var actions = new NativeArray<DiscreteAction_TWO_THREE>(1, Allocator.Persistent);
@ -185,7 +185,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
ent = entities,
action = actions
};
actionJob1.Schedule(world1, new JobHandle()).Complete();
actionJob1.Schedule(policy1, new JobHandle()).Complete();
Assert.AreEqual(entity, entities[0]);
Assert.AreEqual(new DiscreteAction_TWO_THREE
@ -203,7 +203,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
ent = entities,
action = actions
};
actionJob2.Schedule(world2, new JobHandle()).Complete();
actionJob2.Schedule(policy2, new JobHandle()).Complete();
Assert.AreEqual(entity, entities[0]);
Assert.AreEqual(new DiscreteAction_TWO_THREE
@ -214,8 +214,8 @@ namespace Unity.AI.MLAgents.Tests.Editor
entities.Dispose();
actions.Dispose();
world1.Dispose();
world2.Dispose();
policy1.Dispose();
policy2.Dispose();
Academy.Instance.Dispose();
}
}