diff --git a/nni/compression/pytorch/speedup/infer_mask.py b/nni/compression/pytorch/speedup/infer_mask.py index ead1ffb73..57c01d752 100644 --- a/nni/compression/pytorch/speedup/infer_mask.py +++ b/nni/compression/pytorch/speedup/infer_mask.py @@ -86,11 +86,15 @@ class AutoMaskInference: self.output_mask = self.in_masks[0] else: if isinstance(self.output, torch.Tensor): + if self.output.requires_grad: + self.output.retain_grad() # issue #5299 self.output_mask = torch.ones_like(self.output) elif isinstance(self.output, list) or isinstance(self.output, tuple): self.output_mask = [] for o_tensor in self.output: if isinstance(o_tensor, torch.Tensor): + if o_tensor.requires_grad: + o_tensor.retain_grad() # issue #5299 self.output_mask.append(torch.ones_like(o_tensor)) else: # if one of the outputs is not tensor, set the corresponding