address comments in segmentmodel
This commit is contained in:
Родитель
bd70309d90
Коммит
c165e0bdb0
|
@ -1 +0,0 @@
|
|||
*.onnx
|
|
@ -38,11 +38,10 @@ enum OnnxDataType : long {
|
|||
}OnnxDataType;
|
||||
|
||||
|
||||
int g_scale = 5;
|
||||
auto outputBindProperties = PropertySet();
|
||||
const int32_t opset = 12;
|
||||
|
||||
/**** Style transfer model ****/
|
||||
void StyleTransfer::SetModels(int w, int h)
|
||||
void StyleTransfer::InitializeSession(int w, int h)
|
||||
{
|
||||
// TODO: Use w/h or use the 720x720 of the mode
|
||||
SetImageSize(720, 720); // SIze model input sizes fixed to 720x720
|
||||
|
@ -53,7 +52,6 @@ void StyleTransfer::Run(IDirect3DSurface src, IDirect3DSurface dest)
|
|||
{
|
||||
m_bSyncStarted = TRUE;
|
||||
|
||||
assert(m_session.Device().AdapterId() == m_highPerfAdapter);
|
||||
VideoFrame inVideoFrame = VideoFrame::CreateWithDirect3D11Surface(src);
|
||||
VideoFrame outVideoFrame = VideoFrame::CreateWithDirect3D11Surface(dest);
|
||||
SetVideoFrames(inVideoFrame, outVideoFrame);
|
||||
|
@ -72,11 +70,12 @@ void StyleTransfer::Run(IDirect3DSurface src, IDirect3DSurface dest)
|
|||
|
||||
m_bSyncStarted = FALSE;
|
||||
}
|
||||
|
||||
LearningModel StyleTransfer::GetModel()
|
||||
{
|
||||
auto rel = std::filesystem::current_path();
|
||||
rel.append("Assets\\mosaic.onnx");
|
||||
return LearningModel::LoadFromFilePath(rel + L"");
|
||||
auto model_path = std::filesystem::current_path();
|
||||
model_path.append("Assets\\mosaic.onnx");
|
||||
return LearningModel::LoadFromFilePath(model_path.c_str());
|
||||
}
|
||||
|
||||
|
||||
|
@ -85,9 +84,9 @@ BackgroundBlur::~BackgroundBlur() {
|
|||
if (m_session) m_session.Close();
|
||||
}
|
||||
|
||||
void BackgroundBlur::SetModels(int w, int h)
|
||||
void BackgroundBlur::InitializeSession(int w, int h)
|
||||
{
|
||||
w /= g_scale; h /= g_scale;
|
||||
w /= m_scale; h /= m_scale;
|
||||
SetImageSize(w, h);
|
||||
|
||||
auto joinOptions1 = LearningModelJoinOptions();
|
||||
|
@ -103,7 +102,7 @@ void BackgroundBlur::SetModels(int w, int h)
|
|||
joinOptions2.Link(L"FCN_out", L"InputScores");
|
||||
joinOptions2.Link(L"OutputImageForward", L"InputImage");
|
||||
joinOptions2.JoinedNodePrefix(L"Post_");
|
||||
//joinOptions2.PromoteUnlinkedOutputsToFusedOutputs(false); // Causes winrt originate error in FusedGraphKernel.cpp, but works on CPU
|
||||
//joinOptions2.PromoteUnlinkedOutputsToFusedOutputs(false); // TODO: Causes winrt originate error in FusedGraphKernel.cpp, but works on CPU
|
||||
auto modelExperimental2 = LearningModelExperimental(intermediateModel);
|
||||
LearningModel modelFused = modelExperimental2.JoinModel(PostProcess(1, 3, h, w, 1), joinOptions2);
|
||||
|
||||
|
@ -115,16 +114,15 @@ void BackgroundBlur::SetModels(int w, int h)
|
|||
}
|
||||
LearningModel BackgroundBlur::GetModel()
|
||||
{
|
||||
auto rel = std::filesystem::current_path();
|
||||
rel.append("Assets\\fcn-resnet50-12.onnx");
|
||||
return LearningModel::LoadFromFilePath(rel + L"");
|
||||
auto model_path = std::filesystem::current_path();
|
||||
model_path.append("Assets\\fcn-resnet50-12.onnx");
|
||||
return LearningModel::LoadFromFilePath(model_path.c_str());
|
||||
}
|
||||
|
||||
void BackgroundBlur::Run(IDirect3DSurface src, IDirect3DSurface dest)
|
||||
{
|
||||
m_bSyncStarted = TRUE;
|
||||
// Device validation
|
||||
assert(m_session.Device().AdapterId() == m_highPerfAdapter);
|
||||
VideoFrame inVideoFrame = VideoFrame::CreateWithDirect3D11Surface(src);
|
||||
VideoFrame outVideoFrame = VideoFrame::CreateWithDirect3D11Surface(dest);
|
||||
SetVideoFrames(inVideoFrame, outVideoFrame);
|
||||
|
@ -145,7 +143,7 @@ void BackgroundBlur::Run(IDirect3DSurface src, IDirect3DSurface dest)
|
|||
|
||||
LearningModel BackgroundBlur::PostProcess(long n, long c, long h, long w, long axis)
|
||||
{
|
||||
auto builder = LearningModelBuilder::Create(12)
|
||||
auto builder = LearningModelBuilder::Create(opset)
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputImage", TensorKind::Float, { n, c, h, w }))
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"InputScores", TensorKind::Float, { -1, -1, h, w })) // Different input type?
|
||||
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"OutputImage", TensorKind::Float, { n, c, h, w }))
|
||||
|
@ -204,8 +202,7 @@ LearningModel BackgroundBlur::PostProcess(long n, long c, long h, long w, long a
|
|||
|
||||
LearningModel Invert(long n, long c, long h, long w)
|
||||
{
|
||||
|
||||
auto builder = LearningModelBuilder::Create(11)
|
||||
auto builder = LearningModelBuilder::Create(opset)
|
||||
// Loading in buffers and reshape
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, { n, c, h, w }))
|
||||
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, { n, c, h, w }))
|
||||
|
@ -230,7 +227,7 @@ LearningModel Normalize0_1ThenZScore(long h, long w, long c, const std::array<fl
|
|||
assert(means.size() == c);
|
||||
assert(stddev.size() == c);
|
||||
|
||||
auto builder = LearningModelBuilder::Create(12)
|
||||
auto builder = LearningModelBuilder::Create(opset)
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", L"The NCHW image", TensorKind::Float, {1, c, h, w}))
|
||||
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", L"The NCHW image normalized with mean and stddev.", TensorKind::Float, {1, c, h, w}))
|
||||
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"OutputImageForward", L"The NCHW image forwarded through the model.", TensorKind::Float, {1, c, h, w}))
|
||||
|
@ -262,7 +259,7 @@ LearningModel Normalize0_1ThenZScore(long h, long w, long c, const std::array<fl
|
|||
|
||||
LearningModel ReshapeFlatBufferToNCHW(long n, long c, long h, long w)
|
||||
{
|
||||
auto builder = LearningModelBuilder::Create(11)
|
||||
auto builder = LearningModelBuilder::Create(opset)
|
||||
// Loading in buffers and reshape
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::UInt8, { 1, n * c * h * w }))
|
||||
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, {n, c, h, w}))
|
||||
|
|
|
@ -29,47 +29,37 @@ LearningModel Normalize0_1ThenZScore(long height, long width, long channels, con
|
|||
LearningModel ReshapeFlatBufferToNCHW(long n, long c, long h, long w);
|
||||
LearningModel Invert(long n, long c, long h, long w);
|
||||
|
||||
class IStreamModel
|
||||
class StreamModelBase
|
||||
{
|
||||
public:
|
||||
IStreamModel():
|
||||
StreamModelBase():
|
||||
m_inputVideoFrame(NULL),
|
||||
m_outputVideoFrame(NULL),
|
||||
m_session(NULL),
|
||||
m_binding(NULL),
|
||||
m_bSyncStarted(FALSE)
|
||||
{}
|
||||
IStreamModel(int w, int h) :
|
||||
StreamModelBase(int w, int h) :
|
||||
m_inputVideoFrame(NULL),
|
||||
m_outputVideoFrame(NULL),
|
||||
m_session(NULL),
|
||||
m_binding(NULL),
|
||||
m_bSyncStarted(FALSE)
|
||||
{}
|
||||
~IStreamModel() {
|
||||
virtual ~StreamModelBase() {
|
||||
if(m_session) m_session.Close();
|
||||
if(m_binding) m_binding.Clear();
|
||||
if (m_inputVideoFrame) m_inputVideoFrame.Close();
|
||||
if (m_outputVideoFrame) m_outputVideoFrame.Close();
|
||||
if (m_device) m_device.Close();
|
||||
};
|
||||
|
||||
virtual void SetModels(int w, int h) =0;
|
||||
virtual void InitializeSession(int w, int h) =0;
|
||||
virtual void Run(IDirect3DSurface src, IDirect3DSurface dest) =0;
|
||||
|
||||
void SetUseGPU(bool use) {
|
||||
m_bUseGPU = use;
|
||||
}
|
||||
void SetDevice() {
|
||||
assert(m_session.Device().AdapterId() == m_highPerfAdapter);
|
||||
assert(m_session.Device().Direct3D11Device() != NULL);
|
||||
m_device = m_session.Device().Direct3D11Device();
|
||||
auto device = m_session.Device().AdapterId();
|
||||
}
|
||||
|
||||
// Synchronous eval status
|
||||
BOOL m_bSyncStarted;
|
||||
VideoFrame m_outputVideoFrame;
|
||||
static const int m_scale = 5;
|
||||
|
||||
protected:
|
||||
// Cache input frames into a shareable d3d-backed VideoFrame
|
||||
|
@ -77,10 +67,7 @@ protected:
|
|||
{
|
||||
if (true || !m_bVideoFramesSet)
|
||||
{
|
||||
if (m_device == NULL)
|
||||
{
|
||||
SetDevice();
|
||||
}
|
||||
auto device = m_session.Device().Direct3D11Device();
|
||||
auto inDesc = inVideoFrame.Direct3DSurface().Description();
|
||||
auto outDesc = outVideoFrame.Direct3DSurface().Description();
|
||||
/*
|
||||
|
@ -88,8 +75,8 @@ protected:
|
|||
whereas every model created with LearningModelBuilder takes arguments in (height, width) order.
|
||||
*/
|
||||
auto format = winrt::Windows::Graphics::DirectX::DirectXPixelFormat::B8G8R8X8UIntNormalized;
|
||||
m_inputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, m_device);
|
||||
m_outputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, m_device);
|
||||
m_inputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, device);
|
||||
m_outputVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(format, m_imageWidthInPixels, m_imageHeightInPixels, device);
|
||||
m_bVideoFramesSet = true;
|
||||
}
|
||||
// TODO: Fix bug in WinML so that the surfaces from capture engine are shareable, remove copy.
|
||||
|
@ -104,8 +91,6 @@ protected:
|
|||
|
||||
LearningModelSession CreateLearningModelSession(const LearningModel& model, bool closedModel = true) {
|
||||
auto device = m_bUseGPU ? LearningModelDevice(LearningModelDeviceKind::DirectXHighPerformance) : LearningModelDevice(LearningModelDeviceKind::Default);
|
||||
auto displayAdapter = winrt::Windows::Devices::Display::Core::DisplayAdapter::FromId(device.AdapterId());
|
||||
m_highPerfAdapter = device.AdapterId();
|
||||
auto options = LearningModelSessionOptions();
|
||||
options.BatchSizeOverride(0);
|
||||
options.CloseModelOnSessionCreation(closedModel);
|
||||
|
@ -119,9 +104,6 @@ protected:
|
|||
|
||||
UINT32 m_imageWidthInPixels;
|
||||
UINT32 m_imageHeightInPixels;
|
||||
IDirect3DDevice m_device;
|
||||
// For debugging potential device issues
|
||||
winrt::Windows::Graphics::DisplayAdapterId m_highPerfAdapter{};
|
||||
|
||||
// Learning Model Binding and Session.
|
||||
LearningModelSession m_session;
|
||||
|
@ -129,34 +111,34 @@ protected:
|
|||
};
|
||||
|
||||
|
||||
class StyleTransfer : public IStreamModel {
|
||||
class StyleTransfer : public StreamModelBase {
|
||||
public:
|
||||
StyleTransfer(int w, int h) : IStreamModel(w, h)
|
||||
StyleTransfer(int w, int h) : StreamModelBase(w, h)
|
||||
{
|
||||
SetModels(w, h);
|
||||
InitializeSession(w, h);
|
||||
}
|
||||
StyleTransfer() : IStreamModel() {};
|
||||
~StyleTransfer(){};
|
||||
void SetModels(int w, int h);
|
||||
StyleTransfer() : StreamModelBase() {};
|
||||
virtual ~StyleTransfer(){};
|
||||
void InitializeSession(int w, int h);
|
||||
void Run(IDirect3DSurface src, IDirect3DSurface dest);
|
||||
private:
|
||||
LearningModel GetModel();
|
||||
};
|
||||
|
||||
|
||||
class BackgroundBlur : public IStreamModel
|
||||
class BackgroundBlur : public StreamModelBase
|
||||
{
|
||||
public:
|
||||
BackgroundBlur(int w, int h) :
|
||||
IStreamModel(w, h)
|
||||
StreamModelBase(w, h)
|
||||
{
|
||||
SetModels(w, h);
|
||||
InitializeSession(w, h);
|
||||
}
|
||||
BackgroundBlur() :
|
||||
IStreamModel()
|
||||
StreamModelBase()
|
||||
{};
|
||||
~BackgroundBlur();
|
||||
void SetModels(int w, int h);
|
||||
virtual ~BackgroundBlur();
|
||||
void InitializeSession(int w, int h);
|
||||
void Run(IDirect3DSurface src, IDirect3DSurface dest);
|
||||
|
||||
private:
|
||||
|
|
|
@ -433,7 +433,7 @@ HRESULT TransformAsync::UpdateDX11Device()
|
|||
D3D11_FENCE_FLAG flag = D3D11_FENCE_FLAG_NONE;
|
||||
m_spDevice->CreateFence(m_fenceValue, flag, __uuidof(ID3D11Fence), m_spFence.put_void());
|
||||
// Probably don't need to save the event for the first frame to render, since that will be long anyways w first Eval/Bind.
|
||||
// Actually prob will be long for the first little bit anyways bc of each IStreamModel to select, but oh well. It'll be fine.
|
||||
// Actually prob will be long for the first little bit anyways bc of each StreamModelBase to select, but oh well. It'll be fine.
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -799,7 +799,7 @@ HRESULT TransformAsync::InitializeTransform(void)
|
|||
|
||||
CHECK_HR(hr = CSampleQueue::Create(&m_pOutputSampleQueue));
|
||||
|
||||
// Set up circular queue of IStreamModels
|
||||
// Set up circular queue of StreamModelBases
|
||||
for (int i = 0; i < m_numThreads; i++) {
|
||||
// TODO: Have a dialogue to select which model to select for real-time inference.
|
||||
m_models.push_back(std::make_unique<BackgroundBlur>());
|
||||
|
@ -898,7 +898,7 @@ HRESULT TransformAsync::UpdateFormatInfo()
|
|||
// Set the size of the SegmentModel
|
||||
for (int i = 0; i < m_numThreads; i++)
|
||||
{
|
||||
m_models[i]->SetModels(m_imageWidthInPixels, m_imageHeightInPixels);
|
||||
m_models[i]->InitializeSession(m_imageWidthInPixels, m_imageHeightInPixels);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -226,7 +226,7 @@ public:
|
|||
HRESULT NotifyRelease();
|
||||
#pragma endregion IMFVideoSampleAllocatorNotify
|
||||
|
||||
// Uses the next available IStreamModel to run inference on pInputSample
|
||||
// Uses the next available StreamModelBase to run inference on pInputSample
|
||||
// and allocates a transformed output sample.
|
||||
HRESULT SubmitEval(IMFSample* pInputSample);
|
||||
|
||||
|
@ -268,7 +268,7 @@ protected:
|
|||
HRESULT OnSetD3DManager(ULONG_PTR ulParam);
|
||||
|
||||
// After the input type is set, update MFT format information and sets
|
||||
// IStreamModel input sizes.
|
||||
// StreamModelBase input sizes.
|
||||
HRESULT UpdateFormatInfo();
|
||||
|
||||
// Sets up the output sample allocator.
|
||||
|
@ -309,7 +309,7 @@ protected:
|
|||
com_ptr<IMFAttributes> m_spAttributes; // MFT Attributes.
|
||||
com_ptr<IMFAttributes> m_spAllocatorAttributes;// Output sample allocator attributes.
|
||||
bool m_bAllocatorInitialized;// True if sample allocator has been initialized.
|
||||
volatile ULONG m_ulSampleCounter; // Frame number, can use to pick a IStreamModel.
|
||||
volatile ULONG m_ulSampleCounter; // Frame number, can use to pick a StreamModelBase.
|
||||
volatile ULONG m_ulProcessedFrameNum; // Number of frames we've processed.
|
||||
volatile ULONG m_currFrameNumber; // The current frame to be processed.
|
||||
|
||||
|
@ -347,7 +347,7 @@ protected:
|
|||
// Model Inference fields
|
||||
int m_numThreads = // Number of threads running inference in parallel.
|
||||
max(std::thread::hardware_concurrency(), 5);
|
||||
std::vector<std::unique_ptr<IStreamModel>> m_models; // m_numThreads number of models to run inference in parallel.
|
||||
std::vector<std::unique_ptr<StreamModelBase>> m_models; // m_numThreads number of models to run inference in parallel.
|
||||
int modelIndex = 0;
|
||||
|
||||
};
|
Загрузка…
Ссылка в новой задаче