зеркало из https://github.com/microsoft/torchy.git
special case argmax return value
This commit is contained in:
Родитель
e99b86b92d
Коммит
72ae8f16fc
|
@ -1192,7 +1192,7 @@ at::Tensor wrap_argmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & se
|
||||||
dispatchKeySet = dispatchKeySet & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);
|
dispatchKeySet = dispatchKeySet & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);
|
||||||
return at::redispatch::argmax(dispatchKeySet, self, dim, keepdim);
|
return at::redispatch::argmax(dispatchKeySet, self, dim, keepdim);
|
||||||
}
|
}
|
||||||
auto tt = at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device());
|
auto tt = at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kLong), self.device());
|
||||||
auto tt_ptr = tt.getIntrusivePtr().get();
|
auto tt_ptr = tt.getIntrusivePtr().get();
|
||||||
unsigned trace_idx = trace.register_tensor((uintptr_t)tt_ptr, H_ARGMAX, dispatchKeySet);
|
unsigned trace_idx = trace.register_tensor((uintptr_t)tt_ptr, H_ARGMAX, dispatchKeySet);
|
||||||
trace.append_arg(trace_idx, self);trace.append_arg(trace_idx, dim);trace.append_arg(trace_idx, keepdim);
|
trace.append_arg(trace_idx, self);trace.append_arg(trace_idx, dim);trace.append_arg(trace_idx, keepdim);
|
||||||
|
|
6
gen.py
6
gen.py
|
@ -140,8 +140,10 @@ def gen_dispatch_wrapper(fn):
|
||||||
dtype = get_arg_of_type(args, 'at::ScalarType')
|
dtype = get_arg_of_type(args, 'at::ScalarType')
|
||||||
if not dtype:
|
if not dtype:
|
||||||
dtype = get_arg_of_type(args, 'c10::optional<at::ScalarType>')
|
dtype = get_arg_of_type(args, 'c10::optional<at::ScalarType>')
|
||||||
if fn.func.name.name.base in always_returns_bool:
|
|
||||||
dtype = 'scalarTypeToTypeMeta(kBool)'
|
fixed = fix_return_type.get(fn.func.name.name.base)
|
||||||
|
if fixed:
|
||||||
|
dtype = f'scalarTypeToTypeMeta({fixed})'
|
||||||
|
|
||||||
device = get_arg_of_type(args, 'at::Device')
|
device = get_arg_of_type(args, 'at::Device')
|
||||||
if not device:
|
if not device:
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
# Copyright (c) 2021-present The Torchy Authors.
|
# Copyright (c) 2021-present The Torchy Authors.
|
||||||
# Distributed under the MIT license that can be found in the LICENSE file.
|
# Distributed under the MIT license that can be found in the LICENSE file.
|
||||||
|
|
||||||
always_returns_bool = {
|
fix_return_type = {
|
||||||
'eq',
|
'argmax': 'kLong',
|
||||||
'greater',
|
'eq': 'kBool',
|
||||||
'gt',
|
'greater': 'kBool',
|
||||||
'isinf',
|
'gt': 'kBool',
|
||||||
'isfinite',
|
'isinf': 'kBool',
|
||||||
'isnan',
|
'isfinite': 'kBool',
|
||||||
'isneginf',
|
'isnan': 'kBool',
|
||||||
'isposinf',
|
'isneginf': 'kBool',
|
||||||
'less',
|
'isposinf': 'kBool',
|
||||||
'lt',
|
'less': 'kBool',
|
||||||
'ne',
|
'lt': 'kBool',
|
||||||
|
'ne': 'kBool',
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче