This commit is contained in:
Andrew Fitzgibbon 2021-08-27 22:04:37 +01:00 коммит произвёл GitHub
Родитель 7ab2d24250
Коммит a79b395997
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 229 добавлений и 20 удалений

2
.vscode/launch.json поставляемый
Просмотреть файл

@ -5,6 +5,7 @@
{
"version": "0.2.0",
"configurations": [
{
@ -39,6 +40,7 @@
"-m",
"pytest",
"src/bench/",
"-v",
"--modulepath=examples/dl-capsule/sqrl",
"--benchmarkname=sqrl",
],

Просмотреть файл

@ -42,7 +42,33 @@ def sqrl_bench_configs():
# vsqrl - vectorized sqrl
#
vsqrl = knossos.vmap(sqrl)
vsqrl = knossos.register_direct(sqrl, vmap=True, generate_lm=True) # TODO: Carbuncle
def sqrl_pytorch_where(x):
"""
Replace "if" with "where" to get torch.vmap to work
"""
y = torch.sum(x)
t = torch.where(y < 0, -0.125 * x, 1 / 2 * x ** 2)
tsint = torch.sin(t) * t
return torch.mean(tsint)
import torch._vmap_internals
vsqrl_pytorch_nice = torch._vmap_internals.vmap(sqrl_pytorch_where)
def vsqrl_pytorch(x):
"""
Hand-vectorized pytorch implementation, assuming x is rank 3
"""
y = torch.sum(x, (1, 2), keepdim=True)
y_lt_0 = (y < 0).repeat((1, *x.size()[1:]))
t = torch.where(y_lt_0, -0.125 * x, 1 / 2 * x ** 2)
tsint = torch.sin(t) * t
return torch.mean(tsint, (1, 2))
# run-bench: Define a range of values at which to call the methods

Просмотреть файл

@ -149,7 +149,7 @@ def func_namer(benchmark_func):
def config_namer(config):
return str(config.shape)
return "x".join([str(x) for x in config.shape])
def pytest_configure(config):

Просмотреть файл

@ -2,10 +2,12 @@ export PYTHONPATH="./src/python"
# TODO: this should be a makefile. fred.csv: fred.py etc
BENCH="pytest src/bench/ \
-v\
--benchmark-autosave --benchmark-max-time=5.0\
--benchmark-name=short --benchmark-sort=name --benchmark-group-by=group,func\
--benchmark-columns=median,iqr,outliers,mean,stddev,min,max,iterations,rounds"
# $BENCH --modulepath=examples/dl-activations/relu3 --benchmarkname=vrelu3
$BENCH --modulepath=examples/dl-capsule/sqrl --benchmarkname=sqrl
$BENCH --modulepath=examples/dl-capsule/sqrl --benchmarkname=vsqrl
$BENCH --modulepath=examples/dl-activations/gelu --benchmarkname=vgelu

Просмотреть файл

@ -1,4 +1,6 @@
import time
import os
import psutil
import torch
@ -52,6 +54,9 @@ def fun_and_grad_matches(f, g, arg):
return True
all_messages = []
def timeit(msg, fn, arg):
MAX_TIME = 5 # No no need to run beyond MAX_TIME sec to get accurate benchmarks
end_time = time.time() + MAX_TIME
@ -59,6 +64,10 @@ def timeit(msg, fn, arg):
forward_timer = time_sampler()
backward_timer = time_sampler()
nruns = 5000
mem_used_start = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
print(f"run-bench: Memory {mem_used_start} before {msg}")
for _ in range(nruns):
inference_timer.mark()
with torch.no_grad():
@ -75,16 +84,25 @@ def timeit(msg, fn, arg):
backward_timer.record()
if time.time() > end_time:
print(f"# Ran to timeout: {fn} {msg} ")
print(f"# Ran to timeout: {msg} ")
break
csum = grad[0].sum()
print(
f"{msg:20} {csum:12.6e} Runs: {inference_timer.ncalls} | Inference: {inference_timer.ms:10.3f} ms |"
mem_used_end = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
print(f"run-bench: Memory {mem_used_end} after {msg}")
mem_used = mem_used_end - mem_used_start
shape_str = "x".join([str(x) for x in arg.shape])
msg = (
f"{msg:20} {csum:12.5e} Runs: {inference_timer.ncalls:4d} | Inference: {inference_timer.ms:10.3f} ms |"
f" Forward: {forward_timer.ms:10.3f} ms |"
f" Backward {backward_timer.ms:10.3f} ms | {arg.shape}"
f" Backward {backward_timer.ms:10.3f} ms |"
f" Memory {mem_used:10.3f} MB |"
f" {shape_str}"
)
print(msg)
all_messages.append(msg)
def bench(module_file, bench_name):
@ -103,7 +121,7 @@ def bench(module_file, bench_name):
module_dir, module_name = os.path.split(module_file)
sys.path.append(module_dir)
mod = importlib.import_module(module_name)
for fn in inspect.getmembers(mod, inspect.isfunction):
for fn in inspect.getmembers(mod):
fn_name, fn_obj = fn
if fn_name == bench_name + "_bench_configs":
configs = list(fn_obj())
@ -139,7 +157,7 @@ def bench(module_file, bench_name):
):
print(pt_value)
print(ks_value)
raise ValueError("Knossos != torch!")
raise ValueError("Knossos != torch")
pt_loss = pt_value.sum()
pt_grad = torch.autograd.grad(pt_loss, pt_arg)[0]
@ -148,7 +166,7 @@ def bench(module_file, bench_name):
ks_grad = torch.autograd.grad(ks_loss, ks_arg)[0]
if (
not torch.isclose(pt_grad, ks_grad, rtol=1e-05, atol=1e-06, equal_nan=False)
not torch.isclose(pt_grad, ks_grad, rtol=1e-05, atol=1e-05, equal_nan=False)
.all()
.numpy()
):
@ -161,7 +179,7 @@ def bench(module_file, bench_name):
)
print(pd.DataFrame(cols, columns=["ARG", "PT", "KS", "Diff"]))
raise ValueError("Knossos != torch!")
raise ValueError("Knossos != torch")
# ptfast should always work, and be the timing reference
timeit(bench_name + " PyTorch fast", pt_fast, arg)
@ -195,3 +213,6 @@ if __name__ == "__main__":
print("run-bench: Will preserve temporary files")
ksc.utils.preserve_temporary_files = True
bench(args.module, args.bench)
print("==================================")
print(*all_messages, sep="\n")

Просмотреть файл

@ -5,6 +5,9 @@ from benchmark_shim import benchmark_semi_pedantic
cpu_device = torch.device("cpu")
torch.set_default_dtype(torch.float32)
# Most of the logic for these benchmark tests is in
# conftest.py in this directory
def assert_close(result, reference_result):
assert torch.allclose(

Просмотреть файл

@ -270,10 +270,6 @@ def generate_cpp_elementwise_entry_point(cpp_function_name, decl):
return cpp_declaration, cpp
def generate_cpp_vmap_entry_point(cpp_function_name, decl):
raise NotImplementedError("vmap")
def generate_cpp_cuda_entry_point(cpp_function_name, decl):
arg_types = arg_types_of_decl(decl)
if not all(a == Type.Float for a in arg_types):
@ -306,5 +302,113 @@ def generate_cpp_cuda_entry_point(cpp_function_name, decl):
return {map_function_name}({join_args(', ', lambda i: f'arg{i}')}, functor_{cpp_function_name}{{}});
}}
"""
return cpp_declaration, cpp
def generate_cpp_vmap_entry_point(cpp_function_name, decl):
def add_vmap_dimension(t: Type):
if t.is_scalar:
return Type.Tensor(1, t)
if t.is_tensor:
return Type.Tensor(t.tensor_rank + 1, t.tensor_elem_type)
# Not clear how to to the remainder -- let's see a use case before deciding
raise ValueError("Vmap understands only tensors for now")
in_arg_types = arg_types_of_decl(decl)
# Add a "vmap dimension" to each arg
arg_types = tuple(add_vmap_dimension(a) for a in in_arg_types)
ks_types = tuple(ks_cpp_type(a) for a in arg_types)
num_args = len(arg_types)
def join_args(callable, sep=", "):
return sep.join(callable(k) for k in range(num_args))
def concat_args(callable):
return "".join(callable(k) for k in range(num_args))
ks_name = utils.encode_name(decl.name.mangled())
# torch::Tensor entry_my_kernel(torch::Tensor arg0, ..., torch::Tensor arg7)
cpp_function = f"""
torch::Tensor {cpp_function_name}({join_args(lambda k: f'torch::Tensor arg{k}')})
"""
cpp_declaration = f"{cpp_function};\n"
cpp = f"""
{cpp_function} {{
int64_t n = arg0.size(0);
"""
# auto ks_arg0 = convert_argument<ks::tensor<2,float>>(arg0)
# ...
# auto ks_arg7 = convert_argument<ks::tensor<1,int>>(arg7)
for k in range(num_args):
cpp += f"""
KS_ASSERT(arg{k}.is_contiguous());
KS_ASSERT(arg{k}.scalar_type() == scalar_type_of_Float);
KS_ASSERT(arg{k}.size(0) == n);
auto ks_arg{k} = convert_argument<{ks_types[k]}>(arg{k});
"""
# Difficulty: depending on the rank of ks_ret, ks_ret[i] returns either
# Rank 1: a reference to the float at index i
# Rank 2: an rvalue representing a view on subtensor i.
# inplace add must act differently on each, so we switch on the return dimension here
ks_return_type = add_vmap_dimension(decl.return_type)
ks_return_dim = ks_return_type.tensor_rank
if ks_return_dim == 1:
cpp += f"""
// Create Torch return value
auto ret = torch::zeros({{n}});
ks::Float* ret_ptr = ret.data_ptr<ks::Float>();
KS_MARK(&g_alloc, mark);
for (int i = 0; i != n; ++i) {{
ret_ptr[i] = ks::{ks_name}(&g_alloc {concat_args(lambda k: f", ks_arg{k}[i]")});
// We have copied the return value, can reset allocator
KS_RESET(&g_alloc, mark);
}}
return ret;
}}
"""
else:
ks_sizes = ", ".join([f"size{d}" for d in range(ks_return_dim - 1)])
cpp += f"""
KS_ASSERT(n > 0); // TODO: Zero-size tensors
// Make the first call to determine output size
auto ret0 = ks::{ks_name}(&g_alloc {concat_args(lambda k: f", ks_arg{k}[0]")});
// Create Torch return value
auto [{ks_sizes}] = ret0.size();
// TODO: use torch::empty here, and copydown below.
auto ret = torch::zeros({{n, {ks_sizes}}});
// And wrap it in ks - this is a view of the torch data, so convert_argument, not convert_return_value
auto ks_ret = convert_argument<ks::tensor<{ks_return_dim}, Float>>(ret);
// Place 0th value in the output
auto ks_ret0 = ks_ret[0];
inplace_add(&ks_ret0, ret0); // This would update a temporary in the 1D case
// And then place the rest
KS_MARK(&g_alloc, mark);
for (int i = 1; i != n; ++i) {{
auto val = ks::{ks_name}(&g_alloc {concat_args(lambda k: f", ks_arg{k}[i]")});
auto ks_ret_view = ks_ret[i];
inplace_add(&ks_ret_view, val);
// We have copied the return value, can reset allocator
KS_RESET(&g_alloc, mark);
}}
return ret;
}}
"""
print(cpp)
return cpp_declaration, cpp

Просмотреть файл

@ -699,6 +699,9 @@ class KscStub:
"""
return self.ensure_compiled(args).apply(*args)
def _reset_allocator(self, *args):
self.ensure_compiled(args).py_mod.reset_allocator()
def _entry(self, *args):
"""
Directly call the Knossos compiled function.

Просмотреть файл

@ -44,5 +44,25 @@ struct Converter<ks::tensor<2, Float>, torch::Tensor>
}
};
// TODO: common these?
template<>
struct Converter<ks::tensor<3, Float>, torch::Tensor>
{
static ks::tensor<3, Float> to_ks(torch::Tensor arg) {
KS_ASSERT(arg.sizes().size() == 3u);
KS_ASSERT(arg.is_contiguous());
KS_ASSERT(arg.scalar_type() == scalar_type_of_Float);
ks::tensor<3, Float>::index_type size {(int)arg.size(0), (int)arg.size(1), (int)arg.size(2)};
return ks::tensor<3, Float>(size, arg.data_ptr<Float>());
}
static torch::Tensor from_ks(ks::tensor<3, Float> ret) {
auto [size0, size1, size2] = ret.size();
torch::Tensor torch_ret = torch::empty({size0, size1, size2}, torch::TensorOptions().dtype(scalar_type_of_Float));
std::memcpy(torch_ret.data_ptr(), ret.data(), size0 * size1 * size2 * sizeof(Float));
return torch_ret;
}
};
}
}

Просмотреть файл

@ -60,8 +60,8 @@ Each C++ function is annotated with one of the following macros:
#endif
#ifdef KS_ALLOCATOR
#define KS_MARK(alloc, markvar) ks::alloc_mark_t markvar = alloc->mark();
#define KS_RESET(alloc, markvar) alloc->reset(markvar);
#define KS_MARK(alloc, markvar) ks::alloc_mark_t markvar = (alloc)->mark();
#define KS_RESET(alloc, markvar) (alloc)->reset(markvar);
#define KS_COPYDOWN(alloc, markvar, expr) ks::copydown(alloc, markvar, expr)
#else
#define KS_MARK(alloc, markvar)
@ -1255,7 +1255,7 @@ namespace ks {
first iteration (using a copydown), then accumulating
subsequent iterations into this result using inplace_add.
e.g. for a 2-dimensional sumbuild, size {4, 3}, there is
e.g. for a 2-dimensional sumbuild, size {4, 3}, there is
the following sequence of calls to f (ignoring the allocator
argument for simplicity):

Просмотреть файл

@ -9,9 +9,9 @@ import importlib
@pytest.mark.parametrize(
"module_file,bench_name",
[
# ("examples/dl-capsule/sqrl", "vsqrl"),
("examples/dl-activations/relu3", "vrelu3"),
("examples/dl-capsule/sqrl", "sqrl"),
("examples/dl-capsule/sqrl", "vsqrl"),
],
)
def test_bench(module_file, bench_name):

24
test/ts2k/test_sqrl.py Normal file
Просмотреть файл

@ -0,0 +1,24 @@
import sys
import torch
import ksc
import importlib
sys.path.append(ksc.utils.get_ksc_dir() + "/examples/dl-capsule")
mod = importlib.import_module("sqrl")
def test_vsqrl_vs_sqrl():
x = torch.rand(10, 3, 4)
ans = mod.vsqrl_pytorch(x)
for i in 0, 1, 4, 9:
ans_i = mod.sqrl_pytorch(x[i])
assert torch.isclose(ans[i], ans_i)
def test_knossos_vs_pytorch():
x = torch.rand(10, 3, 4)
ans_pt = mod.vsqrl_pytorch(x)
ans_ks = mod.vsqrl(x)
assert torch.isclose(ans_pt, ans_ks).all()

Просмотреть файл

@ -54,13 +54,14 @@ def f(x: float):
def test_ts2k_relux():
relux._reset_allocator(2.0)
ks_ans = relux._entry(2.0)
ans = relux.raw_f(2.0)
assert pytest.approx(ks_ans, 1e-6) == ans
def test_ts2k_relux_grad():
relux.ensure_compiled((2.0,)) # TODO: remove when entry_vjp knows how to compile
relux._reset_allocator((1.3,)) # TODO: remove when entry_vjp knows how to compile
ks_ans = relux._entry_vjp(1.3, 1.0)
ans = grad_relux(1.3)
assert pytest.approx(ks_ans, 1e-6) == ans
@ -113,11 +114,13 @@ def test_bar():
a, x = 1, 12.34
# Check primal
bar._reset_allocator(a, x)
ks_ans = bar._entry(a, x)
ans = bar.raw_f(a, x)
assert pytest.approx(ks_ans, 1e-5) == ans
# Check grad
bar._reset_allocator(a, x)
ks_ans = bar._entry_vjp((a, x), 1.0)
ans = grad_bar(a, x)
assert pytest.approx(ks_ans[1], 1e-5) == ans[1]
@ -138,6 +141,7 @@ def test_far():
x = torch.randn(2, 3)
y = torch.randn(2, 5)
far._reset_allocator(x, y)
ks_ans = far._entry(x, y)
ans = far.raw_f(x, y)
assert pytest.approx(ks_ans, 1e-5) == ans.item()