Python API integration (#392)
* Implement jit demo version and add tests (#378) * Add signatures to detect reusable kernel + nnf rt mem reservation + jit optional kwargs (#379) * Add signatures to detect reusable kernel * Add reserve nnf workspace memory * Add jit optionally take arguments (keyword-only) Reference https://realpython.com/primer-on-python-decorators/#both-please-but-never-mind-the-bread * Fix who to reserve * Fix unit bug + don't check twice * Remove unused imports * Add comment to pytest * Fix missing kwargs * Support JIT for class and class method + nnf config + 3-lvl-signature (#398) * Add docstring * Support JIT for class method * Support decorator for class * Fix signature bug + add failure case * test_jit.py clean code * Change relative path to __module__ * Add nnfusion config * Impl. 3 level signature * Clean code + add test_config + del test_graph test_keep_signature_but_change_compute_graph can pass now but it makes no sense to be tested now since different kernels with the same object signature will no longer replace each other. * Clean code * Add docstrings * Clean code * Fix tune to follow docs * More pythonic http://www.kr41.net/2016/03-23-dont_inherit_python_builtin_dict_type.html * Update docstring * Add test case * Update docstring * Fix bug: decorator for method with multi instances * Clean code --amend * Add NNFusion-JIT docs Co-authored-by: Wenxiang Hu <8460860+wenxcs@users.noreply.github.com>
This commit is contained in:
Родитель
1ee7762413
Коммит
fc58f6bff9
|
@ -0,0 +1,117 @@
|
||||||
|
# How to use NNFusion JIT
|
||||||
|
|
||||||
|
NNFusion JIT is a JIT compiler for PyTorch using NNFusion CLI. It allows the user to JIT a function or a `torch.nn.Module` object using a decorator or a function wrapper. All NNFusion optimization such as kernel tuning can be applied using this interface.
|
||||||
|
|
||||||
|
## nnfusion.jit
|
||||||
|
|
||||||
|
`nnfusion.jit(obj=None, *, tune=None, tuning_steps=None, config=None)`
|
||||||
|
|
||||||
|
Lazily trace an object and optimize using nnfusion compilation until the object is called. It can be used as a decorator or an explicit function wrapper. The inputs and outputs of the object should be `torch.Tensor` or a sequence of `torch.Tensor`.
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
+ obj (function, `torch.nn.Module` instance / method / class):
|
||||||
|
+ Target object to be traced. When `obj` is an instance or a class, it is equivalent to tracing its `forward` function.
|
||||||
|
+ tune (Optional[bool]):
|
||||||
|
+ Whether to tune kernel. By default it follows `config`. If set, it overwrites `config`.
|
||||||
|
+ tuning_steps (Optional[int]):
|
||||||
|
+ Number of kernel tuning steps. By default it follows `config`. If set, it overwrites `config` and `tune`.
|
||||||
|
|
||||||
|
+ config (Optional[dict, nnfusion.Config]):
|
||||||
|
+ NNFusion compilation config. By default it will be set to default config `nnfusion.Config()`. Pass a `dict` to overwrite default config or directly pass an instance of `nnfusion.Config`.
|
||||||
|
+ For example, `@nnfusion.jit(tune=True, config={'kernel_tuning_steps': 42})`
|
||||||
|
+ For more flags information, please execute the command `nnfusion` in the terminal.
|
||||||
|
|
||||||
|
|
||||||
|
### Use Cases Demo
|
||||||
|
|
||||||
|
|
||||||
|
`nnfusion.jit` can be used as a function wrapper for standalone functions and `torch.nn.Module` instances/methods/classes.
|
||||||
|
|
||||||
|
|
||||||
|
It can also be used as a decorator for standalone functions and `torch.nn.Module` methods/classes.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Case 1: decorator for a standalone function
|
||||||
|
@nnfusion.jit
|
||||||
|
def foo(t1, t2):
|
||||||
|
return t1 + t2, t1 - t2
|
||||||
|
|
||||||
|
|
||||||
|
# Case 2: decorator for a class method
|
||||||
|
class Net(nn.Linear):
|
||||||
|
@nnfusion.jit
|
||||||
|
def foo(self, x):
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Case 3: decorator for a class
|
||||||
|
@nnfusion.jit
|
||||||
|
class Net(nn.Linear):
|
||||||
|
def this_will_not_be_traced(self, x):
|
||||||
|
return super().forward(x)
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Case 4: function for a standalone function
|
||||||
|
def foo(t1, t2):
|
||||||
|
return t1 + t2, t1 - t2
|
||||||
|
jitted_foo = nnfusion.jit(foo)
|
||||||
|
|
||||||
|
|
||||||
|
# Case 5: function for a torch.nn.Module class method
|
||||||
|
class Net(nn.Linear):
|
||||||
|
def foo(self, x):
|
||||||
|
return super().forward(x)
|
||||||
|
model = Net().eval()
|
||||||
|
model.foo = nnfusion.jit(model.foo)
|
||||||
|
|
||||||
|
|
||||||
|
# Case 6: function for a torch.nn.Module class
|
||||||
|
jitted_linear = nnfusion.jit(nn.Linear)
|
||||||
|
model = jitted_linear().eval()
|
||||||
|
|
||||||
|
|
||||||
|
# Case 7: function for a torch.nn.Module instance
|
||||||
|
class Net(nn.Linear):
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x)
|
||||||
|
model = Net().eval()
|
||||||
|
jitted_model = nnfusion.jit(model)
|
||||||
|
```
|
||||||
|
|
||||||
|
It is allowed to pass optional keyword arguments:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@nnfusion.jit(tune=True)
|
||||||
|
def foo(t1, t2):
|
||||||
|
return t1 + t2
|
||||||
|
|
||||||
|
def bar(t):
|
||||||
|
return t + t, t * t
|
||||||
|
jitted_bar = nnfusion.jit(bar, tuning_steps=2000)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compiled kernels caching strategies
|
||||||
|
|
||||||
|
The compiled kernels are saved in `nnf-kernels/`. If a "match" kernel is found before the compilation, it will be directly reused. Here, "match" means having the same object signature (`__module__` and `__qualname__`), computational graph (ONNX model binary), and NNFusion compilation config (`config`).
|
||||||
|
|
||||||
|
## nnfusion.Config
|
||||||
|
|
||||||
|
`nnfusion.Config(*args, antares_mode=True, blockfusion_level=0, extern_result_memory=True, function_codegen=True, ir_based_fusion=False, kernel_fusion_level=0, kernel_tuning_steps=1000, **kwargs)`
|
||||||
|
|
||||||
|
NNFusion compilation config. Can pass in any other NNFusion compiler flags (execute the command `nnfusion` in the terminal for more details) and unknown flags will be ignored. Use it as a `dict` with some default key-value pairs.
|
||||||
|
|
||||||
|
### Use Cases Demo
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = nnfusion.Config(function_codegen=False,
|
||||||
|
new_flag=42)
|
||||||
|
config = nnfusion.Config({'function_codegen': False,
|
||||||
|
'new_flag': 42})
|
||||||
|
config = nnfusion.Config()
|
||||||
|
config['function_codegen'] = False
|
||||||
|
config['new_flag'] = 42
|
||||||
|
```
|
|
@ -11,6 +11,7 @@ Welcome to the NNFusion Documents!
|
||||||
- [Compile-a-Tensorflow-model-with-NNFusion](Compile-a-Tensorflow-model-with-NNFusion.md)
|
- [Compile-a-Tensorflow-model-with-NNFusion](Compile-a-Tensorflow-model-with-NNFusion.md)
|
||||||
- [Compile-a-model-with-kernel-tuning-enabled](Compile-a-model-with-kernel-tuning-enabled.md)
|
- [Compile-a-model-with-kernel-tuning-enabled](Compile-a-model-with-kernel-tuning-enabled.md)
|
||||||
- [How to use NNFusion Python interface for inference/training](../src/python/example/README.md)
|
- [How to use NNFusion Python interface for inference/training](../src/python/example/README.md)
|
||||||
|
- [How to use NNFusion JIT](NNFusion-JIT.md)
|
||||||
4. [Guide-for-Contributors](Guide-for-Contributors.md)
|
4. [Guide-for-Contributors](Guide-for-Contributors.md)
|
||||||
- [Contribution-Guide](Contribution-Guide.md)
|
- [Contribution-Guide](Contribution-Guide.md)
|
||||||
- [Coding-Guide](Coding-Guide.md)
|
- [Coding-Guide](Coding-Guide.md)
|
||||||
|
|
|
@ -2,4 +2,7 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
__version__ = "0.3.0"
|
__version__ = "0.3.0"
|
||||||
__author__ = "Microsoft"
|
__author__ = "Microsoft"
|
||||||
|
|
||||||
|
from .jit import jit
|
||||||
|
from .config import Config
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
from collections.abc import MutableMapping
|
||||||
|
|
||||||
|
|
||||||
|
class Config(MutableMapping):
|
||||||
|
"""
|
||||||
|
NNFusion compilation config. Can pass in any other NNFusion compiler flags
|
||||||
|
(execute the command `nnfusion` in the terminal for more details) and
|
||||||
|
unknown flags will be ignored.
|
||||||
|
Use it as a `dict` with some default key-value pairs.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
*args,
|
||||||
|
antares_mode=True,
|
||||||
|
blockfusion_level=0,
|
||||||
|
extern_result_memory=True,
|
||||||
|
function_codegen=True,
|
||||||
|
ir_based_fusion=False,
|
||||||
|
kernel_fusion_level=0,
|
||||||
|
kernel_tuning_steps=1000,
|
||||||
|
**kwargs):
|
||||||
|
locals_ = locals()
|
||||||
|
|
||||||
|
self._storage = {
|
||||||
|
flag: locals_[flag]
|
||||||
|
for flag in self.__init__.__kwdefaults__
|
||||||
|
}
|
||||||
|
self._storage.update(dict(*args, **kwargs))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_flag_value(flag, value):
|
||||||
|
value = int(value) if isinstance(value, bool) else value
|
||||||
|
return f'-f{flag}={value}'
|
||||||
|
|
||||||
|
def to_flag(self):
|
||||||
|
return ' '.join([
|
||||||
|
self._parse_flag_value(flag, value)
|
||||||
|
for flag, value in sorted(self._storage.items())
|
||||||
|
])
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._storage)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._storage)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._storage[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self._storage[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self._storage[key]
|
|
@ -1,13 +1,15 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
import ctypes
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import json
|
|
||||||
import ctypes
|
import torch
|
||||||
from . import dtypes
|
|
||||||
from .utils import cd
|
from .data_format import cast_pytorch_tensor
|
||||||
from .description import IODescription
|
from .description import IODescription
|
||||||
|
from .utils import cd
|
||||||
|
|
||||||
|
|
||||||
def find_nnf_rt(nnf_rt_dir):
|
def find_nnf_rt(nnf_rt_dir):
|
||||||
|
@ -80,11 +82,13 @@ class Executor(object):
|
||||||
5: ("", ""), # UNKNOWN
|
5: ("", ""), # UNKNOWN
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, nnf_rt_dir):
|
def __init__(self, nnf_rt_dir, device=None):
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
nnf_rt_dir: A full string path to nnfusion runtime,
|
nnf_rt_dir: A full string path to nnfusion runtime,
|
||||||
it's usually like "codegen_root/nnfusion_rt/cuda_codegen".
|
it's usually like "codegen_root/nnfusion_rt/cuda_codegen".
|
||||||
|
device: A device type (`torch.device`) that is used for workspace
|
||||||
|
memory reservation (if needed) by nnfusion runtime.
|
||||||
"""
|
"""
|
||||||
nnf_rt_dir = os.path.abspath(nnf_rt_dir)
|
nnf_rt_dir = os.path.abspath(nnf_rt_dir)
|
||||||
self.libnnf_path = find_nnf_rt(nnf_rt_dir)
|
self.libnnf_path = find_nnf_rt(nnf_rt_dir)
|
||||||
|
@ -113,9 +117,14 @@ class Executor(object):
|
||||||
init_func_name, free_func_name = self.device_type_map[device_type]
|
init_func_name, free_func_name = self.device_type_map[device_type]
|
||||||
self.nnf_rt_init = getattr(self.libnnf, init_func_name, None)
|
self.nnf_rt_init = getattr(self.libnnf, init_func_name, None)
|
||||||
self.nnf_rt_free = getattr(self.libnnf, free_func_name, None)
|
self.nnf_rt_free = getattr(self.libnnf, free_func_name, None)
|
||||||
|
|
||||||
if self.nnf_rt_init:
|
if self.nnf_rt_init:
|
||||||
with cd(nnf_rt_dir):
|
with cd(nnf_rt_dir):
|
||||||
self.nnf_rt_init()
|
workspace_ptr = self._maybe_reserve_mem(device)
|
||||||
|
if workspace_ptr is not None:
|
||||||
|
self.nnf_rt_init(workspace_ptr)
|
||||||
|
else:
|
||||||
|
self.nnf_rt_init()
|
||||||
self.init_flag = True
|
self.init_flag = True
|
||||||
|
|
||||||
# parse input/output
|
# parse input/output
|
||||||
|
@ -218,4 +227,17 @@ class Executor(object):
|
||||||
|
|
||||||
def feed_pointers(self, signature, params):
|
def feed_pointers(self, signature, params):
|
||||||
self.kernel_entry.argtypes = signature
|
self.kernel_entry.argtypes = signature
|
||||||
self.kernel_entry(*params)
|
self.kernel_entry(*params)
|
||||||
|
|
||||||
|
def _maybe_reserve_mem(self, device):
|
||||||
|
get_workspace_size = getattr(self.libnnf, 'get_workspace_size', None)
|
||||||
|
if get_workspace_size is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
n_byte = get_workspace_size()
|
||||||
|
if not n_byte:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._reserved_mem = torch.empty(n_byte,
|
||||||
|
dtype=torch.int8, device=device)
|
||||||
|
return cast_pytorch_tensor(self._reserved_mem).pointer
|
||||||
|
|
|
@ -0,0 +1,219 @@
|
||||||
|
import copy
|
||||||
|
import functools
|
||||||
|
from inspect import isfunction, ismethod, isclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .jit_utils import TorchModule, get_signature
|
||||||
|
from .runtime import NNFusionRT
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
|
|
||||||
|
def is_method_of_instance(obj, cls):
|
||||||
|
return ismethod(obj) and isinstance(obj.__self__, cls)
|
||||||
|
|
||||||
|
|
||||||
|
def is_subclass_of_cls(obj, cls):
|
||||||
|
return isclass(obj) and issubclass(obj, cls)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nrt_forward(obj, signature, config, outputs, *inputs,
|
||||||
|
is_method=False):
|
||||||
|
"""
|
||||||
|
Return a wrapped forward function that using nnf as runtime
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(obj, torch.nn.Module):
|
||||||
|
raise AssertionError(
|
||||||
|
"Internal bug, please report to "
|
||||||
|
"https://github.com/microsoft/nnfusion"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_is_tensor = isinstance(outputs, torch.Tensor)
|
||||||
|
if output_is_tensor:
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
nnf = NNFusionRT(obj, config, signature)
|
||||||
|
nnf.compile(inputs, outputs)
|
||||||
|
|
||||||
|
# TODO free outputs and only save desc?
|
||||||
|
|
||||||
|
def forward(*inputs):
|
||||||
|
results = [
|
||||||
|
torch.empty_like(output)
|
||||||
|
for output in outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
if is_method:
|
||||||
|
obj, *inputs = inputs
|
||||||
|
nnf.run_method(obj, inputs, results)
|
||||||
|
else:
|
||||||
|
inputs = list(inputs)
|
||||||
|
nnf.run(inputs, results)
|
||||||
|
|
||||||
|
if output_is_tensor:
|
||||||
|
return results[0]
|
||||||
|
return results
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def nrt_forward(obj, *inputs, config=None, signature=None, is_method=False):
|
||||||
|
if signature is None:
|
||||||
|
signature = get_signature(obj)
|
||||||
|
|
||||||
|
if hasattr(obj, '_orig_forward'):
|
||||||
|
# shallow copy is needed to avoid recursion
|
||||||
|
# call instance forward -> call nnf_forward -> call instance forward
|
||||||
|
obj_ = copy.copy(obj)
|
||||||
|
obj_.forward = obj._orig_forward
|
||||||
|
obj = obj_
|
||||||
|
|
||||||
|
outputs = obj(*inputs)
|
||||||
|
|
||||||
|
def jit_class_method_using_decorator():
|
||||||
|
"""
|
||||||
|
Check if obj is a class method with @nnfusion.jit decorator.
|
||||||
|
The cases of decorating class method with the @ symbol or applying it
|
||||||
|
as function are different.
|
||||||
|
"""
|
||||||
|
return isinstance(inputs[0], torch.nn.Module)
|
||||||
|
|
||||||
|
if jit_class_method_using_decorator():
|
||||||
|
self, *inputs = inputs
|
||||||
|
|
||||||
|
# shallow copy is needed to avoid recursion when using jit as decorator:
|
||||||
|
# export onnx -> call forward to trace -> call nnf jit func -> export onnx
|
||||||
|
self_ = copy.copy(self)
|
||||||
|
|
||||||
|
def forward(*args):
|
||||||
|
if forward.first_call:
|
||||||
|
forward.first_call = False
|
||||||
|
return obj(self, *args)
|
||||||
|
# handle the case that jit target function will call `forward`
|
||||||
|
return self.forward(*args)
|
||||||
|
forward.first_call = True
|
||||||
|
self_.forward = forward
|
||||||
|
|
||||||
|
return get_nrt_forward(self_, signature, config, outputs,
|
||||||
|
*inputs, is_method=True)
|
||||||
|
|
||||||
|
if isfunction(obj) or is_method_of_instance(obj, torch.nn.Module):
|
||||||
|
return get_nrt_forward(TorchModule(obj), signature, config, outputs,
|
||||||
|
*inputs)
|
||||||
|
return get_nrt_forward(obj, signature, config, outputs, *inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_config(tune, tuning_steps, config):
|
||||||
|
if config is None:
|
||||||
|
config = Config()
|
||||||
|
elif type(config) is dict:
|
||||||
|
config = Config(config)
|
||||||
|
|
||||||
|
if not type(config) is Config:
|
||||||
|
raise TypeError(
|
||||||
|
"Expected optional 'config' argument of type dict or "
|
||||||
|
f"nnfusion.Config but found {config}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if tuning_steps is not None:
|
||||||
|
if not isinstance(tuning_steps, int):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected optional 'tuning_steps' argument of type int "
|
||||||
|
f"but found {tuning_steps}"
|
||||||
|
)
|
||||||
|
if tune is False:
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflict is detected: tune={tune} and "
|
||||||
|
f"tuning_steps={tuning_steps}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tune = True
|
||||||
|
config['kernel_tuning_steps'] = tuning_steps
|
||||||
|
|
||||||
|
if tune is not None:
|
||||||
|
if not isinstance(tune, bool):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected optional 'tune' argument of type bool "
|
||||||
|
f"but found {tune}"
|
||||||
|
)
|
||||||
|
config['antares_mode'] = tune
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def check_obj_type(obj):
|
||||||
|
if not (
|
||||||
|
isfunction(obj)
|
||||||
|
or isinstance(obj, torch.nn.Module)
|
||||||
|
or is_subclass_of_cls(obj, torch.nn.Module)
|
||||||
|
or is_method_of_instance(obj, torch.nn.Module)
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected function or torch.nn.Module instance/method/class "
|
||||||
|
f"but found {obj}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def jit_class(obj, config):
|
||||||
|
"""
|
||||||
|
Return jitted class using dynamic inheritance to override the forward
|
||||||
|
function and keep its signature.
|
||||||
|
"""
|
||||||
|
class JITModule(obj):
|
||||||
|
@jit(config=config,
|
||||||
|
_signature='.'.join([get_signature(obj), 'forward']))
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
return JITModule
|
||||||
|
|
||||||
|
|
||||||
|
def jit(obj=None, *, tune=None, tuning_steps=None, config=None, _signature=None):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
obj (function, `torch.nn.Module` instance/method/class):
|
||||||
|
The target object to be traced. When `obj` is an instance or a
|
||||||
|
class, it is equivalent to trace its `forward` function.
|
||||||
|
tune (Optional[bool]):
|
||||||
|
Whether to tune kernel. By default it follows `config`.
|
||||||
|
If set, it overwrites `config`.
|
||||||
|
tuning_steps (Optional[int]):
|
||||||
|
Number of kernel tuning steps. By default it follows `config`.
|
||||||
|
If set, it overwrites `config` and `tune`.
|
||||||
|
config (Optional[dict, nnfusion.Config]):
|
||||||
|
NNFusion compilation config.
|
||||||
|
By default it will be set to `nnfusion.Config()`.
|
||||||
|
Pass a `dict` to overwrite default config or directly pass an
|
||||||
|
instance of `nnfusion.Config`.
|
||||||
|
For example, `@nnfusion.jit(tune=True,
|
||||||
|
config={'kernel_tuning_steps': 42})`
|
||||||
|
For more flags information, please execute the command `nnfusion`
|
||||||
|
in the terminal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = parse_config(tune, tuning_steps, config)
|
||||||
|
|
||||||
|
def _jit(_obj):
|
||||||
|
|
||||||
|
check_obj_type(_obj)
|
||||||
|
|
||||||
|
if is_subclass_of_cls(_obj, torch.nn.Module):
|
||||||
|
return jit_class(_obj, config)
|
||||||
|
|
||||||
|
@functools.wraps(_obj)
|
||||||
|
def wrapper(*args): # TODO support kwargs?
|
||||||
|
if wrapper.forward is None:
|
||||||
|
wrapper.forward = nrt_forward(_obj, *args,
|
||||||
|
config=config,
|
||||||
|
signature=_signature)
|
||||||
|
return wrapper.forward(*args)
|
||||||
|
wrapper.forward = None
|
||||||
|
|
||||||
|
if isinstance(_obj, torch.nn.Module):
|
||||||
|
_obj._orig_forward = _obj.forward
|
||||||
|
_obj.forward = wrapper
|
||||||
|
return _obj
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
if obj is None:
|
||||||
|
return _jit
|
||||||
|
return _jit(obj)
|
|
@ -0,0 +1,34 @@
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TorchModule(torch.nn.Module):
|
||||||
|
def __init__(self, func):
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_signature(obj):
|
||||||
|
"""
|
||||||
|
Signature of an object to detect reusable kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(obj, torch.nn.Module):
|
||||||
|
return get_signature(obj.__class__)
|
||||||
|
|
||||||
|
if not (
|
||||||
|
inspect.isfunction(obj)
|
||||||
|
or inspect.ismethod(obj)
|
||||||
|
or inspect.isclass(obj)
|
||||||
|
):
|
||||||
|
raise Exception(f"Not support type {obj} for get_signature")
|
||||||
|
|
||||||
|
signature = "-".join([obj.__module__, obj.__qualname__])
|
||||||
|
|
||||||
|
# Remove special chars to avoid the trouble of dealing with paths
|
||||||
|
return re.sub("[<>]", "", signature)
|
|
@ -0,0 +1,132 @@
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.onnx
|
||||||
|
|
||||||
|
from .data_format import cast_pytorch_tensor
|
||||||
|
from .executor import Executor
|
||||||
|
from .session import build, codegen, modify_nnfusion_rt
|
||||||
|
from .utils import get_sha256_of_file, get_sha256_of_str
|
||||||
|
|
||||||
|
|
||||||
|
class NNFusionRT:
|
||||||
|
def __init__(self, model, config, signature, cache_dir="nnf-kernels"):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
model: the `torch.nn.Module` to be compiled.
|
||||||
|
config: nnfusion compilation config
|
||||||
|
signature (str): signature of model so that we can reuse compiled
|
||||||
|
kernel (if any).
|
||||||
|
cache_dir: path to save compiled kernels
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.weight_dict = {
|
||||||
|
name: cast_pytorch_tensor(tensor)
|
||||||
|
for name, tensor in model.state_dict().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.root_dir = Path(cache_dir) / signature
|
||||||
|
self.root_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.compile_flag = self._get_compile_flag(config)
|
||||||
|
self.executor = None
|
||||||
|
|
||||||
|
def compile(self, inputs, outputs):
|
||||||
|
"""
|
||||||
|
Perform nnfusion codegen and compilation for target input sizes.
|
||||||
|
Skip if a kernel with the same signature is found.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
inputs: a list of model inputs.
|
||||||
|
outputs: a list of model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def export_onnx(fname):
|
||||||
|
input_names = ["input" + str(i) for i in range(len(inputs))]
|
||||||
|
output_names = ["output" + str(i) for i in range(len(outputs))]
|
||||||
|
torch.onnx.export(self.model, inputs, fname,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
export_params=False,
|
||||||
|
) # , opset_version=11)
|
||||||
|
|
||||||
|
def check_if_need_build():
|
||||||
|
"""
|
||||||
|
Note that this function assume no hash collision
|
||||||
|
"""
|
||||||
|
need_build = False
|
||||||
|
|
||||||
|
# Compare onnx file to check if modified
|
||||||
|
with tempfile.TemporaryDirectory(dir=self.root_dir) as tmp:
|
||||||
|
temp_onnx_path = Path(tmp) / "temp.onnx"
|
||||||
|
export_onnx(temp_onnx_path)
|
||||||
|
|
||||||
|
onnx_hash = get_sha256_of_file(temp_onnx_path)
|
||||||
|
flag_hash = get_sha256_of_str(self.compile_flag)
|
||||||
|
|
||||||
|
onnx_dir = self.root_dir / onnx_hash
|
||||||
|
flag_dir = onnx_dir / flag_hash
|
||||||
|
flag_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
onnx_path = onnx_dir / "model.onnx"
|
||||||
|
if not onnx_path.is_file():
|
||||||
|
os.link(temp_onnx_path, onnx_path)
|
||||||
|
need_build = True
|
||||||
|
|
||||||
|
nnf_dir = flag_dir / "nnfusion_rt" / "cuda_codegen"
|
||||||
|
if not nnf_dir.joinpath('libnnfusion_naive_rt.so').is_file():
|
||||||
|
need_build = True
|
||||||
|
|
||||||
|
return need_build, onnx_path, flag_dir, nnf_dir
|
||||||
|
|
||||||
|
def do_compile(onnx_path, work_dir, nnf_dir):
|
||||||
|
codegen(onnx_path, self.compile_flag, work_dir)
|
||||||
|
modify_nnfusion_rt(nnf_dir)
|
||||||
|
build(nnf_dir)
|
||||||
|
|
||||||
|
need_build, onnx_path, work_dir, nnf_dir = check_if_need_build()
|
||||||
|
if need_build:
|
||||||
|
do_compile(onnx_path, work_dir, nnf_dir)
|
||||||
|
|
||||||
|
self.executor = Executor(nnf_dir, device=inputs[0].device)
|
||||||
|
|
||||||
|
def run(self, inputs, outputs, weights=None):
|
||||||
|
"""
|
||||||
|
Perform the computation. The result will be saved in `outputs`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
inputs: the input tensor(s). Can be a list or tuple.
|
||||||
|
outputs: the output tensor(s). Can be a list or tuple.
|
||||||
|
"""
|
||||||
|
if weights is None:
|
||||||
|
weights = self.weight_dict
|
||||||
|
if not isinstance(inputs, (tuple, list)):
|
||||||
|
inputs = [inputs]
|
||||||
|
if not isinstance(outputs, (tuple, list)):
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
in_dict = dict(weights, **{
|
||||||
|
f'input{i}': cast_pytorch_tensor(tensor)
|
||||||
|
for i, tensor in enumerate(inputs)
|
||||||
|
})
|
||||||
|
out_dict = {
|
||||||
|
f'output{i}': cast_pytorch_tensor(tensor)
|
||||||
|
for i, tensor in enumerate(outputs)
|
||||||
|
}
|
||||||
|
self.executor(in_dict, out_dict, strict=False)
|
||||||
|
|
||||||
|
def run_method(self, obj, inputs, outputs):
|
||||||
|
weights = {
|
||||||
|
name: cast_pytorch_tensor(tensor)
|
||||||
|
for name, tensor in obj.state_dict().items()
|
||||||
|
}
|
||||||
|
return self.run(inputs, outputs, weights)
|
||||||
|
|
||||||
|
def _get_compile_flag(self, config):
|
||||||
|
return " ".join([
|
||||||
|
"-f onnx",
|
||||||
|
config.to_flag()
|
||||||
|
])
|
|
@ -5,6 +5,7 @@ from contextlib import contextmanager
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import hashlib
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -31,4 +32,16 @@ def execute(command, redirect_stderr=True, shell=True, **kwargs):
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logger.error(e.output)
|
logger.error(e.output)
|
||||||
raise e
|
raise e
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_sha256_of_file(path, max_len=None):
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
return hash_sha256.hexdigest()[:max_len]
|
||||||
|
|
||||||
|
|
||||||
|
def get_sha256_of_str(string, max_len=None):
|
||||||
|
return hashlib.sha256(string.encode("utf-8")).hexdigest()[:max_len]
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# NNFusion Python Tests
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd path/to/nnfusion/test/python
|
||||||
|
|
||||||
|
python3 -m pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# A link is needed temporarily since installing locally (pip install -e .) is not yet supported.
|
||||||
|
ln -s ../../src/python/nnfusion nnfusion
|
||||||
|
|
||||||
|
pytest -v --forked
|
||||||
|
```
|
|
@ -0,0 +1,2 @@
|
||||||
|
pytest>=7.0.0
|
||||||
|
pytest-xdist>=1.4.0
|
|
@ -0,0 +1,26 @@
|
||||||
|
from nnfusion import Config
|
||||||
|
|
||||||
|
|
||||||
|
def test_config():
|
||||||
|
# default
|
||||||
|
config = Config()
|
||||||
|
assert config['kernel_tuning_steps'] == 1000
|
||||||
|
assert config['function_codegen'] is True
|
||||||
|
|
||||||
|
# init with kwargs
|
||||||
|
config = Config(kernel_tuning_steps=42,
|
||||||
|
function_codegen=False,
|
||||||
|
foo=True,)
|
||||||
|
assert config['kernel_tuning_steps'] == 42
|
||||||
|
assert config['function_codegen'] is False
|
||||||
|
assert config['foo'] is True
|
||||||
|
|
||||||
|
# init with dict
|
||||||
|
config = Config({
|
||||||
|
'kernel_tuning_steps': 42,
|
||||||
|
'function_codegen': False,
|
||||||
|
'foo': True,
|
||||||
|
})
|
||||||
|
assert config['kernel_tuning_steps'] == 42
|
||||||
|
assert config['function_codegen'] is False
|
||||||
|
assert config['foo'] is True
|
|
@ -0,0 +1,190 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import nnfusion
|
||||||
|
|
||||||
|
|
||||||
|
def assert_allclose(output1, output2, rtol=1e-5, atol=1e-5):
|
||||||
|
|
||||||
|
if not isinstance(output1, (tuple, list)):
|
||||||
|
assert not isinstance(output2, (tuple, list))
|
||||||
|
assert output1.size() == output2.size()
|
||||||
|
assert torch.allclose(output1, output2, rtol, atol), (
|
||||||
|
output1, output2, output1.sub(output2).max())
|
||||||
|
return
|
||||||
|
|
||||||
|
assert isinstance(output2, (tuple, list))
|
||||||
|
assert len(output1) == len(output2)
|
||||||
|
|
||||||
|
for out1, out2 in zip(output1, output2):
|
||||||
|
assert_allclose(out1, out2, rtol, atol)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_torch_and_nrt(obj, *inputs, step=1, run=None):
|
||||||
|
assert step >= 1
|
||||||
|
|
||||||
|
if step == 1:
|
||||||
|
result_torch = obj(*inputs)
|
||||||
|
result_nrt = nnfusion.jit(obj)(*inputs)
|
||||||
|
else:
|
||||||
|
def repeat(obj, *inputs):
|
||||||
|
for _ in range(step):
|
||||||
|
outputs, inputs = run(obj, *inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
assert run is not None
|
||||||
|
result_torch = repeat(obj, *inputs)
|
||||||
|
result_nrt = repeat(nnfusion.jit(obj), *inputs)
|
||||||
|
|
||||||
|
assert_allclose(result_torch, result_nrt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_input_multi_output():
|
||||||
|
def func(t):
|
||||||
|
return t + t, t * t
|
||||||
|
t = torch.randn(8, device="cuda")
|
||||||
|
compare_torch_and_nrt(func, t)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_input_single_output():
|
||||||
|
def func(t1, t2):
|
||||||
|
return t1 + t2
|
||||||
|
t = [torch.randn(8, device="cuda") for _ in range(2)]
|
||||||
|
compare_torch_and_nrt(func, *t)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_identical_input_single_output():
|
||||||
|
def func(t1, t2):
|
||||||
|
return t1 + t2
|
||||||
|
t = torch.randn(8, device="cuda")
|
||||||
|
compare_torch_and_nrt(func, t, t)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason=(
|
||||||
|
"Probably identical tensors are fused during optimization. "
|
||||||
|
"May need a copy at backend"))
|
||||||
|
def test_single_input_multi_identical_output():
|
||||||
|
def func(t):
|
||||||
|
return t, t
|
||||||
|
t = torch.randn(8, device="cuda")
|
||||||
|
compare_torch_and_nrt(func, t)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Compilation Error")
|
||||||
|
def test_single_input_multi_identical_output_advanced():
|
||||||
|
def func(t):
|
||||||
|
t2 = t + t
|
||||||
|
return t2, t2
|
||||||
|
t = torch.randn(8, device="cuda")
|
||||||
|
compare_torch_and_nrt(func, t)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_instance_using_function():
|
||||||
|
model = torch.nn.Linear(8, 8).cuda().eval()
|
||||||
|
t = torch.randn(1, 8, device="cuda")
|
||||||
|
compare_torch_and_nrt(model, t)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_class_method_using_decorator():
|
||||||
|
class Foo(torch.nn.Linear):
|
||||||
|
@nnfusion.jit
|
||||||
|
def foo(self, t):
|
||||||
|
return t + t
|
||||||
|
@nnfusion.jit
|
||||||
|
def bar(self, t):
|
||||||
|
return self.forward(t) + 1
|
||||||
|
|
||||||
|
model = Foo(8, 8).cuda().eval()
|
||||||
|
t = torch.randn(1, 8, device="cuda")
|
||||||
|
assert_allclose(t + t, model.foo(t))
|
||||||
|
assert_allclose(F.linear(t, model.weight, model.bias) + 1, model.bar(t))
|
||||||
|
|
||||||
|
class Bar(torch.nn.Linear):
|
||||||
|
@nnfusion.jit
|
||||||
|
def forward(self, t):
|
||||||
|
return super().forward(t)
|
||||||
|
model = Bar(8, 8).cuda().eval()
|
||||||
|
assert_allclose(F.linear(t, model.weight, model.bias), model(t))
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_class_method_using_function():
|
||||||
|
class Foo(torch.nn.Linear):
|
||||||
|
def foo(self, t):
|
||||||
|
return self.forward(t) + 1
|
||||||
|
|
||||||
|
t = torch.randn(1, 8, device="cuda")
|
||||||
|
model = Foo(8, 8).cuda().eval()
|
||||||
|
assert_allclose(F.linear(t, model.weight, model.bias) + 1, model.foo(t))
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_class_using_decorator():
|
||||||
|
def func(t):
|
||||||
|
return t + t
|
||||||
|
|
||||||
|
@nnfusion.jit
|
||||||
|
class Foo(torch.nn.Linear):
|
||||||
|
@nnfusion.jit
|
||||||
|
def foo(self, t):
|
||||||
|
return func(t)
|
||||||
|
|
||||||
|
model = Foo(8, 8).cuda().eval()
|
||||||
|
t = torch.randn(1, 8, device="cuda")
|
||||||
|
assert_allclose(F.linear(t, model.weight, model.bias), model(t))
|
||||||
|
assert_allclose(func(t), model.foo(t))
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_class_using_decorator_multi_instance():
|
||||||
|
@nnfusion.jit
|
||||||
|
class Foo(torch.nn.Linear):
|
||||||
|
pass
|
||||||
|
model1 = Foo(2, 2).cuda().eval()
|
||||||
|
model2 = Foo(2, 2).cuda().eval()
|
||||||
|
t = torch.randn(1, 2, device="cuda")
|
||||||
|
assert_allclose(F.linear(t, model1.weight, model1.bias), model1(t))
|
||||||
|
assert_allclose(F.linear(t, model2.weight, model2.bias), model2(t))
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_class_using_function():
|
||||||
|
LinearJIT = nnfusion.jit(torch.nn.Linear)
|
||||||
|
|
||||||
|
model = LinearJIT(8, 8).cuda().eval()
|
||||||
|
t = torch.randn(1, 8, device="cuda")
|
||||||
|
assert_allclose(F.linear(t, model.weight, model.bias), model(t))
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_with_kwargs():
|
||||||
|
@nnfusion.jit(tune=True, config=nnfusion.Config(kernel_tuning_steps=5))
|
||||||
|
def func(t):
|
||||||
|
return t + t
|
||||||
|
t = torch.randn(8, device="cuda")
|
||||||
|
assert_allclose(t + t, func(t))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason=(
|
||||||
|
"nnfusion codegen and compile success exit with 0 "
|
||||||
|
"but para_info.json is null"))
|
||||||
|
def test_nested_jit():
|
||||||
|
@nnfusion.jit
|
||||||
|
def func1(t): return t + t
|
||||||
|
|
||||||
|
@nnfusion.jit
|
||||||
|
def func2(t): return func1(t) + 1
|
||||||
|
|
||||||
|
t = torch.randn(1, 8, device="cuda")
|
||||||
|
assert_allclose(t + t, func1(t))
|
||||||
|
assert_allclose(t + t + 1, func2(t))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("step", [1, 5])
|
||||||
|
def test_repeat(step):
|
||||||
|
def run(func, *inputs):
|
||||||
|
outputs = func(*inputs)
|
||||||
|
next_inputs = outputs
|
||||||
|
return outputs, next_inputs
|
||||||
|
|
||||||
|
def func(t1, t2):
|
||||||
|
return t1 + t2, t1 - t2
|
||||||
|
|
||||||
|
t = [torch.randn(8, device="cuda") for _ in range(2)]
|
||||||
|
compare_torch_and_nrt(func, *t, step=step, run=run)
|
Загрузка…
Ссылка в новой задаче