start hello model but messy
This commit is contained in:
Родитель
d77cf0d69f
Коммит
84ccefdd37
|
@ -38,7 +38,7 @@ enum OnnxDataType : long {
|
|||
} OnnxDataType;
|
||||
|
||||
|
||||
const int32_t opset = 12;
|
||||
const int32_t opset = 13;
|
||||
|
||||
|
||||
/**** Style transfer model ****/
|
||||
|
@ -83,12 +83,52 @@ void WindowsHello::InitializeSession(int w, int h) {
|
|||
SetImageSize(w, h);
|
||||
|
||||
// Preprocessing: grayscale model
|
||||
m_sessionPreprocess = CreateLearningModelSession(GrayScale(1, 3, m_imageHeightInPixels, m_imageWidthInPixels));
|
||||
m_bindingPreprocess = LearningModelBinding(m_sessionPreprocess);
|
||||
|
||||
m_session = CreateLearningModelSession(GetModel());
|
||||
m_binding = LearningModelBinding(m_session);
|
||||
|
||||
}
|
||||
|
||||
void WindowsHello::Run(IDirect3DSurface src, IDirect3DSurface dest)
|
||||
{
|
||||
VideoFrame inVideoFrame = VideoFrame::CreateWithDirect3D11Surface(src);
|
||||
VideoFrame outVideoFrame = VideoFrame::CreateWithDirect3D11Surface(dest);
|
||||
SetVideoFrames(inVideoFrame, outVideoFrame);
|
||||
auto outputBindProperties = PropertySet();
|
||||
|
||||
// Shape validation
|
||||
assert((UINT32)m_inputVideoFrame.Direct3DSurface().Description().Height == m_imageHeightInPixels);
|
||||
assert((UINT32)m_inputVideoFrame.Direct3DSurface().Description().Width == m_imageWidthInPixels);
|
||||
|
||||
// Preprocessing: grayscale
|
||||
std::vector<int64_t> shape = { 1, 1, m_imageHeightInPixels, m_imageWidthInPixels };
|
||||
ITensor intermediateTensor = TensorFloat::Create(shape);
|
||||
hstring inputName = m_sessionPreprocess.Model().InputFeatures().GetAt(0).Name();
|
||||
hstring outputName = m_sessionPreprocess.Model().OutputFeatures().GetAt(0).Name();
|
||||
|
||||
m_bindingPreprocess.Bind(inputName, m_inputVideoFrame);
|
||||
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(true));
|
||||
m_bindingPreprocess.Bind(outputName, intermediateTensor, outputBindProperties);
|
||||
m_sessionPreprocess.Evaluate(m_bindingPreprocess, L"");
|
||||
|
||||
// Run through actual model
|
||||
std::vector<int64_t> helloOutputShape = {1, -1, 4}; // TODO: How many bboxes?
|
||||
ITensor helloOutput = TensorFloat::Create(helloOutputShape);
|
||||
m_binding.Bind(m_session.Model().InputFeatures().GetAt(0).Name(), intermediateTensor);
|
||||
m_binding.Bind(m_session.Model().OutputFeatures().GetAt(0).Name(), helloOutput, outputBindProperties);
|
||||
m_session.Evaluate(m_binding, L"");
|
||||
|
||||
//m_.CopyToAsync(outVideoFrame);
|
||||
// Draw bboxes on videoframe?? For now just print them out
|
||||
auto bbox = helloOutput.as<TensorFloat>().GetAsVectorView();
|
||||
m_inputVideoFrame.CopyToAsync(outVideoFrame).get();
|
||||
OutputDebugString(L"BOUNDING BOX: ");
|
||||
OutputDebugString(std::to_wstring(bbox.GetAt(0)).c_str());
|
||||
OutputDebugString(L"\n");
|
||||
}
|
||||
|
||||
LearningModel WindowsHello::GetModel() {
|
||||
auto modelPath = std::filesystem::path(m_modelBasePath.c_str());
|
||||
modelPath.append("retina_rgb.onnx");
|
||||
|
@ -226,6 +266,21 @@ LearningModel Invert(long n, long c, long h, long w)
|
|||
return builder.CreateModel();
|
||||
}
|
||||
|
||||
LearningModel GrayScale(long n, long c, long h, long w)
|
||||
{
|
||||
// Go from [n, c, h, w] -> [1, 1, h, w]
|
||||
// Start with simple reduction of mean across channels
|
||||
auto builder = LearningModelBuilder::Create(opset)
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, { 1, c, h, w }))
|
||||
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, { 1, 1, h, w }))
|
||||
.Operators().Add(LearningModelOperator(L"ReduceMean")
|
||||
.SetInput(L"data", L"Input")
|
||||
.SetConstant(L"axes", TensorFloat::CreateFromIterable({ 1 }, { 1.f }))
|
||||
.SetOutput(L"reduced", L"Output")
|
||||
);
|
||||
return builder.CreateModel();
|
||||
}
|
||||
|
||||
LearningModel Normalize0_1ThenZScore(long h, long w, long c, const std::array<float, 3>& means, const std::array<float, 3>& stddev)
|
||||
{
|
||||
assert(means.size() == c);
|
||||
|
|
|
@ -110,7 +110,7 @@ protected:
|
|||
class StyleTransfer : public StreamModelBase
|
||||
{
|
||||
public:
|
||||
StyleTransfer() : StreamModelBase() {};
|
||||
StyleTransfer() : StreamModelBase(){};
|
||||
void InitializeSession(int w, int h);
|
||||
void Run(IDirect3DSurface src, IDirect3DSurface dest);
|
||||
private:
|
||||
|
@ -119,12 +119,16 @@ private:
|
|||
|
||||
class WindowsHello : public StreamModelBase {
|
||||
public:
|
||||
WindowsHello() : StreamModelBase() {};
|
||||
WindowsHello() :
|
||||
StreamModelBase(),
|
||||
m_sessionPreprocess(nullptr),
|
||||
m_bindingPreprocess(nullptr) {};
|
||||
void InitializeSession(int w, int h);
|
||||
void Run(IDirect3DSurface src, IDirect3DSurface dest);
|
||||
private:
|
||||
LearningModel GetModel();
|
||||
|
||||
LearningModelSession m_sessionPreprocess;
|
||||
LearningModelBinding m_bindingPreprocess;
|
||||
};
|
||||
|
||||
class BackgroundBlur : public StreamModelBase
|
||||
|
|
Загрузка…
Ссылка в новой задаче