Clean pytorch_nice doc (#971)
This commit is contained in:
Родитель
9252a6b486
Коммит
0daf53e905
|
@ -42,14 +42,22 @@ def vrelu3_pytorch(x: torch.Tensor):
|
|||
|
||||
|
||||
# run-bench: PyTorch "nice" implementation
|
||||
def relu3_pytorch_nice(x: float) -> float:
|
||||
if x < 0.0:
|
||||
return torch.zeros_like(x) # Needed for PyTorch, not for Knossos [Note: zeros]
|
||||
elif x < 1.0:
|
||||
return 1 / 3 * x ** 3
|
||||
else:
|
||||
return x - 2 / 3
|
||||
# TODO: With torch 1.9.0 this leads to
|
||||
# RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
|
||||
# See https://msrcambridge.visualstudio.com/Knossos/_backlogs/backlog/Knossos%20Team/Goals/?workitem=19587
|
||||
if False:
|
||||
|
||||
def relu3_pytorch_nice(x: float) -> float:
|
||||
if x < 0.0:
|
||||
return torch.zeros_like(
|
||||
x
|
||||
) # Needed for PyTorch, not for Knossos [Note: zeros]
|
||||
elif x < 1.0:
|
||||
return 1 / 3 * x ** 3
|
||||
else:
|
||||
return x - 2 / 3
|
||||
|
||||
vrelu3_pytorch_nice = torch._vmap_internals.vmap(relu3_pytorch_nice)
|
||||
|
||||
# run-bench: Knossos implementation
|
||||
def vrelu3(x: torch.Tensor):
|
||||
|
@ -478,12 +486,6 @@ def vrelu3_embedded_INCORRECT_ks_upper_bound():
|
|||
)
|
||||
|
||||
|
||||
# With torch 1.9.0 this leads to
|
||||
# RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
|
||||
# See https://msrcambridge.visualstudio.com/Knossos/_backlogs/backlog/Knossos%20Team/Goals/?workitem=19587
|
||||
# vrelu3_pytorch_nice = torch._vmap_internals.vmap(relu3_pytorch_nice)
|
||||
|
||||
|
||||
def vrelu3_cuda_init():
|
||||
__ksc_path, ksc_runtime_dir = utils.get_ksc_paths()
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
|
Загрузка…
Ссылка в новой задаче