address comments in segmentmodel

This commit is contained in:
Linnea May 2022-04-15 13:32:34 -07:00
Родитель bd70309d90
Коммит c165e0bdb0
5 изменённых файлов: 44 добавлений и 66 удалений

1
Samples/BackgroundBlur/.gitignore поставляемый
Просмотреть файл

@ -1 +0,0 @@
*.onnx

Просмотреть файл

@ -38,11 +38,10 @@ enum OnnxDataType : long {
}OnnxDataType;
int g_scale = 5;
auto outputBindProperties = PropertySet();
const int32_t opset = 12;
/**** Style transfer model ****/
void StyleTransfer::SetModels(int w, int h)
void StyleTransfer::InitializeSession(int w, int h)
{
// TODO: Use w/h or use the 720x720 of the mode
SetImageSize(720, 720); // SIze model input sizes fixed to 720x720
@ -53,7 +52,6 @@ void StyleTransfer::Run(IDirect3DSurface src, IDirect3DSurface dest)
{
m_bSyncStarted = TRUE;
assert(m_session.Device().AdapterId() == m_highPerfAdapter);
VideoFrame inVideoFrame = VideoFrame::CreateWithDirect3D11Surface(src);
VideoFrame outVideoFrame = VideoFrame::CreateWithDirect3D11Surface(dest);
SetVideoFrames(inVideoFrame, outVideoFrame);
@ -72,11 +70,12 @@ void StyleTransfer::Run(IDirect3DSurface src, IDirect3DSurface dest)
m_bSyncStarted = FALSE;
}
LearningModel StyleTransfer::GetModel()
{
auto rel = std::filesystem::current_path();
rel.append("Assets\\mosaic.onnx");
return LearningModel::LoadFromFilePath(rel + L"");
auto model_path = std::filesystem::current_path();
model_path.append("Assets\\mosaic.onnx");
return LearningModel::LoadFromFilePath(model_path.c_str());
}
@ -85,9 +84,9 @@ BackgroundBlur::~BackgroundBlur() {
if (m_session) m_session.Close();
}
void BackgroundBlur::SetModels(int w, int h)
void BackgroundBlur::InitializeSession(int w, int h)
{
w /= g_scale; h /= g_scale;
w /= m_scale; h /= m_scale;
SetImageSize(w, h);
auto joinOptions1 = LearningModelJoinOptions();
@ -103,7 +102,7 @@ void BackgroundBlur::SetModels(int w, int h)
joinOptions2.Link(L"FCN_out", L"InputScores");
joinOptions2.Link(L"OutputImageForward", L"InputImage");
joinOptions2.JoinedNodePrefix(L"Post_");
//joinOptions2.PromoteUnlinkedOutputsToFusedOutputs(false); // Causes winrt originate error in FusedGraphKernel.cpp, but works on CPU
//joinOptions2.PromoteUnlinkedOutputsToFusedOutputs(false); // TODO: Causes winrt originate error in FusedGraphKernel.cpp, but works on CPU
auto modelExperimental2 = LearningModelExperimental(intermediateModel);
LearningModel modelFused = modelExperimental2.JoinModel(PostProcess(1, 3, h, w, 1), joinOptions2);
@ -115,16 +114,15 @@ void BackgroundBlur::SetModels(int w, int h)
}
LearningModel BackgroundBlur::GetModel()
{
auto rel = std::filesystem::current_path();
rel.append("Assets\\fcn-resnet50-12.onnx");
return LearningModel::LoadFromFilePath(rel + L"");
auto model_path = std::filesystem::current_path();
model_path.append("Assets\\fcn-resnet50-12.onnx");
return LearningModel::LoadFromFilePath(model_path.c_str());
}
void BackgroundBlur::Run(IDirect3DSurface src, IDirect3DSurface dest)
{
m_bSyncStarted = TRUE;
// Device validation
assert(m_session.Device().AdapterId() == m_highPerfAdapter);
VideoFrame inVideoFrame = VideoFrame::CreateWithDirect3D11Surface(src);
VideoFrame outVideoFrame = VideoFrame::CreateWithDirect3D11Surface(dest);
SetVideoFrames(inVideoFrame, outVideoFrame);
@ -145,7 +143,7 @@ void BackgroundBlur::Run(IDirect3DSurface src, IDirect3DSurface dest)
LearningModel BackgroundBlur::PostProcess(long n, long c, long h, long w, long axis)
{
auto builder = LearningModelBuilder::Create(12)
auto builder = LearningModelBuilder::Create(opset)
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputImage", TensorKind::Float, { n, c, h, w }))
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputScores", TensorKind::Float, { -1, -1, h, w })) // Different input type?
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"OutputImage", TensorKind::Float, { n, c, h, w }))
@ -204,8 +202,7 @@ LearningModel BackgroundBlur::PostProcess(long n, long c, long h, long w, long a
LearningModel Invert(long n, long c, long h, long w)
{
auto builder = LearningModelBuilder::Create(11)
auto builder = LearningModelBuilder::Create(opset)
// Loading in buffers and reshape
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, { n, c, h, w }))
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, { n, c, h, w }))
@ -230,7 +227,7 @@ LearningModel Normalize0_1ThenZScore(long h, long w, long c, const std::array<fl
assert(means.size() == c);
assert(stddev.size() == c);
auto builder = LearningModelBuilder::Create(12)
auto builder = LearningModelBuilder::Create(opset)
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", L"The NCHW image", TensorKind::Float, {1, c, h, w}))
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", L"The NCHW image normalized with mean and stddev.", TensorKind::Float, {1, c, h, w}))
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"OutputImageForward", L"The NCHW image forwarded through the model.", TensorKind::Float, {1, c, h, w}))
@ -262,7 +259,7 @@ LearningModel Normalize0_1ThenZScore(long h, long w, long c, const std::array<fl
LearningModel ReshapeFlatBufferToNCHW(long n, long c, long h, long w)
{
auto builder = LearningModelBuilder::Create(11)
auto builder = LearningModelBuilder::Create(opset)
// Loading in buffers and reshape
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::UInt8, { 1, n * c * h * w }))
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, {n, c, h, w}))

