зеркало из https://github.com/microsoft/torchy.git
609 строки
20 KiB
Python
609 строки
20 KiB
Python
# Copyright (c) 2021-present The Torchy Authors.
|
|
# Distributed under the MIT license that can be found in the LICENSE file.
|
|
|
|
PYTORCH = '../pytorch'
|
|
|
|
from typings_data import *
|
|
import sys
|
|
sys.path.append(PYTORCH)
|
|
from tools.codegen.gen import *
|
|
from tools.codegen.api import types
|
|
|
|
yaml_path = PYTORCH + '/aten/src/ATen/native/native_functions.yaml'
|
|
native_functions = parse_native_yaml(yaml_path)
|
|
|
|
dtype_exceptions = {
|
|
}
|
|
|
|
shape_exceptions = {
|
|
'arange.start_out' : 'ARANGE',
|
|
'arange.start_step' : 'ARANGE',
|
|
'cat' : 'CAT',
|
|
'conv2d' : 'CONV2D',
|
|
'embedding' : 'EMBEDDING',
|
|
'max_pool2d' : 'CONV2D',
|
|
'mkldnn_convolution': 'CONV2D2',
|
|
'stack' : 'STACK',
|
|
'stack.out' : 'STACK',
|
|
'transpose_' : '',
|
|
}
|
|
|
|
strides_exceptions = {
|
|
'_s_where' : 'CONTIGUOUS',
|
|
'clone' : 'CLONE',
|
|
'conv2d' : 'CONTIGUOUS',
|
|
'embedding' : 'CONTIGUOUS',
|
|
'max_pool2d' : 'CONTIGUOUS',
|
|
'mkldnn_convolution': 'CONTIGUOUS',
|
|
}
|
|
|
|
def get_dtype_infer_fn(fn):
|
|
name = str(fn.func.name)
|
|
return dtype_exceptions.get(name, type_inference.get(name))
|
|
|
|
def get_shape_infer_fn(fn):
|
|
name = str(fn.func.name)
|
|
return shape_exceptions.get(name, shape_inference.get(name))
|
|
|
|
def get_strides_infer_fn(fn):
|
|
name = str(fn.func.name)
|
|
return strides_exceptions.get(name, strides_inference.get(name))
|
|
|
|
|
|
@with_native_function
|
|
def skip_fn(fn):
|
|
allowed_ret_types = {
|
|
'at::Tensor',
|
|
'at::Tensor &',
|
|
'const at::Tensor &',
|
|
}
|
|
dispatcher_sig = DispatcherSignature.from_schema(fn.func)
|
|
rettype = dispatcher_sig.returns_type().cpp_type()
|
|
if rettype not in allowed_ret_types:
|
|
return True
|
|
return False
|
|
|
|
def wrapper_name(fn):
|
|
return 'wrap_' + str(fn.func.name).replace('.', '_')
|
|
|
|
def fn_enum(fn):
|
|
return 'H_' + str(fn.func.name).replace('.', '_').upper()
|
|
|
|
def get_arg_of_type(args, type):
|
|
for arg in args:
|
|
if arg.type.cpp_type(strip_ref=True) == type:
|
|
return arg
|
|
return None
|
|
|
|
def maybe_tensor(type):
|
|
types = {
|
|
'at::Tensor',
|
|
'at::TensorList',
|
|
'c10::List<c10::optional<at::Tensor>>',
|
|
'c10::optional<at::Tensor>',
|
|
}
|
|
return type.remove_const_ref().cpp_type() in types
|
|
|
|
def to_scalar_type(v):
|
|
ty = v.type.remove_const_ref().cpp_type()
|
|
if ty == 'at::Tensor':
|
|
return f'{v.expr}.scalar_type()'
|
|
if ty == 'c10::optional<at::ScalarType>':
|
|
return f'{v.expr}.value_or(ScalarType::Undefined)'
|
|
print('to_scalar_type', ty)
|
|
exit(-1)
|
|
|
|
def is_type_arg(arg):
|
|
type = arg.type.remove_const_ref().cpp_type()
|
|
dispatch_types = [
|
|
'at::Scalar',
|
|
'at::ScalarType',
|
|
'c10::optional<at::Scalar>',
|
|
'c10::optional<at::ScalarType>',
|
|
]
|
|
return 'Tensor' in type or type in dispatch_types
|
|
|
|
def to_dtype(arg):
|
|
type = arg.type.cpp_type()
|
|
if type == 'at::ScalarType':
|
|
return arg.expr
|
|
return f'{arg.expr}.dtype()'
|
|
|
|
def to_scalartype(arg):
|
|
type = arg.type.remove_const_ref().cpp_type()
|
|
if type == 'at::ScalarType':
|
|
return arg.expr
|
|
if type == 'at::Scalar':
|
|
return f'{arg.expr}.type()'
|
|
return f'{arg.expr}.scalar_type()'
|
|
|
|
|
|
def mk_dtype_infer(type, all_args):
|
|
args = [arg for arg in all_args if is_type_arg(arg)]
|
|
|
|
if type[0:3] == 'ALL':
|
|
return f'k{type[4:]}'
|
|
if type == 'BOOL2INT':
|
|
return f'bool_to_int({args[0].expr}.scalar_type())'
|
|
if type == 'EQ_PROMOTED':
|
|
return f'promote_tys({", ".join(t.expr for t in args)})'
|
|
if type == 'EQ_PROMOTED_BUGGY':
|
|
return f'promote_buggy({", ".join(t.expr for t in args)})'
|
|
if type == 'EQ_PROMOTED_BUGGY2':
|
|
return f'promote_buggy({args[0].expr}, {args[1].expr})'
|
|
if type == 'EQ_PROMOTED_CONST':
|
|
return f'promote_const({", ".join(t.expr for t in args)})'
|
|
if type == 'EQ_SECOND':
|
|
return to_dtype(args[1])
|
|
if type == 'EQ_THIRD':
|
|
return to_dtype(args[2])
|
|
if type == 'EQ_FOURTH':
|
|
return to_dtype(args[3])
|
|
if type == 'BOOLBYTE':
|
|
return f'bool_byte({args[0].expr}.scalar_type())'
|
|
if type == 'BOOL2INT':
|
|
return f'bool_to_int({args[0].expr}.scalar_type())'
|
|
if type == 'INTEGRAL2INT':
|
|
return f'integrals_to_int({args[0].expr}.scalar_type())'
|
|
if type == 'TO_COMPLEX':
|
|
return f'to_complex({args[0].expr}.scalar_type())'
|
|
if type == 'TO_DOUBLE2':
|
|
return f'to_double2({to_scalartype(args[0])}, {to_scalartype(args[1])})'
|
|
if type == 'TO_FLOAT':
|
|
return f'to_float({args[0].expr}.scalar_type())'
|
|
if type == 'TO_FLOAT_DOUBLE':
|
|
return f'to_float_double({args[0].expr}.scalar_type())'
|
|
if type == 'TO_FLOAT2':
|
|
return f'to_float2({args[0].expr}, {args[1].expr})'
|
|
if type == 'TO_FLOAT3':
|
|
return f'to_float3({args[0].expr}, {args[1].expr}, {args[2].expr})'
|
|
if type == 'TO_QINT':
|
|
return f'toQIntType({args[0].expr}.scalar_type())'
|
|
if type == 'TO_REAL2':
|
|
return f'to_real2({args[0].expr}, {args[1].expr})'
|
|
if type == 'TO_REAL_FLOAT':
|
|
return f'to_real_float({to_scalar_type(args[0])})'
|
|
if type == 'TO_VALUE_TYPE':
|
|
return f'toValueType({args[0].expr}.scalar_type())'
|
|
if type == 'OPTIONAL_OR21':
|
|
return f'optional_or_else({args[1].expr}, {args[0].expr}.scalar_type())'
|
|
if type == 'OPTIONAL_OR31':
|
|
return f'optional_or_else({args[2].expr}, {args[0].expr}.scalar_type())'
|
|
if type == 'OPTIONAL_O21LONG':
|
|
return f'optional_or_longelse({args[1].expr}, {args[0].expr}.scalar_type())'
|
|
if type == 'FIRST_OR_DEFAULT':
|
|
return args[0].expr
|
|
if type == 'SECOND_OR_DEFAULT':
|
|
return args[1].expr
|
|
if type == 'FIRST_OR_LONG':
|
|
return f'optional_or_else({args[0].expr}, kLong)'
|
|
if type == 'SECOND_OR_LONG_DEFAULT':
|
|
return f'optional_or_longdefault({args[1].expr}, {args[0].expr}.type())'
|
|
print('mk_dtype_infer', type)
|
|
exit()
|
|
|
|
|
|
def get_dtype_arg(all_tensors, args, name):
|
|
tensors = [a.expr for a in all_tensors if a.type.remove_const_ref().cpp_type() == 'at::Tensor']
|
|
tensor_lst = [a.expr for a in all_tensors if a.type.remove_const_ref().cpp_type() == 'at::TensorList']
|
|
dtype = 'nullopt'
|
|
device = 'nullopt'
|
|
if tensors:
|
|
dtype = f'{tensors[0]}.dtype()'
|
|
device = f'{tensors[0]}.device()'
|
|
elif tensor_lst:
|
|
device = f'device_of({tensor_lst[0]})'
|
|
|
|
device_arg = get_arg_of_type(args, 'at::Device')
|
|
if device_arg:
|
|
device = device_arg.expr
|
|
|
|
device_arg = get_arg_of_type(args, 'c10::optional<at::Device>')
|
|
if device_arg:
|
|
device = device_arg.expr
|
|
|
|
dtype_fn = get_dtype_infer_fn(fn)
|
|
if dtype_fn:
|
|
dtype = mk_dtype_infer(dtype_fn, args)
|
|
else:
|
|
dtype_arg = get_arg_of_type(args, 'at::ScalarType')
|
|
if dtype_arg:
|
|
dtype = dtype_arg.expr
|
|
|
|
dtype_arg = get_arg_of_type(args, 'c10::optional<at::ScalarType>')
|
|
if dtype_arg:
|
|
dtype = dtype_arg.expr
|
|
|
|
tensor_arg = get_arg_of_type(args, 'at::TensorList')
|
|
if dtype == 'nullopt' and tensor_arg:
|
|
return tensor_arg.expr
|
|
|
|
return f'{dtype}, {device}'
|
|
|
|
|
|
def fn_output(fn):
|
|
if fn.func.arguments.out:
|
|
assert len(fn.func.arguments.out) == 1
|
|
return fn.func.arguments.out[0].name
|
|
else:
|
|
assert fn.func.arguments.self_arg.argument.is_write
|
|
return fn.func.arguments.self_arg.argument.name
|
|
|
|
|
|
def move_if_needed(str, arg):
|
|
basic_types = {
|
|
'bool',
|
|
'int64_t',
|
|
'double',
|
|
'at::Device',
|
|
'at::Dimname',
|
|
'at::DimnameList',
|
|
'at::IntArrayRef',
|
|
'at::MemoryFormat',
|
|
'at::ScalarType',
|
|
'at::Layout',
|
|
'at::TensorList',
|
|
'c10::string_view',
|
|
}
|
|
free_copy_types = (
|
|
types.ArrayRefCType,
|
|
types.ConstRefCType,
|
|
types.MutRefCType,
|
|
)
|
|
def free(type):
|
|
return isinstance(type, free_copy_types) or \
|
|
type.cpp_type() in basic_types
|
|
|
|
if free(arg.type.type) or \
|
|
(isinstance(arg.type.type, types.OptionalCType) and free(arg.type.type.elem)):
|
|
return str
|
|
return f'std::move({str})'
|
|
|
|
def is_shape_arg(arg):
|
|
type = arg.type.cpp_type()
|
|
dispatch_types = [
|
|
'bool',
|
|
'int64_t',
|
|
'at::IntArrayRef',
|
|
'c10::optional<int64_t>',
|
|
'c10::optional<at::MemoryFormat>',
|
|
]
|
|
return 'Tensor' in type or type in dispatch_types
|
|
|
|
def mk_shape_infer(shape, all_args):
|
|
args = [arg for arg in all_args if is_shape_arg(arg)]
|
|
|
|
if shape == 'ALL []':
|
|
return 'IntArrayRef()'
|
|
if shape == 'ALL [0]':
|
|
return 'IntArrayRef(0)'
|
|
if shape == 'ALL [1]':
|
|
return 'IntArrayRef(1)'
|
|
if shape == 'EQ_FIRST':
|
|
return args[0].expr
|
|
if shape == 'EQ_SECOND':
|
|
return args[1].expr
|
|
if shape == 'EQ_THIRD':
|
|
return args[2].expr
|
|
if shape == 'STD_PROMOTE':
|
|
args = [arg.expr for arg in all_args if 'Tensor' in arg.type.cpp_type() or 'at::IntArrayRef' in arg.type.cpp_type()]
|
|
return f'shape_std_promote({", ".join(args)})'
|
|
if shape == 'PROMOTE_1_2':
|
|
return f'shape_std_promote({args[0].expr}, {args[1].expr})'
|
|
if shape == 'PROMOTE_1_2_3':
|
|
return f'shape_std_promote({args[0].expr}, {args[1].expr}, {args[2].expr})'
|
|
if shape == 'MATMUL_1ST_2ND':
|
|
return f'shape_matmul({args[0].expr}, {args[1].expr})'
|
|
if shape == 'MATMUL_2ND_3RD':
|
|
return f'shape_matmul({args[1].expr}, {args[2].expr})'
|
|
if shape == 'MUL_1ST_2ND':
|
|
if args[0].type.cpp_type() == 'at::TensorList':
|
|
return f'shape_mul({args[0].expr})'
|
|
return f'shape_mul({args[0].expr}, {args[1].expr})'
|
|
if shape == 'MULT_1ST_2ND':
|
|
return f'shape_mult({args[0].expr}, {args[1].expr})'
|
|
if shape == 'MULLAST_1ST_2ND':
|
|
return f'shape_mul_last({args[0].expr}, {args[1].expr})'
|
|
if shape == 'PICK_1ST_2ND':
|
|
return f'shape_pick_1st({args[1].expr})'
|
|
if shape == 'JOIN_2_3':
|
|
return f'shape_join({args[1].expr}, {args[2].expr})'
|
|
if shape == 'PAD1':
|
|
return f'shape_pad1({args[0].expr})'
|
|
if shape == 'DROP1':
|
|
return f'shape_drop1({args[0].expr})'
|
|
if shape == 'DROP2':
|
|
return f'shape_drop2({args[0].expr})'
|
|
if shape == 'TRANSPOSE':
|
|
return f'shape_transpose({args[0].expr}, {all_args[1].expr}, {all_args[2].expr})'
|
|
if shape == 'RESHAPE':
|
|
return f'shape_reshape({args[0].expr}, {args[1].expr})'
|
|
if shape == 'SELECT':
|
|
return f'shape_select({args[0].expr}, {args[1].expr})'
|
|
if shape == 'UNSQUEEZE':
|
|
return f'shape_unsqueeze({args[0].expr}, {all_args[1].expr})'
|
|
if shape == 'FLATTEN':
|
|
return f'shape_flatten({args[0].expr}, {all_args[1].expr}, {all_args[2].expr})'
|
|
if shape == 'ARANGE':
|
|
return f'shape_arange({all_args[0].expr}, {all_args[1].expr}, {all_args[2].expr})'
|
|
if shape == 'EMBEDDING':
|
|
return f'shape_embedding({args[0].expr}, {args[1].expr})'
|
|
if shape == 'SLICE':
|
|
return f'shape_slice({args[0].expr}, {all_args[1].expr}, {all_args[2].expr}, {all_args[3].expr}, {all_args[4].expr})'
|
|
if shape == 'STACK':
|
|
return f'shape_stack({args[0].expr}, {all_args[1].expr})'
|
|
if shape == 'CAT':
|
|
return f'shape_cat({args[0].expr}, {args[1].expr})'
|
|
if shape == 'ARGMAX':
|
|
return f'shape_argmax({args[0].expr}, {args[1].expr}, {args[2].expr})'
|
|
if shape == 'CONV2D':
|
|
off = 0 if args[2].type.cpp_type() == 'at::IntArrayRef' else 1
|
|
return f'shape_conv2d({args[0].expr}, {args[1].expr}, {args[2+off].expr}, {args[3+off].expr}, {args[4+off].expr})'
|
|
if shape == 'CONV2D2':
|
|
return f'shape_conv2d({args[0].expr}, {args[1].expr}, {args[4].expr}, {args[3].expr}, {args[5].expr})'
|
|
if shape == 'POOL2D':
|
|
return f'shape_pool2d({args[0].expr}, {args[1].expr})'
|
|
if shape == 'TRANSPOSE2D':
|
|
return f'shape_transpose2d({args[0].expr})'
|
|
if shape == 'REDUCE':
|
|
return f'shape_reduce({args[0].expr}, {args[1].expr}, {args[2].expr})'
|
|
if shape == 'PERMUTE':
|
|
return f'shape_permute({args[0].expr}, {args[1].expr})'
|
|
if shape == 'UNFOLD':
|
|
return f'shape_unfold({args[0].expr}, {all_args[1].expr}, {all_args[2].expr}, {all_args[3].expr})'
|
|
if shape == 'NARROW':
|
|
return f'shape_narrow({args[0].expr}, {args[1].expr}, {args[2].expr}, {args[3].expr})'
|
|
|
|
print('mk_shape_infer', shape)
|
|
return 'nullopt'
|
|
#exit()
|
|
|
|
|
|
def mk_strides_infer(fn, all_args, ret):
|
|
args = [arg for arg in all_args if is_shape_arg(arg)]
|
|
|
|
if fn == 'ALL []':
|
|
return 'IntArrayRef()'
|
|
if fn == 'ALL [0]':
|
|
return 'IntArrayRef(0)'
|
|
if fn == 'EQ_FIRST':
|
|
return args[0].expr
|
|
if fn == 'EQ_SECOND':
|
|
return args[1].expr
|
|
if fn == 'CONTIGUOUS':
|
|
return f'strides_contiguous({ret})'
|
|
if fn == 'STD_PROMOTE':
|
|
args = [arg.expr for arg in all_args if 'Tensor' in arg.type.cpp_type() or 'at::IntArrayRef' in arg.type.cpp_type()]
|
|
return f'strides_std_promote({ret}, {", ".join(args)})'
|
|
if fn == 'VIEW':
|
|
return f'strides_view({args[0].expr}, {ret})'
|
|
if fn == 'TRANSPOSE2D':
|
|
return f'strides_transpose2d({args[0].expr})'
|
|
if fn == 'TRANSPOSE':
|
|
return f'strides_transpose({args[0].expr}, {args[1].expr}, {args[2].expr})'
|
|
if fn == 'CLONE':
|
|
return f'strides_clone({args[0].expr}, {args[1].expr})'
|
|
if fn == 'CLONE1':
|
|
return f'strides_clone({args[0].expr})'
|
|
if fn == 'CLONE2':
|
|
dtype = all_args[1].expr if 'ScalarType' in all_args[1].type.cpp_type() else all_args[2].expr
|
|
return f'strides_clone2({args[0].expr}, {dtype}, {args[2].expr}, {args[3].expr})'
|
|
if fn == 'CLONE3':
|
|
return f'strides_clone2({args[0].expr}, kFloat, true, {args[2].expr})'
|
|
if fn == 'CLONE_BOOL':
|
|
return f'strides_clone_bool({args[0].expr}, {args[1].expr})'
|
|
if fn == 'PERMUTE':
|
|
return f'strides_permute({args[0].expr}, {args[1].expr})'
|
|
if fn == 'EXPAND':
|
|
return f'strides_expand({args[0].expr}, {args[1].expr})'
|
|
if fn == 'SLICE':
|
|
return f'strides_slice({args[0].expr}, {args[1].expr}, {args[4].expr})'
|
|
if fn == 'FLATTEN':
|
|
return f'strides_flatten({args[0].expr}, {ret})'
|
|
if fn == 'SELECT':
|
|
return f'strides_select({args[0].expr}, {args[1].expr})'
|
|
if fn == 'UNSQUEEZE':
|
|
return f'strides_unsqueeze({args[0].expr}, {args[1].expr})'
|
|
|
|
print('mk_strides_infer', fn)
|
|
return 'nullopt'
|
|
|
|
|
|
@with_native_function
|
|
def gen_dispatch_wrapper(fn):
|
|
sig_group = CppSignatureGroup.from_native_function(fn, method=False, fallback_binding=fn.manual_cpp_binding)
|
|
sig = sig_group.faithful_signature if sig_group.faithful_signature else sig_group.signature
|
|
dispatcher_sig = DispatcherSignature.from_schema(fn.func)
|
|
|
|
rettype = dispatcher_sig.returns_type().cpp_type()
|
|
fndecl = sig.defn(prefix='wrap_', is_redispatching_fn=True)
|
|
fndecl = fndecl.replace('wrap_' + sig.name(), wrapper_name(fn))
|
|
|
|
args = translate(sig.arguments(), dispatcher_sig.arguments())
|
|
register_args = ''.join([f'trace.append_arg({move_if_needed(a.expr, a)});' for a in args])
|
|
|
|
rargs = ', '.join(['dispatchKeySet'] + [move_if_needed(a.expr, a) for a in args])
|
|
redispatch = f'at::redispatch::{sig.name()}({rargs})'
|
|
tensor_args = [a for a in args if maybe_tensor(a.type)]
|
|
|
|
dispatchkey = "dispatchKeySet = dispatchKeySet & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);"
|
|
|
|
shape_fn = get_shape_infer_fn(fn)
|
|
strides_fn = get_strides_infer_fn(fn)
|
|
|
|
# emit pass-through wrapper for unsupported functions
|
|
if skip_fn(fn):
|
|
return f'''
|
|
{fndecl} {{
|
|
stats_inc_unsupported_wrapper();
|
|
{dispatchkey}
|
|
return {redispatch};
|
|
}}'''
|
|
|
|
# returns a tensor and takes tensors as arguments
|
|
# e.g. add(x, y)
|
|
if rettype == 'at::Tensor':
|
|
dtype_device = get_dtype_arg(tensor_args, args, fn.func.name)
|
|
|
|
set_shape = ''
|
|
if shape_fn:
|
|
set_shape = f'set_shape(tt, {mk_shape_infer(shape_fn, args)});\n '
|
|
|
|
set_strides = ''
|
|
if strides_fn:
|
|
set_strides = f'set_strides(tt, {mk_strides_infer(strides_fn, args, "tt")});\n '
|
|
|
|
return f'''
|
|
{fndecl} {{
|
|
if (trace.is_flushing()) {{
|
|
{dispatchkey}
|
|
return {redispatch};
|
|
}}
|
|
auto tt = register_new_tensor(dispatchKeySet, {fn_enum(fn)}, {dtype_device});
|
|
{set_shape}{set_strides}{register_args}
|
|
return tt;
|
|
}}'''
|
|
|
|
# in-place op. returns one of the arguments
|
|
# e.g. mul_ or mul_out
|
|
assert rettype == 'at::Tensor &' or rettype == 'const at::Tensor &'
|
|
assert tensor_args
|
|
ret = fn_output(fn)
|
|
|
|
keeps_shape = 'false'
|
|
if (shape_fn == 'EQ_FIRST' and len(tensor_args) >= 1 and tensor_args[0].expr == ret) or\
|
|
(shape_fn == 'EQ_SECOND' and len(tensor_args) >= 2 and tensor_args[1].expr == ret) or\
|
|
(shape_fn == 'EQ_THIRD' and len(tensor_args) >= 3 and tensor_args[2].expr == ret):
|
|
keeps_shape = 'true'
|
|
elif shape_fn:
|
|
keeps_shape = f'eq_shapes({ret}, {mk_shape_infer(shape_fn, args)})'
|
|
|
|
keeps_strides = 'false'
|
|
if (strides_fn == 'EQ_FIRST' and len(tensor_args) >= 1 and tensor_args[0].expr == ret) or\
|
|
(strides_fn == 'EQ_SECOND' and len(tensor_args) >= 2 and tensor_args[1].expr == ret) or\
|
|
(strides_fn == 'EQ_THIRD' and len(tensor_args) >= 3 and tensor_args[2].expr == ret):
|
|
keeps_strides = 'true'
|
|
elif strides_fn:
|
|
keeps_strides = f'eq_shapes({ret}, {mk_strides_infer(strides_fn, args, ret)})'
|
|
|
|
return f'''
|
|
{fndecl} {{
|
|
if (trace.is_flushing()) {{
|
|
{dispatchkey}
|
|
return {redispatch};
|
|
}}
|
|
bool flush = register_in_place({ret}, {fn_enum(fn)}, dispatchKeySet, {keeps_shape}, {keeps_strides});
|
|
{register_args}
|
|
if (flush)
|
|
trace.flush(STATS(FlushReason::INPLACE_SHARED));
|
|
else
|
|
update_trace_idx({ret});
|
|
return {ret};
|
|
}}'''
|
|
|
|
|
|
@with_native_function
|
|
def gen_torch_library_table(fn):
|
|
return f'm.impl("{fn.func.name}", {wrapper_name(fn)});'
|
|
|
|
enum_names = {}
|
|
@with_native_function
|
|
def gen_ops_names(fn):
|
|
enum_names[fn_enum(fn)] = fn.func.name
|
|
|
|
|
|
# (inplace, code, redispatch_signature) -> (enum, fn_ptr)*
|
|
interpreter_code = {}
|
|
|
|
@with_native_function
|
|
def gen_interpreter_redispatch(fn):
|
|
global interpreter_code
|
|
sig_group = CppSignatureGroup.from_native_function(fn, method=False, fallback_binding=fn.manual_cpp_binding)
|
|
sig = sig_group.faithful_signature if sig_group.faithful_signature else sig_group.signature
|
|
dispatcher_sig = DispatcherSignature.from_schema(fn.func)
|
|
|
|
dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments())
|
|
args = []
|
|
for i, arg in enumerate(dispatcher_exprs):
|
|
type = arg.type.cpp_type(strip_ref=False)
|
|
type = type.replace('const ', '')
|
|
args.append(f'load<{type}>()(op.args[{i}], load_state)')
|
|
|
|
redispatch = f'<FN>(ks, {", ".join(args)})'
|
|
rettype = dispatcher_sig.returns_type().cpp_type()
|
|
|
|
if rettype == 'at::Tensor':
|
|
code = f'results[i] = {redispatch};\n break;'
|
|
inplace = False
|
|
|
|
# in-place op
|
|
else:
|
|
assert rettype == 'at::Tensor &' or rettype == 'const at::Tensor &'
|
|
inplace = True
|
|
code = f'results[i] = {redispatch};\n break;'
|
|
|
|
signature = dispatcher_sig.type()
|
|
fn_ptr = f'at::redispatch::{sig.name()}'
|
|
key = inplace, code, signature
|
|
interpreter_code.setdefault(key, [])
|
|
interpreter_code[key].append((fn_enum(fn), fn_ptr))
|
|
|
|
|
|
fd1 = open('autogen/dispatch_wrappers.h', 'w')
|
|
fd2 = open('autogen/torch_library_table.h', 'w')
|
|
fd3 = open('autogen/ops_enum.h', 'w')
|
|
fd4 = open('autogen/ops_names.h', 'w')
|
|
fd5 = open('autogen/interpreter_redispatch.h', 'w')
|
|
fd6 = open('autogen/interpreter_redispatch_tables.h', 'w')
|
|
fd7 = open('autogen/ops_data.h', 'w')
|
|
|
|
total = 0
|
|
for fn in native_functions.native_functions:
|
|
total += 1
|
|
print(gen_dispatch_wrapper(fn), file=fd1)
|
|
print(gen_torch_library_table(fn), file=fd2)
|
|
|
|
if skip_fn(fn):
|
|
continue
|
|
gen_ops_names(fn)
|
|
gen_interpreter_redispatch(fn)
|
|
|
|
print(f'Total redispatched functions: {total}')
|
|
print(f'Distinct signatures: {len(interpreter_code)}')
|
|
|
|
table_id = 0
|
|
# put all inplaces last
|
|
interpreter_code = sorted(interpreter_code.items())
|
|
is_first_inplace = True
|
|
|
|
for ((inplace, code, sig), entries) in interpreter_code:
|
|
if inplace and is_first_inplace:
|
|
is_first_inplace = False
|
|
print(f'#define FIRST_INPLACE_OP {entries[0][0]}', file=fd7)
|
|
|
|
for (enum, ptr) in entries:
|
|
print(f'case {enum}:', file=fd5)
|
|
print(f'{enum},', file=fd3)
|
|
print(f'"{enum_names[enum]}",', file=fd4)
|
|
|
|
if len(entries) == 1:
|
|
code = code.replace('<FN>', entries[0][1])
|
|
print(f' {code}\n', file=fd5)
|
|
elif len(entries) == 2:
|
|
ptr = sig.replace(' (', f'(*ptr)(DispatchKeySet, ')
|
|
print(f' {{{ptr} = {entries[0][1]};', file=fd5)
|
|
print(f' if (op.id == {entries[1][0]}) ptr = {entries[1][1]};', file=fd5)
|
|
code = code.replace('<FN>', f'ptr')
|
|
print(f' {code}}}\n', file=fd5)
|
|
else:
|
|
table = f'redispatch_ptrs_{table_id}'
|
|
table_id += 1
|
|
code = code.replace('<FN>', f'{table}[op.id - {entries[0][0]}]')
|
|
print(f' {code}\n', file=fd5)
|
|
|
|
table = sig.replace(' (', f'(*const {table}[])(DispatchKeySet, ')
|
|
print(f'{table} = {{', file=fd6)
|
|
for (enum, ptr) in entries:
|
|
print(f' {ptr},', file=fd6)
|
|
print(f'}};\n', file=fd6)
|