Fixed failing GPU test.
This commit is contained in:
Родитель
9d34978a43
Коммит
24485edd08
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from mlagents.torch_utils import torch
|
||||
from mlagents.torch_utils import torch, default_device
|
||||
import numpy as np
|
||||
|
||||
from mlagents.trainers.torch_entities.utils import ModelUtils
|
||||
|
@ -217,7 +217,7 @@ def test_predict_minimum_training():
|
|||
argmin = argmin.squeeze()
|
||||
argmin = argmin.detach()
|
||||
sliced_oh = onehots[:, : num + 1]
|
||||
inp = torch.cat([inp, sliced_oh], dim=2)
|
||||
inp = torch.cat([inp, sliced_oh.to(default_device())], dim=2)
|
||||
|
||||
embeddings = entity_embedding(inp, inp)
|
||||
masks = get_zero_entities_mask([inp])
|
||||
|
|
Загрузка…
Ссылка в новой задаче