Просмотреть файл

@ -29,47 +29,37 @@ LearningModel Normalize0_1ThenZScore(long height, long width, long channels, con
LearningModel ReshapeFlatBufferToNCHW(long n, long c, long h, long w);
LearningModel Invert(long n, long c, long h, long w);
class IStreamModel
class StreamModelBase
{
public:
IStreamModel():
StreamModelBase():
m_inputVideoFrame(NULL),
m_outputVideoFrame(NULL),
m_session(NULL),
m_binding(NULL),
m_bSyncStarted(FALSE)
{}
IStreamModel(int w, int h) :
StreamModelBase(int w, int h) :
m_inputVideoFrame(NULL),
m_outputVideoFrame(NULL),
m_session(NULL),
m_binding(NULL),
m_bSyncStarted(FALSE)
{}
~IStreamModel() {
virtual ~StreamModelBase() {
if(m_session) m_session.Close();
if(m_binding) m_binding.Clear();
if (m_inputVideoFrame) m_inputVideoFrame.Close();
if (m_outputVideoFrame) m_outputVideoFrame.Close();
if (m_device) m_device.Close();
};
virtual void SetModels(int w, int h) =0;
virtual void InitializeSession(int w, int h) =0;
virtual void Run(IDirect3DSurface src, IDirect3DSurface dest) =0;
void SetUseGPU(bool use) {
m_bUseGPU = use;
}
void SetDevice() {
assert(m_session.Device().AdapterId() == m_highPerfAdapter);
assert(m_session.Device().Direct3D11Device() != NULL);
m_device = m_session.Device().Direct3D11Device();
auto device = m_session.Device().AdapterId();
}
// Synchronous eval status
BOOL m_bSyncStarted;
VideoFrame m_outputVideoFrame;
static const int m_scale = 5;
protected:
// Cache input frames into a shareable d3d-backed VideoFrame
@ -77,10 +67,7 @@ protected:
{
if (true || !m_bVideoFramesSet)
{
if (m_device == NULL)
{
SetDevice();
}
auto device = m_session.Device().Direct3D11Device();
auto inDesc = inVideoFrame.Direct3DSurface().Description();
auto outDesc = outVideoFrame.Direct3DSurface().Description();
/*
@ -88,8 +75,8 @@ protected:
whereas every model created with LearningModelBuilder takes arguments in (height, width) order.
*/
auto format = winrt::Windows::Graphics::DirectX::DirectXPixelFormat::B8G8R8X8UIntNormalized;
m_inputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, m_device);
m_outputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, m_device);
m_inputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, device);
m_outputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, device);
m_bVideoFramesSet = true;
}
// TODO: Fix bug in WinML so that the surfaces from capture engine are shareable, remove copy.
@ -104,8 +91,6 @@ protected:
LearningModelSession CreateLearningModelSession(const LearningModel& model, bool closedModel = true) {
auto device = m_bUseGPU ? LearningModelDevice(LearningModelDeviceKind::DirectXHighPerformance) : LearningModelDevice(LearningModelDeviceKind::Default);
auto displayAdapter = winrt::Windows::Devices::Display::Core::DisplayAdapter::FromId(device.AdapterId());
m_highPerfAdapter = device.AdapterId();
auto options = LearningModelSessionOptions();
options.BatchSizeOverride(0);
options.CloseModelOnSessionCreation(closedModel);
@ -119,9 +104,6 @@ protected:
UINT32 m_imageWidthInPixels;
UINT32 m_imageHeightInPixels;
IDirect3DDevice m_device;
// For debugging potential device issues
winrt::Windows::Graphics::DisplayAdapterId m_highPerfAdapter{};
// Learning Model Binding and Session.
LearningModelSession m_session;
@ -129,34 +111,34 @@ protected:
};
class StyleTransfer : public IStreamModel {
class StyleTransfer : public StreamModelBase {
public:
StyleTransfer(int w, int h) : IStreamModel(w, h)
StyleTransfer(int w, int h) : StreamModelBase(w, h)
{
SetModels(w, h);
InitializeSession(w, h);
}
StyleTransfer() : IStreamModel() {};
~StyleTransfer(){};
void SetModels(int w, int h);
StyleTransfer() : StreamModelBase() {};
virtual ~StyleTransfer(){};
void InitializeSession(int w, int h);
void Run(IDirect3DSurface src, IDirect3DSurface dest);
private:
LearningModel GetModel();
};
class BackgroundBlur : public IStreamModel
class BackgroundBlur : public StreamModelBase
{
public:
BackgroundBlur(int w, int h) :
IStreamModel(w, h)
StreamModelBase(w, h)
{
SetModels(w, h);
InitializeSession(w, h);
}
BackgroundBlur() :
IStreamModel()
StreamModelBase()
{};
~BackgroundBlur();
void SetModels(int w, int h);
virtual ~BackgroundBlur();
void InitializeSession(int w, int h);
void Run(IDirect3DSurface src, IDirect3DSurface dest);
private:

Просмотреть файл

@ -433,7 +433,7 @@ HRESULT TransformAsync::UpdateDX11Device()
D3D11_FENCE_FLAG flag = D3D11_FENCE_FLAG_NONE;
m_spDevice->CreateFence(m_fenceValue, flag, __uuidof(ID3D11Fence), m_spFence.put_void());
// Probably don't need to save the event for the first frame to render, since that will be long anyways w first Eval/Bind.
// Actually prob will be long for the first little bit anyways bc of each IStreamModel to select, but oh well. It'll be fine.
// Actually prob will be long for the first little bit anyways bc of each StreamModelBase to select, but oh well. It'll be fine.
}
else
{
@ -799,7 +799,7 @@ HRESULT TransformAsync::InitializeTransform(void)
CHECK_HR(hr = CSampleQueue::Create(&m_pOutputSampleQueue));
// Set up circular queue of IStreamModels
// Set up circular queue of StreamModelBases
for (int i = 0; i < m_numThreads; i++) {
// TODO: Have a dialogue to select which model to select for real-time inference.
m_models.push_back(std::make_unique<BackgroundBlur>());
@ -898,7 +898,7 @@ HRESULT TransformAsync::UpdateFormatInfo()
// Set the size of the SegmentModel
for (int i = 0; i < m_numThreads; i++)
{
m_models[i]->SetModels(m_imageWidthInPixels, m_imageHeightInPixels);
m_models[i]->InitializeSession(m_imageWidthInPixels, m_imageHeightInPixels);
}
}

Просмотреть файл

@ -226,7 +226,7 @@ public:
HRESULT NotifyRelease();
#pragma endregion IMFVideoSampleAllocatorNotify
// Uses the next available IStreamModel to run inference on pInputSample
// Uses the next available StreamModelBase to run inference on pInputSample
// and allocates a transformed output sample.
HRESULT SubmitEval(IMFSample* pInputSample);
@ -268,7 +268,7 @@ protected:
HRESULT OnSetD3DManager(ULONG_PTR ulParam);
// After the input type is set, update MFT format information and sets
// IStreamModel input sizes.
// StreamModelBase input sizes.
HRESULT UpdateFormatInfo();
// Sets up the output sample allocator.
@ -309,7 +309,7 @@ protected:
com_ptr<IMFAttributes> m_spAttributes; // MFT Attributes.
com_ptr<IMFAttributes> m_spAllocatorAttributes;// Output sample allocator attributes.
bool m_bAllocatorInitialized;// True if sample allocator has been initialized.
volatile ULONG m_ulSampleCounter; // Frame number, can use to pick a IStreamModel.
volatile ULONG m_ulSampleCounter; // Frame number, can use to pick a StreamModelBase.
volatile ULONG m_ulProcessedFrameNum; // Number of frames we've processed.
volatile ULONG m_currFrameNumber; // The current frame to be processed.
@ -347,7 +347,7 @@ protected:
// Model Inference fields
int m_numThreads = // Number of threads running inference in parallel.
max(std::thread::hardware_concurrency(), 5);
std::vector<std::unique_ptr<IStreamModel>> m_models; // m_numThreads number of models to run inference in parallel.
std::vector<std::unique_ptr<StreamModelBase>> m_models; // m_numThreads number of models to run inference in parallel.
int modelIndex = 0;
};