зеркало из https://github.com/microsoft/torchy.git
special case argmax return value
This commit is contained in:
Родитель
e99b86b92d
Коммит
72ae8f16fc
|
@ -1192,7 +1192,7 @@ at::Tensor wrap_argmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & se
|
|||
dispatchKeySet = dispatchKeySet & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);
|
||||
return at::redispatch::argmax(dispatchKeySet, self, dim, keepdim);
|
||||
}
|
||||
auto tt = at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device());
|
||||
auto tt = at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kLong), self.device());
|
||||
auto tt_ptr = tt.getIntrusivePtr().get();
|
||||
unsigned trace_idx = trace.register_tensor((uintptr_t)tt_ptr, H_ARGMAX, dispatchKeySet);
|
||||
trace.append_arg(trace_idx, self);trace.append_arg(trace_idx, dim);trace.append_arg(trace_idx, keepdim);
|
||||
|
|
6
gen.py
6
gen.py
|
@ -140,8 +140,10 @@ def gen_dispatch_wrapper(fn):
|
|||
dtype = get_arg_of_type(args, 'at::ScalarType')
|
||||
if not dtype:
|
||||
dtype = get_arg_of_type(args, 'c10::optional<at::ScalarType>')
|
||||
if fn.func.name.name.base in always_returns_bool:
|
||||
dtype = 'scalarTypeToTypeMeta(kBool)'
|
||||
|
||||
fixed = fix_return_type.get(fn.func.name.name.base)
|
||||
if fixed:
|
||||
dtype = f'scalarTypeToTypeMeta({fixed})'
|
||||
|
||||
device = get_arg_of_type(args, 'at::Device')
|
||||
if not device:
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
# Copyright (c) 2021-present The Torchy Authors.
|
||||
# Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
always_returns_bool = {
|
||||
'eq',
|
||||
'greater',
|
||||
'gt',
|
||||
'isinf',
|
||||
'isfinite',
|
||||
'isnan',
|
||||
'isneginf',
|
||||
'isposinf',
|
||||
'less',
|
||||
'lt',
|
||||
'ne',
|
||||
fix_return_type = {
|
||||
'argmax': 'kLong',
|
||||
'eq': 'kBool',
|
||||
'greater': 'kBool',
|
||||
'gt': 'kBool',
|
||||
'isinf': 'kBool',
|
||||
'isfinite': 'kBool',
|
||||
'isnan': 'kBool',
|
||||
'isneginf': 'kBool',
|
||||
'isposinf': 'kBool',
|
||||
'less': 'kBool',
|
||||
'lt': 'kBool',
|
||||
'ne': 'kBool',
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче