зеркало из https://github.com/microsoft/ZooTracer.git
275 строки
10 KiB
C++
275 строки
10 KiB
C++
#include <ztTrace.h>
|
|
#include <ztLog.h>
|
|
#include <sstream>
|
|
|
|
#include "TracePoint.h"
|
|
|
|
#include <ppl.h>
|
|
#include <agents.h>
|
|
using concurrency::task_group;
|
|
|
|
using namespace zt;
|
|
using std::vector;
|
|
using std::shared_ptr;
|
|
using std::make_shared;
|
|
using std::dynamic_pointer_cast;
|
|
using std::tuple;
|
|
using std::get;
|
|
using std::pair;
|
|
|
|
// Builds and holds one trace.
|
|
class zt::Trace::Implementation : concurrency::agent{
|
|
public:
|
|
Implementation(const KDTreeSource& kdtree_source, const TraceParameters&);
|
|
~Implementation();
|
|
|
|
const Patch* get_tracepoint(FrameIndex) const;
|
|
void fix(FrameIndex, Patch, const Descriptor&);
|
|
bool is_fixed(FrameIndex);
|
|
void occlude(FrameIndex);
|
|
bool is_forced_occluded(FrameIndex);
|
|
void reset(FrameIndex);
|
|
void rerun(const TraceParameters&);
|
|
void subscribe(Handler event_handler){ _subscription = event_handler; }
|
|
protected:
|
|
void run() override;
|
|
private:
|
|
TraceParameters _pars;
|
|
const KDTreeSource& _kdtree_source;
|
|
vector<shared_ptr<TracePoint>> _trace;
|
|
Handler _subscription; // notifies a change in a trace segment.
|
|
task_group _build; // builds trace segments in parallel.
|
|
void start_build(); // starts trace optimization.
|
|
void cancel_build(); // cancel building the trace.
|
|
using Msg = tuple<FrameIndex, shared_ptr<TracePoint>>;
|
|
static Msg make_msg(FrameIndex i, TracePoint* tp){ return std::make_tuple(i, shared_ptr<TracePoint>(tp)); }
|
|
concurrency::unbounded_buffer<Msg> _jobs; // accepts jobs from 'fix', 'occlude' and 'reset' .
|
|
void change_the_trace(FrameIndex, shared_ptr<TracePoint>); // make changes to the trace according to user input.
|
|
shared_ptr<TracePointAuto> _occlusion{ new TracePointAuto() };
|
|
shared_ptr<TracePointOccluded> _trace_end{ new TracePointOccluded() };
|
|
bool is_keyframe(FrameIndex frame){ return TracePoint::is_keyframe(*_trace[frame]); }
|
|
bool is_occluded(FrameIndex frame){ return TracePoint::is_occluded(*_trace[frame]); }
|
|
bool is_auto(FrameIndex frame){ return TracePoint::is_auto(*_trace[frame]); }
|
|
};
|
|
|
|
Trace::Trace(const KDTreeSource& source, const TraceParameters& pars) :impl(new Implementation(source, pars)){}
|
|
Trace::~Trace(){}
|
|
|
|
// Get location of the traced object at the frame. Returns nullptr if the traced object is occluded at the frame.
|
|
const Patch* Trace::get_tracepoint(FrameIndex frame) const{ return impl->get_tracepoint(frame); }
|
|
void Trace::fix(FrameIndex frame, Patch location, const Descriptor& descriptor){ impl->fix(frame, location, descriptor); }
|
|
bool Trace::is_fixed(FrameIndex frame){ return impl->is_fixed(frame); }
|
|
void Trace::occlude(FrameIndex frame){ impl->occlude(frame); }
|
|
bool Trace::is_forced_occluded(FrameIndex frame){ return impl->is_forced_occluded(frame); }
|
|
void Trace::reset(FrameIndex frame){ impl->reset(frame); }
|
|
void Trace::rerun(const TraceParameters& pars){ impl->rerun(pars); }
|
|
|
|
// Links notification handler to the trace. The 'nullptr' value clears the subscription. The handler may be called from a different thread.
|
|
void Trace::subscribe(Handler event_handler){ impl->subscribe(event_handler); }
|
|
|
|
|
|
Trace::Implementation::Implementation(const KDTreeSource& kdtree_source, const TraceParameters& pars)
|
|
: _kdtree_source(kdtree_source), _subscription(nullptr), _pars(pars)
|
|
{
|
|
int n = _kdtree_source.num_frames();
|
|
for (int i = 0; i < n; i++) _trace.push_back(std::make_shared<TracePointAuto>());
|
|
start(); // starts the agent (see run() method).
|
|
}
|
|
Trace::Implementation::~Implementation() {
|
|
_jobs.enqueue(make_msg(-1, new TracePointOccluded()));
|
|
wait(this);
|
|
Log::write("trace destroyed");
|
|
}
|
|
|
|
|
|
// returns trace point location or nullptr if occluded.
|
|
const Patch* Trace::Implementation::get_tracepoint(FrameIndex frame) const{
|
|
if (frame < 0 || frame >= static_cast<FrameIndex>(_trace.size())) throw std::out_of_range("frame");
|
|
return _trace[frame]->location();
|
|
}
|
|
|
|
void Trace::Implementation::fix(FrameIndex frame, Patch location, const Descriptor& descriptor){
|
|
_jobs.enqueue(make_msg(frame, new TracePointKeyFrame{ location, descriptor }));
|
|
}
|
|
|
|
bool Trace::Implementation::is_fixed(FrameIndex frame){ return is_keyframe(frame); }
|
|
|
|
void Trace::Implementation::occlude(FrameIndex frame){
|
|
if (!is_occluded(frame))
|
|
_jobs.enqueue(make_msg(frame, new TracePointOccluded()));
|
|
}
|
|
|
|
bool Trace::Implementation::is_forced_occluded(FrameIndex frame){ return is_occluded(frame); }
|
|
|
|
void Trace::Implementation::reset(FrameIndex frame){
|
|
if (!is_auto(frame))
|
|
_jobs.enqueue(make_msg(frame, new TracePointAuto()));
|
|
}
|
|
|
|
void Trace::Implementation::rerun(const TraceParameters& pars){
|
|
_pars = pars;
|
|
_jobs.enqueue(make_msg(0, nullptr)); }
|
|
|
|
|
|
// agent message loop
|
|
void Trace::Implementation::run(){
|
|
try{
|
|
while (true)
|
|
{
|
|
auto msg = concurrency::receive(_jobs);
|
|
cancel_build();
|
|
if (get<0>(msg) < 0) break;
|
|
change_the_trace(get<0>(msg), move(get<1>(msg)));
|
|
while (concurrency::try_receive(_jobs, msg))
|
|
{
|
|
if (get<0>(msg) < 0) break;
|
|
change_the_trace(get<0>(msg), move(get<1>(msg)));
|
|
}
|
|
start_build();
|
|
}
|
|
}
|
|
catch (std::exception e)
|
|
{
|
|
Log::write("Exception in trace agent: " + std::string(e.what()));
|
|
}
|
|
done();
|
|
}
|
|
|
|
void Trace::Implementation::change_the_trace(FrameIndex frame, shared_ptr<TracePoint> tp){
|
|
bool add = false;
|
|
bool all = false;
|
|
if (nullptr == tp){
|
|
add = all = true;
|
|
}
|
|
else if (TracePoint::is_occluded(*tp)){
|
|
add = all = is_keyframe(frame);
|
|
_trace[frame] = std::make_shared<TracePointOccluded>();
|
|
}
|
|
else if (TracePoint::is_auto(*tp) && !is_auto(frame)){
|
|
add = all = true;
|
|
_trace[frame] = std::make_shared<TracePointAuto>();
|
|
}
|
|
else if (TracePoint::is_keyframe(*tp)){
|
|
add = true;
|
|
all = is_keyframe(frame);
|
|
_trace[frame] = make_shared<TracePointKeyFrame>(std::move(dynamic_cast<TracePointKeyFrame&>(*tp.get())));
|
|
}
|
|
else throw std::exception("internal error");
|
|
if (add && all){
|
|
// full trace preparation
|
|
for (size_t i = 0; i < _trace.size(); i++){
|
|
if (is_auto(i)){
|
|
auto p = make_shared<TracePointAuto>(); // reset the point
|
|
if (_kdtree_source.is_ready(i)){
|
|
auto kdt = _kdtree_source[i];
|
|
vector<pair<FrameIndex, shared_ptr<TracePointKeyFrame>>> kfs;
|
|
for (size_t k = 0; k < _trace.size(); k++){
|
|
if (is_keyframe(k)){
|
|
int key_frame = static_cast<int>(kfs.size());
|
|
kfs.push_back(pair<FrameIndex, shared_ptr<TracePointKeyFrame>>(k, dynamic_pointer_cast<TracePointKeyFrame>(_trace[k])));
|
|
auto kd_matches = kdt->getMatches(kfs[key_frame].second->descriptor(), _pars.num_matches(), _pars.match_ratio());
|
|
vector<pair<Patch, Descriptor>> tp_matches;
|
|
for (auto& m : kd_matches)
|
|
tp_matches.push_back(pair<Patch, Descriptor>(Patch(get<0>(m), get<1>(m)), get<3>(m)));
|
|
p->add_matches(kfs, key_frame, tp_matches, _pars.max_matches_per_frame(), _pars.appearance_threshold());
|
|
}
|
|
}
|
|
}
|
|
_trace[i] = move(p);
|
|
}
|
|
}
|
|
}
|
|
else if (add){
|
|
// incremental trace preparation
|
|
for (size_t i = 0; i < _trace.size(); i++){
|
|
if (is_auto(i)){
|
|
if (_kdtree_source.is_ready(i)){
|
|
auto kdt = _kdtree_source[i];
|
|
vector<pair<FrameIndex, shared_ptr<TracePointKeyFrame>>> kfs;
|
|
int key_frame = -1;
|
|
for (size_t k = 0; k < _trace.size(); k++){
|
|
if (is_keyframe(k)){
|
|
kfs.push_back(pair<FrameIndex, shared_ptr<TracePointKeyFrame>>(k, dynamic_pointer_cast<TracePointKeyFrame>(_trace[k])));
|
|
if (k == frame) key_frame = static_cast<int>(kfs.size()) - 1;
|
|
}
|
|
}
|
|
assert(key_frame >= 0);
|
|
auto kd_matches = kdt->getMatches(kfs[key_frame].second->descriptor(), _pars.num_matches(), _pars.match_ratio());
|
|
vector<pair<Patch, Descriptor>> tp_matches;
|
|
for (auto& m : kd_matches)
|
|
tp_matches.push_back(pair<Patch, Descriptor>(Patch(get<0>(m), get<1>(m)), get<3>(m)));
|
|
dynamic_pointer_cast<TracePointAuto>(_trace[i])->add_matches(kfs, key_frame, tp_matches, _pars.max_matches_per_frame(), _pars.appearance_threshold());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// function object
|
|
class BuildTask{
|
|
public:
|
|
BuildTask(shared_ptr<TracePoint> before, FrameIndex segment_start, const OptimizationParameters& pars, Trace::Handler handler)
|
|
: _before(before), _segment_start(segment_start), _pars(pars), _handler(handler){}
|
|
|
|
void operator() () const {
|
|
FrameIndex segment_end = _segment_start + static_cast<FrameIndex>(_segment.size()) - 1;
|
|
{
|
|
std::stringstream ss;
|
|
ss << "starting segment " << _segment_start << " to " << segment_end;
|
|
Log::write(ss.str());
|
|
}
|
|
try{
|
|
if (_handler != nullptr) _handler(_segment_start, segment_end);
|
|
TracePointAuto::optimize(_segment, *_before, *_after, _pars);
|
|
if (_handler != nullptr) _handler(_segment_start, segment_end);
|
|
}
|
|
catch (std::exception e){
|
|
Log::write("exception in trace optimization: " + std::string(e.what()));
|
|
}
|
|
{
|
|
std::stringstream ss;
|
|
ss << "finished segment " << _segment_start << " to " << segment_end;
|
|
Log::write(ss.str());
|
|
}
|
|
}
|
|
|
|
void add(shared_ptr<TracePointAuto> trace_point){ _segment.push_back(trace_point); }
|
|
void after(shared_ptr<TracePoint> after){ _after = after; }
|
|
private:
|
|
FrameIndex _segment_start;
|
|
SegmentType _segment;
|
|
shared_ptr<TracePoint> _before;
|
|
shared_ptr<TracePoint> _after;
|
|
OptimizationParameters _pars;
|
|
Trace::Handler _handler;
|
|
};
|
|
|
|
void Trace::Implementation::start_build(){
|
|
shared_ptr<TracePoint> start = _trace_end;
|
|
FrameIndex segment_start = 0;
|
|
FrameIndex frame_end = static_cast<FrameIndex>(_trace.size());
|
|
while (segment_start < frame_end && _kdtree_source.is_ready(segment_start))
|
|
{
|
|
while (segment_start < frame_end && _kdtree_source.is_ready(segment_start) && is_keyframe(segment_start))
|
|
{
|
|
start = _trace[segment_start];
|
|
segment_start += 1;
|
|
}
|
|
if (segment_start < frame_end && _kdtree_source.is_ready(segment_start))
|
|
{
|
|
BuildTask task(start, segment_start, OptimizationParameters(_pars.lambda_d(), _pars.lambda_u(), _pars.lambda_o(), _pars.max_occlusion_duration()), _subscription);
|
|
while (segment_start < frame_end && _kdtree_source.is_ready(segment_start) && !(is_keyframe(segment_start)))
|
|
{
|
|
task.add(is_auto(segment_start) ? dynamic_pointer_cast<TracePointAuto>(_trace[segment_start]) : _occlusion);
|
|
++segment_start;
|
|
}
|
|
task.after(segment_start < frame_end && _kdtree_source.is_ready(segment_start) ? _trace[segment_start] : _trace_end);
|
|
_build.run(task);
|
|
}
|
|
}
|
|
}
|
|
|
|
void Trace::Implementation::cancel_build(){
|
|
_build.cancel();
|
|
_build.wait();
|
|
} |