diff --git a/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.cpp b/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.cpp index 85da2ce6..6a6948ba 100644 --- a/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.cpp +++ b/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.cpp @@ -38,7 +38,7 @@ enum OnnxDataType : long { } OnnxDataType; -const int32_t opset = 12; +const int32_t opset = 13; /**** Style transfer model ****/ @@ -83,12 +83,52 @@ void WindowsHello::InitializeSession(int w, int h) { SetImageSize(w, h); // Preprocessing: grayscale model + m_sessionPreprocess = CreateLearningModelSession(GrayScale(1, 3, m_imageHeightInPixels, m_imageWidthInPixels)); + m_bindingPreprocess = LearningModelBinding(m_sessionPreprocess); m_session = CreateLearningModelSession(GetModel()); m_binding = LearningModelBinding(m_session); } +void WindowsHello::Run(IDirect3DSurface src, IDirect3DSurface dest) +{ + VideoFrame inVideoFrame = VideoFrame::CreateWithDirect3D11Surface(src); + VideoFrame outVideoFrame = VideoFrame::CreateWithDirect3D11Surface(dest); + SetVideoFrames(inVideoFrame, outVideoFrame); + auto outputBindProperties = PropertySet(); + + // Shape validation + assert((UINT32)m_inputVideoFrame.Direct3DSurface().Description().Height == m_imageHeightInPixels); + assert((UINT32)m_inputVideoFrame.Direct3DSurface().Description().Width == m_imageWidthInPixels); + + // Preprocessing: grayscale + std::vector shape = { 1, 1, m_imageHeightInPixels, m_imageWidthInPixels }; + ITensor intermediateTensor = TensorFloat::Create(shape); + hstring inputName = m_sessionPreprocess.Model().InputFeatures().GetAt(0).Name(); + hstring outputName = m_sessionPreprocess.Model().OutputFeatures().GetAt(0).Name(); + + m_bindingPreprocess.Bind(inputName, m_inputVideoFrame); + outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(true)); + m_bindingPreprocess.Bind(outputName, intermediateTensor, outputBindProperties); + m_sessionPreprocess.Evaluate(m_bindingPreprocess, L""); + + // Run through actual model + std::vector helloOutputShape = {1, -1, 4}; // TODO: How many bboxes? + ITensor helloOutput = TensorFloat::Create(helloOutputShape); + m_binding.Bind(m_session.Model().InputFeatures().GetAt(0).Name(), intermediateTensor); + m_binding.Bind(m_session.Model().OutputFeatures().GetAt(0).Name(), helloOutput, outputBindProperties); + m_session.Evaluate(m_binding, L""); + + //m_.CopyToAsync(outVideoFrame); + // Draw bboxes on videoframe?? For now just print them out + auto bbox = helloOutput.as().GetAsVectorView(); + m_inputVideoFrame.CopyToAsync(outVideoFrame).get(); + OutputDebugString(L"BOUNDING BOX: "); + OutputDebugString(std::to_wstring(bbox.GetAt(0)).c_str()); + OutputDebugString(L"\n"); +} + LearningModel WindowsHello::GetModel() { auto modelPath = std::filesystem::path(m_modelBasePath.c_str()); modelPath.append("retina_rgb.onnx"); @@ -226,6 +266,21 @@ LearningModel Invert(long n, long c, long h, long w) return builder.CreateModel(); } +LearningModel GrayScale(long n, long c, long h, long w) +{ + // Go from [n, c, h, w] -> [1, 1, h, w] + // Start with simple reduction of mean across channels + auto builder = LearningModelBuilder::Create(opset) + .Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, { 1, c, h, w })) + .Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, { 1, 1, h, w })) + .Operators().Add(LearningModelOperator(L"ReduceMean") + .SetInput(L"data", L"Input") + .SetConstant(L"axes", TensorFloat::CreateFromIterable({ 1 }, { 1.f })) + .SetOutput(L"reduced", L"Output") + ); + return builder.CreateModel(); +} + LearningModel Normalize0_1ThenZScore(long h, long w, long c, const std::array& means, const std::array& stddev) { assert(means.size() == c); diff --git a/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.h b/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.h index ca72cc89..f3ee3e31 100644 --- a/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.h +++ b/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/SegmentModel.h @@ -110,7 +110,7 @@ protected: class StyleTransfer : public StreamModelBase { public: - StyleTransfer() : StreamModelBase() {}; + StyleTransfer() : StreamModelBase(){}; void InitializeSession(int w, int h); void Run(IDirect3DSurface src, IDirect3DSurface dest); private: @@ -119,12 +119,16 @@ private: class WindowsHello : public StreamModelBase { public: - WindowsHello() : StreamModelBase() {}; + WindowsHello() : + StreamModelBase(), + m_sessionPreprocess(nullptr), + m_bindingPreprocess(nullptr) {}; void InitializeSession(int w, int h); void Run(IDirect3DSurface src, IDirect3DSurface dest); private: LearningModel GetModel(); - + LearningModelSession m_sessionPreprocess; + LearningModelBinding m_bindingPreprocess; }; class BackgroundBlur : public StreamModelBase