Knossos.register, and delayed compilation (#960)
This commit is contained in:
Родитель
0daf53e905
Коммит
6cff51b0a5
|
@ -11,7 +11,7 @@
|
|||
"name": "(gdb) Launch python for relu3",
|
||||
"type": "cppdbg",
|
||||
"request": "launch",
|
||||
"program": "/usr/bin/python",
|
||||
"program": "/anaconda/envs/knossos/bin/python",
|
||||
"args": [
|
||||
"src/bench/run-bench.py",
|
||||
"examples/dl-activations/relu3",
|
||||
|
@ -30,6 +30,33 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "(gdb) pytest",
|
||||
"type": "cppdbg",
|
||||
"request": "launch",
|
||||
"program": "/anaconda/envs/knossos/bin/python",
|
||||
"args": [
|
||||
"-m",
|
||||
"pytest",
|
||||
"src/bench/",
|
||||
"--modulepath=examples/dl-capsule/sqrl",
|
||||
"--benchmarkname=sqrl",
|
||||
],
|
||||
"stopAtEntry": false,
|
||||
"cwd": "${workspaceFolder}",
|
||||
"environment": [
|
||||
{"name":"PYTHONPATH", "value":"./src/python"}
|
||||
],
|
||||
"externalConsole": false,
|
||||
"MIMode": "gdb",
|
||||
"setupCommands": [
|
||||
{
|
||||
"description": "Enable pretty-printing for gdb",
|
||||
"text": "-enable-pretty-printing",
|
||||
"ignoreFailures": true
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "python",
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from math import sqrt, tanh, erf, exp
|
||||
import torch
|
||||
|
||||
import ksc.torch_frontend as knossos
|
||||
|
||||
from ksc.torch_utils import elementwise_apply_hack
|
||||
import ksc.compile
|
||||
from ksc.torch_frontend import cpp_string_to_autograd_function
|
||||
|
@ -81,6 +83,7 @@ def sigmoid(x):
|
|||
|
||||
|
||||
# Gelu and activations
|
||||
@knossos.register
|
||||
def gelu(x: float) -> float:
|
||||
return 0.5 * x * (1.0 + erf(x / sqrt(2)))
|
||||
|
||||
|
@ -124,6 +127,7 @@ def gelu_approx_tanh(x: float) -> float:
|
|||
return 0.5 * (1 + tanh(x * (C * x * x + B))) * x
|
||||
|
||||
|
||||
@knossos.register
|
||||
def vgelu(x: torch.Tensor):
|
||||
return elementwise_apply_hack("gelu", x)
|
||||
|
||||
|
|
|
@ -10,11 +10,13 @@ from ksc.torch_frontend import (
|
|||
ksc_string_to_autograd_function,
|
||||
cpp_string_to_autograd_function,
|
||||
)
|
||||
import ksc.torch_frontend as knossos
|
||||
from ksc.torch_utils import elementwise_apply_hack
|
||||
|
||||
import torch._vmap_internals
|
||||
|
||||
# BEGINDOC
|
||||
@knossos.register
|
||||
def relu3(x: float) -> float:
|
||||
"""
|
||||
Like ReLu, but smoother
|
||||
|
@ -60,6 +62,7 @@ if False:
|
|||
vrelu3_pytorch_nice = torch._vmap_internals.vmap(relu3_pytorch_nice)
|
||||
|
||||
# run-bench: Knossos implementation
|
||||
@knossos.register
|
||||
def vrelu3(x: torch.Tensor):
|
||||
return elementwise_apply_hack("relu3", x)
|
||||
|
||||
|
@ -606,3 +609,10 @@ def relu3_in_fcdnn():
|
|||
|
||||
# Run training
|
||||
# train_model(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
y = relu3(0.3)
|
||||
xs = next(vrelu3_bench_configs())
|
||||
ys = vrelu3(xs)
|
||||
print(ys.sum())
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
import ksc.torch_frontend as knossos
|
||||
|
||||
# run-bench: Knossos source, and "nice" PyTorch implementation
|
||||
# BEGINDOC
|
||||
@knossos.register
|
||||
def sqrl(x: torch.Tensor):
|
||||
"""
|
||||
sqrl: Squared Leaky Relu
|
||||
|
|
|
@ -12,7 +12,7 @@ from collections import namedtuple
|
|||
from contextlib import contextmanager
|
||||
from typing import Callable
|
||||
|
||||
from ksc.torch_frontend import tsmod2ksmod
|
||||
from ksc.torch_frontend import KscStub
|
||||
from ksc import utils
|
||||
|
||||
|
||||
|
@ -97,7 +97,7 @@ def function_to_manual_cuda_benchmarks(func):
|
|||
def functions_to_benchmark(
|
||||
mod, benchmark_name, example_inputs, torch_extension_name_base
|
||||
):
|
||||
for fn_name, fn_obj in inspect.getmembers(mod, lambda m: inspect.isfunction(m)):
|
||||
for fn_name, fn_obj in inspect.getmembers(mod):
|
||||
if fn_name.startswith(benchmark_name):
|
||||
if fn_name == benchmark_name + "_bench_configs":
|
||||
continue
|
||||
|
@ -106,17 +106,15 @@ def functions_to_benchmark(
|
|||
elif fn_name == benchmark_name + "_pytorch_nice":
|
||||
yield BenchmarkFunction("PyTorch Nice", fn_obj)
|
||||
elif fn_name == benchmark_name:
|
||||
assert isinstance(fn_obj, KscStub)
|
||||
torch_extension_name = (
|
||||
"ksc_src_bench_" + torch_extension_name_base + "_" + benchmark_name
|
||||
)
|
||||
ks_mod = tsmod2ksmod(
|
||||
mod,
|
||||
benchmark_name,
|
||||
torch_extension_name,
|
||||
example_inputs,
|
||||
generate_lm=False,
|
||||
ks_compiled = fn_obj.compile(
|
||||
torch_extension_name=torch_extension_name,
|
||||
example_inputs=example_inputs,
|
||||
)
|
||||
yield BenchmarkFunction("Knossos", ks_mod.apply)
|
||||
yield BenchmarkFunction("Knossos", ks_compiled.apply)
|
||||
elif fn_name == benchmark_name + "_cuda_init":
|
||||
if torch.cuda.is_available():
|
||||
yield from function_to_manual_cuda_benchmarks(fn_obj)
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
import time
|
||||
from ksc import torch_frontend
|
||||
import torch
|
||||
|
||||
import ksc.torch_frontend
|
||||
from ksc.torch_frontend import tsmod2ksmod
|
||||
|
||||
|
||||
class time_sampler:
|
||||
def __init__(self, minimizing=False):
|
||||
|
@ -120,14 +116,9 @@ def bench(module_file, bench_name):
|
|||
else:
|
||||
print(f"Ignoring {fn_name}")
|
||||
|
||||
# TODO: elementwise_apply
|
||||
torch_extension_name = "ksc_run_bench_" + bench_name
|
||||
ks_compiled = tsmod2ksmod(
|
||||
mod,
|
||||
bench_name,
|
||||
torch_extension_name,
|
||||
ks_compiled = ks_raw.compile(
|
||||
torch_extension_name="ksc_run_bench_" + bench_name,
|
||||
example_inputs=(configs[0],),
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
for arg in configs:
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
from ksc.tracing.jitting import trace
|
||||
from ksc.ks_function import KsFunction
|
||||
from ksc.torch_frontend import ts2ks, ts2ks_fromgraph, ts2mod, tsmod2ksmod
|
||||
from ksc.torch_frontend import (
|
||||
ts2ks,
|
||||
ts2ks_fromgraph,
|
||||
register,
|
||||
)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from typing import Callable, List, Tuple
|
||||
from typing import Callable, List, Tuple, Optional
|
||||
from types import ModuleType
|
||||
from dataclasses import dataclass
|
||||
from contextlib import contextmanager
|
||||
|
||||
import functools
|
||||
import numpy
|
||||
import inspect
|
||||
import torch
|
||||
import torch.onnx
|
||||
|
||||
|
@ -464,6 +466,10 @@ def backward_template(py_mod, ctx, *args):
|
|||
return torch_from_ks(outputs)
|
||||
|
||||
|
||||
class KscAutogradFunction(torch.autograd.Function):
|
||||
pass
|
||||
|
||||
|
||||
def make_KscAutogradFunction(py_mod):
|
||||
# We need to make a new class for every py_mod, as PyTorch requires forward and backward to be
|
||||
# staticmethods. This is not too expensive, as each mod needs to be compiled anyway.
|
||||
|
@ -471,7 +477,7 @@ def make_KscAutogradFunction(py_mod):
|
|||
backward = lambda ctx, args: backward_template(py_mod, ctx, args)
|
||||
return type(
|
||||
"KscAutogradFunction_" + py_mod.__name__,
|
||||
(torch.autograd.Function,),
|
||||
(KscAutogradFunction,),
|
||||
{
|
||||
"py_mod": py_mod,
|
||||
"forward": staticmethod(forward),
|
||||
|
@ -602,53 +608,6 @@ def cpp_string_to_autograd_function(
|
|||
return make_KscAutogradFunction(mod)
|
||||
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
def tsmod2ksmod(
|
||||
module, function_name, torch_extension_name, example_inputs, generate_lm=True
|
||||
):
|
||||
global todo_stack
|
||||
todo_stack = {function_name}
|
||||
ksc_defs = []
|
||||
while len(todo_stack) > 0:
|
||||
print(f"tsmod2ksmod: Remaining: {todo_stack}")
|
||||
for fn in inspect.getmembers(module, inspect.isfunction):
|
||||
fn_name, fn_obj = fn
|
||||
if fn_name in todo_stack:
|
||||
todo_stack.remove(fn_name)
|
||||
print(f"tsmod2ksmod: converting {fn_name}, remaining: {todo_stack}")
|
||||
ts_fn = torch.jit.script(fn_obj)
|
||||
ts_graph = ts_fn.graph
|
||||
ksc_def = ts2ks_fromgraph(False, fn_name, ts_graph, example_inputs)
|
||||
ksc_defs.insert(0, ksc_def)
|
||||
|
||||
elementwise = is_elementwise_operation(ksc_defs[-1])
|
||||
if elementwise:
|
||||
ksc_defs.pop()
|
||||
|
||||
entry_def = ksc_defs[-1]
|
||||
return ksc_defs_to_autograd_function(
|
||||
ksc_defs,
|
||||
entry_def,
|
||||
torch_extension_name,
|
||||
elementwise=elementwise,
|
||||
generate_lm=generate_lm,
|
||||
)
|
||||
|
||||
|
||||
def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True):
|
||||
fn = torch.jit.script(function)
|
||||
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
|
||||
return ksc_defs_to_autograd_function(
|
||||
[ksc_def],
|
||||
ksc_def,
|
||||
torch_extension_name,
|
||||
elementwise=False,
|
||||
generate_lm=generate_lm,
|
||||
)
|
||||
|
||||
|
||||
def is_elementwise_operation(ksc_def):
|
||||
"""
|
||||
Inspect the body of a def to determine whether it is a
|
||||
|
@ -691,3 +650,159 @@ def is_elementwise_operation(ksc_def):
|
|||
print(f"Num args {len(ksc_def.args)}")
|
||||
return False
|
||||
return is_map(ksc_def.body, ksc_def.args[0].name)
|
||||
|
||||
|
||||
def _tsmod2ksmod(
|
||||
module, function_obj, torch_extension_name, example_inputs, generate_lm=True
|
||||
):
|
||||
global todo_stack
|
||||
todo_stack = {function_obj}
|
||||
ksc_defs = []
|
||||
while len(todo_stack) > 0:
|
||||
print(f"tsmod2ksmod: Remaining: {todo_stack}")
|
||||
todo = next(iter(todo_stack))
|
||||
if isinstance(todo, str):
|
||||
# String function name, try to find it in the caller's module
|
||||
todo_fn = None
|
||||
for module_fn_name, module_fn_obj in inspect.getmembers(module):
|
||||
if module_fn_name == todo:
|
||||
print(f"tsmod2ksmod: converting {todo}, remaining: {todo_stack}")
|
||||
if isinstance(module_fn_obj, KscStub):
|
||||
todo_fn = module_fn_obj.raw_f
|
||||
else:
|
||||
todo_fn = module_fn_obj
|
||||
break
|
||||
# Check we found it
|
||||
if not todo_fn:
|
||||
raise ValueError(f"Did not find string-named function {todo}")
|
||||
else:
|
||||
todo_fn = todo
|
||||
|
||||
todo_stack.remove(todo)
|
||||
|
||||
ts_fn = torch.jit.script(todo_fn)
|
||||
ts_graph = ts_fn.graph
|
||||
ksc_def = ts2ks_fromgraph(False, todo_fn.__name__, ts_graph, example_inputs)
|
||||
ksc_defs.insert(0, ksc_def)
|
||||
|
||||
elementwise = is_elementwise_operation(ksc_defs[-1])
|
||||
if elementwise:
|
||||
ksc_defs.pop()
|
||||
|
||||
entry_def = ksc_defs[-1]
|
||||
return ksc_defs_to_autograd_function(
|
||||
ksc_defs,
|
||||
entry_def,
|
||||
torch_extension_name,
|
||||
elementwise=elementwise,
|
||||
generate_lm=generate_lm,
|
||||
)
|
||||
|
||||
|
||||
def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True):
|
||||
fn = torch.jit.script(function)
|
||||
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
|
||||
return ksc_defs_to_autograd_function(
|
||||
[ksc_def],
|
||||
ksc_def,
|
||||
torch_extension_name,
|
||||
elementwise=False,
|
||||
generate_lm=generate_lm,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KscStub:
|
||||
raw_f: Callable
|
||||
generate_lm: bool
|
||||
f_module: ModuleType
|
||||
compiled: Optional[KscAutogradFunction]
|
||||
|
||||
def __call__(self, *args):
|
||||
"""
|
||||
Call with pytorch tensors.
|
||||
This calls the KscAutoGradFunction apply method, so is suitable
|
||||
for use in the "forward/backward" pattern for gradient computation.
|
||||
"""
|
||||
self.ensure_compiled(args)
|
||||
return self.compiled.apply(*args)
|
||||
|
||||
def _entry(self, *args):
|
||||
"""
|
||||
Directly call the Knossos compiled function.
|
||||
Does not wrap torch tensors, or reset memory allocator.
|
||||
For test use only
|
||||
"""
|
||||
self.ensure_compiled(args)
|
||||
return self.compiled.py_mod.entry(*args)
|
||||
|
||||
def _entry_vjp(self, *args):
|
||||
"""
|
||||
Directly call the Knossos vjp function.
|
||||
Does not wrap torch tensors, or reset memory allocator.
|
||||
For test use only
|
||||
"""
|
||||
assert self.compiled # TODO: infer call args from vjp args
|
||||
return self.compiled.py_mod.entry_vjp(*args)
|
||||
|
||||
def compile(self, example_inputs, torch_extension_name):
|
||||
self.compiled = _tsmod2ksmod(
|
||||
self.f_module,
|
||||
self.raw_f,
|
||||
torch_extension_name=torch_extension_name,
|
||||
example_inputs=example_inputs,
|
||||
generate_lm=self.generate_lm,
|
||||
)
|
||||
|
||||
return self.compiled
|
||||
|
||||
def ensure_compiled(self, example_inputs):
|
||||
if not self.compiled:
|
||||
print(f"knossos.register: Compiling {self.raw_f.__name__}")
|
||||
torch_extension_name = (
|
||||
"KscStub_" + self.f_module.__name__ + "_" + self.raw_f.__name__
|
||||
)
|
||||
self.compile(example_inputs, torch_extension_name)
|
||||
|
||||
|
||||
def optional_arg_decorator(register):
|
||||
# https://stackoverflow.com/a/20966822
|
||||
def wrapped_decorator(*args, **kwargs):
|
||||
# Grab the caller's module here, as wrapped_decorator may be 1 or 2 deeper
|
||||
module = inspect.getmodule(inspect.currentframe().f_back)
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
return register(args[0], module)
|
||||
|
||||
# we have optional args
|
||||
def real_decorator(f):
|
||||
return register(f, module, *args, **kwargs)
|
||||
|
||||
return real_decorator
|
||||
|
||||
return wrapped_decorator
|
||||
|
||||
|
||||
@optional_arg_decorator
|
||||
def register(f: Callable, module: ModuleType, generate_lm=False) -> KscStub:
|
||||
"""
|
||||
Main Knossos entry point.
|
||||
|
||||
The @register decorator transforms a TorchScript function into a
|
||||
KscAutogradFunction which implements the function and its
|
||||
derivatives.
|
||||
```
|
||||
@knossos.register
|
||||
def f(x : torch.Tensor) -> torch.Tensor:
|
||||
return x * sin(x)
|
||||
```
|
||||
Endows f with the following behaviours
|
||||
```
|
||||
y = f(x) # Fast (C++/CUDA/...) computation of f(x)
|
||||
vjp(f, x, dy) # Fast computation of dot(dy, [df_i/dx_j])
|
||||
```
|
||||
The implementation delays compilation until the first call, or
|
||||
when "f.compile()" is explicitly called.
|
||||
"""
|
||||
|
||||
return KscStub(f, generate_lm, module, None)
|
||||
|
|
|
@ -39,4 +39,12 @@ def elementwise_apply(f: Callable[[float], float], x: torch.Tensor):
|
|||
|
||||
@torch.jit.ignore
|
||||
def elementwise_apply_hack(f: str, x: torch.Tensor):
|
||||
pass
|
||||
# Convert string function name to callable
|
||||
import inspect
|
||||
|
||||
module = inspect.getmodule(inspect.currentframe().f_back)
|
||||
for fn_name, fn_obj in inspect.getmembers(module):
|
||||
if fn_name == f:
|
||||
return elementwise_apply_pt18(fn_obj, x)
|
||||
|
||||
assert False
|
||||
|
|
|
@ -5,8 +5,6 @@ import torch
|
|||
import inspect
|
||||
import importlib
|
||||
|
||||
from ksc.torch_frontend import tsmod2ksmod
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_file,bench_name",
|
||||
|
@ -22,21 +20,19 @@ def test_bench(module_file, bench_name):
|
|||
module_dir, module_name = os.path.split(module_file)
|
||||
sys.path.append(module_dir)
|
||||
mod = importlib.import_module(module_name)
|
||||
for fn in inspect.getmembers(mod, inspect.isfunction):
|
||||
fn_name, fn_obj = fn
|
||||
for fn_name, fn_obj in inspect.getmembers(mod):
|
||||
if fn_name == bench_name + "_bench_configs":
|
||||
configs = list(fn_obj())
|
||||
elif fn_name == bench_name + "_pytorch":
|
||||
pt_fast = fn_obj
|
||||
elif fn_name == bench_name:
|
||||
ks_raw = fn_obj
|
||||
else:
|
||||
print(f"Ignoring {fn_name}")
|
||||
|
||||
arg = configs[0]
|
||||
|
||||
torch_extension_name = "ksc_test_dl_activations_" + bench_name
|
||||
ks_compiled = tsmod2ksmod(
|
||||
mod, bench_name, torch_extension_name, example_inputs=(arg,), generate_lm=False
|
||||
)
|
||||
ks_compiled = ks_raw.compile((arg,), "ksc_test_dl_activations_" + bench_name)
|
||||
|
||||
pt_arg = arg.detach()
|
||||
pt_arg.requires_grad = True
|
||||
|
|
|
@ -6,6 +6,7 @@ import numpy
|
|||
|
||||
from ksc import utils
|
||||
from ksc.type import Type
|
||||
import ksc.torch_frontend as knossos
|
||||
from ksc.torch_frontend import ts2mod
|
||||
|
||||
|
||||
|
@ -30,6 +31,7 @@ def grad_bar1(a: int, x: float, b: str):
|
|||
return torch.sin(t) + t * torch.cos(t)
|
||||
|
||||
|
||||
@knossos.register
|
||||
def relux(x: float):
|
||||
if x < 0.0:
|
||||
return 0.1 * x
|
||||
|
@ -53,31 +55,19 @@ def f(x: float):
|
|||
return r2
|
||||
|
||||
|
||||
ks_relux = None
|
||||
|
||||
|
||||
def compile_relux():
|
||||
global ks_relux
|
||||
if ks_relux is None:
|
||||
print("Compiling relux")
|
||||
torch_extension_name = "ksc_test_ts2k_relux"
|
||||
ks_relux = ts2mod(relux, (1.0,), torch_extension_name)
|
||||
|
||||
|
||||
def test_relux():
|
||||
compile_relux()
|
||||
ks_ans = ks_relux.py_mod.entry(2.0)
|
||||
ans = relux(2.0)
|
||||
def test_ts2k_relux():
|
||||
ks_ans = relux._entry(2.0)
|
||||
ans = relux.raw_f(2.0)
|
||||
assert pytest.approx(ks_ans, 1e-6) == ans
|
||||
|
||||
|
||||
def test_relux_grad():
|
||||
compile_relux()
|
||||
ks_ans = ks_relux.py_mod.entry_vjp(1.3, 1.0)
|
||||
def test_ts2k_relux_grad():
|
||||
ks_ans = relux._entry_vjp(1.3, 1.0)
|
||||
ans = grad_relux(1.3)
|
||||
assert pytest.approx(ks_ans, 1e-6) == ans
|
||||
|
||||
|
||||
@knossos.register(generate_lm=True)
|
||||
def bar(a: int, x: float):
|
||||
y = torch.tensor([[1.1, -1.2], [2.1, 2.2]])
|
||||
|
||||
|
@ -107,20 +97,19 @@ def grad_bar(a: int, x: float):
|
|||
|
||||
def test_bar():
|
||||
a, x = 1, 12.34
|
||||
torch_extension_name = "ksc_test_ts2k_bar"
|
||||
ks_bar = ts2mod(bar, (a, x), torch_extension_name)
|
||||
|
||||
# Check primal
|
||||
ks_ans = ks_bar.py_mod.entry(a, x)
|
||||
ans = bar(a, x)
|
||||
ks_ans = bar._entry(a, x)
|
||||
ans = bar.raw_f(a, x)
|
||||
assert pytest.approx(ks_ans, 1e-5) == ans
|
||||
|
||||
# Check grad
|
||||
ks_ans = ks_bar.py_mod.entry_vjp((a, x), 1.0)
|
||||
ks_ans = bar._entry_vjp((a, x), 1.0)
|
||||
ans = grad_bar(a, x)
|
||||
assert pytest.approx(ks_ans[1], 1e-5) == ans[1]
|
||||
|
||||
|
||||
@knossos.register(generate_lm=True)
|
||||
def far(x: torch.Tensor, y: torch.Tensor):
|
||||
xx = torch.cat([x, y], dim=1)
|
||||
xbar = torch.mean(xx)
|
||||
|
@ -134,25 +123,23 @@ def far(x: torch.Tensor, y: torch.Tensor):
|
|||
def test_far():
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(2, 5)
|
||||
torch_extension_name = "ksc_test_ts2k_far"
|
||||
ks_far = ts2mod(far, (x, y), torch_extension_name)
|
||||
|
||||
ks_ans = ks_far.py_mod.entry(ks_far.adapt(x), ks_far.adapt(y))
|
||||
ans = far(x, y)
|
||||
ks_ans = far._entry(x, y)
|
||||
ans = far.raw_f(x, y)
|
||||
assert pytest.approx(ks_ans, 1e-5) == ans.item()
|
||||
|
||||
|
||||
def test_cat():
|
||||
@knossos.register(generate_lm=True)
|
||||
def f(x: torch.Tensor, y: torch.Tensor):
|
||||
return torch.cat([x, y], dim=1)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(2, 5)
|
||||
torch_extension_name = "ksc_test_ts2k_cat"
|
||||
ks_f = ts2mod(f, (x, y), torch_extension_name)
|
||||
ks_ans = ks_f.py_mod.entry(ks_f.adapt(x), ks_f.adapt(y))
|
||||
|
||||
ks_ans = f._entry(x, y)
|
||||
ks_ans_np = numpy.array(ks_ans, copy=True)
|
||||
py_ans = f(x, y)
|
||||
py_ans = f.raw_f(x, y)
|
||||
assert (ks_ans_np == py_ans.numpy()).all() # non-approx
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче