This commit is contained in:
Super Daniel 2023-02-10 10:49:44 +08:00 коммит произвёл GitHub
Родитель 8371d228f5
Коммит efe1246323
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 4 добавлений и 0 удалений

Просмотреть файл

@ -86,11 +86,15 @@ class AutoMaskInference:
self.output_mask = self.in_masks[0]
else:
if isinstance(self.output, torch.Tensor):
if self.output.requires_grad:
self.output.retain_grad() # issue #5299
self.output_mask = torch.ones_like(self.output)
elif isinstance(self.output, list) or isinstance(self.output, tuple):
self.output_mask = []
for o_tensor in self.output:
if isinstance(o_tensor, torch.Tensor):
if o_tensor.requires_grad:
o_tensor.retain_grad() # issue #5299
self.output_mask.append(torch.ones_like(o_tensor))
else:
# if one of the outputs is not tensor, set the corresponding