From 026737afe8e49c9e41b04accf5c5eb79f8a35bbb Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Tue, 27 Apr 2021 18:27:03 +0100 Subject: [PATCH] add autograd fallback handler --- backends/interpreter.cpp | 1 - dispatch.h | 2 ++ tensor.cpp | 9 ++------- trace.cpp | 17 +++++++++++++++++ trace.h | 27 ++++++--------------------- 5 files changed, 27 insertions(+), 29 deletions(-) diff --git a/backends/interpreter.cpp b/backends/interpreter.cpp index efdf2f5..c202963 100644 --- a/backends/interpreter.cpp +++ b/backends/interpreter.cpp @@ -2,7 +2,6 @@ // Distributed under the MIT license that can be found in the LICENSE file. #include "dispatch.h" -#include "ops.h" #include "tensor.h" #include "trace.h" #include diff --git a/dispatch.h b/dispatch.h index e4f0b35..bafd827 100644 --- a/dispatch.h +++ b/dispatch.h @@ -5,3 +5,5 @@ #define DISPATCHKEY_NO_NS PrivateUse1 #define DISPATCHKEY DispatchKey::DISPATCHKEY_NO_NS + +#define AUTOGRADDISPATCHKEY_NO_NS AutogradPrivateUse1 diff --git a/tensor.cpp b/tensor.cpp index d07069e..05b7416 100644 --- a/tensor.cpp +++ b/tensor.cpp @@ -7,7 +7,6 @@ #undef NDEBUG #include "tensor.h" #include "dispatch.h" -#include "ops.h" #include "trace.h" #include #include @@ -235,12 +234,8 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) { #include "autogen/torch_library_table.h" } -#if 0 -TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { - m.impl("isfinite", isfinite); - m.impl("reshape", reshape); - m.impl("to.device", to_device); +TORCH_LIBRARY_IMPL(_, AUTOGRADDISPATCHKEY_NO_NS, m) { + m.fallback(torch::CppFunction::makeFallthrough()); } -#endif } diff --git a/trace.cpp b/trace.cpp index a325c17..8265318 100644 --- a/trace.cpp +++ b/trace.cpp @@ -50,6 +50,11 @@ public: }; } +void TensorOp::incref() { + assert(isObservable()); + ++refs; +} + void TensorOp::decref(TensorOp *ops) { assert(refs > 0); --refs; @@ -176,6 +181,18 @@ void Trace::incref(const List> &l) { } } +void Trace::set_unobservable(unsigned idx) { + auto &op = ops[idx]; + assert(op.tensor); + op.tensor = nullptr; + op.decref(ops); + + // reclaim slot if this was the last created tensor + if (op.refs == 0 && idx+1 == next_op) { + --next_op; + } +} + void Trace::flush() { #if 1 cout << "Flush trace\n" << *this; diff --git a/trace.h b/trace.h index a009815..6da1819 100644 --- a/trace.h +++ b/trace.h @@ -3,6 +3,7 @@ // Copyright (c) 2021-present The Torchy Authors. // Distributed under the MIT license that can be found in the LICENSE file. +#include "ops.h" #include #include #include @@ -51,18 +52,13 @@ struct TensorOp { TorchyTensor *tensor; std::vector args; c10::DispatchKeySet dispatch_key; - unsigned id; + TorchOp id; unsigned refs; - void incref() { - assert(isObservable()); - ++refs; - } - + void incref(); void decref(TensorOp *ops); bool isObservable() const { - assert(!tensor || refs > 0); return tensor; } @@ -91,7 +87,7 @@ class Trace { std::vector> deep_copies; template - at::ArrayRef deep_copy(at::ArrayRef arr) { + at::ArrayRef deep_copy(const at::ArrayRef &arr) { if (arr.empty()) return arr; size_t size = arr.size() * sizeof(T); @@ -139,7 +135,7 @@ public: TensorOp* getOps() { return ops; } template - unsigned register_tensor(TorchyTensor *tensor, unsigned op_id, + unsigned register_tensor(TorchyTensor *tensor, TorchOp op_id, c10::DispatchKeySet ks, T&... args) { assert(!flushing); if (next_op == MAX_TRACE_LENGTH) @@ -155,18 +151,7 @@ public: return next_op++; } - void set_unobservable(unsigned idx) { - auto &op = ops[idx]; - assert(op.tensor); - op.tensor = nullptr; - op.decref(ops); - - // reclaim slot if this was the last created tensor - if (op.refs == 0 && idx+1 == next_op) { - --next_op; - } - } - + void set_unobservable(unsigned idx); void flush(); friend std::ostream& operator<<(std::ostream &os, const Trace &t);