Implement vmap (#1019)
This commit is contained in:
Родитель
7ab2d24250
Коммит
a79b395997
|
@ -5,6 +5,7 @@
|
|||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
|
||||
|
||||
|
||||
{
|
||||
|
@ -39,6 +40,7 @@
|
|||
"-m",
|
||||
"pytest",
|
||||
"src/bench/",
|
||||
"-v",
|
||||
"--modulepath=examples/dl-capsule/sqrl",
|
||||
"--benchmarkname=sqrl",
|
||||
],
|
||||
|
|
|
@ -42,7 +42,33 @@ def sqrl_bench_configs():
|
|||
# vsqrl - vectorized sqrl
|
||||
#
|
||||
|
||||
vsqrl = knossos.vmap(sqrl)
|
||||
vsqrl = knossos.register_direct(sqrl, vmap=True, generate_lm=True) # TODO: Carbuncle
|
||||
|
||||
|
||||
def sqrl_pytorch_where(x):
|
||||
"""
|
||||
Replace "if" with "where" to get torch.vmap to work
|
||||
"""
|
||||
y = torch.sum(x)
|
||||
t = torch.where(y < 0, -0.125 * x, 1 / 2 * x ** 2)
|
||||
tsint = torch.sin(t) * t
|
||||
return torch.mean(tsint)
|
||||
|
||||
|
||||
import torch._vmap_internals
|
||||
|
||||
vsqrl_pytorch_nice = torch._vmap_internals.vmap(sqrl_pytorch_where)
|
||||
|
||||
|
||||
def vsqrl_pytorch(x):
|
||||
"""
|
||||
Hand-vectorized pytorch implementation, assuming x is rank 3
|
||||
"""
|
||||
y = torch.sum(x, (1, 2), keepdim=True)
|
||||
y_lt_0 = (y < 0).repeat((1, *x.size()[1:]))
|
||||
t = torch.where(y_lt_0, -0.125 * x, 1 / 2 * x ** 2)
|
||||
tsint = torch.sin(t) * t
|
||||
return torch.mean(tsint, (1, 2))
|
||||
|
||||
|
||||
# run-bench: Define a range of values at which to call the methods
|
||||
|
|
|
@ -149,7 +149,7 @@ def func_namer(benchmark_func):
|
|||
|
||||
|
||||
def config_namer(config):
|
||||
return str(config.shape)
|
||||
return "x".join([str(x) for x in config.shape])
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
|
|
@ -2,10 +2,12 @@ export PYTHONPATH="./src/python"
|
|||
|
||||
# TODO: this should be a makefile. fred.csv: fred.py etc
|
||||
BENCH="pytest src/bench/ \
|
||||
-v\
|
||||
--benchmark-autosave --benchmark-max-time=5.0\
|
||||
--benchmark-name=short --benchmark-sort=name --benchmark-group-by=group,func\
|
||||
--benchmark-columns=median,iqr,outliers,mean,stddev,min,max,iterations,rounds"
|
||||
|
||||
# $BENCH --modulepath=examples/dl-activations/relu3 --benchmarkname=vrelu3
|
||||
$BENCH --modulepath=examples/dl-capsule/sqrl --benchmarkname=sqrl
|
||||
$BENCH --modulepath=examples/dl-capsule/sqrl --benchmarkname=vsqrl
|
||||
$BENCH --modulepath=examples/dl-activations/gelu --benchmarkname=vgelu
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import time
|
||||
import os
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -52,6 +54,9 @@ def fun_and_grad_matches(f, g, arg):
|
|||
return True
|
||||
|
||||
|
||||
all_messages = []
|
||||
|
||||
|
||||
def timeit(msg, fn, arg):
|
||||
MAX_TIME = 5 # No no need to run beyond MAX_TIME sec to get accurate benchmarks
|
||||
end_time = time.time() + MAX_TIME
|
||||
|
@ -59,6 +64,10 @@ def timeit(msg, fn, arg):
|
|||
forward_timer = time_sampler()
|
||||
backward_timer = time_sampler()
|
||||
nruns = 5000
|
||||
|
||||
mem_used_start = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
|
||||
print(f"run-bench: Memory {mem_used_start} before {msg}")
|
||||
|
||||
for _ in range(nruns):
|
||||
inference_timer.mark()
|
||||
with torch.no_grad():
|
||||
|
@ -75,16 +84,25 @@ def timeit(msg, fn, arg):
|
|||
backward_timer.record()
|
||||
|
||||
if time.time() > end_time:
|
||||
print(f"# Ran to timeout: {fn} {msg} ")
|
||||
print(f"# Ran to timeout: {msg} ")
|
||||
break
|
||||
|
||||
csum = grad[0].sum()
|
||||
|
||||
print(
|
||||
f"{msg:20} {csum:12.6e} Runs: {inference_timer.ncalls} | Inference: {inference_timer.ms:10.3f} ms |"
|
||||
mem_used_end = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
|
||||
print(f"run-bench: Memory {mem_used_end} after {msg}")
|
||||
|
||||
mem_used = mem_used_end - mem_used_start
|
||||
shape_str = "x".join([str(x) for x in arg.shape])
|
||||
msg = (
|
||||
f"{msg:20} {csum:12.5e} Runs: {inference_timer.ncalls:4d} | Inference: {inference_timer.ms:10.3f} ms |"
|
||||
f" Forward: {forward_timer.ms:10.3f} ms |"
|
||||
f" Backward {backward_timer.ms:10.3f} ms | {arg.shape}"
|
||||
f" Backward {backward_timer.ms:10.3f} ms |"
|
||||
f" Memory {mem_used:10.3f} MB |"
|
||||
f" {shape_str}"
|
||||
)
|
||||
print(msg)
|
||||
all_messages.append(msg)
|
||||
|
||||
|
||||
def bench(module_file, bench_name):
|
||||
|
@ -103,7 +121,7 @@ def bench(module_file, bench_name):
|
|||
module_dir, module_name = os.path.split(module_file)
|
||||
sys.path.append(module_dir)
|
||||
mod = importlib.import_module(module_name)
|
||||
for fn in inspect.getmembers(mod, inspect.isfunction):
|
||||
for fn in inspect.getmembers(mod):
|
||||
fn_name, fn_obj = fn
|
||||
if fn_name == bench_name + "_bench_configs":
|
||||
configs = list(fn_obj())
|
||||
|
@ -139,7 +157,7 @@ def bench(module_file, bench_name):
|
|||
):
|
||||
print(pt_value)
|
||||
print(ks_value)
|
||||
raise ValueError("Knossos != torch!")
|
||||
raise ValueError("Knossos != torch")
|
||||
|
||||
pt_loss = pt_value.sum()
|
||||
pt_grad = torch.autograd.grad(pt_loss, pt_arg)[0]
|
||||
|
@ -148,7 +166,7 @@ def bench(module_file, bench_name):
|
|||
ks_grad = torch.autograd.grad(ks_loss, ks_arg)[0]
|
||||
|
||||
if (
|
||||
not torch.isclose(pt_grad, ks_grad, rtol=1e-05, atol=1e-06, equal_nan=False)
|
||||
not torch.isclose(pt_grad, ks_grad, rtol=1e-05, atol=1e-05, equal_nan=False)
|
||||
.all()
|
||||
.numpy()
|
||||
):
|
||||
|
@ -161,7 +179,7 @@ def bench(module_file, bench_name):
|
|||
)
|
||||
|
||||
print(pd.DataFrame(cols, columns=["ARG", "PT", "KS", "Diff"]))
|
||||
raise ValueError("Knossos != torch!")
|
||||
raise ValueError("Knossos != torch")
|
||||
|
||||
# ptfast should always work, and be the timing reference
|
||||
timeit(bench_name + " PyTorch fast", pt_fast, arg)
|
||||
|
@ -195,3 +213,6 @@ if __name__ == "__main__":
|
|||
print("run-bench: Will preserve temporary files")
|
||||
ksc.utils.preserve_temporary_files = True
|
||||
bench(args.module, args.bench)
|
||||
|
||||
print("==================================")
|
||||
print(*all_messages, sep="\n")
|
||||
|
|
|
@ -5,6 +5,9 @@ from benchmark_shim import benchmark_semi_pedantic
|
|||
cpu_device = torch.device("cpu")
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# Most of the logic for these benchmark tests is in
|
||||
# conftest.py in this directory
|
||||
|
||||
|
||||
def assert_close(result, reference_result):
|
||||
assert torch.allclose(
|
||||
|
|
|
@ -270,10 +270,6 @@ def generate_cpp_elementwise_entry_point(cpp_function_name, decl):
|
|||
return cpp_declaration, cpp
|
||||
|
||||
|
||||
def generate_cpp_vmap_entry_point(cpp_function_name, decl):
|
||||
raise NotImplementedError("vmap")
|
||||
|
||||
|
||||
def generate_cpp_cuda_entry_point(cpp_function_name, decl):
|
||||
arg_types = arg_types_of_decl(decl)
|
||||
if not all(a == Type.Float for a in arg_types):
|
||||
|
@ -306,5 +302,113 @@ def generate_cpp_cuda_entry_point(cpp_function_name, decl):
|
|||
return {map_function_name}({join_args(', ', lambda i: f'arg{i}')}, functor_{cpp_function_name}{{}});
|
||||
}}
|
||||
"""
|
||||
|
||||
return cpp_declaration, cpp
|
||||
|
||||
|
||||
def generate_cpp_vmap_entry_point(cpp_function_name, decl):
|
||||
def add_vmap_dimension(t: Type):
|
||||
if t.is_scalar:
|
||||
return Type.Tensor(1, t)
|
||||
if t.is_tensor:
|
||||
return Type.Tensor(t.tensor_rank + 1, t.tensor_elem_type)
|
||||
|
||||
# Not clear how to to the remainder -- let's see a use case before deciding
|
||||
raise ValueError("Vmap understands only tensors for now")
|
||||
|
||||
in_arg_types = arg_types_of_decl(decl)
|
||||
|
||||
# Add a "vmap dimension" to each arg
|
||||
arg_types = tuple(add_vmap_dimension(a) for a in in_arg_types)
|
||||
ks_types = tuple(ks_cpp_type(a) for a in arg_types)
|
||||
|
||||
num_args = len(arg_types)
|
||||
|
||||
def join_args(callable, sep=", "):
|
||||
return sep.join(callable(k) for k in range(num_args))
|
||||
|
||||
def concat_args(callable):
|
||||
return "".join(callable(k) for k in range(num_args))
|
||||
|
||||
ks_name = utils.encode_name(decl.name.mangled())
|
||||
|
||||
# torch::Tensor entry_my_kernel(torch::Tensor arg0, ..., torch::Tensor arg7)
|
||||
cpp_function = f"""
|
||||
torch::Tensor {cpp_function_name}({join_args(lambda k: f'torch::Tensor arg{k}')})
|
||||
"""
|
||||
|
||||
cpp_declaration = f"{cpp_function};\n"
|
||||
|
||||
cpp = f"""
|
||||
{cpp_function} {{
|
||||
int64_t n = arg0.size(0);
|
||||
"""
|
||||
|
||||
# auto ks_arg0 = convert_argument<ks::tensor<2,float>>(arg0)
|
||||
# ...
|
||||
# auto ks_arg7 = convert_argument<ks::tensor<1,int>>(arg7)
|
||||
for k in range(num_args):
|
||||
cpp += f"""
|
||||
KS_ASSERT(arg{k}.is_contiguous());
|
||||
KS_ASSERT(arg{k}.scalar_type() == scalar_type_of_Float);
|
||||
KS_ASSERT(arg{k}.size(0) == n);
|
||||
|
||||
auto ks_arg{k} = convert_argument<{ks_types[k]}>(arg{k});
|
||||
"""
|
||||
|
||||
# Difficulty: depending on the rank of ks_ret, ks_ret[i] returns either
|
||||
# Rank 1: a reference to the float at index i
|
||||
# Rank 2: an rvalue representing a view on subtensor i.
|
||||
# inplace add must act differently on each, so we switch on the return dimension here
|
||||
|
||||
ks_return_type = add_vmap_dimension(decl.return_type)
|
||||
ks_return_dim = ks_return_type.tensor_rank
|
||||
if ks_return_dim == 1:
|
||||
cpp += f"""
|
||||
// Create Torch return value
|
||||
auto ret = torch::zeros({{n}});
|
||||
ks::Float* ret_ptr = ret.data_ptr<ks::Float>();
|
||||
|
||||
KS_MARK(&g_alloc, mark);
|
||||
for (int i = 0; i != n; ++i) {{
|
||||
ret_ptr[i] = ks::{ks_name}(&g_alloc {concat_args(lambda k: f", ks_arg{k}[i]")});
|
||||
// We have copied the return value, can reset allocator
|
||||
KS_RESET(&g_alloc, mark);
|
||||
}}
|
||||
|
||||
return ret;
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
ks_sizes = ", ".join([f"size{d}" for d in range(ks_return_dim - 1)])
|
||||
cpp += f"""
|
||||
KS_ASSERT(n > 0); // TODO: Zero-size tensors
|
||||
|
||||
// Make the first call to determine output size
|
||||
auto ret0 = ks::{ks_name}(&g_alloc {concat_args(lambda k: f", ks_arg{k}[0]")});
|
||||
|
||||
// Create Torch return value
|
||||
auto [{ks_sizes}] = ret0.size();
|
||||
// TODO: use torch::empty here, and copydown below.
|
||||
auto ret = torch::zeros({{n, {ks_sizes}}});
|
||||
// And wrap it in ks - this is a view of the torch data, so convert_argument, not convert_return_value
|
||||
auto ks_ret = convert_argument<ks::tensor<{ks_return_dim}, Float>>(ret);
|
||||
|
||||
// Place 0th value in the output
|
||||
auto ks_ret0 = ks_ret[0];
|
||||
inplace_add(&ks_ret0, ret0); // This would update a temporary in the 1D case
|
||||
|
||||
// And then place the rest
|
||||
KS_MARK(&g_alloc, mark);
|
||||
for (int i = 1; i != n; ++i) {{
|
||||
auto val = ks::{ks_name}(&g_alloc {concat_args(lambda k: f", ks_arg{k}[i]")});
|
||||
auto ks_ret_view = ks_ret[i];
|
||||
inplace_add(&ks_ret_view, val);
|
||||
// We have copied the return value, can reset allocator
|
||||
KS_RESET(&g_alloc, mark);
|
||||
}}
|
||||
|
||||
return ret;
|
||||
}}
|
||||
"""
|
||||
print(cpp)
|
||||
return cpp_declaration, cpp
|
||||
|
|
|
@ -699,6 +699,9 @@ class KscStub:
|
|||
"""
|
||||
return self.ensure_compiled(args).apply(*args)
|
||||
|
||||
def _reset_allocator(self, *args):
|
||||
self.ensure_compiled(args).py_mod.reset_allocator()
|
||||
|
||||
def _entry(self, *args):
|
||||
"""
|
||||
Directly call the Knossos compiled function.
|
||||
|
|
|
@ -44,5 +44,25 @@ struct Converter<ks::tensor<2, Float>, torch::Tensor>
|
|||
}
|
||||
};
|
||||
|
||||
// TODO: common these?
|
||||
template<>
|
||||
struct Converter<ks::tensor<3, Float>, torch::Tensor>
|
||||
{
|
||||
static ks::tensor<3, Float> to_ks(torch::Tensor arg) {
|
||||
KS_ASSERT(arg.sizes().size() == 3u);
|
||||
KS_ASSERT(arg.is_contiguous());
|
||||
KS_ASSERT(arg.scalar_type() == scalar_type_of_Float);
|
||||
ks::tensor<3, Float>::index_type size {(int)arg.size(0), (int)arg.size(1), (int)arg.size(2)};
|
||||
return ks::tensor<3, Float>(size, arg.data_ptr<Float>());
|
||||
}
|
||||
|
||||
static torch::Tensor from_ks(ks::tensor<3, Float> ret) {
|
||||
auto [size0, size1, size2] = ret.size();
|
||||
torch::Tensor torch_ret = torch::empty({size0, size1, size2}, torch::TensorOptions().dtype(scalar_type_of_Float));
|
||||
std::memcpy(torch_ret.data_ptr(), ret.data(), size0 * size1 * size2 * sizeof(Float));
|
||||
return torch_ret;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,8 +60,8 @@ Each C++ function is annotated with one of the following macros:
|
|||
#endif
|
||||
|
||||
#ifdef KS_ALLOCATOR
|
||||
#define KS_MARK(alloc, markvar) ks::alloc_mark_t markvar = alloc->mark();
|
||||
#define KS_RESET(alloc, markvar) alloc->reset(markvar);
|
||||
#define KS_MARK(alloc, markvar) ks::alloc_mark_t markvar = (alloc)->mark();
|
||||
#define KS_RESET(alloc, markvar) (alloc)->reset(markvar);
|
||||
#define KS_COPYDOWN(alloc, markvar, expr) ks::copydown(alloc, markvar, expr)
|
||||
#else
|
||||
#define KS_MARK(alloc, markvar)
|
||||
|
@ -1255,7 +1255,7 @@ namespace ks {
|
|||
first iteration (using a copydown), then accumulating
|
||||
subsequent iterations into this result using inplace_add.
|
||||
|
||||
e.g. for a 2-dimensional sumbuild, size {4, 3}, there is
|
||||
e.g. for a 2-dimensional sumbuild, size {4, 3}, there is
|
||||
the following sequence of calls to f (ignoring the allocator
|
||||
argument for simplicity):
|
||||
|
||||
|
|
|
@ -9,9 +9,9 @@ import importlib
|
|||
@pytest.mark.parametrize(
|
||||
"module_file,bench_name",
|
||||
[
|
||||
# ("examples/dl-capsule/sqrl", "vsqrl"),
|
||||
("examples/dl-activations/relu3", "vrelu3"),
|
||||
("examples/dl-capsule/sqrl", "sqrl"),
|
||||
("examples/dl-capsule/sqrl", "vsqrl"),
|
||||
],
|
||||
)
|
||||
def test_bench(module_file, bench_name):
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
import sys
|
||||
import torch
|
||||
import ksc
|
||||
|
||||
import importlib
|
||||
|
||||
sys.path.append(ksc.utils.get_ksc_dir() + "/examples/dl-capsule")
|
||||
|
||||
mod = importlib.import_module("sqrl")
|
||||
|
||||
|
||||
def test_vsqrl_vs_sqrl():
|
||||
x = torch.rand(10, 3, 4)
|
||||
ans = mod.vsqrl_pytorch(x)
|
||||
for i in 0, 1, 4, 9:
|
||||
ans_i = mod.sqrl_pytorch(x[i])
|
||||
assert torch.isclose(ans[i], ans_i)
|
||||
|
||||
|
||||
def test_knossos_vs_pytorch():
|
||||
x = torch.rand(10, 3, 4)
|
||||
ans_pt = mod.vsqrl_pytorch(x)
|
||||
ans_ks = mod.vsqrl(x)
|
||||
assert torch.isclose(ans_pt, ans_ks).all()
|
|
@ -54,13 +54,14 @@ def f(x: float):
|
|||
|
||||
|
||||
def test_ts2k_relux():
|
||||
relux._reset_allocator(2.0)
|
||||
ks_ans = relux._entry(2.0)
|
||||
ans = relux.raw_f(2.0)
|
||||
assert pytest.approx(ks_ans, 1e-6) == ans
|
||||
|
||||
|
||||
def test_ts2k_relux_grad():
|
||||
relux.ensure_compiled((2.0,)) # TODO: remove when entry_vjp knows how to compile
|
||||
relux._reset_allocator((1.3,)) # TODO: remove when entry_vjp knows how to compile
|
||||
ks_ans = relux._entry_vjp(1.3, 1.0)
|
||||
ans = grad_relux(1.3)
|
||||
assert pytest.approx(ks_ans, 1e-6) == ans
|
||||
|
@ -113,11 +114,13 @@ def test_bar():
|
|||
a, x = 1, 12.34
|
||||
|
||||
# Check primal
|
||||
bar._reset_allocator(a, x)
|
||||
ks_ans = bar._entry(a, x)
|
||||
ans = bar.raw_f(a, x)
|
||||
assert pytest.approx(ks_ans, 1e-5) == ans
|
||||
|
||||
# Check grad
|
||||
bar._reset_allocator(a, x)
|
||||
ks_ans = bar._entry_vjp((a, x), 1.0)
|
||||
ans = grad_bar(a, x)
|
||||
assert pytest.approx(ks_ans[1], 1e-5) == ans[1]
|
||||
|
@ -138,6 +141,7 @@ def test_far():
|
|||
x = torch.randn(2, 3)
|
||||
y = torch.randn(2, 5)
|
||||
|
||||
far._reset_allocator(x, y)
|
||||
ks_ans = far._entry(x, y)
|
||||
ans = far.raw_f(x, y)
|
||||
assert pytest.approx(ks_ans, 1e-5) == ans.item()
|
||||
|
|
Загрузка…
Ссылка в новой задаче