// Copyright (c) 2021-present The Torchy Authors. // Distributed under the MIT license that can be found in the LICENSE file. #include "trace.h" #include "backends/backends.h" #include "stopwatch.h" #include "tensor.h" #include "utils.h" #include #include #include #include #include #include using namespace at; using namespace std; static Interpreter interpreter; static TorchScript torchscript; static TorchyBackend *backend = getenv("TORCHY_FORCE_INTERPRETER") ? (TorchyBackend*)&interpreter : (TorchyBackend*)&torchscript; bool TraceOpDef::operator==(const TraceOpDef &rhs) const { return id == rhs.id && observable == rhs.observable && dead == rhs.dead && args == rhs.args; } void TraceOpDef::destroy() { args.clear(); } void TraceOpRunTimeData::destroy() { ((ThreadLocalState*)&tls)->~ThreadLocalState(); } void Trace::incref(unsigned idx) { assert(idx < next_op); auto &op = ops[idx]; auto &rdata = data[idx]; (void)op; assert(op.observable && !op.dead && rdata.refs > 0); ++rdata.refs; assert(rdata.refs != 0); } void Trace::decref(unsigned idx) { assert(idx < next_op); auto &op = ops[idx]; auto &rdata = data[idx]; assert(!op.dead && rdata.refs > 0); // We can't declare ops dead for sure because of aliasing. A non-inplace op // may return a new tensor, but sharing storage with another tensor // (e.g., view, reshape). // Just because the result of these ops is dead it doesn't mean there isn't // another tensor out there with the same storage. // See tests/unit/inplace_dead_alias.py //--rdata.refs; if (rdata.refs == 0) { TensorVisitor visitor([&](InputIdx idx) { if (idx.is_trace()) decref(idx.trace_idx()); }, inputs); for (auto &arg : op.args) { visit(visitor, arg); } assert(!op.observable && !rdata.hasTensors()); op.dead = true; op.destroy(); rdata.destroy(); } } bool TraceOpRunTimeData::hasTensors() const { return find_if(tensors.begin(), tensors.end(), [](auto t) { return t != 0; }) != tensors.end(); } unsigned TraceOpRunTimeData::numTensors() const { return count_if(tensors.begin(), tensors.end(), [](auto t) { return t!=0; }); } uintptr_t TraceOpRunTimeData::someTensor() const { auto I = find_if(tensors.begin(), tensors.end(), [](auto t) { return t != 0 && t != DUMMY_TORCHY; }); return I == tensors.end() ? 0 : *I; } bool TraceCacheKey::operator==(const TraceCacheKey &rhs) const { return num_ops == rhs.num_ops && equal(ops.get(), &ops[num_ops], rhs.ops.get()); } size_t TraceCacheKeyHash::operator()(const TraceCacheKey &key) const { size_t hash = 0; // we only hash the prefix of the trace ops // this enables us to discover quickly traces that need deoptimization // and prefix of traces for speculative execution unsigned bits = 11; unsigned max_rounds = (sizeof(size_t) * 8) / bits; for (unsigned i = 0; i < min(key.num_ops, max_rounds); ++i) { hash = (hash << bits) | key.ops[i].id; } return hash; } TraceCacheData::~TraceCacheData() { if (backend) backend->destroy(program); } namespace { class printer { ostream &os; public: printer(ostream &os) : os(os) {} template ostream& operator()(const T &a) { return os << a; } ostream& operator()(const InputIdx &idx) { if (idx.is_input()) return os << "in<" << idx.input_idx() << '>'; return os << '%' << idx.trace_idx(); } template ostream& operator()(const optional &a) { if (!a) return os << "(null)"; return (*this)(*a); } template ostream& operator()(const vector &l) { os << '['; bool first = true; for (const auto &elem : l) { if (!first) os << ", "; first = false; (*this)(elem); } return os << ']'; } }; } static void print_op(ostream &os, unsigned idx, const TraceOpDef &op, const TraceOpRunTimeData *rdata) { os << '%' << idx << " = "; auto t = rdata ? rdata->someTensor() : 0; if (t && tensor_has_dtype(t)) os << '<' << tensor_get_dtype(t) << "> "; os << op.id; if (op.dead) { os << " [dead]\n"; return; } bool first = true; for (auto &arg : op.args) { os << (first ? " " : ", "); first = false; visit(printer(os), arg); } if (rdata) { auto n_tensors = rdata->numTensors(); assert(n_tensors >= op.observable); assert(rdata->refs >= n_tensors); if (rdata->refs > 0) os << " #refs E/I=" << n_tensors << '/' << (rdata->refs - n_tensors); } if (op.observable) os << " #output"; if (t && tensor_has_shape(t)) os << " shape=" << tensor_get_shape(t); if (t && tensor_has_strides(t)) os << " strides=" << tensor_get_strides(t); os << '\n'; } void Trace::print(ostream &os, unsigned idx) const { assert(idx < next_op); print_op(os, idx, ops[idx], &data[idx]); } Trace::~Trace() { destroyed = true; #if 0 cerr << "NUM BUCKETS: " << cache.bucket_count() << '\n'; unsigned collisions = 0; for (unsigned i = 0; i < cache.bucket_count(); ++i) { auto sz = cache.bucket_size(i); if (sz <= 1) continue; cerr << i << ": " << sz << '\n'; collisions += sz; for (auto &p : cache) { auto &k = p.first; if (cache.bucket(k) == i) { cerr << "HASH: " << TraceCacheKeyHash()(k) << '\n'; for (unsigned i = 0; i < k.num_ops; ++i) { print_op(cerr, i, k.ops[i], nullptr); } } } } cerr << "TOTAL COLLISIONS = " << collisions << endl; #endif } bool Trace::is_input(const c10::TensorImpl &t) const { for (auto &in : inputs) { if (in.isTensor() && in.toTensor().unsafeGetTensorImpl() == &t) return true; } return false; } InputIdx Trace::get_tensor_idx(const Tensor &t) { auto idx = trace_idx(t); if (idx != -1u) { assert(idx < next_op); // check if this is also an input tensor // e.g. for inplace ops we have %0 = op %0 <-- but we want in<0> there if (idx != next_op-1) { incref(idx); return { idx, false }; } } // check if tensor is already an input unsigned i = 0; for (auto &in : inputs) { if (in.isTensor() && in.toTensor().unsafeGetTensorImpl() == t.unsafeGetTensorImpl()) return { i, true }; ++i; } inputs.emplace_back(t); return { (unsigned)inputs.size()-1, true }; } void Trace::append_arg(const Tensor &t) { ops[next_op-1].args.emplace_back(get_tensor_idx(t)); } void Trace::append_arg(ArrayRef arg) { vector val; for (auto &t : arg) { val.emplace_back(get_tensor_idx(t)); } ops[next_op-1].args.emplace_back(move(val)); } void Trace::append_arg(const optional &arg) { optional val; if (arg) val = get_tensor_idx(*arg); ops[next_op-1].args.emplace_back(move(val)); } void Trace::append_arg(const List> &arg) { vector> val; for (const auto &it : arg) { const optional &in = it; optional elem; if (in) elem = get_tensor_idx(*in); val.emplace_back(move(elem)); } ops[next_op-1].args.emplace_back(move(val)); } void Trace::append_arg(Storage &&arg) { ops[next_op-1].args.emplace_back(InputIdx(inputs.size(), true)); inputs.emplace_back(move(arg)); } void Trace::append_arg(optional &&arg) { optional val; if (arg) { val = InputIdx(inputs.size(), true); inputs.emplace_back(move(*arg)); } ops[next_op-1].args.emplace_back(move(val)); } void Trace::append_arg(string_view arg) { ops[next_op-1].args.emplace_back(string(arg.data(), arg.size())); } void Trace::append_arg(optional arg) { optional copy; if (arg) copy = string(arg->data(), arg->size()); ops[next_op-1].args.emplace_back(move(copy)); } unsigned Trace::register_tensor(uintptr_t tensor, TorchOp op_id, c10::DispatchKeySet ks) { assert(!flushing); #ifndef TORCHY_RELEASE // FOR DEBUGGING ONLY. Can be used to binary search a trace that goes wrong static unsigned call_count = 0; ++call_count; if (auto *limit = getenv("TORCHY_FLUSH_BEFORE")) { if (next_op != 0 && call_count <= (unsigned)atoi(limit)) flush(STATS(FlushReason::DEBUG)); } if (auto *limit = getenv("TORCHY_FLUSH_AFTER")) { if (next_op != 0 && call_count > (unsigned)atoi(limit)) flush(STATS(FlushReason::DEBUG)); } if (auto *limit = getenv("TORCHY_MAX_TRACE_LENGTH")) { if (next_op == (unsigned)atoi(limit)) flush(STATS(FlushReason::TRACE_MAX_LENGTH)); } #endif if (next_op == MAX_TRACE_LENGTH) flush(STATS(FlushReason::TRACE_MAX_LENGTH)); auto &op = ops[next_op]; op.id = op_id; assert(op.args.empty()); op.observable = true; op.dead = false; auto &rdata = data[next_op]; rdata.tensors[0] = tensor; for (unsigned i = 1; i < rdata.tensors.size(); ++i) { rdata.tensors[i] = 0; } rdata.refs = 1; new (&rdata.tls) ThreadLocalState(); rdata.dispatch_key = ks; rdata.inplace = op.inplace(); return next_op++; } void Trace::add_shared(unsigned idx, uintptr_t ptr) { assert(idx < next_op); for (auto &tensor : data[idx].tensors) { if (tensor == 0) { tensor = ptr; incref(idx); return; } } // no more space for additional observers; just flush flush(STATS(FlushReason::OVERFLOW_SHARED_LIST)); } void Trace::set_unobservable(unsigned idx, uintptr_t ptr) { // technically this accesses memory that has been deallocated already // but since it's a global and it's just a bool, whatever.. if (destroyed) return; assert(idx < next_op); auto &op = ops[idx]; auto &rdata = data[idx]; bool found = false; for (auto &tensor : rdata.tensors) { if (tensor == ptr) { tensor = 0; found = true; break; } } assert(found); (void)found; op.observable = rdata.hasTensors(); decref(idx); // reclaim slot if this was the last created tensor if (rdata.refs == 0 && idx+1 == next_op) { op.destroy(); rdata.destroy(); --next_op; } } TraceCacheKey Trace::mk_trace_key() { TraceOpDef *new_ops = new TraceOpDef[next_op]; // we can't move the args for the interpreter as it traverses the args // in run() rather than compile() -- a NOP std::uninitialized_copy_n(ops, next_op, new_ops); return TraceCacheKey(new_ops, next_op); } void Trace::flush(STATS(FlushReason reason)) { assert(!flushing); assert(next_op > 0); flushing = true; stats_register_trace(*this, reason); #ifdef TORCHY_PRINT_TRACE_ON_FLUSH cerr << "Flush trace\n" << *this << endl; #endif TraceCacheKey key = { ops, next_op }; auto I = cache.find(key); key.ops.release(); // so the destructor doesn't kick in if (I == cache.end()) { STATS(StopWatch time); auto *program = backend->compile(*this); auto *used_backend = backend; // fallback to the interpreter if the default backend can't handle this if (!program) { program = interpreter.compile(*this); used_backend = &interpreter; } STATS(time.stop()); stats_register_compile_time(time); I = cache.emplace(piecewise_construct, forward_as_tuple(mk_trace_key()), forward_as_tuple(program, used_backend)).first; } STATS(StopWatch run_time); I->second.backend->run(I->second.program, *this); STATS(run_time.stop()); stats_register_trace_time(run_time); // reduce reference count on tensors s.t. they are deleted if possible for (unsigned i = 0; i < next_op; ++i) { ops[i].destroy(); data[i].destroy(); } inputs.clear(); next_op = 0; flushing = false; } ostream& operator<<(ostream &os, const Trace &t) { if (t.next_op == 0) return os << "(empty)\n"; for (unsigned i = 0; i < t.next_op; ++i) { t.print(os, i); } if (t.inputs.empty()) return os; os << "\nInputs:\n"; unsigned i = 0; for (auto &in : t.inputs) { os << "in<" << i++ << ">: "; if (in.isTensor()) { const auto &t = in.toTensor(); os << "tensor(" << t.scalar_type() << " : " << t.sizes(); if (!dynamic_cast(t.unsafeGetTensorImpl())) os << " / " << t.strides(); os << ")\n"; } else if (in.isGenerator()) { const auto &g = in.toGenerator(); os << "generator(" << g.current_seed() << ", " << g.device() << ")\n"; } else if (in.isStorage()) { os << "storage(" << in.toStorage().nbytes() << ")\n"; } else { assert(0); } } return os; }