зеркало из https://github.com/microsoft/torchy.git
add caching mechanism for compiled traces
This commit is contained in:
Родитель
f0e6e2962e
Коммит
5db8fa456d
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
|
||||
// Copyright (c) 2021-present The Torchy Authors.
|
||||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
class Trace;
|
||||
|
||||
struct TorchyBackend {
|
||||
virtual void* compile(const Trace &trace) = 0;
|
||||
virtual void run(const void *prog, Trace &trace) = 0;
|
||||
virtual void destroy(void *prog) = 0;
|
||||
};
|
||||
|
||||
struct Interpreter final : public TorchyBackend {
|
||||
void* compile(const Trace &trace) override { return nullptr; }
|
||||
void run(const void *prog, Trace &trace) override;
|
||||
void destroy(void *prog) override {}
|
||||
};
|
||||
|
||||
struct TorchScript final : public TorchyBackend {
|
||||
void* compile(const Trace &trace) override;
|
||||
void run(const void *prog, Trace &trace) override;
|
||||
void destroy(void *prog) override;
|
||||
};
|
|
@ -1,29 +1,37 @@
|
|||
// Copyright (c) 2021-present The Torchy Authors.
|
||||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include "backends.h"
|
||||
#include "tensor.h"
|
||||
#include "trace.h"
|
||||
|
||||
static void init_update_in_place(TensorOp &op) {
|
||||
for (auto tensor : op.tensors) {
|
||||
static void init_update_in_place(const TraceOpRunTimeData &data) {
|
||||
for (auto tensor : data.tensors) {
|
||||
if (tensor != 0)
|
||||
init_update_in_place(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
static void end_update_in_place(TensorOp &op) {
|
||||
assert(op.tensors[0] != 0);
|
||||
end_update_in_place_first(op.tensors[0]);
|
||||
static void end_update_in_place(const TraceOpRunTimeData &data) {
|
||||
bool first = true;
|
||||
unsigned first_idx;
|
||||
|
||||
for (unsigned i = 1; i < op.tensors.size(); ++i) {
|
||||
if (op.tensors[i] != 0)
|
||||
end_update_in_place_copy(op.tensors[i], op.tensors[0]);
|
||||
for (unsigned i = 0; i < data.tensors.size(); ++i) {
|
||||
if (data.tensors[i] == 0)
|
||||
continue;
|
||||
if (first) {
|
||||
end_update_in_place_first(data.tensors[i]);
|
||||
first_idx = i;
|
||||
} else {
|
||||
end_update_in_place_copy(data.tensors[i], data.tensors[first_idx]);
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
static void finish_trace(TensorOp &op) {
|
||||
for (auto tensor : op.tensors) {
|
||||
static void finish_trace(const TraceOpRunTimeData &data) {
|
||||
for (auto tensor : data.tensors) {
|
||||
if (tensor != 0)
|
||||
finish_trace(tensor);
|
||||
}
|
||||
|
@ -32,8 +40,8 @@ static void finish_trace(TensorOp &op) {
|
|||
# define finish_trace(op) (void)0
|
||||
#endif
|
||||
|
||||
static void set(TensorOp &op, const at::Tensor &t) {
|
||||
for (auto tensor : op.tensors) {
|
||||
static void set(const TraceOpRunTimeData &data, const at::Tensor &t) {
|
||||
for (auto tensor : data.tensors) {
|
||||
if (tensor != 0)
|
||||
set(tensor, t);
|
||||
}
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
// Copyright (c) 2021-present The Torchy Authors.
|
||||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include "autogen/ops_data.h"
|
||||
#include "common.h"
|
||||
#include "dispatch.h"
|
||||
#include "trace.h"
|
||||
#include <ATen/RedispatchFunctions.h>
|
||||
#include <type_traits>
|
||||
|
||||
|
@ -18,30 +16,44 @@ using namespace at;
|
|||
|
||||
namespace {
|
||||
|
||||
struct LoadState {
|
||||
InputData &inputs;
|
||||
Tensor *results;
|
||||
std::vector<std::unique_ptr<std::vector<Tensor>>> tmp_vectors;
|
||||
std::vector<std::unique_ptr<c10::List<c10::optional<Tensor>>>> tmp_lists;
|
||||
|
||||
void reset() {
|
||||
tmp_vectors.clear();
|
||||
tmp_lists.clear();
|
||||
}
|
||||
};
|
||||
|
||||
#define LOAD_ARGS UnionInputTy &arg, LoadState &load_state
|
||||
|
||||
template <typename T>
|
||||
struct load {
|
||||
T operator()(UnionInputTy &arg) {
|
||||
T operator()(LOAD_ARGS) {
|
||||
return std::move(get<T>(arg));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct load<T&> {
|
||||
T& operator()(UnionInputTy &arg) {
|
||||
T& operator()(LOAD_ARGS) {
|
||||
return get<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct load<ArrayRef<T>> {
|
||||
ArrayRef<T> operator()(UnionInputTy &arg) {
|
||||
ArrayRef<T> operator()(LOAD_ARGS) {
|
||||
return get<std::vector<T>>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct load<c10::optional<ArrayRef<T>>> {
|
||||
c10::optional<ArrayRef<T>> operator()(UnionInputTy &arg) {
|
||||
c10::optional<ArrayRef<T>> operator()(LOAD_ARGS) {
|
||||
auto &opt = get<c10::optional<std::vector<T>>>(arg);
|
||||
if (!opt)
|
||||
return c10::nullopt;
|
||||
|
@ -50,89 +62,131 @@ struct load<c10::optional<ArrayRef<T>>> {
|
|||
};
|
||||
|
||||
template <>
|
||||
c10::string_view load<c10::string_view>::operator()(UnionInputTy &arg) {
|
||||
return get<std::string>(arg);
|
||||
}
|
||||
|
||||
template <>
|
||||
c10::optional<c10::string_view>
|
||||
load<c10::optional<c10::string_view>>::operator()(UnionInputTy &arg) {
|
||||
auto &opt = get<c10::optional<std::string>>(arg);
|
||||
if (!opt)
|
||||
return c10::nullopt;
|
||||
return *opt;
|
||||
}
|
||||
|
||||
#include "autogen/interpreter_redispatch_tables.h"
|
||||
|
||||
struct DispatchKeyComputer {
|
||||
c10::DispatchKeySet ks;
|
||||
|
||||
DispatchKeyComputer(c10::DispatchKeySet ks) : ks(ks) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(const T&) {}
|
||||
|
||||
void operator()(const at::Tensor &t) {
|
||||
ks = ks | t.key_set();
|
||||
}
|
||||
|
||||
void operator()(const at::Generator &gen) {
|
||||
if (gen.defined())
|
||||
ks = ks | gen.key_set();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void operator()(const c10::optional<T> &opt) {
|
||||
if (opt)
|
||||
(*this)(*opt);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void operator()(const std::vector<T> &l) {
|
||||
for (const auto &elem : l) {
|
||||
(*this)(elem);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void operator()(const at::List<T> &l) {
|
||||
for (const auto &it : l) {
|
||||
const T &elem = it;
|
||||
(*this)(elem);
|
||||
}
|
||||
struct load<c10::string_view> {
|
||||
c10::string_view operator()(LOAD_ARGS) {
|
||||
return get<std::string>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct load<c10::optional<c10::string_view>> {
|
||||
c10::optional<c10::string_view> operator()(LOAD_ARGS) {
|
||||
auto &opt = get<c10::optional<std::string>>(arg);
|
||||
if (!opt)
|
||||
return c10::nullopt;
|
||||
return *opt;
|
||||
}
|
||||
};
|
||||
|
||||
Tensor& get_tensor(InputIdx idx, LoadState &load_state) {
|
||||
return idx.is_input() ? load_state.inputs[idx.input_idx()].toTensor()
|
||||
: load_state.results[idx.trace_idx()];
|
||||
}
|
||||
|
||||
template <>
|
||||
struct load<Tensor&> {
|
||||
Tensor& operator()(LOAD_ARGS) {
|
||||
return get_tensor(get<InputIdx>(arg), load_state);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct load<optional<Tensor>&> {
|
||||
optional<Tensor> operator()(LOAD_ARGS) {
|
||||
auto idx = get<optional<InputIdx>>(arg);
|
||||
if (!idx)
|
||||
return c10::nullopt;
|
||||
return get_tensor(*idx, load_state);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct load<TensorList> {
|
||||
TensorList operator()(LOAD_ARGS) {
|
||||
auto vect = std::make_unique<std::vector<Tensor>>();
|
||||
for (auto idx : get<std::vector<InputIdx>>(arg)) {
|
||||
vect->emplace_back(get_tensor(idx, load_state));
|
||||
}
|
||||
auto ret = vect.get();
|
||||
load_state.tmp_vectors.emplace_back(std::move(vect));
|
||||
return *ret;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct load<c10::List<c10::optional<Tensor>>&> {
|
||||
c10::List<c10::optional<Tensor>>& operator()(LOAD_ARGS) {
|
||||
auto lst = std::make_unique<c10::List<c10::optional<Tensor>>>();
|
||||
for (auto idx : get<std::vector<c10::optional<InputIdx>>>(arg)) {
|
||||
lst->push_back(
|
||||
idx ? make_optional(get_tensor(*idx, load_state)) : c10::nullopt);
|
||||
}
|
||||
auto ret = lst.get();
|
||||
load_state.tmp_lists.emplace_back(std::move(lst));
|
||||
return *ret;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct load<Storage> {
|
||||
Storage operator()(LOAD_ARGS) {
|
||||
// Storage input is never shared; can be moved
|
||||
return
|
||||
std::move(load_state.inputs[get<InputIdx>(arg).input_idx()]).toStorage();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct load<c10::optional<Generator>> {
|
||||
c10::optional<Generator> operator()(LOAD_ARGS) {
|
||||
auto idx = get<c10::optional<InputIdx>>(arg);
|
||||
if (!idx)
|
||||
return c10::nullopt;
|
||||
return load_state.inputs[idx->input_idx()].toGenerator();
|
||||
}
|
||||
};
|
||||
|
||||
#include "autogen/interpreter_redispatch_tables.h"
|
||||
|
||||
}
|
||||
|
||||
|
||||
namespace interpreter {
|
||||
void Interpreter::run(const void *prog, Trace &t) {
|
||||
Tensor results[MAX_TRACE_LENGTH];
|
||||
LoadState load_state{t.getInputs(), results};
|
||||
|
||||
void run(Trace &t) {
|
||||
ThreadLocalState tls;
|
||||
auto *ops = t.getOps();
|
||||
auto *data = t.getRuntimeData();
|
||||
|
||||
for (unsigned i = 0, e = t.numOps(); i < e; ++i) {
|
||||
auto &op = ops[i];
|
||||
if (!op.needsComputing())
|
||||
auto &rdata = data[i];
|
||||
if (op.dead)
|
||||
continue;
|
||||
|
||||
#ifdef DEBUG_DISPATCH
|
||||
std::cerr << "Dispatch " << op.id << std::endl;
|
||||
#endif
|
||||
|
||||
DispatchKeyComputer visitor(op.dispatch_key);
|
||||
for (auto &arg : op.args) {
|
||||
visit(visitor, arg);
|
||||
auto ks = rdata.dispatch_key;
|
||||
for (auto &arg : t.getInputs()) {
|
||||
if (arg.isTensor()) {
|
||||
ks = ks | arg.toTensor().key_set();
|
||||
} else if (arg.isGenerator()) {
|
||||
const auto &gen = arg.toGenerator();
|
||||
if (gen.defined())
|
||||
ks = ks | gen.key_set();
|
||||
} else {
|
||||
assert(arg.isStorage());
|
||||
}
|
||||
}
|
||||
auto ks
|
||||
= visitor.ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);
|
||||
ks = ks & DispatchKeySet(DispatchKeySet::FULL_AFTER, DISPATCHKEY);
|
||||
|
||||
ThreadLocalState::setThreadLocalState(op.tls);
|
||||
ThreadLocalState::setThreadLocalState(rdata.tls);
|
||||
|
||||
if (op.id >= FIRST_INPLACE_OP)
|
||||
init_update_in_place(op);
|
||||
if (rdata.inplace)
|
||||
init_update_in_place(rdata);
|
||||
|
||||
switch (op.id) {
|
||||
|
||||
|
@ -142,15 +196,17 @@ void run(Trace &t) {
|
|||
assert(0 && "Unhandled op");
|
||||
}
|
||||
|
||||
// generated redispatch code only reaches here for in-place ops
|
||||
end_update_in_place(op);
|
||||
if (rdata.inplace) {
|
||||
end_update_in_place(rdata);
|
||||
} else {
|
||||
set(rdata, results[i]);
|
||||
}
|
||||
load_state.reset();
|
||||
}
|
||||
|
||||
for (unsigned i = 0, e = t.numOps(); i < e; ++i) {
|
||||
finish_trace(ops[i]);
|
||||
finish_trace(data[i]);
|
||||
}
|
||||
|
||||
ThreadLocalState::setThreadLocalState(tls);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
// Copyright (c) 2021-present The Torchy Authors.
|
||||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include "autogen/ops_data.h"
|
||||
#include "common.h"
|
||||
#include "trace.h"
|
||||
#include <torch/csrc/jit/api/method.h>
|
||||
#include <map>
|
||||
|
||||
|
@ -25,17 +23,19 @@ std::string cut_overload(const char *fn) {
|
|||
return dot ? std::string(fn, dot - fn) : fn;
|
||||
}
|
||||
|
||||
using ValueMap = std::map<const TensorImpl*, Value*>;
|
||||
|
||||
class ValGen {
|
||||
Graph &g;
|
||||
ValueMap ↦
|
||||
Stack &inputs;
|
||||
Value **results;
|
||||
std::vector<Value*> inputs;
|
||||
Value *none_val = nullptr;
|
||||
|
||||
public:
|
||||
ValGen(Graph &g, ValueMap &map, Stack &inputs)
|
||||
: g(g), map(map), inputs(inputs) {}
|
||||
ValGen(Graph &g, Value **results, const Stack &in_stack)
|
||||
: g(g), results(results) {
|
||||
for (unsigned i = 0, e = in_stack.size(); i < e; ++i) {
|
||||
inputs.push_back(g.addInput());
|
||||
}
|
||||
}
|
||||
|
||||
Value* mk_none() {
|
||||
if (!none_val) {
|
||||
|
@ -51,13 +51,8 @@ public:
|
|||
return g.insertConstant(a);
|
||||
}
|
||||
|
||||
Value* operator()(const Tensor &t) {
|
||||
auto &v = map[t.getIntrusivePtr().get()];
|
||||
if (!v) {
|
||||
v = g.addInput();
|
||||
inputs.emplace_back(t);
|
||||
}
|
||||
return v;
|
||||
Value* operator()(const InputIdx &in) {
|
||||
return in.is_input() ? inputs[in.input_idx()] : results[in.trace_idx()];
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
@ -75,55 +70,41 @@ public:
|
|||
g.appendNode(n);
|
||||
return n->output();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Value* operator()(const List<T> &l) {
|
||||
std::vector<Value*> vals;
|
||||
for (const auto &it : l) {
|
||||
const T &elem = it;
|
||||
vals.emplace_back((*this)(elem));
|
||||
}
|
||||
auto *n = g.createList(vals[0]->type(), vals);
|
||||
g.appendNode(n);
|
||||
return n->output();
|
||||
}
|
||||
|
||||
// unsupported by TorchScript
|
||||
Value* operator()(const Storage&) {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
struct CompiledProgram {
|
||||
std::unique_ptr<GraphFunction> fn;
|
||||
uint8_t output_ops[MAX_TRACE_LENGTH];
|
||||
unsigned num_outputs = 0;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
|
||||
namespace torchscript {
|
||||
|
||||
bool run(Trace &t) {
|
||||
void* TorchScript::compile(const Trace &t) {
|
||||
auto *ops = t.getOps();
|
||||
Value *results[MAX_TRACE_LENGTH];
|
||||
Value *outputs[MAX_TRACE_LENGTH];
|
||||
uint8_t output_ops[MAX_TRACE_LENGTH];
|
||||
unsigned num_outputs = 0;
|
||||
Stack stack;
|
||||
ValueMap val_map;
|
||||
Value *op_inputs[MAX_NUM_INPUTS];
|
||||
|
||||
auto prog = std::make_unique<CompiledProgram>();
|
||||
auto &output_ops = prog->output_ops;
|
||||
auto &num_outputs = prog->num_outputs;
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
ValGen val_gen(*graph, val_map, stack);
|
||||
ValGen val_gen(*graph, results, t.getInputs());
|
||||
|
||||
for (unsigned i = 0, e = t.numOps(); i < e; ++i) {
|
||||
auto &op = ops[i];
|
||||
if (!op.needsComputing())
|
||||
if (op.dead)
|
||||
continue;
|
||||
|
||||
if (op.id >= FIRST_INPLACE_OP)
|
||||
init_update_in_place(op);
|
||||
|
||||
unsigned num_inputs = 0;
|
||||
for (auto &arg : op.args) {
|
||||
auto *v = visit(val_gen, arg);
|
||||
if (!v) {
|
||||
stats_inc_torchscript_fail();
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
op_inputs[num_inputs++] = v;
|
||||
}
|
||||
|
@ -131,23 +112,23 @@ bool run(Trace &t) {
|
|||
Node *n = graph->create(Symbol::aten(cut_overload(op_name(op.id))),
|
||||
at::ArrayRef<Value*>(op_inputs, num_inputs));
|
||||
if (!n->maybeOperator()) {
|
||||
#ifdef DEBUG_GRAPH
|
||||
std::cerr << "Op not supported by TorchScript: " << op_name(op.id)
|
||||
<< std::endl;
|
||||
// prints a nice error msg with supported overloads
|
||||
n->getOperator();
|
||||
#endif
|
||||
stats_inc_torchscript_fail();
|
||||
return nullptr;
|
||||
}
|
||||
graph->appendNode(n);
|
||||
|
||||
Value *v = n->output();
|
||||
if (op.observable && op.id < FIRST_INPLACE_OP) {
|
||||
results[i] = v;
|
||||
if (op.observable && !op.inplace()) {
|
||||
output_ops[num_outputs] = i;
|
||||
outputs[num_outputs++] = v;
|
||||
}
|
||||
|
||||
for (auto tt : op.tensors) {
|
||||
if (auto *t = is_impl(tt))
|
||||
val_map.emplace(t, v);
|
||||
}
|
||||
}
|
||||
|
||||
if (num_outputs == 0) {
|
||||
|
@ -166,21 +147,42 @@ bool run(Trace &t) {
|
|||
|
||||
assert((graph->lint(), true));
|
||||
|
||||
GraphFunction fn("torchy", move(graph), {});
|
||||
prog->fn
|
||||
= std::make_unique<GraphFunction>("torchy", move(graph),
|
||||
std::function<void(GraphFunction&)>());
|
||||
|
||||
#ifdef DEBUG_GRAPH
|
||||
fn.optimized_graph()->print(std::cerr << "\nOptimized graph:\n");
|
||||
prog->fn->optimized_graph()->print(std::cerr << "\nOptimized graph:\n");
|
||||
std::cerr << '\n';
|
||||
#endif
|
||||
|
||||
fn.run(stack);
|
||||
return prog.release();
|
||||
}
|
||||
|
||||
void TorchScript::run(const void *ptr, Trace &t) {
|
||||
auto prog = (const CompiledProgram*)ptr;
|
||||
auto &output_ops = prog->output_ops;
|
||||
auto num_outputs = prog->num_outputs;
|
||||
auto *data = t.getRuntimeData();
|
||||
|
||||
// FIXME: we don't take TLS or dispatch keys into account here
|
||||
// may break more complicated programs..
|
||||
|
||||
for (unsigned i = 0, e = t.numOps(); i < e; ++i) {
|
||||
auto &op = data[i];
|
||||
if (op.needsComputing() && op.inplace)
|
||||
init_update_in_place(op);
|
||||
}
|
||||
|
||||
auto &stack = t.getInputs();
|
||||
prog->fn->run(stack);
|
||||
// inputs are consumed, and the output is passed back on the stack
|
||||
assert(stack.size() == 1);
|
||||
|
||||
// patch tensors with the output
|
||||
if (num_outputs == 1) {
|
||||
assert(stack[0].isTensor());
|
||||
set(ops[output_ops[0]], std::move(stack[0]).toTensor());
|
||||
set(data[output_ops[0]], std::move(stack[0]).toTensor());
|
||||
}
|
||||
else if (num_outputs > 1) {
|
||||
assert(stack[0].isTuple());
|
||||
|
@ -190,17 +192,18 @@ bool run(Trace &t) {
|
|||
|
||||
for (unsigned i = 0; i < num_outputs; ++i) {
|
||||
assert(elems[i].isTensor());
|
||||
set(ops[output_ops[i]], std::move(elems[i]).toTensor());
|
||||
set(data[output_ops[i]], std::move(elems[i]).toTensor());
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0, e = t.numOps(); i < e; ++i) {
|
||||
auto &op = ops[i];
|
||||
if (op.id >= FIRST_INPLACE_OP)
|
||||
auto &op = data[i];
|
||||
if (op.needsComputing() && op.inplace)
|
||||
end_update_in_place(op);
|
||||
finish_trace(op);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void TorchScript::destroy(void *prog) {
|
||||
delete (CompiledProgram*)prog;
|
||||
}
|
||||
|
|
6
gen.py
6
gen.py
|
@ -427,20 +427,20 @@ def gen_interpreter_redispatch(fn):
|
|||
for i, arg in enumerate(dispatcher_exprs):
|
||||
type = arg.type.cpp_type(strip_ref=False)
|
||||
type = type.replace('const ', '')
|
||||
args.append(f'load<{type}>()(op.args[{i}])')
|
||||
args.append(f'load<{type}>()(op.args[{i}], load_state)')
|
||||
|
||||
redispatch = f'<FN>(ks, {", ".join(args)})'
|
||||
rettype = dispatcher_sig.returns_type().cpp_type()
|
||||
|
||||
if rettype == 'at::Tensor':
|
||||
code = f'set(op, {redispatch});\n continue;'
|
||||
code = f'results[i] = {redispatch};\n break;'
|
||||
inplace = False
|
||||
|
||||
# in-place op
|
||||
else:
|
||||
assert rettype == 'at::Tensor &' or rettype == 'const at::Tensor &'
|
||||
inplace = True
|
||||
code = f'{redispatch};\n break;'
|
||||
code = f'results[i] = {redispatch};\n break;'
|
||||
|
||||
signature = dispatcher_sig.type()
|
||||
fn_ptr = f'at::redispatch::{sig.name()}'
|
||||
|
|
|
@ -1,3 +1,24 @@
|
|||
diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h
|
||||
index dc69764b15..aeb12164fa 100644
|
||||
--- a/c10/core/Scalar.h
|
||||
+++ b/c10/core/Scalar.h
|
||||
@@ -149,6 +149,16 @@ class C10_API Scalar {
|
||||
}
|
||||
}
|
||||
|
||||
+ bool operator==(const Scalar &rhs) const {
|
||||
+ if (tag != rhs.tag)
|
||||
+ return false;
|
||||
+ switch (tag) {
|
||||
+ case Tag::HAS_d: return v.d == rhs.v.d;
|
||||
+ case Tag::HAS_z: return v.z == rhs.v.z;
|
||||
+ default: return v.i == rhs.v.i;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
private:
|
||||
template <
|
||||
typename T,
|
||||
diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h
|
||||
index ff29b68dc4..cfae0dc783 100644
|
||||
--- a/c10/core/StorageImpl.h
|
||||
|
|
18
stats.cpp
18
stats.cpp
|
@ -92,6 +92,7 @@ array<unsigned, (unsigned)FlushReason::NUM_REASONS> flush_reasons_count;
|
|||
array<unsigned, MAX_TRACE_LENGTH+1> trace_size;
|
||||
array<unsigned, MAX_TRACE_LENGTH+1> num_trace_outputs;
|
||||
array<unsigned, MAX_TRACE_LENGTH+1> num_trace_deads;
|
||||
vector<pair<float, string>> trace_compile_time;
|
||||
unordered_map<string, vector<float>> trace_run_time;
|
||||
unordered_map<string, unordered_map<string, unsigned>> trace_successors;
|
||||
string first_trace, current_trace, last_trace;
|
||||
|
@ -148,6 +149,17 @@ struct PrintStats {
|
|||
<< p.first << "\n\n";
|
||||
}
|
||||
|
||||
print_header("Slowest Trace Compilation");
|
||||
sort(trace_compile_time.begin(), trace_compile_time.end());
|
||||
{
|
||||
auto I = trace_compile_time.rbegin(), E = trace_compile_time.rend();
|
||||
for (unsigned i = 0; i < 10 && I != E; ++i, ++I) {
|
||||
cerr << "Trace compiled in "
|
||||
<< unsigned(I->first * 1000000.0) << " us\n"
|
||||
<< I->second << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "Number of Torchscript failures:\t" << torchscript_failures
|
||||
<< "\nNumber of unsupported ops:\t" << unsupported_wrappers
|
||||
<< "\nNumber of traces:\t" << total
|
||||
|
@ -269,7 +281,7 @@ void stats_register_trace(const Trace &t, FlushReason reason) {
|
|||
for (unsigned i = 0; i < num_ops; ++i) {
|
||||
auto &op = ops[i];
|
||||
num_outputs += op.observable;
|
||||
num_deads += !op.needsComputing();
|
||||
num_deads += op.dead;
|
||||
}
|
||||
assert(num_outputs > 0);
|
||||
++num_trace_outputs[num_outputs];
|
||||
|
@ -285,6 +297,10 @@ void stats_register_trace(const Trace &t, FlushReason reason) {
|
|||
first_trace = current_trace;
|
||||
}
|
||||
|
||||
void stats_register_compile_time(const StopWatch &run_time) {
|
||||
trace_compile_time.emplace_back(run_time.seconds(), current_trace);
|
||||
}
|
||||
|
||||
void stats_register_trace_time(const StopWatch &run_time) {
|
||||
trace_run_time[current_trace].emplace_back(run_time.seconds());
|
||||
|
||||
|
|
2
stats.h
2
stats.h
|
@ -33,6 +33,7 @@ class Trace;
|
|||
|
||||
#define STATS(x) x
|
||||
void stats_register_trace(const Trace &t, FlushReason reason);
|
||||
void stats_register_compile_time(const StopWatch &run_time);
|
||||
void stats_register_trace_time(const StopWatch &run_time);
|
||||
void stats_inc_unsupported_wrapper();
|
||||
void stats_inc_torchscript_fail();
|
||||
|
@ -41,6 +42,7 @@ void stats_inc_torchscript_fail();
|
|||
|
||||
#define STATS(x)
|
||||
#define stats_register_trace(t, reason) (void)0
|
||||
#define stats_register_compile_time(run_time) (void)0
|
||||
#define stats_register_trace_time(run_time) (void)0
|
||||
#define stats_inc_unsupported_wrapper() (void)0
|
||||
#define stats_inc_torchscript_fail() (void)0
|
||||
|
|
385
trace.cpp
385
trace.cpp
|
@ -2,6 +2,7 @@
|
|||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include "trace.h"
|
||||
#include "backends/backends.h"
|
||||
#include "stopwatch.h"
|
||||
#include "tensor.h"
|
||||
#include "utils.h"
|
||||
|
@ -15,90 +16,113 @@
|
|||
using namespace at;
|
||||
using namespace std;
|
||||
|
||||
static bool force_interpreter = getenv("TORCHY_FORCE_INTERPRETER");
|
||||
static Interpreter interpreter;
|
||||
static TorchScript torchscript;
|
||||
static TorchyBackend *backend
|
||||
= getenv("TORCHY_FORCE_INTERPRETER")
|
||||
? (TorchyBackend*)&interpreter : (TorchyBackend*)&torchscript;
|
||||
|
||||
namespace interpreter { void run(Trace &t); }
|
||||
namespace torchscript { bool run(Trace &t); }
|
||||
bool TraceOpDef::operator==(const TraceOpDef &rhs) const {
|
||||
return id == rhs.id && observable == rhs.observable && dead == rhs.dead &&
|
||||
args == rhs.args;
|
||||
}
|
||||
|
||||
void TensorOp::destroy() {
|
||||
void TraceOpDef::destroy() {
|
||||
args.clear();
|
||||
}
|
||||
|
||||
void TraceOpRunTimeData::destroy() {
|
||||
tls.~ThreadLocalState();
|
||||
}
|
||||
|
||||
void TensorOp::incref() {
|
||||
assert(observable);
|
||||
++refs;
|
||||
assert(refs != 0);
|
||||
void Trace::incref(unsigned idx) {
|
||||
assert(idx < next_op);
|
||||
auto &op = ops[idx];
|
||||
auto &rdata = data[idx];
|
||||
assert(op.observable && !op.dead && rdata.refs > 0);
|
||||
++rdata.refs;
|
||||
assert(rdata.refs != 0);
|
||||
}
|
||||
|
||||
void TensorOp::decref(TensorOp *ops) {
|
||||
assert(refs > 0);
|
||||
--refs;
|
||||
void Trace::decref(unsigned idx) {
|
||||
assert(idx < next_op);
|
||||
auto &op = ops[idx];
|
||||
auto &rdata = data[idx];
|
||||
assert(!op.dead && rdata.refs > 0);
|
||||
--rdata.refs;
|
||||
|
||||
if (refs == 0) {
|
||||
TensorVisitor visitor([&](const Tensor &t) {
|
||||
auto idx = trace_idx(t);
|
||||
if (idx != -1u)
|
||||
ops[idx].decref(ops);
|
||||
});
|
||||
for (auto &arg : args) {
|
||||
if (rdata.refs == 0) {
|
||||
TensorVisitor visitor([&](InputIdx idx) {
|
||||
if (idx.is_trace())
|
||||
decref(idx.trace_idx());
|
||||
}, inputs);
|
||||
for (auto &arg : op.args) {
|
||||
visit(visitor, arg);
|
||||
}
|
||||
assert(!observable && !hasTensors());
|
||||
destroy();
|
||||
assert(!op.observable && !rdata.hasTensors());
|
||||
op.dead = true;
|
||||
op.destroy();
|
||||
rdata.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorOp::hasTensors() const {
|
||||
bool TraceOpRunTimeData::hasTensors() const {
|
||||
return find_if(tensors.begin(), tensors.end(), [](auto t) { return t != 0; })
|
||||
!= tensors.end();
|
||||
}
|
||||
|
||||
unsigned TensorOp::numTensors() const {
|
||||
unsigned TraceOpRunTimeData::numTensors() const {
|
||||
return count_if(tensors.begin(), tensors.end(), [](auto t) { return t!=0; });
|
||||
}
|
||||
|
||||
uintptr_t TensorOp::someTensor() const {
|
||||
uintptr_t TraceOpRunTimeData::someTensor() const {
|
||||
auto I = find_if(tensors.begin(), tensors.end(),
|
||||
[](auto t) { return t != 0 && t != DUMMY_TORCHY; });
|
||||
return I == tensors.end() ? 0 : *I;
|
||||
}
|
||||
|
||||
bool TensorOp::operator!=(const Tensor &t) const {
|
||||
auto *t_ptr = t.getIntrusivePtr().get();
|
||||
for (auto ptr : tensors) {
|
||||
if (ptr == (uintptr_t)t_ptr)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
bool TraceCacheKey::operator==(const TraceCacheKey &rhs) const {
|
||||
if (num_ops != rhs.num_ops)
|
||||
return false;
|
||||
return equal(ops.get(), &ops[num_ops], rhs.ops.get());
|
||||
}
|
||||
|
||||
size_t TraceCacheKeyHash::operator()(const TraceCacheKey &key) const {
|
||||
size_t hash = 0;
|
||||
// we only hash the prefix of the trace ops
|
||||
// this enables us to discover quickly traces that need deoptimization
|
||||
// and prefix of traces for speculative execution
|
||||
unsigned bits = 11;
|
||||
unsigned max_rounds = (sizeof(size_t) * 8) / bits;
|
||||
for (unsigned i = 0; i < min(key.num_ops, max_rounds); ++i) {
|
||||
hash = (hash << bits) | key.ops[i].id;
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
TraceCacheData::~TraceCacheData() {
|
||||
if (backend)
|
||||
backend->destroy(program);
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
using InputMap = map<const TensorImpl*, unsigned>;
|
||||
|
||||
class printer {
|
||||
ostream &os;
|
||||
InputMap &inputs;
|
||||
const TensorOp &op;
|
||||
unsigned op_idx;
|
||||
|
||||
public:
|
||||
printer(ostream &os, InputMap &inputs, const TensorOp &op, unsigned op_idx)
|
||||
: os(os), inputs(inputs), op(op), op_idx(op_idx) {}
|
||||
printer(ostream &os) : os(os) {}
|
||||
|
||||
template<typename T>
|
||||
ostream& operator()(const T &a) {
|
||||
return os << a;
|
||||
}
|
||||
|
||||
ostream& operator()(const Tensor &t) {
|
||||
auto idx = trace_idx(t);
|
||||
if (idx != -1u && (op != t || idx < op_idx))
|
||||
return os << '%' << idx;
|
||||
|
||||
auto n = inputs.emplace(t.getIntrusivePtr().get(),
|
||||
(unsigned)inputs.size()).first->second;
|
||||
return os << "in<" << n << '>';
|
||||
ostream& operator()(const InputIdx &idx) {
|
||||
if (idx.is_input())
|
||||
return os << "in<" << idx.input_idx() << '>';
|
||||
return os << '%' << idx.trace_idx();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
@ -119,97 +143,120 @@ public:
|
|||
}
|
||||
return os << ']';
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ostream& operator()(const List<T> &l) {
|
||||
os << '(';
|
||||
bool first = true;
|
||||
for (const auto &it : l) {
|
||||
if (!first) os << ", ";
|
||||
first = false;
|
||||
|
||||
const T &elem = it;
|
||||
(*this)(elem);
|
||||
}
|
||||
return os << ')';
|
||||
}
|
||||
|
||||
ostream& operator()(const Storage &s) {
|
||||
if (!s)
|
||||
return os << "storage(null)";
|
||||
return os << "storage(" << s.nbytes() << ')';
|
||||
}
|
||||
|
||||
ostream& operator()(const Generator &g) {
|
||||
if (!g.defined())
|
||||
return os << "generator(null)";
|
||||
return os << "generator(" << g.current_seed() << ", " << g.device() << ')';
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
void TensorOp::print(ostream &os, InputMap &inputs, unsigned idx) const {
|
||||
auto t = someTensor();
|
||||
void Trace::print(ostream &os, unsigned idx) const {
|
||||
assert(idx < next_op);
|
||||
auto &op = ops[idx];
|
||||
auto &rdata = data[idx];
|
||||
|
||||
auto t = rdata.someTensor();
|
||||
if (t && tensor_has_dtype(t))
|
||||
os << '<' << tensor_get_dtype(t) << "> ";
|
||||
os << id;
|
||||
os << op.id;
|
||||
|
||||
if (!needsComputing()) {
|
||||
if (op.dead) {
|
||||
os << " [dead]";
|
||||
return;
|
||||
}
|
||||
|
||||
bool first = true;
|
||||
for (auto &arg : args) {
|
||||
for (auto &arg : op.args) {
|
||||
os << (first ? " " : ", ");
|
||||
first = false;
|
||||
|
||||
visit(printer(os, inputs, *this, idx), arg);
|
||||
visit(printer(os), arg);
|
||||
}
|
||||
|
||||
auto n_tensors = numTensors();
|
||||
assert(n_tensors >= observable);
|
||||
assert(refs >= n_tensors);
|
||||
auto n_tensors = rdata.numTensors();
|
||||
assert(n_tensors >= op.observable);
|
||||
assert(rdata.refs >= n_tensors);
|
||||
|
||||
if (refs > 0)
|
||||
os << " #refs E/I=" << n_tensors << '/' << (refs - n_tensors);
|
||||
if (rdata.refs > 0)
|
||||
os << " #refs E/I=" << n_tensors << '/' << (rdata.refs - n_tensors);
|
||||
|
||||
if (observable)
|
||||
if (op.observable)
|
||||
os << " #output";
|
||||
|
||||
if (t && tensor_has_shape(t))
|
||||
os << " shape=" << tensor_get_shape(t);
|
||||
}
|
||||
|
||||
|
||||
Trace::~Trace() {
|
||||
destroyed = true;
|
||||
}
|
||||
|
||||
void Trace::incref(const Tensor &t) {
|
||||
InputIdx Trace::get_tensor_idx(const Tensor &t) {
|
||||
auto idx = trace_idx(t);
|
||||
if (idx != -1u) {
|
||||
assert(idx < next_op);
|
||||
ops[idx].incref();
|
||||
// check if this is also an input tensor
|
||||
// e.g. for inplace ops we have %0 = op %0 <-- but we want in<0> there
|
||||
if (idx != next_op-1) {
|
||||
incref(idx);
|
||||
return { idx, false };
|
||||
}
|
||||
}
|
||||
inputs.emplace_back(t);
|
||||
return { (unsigned)inputs.size()-1, true };
|
||||
}
|
||||
|
||||
void Trace::incref(const optional<Tensor> &t) {
|
||||
if (t)
|
||||
incref(*t);
|
||||
void Trace::append_arg(const Tensor &t) {
|
||||
auto val = get_tensor_idx(t);
|
||||
ops[next_op-1].args.emplace_back(move(val));
|
||||
}
|
||||
|
||||
void Trace::incref(const TensorList &l) {
|
||||
for (auto &t : l) {
|
||||
incref(t);
|
||||
void Trace::append_arg(ArrayRef<Tensor> arg) {
|
||||
vector<InputIdx> val;
|
||||
for (auto &t : arg) {
|
||||
val.emplace_back(get_tensor_idx(t));
|
||||
}
|
||||
ops[next_op-1].args.emplace_back(move(val));
|
||||
}
|
||||
|
||||
void Trace::incref(const List<optional<Tensor>> &l) {
|
||||
for (const auto &t : l) {
|
||||
const optional<Tensor> &opt = t;
|
||||
incref(opt);
|
||||
void Trace::append_arg(optional<Tensor> arg) {
|
||||
optional<InputIdx> val;
|
||||
if (arg)
|
||||
val = get_tensor_idx(*arg);
|
||||
ops[next_op-1].args.emplace_back(move(val));
|
||||
}
|
||||
|
||||
void Trace::append_arg(const List<optional<Tensor>> &arg) {
|
||||
vector<optional<InputIdx>> val;
|
||||
for (const auto &it : arg) {
|
||||
const optional<Tensor> &in = it;
|
||||
optional<InputIdx> elem;
|
||||
if (in)
|
||||
elem = get_tensor_idx(*in);
|
||||
val.emplace_back(move(elem));
|
||||
}
|
||||
ops[next_op-1].args.emplace_back(move(val));
|
||||
}
|
||||
|
||||
void Trace::append_arg(Storage &&arg) {
|
||||
ops[next_op-1].args.emplace_back(InputIdx(inputs.size(), true));
|
||||
inputs.emplace_back(move(arg));
|
||||
}
|
||||
|
||||
void Trace::append_arg(optional<Generator> &&arg) {
|
||||
optional<InputIdx> val;
|
||||
if (arg) {
|
||||
val = InputIdx(inputs.size(), true);
|
||||
inputs.emplace_back(move(*arg));
|
||||
}
|
||||
ops[next_op-1].args.emplace_back(move(val));
|
||||
}
|
||||
|
||||
void Trace::append_arg(string_view arg) {
|
||||
ops[next_op-1].args.emplace_back(string(arg.data(), arg.size()));
|
||||
}
|
||||
|
||||
void Trace::append_arg(optional<string_view> arg) {
|
||||
optional<string> copy;
|
||||
if (arg)
|
||||
copy = string(arg->data(), arg->size());
|
||||
ops[next_op-1].args.emplace_back(move(copy));
|
||||
}
|
||||
|
||||
unsigned Trace::register_tensor(uintptr_t tensor, TorchOp op_id,
|
||||
|
@ -237,27 +284,29 @@ unsigned Trace::register_tensor(uintptr_t tensor, TorchOp op_id,
|
|||
flush(STATS(FlushReason::TRACE_MAX_LENGTH));
|
||||
|
||||
auto &op = ops[next_op];
|
||||
op.tensors[0] = tensor;
|
||||
for (unsigned i = 1; i < op.tensors.size(); ++i) {
|
||||
op.tensors[i] = 0;
|
||||
}
|
||||
op.id = op_id;
|
||||
assert(op.args.empty());
|
||||
op.refs = 1;
|
||||
op.observable = true;
|
||||
op.tls = ThreadLocalState();
|
||||
op.dispatch_key = ks;
|
||||
op.dead = false;
|
||||
|
||||
auto &rdata = data[next_op];
|
||||
rdata.tensors[0] = tensor;
|
||||
for (unsigned i = 1; i < rdata.tensors.size(); ++i) {
|
||||
rdata.tensors[i] = 0;
|
||||
}
|
||||
rdata.refs = 1;
|
||||
rdata.tls = ThreadLocalState();
|
||||
rdata.dispatch_key = ks;
|
||||
rdata.inplace = op.inplace();
|
||||
return next_op++;
|
||||
}
|
||||
|
||||
void Trace::add_shared(unsigned idx, uintptr_t ptr) {
|
||||
assert(idx < next_op);
|
||||
auto &op = ops[idx];
|
||||
|
||||
for (auto &tensor : op.tensors) {
|
||||
for (auto &tensor : data[idx].tensors) {
|
||||
if (tensor == 0) {
|
||||
tensor = ptr;
|
||||
op.incref();
|
||||
incref(idx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -273,10 +322,11 @@ void Trace::set_unobservable(unsigned idx, uintptr_t ptr) {
|
|||
return;
|
||||
|
||||
assert(idx < next_op);
|
||||
auto &op = ops[idx];
|
||||
auto &op = ops[idx];
|
||||
auto &rdata = data[idx];
|
||||
|
||||
bool found = false;
|
||||
for (auto &tensor : op.tensors) {
|
||||
for (auto &tensor : rdata.tensors) {
|
||||
if (tensor == ptr) {
|
||||
tensor = 0;
|
||||
found = true;
|
||||
|
@ -285,81 +335,72 @@ void Trace::set_unobservable(unsigned idx, uintptr_t ptr) {
|
|||
}
|
||||
assert(found); (void)found;
|
||||
|
||||
op.observable = op.hasTensors();
|
||||
op.decref(ops);
|
||||
op.observable = rdata.hasTensors();
|
||||
decref(idx);
|
||||
|
||||
// reclaim slot if this was the last created tensor
|
||||
if (op.refs == 0 && idx+1 == next_op) {
|
||||
if (rdata.refs == 0 && idx+1 == next_op) {
|
||||
op.destroy();
|
||||
rdata.destroy();
|
||||
--next_op;
|
||||
}
|
||||
}
|
||||
|
||||
TraceCacheKey Trace::mk_trace_key() {
|
||||
TraceOpDef *new_ops = new TraceOpDef[next_op];
|
||||
// we can't move the args for the interpreter as it traverses the args
|
||||
// in run() rather than compile() -- a NOP
|
||||
std::uninitialized_copy_n(ops, next_op, new_ops);
|
||||
// FIXME: C++17 only
|
||||
//std::uninitialized_move_n(ops, next_op, new_ops);
|
||||
//for (unsigned i = 0; i < next_op; ++i) {
|
||||
// new (new_ops + i) TraceOpDef(move(ops[i]));
|
||||
//}
|
||||
return TraceCacheKey(new_ops, next_op);
|
||||
}
|
||||
|
||||
void Trace::flush(STATS(FlushReason reason)) {
|
||||
assert(!flushing);
|
||||
flushing = true;
|
||||
|
||||
// trim set of observable tensors as the references in arguments keep the
|
||||
// tensors alive and therefore we aren't notified the user's program
|
||||
// can't observe these tensors anymore
|
||||
// TODO: benchmark: should we skip this when running the interpreter that
|
||||
// doesn't really benefit from this information?
|
||||
{
|
||||
// tensor impl -> (refs, trace idx)
|
||||
unordered_map<uintptr_t, pair<uint16_t, uint16_t>> refs;
|
||||
refs.reserve(next_op);
|
||||
|
||||
TensorVisitor visitor([&](const Tensor &t) {
|
||||
auto I = refs.find((uintptr_t)t.getIntrusivePtr().get());
|
||||
// all refs are inputs -> not observable
|
||||
if (I != refs.end() && --I->second.first == 0) {
|
||||
refs.erase(I);
|
||||
auto &argop = ops[I->second.second];
|
||||
argop.observable = false;
|
||||
argop.decref(ops);
|
||||
}
|
||||
});
|
||||
|
||||
for (unsigned i = 0; i < next_op; ++i) {
|
||||
auto &op = ops[i];
|
||||
|
||||
if (i > 0) {
|
||||
for (auto &arg : op.args) {
|
||||
visit(visitor, arg);
|
||||
}
|
||||
}
|
||||
|
||||
if (op.observable) {
|
||||
for (auto tensor : op.tensors) {
|
||||
if (tensor == 0 || tensor == DUMMY_TORCHY)
|
||||
continue;
|
||||
refs.emplace(tensor,
|
||||
// -1 as reclaim adds +1 ref
|
||||
make_pair(intrusive_ptr<TensorImpl>::unsafe_reclaim_from_nonowning(
|
||||
(TensorImpl*)tensor).use_count()-1, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stats_register_trace(*this, reason);
|
||||
|
||||
#ifdef TORCHY_PRINT_TRACE_ON_FLUSH
|
||||
cerr << "Flush trace\n" << *this << endl;
|
||||
#endif
|
||||
|
||||
TraceCacheKey key = { ops, next_op };
|
||||
auto I = cache.find(key);
|
||||
key.ops.release(); // so the destructor doesn't kick in
|
||||
|
||||
if (I == cache.end()) {
|
||||
STATS(StopWatch time);
|
||||
auto *program = backend->compile(*this);
|
||||
auto *used_backend = backend;
|
||||
|
||||
// fallback to the interpreter if the default backend can't handle this
|
||||
if (!program) {
|
||||
program = interpreter.compile(*this);
|
||||
used_backend = &interpreter;
|
||||
}
|
||||
STATS(time.stop());
|
||||
stats_register_compile_time(time);
|
||||
|
||||
I = cache.emplace(piecewise_construct, forward_as_tuple(mk_trace_key()),
|
||||
forward_as_tuple(program, used_backend)).first;
|
||||
}
|
||||
|
||||
STATS(StopWatch run_time);
|
||||
// try torchscript first; fallback to the interpreter if it can't handle this
|
||||
if (force_interpreter || !torchscript::run(*this))
|
||||
interpreter::run(*this);
|
||||
|
||||
I->second.backend->run(I->second.program, *this);
|
||||
STATS(run_time.stop());
|
||||
|
||||
stats_register_trace_time(run_time);
|
||||
|
||||
// reduce reference count on tensors s.t. they are deleted if possible
|
||||
for (unsigned i = 0; i < next_op; ++i) {
|
||||
ops[i].destroy();
|
||||
data[i].destroy();
|
||||
}
|
||||
inputs.clear();
|
||||
|
||||
next_op = 0;
|
||||
flushing = false;
|
||||
|
@ -369,22 +410,30 @@ ostream& operator<<(ostream &os, const Trace &t) {
|
|||
if (t.next_op == 0)
|
||||
return os << "(empty)\n";
|
||||
|
||||
map<const TensorImpl*, unsigned> inputs_map;
|
||||
for (unsigned i = 0; i < t.next_op; ++i) {
|
||||
os << '%' << i << " = ";
|
||||
t.ops[i].print(os, inputs_map, i);
|
||||
t.print(os, i);
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
if (inputs_map.empty())
|
||||
if (t.inputs.empty())
|
||||
return os;
|
||||
|
||||
os << "\nInputs' shapes:\n";
|
||||
for (unsigned i = 0, e = inputs_map.size(); i != e; ++i) {
|
||||
auto I = find_if(inputs_map.begin(), inputs_map.end(),
|
||||
[=](auto &p) { return p.second == i; });
|
||||
assert(I != inputs_map.end());
|
||||
os << "in<" << i << ">: " << I->first->sizes() << '\n';
|
||||
os << "\nInputs:\n";
|
||||
unsigned i = 0;
|
||||
for (auto &in : t.inputs) {
|
||||
os << "in<" << i++ << ">: ";
|
||||
if (in.isTensor()) {
|
||||
os << "tensor(" << in.toTensor().sizes() << ')';
|
||||
} else if (in.isGenerator()) {
|
||||
const auto &g = in.toGenerator();
|
||||
os << "generator(" << g.current_seed() << ", " << g.device() << ')';
|
||||
} else if (in.isStorage()) {
|
||||
os << "storage(" << in.toStorage().nbytes() << ')';
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
os << '\n';
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
|
158
trace.h
158
trace.h
|
@ -4,7 +4,10 @@
|
|||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include "config.h"
|
||||
#include "autogen/ops_data.h"
|
||||
#include "backends/backends.h"
|
||||
#include "ops.h"
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
|
@ -15,28 +18,47 @@
|
|||
#include <map>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define MAX_TRACE_LENGTH 64
|
||||
|
||||
class InputIdx {
|
||||
int idx;
|
||||
public:
|
||||
InputIdx(unsigned idx, bool input) : idx(input ? idx : ~idx) {}
|
||||
|
||||
bool is_input() const { return idx >= 0; }
|
||||
bool is_trace() const { return idx < 0; }
|
||||
|
||||
unsigned input_idx() const {
|
||||
assert(is_input());
|
||||
return idx;
|
||||
}
|
||||
|
||||
unsigned trace_idx() const {
|
||||
assert(!is_input());
|
||||
return ~idx;
|
||||
}
|
||||
|
||||
bool operator==(const InputIdx &rhs) const { return idx == rhs.idx; }
|
||||
};
|
||||
|
||||
using UnionInputTy = c10::variant<
|
||||
bool,
|
||||
double,
|
||||
int64_t,
|
||||
InputIdx,
|
||||
at::Device,
|
||||
at::Dimname,
|
||||
at::MemoryFormat,
|
||||
at::ScalarType,
|
||||
at::Storage,
|
||||
at::Tensor,
|
||||
c10::List<c10::optional<at::Tensor>>,
|
||||
c10::optional<bool>,
|
||||
c10::optional<double>,
|
||||
c10::optional<int64_t>,
|
||||
c10::optional<at::Generator>,
|
||||
c10::optional<InputIdx>,
|
||||
c10::optional<at::MemoryFormat>,
|
||||
c10::optional<at::Scalar>,
|
||||
c10::optional<at::Tensor>,
|
||||
c10::optional<c10::Device>,
|
||||
c10::optional<c10::Layout>,
|
||||
c10::optional<c10::ScalarType>,
|
||||
|
@ -44,39 +66,46 @@ using UnionInputTy = c10::variant<
|
|||
c10::Scalar,
|
||||
std::string,
|
||||
std::vector<long>,
|
||||
std::vector<InputIdx>,
|
||||
std::vector<c10::optional<InputIdx>>,
|
||||
std::vector<at::Dimname>,
|
||||
std::vector<at::Tensor>,
|
||||
c10::optional<std::vector<double>>,
|
||||
c10::optional<std::vector<long>>,
|
||||
c10::optional<std::vector<at::Dimname>>
|
||||
>;
|
||||
|
||||
struct TensorOp {
|
||||
// TODO: measure typical amount of sharing
|
||||
std::array<uintptr_t, 3> tensors;
|
||||
// TODO: investigate if specializing this for the common case
|
||||
// e.g. 2 tensors makes sense (would save space + 1 mem alloc)
|
||||
struct TraceOpDef {
|
||||
std::vector<UnionInputTy> args;
|
||||
TorchOp id;
|
||||
bool observable;
|
||||
bool dead;
|
||||
|
||||
bool inplace() const {
|
||||
return id >= FIRST_INPLACE_OP;
|
||||
}
|
||||
|
||||
bool operator==(const TraceOpDef &rhs) const;
|
||||
|
||||
private:
|
||||
void destroy();
|
||||
|
||||
friend class Trace;
|
||||
};
|
||||
|
||||
|
||||
struct TraceOpRunTimeData {
|
||||
std::array<uintptr_t, 3> tensors;
|
||||
at::ThreadLocalState tls;
|
||||
c10::DispatchKeySet dispatch_key;
|
||||
TorchOp id;
|
||||
uint16_t refs;
|
||||
bool observable;
|
||||
bool inplace;
|
||||
|
||||
bool needsComputing() const {
|
||||
return refs > 0;
|
||||
}
|
||||
|
||||
bool operator!=(const at::Tensor &t) const;
|
||||
|
||||
void print(std::ostream &os,
|
||||
std::map<const at::TensorImpl*, unsigned> &inputs,
|
||||
unsigned idx) const;
|
||||
|
||||
private:
|
||||
void destroy();
|
||||
void incref();
|
||||
void decref(TensorOp *ops);
|
||||
bool hasTensors() const;
|
||||
unsigned numTensors() const;
|
||||
uintptr_t someTensor() const;
|
||||
|
@ -85,66 +114,95 @@ private:
|
|||
};
|
||||
|
||||
|
||||
struct TraceCacheKey {
|
||||
std::unique_ptr<TraceOpDef[]> ops;
|
||||
unsigned num_ops;
|
||||
// TODO: some backends may want shape information of inputs to specialize code
|
||||
|
||||
TraceCacheKey(TraceOpDef *ops, unsigned num_ops)
|
||||
: ops(ops), num_ops(num_ops) {}
|
||||
|
||||
bool operator==(const TraceCacheKey &rhs) const;
|
||||
};
|
||||
|
||||
struct TraceCacheKeyHash {
|
||||
size_t operator()(const TraceCacheKey &key) const;
|
||||
};
|
||||
|
||||
struct TraceCacheData {
|
||||
void *program = nullptr;
|
||||
TorchyBackend *backend = nullptr;
|
||||
|
||||
TraceCacheData(void *program, TorchyBackend *backend)
|
||||
: program(program), backend(backend) {}
|
||||
|
||||
TraceCacheData(TraceCacheData &&other) {
|
||||
std::swap(program, other.program);
|
||||
std::swap(backend, other.backend);
|
||||
}
|
||||
|
||||
~TraceCacheData();
|
||||
};
|
||||
|
||||
using InputData = std::vector<c10::IValue>;
|
||||
|
||||
class Trace {
|
||||
TensorOp ops[MAX_TRACE_LENGTH];
|
||||
TraceOpDef ops[MAX_TRACE_LENGTH];
|
||||
TraceOpRunTimeData data[MAX_TRACE_LENGTH];
|
||||
InputData inputs;
|
||||
|
||||
std::unordered_map<TraceCacheKey, TraceCacheData, TraceCacheKeyHash> cache;
|
||||
unsigned next_op = 0;
|
||||
bool flushing = false;
|
||||
bool destroyed = false;
|
||||
|
||||
template <typename T>
|
||||
void incref(const T &t) {}
|
||||
void incref(unsigned idx);
|
||||
void decref(unsigned idx);
|
||||
InputIdx get_tensor_idx(const at::Tensor &t);
|
||||
TraceCacheKey mk_trace_key();
|
||||
|
||||
void incref(const at::Tensor &t);
|
||||
void incref(const c10::optional<at::Tensor> &t);
|
||||
void incref(const at::TensorList &l);
|
||||
void incref(const c10::List<c10::optional<at::Tensor>> &l);
|
||||
void print(std::ostream &os, unsigned idx) const;
|
||||
|
||||
public:
|
||||
~Trace();
|
||||
|
||||
bool is_flushing() const { return flushing; }
|
||||
unsigned numOps() const { return next_op; }
|
||||
const TensorOp* getOps() const { return ops; }
|
||||
TensorOp* getOps() { return ops; }
|
||||
const TraceOpDef* getOps() const { return ops; }
|
||||
TraceOpDef* getOps() { return ops; }
|
||||
const TraceOpRunTimeData* getRuntimeData() const { return data; }
|
||||
InputData& getInputs() { return inputs; }
|
||||
const InputData& getInputs() const { return inputs; }
|
||||
|
||||
template<typename A>
|
||||
void append_arg(A &&arg) {
|
||||
incref(arg);
|
||||
ops[next_op-1].args.emplace_back(std::forward<A>(arg));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void append_arg(at::ArrayRef<T> arg) {
|
||||
incref(arg);
|
||||
ops[next_op-1].args.emplace_back(arg.vec());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void append_arg(c10::optional<at::ArrayRef<T>> arg) {
|
||||
c10::optional<std::vector<T>> copy;
|
||||
if (arg) {
|
||||
incref(*arg);
|
||||
copy = arg->vec();
|
||||
}
|
||||
ops[next_op-1].args.emplace_back(std::move(copy));
|
||||
}
|
||||
|
||||
void append_arg(c10::string_view arg) {
|
||||
ops[next_op-1].args.emplace_back(std::string(arg.data(), arg.size()));
|
||||
}
|
||||
|
||||
void append_arg(c10::optional<c10::string_view> arg) {
|
||||
c10::optional<std::string> copy;
|
||||
if (arg)
|
||||
copy = std::string(arg->data(), arg->size());
|
||||
copy = arg->vec();
|
||||
ops[next_op-1].args.emplace_back(std::move(copy));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void append_arg(const c10::List<T> &arg) {
|
||||
incref(arg);
|
||||
ops[next_op-1].args.emplace_back(arg.copy());
|
||||
void append_arg(const at::Tensor &arg);
|
||||
void append_arg(at::Tensor &arg) {
|
||||
append_arg(const_cast<const at::Tensor&>(arg));
|
||||
}
|
||||
void append_arg(at::ArrayRef<at::Tensor> arg);
|
||||
void append_arg(c10::optional<at::Tensor> arg);
|
||||
void append_arg(const c10::List<c10::optional<at::Tensor>> &arg);
|
||||
void append_arg(at::Storage &&arg);
|
||||
void append_arg(c10::optional<at::Generator> &&arg);
|
||||
void append_arg(c10::string_view arg);
|
||||
void append_arg(c10::optional<c10::string_view> arg);
|
||||
|
||||
unsigned register_tensor(uintptr_t tensor, TorchOp op_id,
|
||||
c10::DispatchKeySet ks);
|
||||
|
|
22
utils.h
22
utils.h
|
@ -3,20 +3,22 @@
|
|||
// Copyright (c) 2021-present The Torchy Authors.
|
||||
// Distributed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
class InputIdx;
|
||||
|
||||
class TensorVisitor {
|
||||
std::function<void(const at::Tensor&)> visit;
|
||||
std::function<void(InputIdx)> visit;
|
||||
InputData &inputs;
|
||||
|
||||
public:
|
||||
TensorVisitor(std::function<void(const at::Tensor&)> &&visit)
|
||||
: visit(std::move(visit)) {}
|
||||
TensorVisitor(std::function<void(InputIdx)> &&visit, InputData &inputs)
|
||||
: visit(std::move(visit)), inputs(inputs) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(const T&) {}
|
||||
|
||||
void operator()(const at::Tensor &t) {
|
||||
visit(t);
|
||||
void operator()(const InputIdx &idx) {
|
||||
if (idx.is_trace() || inputs[idx.input_idx()].isTensor())
|
||||
visit(idx);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
@ -31,12 +33,4 @@ public:
|
|||
(*this)(elem);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void operator()(const at::List<T> &l) {
|
||||
for (const auto &it : l) {
|
||||
const T &elem = it;
|
||||
(*this)(elem);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
Загрузка…
Ссылка в новой задаче