From 24485edd08014c0dd7b3e1966eb7e971e4cdf663 Mon Sep 17 00:00:00 2001 From: Miguel Alonso Jr Date: Fri, 4 Oct 2024 19:06:55 -0400 Subject: [PATCH] Fixed failing GPU test. --- .../mlagents/trainers/tests/torch_entities/test_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py index f7344a647..1a210987b 100644 --- a/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py @@ -1,5 +1,5 @@ import pytest -from mlagents.torch_utils import torch +from mlagents.torch_utils import torch, default_device import numpy as np from mlagents.trainers.torch_entities.utils import ModelUtils @@ -217,7 +217,7 @@ def test_predict_minimum_training(): argmin = argmin.squeeze() argmin = argmin.detach() sliced_oh = onehots[:, : num + 1] - inp = torch.cat([inp, sliced_oh], dim=2) + inp = torch.cat([inp, sliced_oh.to(default_device())], dim=2) embeddings = entity_embedding(inp, inp) masks = get_zero_entities_mask([inp])