using Unity.Collections; using Unity.Entities; using Unity.Collections.LowLevel.Unsafe; namespace Unity.AI.MLAgents { public static class ActionHashMapUtils { /// /// Retrieves the action data for a Policy in puts it into a HashMap. /// This action deletes the action data from the Policy. /// /// The Policy the data will be retrieved from. /// The memory allocator of the create NativeHashMap. /// The type of the Action struct. It must match the Action Size /// and Action Type of the Policy. /// A NativeHashMap from Entities to Actions with type T. public static void GenerateActionHashMap( this Policy policy, NativeHashMap continuousActionMap, NativeHashMap discreteActionMap) where TC : struct where TD : struct { int contSize = policy.ContinuousActionSize; int discSize = policy.DiscreteActionBranches.Length; #if ENABLE_UNITY_COLLECTIONS_CHECKS if (contSize != UnsafeUtility.SizeOf() / 4) { var receivedSize = UnsafeUtility.SizeOf() / 4; throw new MLAgentsException($"Continuous action space size does not match for action. Expected {contSize} but received {receivedSize}"); } if (discSize != UnsafeUtility.SizeOf() / 4) { var receivedSize = UnsafeUtility.SizeOf() / 4; throw new MLAgentsException($"Discrete action space size does not match for action. Expected {discSize} but received {receivedSize}"); } #endif for (int i = 0; i < policy.TerminationCounter.Count; i++) { // Remove the action of terminated agents continuousActionMap.Remove(policy.TerminationAgentEntityIds[i]); discreteActionMap.Remove(policy.TerminationAgentEntityIds[i]); } Academy.Instance.UpdatePolicy(policy); int actionCount = policy.ActionCounter.Count; for (int i = 0; i < actionCount; i++) { continuousActionMap.Remove(policy.ActionAgentEntityIds[i]); continuousActionMap.TryAdd(policy.ActionAgentEntityIds[i], policy.ContinuousActuators.Slice(i * contSize, contSize).SliceConvert()[0]); discreteActionMap.Remove(policy.ActionAgentEntityIds[i]); discreteActionMap.TryAdd(policy.ActionAgentEntityIds[i], policy.DiscreteActuators.Slice(i * discSize, discSize).SliceConvert()[0]); } policy.ResetActionsCounter(); } /// /// Retrieves the continuous action data for a Policy in puts it into a HashMap. /// This action deletes the action data from the Policy. /// /// The Policy the data will be retrieved from. /// The memory allocator of the create NativeHashMap. /// The type of the Action struct. It must match the Action Size /// and Action Type of the Policy. /// A NativeHashMap from Entities to Actions with type T. public static void GenerateContinuousActionHashMap( this Policy policy, NativeHashMap continuousActionMap) where TC : struct { int contSize = policy.ContinuousActionSize; #if ENABLE_UNITY_COLLECTIONS_CHECKS if (contSize != UnsafeUtility.SizeOf() / 4) { var receivedSize = UnsafeUtility.SizeOf() / 4; throw new MLAgentsException($"Continuous action space size does not match for action. Expected {contSize} but received {receivedSize}"); } #endif for (int i = 0; i < policy.TerminationCounter.Count; i++) { // Remove the action of terminated agents continuousActionMap.Remove(policy.TerminationAgentEntityIds[i]); } Academy.Instance.UpdatePolicy(policy); int actionCount = policy.ActionCounter.Count; for (int i = 0; i < actionCount; i++) { continuousActionMap.Remove(policy.ActionAgentEntityIds[i]); continuousActionMap.TryAdd(policy.ActionAgentEntityIds[i], policy.ContinuousActuators.Slice(i * contSize, contSize).SliceConvert()[0]); } policy.ResetActionsCounter(); } /// /// Retrieves the discrete action data for a Policy in puts it into a HashMap. /// This action deletes the action data from the Policy. /// /// The Policy the data will be retrieved from. /// The memory allocator of the create NativeHashMap. /// The type of the Action struct. It must match the Action Size /// and Action Type of the Policy. /// A NativeHashMap from Entities to Actions with type T. public static void GenerateDiscreteActionHashMap( this Policy policy, NativeHashMap discreteActionMap) where TD : struct { int discSize = policy.DiscreteActionBranches.Length; #if ENABLE_UNITY_COLLECTIONS_CHECKS if (discSize != UnsafeUtility.SizeOf() / 4) { var receivedSize = UnsafeUtility.SizeOf() / 4; throw new MLAgentsException($"Discrete action space size does not match for action. Expected {discSize} but received {receivedSize}"); } #endif for (int i = 0; i < policy.TerminationCounter.Count; i++) { // Remove the action of terminated agents discreteActionMap.Remove(policy.TerminationAgentEntityIds[i]); } Academy.Instance.UpdatePolicy(policy); int actionCount = policy.ActionCounter.Count; for (int i = 0; i < actionCount; i++) { discreteActionMap.Remove(policy.ActionAgentEntityIds[i]); discreteActionMap.TryAdd(policy.ActionAgentEntityIds[i], policy.DiscreteActuators.Slice(i * discSize, discSize).SliceConvert()[0]); } policy.ResetActionsCounter(); } } }