зеркало из https://github.com/microsoft/nni.git
Родитель
8371d228f5
Коммит
efe1246323
|
@ -86,11 +86,15 @@ class AutoMaskInference:
|
|||
self.output_mask = self.in_masks[0]
|
||||
else:
|
||||
if isinstance(self.output, torch.Tensor):
|
||||
if self.output.requires_grad:
|
||||
self.output.retain_grad() # issue #5299
|
||||
self.output_mask = torch.ones_like(self.output)
|
||||
elif isinstance(self.output, list) or isinstance(self.output, tuple):
|
||||
self.output_mask = []
|
||||
for o_tensor in self.output:
|
||||
if isinstance(o_tensor, torch.Tensor):
|
||||
if o_tensor.requires_grad:
|
||||
o_tensor.retain_grad() # issue #5299
|
||||
self.output_mask.append(torch.ones_like(o_tensor))
|
||||
else:
|
||||
# if one of the outputs is not tensor, set the corresponding
|
||||
|
|
Загрузка…
Ссылка в новой задаче