This commit is contained in:
Nuno Lopes 2021-05-28 14:30:09 +01:00
Родитель 72ae8f16fc
Коммит dc4ca34e0d
6 изменённых файлов: 140 добавлений и 19 удалений

Просмотреть файл

@ -3,6 +3,8 @@
// Copyright (c) 2021-present The Torchy Authors.
// Distributed under the MIT license that can be found in the LICENSE file.
//#define TORCHY_RELEASE
#define DUMMY_TORCHY 0x1
#ifdef TORCHY_RELEASE
@ -18,5 +20,8 @@
#endif
#define TORCHY_PRINT_TRACE_ON_FLUSH
#define TORCHY_ENABLE_STATS
#endif
#include "stats.h"

77
stats.cpp Normal file
Просмотреть файл

@ -0,0 +1,77 @@
// Copyright (c) 2021-present The Torchy Authors.
// Distributed under the MIT license that can be found in the LICENSE file.
#include "stats.h"
#ifdef TORCHY_ENABLE_STATS
#include <iostream>
using namespace std;
#define NUM_ELEMS(a) (sizeof(a) / sizeof(*a))
namespace {
const char* flush_reasons[] = {
"dim",
"has_storage",
"is_contiguous",
"numel",
"set_size",
"set_storage_offset",
"set_stride",
"size",
"sizes",
"storage",
"storage_offset",
"stride",
"strides",
"trace max length",
"unsupported operation",
};
static_assert(NUM_ELEMS(flush_reasons) == FlushReason::NUM_REASONS);
unsigned flush_reasons_count[FlushReason::NUM_REASONS] = {0};
struct PrintStats {
~PrintStats() {
cerr << "\n\n------------ STATISTICS ------------\n";
print_table("Trace Flush Reason", flush_reasons_count, flush_reasons,
FlushReason::NUM_REASONS):
cerr << endl;
}
void print_table(const char *header, unsigned *data, const char **labels,
unsigned size) {
cerr << header << ":";
unsigned max_label = 0;
for (unsigned i = 0; i < size; ++i) {
max_label = max(max_label, strlen(labels[i]));
}
for (unsigned i = 0; i < size; ++i) {
cerr << label[i] << ": ";
pad(label[i], max_label);
cerr << data[i] << '\n';
}
cerr << '\n';
}
void pad(const char *str, unsigned length) {
for (unsigned i = strlen(str); i < length; ++i) {
cerr << ' ';
}
}
};
PrintStats printer;
}
void inc_flush_reason(FlushReason reason) {
++flush_reasons_count[reason];
}
#endif

37
stats.h Normal file
Просмотреть файл

@ -0,0 +1,37 @@
#pragma once
// Copyright (c) 2021-present The Torchy Authors.
// Distributed under the MIT license that can be found in the LICENSE file.
// Please update stats.cpp string version of this when adding a new reason!
enum class FlushReason {
DIM,
HAS_STORAGE,
IS_CONTIGUOUS,
NUMEL,
SET_SIZE,
SET_STORAGE_OFFSET,
SET_STRIDE,
SIZE,
SIZES,
STORAGE,
STORAGE_OFFSET,
STRIDE,
STRIDES,
TRACE_MAX_LENGTH,
UNSUPPORTED_OPERATION,
NUM_REASONS
};
#ifdef TORCHY_ENABLE_STATS
#define STATS(x) x
void inc_flush_reason(FlushReason reason);
#else
#define STATS(x)
#define inc_flush_reason(x) (void)0
#endif

Просмотреть файл

@ -97,10 +97,10 @@ public:
set_materialized(true);
}
void ensure_materialized() const {
void ensure_materialized(STATS(FlushReason reason)) const {
if (!trace.is_flushing() && !materialized()) {
assert(trace_idx != -1u);
trace.flush();
trace.flush(STATS(reason));
assert(!storage_ || materialized());
}
}
@ -117,42 +117,42 @@ public:
// an extra indirection. Another way is to templatize these.
IntArrayRef sizes() const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::SIZES));
return TensorImpl::sizes();
}
IntArrayRef strides() const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::STRIDES));
return TensorImpl::strides();
}
int64_t dim() const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::DIM));
return TensorImpl::dim();
}
bool has_storage() const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::HAS_STORAGE));
return TensorImpl::has_storage();
}
const Storage& storage() const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::STORAGE));
return TensorImpl::storage();
}
int64_t numel() const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::NUMEL));
return TensorImpl::numel();
}
bool is_contiguous(at::MemoryFormat memory_format) const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::IS_CONTIGUOUS));
return TensorImpl::is_contiguous(memory_format);
}
int64_t storage_offset() const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::STORAGE_OFFSET));
return TensorImpl::storage_offset();
}
@ -161,27 +161,27 @@ public:
}
void set_size(int64_t dim, int64_t new_size) override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::SET_SIZE));
TensorImpl::set_size(dim, new_size);
}
void set_stride(int64_t dim, int64_t new_stride) override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::SET_STRIDE));
TensorImpl::set_stride(dim, new_stride);
}
void set_storage_offset(int64_t storage_offset) override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::SET_STORAGE_OFFSET));
TensorImpl::set_storage_offset(storage_offset);
}
int64_t size(int64_t d) const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::SIZE));
return TensorImpl::size(d);
}
int64_t stride(int64_t d) const override {
ensure_materialized();
ensure_materialized(STATS(FlushReason::STRIDE));
return TensorImpl::stride(d);
}
@ -291,7 +291,7 @@ void finish_in_place(TorchyTensor *tt, unsigned idx) {
void ensure_materialized(const Tensor &t) {
if (auto tt = is_torchy(t))
tt->ensure_materialized();
tt->ensure_materialized(STATS(FlushReason::UNSUPPORTED_OPERATION));
}
void ensure_materialized(const optional<Tensor> &t) {

Просмотреть файл

@ -190,7 +190,7 @@ unsigned Trace::register_tensor(uintptr_t tensor, TorchOp op_id,
#endif
if (next_op == MAX_TRACE_LENGTH)
flush();
flush(FlushReason::TRACE_MAX_LENGTH);
auto &op = ops[next_op];
op.tensors[0] = tensor;
@ -250,10 +250,12 @@ void Trace::set_unobservable(unsigned idx, uintptr_t ptr) {
}
}
void Trace::flush() {
void Trace::flush(STATS(FlushReason reason)) {
assert(!flushing);
flushing = true;
inc_flush_reason(reason);
// trim set of observable tensors as the references in arguments keep the
// tensors alive and therefore we aren't notified the user's program
// can't observe these tensors anymore

Просмотреть файл

@ -153,7 +153,7 @@ public:
void add_shared(unsigned idx, uintptr_t ptr);
void set_unobservable(unsigned idx, uintptr_t ptr);
void flush();
void flush(STATS(FlushReason reason));
friend std::ostream& operator<<(std::ostream &os, const Trace &t);
};