Knossos.register, and delayed compilation (#960)

This commit is contained in:
Andrew Fitzgibbon 2021-07-24 13:38:12 +01:00 коммит произвёл GitHub
Родитель 0daf53e905
Коммит 6cff51b0a5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 254 добавлений и 112 удалений

29
.vscode/launch.json поставляемый
Просмотреть файл

@ -11,7 +11,7 @@
"name": "(gdb) Launch python for relu3",
"type": "cppdbg",
"request": "launch",
"program": "/usr/bin/python",
"program": "/anaconda/envs/knossos/bin/python",
"args": [
"src/bench/run-bench.py",
"examples/dl-activations/relu3",
@ -30,6 +30,33 @@
}
]
},
{
"name": "(gdb) pytest",
"type": "cppdbg",
"request": "launch",
"program": "/anaconda/envs/knossos/bin/python",
"args": [
"-m",
"pytest",
"src/bench/",
"--modulepath=examples/dl-capsule/sqrl",
"--benchmarkname=sqrl",
],
"stopAtEntry": false,
"cwd": "${workspaceFolder}",
"environment": [
{"name":"PYTHONPATH", "value":"./src/python"}
],
"externalConsole": false,
"MIMode": "gdb",
"setupCommands": [
{
"description": "Enable pretty-printing for gdb",
"text": "-enable-pretty-printing",
"ignoreFailures": true
}
]
},
{
"name": "Python: Current File",
"type": "python",

Просмотреть файл

@ -1,6 +1,8 @@
from math import sqrt, tanh, erf, exp
import torch
import ksc.torch_frontend as knossos
from ksc.torch_utils import elementwise_apply_hack
import ksc.compile
from ksc.torch_frontend import cpp_string_to_autograd_function
@ -81,6 +83,7 @@ def sigmoid(x):
# Gelu and activations
@knossos.register
def gelu(x: float) -> float:
return 0.5 * x * (1.0 + erf(x / sqrt(2)))
@ -124,6 +127,7 @@ def gelu_approx_tanh(x: float) -> float:
return 0.5 * (1 + tanh(x * (C * x * x + B))) * x
@knossos.register
def vgelu(x: torch.Tensor):
return elementwise_apply_hack("gelu", x)

Просмотреть файл

@ -10,11 +10,13 @@ from ksc.torch_frontend import (
ksc_string_to_autograd_function,
cpp_string_to_autograd_function,
)
import ksc.torch_frontend as knossos
from ksc.torch_utils import elementwise_apply_hack
import torch._vmap_internals
# BEGINDOC
@knossos.register
def relu3(x: float) -> float:
"""
Like ReLu, but smoother
@ -60,6 +62,7 @@ if False:
vrelu3_pytorch_nice = torch._vmap_internals.vmap(relu3_pytorch_nice)
# run-bench: Knossos implementation
@knossos.register
def vrelu3(x: torch.Tensor):
return elementwise_apply_hack("relu3", x)
@ -606,3 +609,10 @@ def relu3_in_fcdnn():
# Run training
# train_model(model)
if __name__ == "__main__":
y = relu3(0.3)
xs = next(vrelu3_bench_configs())
ys = vrelu3(xs)
print(ys.sum())

Просмотреть файл

@ -1,7 +1,9 @@
import torch
import ksc.torch_frontend as knossos
# run-bench: Knossos source, and "nice" PyTorch implementation
# BEGINDOC
@knossos.register
def sqrl(x: torch.Tensor):
"""
sqrl: Squared Leaky Relu

Просмотреть файл

@ -12,7 +12,7 @@ from collections import namedtuple
from contextlib import contextmanager
from typing import Callable
from ksc.torch_frontend import tsmod2ksmod
from ksc.torch_frontend import KscStub
from ksc import utils
@ -97,7 +97,7 @@ def function_to_manual_cuda_benchmarks(func):
def functions_to_benchmark(
mod, benchmark_name, example_inputs, torch_extension_name_base
):
for fn_name, fn_obj in inspect.getmembers(mod, lambda m: inspect.isfunction(m)):
for fn_name, fn_obj in inspect.getmembers(mod):
if fn_name.startswith(benchmark_name):
if fn_name == benchmark_name + "_bench_configs":
continue
@ -106,17 +106,15 @@ def functions_to_benchmark(
elif fn_name == benchmark_name + "_pytorch_nice":
yield BenchmarkFunction("PyTorch Nice", fn_obj)
elif fn_name == benchmark_name:
assert isinstance(fn_obj, KscStub)
torch_extension_name = (
"ksc_src_bench_" + torch_extension_name_base + "_" + benchmark_name
)
ks_mod = tsmod2ksmod(
mod,
benchmark_name,
torch_extension_name,
example_inputs,
generate_lm=False,
ks_compiled = fn_obj.compile(
torch_extension_name=torch_extension_name,
example_inputs=example_inputs,
)
yield BenchmarkFunction("Knossos", ks_mod.apply)
yield BenchmarkFunction("Knossos", ks_compiled.apply)
elif fn_name == benchmark_name + "_cuda_init":
if torch.cuda.is_available():
yield from function_to_manual_cuda_benchmarks(fn_obj)

Просмотреть файл

@ -1,10 +1,6 @@
import time
from ksc import torch_frontend
import torch
import ksc.torch_frontend
from ksc.torch_frontend import tsmod2ksmod
class time_sampler:
def __init__(self, minimizing=False):
@ -120,14 +116,9 @@ def bench(module_file, bench_name):
else:
print(f"Ignoring {fn_name}")
# TODO: elementwise_apply
torch_extension_name = "ksc_run_bench_" + bench_name
ks_compiled = tsmod2ksmod(
mod,
bench_name,
torch_extension_name,
ks_compiled = ks_raw.compile(
torch_extension_name="ksc_run_bench_" + bench_name,
example_inputs=(configs[0],),
generate_lm=False,
)
for arg in configs:

Просмотреть файл

@ -1,3 +1,7 @@
from ksc.tracing.jitting import trace
from ksc.ks_function import KsFunction
from ksc.torch_frontend import ts2ks, ts2ks_fromgraph, ts2mod, tsmod2ksmod
from ksc.torch_frontend import (
ts2ks,
ts2ks_fromgraph,
register,
)

Просмотреть файл

@ -1,8 +1,10 @@
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Optional
from types import ModuleType
from dataclasses import dataclass
from contextlib import contextmanager
import functools
import numpy
import inspect
import torch
import torch.onnx
@ -464,6 +466,10 @@ def backward_template(py_mod, ctx, *args):
return torch_from_ks(outputs)
class KscAutogradFunction(torch.autograd.Function):
pass
def make_KscAutogradFunction(py_mod):
# We need to make a new class for every py_mod, as PyTorch requires forward and backward to be
# staticmethods. This is not too expensive, as each mod needs to be compiled anyway.
@ -471,7 +477,7 @@ def make_KscAutogradFunction(py_mod):
backward = lambda ctx, args: backward_template(py_mod, ctx, args)
return type(
"KscAutogradFunction_" + py_mod.__name__,
(torch.autograd.Function,),
(KscAutogradFunction,),
{
"py_mod": py_mod,
"forward": staticmethod(forward),
@ -602,53 +608,6 @@ def cpp_string_to_autograd_function(
return make_KscAutogradFunction(mod)
import inspect
def tsmod2ksmod(
module, function_name, torch_extension_name, example_inputs, generate_lm=True
):
global todo_stack
todo_stack = {function_name}
ksc_defs = []
while len(todo_stack) > 0:
print(f"tsmod2ksmod: Remaining: {todo_stack}")
for fn in inspect.getmembers(module, inspect.isfunction):
fn_name, fn_obj = fn
if fn_name in todo_stack:
todo_stack.remove(fn_name)
print(f"tsmod2ksmod: converting {fn_name}, remaining: {todo_stack}")
ts_fn = torch.jit.script(fn_obj)
ts_graph = ts_fn.graph
ksc_def = ts2ks_fromgraph(False, fn_name, ts_graph, example_inputs)
ksc_defs.insert(0, ksc_def)
elementwise = is_elementwise_operation(ksc_defs[-1])
if elementwise:
ksc_defs.pop()
entry_def = ksc_defs[-1]
return ksc_defs_to_autograd_function(
ksc_defs,
entry_def,
torch_extension_name,
elementwise=elementwise,
generate_lm=generate_lm,
)
def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True):
fn = torch.jit.script(function)
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
return ksc_defs_to_autograd_function(
[ksc_def],
ksc_def,
torch_extension_name,
elementwise=False,
generate_lm=generate_lm,
)
def is_elementwise_operation(ksc_def):
"""
Inspect the body of a def to determine whether it is a
@ -691,3 +650,159 @@ def is_elementwise_operation(ksc_def):
print(f"Num args {len(ksc_def.args)}")
return False
return is_map(ksc_def.body, ksc_def.args[0].name)
def _tsmod2ksmod(
module, function_obj, torch_extension_name, example_inputs, generate_lm=True
):
global todo_stack
todo_stack = {function_obj}
ksc_defs = []
while len(todo_stack) > 0:
print(f"tsmod2ksmod: Remaining: {todo_stack}")
todo = next(iter(todo_stack))
if isinstance(todo, str):
# String function name, try to find it in the caller's module
todo_fn = None
for module_fn_name, module_fn_obj in inspect.getmembers(module):
if module_fn_name == todo:
print(f"tsmod2ksmod: converting {todo}, remaining: {todo_stack}")
if isinstance(module_fn_obj, KscStub):
todo_fn = module_fn_obj.raw_f
else:
todo_fn = module_fn_obj
break
# Check we found it
if not todo_fn:
raise ValueError(f"Did not find string-named function {todo}")
else:
todo_fn = todo
todo_stack.remove(todo)
ts_fn = torch.jit.script(todo_fn)
ts_graph = ts_fn.graph
ksc_def = ts2ks_fromgraph(False, todo_fn.__name__, ts_graph, example_inputs)
ksc_defs.insert(0, ksc_def)
elementwise = is_elementwise_operation(ksc_defs[-1])
if elementwise:
ksc_defs.pop()
entry_def = ksc_defs[-1]
return ksc_defs_to_autograd_function(
ksc_defs,
entry_def,
torch_extension_name,
elementwise=elementwise,
generate_lm=generate_lm,
)
def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True):
fn = torch.jit.script(function)
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
return ksc_defs_to_autograd_function(
[ksc_def],
ksc_def,
torch_extension_name,
elementwise=False,
generate_lm=generate_lm,
)
@dataclass
class KscStub:
raw_f: Callable
generate_lm: bool
f_module: ModuleType
compiled: Optional[KscAutogradFunction]
def __call__(self, *args):
"""
Call with pytorch tensors.
This calls the KscAutoGradFunction apply method, so is suitable
for use in the "forward/backward" pattern for gradient computation.
"""
self.ensure_compiled(args)
return self.compiled.apply(*args)
def _entry(self, *args):
"""
Directly call the Knossos compiled function.
Does not wrap torch tensors, or reset memory allocator.
For test use only
"""
self.ensure_compiled(args)
return self.compiled.py_mod.entry(*args)
def _entry_vjp(self, *args):
"""
Directly call the Knossos vjp function.
Does not wrap torch tensors, or reset memory allocator.
For test use only
"""
assert self.compiled # TODO: infer call args from vjp args
return self.compiled.py_mod.entry_vjp(*args)
def compile(self, example_inputs, torch_extension_name):
self.compiled = _tsmod2ksmod(
self.f_module,
self.raw_f,
torch_extension_name=torch_extension_name,
example_inputs=example_inputs,
generate_lm=self.generate_lm,
)
return self.compiled
def ensure_compiled(self, example_inputs):
if not self.compiled:
print(f"knossos.register: Compiling {self.raw_f.__name__}")
torch_extension_name = (
"KscStub_" + self.f_module.__name__ + "_" + self.raw_f.__name__
)
self.compile(example_inputs, torch_extension_name)
def optional_arg_decorator(register):
# https://stackoverflow.com/a/20966822
def wrapped_decorator(*args, **kwargs):
# Grab the caller's module here, as wrapped_decorator may be 1 or 2 deeper
module = inspect.getmodule(inspect.currentframe().f_back)
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return register(args[0], module)
# we have optional args
def real_decorator(f):
return register(f, module, *args, **kwargs)
return real_decorator
return wrapped_decorator
@optional_arg_decorator
def register(f: Callable, module: ModuleType, generate_lm=False) -> KscStub:
"""
Main Knossos entry point.
The @register decorator transforms a TorchScript function into a
KscAutogradFunction which implements the function and its
derivatives.
```
@knossos.register
def f(x : torch.Tensor) -> torch.Tensor:
return x * sin(x)
```
Endows f with the following behaviours
```
y = f(x) # Fast (C++/CUDA/...) computation of f(x)
vjp(f, x, dy) # Fast computation of dot(dy, [df_i/dx_j])
```
The implementation delays compilation until the first call, or
when "f.compile()" is explicitly called.
"""
return KscStub(f, generate_lm, module, None)

Просмотреть файл

@ -39,4 +39,12 @@ def elementwise_apply(f: Callable[[float], float], x: torch.Tensor):
@torch.jit.ignore
def elementwise_apply_hack(f: str, x: torch.Tensor):
pass
# Convert string function name to callable
import inspect
module = inspect.getmodule(inspect.currentframe().f_back)
for fn_name, fn_obj in inspect.getmembers(module):
if fn_name == f:
return elementwise_apply_pt18(fn_obj, x)
assert False

Просмотреть файл

@ -5,8 +5,6 @@ import torch
import inspect
import importlib
from ksc.torch_frontend import tsmod2ksmod
@pytest.mark.parametrize(
"module_file,bench_name",
@ -22,21 +20,19 @@ def test_bench(module_file, bench_name):
module_dir, module_name = os.path.split(module_file)
sys.path.append(module_dir)
mod = importlib.import_module(module_name)
for fn in inspect.getmembers(mod, inspect.isfunction):
fn_name, fn_obj = fn
for fn_name, fn_obj in inspect.getmembers(mod):
if fn_name == bench_name + "_bench_configs":
configs = list(fn_obj())
elif fn_name == bench_name + "_pytorch":
pt_fast = fn_obj
elif fn_name == bench_name:
ks_raw = fn_obj
else:
print(f"Ignoring {fn_name}")
arg = configs[0]
torch_extension_name = "ksc_test_dl_activations_" + bench_name
ks_compiled = tsmod2ksmod(
mod, bench_name, torch_extension_name, example_inputs=(arg,), generate_lm=False
)
ks_compiled = ks_raw.compile((arg,), "ksc_test_dl_activations_" + bench_name)
pt_arg = arg.detach()
pt_arg.requires_grad = True

Просмотреть файл

@ -6,6 +6,7 @@ import numpy
from ksc import utils
from ksc.type import Type
import ksc.torch_frontend as knossos
from ksc.torch_frontend import ts2mod
@ -30,6 +31,7 @@ def grad_bar1(a: int, x: float, b: str):
return torch.sin(t) + t * torch.cos(t)
@knossos.register
def relux(x: float):
if x < 0.0:
return 0.1 * x
@ -53,31 +55,19 @@ def f(x: float):
return r2
ks_relux = None
def compile_relux():
global ks_relux
if ks_relux is None:
print("Compiling relux")
torch_extension_name = "ksc_test_ts2k_relux"
ks_relux = ts2mod(relux, (1.0,), torch_extension_name)
def test_relux():
compile_relux()
ks_ans = ks_relux.py_mod.entry(2.0)
ans = relux(2.0)
def test_ts2k_relux():
ks_ans = relux._entry(2.0)
ans = relux.raw_f(2.0)
assert pytest.approx(ks_ans, 1e-6) == ans
def test_relux_grad():
compile_relux()
ks_ans = ks_relux.py_mod.entry_vjp(1.3, 1.0)
def test_ts2k_relux_grad():
ks_ans = relux._entry_vjp(1.3, 1.0)
ans = grad_relux(1.3)
assert pytest.approx(ks_ans, 1e-6) == ans
@knossos.register(generate_lm=True)
def bar(a: int, x: float):
y = torch.tensor([[1.1, -1.2], [2.1, 2.2]])
@ -107,20 +97,19 @@ def grad_bar(a: int, x: float):
def test_bar():
a, x = 1, 12.34
torch_extension_name = "ksc_test_ts2k_bar"
ks_bar = ts2mod(bar, (a, x), torch_extension_name)
# Check primal
ks_ans = ks_bar.py_mod.entry(a, x)
ans = bar(a, x)
ks_ans = bar._entry(a, x)
ans = bar.raw_f(a, x)
assert pytest.approx(ks_ans, 1e-5) == ans
# Check grad
ks_ans = ks_bar.py_mod.entry_vjp((a, x), 1.0)
ks_ans = bar._entry_vjp((a, x), 1.0)
ans = grad_bar(a, x)
assert pytest.approx(ks_ans[1], 1e-5) == ans[1]
@knossos.register(generate_lm=True)
def far(x: torch.Tensor, y: torch.Tensor):
xx = torch.cat([x, y], dim=1)
xbar = torch.mean(xx)
@ -134,25 +123,23 @@ def far(x: torch.Tensor, y: torch.Tensor):
def test_far():
x = torch.randn(2, 3)
y = torch.randn(2, 5)
torch_extension_name = "ksc_test_ts2k_far"
ks_far = ts2mod(far, (x, y), torch_extension_name)
ks_ans = ks_far.py_mod.entry(ks_far.adapt(x), ks_far.adapt(y))
ans = far(x, y)
ks_ans = far._entry(x, y)
ans = far.raw_f(x, y)
assert pytest.approx(ks_ans, 1e-5) == ans.item()
def test_cat():
@knossos.register(generate_lm=True)
def f(x: torch.Tensor, y: torch.Tensor):
return torch.cat([x, y], dim=1)
x = torch.randn(2, 3)
y = torch.randn(2, 5)
torch_extension_name = "ksc_test_ts2k_cat"
ks_f = ts2mod(f, (x, y), torch_extension_name)
ks_ans = ks_f.py_mod.entry(ks_f.adapt(x), ks_f.adapt(y))
ks_ans = f._entry(x, y)
ks_ans_np = numpy.array(ks_ans, copy=True)
py_ans = f(x, y)
py_ans = f.raw_f(x, y)
assert (ks_ans_np == py_ans.numpy()).all() # non-approx