diff --git a/README.md b/README.md index d1dc099..13a2c5c 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,9 @@ WIP; don't use. Install ------- +``` $ python setup.py install +``` Run @@ -15,7 +17,7 @@ Run Torchy shouldn't require any change beyond adding a call to `torchy.enable()`. Example: -``` +```python import torch import torchy diff --git a/handlers.cpp b/handlers.cpp index 2212216..ab96bd3 100644 --- a/handlers.cpp +++ b/handlers.cpp @@ -37,6 +37,7 @@ using TorchTensorImpl = c10::intrusive_ptr; using UnionInputTys = c10::variant< IntArrayRef, c10::optional, + c10::optional, Scalar, Tensor >; @@ -110,6 +111,11 @@ struct TensorOp { os << **o; else os << "(null)"; + } else if (auto s = get_if>(&arg)) { + if (*s) + os << **s; + else + os << "(null)"; } else { assert(false); } @@ -235,6 +241,9 @@ public: } else if (!strcmp(op.id, "eq_Tensor")) { set(op.tensor, at::redispatch::eq(dispatch_key, get(op.args[0]), get(op.args[1]))); + } else if (!strcmp(op.id, "gt_Scalar")) { + set(op.tensor, at::redispatch::gt(dispatch_key, get(op.args[0]), + get(op.args[1]))); } else if (!strcmp(op.id, "masked_select")) { set(op.tensor, at::redispatch::masked_select(dispatch_key, @@ -262,6 +271,10 @@ public: set(op.tensor, at::redispatch::reshape(dispatch_key, get(op.args[0]), get(op.args[1]))); + } else if (!strcmp(op.id, "sum")) { + set(op.tensor, + at::redispatch::sum(dispatch_key, get(op.args[0]), + get>(op.args[1]))); } else if (!strcmp(op.id, "view")) { set(op.tensor, at::redispatch::view(dispatch_key, get(op.args[0]), @@ -307,9 +320,6 @@ class TorchyTensor final : public TensorImpl { // FIXME: cant access: is_channels_last_3d_contiguous_ is_non_overlapping_and_dense_ = tensor->is_non_overlapping_and_dense(); is_wrapped_number_ = tensor->is_wrapped_number(); - - if (tensor->has_storage()) - storage_ = tensor->storage(); } public: @@ -325,9 +335,12 @@ template void set(Tensor &&t) { trace_idx = -1u; tensor = t.unsafeReleaseIntrusivePtr(); + key_set_ = key_set_ | tensor->key_set(); - data_type_ = tensor->dtype(); + storage_ = tensor->storage(); refresh_non_virtual(); + + assert(dtype() == tensor->dtype()); } void ensure_tensor() const { @@ -482,6 +495,13 @@ Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor &self, Tensor as_strided(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) { + if (trace.is_flushing()) { + ensure_materialized(self); + return + at::redispatch::as_strided( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + self, size, stride, storage_offset); + } return at::detail::make_tensor(self.dtype(), self.device(), ks, "as_strided", self, size, stride, storage_offset); @@ -525,6 +545,15 @@ Tensor& detach_(c10::DispatchKeySet ks, Tensor &self) { ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), self); } +Tensor& div_out(c10::DispatchKeySet ks, const Tensor &self, const Tensor &other, + Tensor &out) { + // TODO: can be made lazy? + ensure_materialized(self, other, out); + return at::redispatch::div_out( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + out, self, other); +} + Tensor empty_memory_format(c10::DispatchKeySet ks, IntArrayRef size, c10::optional dtype, c10::optional layout, @@ -557,7 +586,8 @@ Tensor empty_strided(c10::DispatchKeySet ks, IntArrayRef size, Tensor eq_Tensor(c10::DispatchKeySet ks, const Tensor &self, const Tensor &other) { - return at::detail::make_tensor(self.dtype(), self.device(), ks, + return at::detail::make_tensor(scalarTypeToTypeMeta(kBool), + self.device(), ks, "eq_Tensor", self, other); } @@ -569,6 +599,38 @@ Tensor& eq_Tensor_out(c10::DispatchKeySet ks, const Tensor &self, out, self, other); } +Tensor& fill__Scalar(c10::DispatchKeySet ks, Tensor &self, + const Scalar &value) { +ensure_materialized(self); + return + at::redispatch::fill_( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + self, value); +} + +Tensor gt_Scalar(c10::DispatchKeySet ks, const Tensor &self, + const Scalar &other) { + if (trace.is_flushing()) { + ensure_materialized(self); + return + at::redispatch::gt( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + self, other); + } + return at::detail::make_tensor(scalarTypeToTypeMeta(kBool), + self.device(), ks, + "gt_Scalar", self, other); +} + +Tensor& gt_Tensor_out(c10::DispatchKeySet ks, const Tensor &self, + const Tensor &other, Tensor &out) { + ensure_materialized(self, other, out); + return + at::redispatch::gt_out( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + out, self, other); +} + Scalar _local_scalar_dense(c10::DispatchKeySet ks, const Tensor &self) { ensure_materialized(self); return at::redispatch::_local_scalar_dense( @@ -607,13 +669,15 @@ Tensor mul_Tensor(c10::DispatchKeySet ks, const Tensor &self, Tensor ne_Scalar(c10::DispatchKeySet ks, const Tensor &self, const Scalar &other) { - return at::detail::make_tensor(self.dtype(), self.device(), ks, + return at::detail::make_tensor(scalarTypeToTypeMeta(kBool), + self.device(), ks, "ne_Scalar", self, other); } Tensor ne_Tensor(c10::DispatchKeySet ks, const Tensor &self, const Tensor &other) { - return at::detail::make_tensor(self.dtype(), self.device(), ks, + return at::detail::make_tensor(scalarTypeToTypeMeta(kBool), + self.device(), ks, "ne_Tensor", self, other); } @@ -627,6 +691,13 @@ Tensor& ne_Tensor_out(c10::DispatchKeySet ks, const Tensor &self, } Tensor reshape(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef shape) { + if (trace.is_flushing()) { + ensure_materialized(self); + return + at::redispatch::reshape( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + self, shape); + } return at::detail::make_tensor(self.dtype(), self.device(), ks, "reshape", self, shape); } @@ -640,6 +711,19 @@ Tensor& resize_(c10::DispatchKeySet ks, Tensor &self, IntArrayRef size, self, size, memory_format); } +Tensor sum(c10::DispatchKeySet ks, const Tensor &self, + c10::optional dtype) { + if (trace.is_flushing()) { + ensure_materialized(self); + return + at::redispatch::sum( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + self, dtype); + } + return at::detail::make_tensor(self.dtype(), self.device(), ks, + "sum", self, dtype); +} + Tensor to_device(c10::DispatchKeySet ks, const Tensor &self, Device device, ScalarType dtype, bool non_blocking, bool copy, c10::optional memory_format) { @@ -651,6 +735,13 @@ Tensor to_device(c10::DispatchKeySet ks, const Tensor &self, } Tensor view(c10::DispatchKeySet ks, const Tensor &self, IntArrayRef size) { + if (trace.is_flushing()) { + ensure_materialized(self); + return + at::redispatch::view( + ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY), + self, size); + } return at::detail::make_tensor(self.dtype(), self.device(), ks, "view", self, size); } @@ -664,10 +755,14 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) { m.impl("ceil.out", ceil_out); m.impl("copy_", copy_); m.impl("detach_", detach_); // FIXME: RegisterDefaultBackend + m.impl("div.out", div_out); m.impl("empty.memory_format", empty_memory_format); // FIXME: not called m.impl("empty_strided", empty_strided); // FIXME: not called m.impl("eq.Tensor", eq_Tensor); m.impl("eq.Tensor_out", eq_Tensor_out); + m.impl("fill_.Scalar", fill__Scalar); + m.impl("gt.Scalar", gt_Scalar); + m.impl("gt.Tensor_out", gt_Tensor_out); m.impl("_local_scalar_dense", _local_scalar_dense); m.impl("masked_select", masked_select); m.impl("max", max); @@ -679,8 +774,14 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) { m.impl("ne.Tensor_out", ne_Tensor_out); m.impl("reshape", reshape); // FIXME: RegisterMath m.impl("resize_", resize_); + m.impl("sum", sum); m.impl("to.device", to_device); // FIXME: RegisterMath m.impl("view", view); } +TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { + m.impl("reshape", reshape); + m.impl("to.device", to_device); +} + }