This commit is contained in:
Nuno Lopes 2021-04-27 18:27:03 +01:00
Родитель 2273425c62
Коммит 026737afe8
5 изменённых файлов: 27 добавлений и 29 удалений

Просмотреть файл

@ -2,7 +2,6 @@
// Distributed under the MIT license that can be found in the LICENSE file.
#include "dispatch.h"
#include "ops.h"
#include "tensor.h"
#include "trace.h"
#include <ATen/core/List.h>

Просмотреть файл

@ -5,3 +5,5 @@
#define DISPATCHKEY_NO_NS PrivateUse1
#define DISPATCHKEY DispatchKey::DISPATCHKEY_NO_NS
#define AUTOGRADDISPATCHKEY_NO_NS AutogradPrivateUse1

Просмотреть файл

@ -7,7 +7,6 @@
#undef NDEBUG
#include "tensor.h"
#include "dispatch.h"
#include "ops.h"
#include "trace.h"
#include <ATen/RedispatchFunctions.h>
#include <torch/library.h>
@ -235,12 +234,8 @@ TORCH_LIBRARY_IMPL(aten, DISPATCHKEY_NO_NS, m) {
#include "autogen/torch_library_table.h"
}
#if 0
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl("isfinite", isfinite);
m.impl("reshape", reshape);
m.impl("to.device", to_device);
TORCH_LIBRARY_IMPL(_, AUTOGRADDISPATCHKEY_NO_NS, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
#endif
}

Просмотреть файл

@ -50,6 +50,11 @@ public:
};
}
void TensorOp::incref() {
assert(isObservable());
++refs;
}
void TensorOp::decref(TensorOp *ops) {
assert(refs > 0);
--refs;
@ -176,6 +181,18 @@ void Trace::incref(const List<optional<Tensor>> &l) {
}
}
void Trace::set_unobservable(unsigned idx) {
auto &op = ops[idx];
assert(op.tensor);
op.tensor = nullptr;
op.decref(ops);
// reclaim slot if this was the last created tensor
if (op.refs == 0 && idx+1 == next_op) {
--next_op;
}
}
void Trace::flush() {
#if 1
cout << "Flush trace\n" << *this;

27
trace.h
Просмотреть файл

@ -3,6 +3,7 @@
// Copyright (c) 2021-present The Torchy Authors.
// Distributed under the MIT license that can be found in the LICENSE file.
#include "ops.h"
#include <ATen/Tensor.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/variant.h>
@ -51,18 +52,13 @@ struct TensorOp {
TorchyTensor *tensor;
std::vector<UnionInputTys> args;
c10::DispatchKeySet dispatch_key;
unsigned id;
TorchOp id;
unsigned refs;
void incref() {
assert(isObservable());
++refs;
}
void incref();
void decref(TensorOp *ops);
bool isObservable() const {
assert(!tensor || refs > 0);
return tensor;
}
@ -91,7 +87,7 @@ class Trace {
std::vector<std::unique_ptr<unsigned char[]>> deep_copies;
template<typename T>
at::ArrayRef<T> deep_copy(at::ArrayRef<T> arr) {
at::ArrayRef<T> deep_copy(const at::ArrayRef<T> &arr) {
if (arr.empty())
return arr;
size_t size = arr.size() * sizeof(T);
@ -139,7 +135,7 @@ public:
TensorOp* getOps() { return ops; }
template<typename... T>
unsigned register_tensor(TorchyTensor *tensor, unsigned op_id,
unsigned register_tensor(TorchyTensor *tensor, TorchOp op_id,
c10::DispatchKeySet ks, T&... args) {
assert(!flushing);
if (next_op == MAX_TRACE_LENGTH)
@ -155,18 +151,7 @@ public:
return next_op++;
}
void set_unobservable(unsigned idx) {
auto &op = ops[idx];
assert(op.tensor);
op.tensor = nullptr;
op.decref(ops);
// reclaim slot if this was the last created tensor
if (op.refs == 0 && idx+1 == next_op) {
--next_op;
}
}
void set_unobservable(unsigned idx);
void flush();
friend std::ostream& operator<<(std::ostream &os, const Trace &t);