зеркало из https://github.com/microsoft/torchy.git
make the example work (almost) :D
This commit is contained in:
Родитель
167d2c2425
Коммит
48deb3c754
|
@ -7,7 +7,9 @@ WIP; don't use.
|
|||
|
||||
Install
|
||||
-------
|
||||
```
|
||||
$ python setup.py install
|
||||
```
|
||||
|
||||
|
||||
Run
|
||||
|
@ -15,7 +17,7 @@ Run
|
|||
Torchy shouldn't require any change beyond adding a call to `torchy.enable()`.
|
||||
Example:
|
||||
|
||||
```
|
||||
```python
|
||||
import torch
|
||||
import torchy
|
||||
|
||||
|
|
115
handlers.cpp
115
handlers.cpp
|
@ -37,6 +37,7 @@ using TorchTensorImpl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
|||
using UnionInputTys = c10::variant<
|
||||
IntArrayRef,
|
||||
c10::optional<int64_t>,
|
||||
c10::optional<ScalarType>,
|
||||
Scalar,
|
||||
Tensor
|
||||
>;
|
||||
|
@ -110,6 +111,11 @@ struct TensorOp {
|
|||
os << **o;
|
||||
else
|
||||
os << "(null)";
|
||||
} else if (auto s = get_if<c10::optional<ScalarType>>(&arg)) {
|
||||
if (*s)
|
||||
os << **s;
|
||||
else
|
||||
os << "(null)";
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
@ -235,6 +241,9 @@ public:
|
|||
} else if (!strcmp(op.id, "eq_Tensor")) {
|
||||
set(op.tensor, at::redispatch::eq(dispatch_key, get<Tensor>(op.args[0]),
|
||||
get<Tensor>(op.args[1])));
|
||||
} else if (!strcmp(op.id, "gt_Scalar")) {
|
||||
set(op.tensor, at::redispatch::gt(dispatch_key, get<Tensor>(op.args[0]),
|
||||
get<Scalar>(op.args[1])));
|
||||
} else if (!strcmp(op.id, "masked_select")) {
|
||||
set(op.tensor,
|
||||
at::redispatch::masked_select(dispatch_key,
|
||||
|
@ -262,6 +271,10 @@ public:
|
|||
set(op.tensor,
|
||||
at::redispatch::reshape(dispatch_key, get<Tensor>(op.args[0]),
|
||||
get<IntArrayRef>(op.args[1])));
|
||||
} else if (!strcmp(op.id, "sum")) {
|
||||
set(op.tensor,
|
||||
at::redispatch::sum(dispatch_key, get<Tensor>(op.args[0]),
|
||||
get<c10::optional<ScalarType>>(op.args[1])));
|
||||
} else if (!strcmp(op.id, "view")) {
|
||||
set(op.tensor,
|
||||
at::redispatch::view(dispatch_key, get<Tensor>(op.args[0]),
|
||||
|
@ -307,9 +320,6 @@ class TorchyTensor final : public TensorImpl {
|
|||
// FIXME: cant access: is_channels_last_3d_contiguous_
|
||||
is_non_overlapping_and_dense_ = tensor->is_non_overlapping_and_dense();
|
||||
is_wrapped_number_ = tensor->is_wrapped_number();
|
||||
|
||||
if (tensor->has_storage())
|
||||
storage_ = tensor->storage();
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -325,9 +335,12 @@ template<typename... T>
|
|||
void set(Tensor &&t) {
|
||||
trace_idx = -1u;
|
||||
tensor = t.unsafeReleaseIntrusivePtr();
|
||||
|
||||
key_set_ = key_set_ | tensor->key_set();
|
||||
data_type_ = tensor->dtype();
|
||||
storage_ = tensor->storage();
|
||||
refresh_non_virtual();
|
||||
|
||||
assert(dtype() == tensor->dtype());
|
||||
}
|
||||
|
||||
void ensure_tensor() const {
|
||||
|
@ -482,6 +495,13 @@ Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor &self,
|
|||
|
||||
Tensor as_strided(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef size,
|
||||
IntArrayRef stride, c10::optional<int64_t> storage_offset) {
|
||||
if (trace.is_flushing()) {
|
||||
ensure_materialized(self);
|
||||
return
|
||||
at::redispatch::as_strided(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
self, size, stride, storage_offset);
|
||||
}
|
||||
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
|
||||
"as_strided", self, size, stride,
|
||||
storage_offset);
|
||||
|
@ -525,6 +545,15 @@ Tensor& detach_(c10::DispatchKeySet ks, Tensor &self) {
|
|||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), self);
|
||||
}
|
||||
|
||||
Tensor& div_out(c10::DispatchKeySet ks, const Tensor &self, const Tensor &other,
|
||||
Tensor &out) {
|
||||
// TODO: can be made lazy?
|
||||
ensure_materialized(self, other, out);
|
||||
return at::redispatch::div_out(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
out, self, other);
|
||||
}
|
||||
|
||||
Tensor empty_memory_format(c10::DispatchKeySet ks, IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
|
@ -557,7 +586,8 @@ Tensor empty_strided(c10::DispatchKeySet ks, IntArrayRef size,
|
|||
|
||||
Tensor eq_Tensor(c10::DispatchKeySet ks, const Tensor &self,
|
||||
const Tensor &other) {
|
||||
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
|
||||
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
|
||||
self.device(), ks,
|
||||
"eq_Tensor", self, other);
|
||||
}
|
||||
|
||||
|
@ -569,6 +599,38 @@ Tensor& eq_Tensor_out(c10::DispatchKeySet ks, const Tensor &self,
|
|||
out, self, other);
|
||||
}
|
||||
|
||||
Tensor& fill__Scalar(c10::DispatchKeySet ks, Tensor &self,
|
||||
const Scalar &value) {
|
||||
ensure_materialized(self);
|
||||
return
|
||||
at::redispatch::fill_(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
self, value);
|
||||
}
|
||||
|
||||
Tensor gt_Scalar(c10::DispatchKeySet ks, const Tensor &self,
|
||||
const Scalar &other) {
|
||||
if (trace.is_flushing()) {
|
||||
ensure_materialized(self);
|
||||
return
|
||||
at::redispatch::gt(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
self, other);
|
||||
}
|
||||
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
|
||||
self.device(), ks,
|
||||
"gt_Scalar", self, other);
|
||||
}
|
||||
|
||||
Tensor& gt_Tensor_out(c10::DispatchKeySet ks, const Tensor &self,
|
||||
const Tensor &other, Tensor &out) {
|
||||
ensure_materialized(self, other, out);
|
||||
return
|
||||
at::redispatch::gt_out(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
out, self, other);
|
||||
}
|
||||
|
||||
Scalar _local_scalar_dense(c10::DispatchKeySet ks, const Tensor &self) {
|
||||
ensure_materialized(self);
|
||||
return at::redispatch::_local_scalar_dense(
|
||||
|
@ -607,13 +669,15 @@ Tensor mul_Tensor(c10::DispatchKeySet ks, const Tensor &self,
|
|||
|
||||
Tensor ne_Scalar(c10::DispatchKeySet ks, const Tensor &self,
|
||||
const Scalar &other) {
|
||||
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
|
||||
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
|
||||
self.device(), ks,
|
||||
"ne_Scalar", self, other);
|
||||
}
|
||||
|
||||
Tensor ne_Tensor(c10::DispatchKeySet ks, const Tensor &self,
|
||||
const Tensor &other) {
|
||||
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
|
||||
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
|
||||
self.device(), ks,
|
||||
"ne_Tensor", self, other);
|
||||
}
|
||||
|
||||
|
@ -627,6 +691,13 @@ Tensor& ne_Tensor_out(c10::DispatchKeySet ks, const Tensor &self,
|
|||
}
|
||||
|
||||
Tensor reshape(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef shape) {
|
||||
if (trace.is_flushing()) {
|
||||
ensure_materialized(self);
|
||||
return
|
||||
at::redispatch::reshape(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
self, shape);
|
||||
}
|
||||
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
|
||||
"reshape", self, shape);
|
||||
}
|
||||
|
@ -640,6 +711,19 @@ Tensor& resize_(c10::DispatchKeySet ks, Tensor &self, IntArrayRef size,
|
|||
self, size, memory_format);
|
||||
}
|
||||
|
||||
Tensor sum(c10::DispatchKeySet ks, const Tensor &self,
|
||||
c10::optional<ScalarType> dtype) {
|
||||
if (trace.is_flushing()) {
|
||||
ensure_materialized(self);
|
||||
return
|
||||
at::redispatch::sum(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
self, dtype);
|
||||
}
|
||||
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
|
||||
"sum", self, dtype);
|
||||
}
|
||||
|
||||
Tensor to_device(c10::DispatchKeySet ks, const Tensor &self,
|
||||
Device device, ScalarType dtype, bool non_blocking, bool copy,
|
||||
c10::optional<MemoryFormat> memory_format) {
|
||||
|
@ -651,6 +735,13 @@ Tensor to_device(c10::DispatchKeySet ks, const Tensor &self,
|
|||
}
|
||||
|
||||
Tensor view(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef size) {
|
||||
if (trace.is_flushing()) {
|
||||
ensure_materialized(self);
|
||||
return
|
||||
at::redispatch::view(
|
||||
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
|
||||
self, size);
|
||||
}
|
||||
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
|
||||
"view", self, size);
|
||||
}
|
||||
|
@ -664,10 +755,14 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) {
|
|||
m.impl("ceil.out", ceil_out);
|
||||
m.impl("copy_", copy_);
|
||||
m.impl("detach_", detach_); // FIXME: RegisterDefaultBackend
|
||||
m.impl("div.out", div_out);
|
||||
m.impl("empty.memory_format", empty_memory_format); // FIXME: not called
|
||||
m.impl("empty_strided", empty_strided); // FIXME: not called
|
||||
m.impl("eq.Tensor", eq_Tensor);
|
||||
m.impl("eq.Tensor_out", eq_Tensor_out);
|
||||
m.impl("fill_.Scalar", fill__Scalar);
|
||||
m.impl("gt.Scalar", gt_Scalar);
|
||||
m.impl("gt.Tensor_out", gt_Tensor_out);
|
||||
m.impl("_local_scalar_dense", _local_scalar_dense);
|
||||
m.impl("masked_select", masked_select);
|
||||
m.impl("max", max);
|
||||
|
@ -679,8 +774,14 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) {
|
|||
m.impl("ne.Tensor_out", ne_Tensor_out);
|
||||
m.impl("reshape", reshape); // FIXME: RegisterMath
|
||||
m.impl("resize_", resize_);
|
||||
m.impl("sum", sum);
|
||||
m.impl("to.device", to_device); // FIXME: RegisterMath
|
||||
m.impl("view", view);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
|
||||
m.impl("reshape", reshape);
|
||||
m.impl("to.device", to_device);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче