Use internal PyTorch vmap for vrelu3 (#881)

* Use internal vmap implementation

* Adjust how candidates functions are filtered
Wrappers *values* will sometimes show internal function name.
Better to filter over the identifier

* Comment vrelu3_pytorch_nice for now
Uses functionality currently unsupported by vmap
"RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback."

* Add note about vmap stablity

* More explicit call, discussion:
https://github.com/microsoft/knossos-ksc/pull/881#discussion_r654523527
This commit is contained in:
Colin Gravill 2021-06-21 16:14:04 +02:00 коммит произвёл GitHub
Родитель 53fcdb7f89
Коммит bafd9c1dea
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 31 добавлений и 27 удалений

Просмотреть файл

@ -6,7 +6,7 @@ from collections import OrderedDict
import ksc.expr as expr
from ksc.type import Type
from ksc.torch_frontend import ksc_string_to_autograd_function
import torch._vmap_internals
# BEGINDOC
def relu3(x: float) -> float:
@ -107,7 +107,10 @@ def relu3_pytorch_nice(x: float) -> float:
return x - 2 / 3
vrelu3_pytorch_nice = torch.vmap(relu3_pytorch_nice)
# With torch 1.9.0 this leads to
# RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
# See https://msrcambridge.visualstudio.com/Knossos/_backlogs/backlog/Knossos%20Team/Goals/?workitem=19587
# vrelu3_pytorch_nice = torch._vmap_internals.vmap(relu3_pytorch_nice)
def vrelu3_cuda_init():

Просмотреть файл

@ -87,31 +87,30 @@ def function_to_manual_cuda_benchmarks(func):
def functions_to_benchmark(mod, benchmark_name, example_inputs):
for fn in inspect.getmembers(
mod, lambda m: inspect.isfunction(m) and m.__name__.startswith(benchmark_name)
):
fn_name, fn_obj = fn
if fn_name == benchmark_name + "_bench_configs":
continue
elif fn_name == benchmark_name + "_pytorch":
yield from function_to_torch_benchmarks(fn_obj)
elif fn_name == benchmark_name + "_pytorch_nice":
yield BenchmarkFunction("PyTorch Nice", fn_obj)
elif fn_name == benchmark_name:
ks_mod = tsmod2ksmod(mod, benchmark_name, example_inputs, generate_lm=False)
yield BenchmarkFunction("Knossos", ks_mod.apply)
elif fn_name == benchmark_name + "_cuda_init":
if torch.cuda.is_available():
yield from function_to_manual_cuda_benchmarks(fn_obj)
elif fn_name.startswith(benchmark_name + "_ks_embedded_"):
n = len(benchmark_name + "_ks_embedded_")
benchmark_display_name = "Knossos embedded " + fn_name[n:]
yield BenchmarkFunction(benchmark_display_name, fn_obj().apply)
else:
# perhaps we should just allow anything that matches the pattern?
# would make it easier to add arbitrary comparisons e.g. TF
print(f"Ignoring {fn_name}")
for fn_name, fn_obj in inspect.getmembers(mod, lambda m: inspect.isfunction(m)):
if fn_name.startswith(benchmark_name):
if fn_name == benchmark_name + "_bench_configs":
continue
elif fn_name == benchmark_name + "_pytorch":
yield from function_to_torch_benchmarks(fn_obj)
elif fn_name == benchmark_name + "_pytorch_nice":
yield BenchmarkFunction("PyTorch Nice", fn_obj)
elif fn_name == benchmark_name:
ks_mod = tsmod2ksmod(
mod, benchmark_name, example_inputs, generate_lm=False
)
yield BenchmarkFunction("Knossos", ks_mod.apply)
elif fn_name == benchmark_name + "_cuda_init":
if torch.cuda.is_available():
yield from function_to_manual_cuda_benchmarks(fn_obj)
elif fn_name.startswith(benchmark_name + "_ks_embedded_"):
n = len(benchmark_name + "_ks_embedded_")
benchmark_display_name = "Knossos embedded " + fn_name[n:]
yield BenchmarkFunction(benchmark_display_name, fn_obj().apply)
else:
# perhaps we should just allow anything that matches the pattern?
# would make it easier to add arbitrary comparisons e.g. TF
print(f"Ignoring {fn_name}")
def func_namer(benchmark_func):

Просмотреть файл

@ -7,6 +7,8 @@ import torch
@torch.jit.ignore
def elementwise_apply_pt18(f, x: torch.Tensor) -> torch.Tensor:
# TODO: torch.vmap in 1.9
# NOTE: torch.vmap still isn't in stable in PyTorch 1.9, it can be called via internal apis: torch._vmap_internals
# https://github.com/pytorch/pytorch/issues/42368
y = torch.zeros_like(x)
sz = x.shape
if len(sz) == 1: