First pass at threading backgroudn blur
This commit is contained in:
Родитель
8c662b6f30
Коммит
69dbb27c2b
|
@ -73,6 +73,11 @@ LearningModel StyleTransfer::GetModel()
|
|||
rel.append("Assets\\mosaic.onnx");
|
||||
return LearningModel::LoadFromFilePath(rel + L"");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/******* Start of old Segment Model stuff *******/
|
||||
SegmentModel::SegmentModel() :
|
||||
m_sess(NULL),
|
||||
m_sessPreprocess(NULL),
|
||||
|
@ -101,30 +106,7 @@ SegmentModel::SegmentModel(UINT32 w, UINT32 h) :
|
|||
m_bindStyleTransfer(NULL),
|
||||
bindings(swapChainEntryCount)
|
||||
{
|
||||
|
||||
SetImageSize(w, h);
|
||||
m_sess = CreateLearningModelSession(Invert(1, 3, h, w));
|
||||
m_sessStyleTransfer = CreateLearningModelSession(StyleTransfer());
|
||||
m_bindStyleTransfer = LearningModelBinding(m_sessStyleTransfer);
|
||||
|
||||
// Initialize segmentation learningmodelsessions
|
||||
m_sessPreprocess = CreateLearningModelSession(Normalize0_1ThenZScore(h, w, 3, mean, stddev));
|
||||
m_sessFCN = CreateLearningModelSession(FCNResnet());
|
||||
m_sessPostprocess = CreateLearningModelSession(PostProcess(1, 3, h, w, 1));
|
||||
|
||||
// Initialize segmentation bindings
|
||||
m_bindPreprocess = LearningModelBinding(m_sessPreprocess);
|
||||
m_bindFCN = LearningModelBinding(m_sessFCN);
|
||||
m_bindPostprocess = LearningModelBinding(m_sessPostprocess);
|
||||
|
||||
// Create set of bindings to cycle through
|
||||
for (int i = 0; i < swapChainEntryCount; i++) {
|
||||
bindings.push_back(std::make_unique<SwapChainEntry>());
|
||||
bindings[i]->binding = LearningModelBinding(m_sessStyleTransfer);
|
||||
bindings[i]->binding.Bind(L"outputImage",
|
||||
VideoFrame(Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8, 720, 720));
|
||||
}
|
||||
|
||||
SetModels(w, h);
|
||||
}
|
||||
|
||||
void SegmentModel::SetModels(UINT32 w, UINT32 h) {
|
||||
|
@ -147,12 +129,17 @@ void SegmentModel::SetModels(UINT32 w, UINT32 h) {
|
|||
m_bindFCN = LearningModelBinding(m_sessFCN);
|
||||
m_bindPostprocess = LearningModelBinding(m_sessPostprocess);
|
||||
|
||||
auto device = m_sessFCN.Device().Direct3D11Device();
|
||||
|
||||
// Create set of bindings to cycle through
|
||||
for (int i = 0; i < swapChainEntryCount; i++) {
|
||||
bindings.push_back(std::make_unique<SwapChainEntry>());
|
||||
bindings[i]->binding = LearningModelBinding(m_sessStyleTransfer);
|
||||
bindings[i]->binding.Bind(L"outputImage",
|
||||
VideoFrame(Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8, 720, 720));
|
||||
bindings[i]->binding_model = LearningModelBinding(m_sessFCN);
|
||||
bindings[i]->binding_post = LearningModelBinding(m_sessPostprocess);
|
||||
bindings[i]->bind_pre = LearningModelBinding(m_sessPreprocess);
|
||||
bindings[i]->binding_post.Bind(L"OutputImage",
|
||||
VideoFrame::CreateAsDirect3D11SurfaceBacked(Windows::Graphics::DirectX::DirectXPixelFormat::B8G8R8X8UIntNormalized, m_imageWidthInPixels , m_imageHeightInPixels ));
|
||||
bindings[i]->outputCache = VideoFrame::CreateAsDirect3D11SurfaceBacked(Windows::Graphics::DirectX::DirectXPixelFormat::B8G8R8X8UIntNormalized, m_imageWidthInPixels , m_imageHeightInPixels );
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,7 +156,6 @@ void SegmentModel::Run(IDirect3DSurface src, IDirect3DSurface dest)
|
|||
VideoFrame input = VideoFrame::CreateWithDirect3D11Surface(src);
|
||||
VideoFrame output = VideoFrame::CreateWithDirect3D11Surface(dest);
|
||||
|
||||
|
||||
auto device = m_sessFCN.Device().Direct3D11Device();
|
||||
auto desc = input.Direct3DSurface().Description();
|
||||
auto descOut = output.Direct3DSurface().Description();
|
||||
|
@ -180,60 +166,20 @@ void SegmentModel::Run(IDirect3DSurface src, IDirect3DSurface dest)
|
|||
output.CopyToAsync(output2).get();
|
||||
std::vector<int64_t> shape = { 1, 3, m_imageHeightInPixels, m_imageWidthInPixels };
|
||||
|
||||
// 2. Preprocessing: z-score normalization
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
ITensor intermediateTensor = TensorFloat::Create(shape);
|
||||
hstring inputName = m_sessPreprocess.Model().InputFeatures().GetAt(0).Name();
|
||||
hstring outputName = m_sessPreprocess.Model().OutputFeatures().GetAt(0).Name();
|
||||
SubmitEval(input2, output);
|
||||
swapChainIndex = (++swapChainIndex) % swapChainEntryCount;
|
||||
|
||||
m_bindPreprocess.Bind(inputName, input2);
|
||||
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(true));
|
||||
m_bindPreprocess.Bind(outputName, intermediateTensor, outputBindProperties);
|
||||
m_sessPreprocess.EvaluateAsync(m_bindPreprocess, L"");
|
||||
auto timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
|
||||
OutputDebugString(L"Pre: ");
|
||||
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
|
||||
|
||||
// 3. Run through actual model
|
||||
now = std::chrono::high_resolution_clock::now();
|
||||
std::vector<int64_t> FCNResnetOutputShape = { 1, 21, m_imageHeightInPixels, m_imageWidthInPixels };
|
||||
ITensor FCNResnetOutput = TensorFloat::Create(FCNResnetOutputShape);
|
||||
|
||||
m_bindFCN.Bind(m_sessFCN.Model().InputFeatures().GetAt(0).Name(), intermediateTensor);
|
||||
m_bindFCN.Bind(m_sessFCN.Model().OutputFeatures().GetAt(0).Name(), FCNResnetOutput, outputBindProperties);
|
||||
m_sessFCN.EvaluateAsync(m_bindFCN, L"");
|
||||
timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
|
||||
OutputDebugString(L" | Model: ");
|
||||
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
|
||||
|
||||
// 4.Postprocessing: extract labels from FCN scores and use to compose background-blurred image
|
||||
now = std::chrono::high_resolution_clock::now();
|
||||
ITensor rawLabels = TensorFloat::Create({1, 1, m_imageHeightInPixels, m_imageWidthInPixels});
|
||||
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(false));
|
||||
m_bindPostprocess.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(0).Name(), input2); // InputImage
|
||||
m_bindPostprocess.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(1).Name(), FCNResnetOutput); // InputScores
|
||||
m_bindPostprocess.Bind(m_sessPostprocess.Model().OutputFeatures().GetAt(0).Name(), output2); // TODO: DisableTensorCPUSync to false now?
|
||||
// Retrieve final output
|
||||
//m_sessPostprocess.EvaluateAsync(m_bindPostprocess, L"").get();
|
||||
auto finalResults = m_sessPostprocess.EvaluateAsync(m_bindPostprocess, L"");
|
||||
output2 = finalResults.get().Outputs().Lookup(L"OutputImage").try_as<VideoFrame>();
|
||||
|
||||
timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
|
||||
OutputDebugString(L" | Post: ");
|
||||
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
|
||||
//timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
|
||||
/*OutputDebugString(L" | Post: ");
|
||||
OutputDebugString(std::to_wstring(timePassed.count()).c_str());*/
|
||||
|
||||
// Copy back to the correct surface for MFT
|
||||
output2.CopyToAsync(output).get();
|
||||
//output2.CopyToAsync(output).get();
|
||||
|
||||
OutputDebugString(L" | Ending run ]");
|
||||
|
||||
}
|
||||
|
||||
//winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> BindInputs()
|
||||
//{
|
||||
//
|
||||
//}
|
||||
|
||||
void SegmentModel::SubmitEval(VideoFrame input, VideoFrame output) {
|
||||
auto currentBinding = bindings[0].get();
|
||||
if (currentBinding->activetask == nullptr
|
||||
|
@ -244,22 +190,51 @@ void SegmentModel::SubmitEval(VideoFrame input, VideoFrame output) {
|
|||
OutputDebugString(std::to_wstring(swapChainIndex).c_str());
|
||||
OutputDebugString(L" | ");
|
||||
// submit an eval and wait for it to finish submitting work
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard{ Processing };
|
||||
currentBinding->binding.Bind(L"inputImage", input);
|
||||
|
||||
// 2. Preprocessing: z-score normalization
|
||||
std::vector<int64_t> shape = { 1, 3, m_imageHeightInPixels, m_imageWidthInPixels };
|
||||
ITensor intermediateTensor = TensorFloat::Create(shape);
|
||||
hstring inputName = m_sessPreprocess.Model().InputFeatures().GetAt(0).Name();
|
||||
hstring outputName = m_sessPreprocess.Model().OutputFeatures().GetAt(0).Name();
|
||||
|
||||
currentBinding->bind_pre.Bind(inputName, input);
|
||||
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(true));
|
||||
currentBinding->bind_pre.Bind(outputName, intermediateTensor, outputBindProperties);
|
||||
m_sessPreprocess.EvaluateAsync(currentBinding->bind_pre, L"");
|
||||
auto timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
|
||||
OutputDebugString(L"Pre: ");
|
||||
|
||||
// 3. Run through actual model
|
||||
std::vector<int64_t> FCNResnetOutputShape = { 1, 21, m_imageHeightInPixels, m_imageWidthInPixels };
|
||||
ITensor FCNResnetOutput = TensorFloat::Create(FCNResnetOutputShape);
|
||||
|
||||
currentBinding->binding_model.Bind(m_sessFCN.Model().InputFeatures().GetAt(0).Name(), intermediateTensor);
|
||||
currentBinding->binding_model.Bind(m_sessFCN.Model().OutputFeatures().GetAt(0).Name(), FCNResnetOutput, outputBindProperties);
|
||||
m_sessFCN.EvaluateAsync(currentBinding->binding_model, L"");
|
||||
OutputDebugString(L" | Model: ");
|
||||
|
||||
// 4. Postprocessing
|
||||
ITensor rawLabels = TensorFloat::Create({ 1, 1, m_imageHeightInPixels, m_imageWidthInPixels });
|
||||
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(false));
|
||||
currentBinding->binding_post.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(0).Name(), input); // InputImage
|
||||
currentBinding->binding_post.Bind(m_sessPostprocess.Model().InputFeatures().GetAt(1).Name(), FCNResnetOutput); // InputScores
|
||||
|
||||
}
|
||||
std::rotate(bindings.begin(), bindings.begin() + 1, bindings.end());
|
||||
finishedFrameIndex = (finishedFrameIndex - 1 + swapChainEntryCount) % swapChainEntryCount;
|
||||
currentBinding->activetask = m_sessStyleTransfer.EvaluateAsync(
|
||||
currentBinding->binding,
|
||||
// Wait only for the last evalasync
|
||||
currentBinding->activetask = m_sessPostprocess.EvaluateAsync(
|
||||
currentBinding->binding_post,
|
||||
std::to_wstring(swapChainIndex).c_str());
|
||||
currentBinding->activetask.Completed([&, currentBinding, now](auto&& asyncInfo, winrt::Windows::Foundation::AsyncStatus const) {
|
||||
OutputDebugString(L"PF Eval completed |");
|
||||
//auto results = asyncInfo.GetResults().Outputs().Lookup(L"OutputImage");
|
||||
VideoFrame evalOutput = asyncInfo.GetResults()
|
||||
.Outputs()
|
||||
.Lookup(L"outputImage")
|
||||
.try_as<VideoFrame>();
|
||||
.Lookup(L"OutputImage")
|
||||
.try_as<VideoFrame>(); // Must have a VF bound to output for winml to cast to VF
|
||||
int bindingIdx;
|
||||
bool finishedFrameUpdated;
|
||||
{
|
||||
|
@ -298,8 +273,6 @@ void SegmentModel::SubmitEval(VideoFrame input, VideoFrame output) {
|
|||
// return without waiting for the submit to finish, setup the completion handler
|
||||
}
|
||||
|
||||
|
||||
|
||||
void SegmentModel::RunStyleTransfer(IDirect3DSurface src, IDirect3DSurface dest)
|
||||
{
|
||||
OutputDebugString(L"\n[Starting RunStyleTransfer | ");
|
||||
|
|
|
@ -21,13 +21,17 @@ using namespace winrt::Windows::Media;
|
|||
|
||||
// Threading fields for style transfer
|
||||
struct SwapChainEntry {
|
||||
LearningModelBinding binding; // Just one for style transfer, for now
|
||||
LearningModelBinding bind_pre;
|
||||
LearningModelBinding binding_model;
|
||||
LearningModelBinding binding_post;
|
||||
winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> activetask;
|
||||
VideoFrame outputCache;
|
||||
SwapChainEntry() :
|
||||
binding(nullptr),
|
||||
bind_pre(nullptr),
|
||||
binding_model(nullptr),
|
||||
binding_post(nullptr),
|
||||
activetask(nullptr),
|
||||
outputCache(VideoFrame(winrt::Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8, 720, 720)) {}
|
||||
outputCache(NULL) {}
|
||||
};
|
||||
|
||||
|
||||
|
@ -76,8 +80,6 @@ private:
|
|||
LearningModelBinding m_bindPostprocess;
|
||||
LearningModelBinding m_bindStyleTransfer;
|
||||
|
||||
|
||||
|
||||
// Threaded style transfer fields
|
||||
void SubmitEval(VideoFrame, VideoFrame);
|
||||
winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> evalStatus;
|
||||
|
@ -116,7 +118,9 @@ public:
|
|||
void SetDevice() {
|
||||
m_device = m_session.Device().Direct3D11Device();
|
||||
}
|
||||
|
||||
protected:
|
||||
//virtual winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> BindInputs(VideoFrame input) = 0;
|
||||
|
||||
void SetVideoFrames(VideoFrame inVideoFrame, VideoFrame outVideoFrame)
|
||||
{
|
||||
|
@ -153,6 +157,15 @@ protected:
|
|||
return session;
|
||||
}
|
||||
|
||||
|
||||
//// Threaded style transfer fields
|
||||
//void SubmitEval(VideoFrame, VideoFrame);
|
||||
//winrt::Windows::Foundation::IAsyncOperation<LearningModelEvaluationResult> evalStatus;
|
||||
//std::vector <std::unique_ptr<SwapChainEntry>> bindings;
|
||||
//int swapChainIndex = 0;
|
||||
//int swapChainEntryCount = 5;
|
||||
//int finishedFrameIndex = 0;
|
||||
|
||||
bool m_bUseGPU = true;
|
||||
bool m_bVideoFramesSet = false;
|
||||
VideoFrame m_inputVideoFrame,
|
||||
|
|
|
@ -1430,7 +1430,7 @@ HRESULT TransformBlur::OnProcessOutput(IMFSample** ppOut)
|
|||
{
|
||||
// Do the copies inside runtest
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
m_segmentModel->Run(src, dest);
|
||||
m_segmentModel.Run(src, dest);
|
||||
auto timePassed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - now);
|
||||
OutputDebugString(std::to_wstring(timePassed.count()).c_str());
|
||||
}
|
||||
|
@ -1497,8 +1497,8 @@ HRESULT TransformBlur::UpdateFormatInfo()
|
|||
CHECK_HR(hr = GetImageSize(m_videoFOURCC, m_imageWidthInPixels, m_imageHeightInPixels, &m_cbImageSize));
|
||||
|
||||
// Set the size of the SegmentModel
|
||||
// m_segmentModel.SetModels(m_imageWidthInPixels, m_imageHeightInPixels);
|
||||
m_segmentModel = new StyleTransfer(m_imageWidthInPixels, m_imageHeightInPixels);
|
||||
m_segmentModel.SetModels(m_imageWidthInPixels, m_imageHeightInPixels);
|
||||
m_streamModel = std::make_unique<StyleTransfer>(m_imageWidthInPixels, m_imageHeightInPixels);
|
||||
}
|
||||
|
||||
done:
|
||||
|
|
|
@ -227,5 +227,7 @@ private:
|
|||
winrt::com_ptr<IMFVideoSampleAllocatorEx> m_spOutputSampleAllocator;
|
||||
|
||||
// Model fields
|
||||
IStreamModel* m_segmentModel;
|
||||
SegmentModel m_segmentModel;
|
||||
std::unique_ptr<IStreamModel> m_streamModel;
|
||||
|
||||
};
|
Загрузка…
Ссылка в новой задаче