special case argmax return value

This commit is contained in:
Nuno Lopes 2021-05-27 18:20:48 +01:00
Родитель e99b86b92d
Коммит 72ae8f16fc
3 изменённых файлов: 18 добавлений и 15 удалений

Просмотреть файл

@ -1192,7 +1192,7 @@ at::Tensor wrap_argmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & se
dispatchKeySet = dispatchKeySet & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);
return at::redispatch::argmax(dispatchKeySet, self, dim, keepdim);
}
auto tt = at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device());
auto tt = at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kLong), self.device());
auto tt_ptr = tt.getIntrusivePtr().get();
unsigned trace_idx = trace.register_tensor((uintptr_t)tt_ptr, H_ARGMAX, dispatchKeySet);
trace.append_arg(trace_idx, self);trace.append_arg(trace_idx, dim);trace.append_arg(trace_idx, keepdim);

6
gen.py
Просмотреть файл

@ -140,8 +140,10 @@ def gen_dispatch_wrapper(fn):
dtype = get_arg_of_type(args, 'at::ScalarType')
if not dtype:
dtype = get_arg_of_type(args, 'c10::optional<at::ScalarType>')
if fn.func.name.name.base in always_returns_bool:
dtype = 'scalarTypeToTypeMeta(kBool)'
fixed = fix_return_type.get(fn.func.name.name.base)
if fixed:
dtype = f'scalarTypeToTypeMeta({fixed})'
device = get_arg_of_type(args, 'at::Device')
if not device:

Просмотреть файл

@ -1,16 +1,17 @@
# Copyright (c) 2021-present The Torchy Authors.
# Distributed under the MIT license that can be found in the LICENSE file.
always_returns_bool = {
'eq',
'greater',
'gt',
'isinf',
'isfinite',
'isnan',
'isneginf',
'isposinf',
'less',
'lt',
'ne',
fix_return_type = {
'argmax': 'kLong',
'eq': 'kBool',
'greater': 'kBool',
'gt': 'kBool',
'isinf': 'kBool',
'isfinite': 'kBool',
'isnan': 'kBool',
'isneginf': 'kBool',
'isposinf': 'kBool',
'less': 'kBool',
'lt': 'kBool',
'ne': 'kBool',
}