From 6cff51b0a530c7574f3f23f38a2e4494f82d0be3 Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Sat, 24 Jul 2021 13:38:12 +0100 Subject: [PATCH] Knossos.register, and delayed compilation (#960) --- .vscode/launch.json | 29 ++++- examples/dl-activations/gelu.py | 4 + examples/dl-activations/relu3.py | 10 ++ examples/dl-capsule/sqrl.py | 2 + src/bench/conftest.py | 16 +-- src/bench/run-bench.py | 13 +- src/python/ksc/__init__.py | 6 +- src/python/ksc/torch_frontend.py | 215 ++++++++++++++++++++++++------- src/python/ksc/torch_utils.py | 10 +- test/ts2k/test_dl_activations.py | 12 +- test/ts2k/test_torch_frontend.py | 49 +++---- 11 files changed, 254 insertions(+), 112 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 0f260306..f7badf27 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,7 +11,7 @@ "name": "(gdb) Launch python for relu3", "type": "cppdbg", "request": "launch", - "program": "/usr/bin/python", + "program": "/anaconda/envs/knossos/bin/python", "args": [ "src/bench/run-bench.py", "examples/dl-activations/relu3", @@ -30,6 +30,33 @@ } ] }, + { + "name": "(gdb) pytest", + "type": "cppdbg", + "request": "launch", + "program": "/anaconda/envs/knossos/bin/python", + "args": [ + "-m", + "pytest", + "src/bench/", + "--modulepath=examples/dl-capsule/sqrl", + "--benchmarkname=sqrl", + ], + "stopAtEntry": false, + "cwd": "${workspaceFolder}", + "environment": [ + {"name":"PYTHONPATH", "value":"./src/python"} + ], + "externalConsole": false, + "MIMode": "gdb", + "setupCommands": [ + { + "description": "Enable pretty-printing for gdb", + "text": "-enable-pretty-printing", + "ignoreFailures": true + } + ] + }, { "name": "Python: Current File", "type": "python", diff --git a/examples/dl-activations/gelu.py b/examples/dl-activations/gelu.py index d11f7a5c..cb5ed167 100644 --- a/examples/dl-activations/gelu.py +++ b/examples/dl-activations/gelu.py @@ -1,6 +1,8 @@ from math import sqrt, tanh, erf, exp import torch +import ksc.torch_frontend as knossos + from ksc.torch_utils import elementwise_apply_hack import ksc.compile from ksc.torch_frontend import cpp_string_to_autograd_function @@ -81,6 +83,7 @@ def sigmoid(x): # Gelu and activations +@knossos.register def gelu(x: float) -> float: return 0.5 * x * (1.0 + erf(x / sqrt(2))) @@ -124,6 +127,7 @@ def gelu_approx_tanh(x: float) -> float: return 0.5 * (1 + tanh(x * (C * x * x + B))) * x +@knossos.register def vgelu(x: torch.Tensor): return elementwise_apply_hack("gelu", x) diff --git a/examples/dl-activations/relu3.py b/examples/dl-activations/relu3.py index 0545241b..22c23c1e 100644 --- a/examples/dl-activations/relu3.py +++ b/examples/dl-activations/relu3.py @@ -10,11 +10,13 @@ from ksc.torch_frontend import ( ksc_string_to_autograd_function, cpp_string_to_autograd_function, ) +import ksc.torch_frontend as knossos from ksc.torch_utils import elementwise_apply_hack import torch._vmap_internals # BEGINDOC +@knossos.register def relu3(x: float) -> float: """ Like ReLu, but smoother @@ -60,6 +62,7 @@ if False: vrelu3_pytorch_nice = torch._vmap_internals.vmap(relu3_pytorch_nice) # run-bench: Knossos implementation +@knossos.register def vrelu3(x: torch.Tensor): return elementwise_apply_hack("relu3", x) @@ -606,3 +609,10 @@ def relu3_in_fcdnn(): # Run training # train_model(model) + + +if __name__ == "__main__": + y = relu3(0.3) + xs = next(vrelu3_bench_configs()) + ys = vrelu3(xs) + print(ys.sum()) diff --git a/examples/dl-capsule/sqrl.py b/examples/dl-capsule/sqrl.py index 78385eac..314d15ad 100644 --- a/examples/dl-capsule/sqrl.py +++ b/examples/dl-capsule/sqrl.py @@ -1,7 +1,9 @@ import torch +import ksc.torch_frontend as knossos # run-bench: Knossos source, and "nice" PyTorch implementation # BEGINDOC +@knossos.register def sqrl(x: torch.Tensor): """ sqrl: Squared Leaky Relu diff --git a/src/bench/conftest.py b/src/bench/conftest.py index 93ce3904..2190eb07 100644 --- a/src/bench/conftest.py +++ b/src/bench/conftest.py @@ -12,7 +12,7 @@ from collections import namedtuple from contextlib import contextmanager from typing import Callable -from ksc.torch_frontend import tsmod2ksmod +from ksc.torch_frontend import KscStub from ksc import utils @@ -97,7 +97,7 @@ def function_to_manual_cuda_benchmarks(func): def functions_to_benchmark( mod, benchmark_name, example_inputs, torch_extension_name_base ): - for fn_name, fn_obj in inspect.getmembers(mod, lambda m: inspect.isfunction(m)): + for fn_name, fn_obj in inspect.getmembers(mod): if fn_name.startswith(benchmark_name): if fn_name == benchmark_name + "_bench_configs": continue @@ -106,17 +106,15 @@ def functions_to_benchmark( elif fn_name == benchmark_name + "_pytorch_nice": yield BenchmarkFunction("PyTorch Nice", fn_obj) elif fn_name == benchmark_name: + assert isinstance(fn_obj, KscStub) torch_extension_name = ( "ksc_src_bench_" + torch_extension_name_base + "_" + benchmark_name ) - ks_mod = tsmod2ksmod( - mod, - benchmark_name, - torch_extension_name, - example_inputs, - generate_lm=False, + ks_compiled = fn_obj.compile( + torch_extension_name=torch_extension_name, + example_inputs=example_inputs, ) - yield BenchmarkFunction("Knossos", ks_mod.apply) + yield BenchmarkFunction("Knossos", ks_compiled.apply) elif fn_name == benchmark_name + "_cuda_init": if torch.cuda.is_available(): yield from function_to_manual_cuda_benchmarks(fn_obj) diff --git a/src/bench/run-bench.py b/src/bench/run-bench.py index fb9ee0fe..a6a2fae6 100644 --- a/src/bench/run-bench.py +++ b/src/bench/run-bench.py @@ -1,10 +1,6 @@ import time -from ksc import torch_frontend import torch -import ksc.torch_frontend -from ksc.torch_frontend import tsmod2ksmod - class time_sampler: def __init__(self, minimizing=False): @@ -120,14 +116,9 @@ def bench(module_file, bench_name): else: print(f"Ignoring {fn_name}") - # TODO: elementwise_apply - torch_extension_name = "ksc_run_bench_" + bench_name - ks_compiled = tsmod2ksmod( - mod, - bench_name, - torch_extension_name, + ks_compiled = ks_raw.compile( + torch_extension_name="ksc_run_bench_" + bench_name, example_inputs=(configs[0],), - generate_lm=False, ) for arg in configs: diff --git a/src/python/ksc/__init__.py b/src/python/ksc/__init__.py index bcc6b733..f0799494 100644 --- a/src/python/ksc/__init__.py +++ b/src/python/ksc/__init__.py @@ -1,3 +1,7 @@ from ksc.tracing.jitting import trace from ksc.ks_function import KsFunction -from ksc.torch_frontend import ts2ks, ts2ks_fromgraph, ts2mod, tsmod2ksmod +from ksc.torch_frontend import ( + ts2ks, + ts2ks_fromgraph, + register, +) diff --git a/src/python/ksc/torch_frontend.py b/src/python/ksc/torch_frontend.py index 4cfb415c..59bff872 100644 --- a/src/python/ksc/torch_frontend.py +++ b/src/python/ksc/torch_frontend.py @@ -1,8 +1,10 @@ -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Optional +from types import ModuleType +from dataclasses import dataclass from contextlib import contextmanager - import functools import numpy +import inspect import torch import torch.onnx @@ -464,6 +466,10 @@ def backward_template(py_mod, ctx, *args): return torch_from_ks(outputs) +class KscAutogradFunction(torch.autograd.Function): + pass + + def make_KscAutogradFunction(py_mod): # We need to make a new class for every py_mod, as PyTorch requires forward and backward to be # staticmethods. This is not too expensive, as each mod needs to be compiled anyway. @@ -471,7 +477,7 @@ def make_KscAutogradFunction(py_mod): backward = lambda ctx, args: backward_template(py_mod, ctx, args) return type( "KscAutogradFunction_" + py_mod.__name__, - (torch.autograd.Function,), + (KscAutogradFunction,), { "py_mod": py_mod, "forward": staticmethod(forward), @@ -602,53 +608,6 @@ def cpp_string_to_autograd_function( return make_KscAutogradFunction(mod) -import inspect - - -def tsmod2ksmod( - module, function_name, torch_extension_name, example_inputs, generate_lm=True -): - global todo_stack - todo_stack = {function_name} - ksc_defs = [] - while len(todo_stack) > 0: - print(f"tsmod2ksmod: Remaining: {todo_stack}") - for fn in inspect.getmembers(module, inspect.isfunction): - fn_name, fn_obj = fn - if fn_name in todo_stack: - todo_stack.remove(fn_name) - print(f"tsmod2ksmod: converting {fn_name}, remaining: {todo_stack}") - ts_fn = torch.jit.script(fn_obj) - ts_graph = ts_fn.graph - ksc_def = ts2ks_fromgraph(False, fn_name, ts_graph, example_inputs) - ksc_defs.insert(0, ksc_def) - - elementwise = is_elementwise_operation(ksc_defs[-1]) - if elementwise: - ksc_defs.pop() - - entry_def = ksc_defs[-1] - return ksc_defs_to_autograd_function( - ksc_defs, - entry_def, - torch_extension_name, - elementwise=elementwise, - generate_lm=generate_lm, - ) - - -def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True): - fn = torch.jit.script(function) - ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs) - return ksc_defs_to_autograd_function( - [ksc_def], - ksc_def, - torch_extension_name, - elementwise=False, - generate_lm=generate_lm, - ) - - def is_elementwise_operation(ksc_def): """ Inspect the body of a def to determine whether it is a @@ -691,3 +650,159 @@ def is_elementwise_operation(ksc_def): print(f"Num args {len(ksc_def.args)}") return False return is_map(ksc_def.body, ksc_def.args[0].name) + + +def _tsmod2ksmod( + module, function_obj, torch_extension_name, example_inputs, generate_lm=True +): + global todo_stack + todo_stack = {function_obj} + ksc_defs = [] + while len(todo_stack) > 0: + print(f"tsmod2ksmod: Remaining: {todo_stack}") + todo = next(iter(todo_stack)) + if isinstance(todo, str): + # String function name, try to find it in the caller's module + todo_fn = None + for module_fn_name, module_fn_obj in inspect.getmembers(module): + if module_fn_name == todo: + print(f"tsmod2ksmod: converting {todo}, remaining: {todo_stack}") + if isinstance(module_fn_obj, KscStub): + todo_fn = module_fn_obj.raw_f + else: + todo_fn = module_fn_obj + break + # Check we found it + if not todo_fn: + raise ValueError(f"Did not find string-named function {todo}") + else: + todo_fn = todo + + todo_stack.remove(todo) + + ts_fn = torch.jit.script(todo_fn) + ts_graph = ts_fn.graph + ksc_def = ts2ks_fromgraph(False, todo_fn.__name__, ts_graph, example_inputs) + ksc_defs.insert(0, ksc_def) + + elementwise = is_elementwise_operation(ksc_defs[-1]) + if elementwise: + ksc_defs.pop() + + entry_def = ksc_defs[-1] + return ksc_defs_to_autograd_function( + ksc_defs, + entry_def, + torch_extension_name, + elementwise=elementwise, + generate_lm=generate_lm, + ) + + +def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True): + fn = torch.jit.script(function) + ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs) + return ksc_defs_to_autograd_function( + [ksc_def], + ksc_def, + torch_extension_name, + elementwise=False, + generate_lm=generate_lm, + ) + + +@dataclass +class KscStub: + raw_f: Callable + generate_lm: bool + f_module: ModuleType + compiled: Optional[KscAutogradFunction] + + def __call__(self, *args): + """ + Call with pytorch tensors. + This calls the KscAutoGradFunction apply method, so is suitable + for use in the "forward/backward" pattern for gradient computation. + """ + self.ensure_compiled(args) + return self.compiled.apply(*args) + + def _entry(self, *args): + """ + Directly call the Knossos compiled function. + Does not wrap torch tensors, or reset memory allocator. + For test use only + """ + self.ensure_compiled(args) + return self.compiled.py_mod.entry(*args) + + def _entry_vjp(self, *args): + """ + Directly call the Knossos vjp function. + Does not wrap torch tensors, or reset memory allocator. + For test use only + """ + assert self.compiled # TODO: infer call args from vjp args + return self.compiled.py_mod.entry_vjp(*args) + + def compile(self, example_inputs, torch_extension_name): + self.compiled = _tsmod2ksmod( + self.f_module, + self.raw_f, + torch_extension_name=torch_extension_name, + example_inputs=example_inputs, + generate_lm=self.generate_lm, + ) + + return self.compiled + + def ensure_compiled(self, example_inputs): + if not self.compiled: + print(f"knossos.register: Compiling {self.raw_f.__name__}") + torch_extension_name = ( + "KscStub_" + self.f_module.__name__ + "_" + self.raw_f.__name__ + ) + self.compile(example_inputs, torch_extension_name) + + +def optional_arg_decorator(register): + # https://stackoverflow.com/a/20966822 + def wrapped_decorator(*args, **kwargs): + # Grab the caller's module here, as wrapped_decorator may be 1 or 2 deeper + module = inspect.getmodule(inspect.currentframe().f_back) + + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return register(args[0], module) + + # we have optional args + def real_decorator(f): + return register(f, module, *args, **kwargs) + + return real_decorator + + return wrapped_decorator + + +@optional_arg_decorator +def register(f: Callable, module: ModuleType, generate_lm=False) -> KscStub: + """ + Main Knossos entry point. + + The @register decorator transforms a TorchScript function into a + KscAutogradFunction which implements the function and its + derivatives. + ``` + @knossos.register + def f(x : torch.Tensor) -> torch.Tensor: + return x * sin(x) + ``` + Endows f with the following behaviours + ``` + y = f(x) # Fast (C++/CUDA/...) computation of f(x) + vjp(f, x, dy) # Fast computation of dot(dy, [df_i/dx_j]) + ``` + The implementation delays compilation until the first call, or + when "f.compile()" is explicitly called. + """ + + return KscStub(f, generate_lm, module, None) diff --git a/src/python/ksc/torch_utils.py b/src/python/ksc/torch_utils.py index 8b2f4f2f..f33ae517 100644 --- a/src/python/ksc/torch_utils.py +++ b/src/python/ksc/torch_utils.py @@ -39,4 +39,12 @@ def elementwise_apply(f: Callable[[float], float], x: torch.Tensor): @torch.jit.ignore def elementwise_apply_hack(f: str, x: torch.Tensor): - pass + # Convert string function name to callable + import inspect + + module = inspect.getmodule(inspect.currentframe().f_back) + for fn_name, fn_obj in inspect.getmembers(module): + if fn_name == f: + return elementwise_apply_pt18(fn_obj, x) + + assert False diff --git a/test/ts2k/test_dl_activations.py b/test/ts2k/test_dl_activations.py index 00ad9bd5..877013fe 100644 --- a/test/ts2k/test_dl_activations.py +++ b/test/ts2k/test_dl_activations.py @@ -5,8 +5,6 @@ import torch import inspect import importlib -from ksc.torch_frontend import tsmod2ksmod - @pytest.mark.parametrize( "module_file,bench_name", @@ -22,21 +20,19 @@ def test_bench(module_file, bench_name): module_dir, module_name = os.path.split(module_file) sys.path.append(module_dir) mod = importlib.import_module(module_name) - for fn in inspect.getmembers(mod, inspect.isfunction): - fn_name, fn_obj = fn + for fn_name, fn_obj in inspect.getmembers(mod): if fn_name == bench_name + "_bench_configs": configs = list(fn_obj()) elif fn_name == bench_name + "_pytorch": pt_fast = fn_obj + elif fn_name == bench_name: + ks_raw = fn_obj else: print(f"Ignoring {fn_name}") arg = configs[0] - torch_extension_name = "ksc_test_dl_activations_" + bench_name - ks_compiled = tsmod2ksmod( - mod, bench_name, torch_extension_name, example_inputs=(arg,), generate_lm=False - ) + ks_compiled = ks_raw.compile((arg,), "ksc_test_dl_activations_" + bench_name) pt_arg = arg.detach() pt_arg.requires_grad = True diff --git a/test/ts2k/test_torch_frontend.py b/test/ts2k/test_torch_frontend.py index 85cbcc16..1020afa5 100644 --- a/test/ts2k/test_torch_frontend.py +++ b/test/ts2k/test_torch_frontend.py @@ -6,6 +6,7 @@ import numpy from ksc import utils from ksc.type import Type +import ksc.torch_frontend as knossos from ksc.torch_frontend import ts2mod @@ -30,6 +31,7 @@ def grad_bar1(a: int, x: float, b: str): return torch.sin(t) + t * torch.cos(t) +@knossos.register def relux(x: float): if x < 0.0: return 0.1 * x @@ -53,31 +55,19 @@ def f(x: float): return r2 -ks_relux = None - - -def compile_relux(): - global ks_relux - if ks_relux is None: - print("Compiling relux") - torch_extension_name = "ksc_test_ts2k_relux" - ks_relux = ts2mod(relux, (1.0,), torch_extension_name) - - -def test_relux(): - compile_relux() - ks_ans = ks_relux.py_mod.entry(2.0) - ans = relux(2.0) +def test_ts2k_relux(): + ks_ans = relux._entry(2.0) + ans = relux.raw_f(2.0) assert pytest.approx(ks_ans, 1e-6) == ans -def test_relux_grad(): - compile_relux() - ks_ans = ks_relux.py_mod.entry_vjp(1.3, 1.0) +def test_ts2k_relux_grad(): + ks_ans = relux._entry_vjp(1.3, 1.0) ans = grad_relux(1.3) assert pytest.approx(ks_ans, 1e-6) == ans +@knossos.register(generate_lm=True) def bar(a: int, x: float): y = torch.tensor([[1.1, -1.2], [2.1, 2.2]]) @@ -107,20 +97,19 @@ def grad_bar(a: int, x: float): def test_bar(): a, x = 1, 12.34 - torch_extension_name = "ksc_test_ts2k_bar" - ks_bar = ts2mod(bar, (a, x), torch_extension_name) # Check primal - ks_ans = ks_bar.py_mod.entry(a, x) - ans = bar(a, x) + ks_ans = bar._entry(a, x) + ans = bar.raw_f(a, x) assert pytest.approx(ks_ans, 1e-5) == ans # Check grad - ks_ans = ks_bar.py_mod.entry_vjp((a, x), 1.0) + ks_ans = bar._entry_vjp((a, x), 1.0) ans = grad_bar(a, x) assert pytest.approx(ks_ans[1], 1e-5) == ans[1] +@knossos.register(generate_lm=True) def far(x: torch.Tensor, y: torch.Tensor): xx = torch.cat([x, y], dim=1) xbar = torch.mean(xx) @@ -134,25 +123,23 @@ def far(x: torch.Tensor, y: torch.Tensor): def test_far(): x = torch.randn(2, 3) y = torch.randn(2, 5) - torch_extension_name = "ksc_test_ts2k_far" - ks_far = ts2mod(far, (x, y), torch_extension_name) - ks_ans = ks_far.py_mod.entry(ks_far.adapt(x), ks_far.adapt(y)) - ans = far(x, y) + ks_ans = far._entry(x, y) + ans = far.raw_f(x, y) assert pytest.approx(ks_ans, 1e-5) == ans.item() def test_cat(): + @knossos.register(generate_lm=True) def f(x: torch.Tensor, y: torch.Tensor): return torch.cat([x, y], dim=1) x = torch.randn(2, 3) y = torch.randn(2, 5) - torch_extension_name = "ksc_test_ts2k_cat" - ks_f = ts2mod(f, (x, y), torch_extension_name) - ks_ans = ks_f.py_mod.entry(ks_f.adapt(x), ks_f.adapt(y)) + + ks_ans = f._entry(x, y) ks_ans_np = numpy.array(ks_ans, copy=True) - py_ans = f(x, y) + py_ans = f.raw_f(x, y) assert (ks_ans_np == py_ans.numpy()).all() # non-approx