Use internal PyTorch vmap for vrelu3 (#881)
* Use internal vmap implementation * Adjust how candidates functions are filtered Wrappers *values* will sometimes show internal function name. Better to filter over the identifier * Comment vrelu3_pytorch_nice for now Uses functionality currently unsupported by vmap "RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback." * Add note about vmap stablity * More explicit call, discussion: https://github.com/microsoft/knossos-ksc/pull/881#discussion_r654523527
This commit is contained in:
Родитель
53fcdb7f89
Коммит
bafd9c1dea
|
@ -6,7 +6,7 @@ from collections import OrderedDict
|
|||
import ksc.expr as expr
|
||||
from ksc.type import Type
|
||||
from ksc.torch_frontend import ksc_string_to_autograd_function
|
||||
|
||||
import torch._vmap_internals
|
||||
|
||||
# BEGINDOC
|
||||
def relu3(x: float) -> float:
|
||||
|
@ -107,7 +107,10 @@ def relu3_pytorch_nice(x: float) -> float:
|
|||
return x - 2 / 3
|
||||
|
||||
|
||||
vrelu3_pytorch_nice = torch.vmap(relu3_pytorch_nice)
|
||||
# With torch 1.9.0 this leads to
|
||||
# RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
|
||||
# See https://msrcambridge.visualstudio.com/Knossos/_backlogs/backlog/Knossos%20Team/Goals/?workitem=19587
|
||||
# vrelu3_pytorch_nice = torch._vmap_internals.vmap(relu3_pytorch_nice)
|
||||
|
||||
|
||||
def vrelu3_cuda_init():
|
||||
|
|
|
@ -87,31 +87,30 @@ def function_to_manual_cuda_benchmarks(func):
|
|||
|
||||
|
||||
def functions_to_benchmark(mod, benchmark_name, example_inputs):
|
||||
for fn in inspect.getmembers(
|
||||
mod, lambda m: inspect.isfunction(m) and m.__name__.startswith(benchmark_name)
|
||||
):
|
||||
fn_name, fn_obj = fn
|
||||
if fn_name == benchmark_name + "_bench_configs":
|
||||
continue
|
||||
elif fn_name == benchmark_name + "_pytorch":
|
||||
yield from function_to_torch_benchmarks(fn_obj)
|
||||
|
||||
elif fn_name == benchmark_name + "_pytorch_nice":
|
||||
yield BenchmarkFunction("PyTorch Nice", fn_obj)
|
||||
elif fn_name == benchmark_name:
|
||||
ks_mod = tsmod2ksmod(mod, benchmark_name, example_inputs, generate_lm=False)
|
||||
yield BenchmarkFunction("Knossos", ks_mod.apply)
|
||||
elif fn_name == benchmark_name + "_cuda_init":
|
||||
if torch.cuda.is_available():
|
||||
yield from function_to_manual_cuda_benchmarks(fn_obj)
|
||||
elif fn_name.startswith(benchmark_name + "_ks_embedded_"):
|
||||
n = len(benchmark_name + "_ks_embedded_")
|
||||
benchmark_display_name = "Knossos embedded " + fn_name[n:]
|
||||
yield BenchmarkFunction(benchmark_display_name, fn_obj().apply)
|
||||
else:
|
||||
# perhaps we should just allow anything that matches the pattern?
|
||||
# would make it easier to add arbitrary comparisons e.g. TF
|
||||
print(f"Ignoring {fn_name}")
|
||||
for fn_name, fn_obj in inspect.getmembers(mod, lambda m: inspect.isfunction(m)):
|
||||
if fn_name.startswith(benchmark_name):
|
||||
if fn_name == benchmark_name + "_bench_configs":
|
||||
continue
|
||||
elif fn_name == benchmark_name + "_pytorch":
|
||||
yield from function_to_torch_benchmarks(fn_obj)
|
||||
elif fn_name == benchmark_name + "_pytorch_nice":
|
||||
yield BenchmarkFunction("PyTorch Nice", fn_obj)
|
||||
elif fn_name == benchmark_name:
|
||||
ks_mod = tsmod2ksmod(
|
||||
mod, benchmark_name, example_inputs, generate_lm=False
|
||||
)
|
||||
yield BenchmarkFunction("Knossos", ks_mod.apply)
|
||||
elif fn_name == benchmark_name + "_cuda_init":
|
||||
if torch.cuda.is_available():
|
||||
yield from function_to_manual_cuda_benchmarks(fn_obj)
|
||||
elif fn_name.startswith(benchmark_name + "_ks_embedded_"):
|
||||
n = len(benchmark_name + "_ks_embedded_")
|
||||
benchmark_display_name = "Knossos embedded " + fn_name[n:]
|
||||
yield BenchmarkFunction(benchmark_display_name, fn_obj().apply)
|
||||
else:
|
||||
# perhaps we should just allow anything that matches the pattern?
|
||||
# would make it easier to add arbitrary comparisons e.g. TF
|
||||
print(f"Ignoring {fn_name}")
|
||||
|
||||
|
||||
def func_namer(benchmark_func):
|
||||
|
|
|
@ -7,6 +7,8 @@ import torch
|
|||
@torch.jit.ignore
|
||||
def elementwise_apply_pt18(f, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO: torch.vmap in 1.9
|
||||
# NOTE: torch.vmap still isn't in stable in PyTorch 1.9, it can be called via internal apis: torch._vmap_internals
|
||||
# https://github.com/pytorch/pytorch/issues/42368
|
||||
y = torch.zeros_like(x)
|
||||
sz = x.shape
|
||||
if len(sz) == 1:
|
||||
|
|
Загрузка…
Ссылка в новой задаче