зеркало из https://github.com/microsoft/torchy.git
add autograd fallback handler
This commit is contained in:
Родитель
2273425c62
Коммит
026737afe8
|
@ -2,7 +2,6 @@
|
|||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include "dispatch.h"
|
||||
#include "ops.h"
|
||||
#include "tensor.h"
|
||||
#include "trace.h"
|
||||
#include <ATen/core/List.h>
|
||||
|
|
|
@ -5,3 +5,5 @@
|
|||
|
||||
#define DISPATCHKEY_NO_NS PrivateUse1
|
||||
#define DISPATCHKEY DispatchKey::DISPATCHKEY_NO_NS
|
||||
|
||||
#define AUTOGRADDISPATCHKEY_NO_NS AutogradPrivateUse1
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
#undef NDEBUG
|
||||
#include "tensor.h"
|
||||
#include "dispatch.h"
|
||||
#include "ops.h"
|
||||
#include "trace.h"
|
||||
#include <ATen/RedispatchFunctions.h>
|
||||
#include <torch/library.h>
|
||||
|
@ -235,12 +234,8 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) {
|
|||
#include "autogen/torch_library_table.h"
|
||||
}
|
||||
|
||||
#if 0
|
||||
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
|
||||
m.impl("isfinite", isfinite);
|
||||
m.impl("reshape", reshape);
|
||||
m.impl("to.device", to_device);
|
||||
TORCH_LIBRARY_IMPL(_, AUTOGRADDISPATCHKEY_NO_NS, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
|
|
17
trace.cpp
17
trace.cpp
|
@ -50,6 +50,11 @@ public:
|
|||
};
|
||||
}
|
||||
|
||||
void TensorOp::incref() {
|
||||
assert(isObservable());
|
||||
++refs;
|
||||
}
|
||||
|
||||
void TensorOp::decref(TensorOp *ops) {
|
||||
assert(refs > 0);
|
||||
--refs;
|
||||
|
@ -176,6 +181,18 @@ void Trace::incref(const List<optional<Tensor>> &l) {
|
|||
}
|
||||
}
|
||||
|
||||
void Trace::set_unobservable(unsigned idx) {
|
||||
auto &op = ops[idx];
|
||||
assert(op.tensor);
|
||||
op.tensor = nullptr;
|
||||
op.decref(ops);
|
||||
|
||||
// reclaim slot if this was the last created tensor
|
||||
if (op.refs == 0 && idx+1 == next_op) {
|
||||
--next_op;
|
||||
}
|
||||
}
|
||||
|
||||
void Trace::flush() {
|
||||
#if 1
|
||||
cout << "Flush trace\n" << *this;
|
||||
|
|
27
trace.h
27
trace.h
|
@ -3,6 +3,7 @@
|
|||
// Copyright (c) 2021-present The Torchy Authors.
|
||||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include "ops.h"
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/util/variant.h>
|
||||
|
@ -51,18 +52,13 @@ struct TensorOp {
|
|||
TorchyTensor *tensor;
|
||||
std::vector<UnionInputTys> args;
|
||||
c10::DispatchKeySet dispatch_key;
|
||||
unsigned id;
|
||||
TorchOp id;
|
||||
unsigned refs;
|
||||
|
||||
void incref() {
|
||||
assert(isObservable());
|
||||
++refs;
|
||||
}
|
||||
|
||||
void incref();
|
||||
void decref(TensorOp *ops);
|
||||
|
||||
bool isObservable() const {
|
||||
assert(!tensor || refs > 0);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
|
@ -91,7 +87,7 @@ class Trace {
|
|||
std::vector<std::unique_ptr<unsigned char[]>> deep_copies;
|
||||
|
||||
template<typename T>
|
||||
at::ArrayRef<T> deep_copy(at::ArrayRef<T> arr) {
|
||||
at::ArrayRef<T> deep_copy(const at::ArrayRef<T> &arr) {
|
||||
if (arr.empty())
|
||||
return arr;
|
||||
size_t size = arr.size() * sizeof(T);
|
||||
|
@ -139,7 +135,7 @@ public:
|
|||
TensorOp* getOps() { return ops; }
|
||||
|
||||
template<typename... T>
|
||||
unsigned register_tensor(TorchyTensor *tensor, unsigned op_id,
|
||||
unsigned register_tensor(TorchyTensor *tensor, TorchOp op_id,
|
||||
c10::DispatchKeySet ks, T&... args) {
|
||||
assert(!flushing);
|
||||
if (next_op == MAX_TRACE_LENGTH)
|
||||
|
@ -155,18 +151,7 @@ public:
|
|||
return next_op++;
|
||||
}
|
||||
|
||||
void set_unobservable(unsigned idx) {
|
||||
auto &op = ops[idx];
|
||||
assert(op.tensor);
|
||||
op.tensor = nullptr;
|
||||
op.decref(ops);
|
||||
|
||||
// reclaim slot if this was the last created tensor
|
||||
if (op.refs == 0 && idx+1 == next_op) {
|
||||
--next_op;
|
||||
}
|
||||
}
|
||||
|
||||
void set_unobservable(unsigned idx);
|
||||
void flush();
|
||||
|
||||
friend std::ostream& operator<<(std::ostream &os, const Trace &t);
|
||||
|
|
Загрузка…
Ссылка в новой задаче