Modified the *CODE* to rename world to policy
This commit is contained in:
Родитель
17a3f90d08
Коммит
0b677295d3
|
@ -12,7 +12,7 @@ namespace Unity.AI.MLAgents.Editor
|
|||
internal static class SpecsPropertyNames
|
||||
{
|
||||
public const string k_Name = "Name";
|
||||
public const string k_WorldProcessorType = "WorldProcessorType";
|
||||
public const string k_PolicyProcessorType = "PolicyProcessorType";
|
||||
public const string k_NumberAgents = "NumberAgents";
|
||||
public const string k_ActionType = "ActionType";
|
||||
public const string k_ObservationShapes = "ObservationShapes";
|
||||
|
@ -27,8 +27,8 @@ namespace Unity.AI.MLAgents.Editor
|
|||
/// PropertyDrawer for BrainParameters. Defines how BrainParameters are displayed in the
|
||||
/// Inspector.
|
||||
/// </summary>
|
||||
[CustomPropertyDrawer(typeof(MLAgentsWorldSpecs))]
|
||||
internal class MLAgentsWorldSpecsDrawer : PropertyDrawer
|
||||
[CustomPropertyDrawer(typeof(PolicySpecs))]
|
||||
internal class PolicySpecsDrawer : PropertyDrawer
|
||||
{
|
||||
// The height of a line in the Unity Inspectors
|
||||
const float k_LineHeight = 21f;
|
||||
|
@ -66,7 +66,7 @@ namespace Unity.AI.MLAgents.Editor
|
|||
new Rect(position.x - 3f, position.y, position.width + 6f, m_TotalHeight),
|
||||
new Color(0f, 0f, 0f, 0.1f));
|
||||
|
||||
EditorGUI.LabelField(position, "ML-Agents World Specs : " + label.text);
|
||||
EditorGUI.LabelField(position, "Policy Specs : " + label.text);
|
||||
position.y += k_LineHeight;
|
||||
EditorGUI.indentLevel++;
|
||||
|
||||
|
@ -74,13 +74,13 @@ namespace Unity.AI.MLAgents.Editor
|
|||
// Name
|
||||
EditorGUI.PropertyField(position,
|
||||
property.FindPropertyRelative(SpecsPropertyNames.k_Name),
|
||||
new GUIContent("World Name", "The name of the World"));
|
||||
new GUIContent("Name", "The name of the Policy"));
|
||||
position.y += k_LineHeight;
|
||||
|
||||
// WorldProcessorType
|
||||
// PolicyProcessorType
|
||||
EditorGUI.PropertyField(position,
|
||||
property.FindPropertyRelative(SpecsPropertyNames.k_WorldProcessorType),
|
||||
new GUIContent("Processor Type", "The Policy for the World"));
|
||||
property.FindPropertyRelative(SpecsPropertyNames.k_PolicyProcessorType),
|
||||
new GUIContent("Policy Type", "The type of Policy"));
|
||||
position.y += k_LineHeight;
|
||||
|
||||
// Number of Agents
|
||||
|
@ -224,7 +224,7 @@ namespace Unity.AI.MLAgents.Editor
|
|||
var name = nameProperty.stringValue;
|
||||
if (name == "")
|
||||
{
|
||||
m_Warnings.Add("Your World must have a non-empty name");
|
||||
m_Warnings.Add("Your Policy must have a non-empty name");
|
||||
}
|
||||
|
||||
// Max number of agents is not zero
|
||||
|
@ -232,14 +232,14 @@ namespace Unity.AI.MLAgents.Editor
|
|||
var nAgents = nAgentsProperty.intValue;
|
||||
if (nAgents == 0)
|
||||
{
|
||||
m_Warnings.Add("Your World must have a non-zero maximum number of Agents");
|
||||
m_Warnings.Add("Your Policy must have a non-zero maximum number of Agents");
|
||||
}
|
||||
|
||||
// At least one observation
|
||||
var observationShapes = property.FindPropertyRelative(SpecsPropertyNames.k_ObservationShapes);
|
||||
if (observationShapes.arraySize == 0)
|
||||
{
|
||||
m_Warnings.Add("Your World must have at least one observation");
|
||||
m_Warnings.Add("Your Policy must have at least one observation");
|
||||
}
|
||||
|
||||
// Action Size is not zero
|
||||
|
@ -247,7 +247,7 @@ namespace Unity.AI.MLAgents.Editor
|
|||
var actionSize = actionSizeProperty.intValue;
|
||||
if (actionSize == 0)
|
||||
{
|
||||
m_Warnings.Add("Your World must have non-zero action size");
|
||||
m_Warnings.Add("Your Policy must have non-zero action size");
|
||||
}
|
||||
|
||||
//Model is not empty
|
|
@ -1,5 +1,5 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 671ee066644364cc191cb6c00ceaf1b4
|
||||
guid: 94a88b2001ca84a0aa8eb12a93dffc6e
|
||||
MonoImporter:
|
||||
externalObjects: {}
|
||||
serializedVersion: 2
|
|
@ -11,7 +11,7 @@ namespace Unity.AI.MLAgents
|
|||
/// <summary>
|
||||
/// The Academy is a singleton that orchestrates the decision making of the
|
||||
/// decision making of the Agents.
|
||||
/// It is used to register WorldProcessors to Worlds and to keep track of the
|
||||
/// It is used to register PolicyProcessors to Policy and to keep track of the
|
||||
/// reset logic of the simulation.
|
||||
/// </summary>
|
||||
public class Academy : IDisposable
|
||||
|
@ -50,7 +50,8 @@ namespace Unity.AI.MLAgents
|
|||
private bool m_FirstMessageReceived;
|
||||
private SharedMemoryCommunicator m_Communicator;
|
||||
|
||||
internal Dictionary<MLAgentsWorld, IWorldProcessor> m_WorldToProcessor; // Maybe we can put the processor in the world with an unsafe unmanaged memory pointer ?
|
||||
internal Dictionary<Policy, IPolicyProcessor> m_PolicyToProcessor;
|
||||
// TODO : Maybe we can put the processor in the policy with an unsafe unmanaged memory pointer ?
|
||||
|
||||
private EnvironmentParameters m_EnvironmentParameters;
|
||||
private StatsRecorder m_StatsRecorder;
|
||||
|
@ -63,33 +64,29 @@ namespace Unity.AI.MLAgents
|
|||
public Action OnEnvironmentReset;
|
||||
|
||||
/// <summary>
|
||||
/// Registers a MLAgentsWorld to a decision making mechanism.
|
||||
/// By default, the MLAgentsWorld will use a remote process for decision making when available.
|
||||
/// Registers a Policy to a decision making mechanism.
|
||||
/// By default, the Policy will use a remote process for decision making when available.
|
||||
/// </summary>
|
||||
/// <param name="policyId"> The string identifier of the MLAgentsWorld. There can only be one world per unique id.</param>
|
||||
/// <param name="world"> The MLAgentsWorld that is being subscribed.</param>
|
||||
/// <param name="worldProcessor"> If the remote process is not available, the MLAgentsWorld will use this World processor for decision making.</param>
|
||||
/// <param name="defaultRemote"> If true, the MLAgentsWorld will default to using the remote process for communication making and use the fallback worldProcessor otherwise.</param>
|
||||
public void RegisterWorld(string policyId, MLAgentsWorld world, IWorldProcessor worldProcessor = null, bool defaultRemote = true)
|
||||
/// <param name="policyId"> The string identifier of the Policy. There can only be one Policy per unique id.</param>
|
||||
/// <param name="policy"> The Policy that is being subscribed.</param>
|
||||
/// <param name="policyProcessor"> If the remote process is not available, the Policy will use this IPolicyProcessor for decision making.</param>
|
||||
/// <param name="defaultRemote"> If true, the Policy will default to using the remote process for communication making and use the fallback IPolicyProcessor otherwise.</param>
|
||||
public void RegisterPolicy(string policyId, Policy policy, IPolicyProcessor policyProcessor = null, bool defaultRemote = true)
|
||||
{
|
||||
// Need to find a way to deregister ?
|
||||
// Need a way to modify the World processor on the fly
|
||||
// Automagically register world on creation ?
|
||||
|
||||
IWorldProcessor processor = null;
|
||||
IPolicyProcessor processor = null;
|
||||
if (m_Communicator != null && defaultRemote)
|
||||
{
|
||||
processor = new RemoteWorldProcessor(world, policyId, m_Communicator);
|
||||
processor = new RemotePolicyProcessor(policy, policyId, m_Communicator);
|
||||
}
|
||||
else if (worldProcessor != null)
|
||||
else if (policyProcessor != null)
|
||||
{
|
||||
processor = worldProcessor;
|
||||
processor = policyProcessor;
|
||||
}
|
||||
else
|
||||
{
|
||||
processor = new NullWorldProcessor(world);
|
||||
processor = new NullPolicyProcessor(policy);
|
||||
}
|
||||
m_WorldToProcessor[world] = processor;
|
||||
m_PolicyToProcessor[policy] = processor;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -133,7 +130,7 @@ namespace Unity.AI.MLAgents
|
|||
Application.quitting += Dispose;
|
||||
OnEnvironmentReset = () => {};
|
||||
|
||||
m_WorldToProcessor = new Dictionary<MLAgentsWorld, IWorldProcessor>();
|
||||
m_PolicyToProcessor = new Dictionary<Policy, IPolicyProcessor>();
|
||||
|
||||
TryInitializeCommunicator();
|
||||
SideChannelsManager.RegisterSideChannel(new EngineConfigurationChannel());
|
||||
|
@ -162,8 +159,8 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
|
||||
// We will make the assumption that a world can only be updated one at a time
|
||||
internal void UpdateWorld(MLAgentsWorld world)
|
||||
// We will make the assumption that a policy can only be updated one at a time
|
||||
internal void UpdatePolicy(Policy policy)
|
||||
{
|
||||
if (!m_Initialized)
|
||||
{
|
||||
|
@ -171,23 +168,23 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
// If no agents requested a decision return
|
||||
if (world.DecisionCounter.Count == 0 && world.TerminationCounter.Count == 0)
|
||||
if (policy.DecisionCounter.Count == 0 && policy.TerminationCounter.Count == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure the world does not have lingering actions:
|
||||
if (world.ActionCounter.Count != 0)
|
||||
// Ensure the policy does not have lingering actions:
|
||||
if (policy.ActionCounter.Count != 0)
|
||||
{
|
||||
// This means something in the execution went wrong, this error should never appear
|
||||
throw new MLAgentsException("TODO : ActionCount is not 0");
|
||||
}
|
||||
|
||||
var processor = m_WorldToProcessor[world];
|
||||
var processor = m_PolicyToProcessor[policy];
|
||||
if (processor == null)
|
||||
{
|
||||
// Raise error
|
||||
throw new MLAgentsException($"A world has not been correctly registered.");
|
||||
throw new MLAgentsException($"A Policy has not been correctly registered.");
|
||||
}
|
||||
|
||||
|
||||
|
@ -208,10 +205,10 @@ namespace Unity.AI.MLAgents
|
|||
if (!reset) // TODO : Comment out if we do not want to reset on first env.reset()
|
||||
{
|
||||
m_Communicator.WriteSideChannelData(SideChannelsManager.GetSideChannelMessage());
|
||||
processor.ProcessWorld();
|
||||
processor.Process();
|
||||
reset = m_Communicator.ReadAndClearResetCommand();
|
||||
world.SetActionReady();
|
||||
world.ResetDecisionsAndTerminationCounters();
|
||||
policy.SetActionReady();
|
||||
policy.ResetDecisionsAndTerminationCounters();
|
||||
SideChannelsManager.ProcessSideChannelData(m_Communicator.ReadAndClearSideChannelData());
|
||||
}
|
||||
if (reset)
|
||||
|
@ -222,16 +219,16 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
else if (!processor.IsConnected)
|
||||
{
|
||||
processor.ProcessWorld();
|
||||
world.SetActionReady();
|
||||
world.ResetDecisionsAndTerminationCounters();
|
||||
processor.Process();
|
||||
policy.SetActionReady();
|
||||
policy.ResetDecisionsAndTerminationCounters();
|
||||
// TODO com.ReadAndClearSideChannelData(); // Remove side channel data
|
||||
}
|
||||
else
|
||||
{
|
||||
// The processor wants to communicate but the communicator is either null or inactive
|
||||
world.ResetActionsCounter();
|
||||
world.ResetDecisionsAndTerminationCounters();
|
||||
policy.ResetActionsCounter();
|
||||
policy.ResetDecisionsAndTerminationCounters();
|
||||
}
|
||||
if (m_Communicator == null)
|
||||
{
|
||||
|
@ -246,16 +243,16 @@ namespace Unity.AI.MLAgents
|
|||
// Need to complete all of the jobs at this point.
|
||||
ECSWorld.EntityManager.CompleteAllJobs();
|
||||
}
|
||||
ResetAllWorlds();
|
||||
ResetAllPolicies();
|
||||
OnEnvironmentReset?.Invoke();
|
||||
}
|
||||
|
||||
private void ResetAllWorlds() // This is problematic because it affects all worlds and is not thread safe...
|
||||
private void ResetAllPolicies() // This is problematic because it affects all policies and is not thread safe...
|
||||
{
|
||||
foreach (var w in m_WorldToProcessor.Keys)
|
||||
foreach (var pol in m_PolicyToProcessor.Keys)
|
||||
{
|
||||
w.ResetActionsCounter();
|
||||
w.ResetDecisionsAndTerminationCounters();
|
||||
pol.ResetActionsCounter();
|
||||
pol.ResetDecisionsAndTerminationCounters();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,39 +9,39 @@ namespace Unity.AI.MLAgents
|
|||
public static class ActionHashMapUtils
|
||||
{
|
||||
/// <summary>
|
||||
/// Retrieves the action data for a world in puts it into a HashMap.
|
||||
/// This action deletes the action data from the world.
|
||||
/// Retrieves the action data for a Policy in puts it into a HashMap.
|
||||
/// This action deletes the action data from the Policy.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld the data will be retrieved from.</param>
|
||||
/// <param name="policy"> The Policy the data will be retrieved from.</param>
|
||||
/// <param name="allocator"> The memory allocator of the create NativeHashMap.</param>
|
||||
/// <typeparam name="T"> The type of the Action struct. It must match the Action Size
|
||||
/// and Action Type of the world.</typeparam>
|
||||
/// and Action Type of the Policy.</typeparam>
|
||||
/// <returns> A NativeHashMap from Entities to Actions with type T.</returns>
|
||||
public static NativeHashMap<Entity, T> GenerateActionHashMap<T>(this MLAgentsWorld world, Allocator allocator) where T : struct
|
||||
public static NativeHashMap<Entity, T> GenerateActionHashMap<T>(this Policy policy, Allocator allocator) where T : struct
|
||||
{
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
if (world.ActionSize != UnsafeUtility.SizeOf<T>() / 4)
|
||||
if (policy.ActionSize != UnsafeUtility.SizeOf<T>() / 4)
|
||||
{
|
||||
var receivedSize = UnsafeUtility.SizeOf<T>() / 4;
|
||||
throw new MLAgentsException($"Action space size does not match for action. Expected {world.ActionSize} but received {receivedSize}");
|
||||
throw new MLAgentsException($"Action space size does not match for action. Expected {policy.ActionSize} but received {receivedSize}");
|
||||
}
|
||||
#endif
|
||||
Academy.Instance.UpdateWorld(world);
|
||||
int actionCount = world.ActionCounter.Count;
|
||||
Academy.Instance.UpdatePolicy(policy);
|
||||
int actionCount = policy.ActionCounter.Count;
|
||||
var result = new NativeHashMap<Entity, T>(actionCount, allocator);
|
||||
int size = world.ActionSize;
|
||||
int size = policy.ActionSize;
|
||||
for (int i = 0; i < actionCount; i++)
|
||||
{
|
||||
if (world.ActionType == ActionType.DISCRETE)
|
||||
if (policy.ActionType == ActionType.DISCRETE)
|
||||
{
|
||||
result.TryAdd(world.ActionAgentEntityIds[i], world.DiscreteActuators.Slice(i * size, size).SliceConvert<T>()[0]);
|
||||
result.TryAdd(policy.ActionAgentEntityIds[i], policy.DiscreteActuators.Slice(i * size, size).SliceConvert<T>()[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
result.TryAdd(world.ActionAgentEntityIds[i], world.ContinuousActuators.Slice(i * size, size).SliceConvert<T>()[0]);
|
||||
result.TryAdd(policy.ActionAgentEntityIds[i], policy.ContinuousActuators.Slice(i * size, size).SliceConvert<T>()[0]);
|
||||
}
|
||||
}
|
||||
world.ResetActionsCounter();
|
||||
policy.ResetActionsCounter();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace Unity.AI.MLAgents
|
|||
/// Retrieve the action that was decided for the Entity.
|
||||
/// This method uses a generic type, as such you must provide a
|
||||
/// type that is compatible with the Action Type and Action Size
|
||||
/// for this MLAgentsWorld.
|
||||
/// for this Policy.
|
||||
/// </summary>
|
||||
/// <typeparam name="T"> The type of action struct.</typeparam>
|
||||
/// <returns> The action struct for the Entity.</returns>
|
||||
|
@ -59,7 +59,7 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// The signature of the Job used to retrieve actuators values from a world
|
||||
/// The signature of the Job used to retrieve actuators values from a Policy
|
||||
/// </summary>
|
||||
[JobProducerType(typeof(IActuatorJobExtensions.ActuatorDataJobProcess<>))]
|
||||
public interface IActuatorJob
|
||||
|
@ -73,31 +73,31 @@ namespace Unity.AI.MLAgents
|
|||
/// Schedule the Job that will generate the action data for the Entities that requested a decision.
|
||||
/// </summary>
|
||||
/// <param name="jobData"> The IActuatorJob struct.</param>
|
||||
/// <param name="mlagentsWorld"> The MLAgentsWorld containing the data needed for decision making.</param>
|
||||
/// <param name="policy"> The Policy containing the data needed for decision making.</param>
|
||||
/// <param name="inputDeps"> The jobHandle for the job.</param>
|
||||
/// <typeparam name="T"> The type of the IActuatorData struct.</typeparam>
|
||||
/// <returns> The updated jobHandle for the job.</returns>
|
||||
public static unsafe JobHandle Schedule<T>(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps)
|
||||
public static unsafe JobHandle Schedule<T>(this T jobData, Policy policy, JobHandle inputDeps)
|
||||
where T : struct, IActuatorJob
|
||||
{
|
||||
inputDeps.Complete(); // TODO : FIND A BETTER WAY TO MAKE SURE ALL THE DATA IS IN THE WORLD
|
||||
Academy.Instance.UpdateWorld(mlagentsWorld);
|
||||
if (mlagentsWorld.ActionCounter.Count == 0)
|
||||
inputDeps.Complete(); // TODO : FIND A BETTER WAY TO MAKE SURE ALL THE DATA IS IN THE POLICY
|
||||
Academy.Instance.UpdatePolicy(policy);
|
||||
if (policy.ActionCounter.Count == 0)
|
||||
{
|
||||
return inputDeps;
|
||||
}
|
||||
return ScheduleImpl(jobData, mlagentsWorld, inputDeps);
|
||||
return ScheduleImpl(jobData, policy, inputDeps);
|
||||
}
|
||||
|
||||
// Passing this along
|
||||
internal static unsafe JobHandle ScheduleImpl<T>(this T jobData, MLAgentsWorld mlagentsWorld, JobHandle inputDeps)
|
||||
internal static unsafe JobHandle ScheduleImpl<T>(this T jobData, Policy policy, JobHandle inputDeps)
|
||||
where T : struct, IActuatorJob
|
||||
{
|
||||
// Creating a data struct that contains the data the user passed into the job (This is what T is here)
|
||||
var data = new ActionEventJobData<T>
|
||||
{
|
||||
UserJobData = jobData,
|
||||
world = mlagentsWorld // Need to create this before hand with the actuator data
|
||||
Policy = policy // Need to create this before hand with the actuator data
|
||||
};
|
||||
|
||||
// Scheduling a Job out of thin air by using a pointer called jobReflectionData in the ActuatorSystemJobStruct
|
||||
|
@ -105,11 +105,11 @@ namespace Unity.AI.MLAgents
|
|||
return JobsUtility.Schedule(ref parameters);
|
||||
}
|
||||
|
||||
// This is the struct containing all the data needed from both the user and the MLAgents world
|
||||
// This is the struct containing all the data needed from both the user and the Policy
|
||||
internal unsafe struct ActionEventJobData<T> where T : struct
|
||||
{
|
||||
public T UserJobData;
|
||||
[NativeDisableContainerSafetyRestriction] public MLAgentsWorld world;
|
||||
[NativeDisableContainerSafetyRestriction] public Policy Policy;
|
||||
}
|
||||
|
||||
internal struct ActuatorDataJobProcess<T> where T : struct, IActuatorJob
|
||||
|
@ -132,11 +132,11 @@ namespace Unity.AI.MLAgents
|
|||
/// Calls the user implemented Execute method with ActuatorEvent struct
|
||||
public static unsafe void Execute(ref ActionEventJobData<T> jobData, IntPtr listDataPtr, IntPtr unusedPtr, ref JobRanges ranges, int jobIndex)
|
||||
{
|
||||
int size = jobData.world.ActionSize;
|
||||
int actionCount = jobData.world.ActionCounter.Count;
|
||||
int size = jobData.Policy.ActionSize;
|
||||
int actionCount = jobData.Policy.ActionCounter.Count;
|
||||
|
||||
// Continuous case
|
||||
if (jobData.world.ActionType == ActionType.CONTINUOUS)
|
||||
if (jobData.Policy.ActionType == ActionType.CONTINUOUS)
|
||||
{
|
||||
for (int i = 0; i < actionCount; i++)
|
||||
{
|
||||
|
@ -144,8 +144,8 @@ namespace Unity.AI.MLAgents
|
|||
{
|
||||
ActionSize = size,
|
||||
ActionType = ActionType.CONTINUOUS,
|
||||
Entity = jobData.world.ActionAgentEntityIds[i],
|
||||
ContinuousActionSlice = jobData.world.ContinuousActuators.Slice(i * size, size)
|
||||
Entity = jobData.Policy.ActionAgentEntityIds[i],
|
||||
ContinuousActionSlice = jobData.Policy.ContinuousActuators.Slice(i * size, size)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -158,12 +158,12 @@ namespace Unity.AI.MLAgents
|
|||
{
|
||||
ActionSize = size,
|
||||
ActionType = ActionType.DISCRETE,
|
||||
Entity = jobData.world.ActionAgentEntityIds[i],
|
||||
DiscreteActionSlice = jobData.world.DiscreteActuators.Slice(i * size, size)
|
||||
Entity = jobData.Policy.ActionAgentEntityIds[i],
|
||||
DiscreteActionSlice = jobData.Policy.DiscreteActuators.Slice(i * size, size)
|
||||
});
|
||||
}
|
||||
}
|
||||
jobData.world.ResetActionsCounter();
|
||||
jobData.Policy.ResetActionsCounter();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,19 +7,19 @@ using System;
|
|||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// A DecisionRequest is a struct used to provide data about an Agent to a MLAgentsWorld.
|
||||
/// This data will be used to generate a decision after the world is processed.
|
||||
/// A DecisionRequest is a struct used to provide data about an Agent to a Policy.
|
||||
/// This data will be used to generate a decision after the Policy is processed.
|
||||
/// Adding data is done through a builder pattern.
|
||||
/// </summary>
|
||||
public struct DecisionRequest
|
||||
{
|
||||
private int m_Index;
|
||||
private MLAgentsWorld m_World;
|
||||
private Policy m_Policy;
|
||||
|
||||
internal DecisionRequest(int index, MLAgentsWorld world)
|
||||
internal DecisionRequest(int index, Policy policy)
|
||||
{
|
||||
this.m_Index = index;
|
||||
this.m_World = world;
|
||||
this.m_Policy = policy;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -29,7 +29,7 @@ namespace Unity.AI.MLAgents
|
|||
/// <returns> The DecisionRequest struct </returns>
|
||||
public DecisionRequest SetReward(float r)
|
||||
{
|
||||
m_World.DecisionRewards[m_Index] = r;
|
||||
m_Policy.DecisionRewards[m_Index] = r;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -43,45 +43,45 @@ namespace Unity.AI.MLAgents
|
|||
public DecisionRequest SetDiscreteActionMask(int branch, int actionIndex)
|
||||
{
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
if (m_World.ActionType == ActionType.CONTINUOUS)
|
||||
if (m_Policy.ActionType == ActionType.CONTINUOUS)
|
||||
{
|
||||
throw new MLAgentsException("SetDiscreteActionMask can only be used with discrete acton space.");
|
||||
}
|
||||
if (branch > m_World.DiscreteActionBranches.Length)
|
||||
if (branch > m_Policy.DiscreteActionBranches.Length)
|
||||
{
|
||||
throw new MLAgentsException("Unknown action branch used when setting mask.");
|
||||
}
|
||||
if (actionIndex > m_World.DiscreteActionBranches[branch])
|
||||
if (actionIndex > m_Policy.DiscreteActionBranches[branch])
|
||||
{
|
||||
throw new MLAgentsException("Index is out of bounds for requested action mask.");
|
||||
}
|
||||
#endif
|
||||
var trueMaskIndex = m_World.DiscreteActionBranches.CumSumAt(branch) + actionIndex;
|
||||
m_World.DecisionActionMasks[trueMaskIndex + m_World.DiscreteActionBranches.Sum() * m_Index] = true;
|
||||
var trueMaskIndex = m_Policy.DiscreteActionBranches.CumSumAt(branch) + actionIndex;
|
||||
m_Policy.DecisionActionMasks[trueMaskIndex + m_Policy.DiscreteActionBranches.Sum() * m_Index] = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the observation for a decision request.
|
||||
/// </summary>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
|
||||
/// <param name="sensor"> A struct strictly containing floats used as observation data </param>
|
||||
/// <returns> The DecisionRequest struct </returns>
|
||||
public DecisionRequest SetObservation<T>(int sensorNumber, T sensor) where T : struct
|
||||
{
|
||||
int inputSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
int3 s = m_World.SensorShapes[sensorNumber];
|
||||
int3 s = m_Policy.SensorShapes[sensorNumber];
|
||||
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
|
||||
if (inputSize != expectedInputSize)
|
||||
{
|
||||
throw new MLAgentsException(
|
||||
$"Cannot set observation due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
|
||||
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
|
||||
}
|
||||
#endif
|
||||
int start = m_World.ObservationOffsets[sensorNumber];
|
||||
int start = m_Policy.ObservationOffsets[sensorNumber];
|
||||
start += inputSize * m_Index;
|
||||
var tmp = m_World.DecisionObs.Slice(start, inputSize).SliceConvert<T>();
|
||||
var tmp = m_Policy.DecisionObs.Slice(start, inputSize).SliceConvert<T>();
|
||||
tmp[0] = sensor;
|
||||
return this;
|
||||
}
|
||||
|
@ -89,12 +89,12 @@ namespace Unity.AI.MLAgents
|
|||
/// <summary>
|
||||
/// Sets the observation for a decision request using a categorical value.
|
||||
/// </summary>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
|
||||
/// <param name="sensor"> An integer containing the index of the categorical observation </param>
|
||||
/// <returns> The DecisionRequest struct </returns>
|
||||
public DecisionRequest SetObservation(int sensorNumber, int sensor)
|
||||
{
|
||||
int3 s = m_World.SensorShapes[sensorNumber];
|
||||
int3 s = m_Policy.SensorShapes[sensorNumber];
|
||||
int maxValue = s.x;
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
|
||||
|
@ -109,33 +109,33 @@ namespace Unity.AI.MLAgents
|
|||
$"Categorical observation is out of bound for observation {sensorNumber} with maximum {maxValue} (received {sensor}.");
|
||||
}
|
||||
#endif
|
||||
int start = m_World.ObservationOffsets[sensorNumber];
|
||||
int start = m_Policy.ObservationOffsets[sensorNumber];
|
||||
start += maxValue * m_Index;
|
||||
m_World.DecisionObs[start + sensor] = 1.0f;
|
||||
m_Policy.DecisionObs[start + sensor] = 1.0f;
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the observation for a decision request.
|
||||
/// </summary>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
|
||||
/// <param name="obs"> A NativeSlice of floats containing the observation data </param>
|
||||
/// <returns> The DecisionRequest struct </returns>
|
||||
public DecisionRequest SetObservationFromSlice(int sensorNumber, [ReadOnly] NativeSlice<float> obs)
|
||||
{
|
||||
int inputSize = obs.Length;
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
int3 s = m_World.SensorShapes[sensorNumber];
|
||||
int3 s = m_Policy.SensorShapes[sensorNumber];
|
||||
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
|
||||
if (inputSize != expectedInputSize)
|
||||
{
|
||||
throw new MLAgentsException(
|
||||
$"Cannot set observation due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
|
||||
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
|
||||
}
|
||||
#endif
|
||||
int start = m_World.ObservationOffsets[sensorNumber];
|
||||
int start = m_Policy.ObservationOffsets[sensorNumber];
|
||||
start += inputSize * m_Index;
|
||||
m_World.DecisionObs.Slice(start, inputSize).CopyFrom(obs);
|
||||
m_Policy.DecisionObs.Slice(start, inputSize).CopyFrom(obs);
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -7,19 +7,19 @@ using System;
|
|||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// A EpisodeTermination is a struct used to provide data about an Agent to a MLAgentsWorld.
|
||||
/// A EpisodeTermination is a struct used to provide data about an Agent to a Policy.
|
||||
/// This data will be used to notify of the end of the episode of an Agent.
|
||||
/// Adding data is done through a builder pattern.
|
||||
/// </summary>
|
||||
public struct EpisodeTermination
|
||||
{
|
||||
private int m_Index;
|
||||
private MLAgentsWorld m_World;
|
||||
private Policy m_Policy;
|
||||
|
||||
internal EpisodeTermination(int index, MLAgentsWorld world)
|
||||
internal EpisodeTermination(int index, Policy policy)
|
||||
{
|
||||
this.m_Index = index;
|
||||
this.m_World = world;
|
||||
this.m_Policy = policy;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -30,31 +30,31 @@ namespace Unity.AI.MLAgents
|
|||
/// <returns> The EpisodeTermination struct </returns>
|
||||
public EpisodeTermination SetReward(float r)
|
||||
{
|
||||
m_World.TerminationRewards[m_Index] = r;
|
||||
m_Policy.TerminationRewards[m_Index] = r;
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the observation for of the end of the Episode.
|
||||
/// </summary>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
|
||||
/// <param name="sensor"> A struct strictly containing floats used as observation data </param>
|
||||
/// <returns> The EpisodeTermination struct </returns>
|
||||
public EpisodeTermination SetObservation<T>(int sensorNumber, T sensor) where T : struct
|
||||
{
|
||||
int inputSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
int3 s = m_World.SensorShapes[sensorNumber];
|
||||
int3 s = m_Policy.SensorShapes[sensorNumber];
|
||||
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
|
||||
if (inputSize != expectedInputSize)
|
||||
{
|
||||
throw new MLAgentsException(
|
||||
$"Cannot set observation due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
|
||||
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : { expectedInputSize }, received size : { inputSize}");
|
||||
}
|
||||
#endif
|
||||
int start = m_World.ObservationOffsets[sensorNumber];
|
||||
int start = m_Policy.ObservationOffsets[sensorNumber];
|
||||
start += inputSize * m_Index;
|
||||
var tmp = m_World.TerminationObs.Slice(start, inputSize).SliceConvert<T>();
|
||||
var tmp = m_Policy.TerminationObs.Slice(start, inputSize).SliceConvert<T>();
|
||||
tmp[0] = sensor;
|
||||
return this;
|
||||
}
|
||||
|
@ -62,12 +62,12 @@ namespace Unity.AI.MLAgents
|
|||
/// <summary>
|
||||
/// Sets the observation for a termination request using a categorical value.
|
||||
/// </summary>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
|
||||
/// <param name="sensor"> An integer containing the index of the categorical observation </param>
|
||||
/// <returns> The EpisodeTermination struct </returns>
|
||||
public EpisodeTermination SetObservation(int sensorNumber, int sensor)
|
||||
{
|
||||
int3 s = m_World.SensorShapes[sensorNumber];
|
||||
int3 s = m_Policy.SensorShapes[sensorNumber];
|
||||
int maxValue = s.x;
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
|
||||
|
@ -82,33 +82,33 @@ namespace Unity.AI.MLAgents
|
|||
$"Categorical observation is out of bound for observation {sensorNumber} with maximum {maxValue} (received {sensor}.");
|
||||
}
|
||||
#endif
|
||||
int start = m_World.ObservationOffsets[sensorNumber];
|
||||
int start = m_Policy.ObservationOffsets[sensorNumber];
|
||||
start += maxValue * m_Index;
|
||||
m_World.TerminationObs[start + sensor] = 1.0f;
|
||||
m_Policy.TerminationObs[start + sensor] = 1.0f;
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the last observation the Agent perceives before ending the episode.
|
||||
/// </summary>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated MLAgentsWorld </param>
|
||||
/// <param name="sensorNumber"> The index of the observation as provided when creating the associated Policy </param>
|
||||
/// <param name="obs"> A NativeSlice of floats containing the observation data </param>
|
||||
/// <returns> The EpisodeTermination struct </returns>
|
||||
public EpisodeTermination SetObservationFromSlice(int sensorNumber, [ReadOnly] NativeSlice<float> obs)
|
||||
{
|
||||
int inputSize = obs.Length;
|
||||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
int3 s = m_World.SensorShapes[sensorNumber];
|
||||
int3 s = m_Policy.SensorShapes[sensorNumber];
|
||||
int expectedInputSize = s.x * math.max(1, s.y) * math.max(1, s.z);
|
||||
if (inputSize != expectedInputSize)
|
||||
{
|
||||
throw new MLAgentsException(
|
||||
$"Cannot set observation due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
|
||||
$"Cannot set observation {sensorNumber} due to incompatible size of the input. Expected size : {expectedInputSize}, received size : { inputSize}");
|
||||
}
|
||||
#endif
|
||||
int start = m_World.ObservationOffsets[sensorNumber];
|
||||
int start = m_Policy.ObservationOffsets[sensorNumber];
|
||||
start += inputSize * m_Index;
|
||||
m_World.TerminationObs.Slice(start, inputSize).CopyFrom(obs);
|
||||
m_Policy.TerminationObs.Slice(start, inputSize).CopyFrom(obs);
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -9,14 +9,14 @@ using Unity.Collections.LowLevel.Unsafe;
|
|||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// MLAgentsWorld is a data container on which the user requests decisions.
|
||||
/// Policy is a data container on which the user requests decisions.
|
||||
/// </summary>
|
||||
public struct MLAgentsWorld : IDisposable
|
||||
public struct Policy : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Indicates if the MLAgentsWorld has been instantiated
|
||||
/// Indicates if the Policy has been instantiated
|
||||
/// </summary>
|
||||
/// <value> True if MLAgentsWorld was instantiated, False otherwise</value>
|
||||
/// <value> True if the Policy was instantiated, False otherwise</value>
|
||||
public bool IsCreated
|
||||
{
|
||||
get { return DecisionAgentIds.IsCreated;}
|
||||
|
@ -60,13 +60,13 @@ namespace Unity.AI.MLAgents
|
|||
/// </summary>
|
||||
/// <param name="maximumNumberAgents"> The maximum number of decisions that can be requested between each MLAgentsSystem update </param>
|
||||
/// <param name="obsShapes"> An array of int3 corresponding to the shape of the expected observations (one int3 per observation) </param>
|
||||
/// <param name="actionType"> An ActionType enum (DISCRETE / CONTINUOUS) specifying the type of actions the MLAgentsWorld will produce </param>
|
||||
/// <param name="actionSize"> The number of actions the MLAgentsWorld is expected to generate for each decision.
|
||||
/// <param name="actionType"> An ActionType enum (DISCRETE / CONTINUOUS) specifying the type of actions the Policy will produce </param>
|
||||
/// <param name="actionSize"> The number of actions the Policy is expected to generate for each decision.
|
||||
/// - If CONTINUOUS ActionType : The number of floats the action contains
|
||||
/// - If DISCRETE ActionType : The number of branches (integer actions) the action contains </param>
|
||||
/// <param name="discreteActionBranches"> For DISCRETE ActionType only : an array of int specifying the number of possible int values each
|
||||
/// action branch has. (Must be of the same length as actionSize </param>
|
||||
public MLAgentsWorld(
|
||||
public Policy(
|
||||
int maximumNumberAgents,
|
||||
int3[] obsShapes,
|
||||
ActionType actionType,
|
||||
|
@ -164,7 +164,7 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dispose of the MLAgentsWorld.
|
||||
/// Dispose of the Policy.
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
|
@ -199,7 +199,7 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Requests a decision for a specific Entity to the MLAgentsWorld. The DecisionRequest
|
||||
/// Requests a decision for a specific Entity to the Policy. The DecisionRequest
|
||||
/// struct this method returns can be used to provide data necessary for the Agent to
|
||||
/// take a decision for the Entity.
|
||||
/// </summary>
|
||||
|
@ -211,7 +211,7 @@ namespace Unity.AI.MLAgents
|
|||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
if (!IsCreated)
|
||||
{
|
||||
throw new MLAgentsException($"Invalid operation, cannot request a decision on a non-initialized MLAgentsWorld");
|
||||
throw new MLAgentsException($"Invalid operation, cannot request a decision on a non-initialized Policy");
|
||||
}
|
||||
#endif
|
||||
var index = DecisionCounter.Increment() - 1;
|
||||
|
@ -239,7 +239,7 @@ namespace Unity.AI.MLAgents
|
|||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
if (!IsCreated)
|
||||
{
|
||||
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized MLAgentsWorld");
|
||||
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized Policy");
|
||||
}
|
||||
#endif
|
||||
var index = TerminationCounter.Increment() - 1;
|
||||
|
@ -267,7 +267,7 @@ namespace Unity.AI.MLAgents
|
|||
#if ENABLE_UNITY_COLLECTIONS_CHECKS
|
||||
if (!IsCreated)
|
||||
{
|
||||
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized MLAgentsWorld");
|
||||
throw new MLAgentsException($"Invalid operation, cannot end episode on a non-initialized Policy");
|
||||
}
|
||||
#endif
|
||||
var index = TerminationCounter.Increment() - 1;
|
|
@ -21,23 +21,23 @@ namespace Unity.AI.MLAgents
|
|||
GPU = 1
|
||||
}
|
||||
|
||||
public static class BarracudaWorldProcessorRegistringExtension
|
||||
public static class BarracudaPolicyProcessorRegistringExtension
|
||||
{
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Neural
|
||||
/// Registers the given Policy to the Academy with a Neural
|
||||
/// Network Model. If the input model is null, a default inactive
|
||||
/// processor will be registered instead. Note that if the simulation
|
||||
/// connects to Python, the Neural Network will be ignored and the world
|
||||
/// connects to Python, the Neural Network will be ignored and the Policy
|
||||
/// will exchange data with Python instead.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// <param name="policy"> The Policy to register</param>
|
||||
/// <param name="policyId"> The name of the Policy. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="model"> The Neural Network model used by the processor</param>
|
||||
/// <param name="inferenceDevice"> The inference device specifying where to run inference
|
||||
/// (CPU or GPU)</param>
|
||||
public static void RegisterWorldWithBarracudaModel(
|
||||
this MLAgentsWorld world,
|
||||
public static void RegisterPolicyWithBarracudaModel(
|
||||
this Policy policy,
|
||||
string policyId,
|
||||
NNModel model,
|
||||
InferenceDevice inferenceDevice = InferenceDevice.CPU
|
||||
|
@ -45,30 +45,30 @@ namespace Unity.AI.MLAgents
|
|||
{
|
||||
if (model != null)
|
||||
{
|
||||
var worldProcessor = new BarracudaWorldProcessor(world, model, inferenceDevice);
|
||||
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true);
|
||||
var policyProcessor = new BarracudaPolicyProcessor(policy, model, inferenceDevice);
|
||||
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
Academy.Instance.RegisterWorld(policyId, world, null, true);
|
||||
Academy.Instance.RegisterPolicy(policyId, policy, null, true);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Neural
|
||||
/// Registers the given Policy to the Academy with a Neural
|
||||
/// Network Model. If the input model is null, a default inactive
|
||||
/// processor will be registered instead. Note that if the simulation
|
||||
/// connects to Python, the world will not connect to Python and run the
|
||||
/// connects to Python, the Policy will not connect to Python and run the
|
||||
/// given Neural Network regardless.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// <param name="policy"> The Policy to register</param>
|
||||
/// <param name="policyId"> The name of the Policy. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="model"> The Neural Network model used by the processor</param>
|
||||
/// <param name="inferenceDevice"> The inference device specifying where to run inference
|
||||
/// (CPU or GPU)</param>
|
||||
public static void RegisterWorldWithBarracudaModelForceNoCommunication(
|
||||
this MLAgentsWorld world,
|
||||
public static void RegisterPolicyWithBarracudaModelForceNoCommunication(
|
||||
this Policy policy,
|
||||
string policyId,
|
||||
NNModel model,
|
||||
InferenceDevice inferenceDevice = InferenceDevice.CPU
|
||||
|
@ -76,19 +76,19 @@ namespace Unity.AI.MLAgents
|
|||
{
|
||||
if (model != null)
|
||||
{
|
||||
var worldProcessor = new BarracudaWorldProcessor(world, model, inferenceDevice);
|
||||
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, false);
|
||||
var policyProcessor = new BarracudaPolicyProcessor(policy, model, inferenceDevice);
|
||||
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
Academy.Instance.RegisterWorld(policyId, world, null, false);
|
||||
Academy.Instance.RegisterPolicy(policyId, policy, null, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal unsafe class BarracudaWorldProcessor : IWorldProcessor
|
||||
internal unsafe class BarracudaPolicyProcessor : IPolicyProcessor
|
||||
{
|
||||
private MLAgentsWorld m_World;
|
||||
private Policy m_Policy;
|
||||
private Model m_BarracudaModel;
|
||||
private IWorker m_Engine;
|
||||
private const bool k_Verbose = false;
|
||||
|
@ -98,9 +98,9 @@ namespace Unity.AI.MLAgents
|
|||
|
||||
public bool IsConnected {get {return false;}}
|
||||
|
||||
internal BarracudaWorldProcessor(MLAgentsWorld world, NNModel model, InferenceDevice inferenceDevice)
|
||||
internal BarracudaPolicyProcessor(Policy policy, NNModel model, InferenceDevice inferenceDevice)
|
||||
{
|
||||
this.m_World = world;
|
||||
this.m_Policy = policy;
|
||||
D.logEnabled = k_Verbose;
|
||||
m_Engine?.Dispose();
|
||||
|
||||
|
@ -111,15 +111,15 @@ namespace Unity.AI.MLAgents
|
|||
|
||||
m_Engine = WorkerFactory.CreateWorker(
|
||||
executionDevice, m_BarracudaModel, k_Verbose);
|
||||
for (int i = 0; i < m_World.SensorShapes.Length; i++)
|
||||
for (int i = 0; i < m_Policy.SensorShapes.Length; i++)
|
||||
{
|
||||
if (m_World.SensorShapes[i].GetDimensions() == 1)
|
||||
obsSize += m_World.SensorShapes[i].GetTotalTensorSize();
|
||||
if (m_Policy.SensorShapes[i].GetDimensions() == 1)
|
||||
obsSize += m_Policy.SensorShapes[i].GetTotalTensorSize();
|
||||
}
|
||||
vectorObsArr = new float[m_World.DecisionAgentIds.Length * obsSize];
|
||||
vectorObsArr = new float[m_Policy.DecisionAgentIds.Length * obsSize];
|
||||
}
|
||||
|
||||
public void ProcessWorld()
|
||||
public void Process()
|
||||
{
|
||||
// TODO : Cover all cases
|
||||
// FOR VECTOR OBS ONLY
|
||||
|
@ -128,27 +128,27 @@ namespace Unity.AI.MLAgents
|
|||
|
||||
var input = new System.Collections.Generic.Dictionary<string, Tensor>();
|
||||
|
||||
// var sensorData = m_World.DecisionObs.ToArray();
|
||||
// var sensorData = m_Policy.DecisionObs.ToArray();
|
||||
int sensorOffset = 0;
|
||||
int vecObsOffset = 0;
|
||||
foreach (var shape in m_World.SensorShapes)
|
||||
foreach (var shape in m_Policy.SensorShapes)
|
||||
{
|
||||
if (shape.GetDimensions() == 1)
|
||||
{
|
||||
for (int i = 0; i < m_World.DecisionCounter.Count; i++)
|
||||
for (int i = 0; i < m_Policy.DecisionCounter.Count; i++)
|
||||
{
|
||||
fixed(void* arrPtr = vectorObsArr)
|
||||
{
|
||||
UnsafeUtility.MemCpy(
|
||||
(byte*)arrPtr + 4 * i * obsSize + 4 * vecObsOffset,
|
||||
(byte*)m_World.DecisionObs.GetUnsafePtr() + 4 * sensorOffset + 4 * i * shape.GetTotalTensorSize(),
|
||||
(byte*)m_Policy.DecisionObs.GetUnsafePtr() + 4 * sensorOffset + 4 * i * shape.GetTotalTensorSize(),
|
||||
shape.GetTotalTensorSize() * 4
|
||||
);
|
||||
}
|
||||
|
||||
// Array.Copy(sensorData, sensorOffset + i * shape.GetTotalTensorSize(), vectorObsArr, i * obsSize + vecObsOffset, shape.GetTotalTensorSize());
|
||||
}
|
||||
sensorOffset += m_World.DecisionAgentIds.Length * shape.GetTotalTensorSize();
|
||||
sensorOffset += m_Policy.DecisionAgentIds.Length * shape.GetTotalTensorSize();
|
||||
vecObsOffset += shape.GetTotalTensorSize();
|
||||
}
|
||||
else
|
||||
|
@ -158,7 +158,7 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
input["vector_observation"] = new Tensor(
|
||||
new TensorShape(m_World.DecisionCounter.Count, obsSize),
|
||||
new TensorShape(m_Policy.DecisionCounter.Count, obsSize),
|
||||
vectorObsArr,
|
||||
"vector_observation");
|
||||
|
||||
|
@ -166,18 +166,18 @@ namespace Unity.AI.MLAgents
|
|||
|
||||
var actuatorT = m_Engine.CopyOutput("action");
|
||||
|
||||
switch (m_World.ActionType)
|
||||
switch (m_Policy.ActionType)
|
||||
{
|
||||
case ActionType.CONTINUOUS:
|
||||
int count = m_World.DecisionCounter.Count * m_World.ActionSize;
|
||||
int count = m_Policy.DecisionCounter.Count * m_Policy.ActionSize;
|
||||
var wholeData = actuatorT.data.Download(count);
|
||||
// var dest = new float[count];
|
||||
// Array.Copy(wholeData, dest, count);
|
||||
// m_World.ContinuousActuators.Slice(0, count).CopyFrom(dest);
|
||||
// m_Policy.ContinuousActuators.Slice(0, count).CopyFrom(dest);
|
||||
fixed(void* arrPtr = wholeData)
|
||||
{
|
||||
UnsafeUtility.MemCpy(
|
||||
m_World.ContinuousActuators.GetUnsafePtr(),
|
||||
m_Policy.ContinuousActuators.GetUnsafePtr(),
|
||||
arrPtr,
|
||||
count * 4
|
||||
);
|
|
@ -0,0 +1,104 @@
|
|||
using System;
|
||||
using Unity.Collections;
|
||||
using Unity.Collections.LowLevel.Unsafe;
|
||||
|
||||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
public static class HeristicPolicyProcessorRegistringExtension
|
||||
{
|
||||
/// <summary>
|
||||
/// Registers the given Policy to the Academy with a Heuristic.
|
||||
/// Note that if the simulation connects to Python, the Heuristic will
|
||||
/// be ignored and the Policy will exchange data with Python instead.
|
||||
/// The Heuristic is a Function that returns an action struct.
|
||||
/// </summary>
|
||||
/// <param name="policy"> The Policy to register</param>
|
||||
/// <param name="policyId"> The name of the Policy. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="heuristic"> The Heuristic used to generate the actions.
|
||||
/// Note that all agents in the Policy will receive the same action.</param>
|
||||
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
|
||||
/// Size and Action Type of the Policy.</typeparam>
|
||||
public static void RegisterPolicyWithHeuristic<TH>(
|
||||
this Policy policy,
|
||||
string policyId,
|
||||
Func<TH> heuristic
|
||||
) where TH : struct
|
||||
{
|
||||
var policyProcessor = new HeuristicPolicyProcessor<TH>(policy, heuristic);
|
||||
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, true);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers the given Policy to the Academy with a Heuristic.
|
||||
/// Note that if the simulation connects to Python, the Policy will not
|
||||
/// exchange data with Python and use the Heuristic regardless.
|
||||
/// The Heuristic is a Function that returns an action struct.
|
||||
/// </summary>
|
||||
/// <param name="policy"> The Policy to register</param>
|
||||
/// <param name="policyId"> The name of the Policy. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="heuristic"> The Heuristic used to generate the actions.
|
||||
/// Note that all agents in the Policy will receive the same action.</param>
|
||||
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
|
||||
/// Size and Action Type of the Policy.</typeparam>
|
||||
public static void RegisterPolicyWithHeuristicForceNoCommunication<TH>(
|
||||
this Policy policy,
|
||||
string policyId,
|
||||
Func<TH> heuristic
|
||||
) where TH : struct
|
||||
{
|
||||
var policyProcessor = new HeuristicPolicyProcessor<TH>(policy, heuristic);
|
||||
Academy.Instance.RegisterPolicy(policyId, policy, policyProcessor, false);
|
||||
}
|
||||
}
|
||||
|
||||
internal class HeuristicPolicyProcessor<T> : IPolicyProcessor where T : struct
|
||||
{
|
||||
private Func<T> m_Heuristic;
|
||||
private Policy m_Policy;
|
||||
|
||||
public bool IsConnected {get {return false;}}
|
||||
|
||||
internal HeuristicPolicyProcessor(Policy policy, Func<T> heuristic)
|
||||
{
|
||||
this.m_Policy = policy;
|
||||
this.m_Heuristic = heuristic;
|
||||
var structSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
|
||||
if (structSize != policy.ActionSize)
|
||||
{
|
||||
throw new MLAgentsException(
|
||||
$"The heuristic provided does not match the action size. Expected {policy.ActionSize} but received {structSize} from heuristic");
|
||||
}
|
||||
}
|
||||
|
||||
public void Process()
|
||||
{
|
||||
T action = m_Heuristic.Invoke();
|
||||
var totalCount = m_Policy.DecisionCounter.Count;
|
||||
|
||||
// TODO : This can be parallelized
|
||||
if (m_Policy.ActionType == ActionType.CONTINUOUS)
|
||||
{
|
||||
var s = m_Policy.ContinuousActuators.Slice(0, totalCount * m_Policy.ActionSize).SliceConvert<T>();
|
||||
for (int i = 0; i < totalCount; i++)
|
||||
{
|
||||
s[i] = action;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var s = m_Policy.DiscreteActuators.Slice(0, totalCount * m_Policy.ActionSize).SliceConvert<T>();
|
||||
for (int i = 0; i < totalCount; i++)
|
||||
{
|
||||
s[i] = action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,21 +6,21 @@ using Unity.Collections.LowLevel.Unsafe;
|
|||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
/// <summary>
|
||||
/// The interface for a world processor. A world processor updates the
|
||||
/// action data of a world using the observation data present in it.
|
||||
/// The interface for a Policy processor. A Policy processor updates the
|
||||
/// action data of a Policy using the observation data present in it.
|
||||
/// </summary>
|
||||
public interface IWorldProcessor : IDisposable
|
||||
public interface IPolicyProcessor : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// True if the World Processor is connected to the Python process
|
||||
/// True if the Policy Processor is connected to the Python process
|
||||
/// </summary>
|
||||
bool IsConnected {get;}
|
||||
|
||||
/// <summary>
|
||||
/// This method is called once everytime the world needs to update its action
|
||||
/// This method is called once everytime the policy needs to update its action
|
||||
/// data. The implementation of this mehtod must give new action data to all the
|
||||
/// agents that requested a decision.
|
||||
/// </summary>
|
||||
void ProcessWorld();
|
||||
void Process();
|
||||
}
|
||||
}
|
|
@ -5,18 +5,18 @@ using Unity.Collections.LowLevel.Unsafe;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
internal class NullWorldProcessor : IWorldProcessor
|
||||
internal class NullPolicyProcessor : IPolicyProcessor
|
||||
{
|
||||
private MLAgentsWorld m_World;
|
||||
private Policy m_Policy;
|
||||
|
||||
public bool IsConnected {get {return false;}}
|
||||
|
||||
internal NullWorldProcessor(MLAgentsWorld world)
|
||||
internal NullPolicyProcessor(Policy policy)
|
||||
{
|
||||
this.m_World = world;
|
||||
this.m_Policy = policy;
|
||||
}
|
||||
|
||||
public void ProcessWorld()
|
||||
public void Process()
|
||||
{
|
||||
}
|
||||
|
|
@ -5,27 +5,27 @@ using Unity.Collections.LowLevel.Unsafe;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
internal class RemoteWorldProcessor : IWorldProcessor
|
||||
internal class RemotePolicyProcessor : IPolicyProcessor
|
||||
{
|
||||
private MLAgentsWorld m_World;
|
||||
private Policy m_Policy;
|
||||
private SharedMemoryCommunicator m_Communicator;
|
||||
private string m_PolicyId;
|
||||
|
||||
public bool IsConnected {get {return true;}}
|
||||
|
||||
internal RemoteWorldProcessor(MLAgentsWorld world, string policyId, SharedMemoryCommunicator com)
|
||||
internal RemotePolicyProcessor(Policy policy, string policyId, SharedMemoryCommunicator com)
|
||||
{
|
||||
this.m_World = world;
|
||||
this.m_Policy = policy;
|
||||
this.m_Communicator = com;
|
||||
this.m_PolicyId = policyId;
|
||||
}
|
||||
|
||||
public void ProcessWorld()
|
||||
public void Process()
|
||||
{
|
||||
m_Communicator.WriteWorld(m_PolicyId, m_World);
|
||||
m_Communicator.WritePolicy(m_PolicyId, m_Policy);
|
||||
m_Communicator.SetUnityReady();
|
||||
m_Communicator.WaitForPython();
|
||||
m_Communicator.LoadWorld(m_PolicyId, m_World);
|
||||
m_Communicator.LoadPolicy(m_PolicyId, m_Policy);
|
||||
}
|
||||
|
||||
public void Dispose()
|
|
@ -68,26 +68,26 @@ namespace Unity.AI.MLAgents
|
|||
startOffset);
|
||||
}
|
||||
|
||||
public static RLDataOffsets FromWorld(MLAgentsWorld world, string name, int offset)
|
||||
public static RLDataOffsets FromPolicy(Policy policy, string name, int offset)
|
||||
{
|
||||
bool isContinuous = world.ActionType == ActionType.CONTINUOUS;
|
||||
bool isContinuous = policy.ActionType == ActionType.CONTINUOUS;
|
||||
int totalFloatObsPerAgent = 0;
|
||||
foreach (int3 shape in world.SensorShapes)
|
||||
foreach (int3 shape in policy.SensorShapes)
|
||||
{
|
||||
totalFloatObsPerAgent += shape.GetTotalTensorSize();
|
||||
}
|
||||
int totalNumberOfMasks = 0;
|
||||
if (!isContinuous)
|
||||
{
|
||||
totalNumberOfMasks = world.DiscreteActionBranches.Sum();
|
||||
totalNumberOfMasks = policy.DiscreteActionBranches.Sum();
|
||||
}
|
||||
|
||||
return ComputeOffsets(
|
||||
name,
|
||||
world.DecisionAgentIds.Length,
|
||||
policy.DecisionAgentIds.Length,
|
||||
isContinuous,
|
||||
world.ActionSize,
|
||||
world.SensorShapes.Length,
|
||||
policy.ActionSize,
|
||||
policy.SensorShapes.Length,
|
||||
totalFloatObsPerAgent,
|
||||
totalNumberOfMasks,
|
||||
offset
|
||||
|
|
|
@ -52,12 +52,12 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
}
|
||||
|
||||
public bool ContainsWorld(string name)
|
||||
public bool ContainsPolicy(string name)
|
||||
{
|
||||
return m_OffsetDict.ContainsKey(name);
|
||||
}
|
||||
|
||||
public void WriteWorld(string name, MLAgentsWorld world)
|
||||
public void WritePolicy(string name, Policy policy)
|
||||
{
|
||||
if (!CanEdit)
|
||||
{
|
||||
|
@ -69,48 +69,48 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
var dataOffsets = m_OffsetDict[name];
|
||||
int totalFloatObsPerAgent = 0;
|
||||
foreach (int3 shape in world.SensorShapes)
|
||||
foreach (int3 shape in policy.SensorShapes)
|
||||
{
|
||||
totalFloatObsPerAgent += shape.GetTotalTensorSize();
|
||||
}
|
||||
|
||||
// Decision data
|
||||
var decisionCount = world.DecisionCounter.Count;
|
||||
var decisionCount = policy.DecisionCounter.Count;
|
||||
SetInt(dataOffsets.DecisionNumberAgentsOffset, decisionCount);
|
||||
SetArray(dataOffsets.DecisionObsOffset, world.DecisionObs, 4 * decisionCount * totalFloatObsPerAgent);
|
||||
SetArray(dataOffsets.DecisionRewardsOffset, world.DecisionRewards, 4 * decisionCount);
|
||||
SetArray(dataOffsets.DecisionAgentIdOffset, world.DecisionAgentIds, 4 * decisionCount);
|
||||
if (world.ActionType == ActionType.DISCRETE)
|
||||
SetArray(dataOffsets.DecisionObsOffset, policy.DecisionObs, 4 * decisionCount * totalFloatObsPerAgent);
|
||||
SetArray(dataOffsets.DecisionRewardsOffset, policy.DecisionRewards, 4 * decisionCount);
|
||||
SetArray(dataOffsets.DecisionAgentIdOffset, policy.DecisionAgentIds, 4 * decisionCount);
|
||||
if (policy.ActionType == ActionType.DISCRETE)
|
||||
{
|
||||
SetArray(dataOffsets.DecisionActionMasksOffset, world.DecisionActionMasks, decisionCount * world.DiscreteActionBranches.Sum());
|
||||
SetArray(dataOffsets.DecisionActionMasksOffset, policy.DecisionActionMasks, decisionCount * policy.DiscreteActionBranches.Sum());
|
||||
}
|
||||
|
||||
//Termination data
|
||||
var terminationCount = world.TerminationCounter.Count;
|
||||
var terminationCount = policy.TerminationCounter.Count;
|
||||
SetInt(dataOffsets.TerminationNumberAgentsOffset, terminationCount);
|
||||
SetArray(dataOffsets.TerminationObsOffset, world.TerminationObs, 4 * terminationCount * totalFloatObsPerAgent);
|
||||
SetArray(dataOffsets.TerminationRewardsOffset, world.TerminationRewards, 4 * terminationCount);
|
||||
SetArray(dataOffsets.TerminationAgentIdOffset, world.TerminationAgentIds, 4 * terminationCount);
|
||||
SetArray(dataOffsets.TerminationStatusOffset, world.TerminationStatus, terminationCount);
|
||||
SetArray(dataOffsets.TerminationObsOffset, policy.TerminationObs, 4 * terminationCount * totalFloatObsPerAgent);
|
||||
SetArray(dataOffsets.TerminationRewardsOffset, policy.TerminationRewards, 4 * terminationCount);
|
||||
SetArray(dataOffsets.TerminationAgentIdOffset, policy.TerminationAgentIds, 4 * terminationCount);
|
||||
SetArray(dataOffsets.TerminationStatusOffset, policy.TerminationStatus, terminationCount);
|
||||
}
|
||||
|
||||
public void WriteWorldSpecs(string name, MLAgentsWorld world)
|
||||
public void WritePolicySpecs(string name, Policy policy)
|
||||
{
|
||||
m_OffsetDict[name] = RLDataOffsets.FromWorld(world, name, m_CurrentEndOffset);
|
||||
m_OffsetDict[name] = RLDataOffsets.FromPolicy(policy, name, m_CurrentEndOffset);
|
||||
var offset = m_CurrentEndOffset;
|
||||
offset = SetString(offset, name); // Name
|
||||
offset = SetInt(offset, world.DecisionAgentIds.Length); // Max Agents
|
||||
offset = SetBool(offset, world.ActionType == ActionType.CONTINUOUS);
|
||||
offset = SetInt(offset, world.ActionSize);
|
||||
if (world.ActionType == ActionType.DISCRETE)
|
||||
offset = SetInt(offset, policy.DecisionAgentIds.Length); // Max Agents
|
||||
offset = SetBool(offset, policy.ActionType == ActionType.CONTINUOUS);
|
||||
offset = SetInt(offset, policy.ActionSize);
|
||||
if (policy.ActionType == ActionType.DISCRETE)
|
||||
{
|
||||
foreach (int branchSize in world.DiscreteActionBranches)
|
||||
foreach (int branchSize in policy.DiscreteActionBranches)
|
||||
{
|
||||
offset = SetInt(offset, branchSize);
|
||||
}
|
||||
}
|
||||
offset = SetInt(offset, world.SensorShapes.Length);
|
||||
foreach (int3 shape in world.SensorShapes)
|
||||
offset = SetInt(offset, policy.SensorShapes.Length);
|
||||
foreach (int3 shape in policy.SensorShapes)
|
||||
{
|
||||
offset = SetInt(offset, shape.x);
|
||||
offset = SetInt(offset, shape.y);
|
||||
|
@ -119,7 +119,7 @@ namespace Unity.AI.MLAgents
|
|||
m_CurrentEndOffset = offset;
|
||||
}
|
||||
|
||||
public void ReadWorld(string name, MLAgentsWorld world)
|
||||
public void ReadPolicy(string name, Policy policy)
|
||||
{
|
||||
if (!CanEdit)
|
||||
{
|
||||
|
@ -127,18 +127,18 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
if (!m_OffsetDict.ContainsKey(name))
|
||||
{
|
||||
throw new MLAgentsException("World not registered");
|
||||
throw new MLAgentsException("Policy not registered");
|
||||
}
|
||||
var dataOffsets = m_OffsetDict[name];
|
||||
SetInt(dataOffsets.DecisionNumberAgentsOffset, 0);
|
||||
SetInt(dataOffsets.TerminationNumberAgentsOffset, 0);
|
||||
if (world.ActionType == ActionType.DISCRETE)
|
||||
if (policy.ActionType == ActionType.DISCRETE)
|
||||
{
|
||||
GetArray(dataOffsets.ActionOffset, world.DiscreteActuators, 4 * world.DecisionCounter.Count * world.ActionSize);
|
||||
GetArray(dataOffsets.ActionOffset, policy.DiscreteActuators, 4 * policy.DecisionCounter.Count * policy.ActionSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
GetArray(dataOffsets.ActionOffset, world.ContinuousActuators, 4 * world.DecisionCounter.Count * world.ActionSize);
|
||||
GetArray(dataOffsets.ActionOffset, policy.ContinuousActuators, 4 * policy.DecisionCounter.Count * policy.ActionSize);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -75,23 +75,23 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the data of a world into the shared memory file.
|
||||
/// Writes the data of a worPolicyld into the shared memory file.
|
||||
/// </summary>
|
||||
public void WriteWorld(string worldName, MLAgentsWorld world)
|
||||
public void WritePolicy(string policyName, Policy policy)
|
||||
{
|
||||
if (!m_SharedMemoryHeader.Active)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if (m_ShareMemoryBody.ContainsWorld(worldName))
|
||||
if (m_ShareMemoryBody.ContainsPolicy(policyName))
|
||||
{
|
||||
m_ShareMemoryBody.WriteWorld(worldName, world);
|
||||
m_ShareMemoryBody.WritePolicy(policyName, policy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// The world needs to register
|
||||
// The policy needs to register
|
||||
int oldTotalCapacity = m_SharedMemoryHeader.RLDataBufferSize;
|
||||
int worldMemorySize = RLDataOffsets.FromWorld(world, worldName, 0).EndOfDataOffset;
|
||||
int policyMemorySize = RLDataOffsets.FromPolicy(policy, policyName, 0).EndOfDataOffset;
|
||||
m_CurrentFileNumber += 1;
|
||||
m_SharedMemoryHeader.FileNumber = m_CurrentFileNumber;
|
||||
byte[] channelData = m_ShareMemoryBody.SideChannelData;
|
||||
|
@ -102,9 +102,9 @@ namespace Unity.AI.MLAgents
|
|||
true,
|
||||
null,
|
||||
m_SharedMemoryHeader.SideChannelBufferSize,
|
||||
oldTotalCapacity + worldMemorySize
|
||||
oldTotalCapacity + policyMemorySize
|
||||
);
|
||||
m_SharedMemoryHeader.RLDataBufferSize = oldTotalCapacity + worldMemorySize;
|
||||
m_SharedMemoryHeader.RLDataBufferSize = oldTotalCapacity + policyMemorySize;
|
||||
if (channelData != null)
|
||||
{
|
||||
m_ShareMemoryBody.SideChannelData = channelData;
|
||||
|
@ -114,8 +114,8 @@ namespace Unity.AI.MLAgents
|
|||
m_ShareMemoryBody.RlData = rlData;
|
||||
}
|
||||
// TODO Need to write the offsets
|
||||
m_ShareMemoryBody.WriteWorldSpecs(worldName, world);
|
||||
m_ShareMemoryBody.WriteWorld(worldName, world);
|
||||
m_ShareMemoryBody.WritePolicySpecs(policyName, policy);
|
||||
m_ShareMemoryBody.WritePolicy(policyName, policy);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -180,11 +180,11 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads the action data form the shared memory file to the world
|
||||
/// Loads the action data form the shared memory file to the policy
|
||||
/// </summary>
|
||||
public void LoadWorld(string worldName, MLAgentsWorld world)
|
||||
public void LoadPolicy(string policyName, Policy policy)
|
||||
{
|
||||
m_ShareMemoryBody.ReadWorld(worldName, world);
|
||||
m_ShareMemoryBody.ReadPolicy(policyName, policy);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
|
|
|
@ -6,7 +6,7 @@ using UnityEngine;
|
|||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
internal enum WorldProcessorType
|
||||
internal enum PolicyProcessorType
|
||||
{
|
||||
Default,
|
||||
InferenceOnly,
|
||||
|
@ -14,16 +14,16 @@ namespace Unity.AI.MLAgents
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// An editor friendly constructor for a MLAgentsWorld.
|
||||
/// Keeps track of the behavior specs of a world, its name,
|
||||
/// An editor friendly constructor for a Policy.
|
||||
/// Keeps track of the behavior specs of a Policy, its name,
|
||||
/// its processor type and Neural Network Model.
|
||||
/// </summary>
|
||||
[Serializable]
|
||||
public struct MLAgentsWorldSpecs
|
||||
public struct PolicySpecs
|
||||
{
|
||||
[SerializeField] internal string Name;
|
||||
|
||||
[SerializeField] internal WorldProcessorType WorldProcessorType;
|
||||
[SerializeField] internal PolicyProcessorType PolicyProcessorType;
|
||||
|
||||
[SerializeField] internal int NumberAgents;
|
||||
[SerializeField] internal ActionType ActionType;
|
||||
|
@ -34,47 +34,47 @@ namespace Unity.AI.MLAgents
|
|||
[SerializeField] internal NNModel Model;
|
||||
[SerializeField] internal InferenceDevice InferenceDevice;
|
||||
|
||||
private MLAgentsWorld m_World;
|
||||
private Policy m_Policy;
|
||||
|
||||
/// <summary>
|
||||
/// Generates the world using the specified specs and registers its
|
||||
/// processor to the Academy. The world is only created and registed once,
|
||||
/// if subsequent calls are made, the created world will be returned.
|
||||
/// Generates a Policy using the specified specs and registers its
|
||||
/// processor to the Academy. The policy is only created and registed once,
|
||||
/// if subsequent calls are made, the created policy will be returned.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public MLAgentsWorld GetWorld()
|
||||
public Policy GetPolicy()
|
||||
{
|
||||
if (m_World.IsCreated)
|
||||
if (m_Policy.IsCreated)
|
||||
{
|
||||
return m_World;
|
||||
return m_Policy;
|
||||
}
|
||||
m_World = new MLAgentsWorld(
|
||||
m_Policy = new Policy(
|
||||
NumberAgents,
|
||||
ObservationShapes,
|
||||
ActionType,
|
||||
ActionSize,
|
||||
DiscreteActionBranches
|
||||
);
|
||||
switch (WorldProcessorType)
|
||||
switch (PolicyProcessorType)
|
||||
{
|
||||
case WorldProcessorType.Default:
|
||||
m_World.RegisterWorldWithBarracudaModel(Name, Model, InferenceDevice);
|
||||
case PolicyProcessorType.Default:
|
||||
m_Policy.RegisterPolicyWithBarracudaModel(Name, Model, InferenceDevice);
|
||||
break;
|
||||
case WorldProcessorType.InferenceOnly:
|
||||
case PolicyProcessorType.InferenceOnly:
|
||||
if (Model == null)
|
||||
{
|
||||
throw new MLAgentsException($"No model specified for {Name}");
|
||||
}
|
||||
m_World.RegisterWorldWithBarracudaModelForceNoCommunication(Name, Model, InferenceDevice);
|
||||
m_Policy.RegisterPolicyWithBarracudaModelForceNoCommunication(Name, Model, InferenceDevice);
|
||||
break;
|
||||
case WorldProcessorType.None:
|
||||
Academy.Instance.RegisterWorld(Name, m_World, new NullWorldProcessor(m_World), false);
|
||||
case PolicyProcessorType.None:
|
||||
Academy.Instance.RegisterPolicy(Name, m_Policy, new NullPolicyProcessor(m_Policy), false);
|
||||
break;
|
||||
default:
|
||||
throw new MLAgentsException($"Unknown WorldProcessor Type");
|
||||
throw new MLAgentsException($"Unknown IPolicyProcessor Type");
|
||||
}
|
||||
|
||||
return m_World;
|
||||
return m_Policy;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,104 +0,0 @@
|
|||
using System;
|
||||
using Unity.Collections;
|
||||
using Unity.Collections.LowLevel.Unsafe;
|
||||
|
||||
|
||||
namespace Unity.AI.MLAgents
|
||||
{
|
||||
public static class HeristicWorldProcessorRegistringExtension
|
||||
{
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
|
||||
/// Note that if the simulation connects to Python, the Heuristic will
|
||||
/// be ignored and the world will exchange data with Python instead.
|
||||
/// The Heuristic is a Function that returns an action struct.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="heuristic"> The Heuristic used to generate the actions.
|
||||
/// Note that all agents in the world will receive the same action.</param>
|
||||
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
|
||||
/// Size and Action Type of the world.</typeparam>
|
||||
public static void RegisterWorldWithHeuristic<TH>(
|
||||
this MLAgentsWorld world,
|
||||
string policyId,
|
||||
Func<TH> heuristic
|
||||
) where TH : struct
|
||||
{
|
||||
var worldProcessor = new HeuristicWorldProcessor<TH>(world, heuristic);
|
||||
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, true);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers the given MLAgentsWorld to the Academy with a Heuristic.
|
||||
/// Note that if the simulation connects to Python, the world will not
|
||||
/// exchange data with Python and use the Heuristic regardless.
|
||||
/// The Heuristic is a Function that returns an action struct.
|
||||
/// </summary>
|
||||
/// <param name="world"> The MLAgentsWorld to register</param>
|
||||
/// <param name="policyId"> The name of the world. This is useful for identification
|
||||
/// and for training.</param>
|
||||
/// <param name="heuristic"> The Heuristic used to generate the actions.
|
||||
/// Note that all agents in the world will receive the same action.</param>
|
||||
/// <typeparam name="TH"> The type of the Action struct. It must match the Action
|
||||
/// Size and Action Type of the world.</typeparam>
|
||||
public static void RegisterWorldWithHeuristicForceNoCommunication<TH>(
|
||||
this MLAgentsWorld world,
|
||||
string policyId,
|
||||
Func<TH> heuristic
|
||||
) where TH : struct
|
||||
{
|
||||
var worldProcessor = new HeuristicWorldProcessor<TH>(world, heuristic);
|
||||
Academy.Instance.RegisterWorld(policyId, world, worldProcessor, false);
|
||||
}
|
||||
}
|
||||
|
||||
internal class HeuristicWorldProcessor<T> : IWorldProcessor where T : struct
|
||||
{
|
||||
private Func<T> m_Heuristic;
|
||||
private MLAgentsWorld m_World;
|
||||
|
||||
public bool IsConnected {get {return false;}}
|
||||
|
||||
internal HeuristicWorldProcessor(MLAgentsWorld world, Func<T> heuristic)
|
||||
{
|
||||
this.m_World = world;
|
||||
this.m_Heuristic = heuristic;
|
||||
var structSize = UnsafeUtility.SizeOf<T>() / sizeof(float);
|
||||
if (structSize != world.ActionSize)
|
||||
{
|
||||
throw new MLAgentsException(
|
||||
$"The heuristic provided does not match the action size. Expected {world.ActionSize} but received {structSize} from heuristic");
|
||||
}
|
||||
}
|
||||
|
||||
public void ProcessWorld()
|
||||
{
|
||||
T action = m_Heuristic.Invoke();
|
||||
var totalCount = m_World.DecisionCounter.Count;
|
||||
|
||||
// TODO : This can be parallelized
|
||||
if (m_World.ActionType == ActionType.CONTINUOUS)
|
||||
{
|
||||
var s = m_World.ContinuousActuators.Slice(0, totalCount * m_World.ActionSize).SliceConvert<T>();
|
||||
for (int i = 0; i < totalCount; i++)
|
||||
{
|
||||
s[i] = action;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var s = m_World.DiscreteActuators.Slice(0, totalCount * m_World.ActionSize).SliceConvert<T>();
|
||||
for (int i = 0; i < totalCount; i++)
|
||||
{
|
||||
s[i] = action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,8 +4,9 @@
|
|||
Material:
|
||||
serializedVersion: 6
|
||||
m_ObjectHideFlags: 0
|
||||
m_PrefabParentObject: {fileID: 0}
|
||||
m_PrefabInternal: {fileID: 0}
|
||||
m_CorrespondingSourceObject: {fileID: 0}
|
||||
m_PrefabInstance: {fileID: 0}
|
||||
m_PrefabAsset: {fileID: 0}
|
||||
m_Name: AgentBlue
|
||||
m_Shader: {fileID: 47, guid: 0000000000000000f000000000000000, type: 0}
|
||||
m_ShaderKeywords: _GLOSSYREFLECTIONS_OFF _SPECULARHIGHLIGHTS_OFF
|
||||
|
|
|
@ -249,9 +249,9 @@ MonoBehaviour:
|
|||
m_Script: {fileID: 11500000, guid: 99c8505fda4844b5b8a9505d069fdd8c, type: 3}
|
||||
m_Name:
|
||||
m_EditorClassIdentifier:
|
||||
MyWorldSpecs:
|
||||
Name: Ball_DOTS
|
||||
WorldProcessorType: 0
|
||||
MyPolicySpecs:
|
||||
Name: 3DBall
|
||||
PolicyProcessorType: 0
|
||||
NumberAgents: 1000
|
||||
ActionType: 1
|
||||
ObservationShapes:
|
||||
|
@ -268,7 +268,7 @@ MonoBehaviour:
|
|||
y: 0
|
||||
z: 0
|
||||
ActionSize: 2
|
||||
DiscreteActionBranches:
|
||||
DiscreteActionBranches: 0000000000000000
|
||||
Model: {fileID: 5022602860645237092, guid: 10312addcce7c4e508c30d95eaf1a2ff, type: 3}
|
||||
InferenceDevice: 0
|
||||
NumberBalls: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ using Unity.AI.MLAgents;
|
|||
|
||||
public class BalanceBallManager : MonoBehaviour
|
||||
{
|
||||
public MLAgentsWorldSpecs MyWorldSpecs;
|
||||
public PolicySpecs MyPolicySpecs;
|
||||
|
||||
public int NumberBalls = 1000;
|
||||
|
||||
|
@ -23,12 +23,13 @@ public class BalanceBallManager : MonoBehaviour
|
|||
NativeArray<Entity> entitiesB;
|
||||
BlobAssetStore blob;
|
||||
|
||||
|
||||
void Awake()
|
||||
{
|
||||
var world = MyWorldSpecs.GetWorld();
|
||||
var policy = MyPolicySpecs.GetPolicy();
|
||||
var ballSystem = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem<BallSystem>();
|
||||
ballSystem.Enabled = true;
|
||||
ballSystem.BallWorld = world;
|
||||
ballSystem.BallPolicy = policy;
|
||||
|
||||
manager = World.DefaultGameObjectInjectionWorld.EntityManager;
|
||||
|
||||
|
|
|
@ -33,17 +33,17 @@ public class BallSystem : JobComponentSystem
|
|||
}
|
||||
}
|
||||
|
||||
public MLAgentsWorld BallWorld;
|
||||
public Policy BallPolicy;
|
||||
|
||||
|
||||
// Update is called once per frame
|
||||
protected override JobHandle OnUpdate(JobHandle inputDeps)
|
||||
{
|
||||
|
||||
if (!BallWorld.IsCreated){
|
||||
if (!BallPolicy.IsCreated){
|
||||
return inputDeps;
|
||||
}
|
||||
var world = BallWorld;
|
||||
var policy = BallPolicy;
|
||||
|
||||
ComponentDataFromEntity<Translation> TranslationFromEntity = GetComponentDataFromEntity<Translation>(isReadOnly: false);
|
||||
ComponentDataFromEntity<PhysicsVelocity> VelFromEntity = GetComponentDataFromEntity<PhysicsVelocity>(isReadOnly: false);
|
||||
|
@ -52,7 +52,7 @@ public class BallSystem : JobComponentSystem
|
|||
.WithNativeDisableParallelForRestriction(VelFromEntity)
|
||||
.ForEach((Entity entity, ref Rotation rot, ref AgentData agentData) =>
|
||||
{
|
||||
|
||||
|
||||
var ballPos = TranslationFromEntity[agentData.BallRef].Value;
|
||||
var ballVel = VelFromEntity[agentData.BallRef].Linear;
|
||||
var platformVel = VelFromEntity[entity];
|
||||
|
@ -70,7 +70,7 @@ public class BallSystem : JobComponentSystem
|
|||
}
|
||||
if (!interruption && !taskFailed)
|
||||
{
|
||||
world.RequestDecision(entity)
|
||||
policy.RequestDecision(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
|
@ -79,7 +79,7 @@ public class BallSystem : JobComponentSystem
|
|||
}
|
||||
if (taskFailed)
|
||||
{
|
||||
world.EndEpisode(entity)
|
||||
policy.EndEpisode(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
|
@ -88,7 +88,7 @@ public class BallSystem : JobComponentSystem
|
|||
}
|
||||
else if (interruption)
|
||||
{
|
||||
world.InterruptEpisode(entity)
|
||||
policy.InterruptEpisode(entity)
|
||||
.SetObservation(0, rot.Value)
|
||||
.SetObservation(1, ballPos - agentData.BallResetPosition)
|
||||
.SetObservation(2, ballVel)
|
||||
|
@ -109,7 +109,7 @@ public class BallSystem : JobComponentSystem
|
|||
{
|
||||
ComponentDataFromEntity = GetComponentDataFromEntity<Actuator>(isReadOnly: false)
|
||||
};
|
||||
inputDeps = reactiveJob.Schedule(world, inputDeps);
|
||||
inputDeps = reactiveJob.Schedule(policy, inputDeps);
|
||||
|
||||
inputDeps = Entities.ForEach((Actuator act, ref Rotation rotation) =>
|
||||
{
|
||||
|
@ -122,6 +122,6 @@ public class BallSystem : JobComponentSystem
|
|||
|
||||
protected override void OnDestroy()
|
||||
{
|
||||
BallWorld.Dispose();
|
||||
BallPolicy.Dispose();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,8 +5,8 @@ using UnityEngine;
|
|||
|
||||
public class BasicAgent : MonoBehaviour
|
||||
{
|
||||
public MLAgentsWorldSpecs BasicSpecs;
|
||||
private MLAgentsWorld m_World;
|
||||
public PolicySpecs BasicSpecs;
|
||||
private Policy m_Policy;
|
||||
private Entity m_Entity;
|
||||
|
||||
public float timeBetweenDecisionsAtInference;
|
||||
|
@ -33,8 +33,8 @@ public class BasicAgent : MonoBehaviour
|
|||
void Start()
|
||||
{
|
||||
m_Entity = World.DefaultGameObjectInjectionWorld.EntityManager.CreateEntity();
|
||||
m_World = BasicSpecs.GetWorld();
|
||||
m_World.RegisterWorldWithHeuristic<int>("BASIC", () => { return 2; });
|
||||
m_Policy = BasicSpecs.GetPolicy();
|
||||
m_Policy.RegisterPolicyWithHeuristic<int>("BASIC", () => { return 2; });
|
||||
Academy.Instance.OnEnvironmentReset += BeginEpisode;
|
||||
BeginEpisode();
|
||||
}
|
||||
|
@ -62,12 +62,12 @@ public class BasicAgent : MonoBehaviour
|
|||
void StepAgent()
|
||||
{
|
||||
// Request a Decision for all agents
|
||||
m_World.RequestDecision(m_Entity)
|
||||
m_Policy.RequestDecision(m_Entity)
|
||||
.SetObservation(0, m_Position)
|
||||
.SetReward(-0.01f);
|
||||
|
||||
// Get the action
|
||||
NativeHashMap<Entity, int> actions = m_World.GenerateActionHashMap<int>(Allocator.Temp);
|
||||
NativeHashMap<Entity, int> actions = m_Policy.GenerateActionHashMap<int>(Allocator.Temp);
|
||||
int action = 0;
|
||||
actions.TryGetValue(m_Entity, out action);
|
||||
|
||||
|
@ -87,7 +87,7 @@ public class BasicAgent : MonoBehaviour
|
|||
// See if the Agent terminated
|
||||
if (m_Position == k_SmallGoalPosition)
|
||||
{
|
||||
m_World.EndEpisode(m_Entity)
|
||||
m_Policy.EndEpisode(m_Entity)
|
||||
.SetObservation(0, m_Position)
|
||||
.SetReward(0.1f);
|
||||
BeginEpisode();
|
||||
|
@ -95,7 +95,7 @@ public class BasicAgent : MonoBehaviour
|
|||
|
||||
if (m_Position == k_LargeGoalPosition)
|
||||
{
|
||||
m_World.EndEpisode(m_Entity)
|
||||
m_Policy.EndEpisode(m_Entity)
|
||||
.SetObservation(0, m_Position)
|
||||
.SetReward(1f);
|
||||
BeginEpisode();
|
||||
|
@ -104,6 +104,6 @@ public class BasicAgent : MonoBehaviour
|
|||
|
||||
void OnDestroy()
|
||||
{
|
||||
m_World.Dispose();
|
||||
m_Policy.Dispose();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ using Unity.Jobs;
|
|||
|
||||
namespace Unity.AI.MLAgents.Tests.Editor
|
||||
{
|
||||
public class TestMLAgentsWorld
|
||||
public class TestMLAgentsPolicy
|
||||
{
|
||||
private World ECSWorld;
|
||||
private EntityManager entityManager;
|
||||
|
@ -51,15 +51,15 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
}
|
||||
|
||||
[Test]
|
||||
public void TestWorldCreation()
|
||||
public void TestPolicyCreation()
|
||||
{
|
||||
var world = new MLAgentsWorld(
|
||||
var policy = new Policy(
|
||||
20,
|
||||
new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) },
|
||||
ActionType.DISCRETE,
|
||||
2,
|
||||
new int[] { 2, 3 });
|
||||
world.Dispose();
|
||||
policy.Dispose();
|
||||
}
|
||||
|
||||
private struct SingleActionEnumUpdate : IActuatorJob
|
||||
|
@ -77,18 +77,18 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
[Test]
|
||||
public void TestManualDecisionSteppingWithHeuristic()
|
||||
{
|
||||
var world = new MLAgentsWorld(
|
||||
var policy = new Policy(
|
||||
20,
|
||||
new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) },
|
||||
ActionType.DISCRETE,
|
||||
2,
|
||||
new int[] { 2, 3 });
|
||||
|
||||
world.RegisterWorldWithHeuristic("test", () => new int2(0, 1));
|
||||
policy.RegisterPolicyWithHeuristic("test", () => new int2(0, 1));
|
||||
|
||||
var entity = entityManager.CreateEntity();
|
||||
|
||||
world.RequestDecision(entity)
|
||||
policy.RequestDecision(entity)
|
||||
.SetObservation(0, new float3(1, 2, 3))
|
||||
.SetReward(1f);
|
||||
|
||||
|
@ -100,7 +100,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
ent = entities,
|
||||
action = actions
|
||||
};
|
||||
actionJob.Schedule(world, new JobHandle()).Complete();
|
||||
actionJob.Schedule(policy, new JobHandle()).Complete();
|
||||
|
||||
Assert.AreEqual(entity, entities[0]);
|
||||
Assert.AreEqual(new DiscreteAction_TWO_THREE
|
||||
|
@ -109,7 +109,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
action_TWO = ThreeOptionEnum.Option_TWO
|
||||
}, actions[0]);
|
||||
|
||||
world.Dispose();
|
||||
policy.Dispose();
|
||||
entities.Dispose();
|
||||
actions.Dispose();
|
||||
Academy.Instance.Dispose();
|
||||
|
@ -118,56 +118,56 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
[Test]
|
||||
public void TestTerminateEpisode()
|
||||
{
|
||||
var world = new MLAgentsWorld(
|
||||
var policy = new Policy(
|
||||
20,
|
||||
new int3[] { new int3(3, 0, 0), new int3(84, 84, 3) },
|
||||
ActionType.DISCRETE,
|
||||
2,
|
||||
new int[] { 2, 3 });
|
||||
|
||||
world.RegisterWorldWithHeuristic("test", () => new int2(0, 1));
|
||||
policy.RegisterPolicyWithHeuristic("test", () => new int2(0, 1));
|
||||
|
||||
var entity = entityManager.CreateEntity();
|
||||
|
||||
world.RequestDecision(entity)
|
||||
policy.RequestDecision(entity)
|
||||
.SetObservation(0, new float3(1, 2, 3))
|
||||
.SetReward(1f);
|
||||
|
||||
var hashMap = world.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
|
||||
var hashMap = policy.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
|
||||
Assert.True(hashMap.TryGetValue(entity, out _));
|
||||
hashMap.Dispose();
|
||||
|
||||
world.EndEpisode(entity);
|
||||
hashMap = world.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
|
||||
policy.EndEpisode(entity);
|
||||
hashMap = policy.GenerateActionHashMap<DiscreteAction_TWO_THREE>(Allocator.Temp);
|
||||
Assert.False(hashMap.TryGetValue(entity, out _));
|
||||
hashMap.Dispose();
|
||||
|
||||
world.Dispose();
|
||||
policy.Dispose();
|
||||
Academy.Instance.Dispose();
|
||||
}
|
||||
|
||||
[Test]
|
||||
public void TestMultiWorld()
|
||||
public void TestMultiPolicy()
|
||||
{
|
||||
var world1 = new MLAgentsWorld(
|
||||
var policy1 = new Policy(
|
||||
20,
|
||||
new int3[] { new int3(3, 0, 0) },
|
||||
ActionType.DISCRETE,
|
||||
2,
|
||||
new int[] { 2, 3 });
|
||||
var world2 = new MLAgentsWorld(
|
||||
var policy2 = new Policy(
|
||||
20,
|
||||
new int3[] { new int3(3, 0, 0) },
|
||||
ActionType.DISCRETE,
|
||||
2,
|
||||
new int[] { 2, 3 });
|
||||
|
||||
world1.RegisterWorldWithHeuristic("test1", () => new DiscreteAction_TWO_THREE
|
||||
policy1.RegisterPolicyWithHeuristic("test1", () => new DiscreteAction_TWO_THREE
|
||||
{
|
||||
action_ONE = TwoOptionEnum.Option_TWO,
|
||||
action_TWO = ThreeOptionEnum.Option_ONE
|
||||
});
|
||||
world2.RegisterWorldWithHeuristic("test2", () => new DiscreteAction_TWO_THREE
|
||||
policy2.RegisterPolicyWithHeuristic("test2", () => new DiscreteAction_TWO_THREE
|
||||
{
|
||||
action_ONE = TwoOptionEnum.Option_ONE,
|
||||
action_TWO = ThreeOptionEnum.Option_TWO
|
||||
|
@ -175,8 +175,8 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
|
||||
var entity = entityManager.CreateEntity();
|
||||
|
||||
world1.RequestDecision(entity);
|
||||
world2.RequestDecision(entity);
|
||||
policy1.RequestDecision(entity);
|
||||
policy2.RequestDecision(entity);
|
||||
|
||||
var entities = new NativeArray<Entity>(1, Allocator.Persistent);
|
||||
var actions = new NativeArray<DiscreteAction_TWO_THREE>(1, Allocator.Persistent);
|
||||
|
@ -185,7 +185,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
ent = entities,
|
||||
action = actions
|
||||
};
|
||||
actionJob1.Schedule(world1, new JobHandle()).Complete();
|
||||
actionJob1.Schedule(policy1, new JobHandle()).Complete();
|
||||
|
||||
Assert.AreEqual(entity, entities[0]);
|
||||
Assert.AreEqual(new DiscreteAction_TWO_THREE
|
||||
|
@ -203,7 +203,7 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
ent = entities,
|
||||
action = actions
|
||||
};
|
||||
actionJob2.Schedule(world2, new JobHandle()).Complete();
|
||||
actionJob2.Schedule(policy2, new JobHandle()).Complete();
|
||||
|
||||
Assert.AreEqual(entity, entities[0]);
|
||||
Assert.AreEqual(new DiscreteAction_TWO_THREE
|
||||
|
@ -214,8 +214,8 @@ namespace Unity.AI.MLAgents.Tests.Editor
|
|||
entities.Dispose();
|
||||
actions.Dispose();
|
||||
|
||||
world1.Dispose();
|
||||
world2.Dispose();
|
||||
policy1.Dispose();
|
||||
policy2.Dispose();
|
||||
Academy.Instance.Dispose();
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче