Compile changed files only (#944)
When we make a new extension, emit the C++ file only if it has changed. Ninja will then not rebuild the extension. Gives a 7x speedup on full tests including benchmark, and a 19x speedup on regular tests Sends generated files to build/torch_extensions
This commit is contained in:
Родитель
b251b940a7
Коммит
90012065ec
|
@ -5,6 +5,8 @@
|
|||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
|
||||
|
||||
{
|
||||
"name": "(gdb) Launch python for relu3",
|
||||
"type": "cppdbg",
|
||||
|
@ -35,6 +37,20 @@
|
|||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Python: Test config entry",
|
||||
"type": "python",
|
||||
"request": "test", // Probably not what you think -- see https://code.visualstudio.com/docs/python/testing#_debug-tests
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Python: -m pytest",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"justMyCode": false,
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Python: run-bench relu3",
|
||||
"type": "python",
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import os
|
||||
from ksc.torch_utils import elementwise_apply_hack
|
||||
from collections import OrderedDict
|
||||
from ksc import utils
|
||||
import ksc.expr as expr
|
||||
from ksc.type import Type
|
||||
from ksc.torch_frontend import (
|
||||
ksc_string_to_autograd_function,
|
||||
cpp_string_to_autograd_function,
|
||||
)
|
||||
from ksc.utils import get_ksc_paths
|
||||
from ksc.torch_utils import elementwise_apply_hack
|
||||
|
||||
import torch._vmap_internals
|
||||
|
||||
# BEGINDOC
|
||||
|
@ -57,6 +58,7 @@ def vrelu3_embedded_ks_checkpointed_map():
|
|||
(map (lam (ti : Float) ([sufrev [relu3 Float]] ti 1.0)) t))
|
||||
""",
|
||||
expr.StructuredName(("vrelu3", Type.Tensor(1, Type.Float))),
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_ks_checkpointed_map",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -113,6 +115,7 @@ def vrelu3_embedded_cpp_inlined_map():
|
|||
}
|
||||
""",
|
||||
"vrelu3",
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_cpp_inlined_map",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -160,6 +163,7 @@ def vrelu3_embedded_cpp_mask():
|
|||
}
|
||||
""",
|
||||
"vrelu3",
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_cpp_mask",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -212,6 +216,7 @@ def vrelu3_embedded_cpp_mask_bool_to_float():
|
|||
}
|
||||
""",
|
||||
"vrelu3",
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_cpp_mask_bool_to_float",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -242,6 +247,7 @@ def vrelu3_embedded_ks_checkpointed_map_handwritten_relu3():
|
|||
(map (lam (ti : Float) ([sufrev [relu3 Float]] ti 1.0)) t))
|
||||
""",
|
||||
expr.StructuredName(("vrelu3", Type.Tensor(1, Type.Float))),
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_ks_checkpointed_map_handwritten_relu3",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -268,6 +274,7 @@ def vrelu3_embedded_ks_checkpointed_map_handwritten_inlined_relu3():
|
|||
ddri)))) t dret))
|
||||
""",
|
||||
expr.StructuredName(("vrelu3", Type.Tensor(1, Type.Float))),
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_ks_checkpointed_map_handwritten_inlined_relu3",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -297,6 +304,7 @@ def vrelu3_embedded_ks_checkpointed_map_mask():
|
|||
ddri)))))) t dret))
|
||||
""",
|
||||
expr.StructuredName(("vrelu3", Type.Tensor(1, Type.Float))),
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_ks_checkpointed_map_mask",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -317,6 +325,7 @@ def vrelu3_embedded_INCORRECT_ks_upper_bound_via_map():
|
|||
(map (lam (ti : Float) ([sufrev [relu3 Float]] ti 1.0)) t))
|
||||
""",
|
||||
expr.StructuredName(("vrelu3", Type.Tensor(1, Type.Float))),
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_INCORRECT_ks_upper_bound_via_map",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -333,6 +342,7 @@ def vrelu3_embedded_INCORRECT_ks_upper_bound():
|
|||
dret)
|
||||
""",
|
||||
expr.StructuredName(("vrelu3", Type.Tensor(1, Type.Float))),
|
||||
"ksc_dl_activations__manual__vrelu3_embedded_INCORRECT_ks_upper_bound",
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
|
@ -364,16 +374,23 @@ def relu3_pytorch_nice(x: float) -> float:
|
|||
|
||||
|
||||
def vrelu3_cuda_init():
|
||||
__ksc_path, ksc_runtime_dir = get_ksc_paths()
|
||||
__ksc_path, ksc_runtime_dir = utils.get_ksc_paths()
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
||||
# TODO: make this use compile.py?
|
||||
# There's no real need, as there's nothing machine-generated
|
||||
# or generated from a string
|
||||
build_directory = utils.get_ksc_build_dir() + "/torch_extensions/vrelu3_cuda"
|
||||
os.makedirs(build_directory, exist_ok=True)
|
||||
vrelu3_cuda = torch.utils.cpp_extension.load(
|
||||
"vrelu3_module",
|
||||
sources=[
|
||||
os.path.join(this_dir, "vrelu3_cuda.cpp"),
|
||||
os.path.join(this_dir, "vrelu3_cuda_kernel.cu"),
|
||||
],
|
||||
build_directory=build_directory,
|
||||
extra_include_paths=[ksc_runtime_dir],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
class VReLu3Function(torch.autograd.Function):
|
||||
|
|
|
@ -5,6 +5,8 @@ import importlib
|
|||
import inspect
|
||||
import torch
|
||||
import os
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
|
@ -92,7 +94,9 @@ def function_to_manual_cuda_benchmarks(func):
|
|||
)
|
||||
|
||||
|
||||
def functions_to_benchmark(mod, benchmark_name, example_inputs):
|
||||
def functions_to_benchmark(
|
||||
mod, benchmark_name, example_inputs, torch_extension_name_base
|
||||
):
|
||||
for fn_name, fn_obj in inspect.getmembers(mod, lambda m: inspect.isfunction(m)):
|
||||
if fn_name.startswith(benchmark_name):
|
||||
if fn_name == benchmark_name + "_bench_configs":
|
||||
|
@ -102,8 +106,15 @@ def functions_to_benchmark(mod, benchmark_name, example_inputs):
|
|||
elif fn_name == benchmark_name + "_pytorch_nice":
|
||||
yield BenchmarkFunction("PyTorch Nice", fn_obj)
|
||||
elif fn_name == benchmark_name:
|
||||
torch_extension_name = (
|
||||
"ksc_src_bench_" + torch_extension_name_base + "_" + benchmark_name
|
||||
)
|
||||
ks_mod = tsmod2ksmod(
|
||||
mod, benchmark_name, example_inputs, generate_lm=False
|
||||
mod,
|
||||
benchmark_name,
|
||||
torch_extension_name,
|
||||
example_inputs,
|
||||
generate_lm=False,
|
||||
)
|
||||
yield BenchmarkFunction("Knossos", ks_mod.apply)
|
||||
elif fn_name == benchmark_name + "_cuda_init":
|
||||
|
@ -134,6 +145,8 @@ def pytest_configure(config):
|
|||
benchmark_name = config.getoption("benchmarkname")
|
||||
|
||||
module_dir, module_name = os.path.split(module_path)
|
||||
torch_extension_name_base = os.path.basename(module_dir) + "__" + module_name
|
||||
torch_extension_name_base = re.sub("[-.]", "_", torch_extension_name_base)
|
||||
|
||||
with utils.add_to_path(module_dir):
|
||||
mod = importlib.import_module(module_name)
|
||||
|
@ -144,7 +157,9 @@ def pytest_configure(config):
|
|||
|
||||
config.reference_func = getattr(mod, benchmark_name + "_pytorch")
|
||||
config.functions_to_benchmark = list(
|
||||
functions_to_benchmark(mod, benchmark_name, example_inputs)
|
||||
functions_to_benchmark(
|
||||
mod, benchmark_name, example_inputs, torch_extension_name_base
|
||||
)
|
||||
)
|
||||
# We want to group by tensor size, it's not clear how to metaprogram the group mark cleanly.
|
||||
# pytest meta programming conflates arguments and decoration. I've not been able to find a way to directly
|
||||
|
|
|
@ -121,8 +121,13 @@ def bench(module_file, bench_name):
|
|||
print(f"Ignoring {fn_name}")
|
||||
|
||||
# TODO: elementwise_apply
|
||||
torch_extension_name = "ksc_run_bench_" + bench_name
|
||||
ks_compiled = tsmod2ksmod(
|
||||
mod, bench_name, example_inputs=(configs[0],), generate_lm=False
|
||||
mod,
|
||||
bench_name,
|
||||
torch_extension_name,
|
||||
example_inputs=(configs[0],),
|
||||
generate_lm=False,
|
||||
)
|
||||
|
||||
for arg in configs:
|
||||
|
|
|
@ -6,9 +6,9 @@ import sys
|
|||
from tempfile import NamedTemporaryFile
|
||||
from tempfile import gettempdir
|
||||
|
||||
from ksc import utils
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
from ksc import utils
|
||||
|
||||
preserve_temporary_files = False
|
||||
|
||||
|
@ -83,6 +83,9 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
|
|||
|
||||
|
||||
def build_py_module_from_cpp(cpp_str, profiling=False, use_aten=False):
|
||||
"""
|
||||
Build python module, independently of pytorch, non-ninja
|
||||
"""
|
||||
_ksc_path, ksc_runtime_dir = utils.get_ksc_paths()
|
||||
pybind11_path = utils.get_ksc_dir() + "/extern/pybind11"
|
||||
|
||||
|
@ -216,7 +219,7 @@ def build_py_module_from_ks(ks_str, bindings_to_generate, use_aten=False):
|
|||
|
||||
|
||||
def build_module_using_pytorch_from_ks(
|
||||
ks_str, bindings_to_generate, use_aten=False,
|
||||
ks_str, bindings_to_generate, torch_extension_name, use_aten=False
|
||||
):
|
||||
"""Uses PyTorch C++ extension mechanism to build and load a module
|
||||
|
||||
|
@ -231,25 +234,31 @@ def build_module_using_pytorch_from_ks(
|
|||
Each StructuredName must have a type attached
|
||||
"""
|
||||
cpp_str = generate_cpp_for_py_module_from_ks(
|
||||
ks_str, bindings_to_generate, "TORCH_EXTENSION_NAME", use_aten
|
||||
ks_str, bindings_to_generate, torch_extension_name, use_aten
|
||||
)
|
||||
|
||||
return build_module_using_pytorch_from_cpp_backend(cpp_str, use_aten)
|
||||
return build_module_using_pytorch_from_cpp_backend(
|
||||
cpp_str, torch_extension_name, use_aten
|
||||
)
|
||||
|
||||
|
||||
def build_module_using_pytorch_from_cpp(
|
||||
cpp_str, bindings_to_generate, use_aten,
|
||||
cpp_str, bindings_to_generate, torch_extension_name, use_aten
|
||||
):
|
||||
cpp_pybind = generate_cpp_pybind_module_declaration(
|
||||
bindings_to_generate, "TORCH_EXTENSION_NAME"
|
||||
bindings_to_generate, torch_extension_name
|
||||
)
|
||||
return build_module_using_pytorch_from_cpp_backend(
|
||||
cpp_str + cpp_pybind, torch_extension_name, use_aten
|
||||
)
|
||||
return build_module_using_pytorch_from_cpp_backend(cpp_str + cpp_pybind, use_aten)
|
||||
|
||||
|
||||
def build_module_using_pytorch_from_cpp_backend(cpp_str, use_aten):
|
||||
def build_module_using_pytorch_from_cpp_backend(
|
||||
cpp_str, torch_extension_name, use_aten
|
||||
):
|
||||
__ksc_path, ksc_runtime_dir = utils.get_ksc_paths()
|
||||
|
||||
cflags = [
|
||||
extra_cflags = [
|
||||
"-DKS_INCLUDE_ATEN" if use_aten else "",
|
||||
]
|
||||
|
||||
|
@ -262,7 +271,7 @@ def build_module_using_pytorch_from_cpp_backend(cpp_str, use_aten):
|
|||
if cpp_compiler == None and sys.platform == "win32":
|
||||
cflags += ["/std:c++17", "/O2"]
|
||||
else:
|
||||
cflags += [
|
||||
extra_cflags += [
|
||||
"-std=c++17",
|
||||
"-g",
|
||||
"-O3",
|
||||
|
@ -272,11 +281,23 @@ def build_module_using_pytorch_from_cpp_backend(cpp_str, use_aten):
|
|||
verbose = True
|
||||
|
||||
# https://pytorch.org/docs/stable/cpp_extension.html
|
||||
module = load_inline(
|
||||
name="dynamic_ksc_cpp",
|
||||
cpp_sources=[cpp_str],
|
||||
build_directory = (
|
||||
utils.get_ksc_build_dir() + "/torch_extensions/" + torch_extension_name
|
||||
)
|
||||
os.makedirs(build_directory, exist_ok=True)
|
||||
|
||||
cpp_str = "#include <torch/extension.h>\n" + cpp_str
|
||||
|
||||
cpp_source_path = os.path.join(build_directory, "ksc-main.cpp")
|
||||
|
||||
utils.write_file_if_different(cpp_str, cpp_source_path, verbose)
|
||||
|
||||
module = cpp_extension.load(
|
||||
name=torch_extension_name,
|
||||
sources=[cpp_source_path],
|
||||
extra_include_paths=[ksc_runtime_dir],
|
||||
extra_cflags=cflags,
|
||||
extra_cflags=extra_cflags,
|
||||
build_directory=build_directory,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
|
|
@ -484,7 +484,9 @@ def make_KscAutogradFunction(py_mod, generate_lm):
|
|||
return newclass()
|
||||
|
||||
|
||||
def ksc_defs_to_module(ksc_defs, entry_def, derivatives_to_generate):
|
||||
def ksc_defs_to_module(
|
||||
ksc_defs, entry_def, derivatives_to_generate, torch_extension_name
|
||||
):
|
||||
symtab = dict()
|
||||
ksc_dir = utils.get_ksc_dir()
|
||||
decls_prelude = list(parse_ks_filename(ksc_dir + "/src/runtime/prelude.ks"))
|
||||
|
@ -519,52 +521,72 @@ def ksc_defs_to_module(ksc_defs, entry_def, derivatives_to_generate):
|
|||
|
||||
ks_str = "\n".join(map(pformat, defs_with_derivatives))
|
||||
|
||||
return ksc_string_to_module(ks_str, entry_def.name, derivatives_to_generate)
|
||||
return ksc_string_to_module(
|
||||
ks_str, entry_def.name, derivatives_to_generate, torch_extension_name
|
||||
)
|
||||
|
||||
|
||||
def ksc_string_to_module(ks_str, entry_sn, derivatives_to_generate):
|
||||
def ksc_string_to_module(
|
||||
ks_str, entry_sn, derivatives_to_generate, torch_extension_name
|
||||
):
|
||||
bindings_to_generate = [("entry", entry_sn)] + [
|
||||
(f"{der}_entry", StructuredName((der, entry_sn)))
|
||||
for der in derivatives_to_generate
|
||||
]
|
||||
|
||||
return build_module_using_pytorch_from_ks(
|
||||
ks_str, bindings_to_generate, use_aten=True
|
||||
ks_str, bindings_to_generate, torch_extension_name, use_aten=True
|
||||
)
|
||||
|
||||
|
||||
def cpp_string_to_module(cpp_str, entry_name, derivatives_to_generate):
|
||||
def cpp_string_to_module(
|
||||
cpp_str, entry_name, derivatives_to_generate, torch_extension_name
|
||||
):
|
||||
bindings_to_generate = [("entry", entry_name)] + [
|
||||
(f"{der}_entry", f"{der}_{entry_name}") for der in derivatives_to_generate
|
||||
]
|
||||
|
||||
return build_module_using_pytorch_from_cpp(
|
||||
cpp_str, bindings_to_generate, use_aten=True,
|
||||
cpp_str, bindings_to_generate, torch_extension_name, use_aten=True,
|
||||
)
|
||||
|
||||
|
||||
def ksc_defs_to_autograd_function(ksc_defs, entry_def, generate_lm=True):
|
||||
def ksc_defs_to_autograd_function(
|
||||
ksc_defs, entry_def, torch_extension_name, generate_lm=True
|
||||
):
|
||||
derivatives_to_generate = ["fwd", "rev"] if generate_lm else ["sufrev"]
|
||||
mod = ksc_defs_to_module(ksc_defs, entry_def, derivatives_to_generate)
|
||||
mod = ksc_defs_to_module(
|
||||
ksc_defs, entry_def, derivatives_to_generate, torch_extension_name
|
||||
)
|
||||
return make_KscAutogradFunction(mod, generate_lm)
|
||||
|
||||
|
||||
def ksc_string_to_autograd_function(ks_str, entry_sn, generate_lm):
|
||||
def ksc_string_to_autograd_function(
|
||||
ks_str, entry_sn, torch_extension_name, generate_lm
|
||||
):
|
||||
derivatives_to_generate = ["fwd", "rev"] if generate_lm else ["sufrev"]
|
||||
mod = ksc_string_to_module(ks_str, entry_sn, derivatives_to_generate)
|
||||
mod = ksc_string_to_module(
|
||||
ks_str, entry_sn, derivatives_to_generate, torch_extension_name
|
||||
)
|
||||
return make_KscAutogradFunction(mod, generate_lm)
|
||||
|
||||
|
||||
def cpp_string_to_autograd_function(cpp_str, entry_name, generate_lm):
|
||||
def cpp_string_to_autograd_function(
|
||||
cpp_str, entry_name, torch_extension_name, generate_lm
|
||||
):
|
||||
derivatives_to_generate = ["fwd", "rev"] if generate_lm else ["sufrev"]
|
||||
mod = cpp_string_to_module(cpp_str, entry_name, derivatives_to_generate)
|
||||
mod = cpp_string_to_module(
|
||||
cpp_str, entry_name, derivatives_to_generate, torch_extension_name
|
||||
)
|
||||
return make_KscAutogradFunction(mod, generate_lm)
|
||||
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
def tsmod2ksmod(module, function_name, example_inputs, generate_lm=True):
|
||||
def tsmod2ksmod(
|
||||
module, function_name, torch_extension_name, example_inputs, generate_lm=True
|
||||
):
|
||||
global todo_stack
|
||||
todo_stack = {function_name}
|
||||
ksc_defs = []
|
||||
|
@ -581,10 +603,14 @@ def tsmod2ksmod(module, function_name, example_inputs, generate_lm=True):
|
|||
ksc_defs.insert(0, ksc_def)
|
||||
|
||||
entry_def = ksc_defs[-1]
|
||||
return ksc_defs_to_autograd_function(ksc_defs, entry_def, generate_lm)
|
||||
return ksc_defs_to_autograd_function(
|
||||
ksc_defs, entry_def, torch_extension_name, generate_lm
|
||||
)
|
||||
|
||||
|
||||
def ts2mod(function, example_inputs, generate_lm=True):
|
||||
def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True):
|
||||
fn = torch.jit.script(function)
|
||||
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
|
||||
return ksc_defs_to_autograd_function([ksc_def], ksc_def, generate_lm)
|
||||
return ksc_defs_to_autograd_function(
|
||||
[ksc_def], ksc_def, torch_extension_name, generate_lm
|
||||
)
|
||||
|
|
|
@ -111,16 +111,20 @@ def get_ksc_dir():
|
|||
return os.path.dirname(d)
|
||||
|
||||
|
||||
def get_ksc_build_dir():
|
||||
return get_ksc_dir() + "/build"
|
||||
|
||||
|
||||
def get_ksc_paths():
|
||||
if "KSC_RUNTIME_DIR" in os.environ:
|
||||
ksc_runtime_dir = os.environ["KSC_RUNTIME_DIR"]
|
||||
else:
|
||||
ksc_runtime_dir = get_ksc_dir() + "/src/runtime"
|
||||
|
||||
if "KSC_PATH" in os.environ:
|
||||
if "KSC_PATH" in os.environ: # TODO: We should deprecate this
|
||||
ksc_path = os.environ["KSC_PATH"]
|
||||
else:
|
||||
ksc_path = get_ksc_dir() + "/build/bin/ksc"
|
||||
ksc_path = get_ksc_build_dir() + "/bin/ksc"
|
||||
|
||||
return ksc_path, ksc_runtime_dir
|
||||
|
||||
|
@ -185,3 +189,34 @@ def add_to_path(p):
|
|||
finally:
|
||||
sys.path = old_path
|
||||
sys.modules = old_modules
|
||||
|
||||
|
||||
import os.path
|
||||
|
||||
|
||||
def write_file_if_different(to_write, filename, verbose):
|
||||
"""
|
||||
Write LINES to FILENAME unless they are identical to the current contents
|
||||
If VERBOSE, print info to stdout.
|
||||
"""
|
||||
if os.path.isfile(filename):
|
||||
# Read from file
|
||||
with open(filename, "r") as f:
|
||||
existing_contents = f.read()
|
||||
|
||||
# Compare to new
|
||||
if existing_contents == to_write:
|
||||
if verbose:
|
||||
print(f"ksc.utils: File not changed: {filename}")
|
||||
return
|
||||
|
||||
if verbose:
|
||||
print(f"ksc.utils: File changed, overwriting {filename}")
|
||||
|
||||
else:
|
||||
if verbose:
|
||||
print(f"ksc.utils: New file {filename}")
|
||||
|
||||
# And overwrite if different
|
||||
with open(filename, "w") as f:
|
||||
f.write(to_write)
|
||||
|
|
|
@ -33,7 +33,10 @@ def test_bench(module_file, bench_name):
|
|||
|
||||
arg = configs[0]
|
||||
|
||||
ks_compiled = tsmod2ksmod(mod, bench_name, example_inputs=(arg,), generate_lm=False)
|
||||
torch_extension_name = "ksc_test_dl_activations_" + bench_name
|
||||
ks_compiled = tsmod2ksmod(
|
||||
mod, bench_name, torch_extension_name, example_inputs=(arg,), generate_lm=False
|
||||
)
|
||||
|
||||
ks_compiled.py_mod.logging(True)
|
||||
|
||||
|
|
|
@ -60,7 +60,8 @@ def compile_relux():
|
|||
global ks_relux
|
||||
if ks_relux is None:
|
||||
print("Compiling relux")
|
||||
ks_relux = ts2mod(relux, (1.0,))
|
||||
torch_extension_name = "ksc_test_ts2k_relux"
|
||||
ks_relux = ts2mod(relux, (1.0,), torch_extension_name)
|
||||
|
||||
|
||||
def test_ts2k_relux():
|
||||
|
@ -106,7 +107,8 @@ def grad_bar(a: int, x: float):
|
|||
|
||||
def test_bar():
|
||||
a, x = 1, 12.34
|
||||
ks_bar = ts2mod(bar, (a, x))
|
||||
torch_extension_name = "ksc_test_ts2k_bar"
|
||||
ks_bar = ts2mod(bar, (a, x), torch_extension_name)
|
||||
|
||||
# Check primal
|
||||
ks_ans = ks_bar.py_mod.entry(a, x)
|
||||
|
@ -132,7 +134,8 @@ def far(x: torch.Tensor, y: torch.Tensor):
|
|||
def test_far():
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(2, 5)
|
||||
ks_far = ts2mod(far, (x, y))
|
||||
torch_extension_name = "ksc_test_ts2k_far"
|
||||
ks_far = ts2mod(far, (x, y), torch_extension_name)
|
||||
|
||||
ks_ans = ks_far.py_mod.entry(ks_far.adapt(x), ks_far.adapt(y))
|
||||
ans = far(x, y)
|
||||
|
@ -145,7 +148,8 @@ def test_cat():
|
|||
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(2, 5)
|
||||
ks_f = ts2mod(f, (x, y))
|
||||
torch_extension_name = "ksc_test_ts2k_cat"
|
||||
ks_f = ts2mod(f, (x, y), torch_extension_name)
|
||||
ks_ans = ks_f.py_mod.entry(ks_f.adapt(x), ks_f.adapt(y))
|
||||
ks_ans_np = numpy.array(ks_ans, copy=True)
|
||||
py_ans = f(x, y)
|
||||
|
@ -174,7 +178,8 @@ def grad_relu3(x: float) -> float:
|
|||
def test_relu3(generate_lm):
|
||||
x = 0.5
|
||||
|
||||
ks_relu3 = ts2mod(relu3, (x,), generate_lm)
|
||||
torch_extension_name = "ksc_test_ts2k_relu3" + ("_lm" if generate_lm else "")
|
||||
ks_relu3 = ts2mod(relu3, (x,), torch_extension_name, generate_lm)
|
||||
|
||||
for x in [-0.1, 0.31221, 2.27160]:
|
||||
# Test function: ks == py
|
||||
|
|
|
@ -192,7 +192,8 @@ if __name__ == "__xmain__":
|
|||
# %44 : (Tensor, Tensor) = prim::TupleConstruct(%new_h.1, %new_cell.1)
|
||||
# return (%44)
|
||||
|
||||
ks_fun = ts2mod(lltm_forward_py, example_inputs=example_inputs)
|
||||
torch_extension_name = "ksc_awf_timing"
|
||||
ks_fun = ts2mod(lltm_forward_py, example_inputs, torch_extension_name)
|
||||
|
||||
def torch_from_ks(ks_object):
|
||||
if isinstance(ks_object, tuple):
|
||||
|
|
Загрузка…
Ссылка в новой задаче