special case argmax return value

This commit is contained in:
Nuno Lopes 2021-05-27 18:20:48 +01:00
Родитель e99b86b92d
Коммит 72ae8f16fc
3 изменённых файлов: 18 добавлений и 15 удалений

Просмотреть файл

@ -1192,7 +1192,7 @@ at::Tensor wrap_argmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & se
dispatchKeySet = dispatchKeySet & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY); dispatchKeySet = dispatchKeySet & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);
return at::redispatch::argmax(dispatchKeySet, self, dim, keepdim); return at::redispatch::argmax(dispatchKeySet, self, dim, keepdim);
} }
auto tt = at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device()); auto tt = at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kLong), self.device());
auto tt_ptr = tt.getIntrusivePtr().get(); auto tt_ptr = tt.getIntrusivePtr().get();
unsigned trace_idx = trace.register_tensor((uintptr_t)tt_ptr, H_ARGMAX, dispatchKeySet); unsigned trace_idx = trace.register_tensor((uintptr_t)tt_ptr, H_ARGMAX, dispatchKeySet);
trace.append_arg(trace_idx, self);trace.append_arg(trace_idx, dim);trace.append_arg(trace_idx, keepdim); trace.append_arg(trace_idx, self);trace.append_arg(trace_idx, dim);trace.append_arg(trace_idx, keepdim);

6
gen.py
Просмотреть файл

@ -140,8 +140,10 @@ def gen_dispatch_wrapper(fn):
dtype = get_arg_of_type(args, 'at::ScalarType') dtype = get_arg_of_type(args, 'at::ScalarType')
if not dtype: if not dtype:
dtype = get_arg_of_type(args, 'c10::optional<at::ScalarType>') dtype = get_arg_of_type(args, 'c10::optional<at::ScalarType>')
if fn.func.name.name.base in always_returns_bool:
dtype = 'scalarTypeToTypeMeta(kBool)' fixed = fix_return_type.get(fn.func.name.name.base)
if fixed:
dtype = f'scalarTypeToTypeMeta({fixed})'
device = get_arg_of_type(args, 'at::Device') device = get_arg_of_type(args, 'at::Device')
if not device: if not device:

Просмотреть файл

@ -1,16 +1,17 @@
# Copyright (c) 2021-present The Torchy Authors. # Copyright (c) 2021-present The Torchy Authors.
# Distributed under the MIT license that can be found in the LICENSE file. # Distributed under the MIT license that can be found in the LICENSE file.
always_returns_bool = { fix_return_type = {
'eq', 'argmax': 'kLong',
'greater', 'eq': 'kBool',
'gt', 'greater': 'kBool',
'isinf', 'gt': 'kBool',
'isfinite', 'isinf': 'kBool',
'isnan', 'isfinite': 'kBool',
'isneginf', 'isnan': 'kBool',
'isposinf', 'isneginf': 'kBool',
'less', 'isposinf': 'kBool',
'lt', 'less': 'kBool',
'ne', 'lt': 'kBool',
'ne': 'kBool',
} }