* Implement jit demo version and add tests (#378)

* Add signatures to detect reusable kernel + nnf rt mem reservation + jit optional kwargs (#379)

* Add signatures to detect reusable kernel

* Add reserve nnf workspace memory

* Add jit optionally take arguments (keyword-only)

Reference
https://realpython.com/primer-on-python-decorators/#both-please-but-never-mind-the-bread

* Fix who to reserve

* Fix unit bug + don't check twice

* Remove unused imports

* Add comment to pytest

* Fix missing kwargs

* Support JIT for class and class method + nnf config + 3-lvl-signature (#398)

* Add docstring

* Support JIT for class method

* Support decorator for class

* Fix signature bug + add failure case

* test_jit.py clean code

* Change relative path to __module__

* Add nnfusion config

* Impl. 3 level signature

* Clean code + add test_config + del test_graph

test_keep_signature_but_change_compute_graph can pass now but
it makes no sense to be tested now since different kernels with
the same object signature will no longer replace each other.

* Clean code

* Add docstrings

* Clean code

* Fix tune to follow docs

* More pythonic

http://www.kr41.net/2016/03-23-dont_inherit_python_builtin_dict_type.html

* Update docstring

* Add test case

* Update docstring

* Fix bug: decorator for method with multi instances

* Clean code --amend

* Add NNFusion-JIT docs

Co-authored-by: Wenxiang Hu <8460860+wenxcs@users.noreply.github.com>
This commit is contained in:
siahuat0727 2022-04-02 16:40:15 +08:00 коммит произвёл GitHub
Родитель 1ee7762413
Коммит fc58f6bff9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 840 добавлений и 10 удалений

117
docs/NNFusion-JIT.md Normal file
Просмотреть файл

@ -0,0 +1,117 @@
# How to use NNFusion JIT
NNFusion JIT is a JIT compiler for PyTorch using NNFusion CLI. It allows the user to JIT a function or a `torch.nn.Module` object using a decorator or a function wrapper. All NNFusion optimization such as kernel tuning can be applied using this interface.
## nnfusion.jit
`nnfusion.jit(obj=None, *, tune=None, tuning_steps=None, config=None)`
Lazily trace an object and optimize using nnfusion compilation until the object is called. It can be used as a decorator or an explicit function wrapper. The inputs and outputs of the object should be `torch.Tensor` or a sequence of `torch.Tensor`.
### Parameters
+ obj (function, `torch.nn.Module` instance / method / class):
+ Target object to be traced. When `obj` is an instance or a class, it is equivalent to tracing its `forward` function.
+ tune (Optional[bool]):
+ Whether to tune kernel. By default it follows `config`. If set, it overwrites `config`.
+ tuning_steps (Optional[int]):
+ Number of kernel tuning steps. By default it follows `config`. If set, it overwrites `config` and `tune`.
+ config (Optional[dict, nnfusion.Config]):
+ NNFusion compilation config. By default it will be set to default config `nnfusion.Config()`. Pass a `dict` to overwrite default config or directly pass an instance of `nnfusion.Config`.
+ For example, `@nnfusion.jit(tune=True, config={'kernel_tuning_steps': 42})`
+ For more flags information, please execute the command `nnfusion` in the terminal.
### Use Cases Demo
`nnfusion.jit` can be used as a function wrapper for standalone functions and `torch.nn.Module` instances/methods/classes.
It can also be used as a decorator for standalone functions and `torch.nn.Module` methods/classes.
```python
# Case 1: decorator for a standalone function
@nnfusion.jit
def foo(t1, t2):
return t1 + t2, t1 - t2
# Case 2: decorator for a class method
class Net(nn.Linear):
@nnfusion.jit
def foo(self, x):
return super().forward(x)
# Case 3: decorator for a class
@nnfusion.jit
class Net(nn.Linear):
def this_will_not_be_traced(self, x):
return super().forward(x)
def forward(self, x):
return super().forward(x)
# Case 4: function for a standalone function
def foo(t1, t2):
return t1 + t2, t1 - t2
jitted_foo = nnfusion.jit(foo)
# Case 5: function for a torch.nn.Module class method
class Net(nn.Linear):
def foo(self, x):
return super().forward(x)
model = Net().eval()
model.foo = nnfusion.jit(model.foo)
# Case 6: function for a torch.nn.Module class
jitted_linear = nnfusion.jit(nn.Linear)
model = jitted_linear().eval()
# Case 7: function for a torch.nn.Module instance
class Net(nn.Linear):
def forward(self, x):
return super().forward(x)
model = Net().eval()
jitted_model = nnfusion.jit(model)
```
It is allowed to pass optional keyword arguments:
```python
@nnfusion.jit(tune=True)
def foo(t1, t2):
return t1 + t2
def bar(t):
return t + t, t * t
jitted_bar = nnfusion.jit(bar, tuning_steps=2000)
```
### Compiled kernels caching strategies
The compiled kernels are saved in `nnf-kernels/`. If a "match" kernel is found before the compilation, it will be directly reused. Here, "match" means having the same object signature (`__module__` and `__qualname__`), computational graph (ONNX model binary), and NNFusion compilation config (`config`).
## nnfusion.Config
`nnfusion.Config(*args, antares_mode=True, blockfusion_level=0, extern_result_memory=True, function_codegen=True, ir_based_fusion=False, kernel_fusion_level=0, kernel_tuning_steps=1000, **kwargs)`
NNFusion compilation config. Can pass in any other NNFusion compiler flags (execute the command `nnfusion` in the terminal for more details) and unknown flags will be ignored. Use it as a `dict` with some default key-value pairs.
### Use Cases Demo
```python
config = nnfusion.Config(function_codegen=False,
new_flag=42)
config = nnfusion.Config({'function_codegen': False,
'new_flag': 42})
config = nnfusion.Config()
config['function_codegen'] = False
config['new_flag'] = 42
```

Просмотреть файл

@ -11,6 +11,7 @@ Welcome to the NNFusion Documents!
- [Compile-a-Tensorflow-model-with-NNFusion](Compile-a-Tensorflow-model-with-NNFusion.md)
- [Compile-a-model-with-kernel-tuning-enabled](Compile-a-model-with-kernel-tuning-enabled.md)
- [How to use NNFusion Python interface for inference/training](../src/python/example/README.md)
- [How to use NNFusion JIT](NNFusion-JIT.md)
4. [Guide-for-Contributors](Guide-for-Contributors.md)
- [Contribution-Guide](Contribution-Guide.md)
- [Coding-Guide](Coding-Guide.md)

Просмотреть файл

@ -2,4 +2,7 @@
# Licensed under the MIT License.
__version__ = "0.3.0"
__author__ = "Microsoft"
__author__ = "Microsoft"
from .jit import jit
from .config import Config

Просмотреть файл

@ -0,0 +1,53 @@
from collections.abc import MutableMapping
class Config(MutableMapping):
"""
NNFusion compilation config. Can pass in any other NNFusion compiler flags
(execute the command `nnfusion` in the terminal for more details) and
unknown flags will be ignored.
Use it as a `dict` with some default key-value pairs.
"""
def __init__(self,
*args,
antares_mode=True,
blockfusion_level=0,
extern_result_memory=True,
function_codegen=True,
ir_based_fusion=False,
kernel_fusion_level=0,
kernel_tuning_steps=1000,
**kwargs):
locals_ = locals()
self._storage = {
flag: locals_[flag]
for flag in self.__init__.__kwdefaults__
}
self._storage.update(dict(*args, **kwargs))
@staticmethod
def _parse_flag_value(flag, value):
value = int(value) if isinstance(value, bool) else value
return f'-f{flag}={value}'
def to_flag(self):
return ' '.join([
self._parse_flag_value(flag, value)
for flag, value in sorted(self._storage.items())
])
def __iter__(self):
return iter(self._storage)
def __len__(self):
return len(self._storage)
def __getitem__(self, key):
return self._storage[key]
def __setitem__(self, key, value):
self._storage[key] = value
def __delitem__(self, key):
del self._storage[key]

Просмотреть файл

@ -1,13 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import ctypes
import json
import os
import platform
import json
import ctypes
from . import dtypes
from .utils import cd
import torch
from .data_format import cast_pytorch_tensor
from .description import IODescription
from .utils import cd
def find_nnf_rt(nnf_rt_dir):
@ -80,11 +82,13 @@ class Executor(object):
5: ("", ""), # UNKNOWN
}
def __init__(self, nnf_rt_dir):
def __init__(self, nnf_rt_dir, device=None):
"""
Parameters:
nnf_rt_dir: A full string path to nnfusion runtime,
it's usually like "codegen_root/nnfusion_rt/cuda_codegen".
device: A device type (`torch.device`) that is used for workspace
memory reservation (if needed) by nnfusion runtime.
"""
nnf_rt_dir = os.path.abspath(nnf_rt_dir)
self.libnnf_path = find_nnf_rt(nnf_rt_dir)
@ -113,9 +117,14 @@ class Executor(object):
init_func_name, free_func_name = self.device_type_map[device_type]
self.nnf_rt_init = getattr(self.libnnf, init_func_name, None)
self.nnf_rt_free = getattr(self.libnnf, free_func_name, None)
if self.nnf_rt_init:
with cd(nnf_rt_dir):
self.nnf_rt_init()
workspace_ptr = self._maybe_reserve_mem(device)
if workspace_ptr is not None:
self.nnf_rt_init(workspace_ptr)
else:
self.nnf_rt_init()
self.init_flag = True
# parse input/output
@ -218,4 +227,17 @@ class Executor(object):
def feed_pointers(self, signature, params):
self.kernel_entry.argtypes = signature
self.kernel_entry(*params)
self.kernel_entry(*params)
def _maybe_reserve_mem(self, device):
get_workspace_size = getattr(self.libnnf, 'get_workspace_size', None)
if get_workspace_size is None:
return None
n_byte = get_workspace_size()
if not n_byte:
return None
self._reserved_mem = torch.empty(n_byte,
dtype=torch.int8, device=device)
return cast_pytorch_tensor(self._reserved_mem).pointer

219
src/python/nnfusion/jit.py Normal file
Просмотреть файл

@ -0,0 +1,219 @@
import copy
import functools
from inspect import isfunction, ismethod, isclass
import torch
from .jit_utils import TorchModule, get_signature
from .runtime import NNFusionRT
from .config import Config
def is_method_of_instance(obj, cls):
return ismethod(obj) and isinstance(obj.__self__, cls)
def is_subclass_of_cls(obj, cls):
return isclass(obj) and issubclass(obj, cls)
def get_nrt_forward(obj, signature, config, outputs, *inputs,
is_method=False):
"""
Return a wrapped forward function that using nnf as runtime
"""
if not isinstance(obj, torch.nn.Module):
raise AssertionError(
"Internal bug, please report to "
"https://github.com/microsoft/nnfusion"
)
output_is_tensor = isinstance(outputs, torch.Tensor)
if output_is_tensor:
outputs = [outputs]
nnf = NNFusionRT(obj, config, signature)
nnf.compile(inputs, outputs)
# TODO free outputs and only save desc?
def forward(*inputs):
results = [
torch.empty_like(output)
for output in outputs
]
if is_method:
obj, *inputs = inputs
nnf.run_method(obj, inputs, results)
else:
inputs = list(inputs)
nnf.run(inputs, results)
if output_is_tensor:
return results[0]
return results
return forward
def nrt_forward(obj, *inputs, config=None, signature=None, is_method=False):
if signature is None:
signature = get_signature(obj)
if hasattr(obj, '_orig_forward'):
# shallow copy is needed to avoid recursion
# call instance forward -> call nnf_forward -> call instance forward
obj_ = copy.copy(obj)
obj_.forward = obj._orig_forward
obj = obj_
outputs = obj(*inputs)
def jit_class_method_using_decorator():
"""
Check if obj is a class method with @nnfusion.jit decorator.
The cases of decorating class method with the @ symbol or applying it
as function are different.
"""
return isinstance(inputs[0], torch.nn.Module)
if jit_class_method_using_decorator():
self, *inputs = inputs
# shallow copy is needed to avoid recursion when using jit as decorator:
# export onnx -> call forward to trace -> call nnf jit func -> export onnx
self_ = copy.copy(self)
def forward(*args):
if forward.first_call:
forward.first_call = False
return obj(self, *args)
# handle the case that jit target function will call `forward`
return self.forward(*args)
forward.first_call = True
self_.forward = forward
return get_nrt_forward(self_, signature, config, outputs,
*inputs, is_method=True)
if isfunction(obj) or is_method_of_instance(obj, torch.nn.Module):
return get_nrt_forward(TorchModule(obj), signature, config, outputs,
*inputs)
return get_nrt_forward(obj, signature, config, outputs, *inputs)
def parse_config(tune, tuning_steps, config):
if config is None:
config = Config()
elif type(config) is dict:
config = Config(config)
if not type(config) is Config:
raise TypeError(
"Expected optional 'config' argument of type dict or "
f"nnfusion.Config but found {config}"
)
if tuning_steps is not None:
if not isinstance(tuning_steps, int):
raise TypeError(
"Expected optional 'tuning_steps' argument of type int "
f"but found {tuning_steps}"
)
if tune is False:
raise ValueError(
f"Conflict is detected: tune={tune} and "
f"tuning_steps={tuning_steps}"
)
tune = True
config['kernel_tuning_steps'] = tuning_steps
if tune is not None:
if not isinstance(tune, bool):
raise TypeError(
"Expected optional 'tune' argument of type bool "
f"but found {tune}"
)
config['antares_mode'] = tune
return config
def check_obj_type(obj):
if not (
isfunction(obj)
or isinstance(obj, torch.nn.Module)
or is_subclass_of_cls(obj, torch.nn.Module)
or is_method_of_instance(obj, torch.nn.Module)
):
raise TypeError(
"Expected function or torch.nn.Module instance/method/class "
f"but found {obj}"
)
def jit_class(obj, config):
"""
Return jitted class using dynamic inheritance to override the forward
function and keep its signature.
"""
class JITModule(obj):
@jit(config=config,
_signature='.'.join([get_signature(obj), 'forward']))
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
return JITModule
def jit(obj=None, *, tune=None, tuning_steps=None, config=None, _signature=None):
"""
Parameters:
obj (function, `torch.nn.Module` instance/method/class):
The target object to be traced. When `obj` is an instance or a
class, it is equivalent to trace its `forward` function.
tune (Optional[bool]):
Whether to tune kernel. By default it follows `config`.
If set, it overwrites `config`.
tuning_steps (Optional[int]):
Number of kernel tuning steps. By default it follows `config`.
If set, it overwrites `config` and `tune`.
config (Optional[dict, nnfusion.Config]):
NNFusion compilation config.
By default it will be set to `nnfusion.Config()`.
Pass a `dict` to overwrite default config or directly pass an
instance of `nnfusion.Config`.
For example, `@nnfusion.jit(tune=True,
config={'kernel_tuning_steps': 42})`
For more flags information, please execute the command `nnfusion`
in the terminal.
"""
config = parse_config(tune, tuning_steps, config)
def _jit(_obj):
check_obj_type(_obj)
if is_subclass_of_cls(_obj, torch.nn.Module):
return jit_class(_obj, config)
@functools.wraps(_obj)
def wrapper(*args): # TODO support kwargs?
if wrapper.forward is None:
wrapper.forward = nrt_forward(_obj, *args,
config=config,
signature=_signature)
return wrapper.forward(*args)
wrapper.forward = None
if isinstance(_obj, torch.nn.Module):
_obj._orig_forward = _obj.forward
_obj.forward = wrapper
return _obj
return wrapper
if obj is None:
return _jit
return _jit(obj)

Просмотреть файл

@ -0,0 +1,34 @@
import inspect
import re
import torch
class TorchModule(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, *args, **kwargs):
return self.func(*args, **kwargs)
def get_signature(obj):
"""
Signature of an object to detect reusable kernel.
"""
if isinstance(obj, torch.nn.Module):
return get_signature(obj.__class__)
if not (
inspect.isfunction(obj)
or inspect.ismethod(obj)
or inspect.isclass(obj)
):
raise Exception(f"Not support type {obj} for get_signature")
signature = "-".join([obj.__module__, obj.__qualname__])
# Remove special chars to avoid the trouble of dealing with paths
return re.sub("[<>]", "", signature)

Просмотреть файл

@ -0,0 +1,132 @@
import os
import tempfile
from pathlib import Path
import torch
import torch.onnx
from .data_format import cast_pytorch_tensor
from .executor import Executor
from .session import build, codegen, modify_nnfusion_rt
from .utils import get_sha256_of_file, get_sha256_of_str
class NNFusionRT:
def __init__(self, model, config, signature, cache_dir="nnf-kernels"):
"""
Parameters:
model: the `torch.nn.Module` to be compiled.
config: nnfusion compilation config
signature (str): signature of model so that we can reuse compiled
kernel (if any).
cache_dir: path to save compiled kernels
"""
self.model = model
self.weight_dict = {
name: cast_pytorch_tensor(tensor)
for name, tensor in model.state_dict().items()
}
self.root_dir = Path(cache_dir) / signature
self.root_dir.mkdir(parents=True, exist_ok=True)
self.compile_flag = self._get_compile_flag(config)
self.executor = None
def compile(self, inputs, outputs):
"""
Perform nnfusion codegen and compilation for target input sizes.
Skip if a kernel with the same signature is found.
Parameters:
inputs: a list of model inputs.
outputs: a list of model outputs.
"""
def export_onnx(fname):
input_names = ["input" + str(i) for i in range(len(inputs))]
output_names = ["output" + str(i) for i in range(len(outputs))]
torch.onnx.export(self.model, inputs, fname,
input_names=input_names,
output_names=output_names,
export_params=False,
) # , opset_version=11)
def check_if_need_build():
"""
Note that this function assume no hash collision
"""
need_build = False
# Compare onnx file to check if modified
with tempfile.TemporaryDirectory(dir=self.root_dir) as tmp:
temp_onnx_path = Path(tmp) / "temp.onnx"
export_onnx(temp_onnx_path)
onnx_hash = get_sha256_of_file(temp_onnx_path)
flag_hash = get_sha256_of_str(self.compile_flag)
onnx_dir = self.root_dir / onnx_hash
flag_dir = onnx_dir / flag_hash
flag_dir.mkdir(parents=True, exist_ok=True)
onnx_path = onnx_dir / "model.onnx"
if not onnx_path.is_file():
os.link(temp_onnx_path, onnx_path)
need_build = True
nnf_dir = flag_dir / "nnfusion_rt" / "cuda_codegen"
if not nnf_dir.joinpath('libnnfusion_naive_rt.so').is_file():
need_build = True
return need_build, onnx_path, flag_dir, nnf_dir
def do_compile(onnx_path, work_dir, nnf_dir):
codegen(onnx_path, self.compile_flag, work_dir)
modify_nnfusion_rt(nnf_dir)
build(nnf_dir)
need_build, onnx_path, work_dir, nnf_dir = check_if_need_build()
if need_build:
do_compile(onnx_path, work_dir, nnf_dir)
self.executor = Executor(nnf_dir, device=inputs[0].device)
def run(self, inputs, outputs, weights=None):
"""
Perform the computation. The result will be saved in `outputs`.
Parameters:
inputs: the input tensor(s). Can be a list or tuple.
outputs: the output tensor(s). Can be a list or tuple.
"""
if weights is None:
weights = self.weight_dict
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
if not isinstance(outputs, (tuple, list)):
outputs = [outputs]
in_dict = dict(weights, **{
f'input{i}': cast_pytorch_tensor(tensor)
for i, tensor in enumerate(inputs)
})
out_dict = {
f'output{i}': cast_pytorch_tensor(tensor)
for i, tensor in enumerate(outputs)
}
self.executor(in_dict, out_dict, strict=False)
def run_method(self, obj, inputs, outputs):
weights = {
name: cast_pytorch_tensor(tensor)
for name, tensor in obj.state_dict().items()
}
return self.run(inputs, outputs, weights)
def _get_compile_flag(self, config):
return " ".join([
"-f onnx",
config.to_flag()
])

Просмотреть файл

@ -5,6 +5,7 @@ from contextlib import contextmanager
import os
import logging
import subprocess
import hashlib
logger = logging.getLogger(__name__)
@ -31,4 +32,16 @@ def execute(command, redirect_stderr=True, shell=True, **kwargs):
except subprocess.CalledProcessError as e:
logger.error(e.output)
raise e
return output
return output
def get_sha256_of_file(path, max_len=None):
hash_sha256 = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()[:max_len]
def get_sha256_of_str(string, max_len=None):
return hashlib.sha256(string.encode("utf-8")).hexdigest()[:max_len]

18
test/python/README.md Normal file
Просмотреть файл

@ -0,0 +1,18 @@
# NNFusion Python Tests
## Installation
```bash
cd path/to/nnfusion/test/python
python3 -m pip install -r requirements.txt
```
## Running tests
```bash
# A link is needed temporarily since installing locally (pip install -e .) is not yet supported.
ln -s ../../src/python/nnfusion nnfusion
pytest -v --forked
```

Просмотреть файл

@ -0,0 +1,2 @@
pytest>=7.0.0
pytest-xdist>=1.4.0

Просмотреть файл

@ -0,0 +1,26 @@
from nnfusion import Config
def test_config():
# default
config = Config()
assert config['kernel_tuning_steps'] == 1000
assert config['function_codegen'] is True
# init with kwargs
config = Config(kernel_tuning_steps=42,
function_codegen=False,
foo=True,)
assert config['kernel_tuning_steps'] == 42
assert config['function_codegen'] is False
assert config['foo'] is True
# init with dict
config = Config({
'kernel_tuning_steps': 42,
'function_codegen': False,
'foo': True,
})
assert config['kernel_tuning_steps'] == 42
assert config['function_codegen'] is False
assert config['foo'] is True

190
test/python/test_jit.py Normal file
Просмотреть файл

@ -0,0 +1,190 @@
import pytest
import torch
from torch.nn import functional as F
import nnfusion
def assert_allclose(output1, output2, rtol=1e-5, atol=1e-5):
if not isinstance(output1, (tuple, list)):
assert not isinstance(output2, (tuple, list))
assert output1.size() == output2.size()
assert torch.allclose(output1, output2, rtol, atol), (
output1, output2, output1.sub(output2).max())
return
assert isinstance(output2, (tuple, list))
assert len(output1) == len(output2)
for out1, out2 in zip(output1, output2):
assert_allclose(out1, out2, rtol, atol)
def compare_torch_and_nrt(obj, *inputs, step=1, run=None):
assert step >= 1
if step == 1:
result_torch = obj(*inputs)
result_nrt = nnfusion.jit(obj)(*inputs)
else:
def repeat(obj, *inputs):
for _ in range(step):
outputs, inputs = run(obj, *inputs)
return outputs
assert run is not None
result_torch = repeat(obj, *inputs)
result_nrt = repeat(nnfusion.jit(obj), *inputs)
assert_allclose(result_torch, result_nrt)
def test_single_input_multi_output():
def func(t):
return t + t, t * t
t = torch.randn(8, device="cuda")
compare_torch_and_nrt(func, t)
def test_multi_input_single_output():
def func(t1, t2):
return t1 + t2
t = [torch.randn(8, device="cuda") for _ in range(2)]
compare_torch_and_nrt(func, *t)
def test_multi_identical_input_single_output():
def func(t1, t2):
return t1 + t2
t = torch.randn(8, device="cuda")
compare_torch_and_nrt(func, t, t)
@pytest.mark.xfail(reason=(
"Probably identical tensors are fused during optimization. "
"May need a copy at backend"))
def test_single_input_multi_identical_output():
def func(t):
return t, t
t = torch.randn(8, device="cuda")
compare_torch_and_nrt(func, t)
@pytest.mark.xfail(reason="Compilation Error")
def test_single_input_multi_identical_output_advanced():
def func(t):
t2 = t + t
return t2, t2
t = torch.randn(8, device="cuda")
compare_torch_and_nrt(func, t)
def test_jit_instance_using_function():
model = torch.nn.Linear(8, 8).cuda().eval()
t = torch.randn(1, 8, device="cuda")
compare_torch_and_nrt(model, t)
def test_jit_class_method_using_decorator():
class Foo(torch.nn.Linear):
@nnfusion.jit
def foo(self, t):
return t + t
@nnfusion.jit
def bar(self, t):
return self.forward(t) + 1
model = Foo(8, 8).cuda().eval()
t = torch.randn(1, 8, device="cuda")
assert_allclose(t + t, model.foo(t))
assert_allclose(F.linear(t, model.weight, model.bias) + 1, model.bar(t))
class Bar(torch.nn.Linear):
@nnfusion.jit
def forward(self, t):
return super().forward(t)
model = Bar(8, 8).cuda().eval()
assert_allclose(F.linear(t, model.weight, model.bias), model(t))
def test_jit_class_method_using_function():
class Foo(torch.nn.Linear):
def foo(self, t):
return self.forward(t) + 1
t = torch.randn(1, 8, device="cuda")
model = Foo(8, 8).cuda().eval()
assert_allclose(F.linear(t, model.weight, model.bias) + 1, model.foo(t))
def test_jit_class_using_decorator():
def func(t):
return t + t
@nnfusion.jit
class Foo(torch.nn.Linear):
@nnfusion.jit
def foo(self, t):
return func(t)
model = Foo(8, 8).cuda().eval()
t = torch.randn(1, 8, device="cuda")
assert_allclose(F.linear(t, model.weight, model.bias), model(t))
assert_allclose(func(t), model.foo(t))
def test_jit_class_using_decorator_multi_instance():
@nnfusion.jit
class Foo(torch.nn.Linear):
pass
model1 = Foo(2, 2).cuda().eval()
model2 = Foo(2, 2).cuda().eval()
t = torch.randn(1, 2, device="cuda")
assert_allclose(F.linear(t, model1.weight, model1.bias), model1(t))
assert_allclose(F.linear(t, model2.weight, model2.bias), model2(t))
def test_jit_class_using_function():
LinearJIT = nnfusion.jit(torch.nn.Linear)
model = LinearJIT(8, 8).cuda().eval()
t = torch.randn(1, 8, device="cuda")
assert_allclose(F.linear(t, model.weight, model.bias), model(t))
def test_jit_with_kwargs():
@nnfusion.jit(tune=True, config=nnfusion.Config(kernel_tuning_steps=5))
def func(t):
return t + t
t = torch.randn(8, device="cuda")
assert_allclose(t + t, func(t))
@pytest.mark.xfail(reason=(
"nnfusion codegen and compile success exit with 0 "
"but para_info.json is null"))
def test_nested_jit():
@nnfusion.jit
def func1(t): return t + t
@nnfusion.jit
def func2(t): return func1(t) + 1
t = torch.randn(1, 8, device="cuda")
assert_allclose(t + t, func1(t))
assert_allclose(t + t + 1, func2(t))
@pytest.mark.parametrize("step", [1, 5])
def test_repeat(step):
def run(func, *inputs):
outputs = func(*inputs)
next_inputs = outputs
return outputs, next_inputs
def func(t1, t2):
return t1 + t2, t1 - t2
t = [torch.randn(8, device="cuda") for _ in range(2)]
compare_torch_and_nrt(func, *t, step=step, run=run)