First pass at threading backgroudn blur

This commit is contained in:
Linnea May 2022-01-21 12:28:31 -08:00
Родитель 8c662b6f30
Коммит 69dbb27c2b
4 изменённых файлов: 79 добавлений и 91 удалений

Просмотреть файл

@ -73,6 +73,11 @@ LearningModel StyleTransfer::GetModel()
rel.append("Assets\\mosaic.onnx");
return LearningModel::LoadFromFilePath(rel + L"");
}
/******* Start of old Segment Model stuff *******/
SegmentModel::SegmentModel() :
m_sess(NULL),
m_sessPreprocess(NULL),
@ -101,30 +106,7 @@ SegmentModel::SegmentModel(UINT32 w, UINT32 h) :
m_bindStyleTransfer(NULL),
bindings(swapChainEntryCount)
{
SetImageSize(w, h);
m_sess = CreateLearningModelSession(Invert(1, 3, h, w));
m_sessStyleTransfer = CreateLearningModelSession(StyleTransfer());
m_bindStyleTransfer = LearningModelBinding(m_sessStyleTransfer);
// Initialize segmentation learningmodelsessions
m_sessPreprocess = CreateLearningModelSession(Normalize0_1ThenZScore(h, w, 3, mean, stddev));
m_sessFCN = CreateLearningModelSession(FCNResnet());
m_sessPostprocess = CreateLearningModelSession(PostProcess(1, 3, h, w, 1));
// Initialize segmentation bindings
m_bindPreprocess = LearningModelBinding(m_sessPreprocess);
m_bindFCN = LearningModelBinding(m_sessFCN);
m_bindPostprocess = LearningModelBinding(m_sessPostprocess);
// Create set of bindings to cycle through
for (int i = 0; i < swapChainEntryCount; i++) {
bindings.push_back(std::make_unique<SwapChainEntry>());
bindings[i]->binding = LearningModelBinding(m_sessStyleTransfer);
bindings[i]->binding.Bind(L"outputImage",
VideoFrame(Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8, 720, 720));
}
SetModels(w, h);
}
void SegmentModel::SetModels(UINT32 w, UINT32 h) {
@ -147,12 +129,17 @@ void SegmentModel::SetModels(UINT32 w, UINT32 h) {
m_bindFCN = LearningModelBinding(m_sessFCN);
m_bindPostprocess = LearningModelBinding(m_sessPostprocess);
auto device = m_sessFCN.Device().Direct3D11Device();
// Create set of bindings to cycle through
for (int i = 0; i < swapChainEntryCount; i++) {
bindings.push_back(std::make_unique<SwapChainEntry>());
bindings[i]->binding = LearningModelBinding(m_sessStyleTransfer);
bindings[i]->binding.Bind(L"outputImage",
VideoFrame(Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8, 720, 720));
bindings[i]->binding_model = LearningModelBinding(m_sessFCN);
bindings[i]->binding_post = LearningModelBinding(m_sessPostprocess);
bindings[i]->bind_pre = LearningModelBinding(m_sessPreprocess);
bindings[i]->binding_post.Bind(L"OutputImage",
VideoFrame::CreateAsDirect3D11SurfaceBacked(Windows::Graphics::DirectX::DirectXPixelFormat::B8G8R8X8UIntNormalized, m_imageWidthInPixels , m_imageHeightInPixels ));
bindings[i]->outputCache = VideoFrame::CreateAsDirect3D11SurfaceBacked(Windows::Graphics::DirectX::DirectXPixelFormat::B8G8R8X8UIntNormalized, m_imageWidthInPixels , m_imageHeightInPixels );
}
}
@ -169,7 +156,6 @@ void SegmentModel::Run(IDirect3DSurface src, IDirect3DSurface dest)
VideoFrame input = VideoFrame::CreateWithDirect3D11Surface(src);
VideoFrame output = VideoFrame::CreateWithDirect3D11Surface(dest);
auto device = m_sessFCN.Device().Direct3D11Device();
auto desc = input.Direct3DSurface().Description();
auto descOut = output.Direct3DSurface().Description();
@ -180,60 +166,20 @@ void SegmentModel::Run(IDirect3DSurface src, IDirect3DSurface dest)
output.CopyToAsync(output2).get();
std::vector<int64_t> shape = { 1, 3, m_imageHeightInPixels, m_imageWidthInPixels };
// 2. Preprocessing: z-score normalization
auto now = std::chrono::high_resolution_clock::now();
ITensor intermediateTensor = TensorFloat::Create(shape);
hstring inputName = m_sessPreprocess.Model().InputFeatures().GetAt(0).Name();
hstring outputName = m_sessPreprocess.Model().OutputFeatures().GetAt(0).Name();
SubmitEval(input2, output);
swapChainIndex = (++swapChainIndex) % swapChainEntryCount;
m_bindPreprocess.Bind(inputName, input2);
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(true));
m_bindPreprocess.Bind(outputName, intermediateTensor, outputBindProperties);
m_sessPreprocess.EvaluateAsync(m_bindPreprocess, L"");
auto timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
OutputDebugString(L"Pre: ");
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
// 3. Run through actual model
now = std::chrono::high_resolution_clock::now();
std::vector<int64_t> FCNResnetOutputShape = { 1, 21, m_imageHeightInPixels, m_imageWidthInPixels };
ITensor FCNResnetOutput = TensorFloat::Create(FCNResnetOutputShape);
m_bindFCN.Bind(m_sessFCN.Model().InputFeatures().GetAt(0).Name(), intermediateTensor);
m_bindFCN.Bind(m_sessFCN.Model().OutputFeatures().GetAt(0).Name(), FCNResnetOutput, outputBindProperties);
m_sessFCN.EvaluateAsync(m_bindFCN, L"");
timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
OutputDebugString(L" | Model: ");
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
// 4.Postprocessing: extract labels from FCN scores and use to compose background-blurred image
now = std::chrono::high_resolution_clock::now();
ITensor rawLabels = TensorFloat::Create({1, 1, m_imageHeightInPixels, m_imageWidthInPixels});
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(false));
m_bindPostprocess.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(0).Name(), input2); // InputImage
m_bindPostprocess.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(1).Name(), FCNResnetOutput); // InputScores
m_bindPostprocess.Bind(m_sessPostprocess.Model().OutputFeatures().GetAt(0).Name(), output2); // TODO: DisableTensorCPUSync to false now?
// Retrieve final output
//m_sessPostprocess.EvaluateAsync(m_bindPostprocess, L"").get();
auto finalResults = m_sessPostprocess.EvaluateAsync(m_bindPostprocess, L"");
output2 = finalResults.get().Outputs().Lookup(L"OutputImage").try_as<VideoFrame>();
timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
OutputDebugString(L" | Post: ");
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
//timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
/*OutputDebugString(L" | Post: ");
OutputDebugString(std::to_wstring(timePassed.count()).c_str());*/
// Copy back to the correct surface for MFT
output2.CopyToAsync(output).get();
//output2.CopyToAsync(output).get();
OutputDebugString(L" | Ending run ]");
}
//winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> BindInputs()
//{
//
//}
void SegmentModel::SubmitEval(VideoFrame input, VideoFrame output) {
auto currentBinding = bindings[0].get();
if (currentBinding->activetask == nullptr
@ -244,22 +190,51 @@ void SegmentModel::SubmitEval(VideoFrame input, VideoFrame output) {
OutputDebugString(std::to_wstring(swapChainIndex).c_str());
OutputDebugString(L" | ");
// submit an eval and wait for it to finish submitting work
{
std::lock_guard<std::mutex> guard{ Processing };
currentBinding->binding.Bind(L"inputImage", input);
// 2. Preprocessing: z-score normalization
std::vector<int64_t> shape = { 1, 3, m_imageHeightInPixels, m_imageWidthInPixels };
ITensor intermediateTensor = TensorFloat::Create(shape);
hstring inputName = m_sessPreprocess.Model().InputFeatures().GetAt(0).Name();
hstring outputName = m_sessPreprocess.Model().OutputFeatures().GetAt(0).Name();
currentBinding->bind_pre.Bind(inputName, input);
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(true));
currentBinding->bind_pre.Bind(outputName, intermediateTensor, outputBindProperties);
m_sessPreprocess.EvaluateAsync(currentBinding->bind_pre, L"");
auto timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
OutputDebugString(L"Pre: ");
// 3. Run through actual model
std::vector<int64_t> FCNResnetOutputShape = { 1, 21, m_imageHeightInPixels, m_imageWidthInPixels };
ITensor FCNResnetOutput = TensorFloat::Create(FCNResnetOutputShape);
currentBinding->binding_model.Bind(m_sessFCN.Model().InputFeatures().GetAt(0).Name(), intermediateTensor);
currentBinding->binding_model.Bind(m_sessFCN.Model().OutputFeatures().GetAt(0).Name(), FCNResnetOutput, outputBindProperties);
m_sessFCN.EvaluateAsync(currentBinding->binding_model, L"");
OutputDebugString(L" | Model: ");
// 4. Postprocessing
ITensor rawLabels = TensorFloat::Create({ 1, 1, m_imageHeightInPixels, m_imageWidthInPixels });
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(false));
currentBinding->binding_post.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(0).Name(), input); // InputImage
currentBinding->binding_post.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(1).Name(), FCNResnetOutput); // InputScores
}
std::rotate(bindings.begin(), bindings.begin() + 1, bindings.end());
finishedFrameIndex = (finishedFrameIndex - 1 + swapChainEntryCount) % swapChainEntryCount;
currentBinding->activetask = m_sessStyleTransfer.EvaluateAsync(
currentBinding->binding,
// Wait only for the last evalasync
currentBinding->activetask = m_sessPostprocess.EvaluateAsync(
currentBinding->binding_post,
std::to_wstring(swapChainIndex).c_str());
currentBinding->activetask.Completed([&, currentBinding, now](auto&& asyncInfo, winrt::Windows::Foundation::AsyncStatus const) {
OutputDebugString(L"PF Eval completed |");
//auto results = asyncInfo.GetResults().Outputs().Lookup(L"OutputImage");
VideoFrame evalOutput = asyncInfo.GetResults()
.Outputs()
.Lookup(L"outputImage")
.try_as<VideoFrame>();
.Lookup(L"OutputImage")
.try_as<VideoFrame>(); // Must have a VF bound to output for winml to cast to VF
int bindingIdx;
bool finishedFrameUpdated;
{
@ -298,8 +273,6 @@ void SegmentModel::SubmitEval(VideoFrame input, VideoFrame output) {
// return without waiting for the submit to finish, setup the completion handler
}
void SegmentModel::RunStyleTransfer(IDirect3DSurface src, IDirect3DSurface dest)
{
OutputDebugString(L"\n[Starting RunStyleTransfer | ");

Просмотреть файл

@ -21,13 +21,17 @@ using namespace winrt::Windows::Media;
// Threading fields for style transfer
struct SwapChainEntry {
LearningModelBinding binding; // Just one for style transfer, for now
LearningModelBinding bind_pre;
LearningModelBinding binding_model;
LearningModelBinding binding_post;
winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> activetask;
VideoFrame outputCache;
SwapChainEntry() :
binding(nullptr),
bind_pre(nullptr),
binding_model(nullptr),
binding_post(nullptr),
activetask(nullptr),
outputCache(VideoFrame(winrt::Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8, 720, 720)) {}
outputCache(NULL) {}
};
@ -76,8 +80,6 @@ private:
LearningModelBinding m_bindPostprocess;
LearningModelBinding m_bindStyleTransfer;
// Threaded style transfer fields
void SubmitEval(VideoFrame, VideoFrame);
winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> evalStatus;
@ -116,7 +118,9 @@ public:
void SetDevice() {
m_device = m_session.Device().Direct3D11Device();
}
protected:
//virtual winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> BindInputs(VideoFrame input) = 0;
void SetVideoFrames(VideoFrame inVideoFrame, VideoFrame outVideoFrame)
{
@ -153,6 +157,15 @@ protected:
return session;
}
//// Threaded style transfer fields
//void SubmitEval(VideoFrame, VideoFrame);
//winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> evalStatus;
//std::vector <std::unique_ptr<SwapChainEntry>> bindings;
//int swapChainIndex = 0;
//int swapChainEntryCount = 5;
//int finishedFrameIndex = 0;
bool m_bUseGPU = true;
bool m_bVideoFramesSet = false;
VideoFrame m_inputVideoFrame,

Просмотреть файл

@ -1430,7 +1430,7 @@ HRESULT TransformBlur::OnProcessOutput(IMFSample** ppOut)
{
// Do the copies inside runtest
auto now = std::chrono::high_resolution_clock::now();
m_segmentModel->Run(src, dest);
m_segmentModel.Run(src, dest);
auto timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
}
@ -1497,8 +1497,8 @@ HRESULT TransformBlur::UpdateFormatInfo()
CHECK_HR(hr = GetImageSize(m_videoFOURCC, m_imageWidthInPixels, m_imageHeightInPixels, &m_cbImageSize));
// Set the size of the SegmentModel
// m_segmentModel.SetModels(m_imageWidthInPixels, m_imageHeightInPixels);
m_segmentModel = new StyleTransfer(m_imageWidthInPixels, m_imageHeightInPixels);
m_segmentModel.SetModels(m_imageWidthInPixels, m_imageHeightInPixels);
m_streamModel = std::make_unique<StyleTransfer>(m_imageWidthInPixels, m_imageHeightInPixels);
}
done:

Просмотреть файл

@ -227,5 +227,7 @@ private:
winrt::com_ptr<IMFVideoSampleAllocatorEx> m_spOutputSampleAllocator;
// Model fields
IStreamModel* m_segmentModel;
SegmentModel m_segmentModel;
std::unique_ptr<IStreamModel> m_streamModel;
};