make the example work (almost) :D

This commit is contained in:
Nuno Lopes 2021-04-16 18:14:12 +01:00
Родитель 167d2c2425
Коммит 48deb3c754
2 изменённых файлов: 111 добавлений и 8 удалений

Просмотреть файл

@ -7,7 +7,9 @@ WIP; don't use.
Install
-------
```
$ python setup.py install
```
Run
@ -15,7 +17,7 @@ Run
Torchy shouldn't require any change beyond adding a call to `torchy.enable()`.
Example:
```
```python
import torch
import torchy

Просмотреть файл

@ -37,6 +37,7 @@ using TorchTensorImpl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using UnionInputTys = c10::variant<
IntArrayRef,
c10::optional<int64_t>,
c10::optional<ScalarType>,
Scalar,
Tensor
>;
@ -110,6 +111,11 @@ struct TensorOp {
os << **o;
else
os << "(null)";
} else if (auto s = get_if<c10::optional<ScalarType>>(&arg)) {
if (*s)
os << **s;
else
os << "(null)";
} else {
assert(false);
}
@ -235,6 +241,9 @@ public:
} else if (!strcmp(op.id, "eq_Tensor")) {
set(op.tensor, at::redispatch::eq(dispatch_key, get<Tensor>(op.args[0]),
get<Tensor>(op.args[1])));
} else if (!strcmp(op.id, "gt_Scalar")) {
set(op.tensor, at::redispatch::gt(dispatch_key, get<Tensor>(op.args[0]),
get<Scalar>(op.args[1])));
} else if (!strcmp(op.id, "masked_select")) {
set(op.tensor,
at::redispatch::masked_select(dispatch_key,
@ -262,6 +271,10 @@ public:
set(op.tensor,
at::redispatch::reshape(dispatch_key, get<Tensor>(op.args[0]),
get<IntArrayRef>(op.args[1])));
} else if (!strcmp(op.id, "sum")) {
set(op.tensor,
at::redispatch::sum(dispatch_key, get<Tensor>(op.args[0]),
get<c10::optional<ScalarType>>(op.args[1])));
} else if (!strcmp(op.id, "view")) {
set(op.tensor,
at::redispatch::view(dispatch_key, get<Tensor>(op.args[0]),
@ -307,9 +320,6 @@ class TorchyTensor final : public TensorImpl {
// FIXME: cant access: is_channels_last_3d_contiguous_
is_non_overlapping_and_dense_ = tensor->is_non_overlapping_and_dense();
is_wrapped_number_ = tensor->is_wrapped_number();
if (tensor->has_storage())
storage_ = tensor->storage();
}
public:
@ -325,9 +335,12 @@ template<typename... T>
void set(Tensor &&t) {
trace_idx = -1u;
tensor = t.unsafeReleaseIntrusivePtr();
key_set_ = key_set_ | tensor->key_set();
data_type_ = tensor->dtype();
storage_ = tensor->storage();
refresh_non_virtual();
assert(dtype() == tensor->dtype());
}
void ensure_tensor() const {
@ -482,6 +495,13 @@ Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor &self,
Tensor as_strided(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef size,
IntArrayRef stride, c10::optional<int64_t> storage_offset) {
if (trace.is_flushing()) {
ensure_materialized(self);
return
at::redispatch::as_strided(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
self, size, stride, storage_offset);
}
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
"as_strided", self, size, stride,
storage_offset);
@ -525,6 +545,15 @@ Tensor& detach_(c10::DispatchKeySet ks, Tensor &self) {
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), self);
}
Tensor& div_out(c10::DispatchKeySet ks, const Tensor &self, const Tensor &other,
Tensor &out) {
// TODO: can be made lazy?
ensure_materialized(self, other, out);
return at::redispatch::div_out(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
out, self, other);
}
Tensor empty_memory_format(c10::DispatchKeySet ks, IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
@ -557,7 +586,8 @@ Tensor empty_strided(c10::DispatchKeySet ks, IntArrayRef size,
Tensor eq_Tensor(c10::DispatchKeySet ks, const Tensor &self,
const Tensor &other) {
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
self.device(), ks,
"eq_Tensor", self, other);
}
@ -569,6 +599,38 @@ Tensor& eq_Tensor_out(c10::DispatchKeySet ks, const Tensor &self,
out, self, other);
}
Tensor& fill__Scalar(c10::DispatchKeySet ks, Tensor &self,
const Scalar &value) {
ensure_materialized(self);
return
at::redispatch::fill_(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
self, value);
}
Tensor gt_Scalar(c10::DispatchKeySet ks, const Tensor &self,
const Scalar &other) {
if (trace.is_flushing()) {
ensure_materialized(self);
return
at::redispatch::gt(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
self, other);
}
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
self.device(), ks,
"gt_Scalar", self, other);
}
Tensor& gt_Tensor_out(c10::DispatchKeySet ks, const Tensor &self,
const Tensor &other, Tensor &out) {
ensure_materialized(self, other, out);
return
at::redispatch::gt_out(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
out, self, other);
}
Scalar _local_scalar_dense(c10::DispatchKeySet ks, const Tensor &self) {
ensure_materialized(self);
return at::redispatch::_local_scalar_dense(
@ -607,13 +669,15 @@ Tensor mul_Tensor(c10::DispatchKeySet ks, const Tensor &self,
Tensor ne_Scalar(c10::DispatchKeySet ks, const Tensor &self,
const Scalar &other) {
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
self.device(), ks,
"ne_Scalar", self, other);
}
Tensor ne_Tensor(c10::DispatchKeySet ks, const Tensor &self,
const Tensor &other) {
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
return at::detail::make_tensor<TorchyTensor>(scalarTypeToTypeMeta(kBool),
self.device(), ks,
"ne_Tensor", self, other);
}
@ -627,6 +691,13 @@ Tensor& ne_Tensor_out(c10::DispatchKeySet ks, const Tensor &self,
}
Tensor reshape(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef shape) {
if (trace.is_flushing()) {
ensure_materialized(self);
return
at::redispatch::reshape(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
self, shape);
}
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
"reshape", self, shape);
}
@ -640,6 +711,19 @@ Tensor& resize_(c10::DispatchKeySet ks, Tensor &self, IntArrayRef size,
self, size, memory_format);
}
Tensor sum(c10::DispatchKeySet ks, const Tensor &self,
c10::optional<ScalarType> dtype) {
if (trace.is_flushing()) {
ensure_materialized(self);
return
at::redispatch::sum(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
self, dtype);
}
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
"sum", self, dtype);
}
Tensor to_device(c10::DispatchKeySet ks, const Tensor &self,
Device device, ScalarType dtype, bool non_blocking, bool copy,
c10::optional<MemoryFormat> memory_format) {
@ -651,6 +735,13 @@ Tensor to_device(c10::DispatchKeySet ks, const Tensor &self,
}
Tensor view(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef size) {
if (trace.is_flushing()) {
ensure_materialized(self);
return
at::redispatch::view(
ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY),
self, size);
}
return at::detail::make_tensor<TorchyTensor>(self.dtype(), self.device(), ks,
"view", self, size);
}
@ -664,10 +755,14 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) {
m.impl("ceil.out", ceil_out);
m.impl("copy_", copy_);
m.impl("detach_", detach_); // FIXME: RegisterDefaultBackend
m.impl("div.out", div_out);
m.impl("empty.memory_format", empty_memory_format); // FIXME: not called
m.impl("empty_strided", empty_strided); // FIXME: not called
m.impl("eq.Tensor", eq_Tensor);
m.impl("eq.Tensor_out", eq_Tensor_out);
m.impl("fill_.Scalar", fill__Scalar);
m.impl("gt.Scalar", gt_Scalar);
m.impl("gt.Tensor_out", gt_Tensor_out);
m.impl("_local_scalar_dense", _local_scalar_dense);
m.impl("masked_select", masked_select);
m.impl("max", max);
@ -679,8 +774,14 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) {
m.impl("ne.Tensor_out", ne_Tensor_out);
m.impl("reshape", reshape); // FIXME: RegisterMath
m.impl("resize_", resize_);
m.impl("sum", sum);
m.impl("to.device", to_device); // FIXME: RegisterMath
m.impl("view", view);
}
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl("reshape", reshape);
m.impl("to.device", to_device);
}
}