This commit is contained in:
Linnea May 2022-11-22 14:40:44 -08:00
Родитель d77cf0d69f
Коммит 84ccefdd37
2 изменённых файлов: 63 добавлений и 4 удалений

Просмотреть файл

@ -38,7 +38,7 @@ enum OnnxDataType : long {
} OnnxDataType;
const int32_t opset = 12;
const int32_t opset = 13;
/**** Style transfer model ****/
@ -83,12 +83,52 @@ void WindowsHello::InitializeSession(int w, int h) {
SetImageSize(w, h);
// Preprocessing: grayscale model
m_sessionPreprocess = CreateLearningModelSession(GrayScale(1, 3, m_imageHeightInPixels, m_imageWidthInPixels));
m_bindingPreprocess = LearningModelBinding(m_sessionPreprocess);
m_session = CreateLearningModelSession(GetModel());
m_binding = LearningModelBinding(m_session);
}
void WindowsHello::Run(IDirect3DSurface src, IDirect3DSurface dest)
{
VideoFrame inVideoFrame = VideoFrame::CreateWithDirect3D11Surface(src);
VideoFrame outVideoFrame = VideoFrame::CreateWithDirect3D11Surface(dest);
SetVideoFrames(inVideoFrame, outVideoFrame);
auto outputBindProperties = PropertySet();
// Shape validation
assert((UINT32)m_inputVideoFrame.Direct3DSurface().Description().Height == m_imageHeightInPixels);
assert((UINT32)m_inputVideoFrame.Direct3DSurface().Description().Width == m_imageWidthInPixels);
// Preprocessing: grayscale
std::vector<int64_t> shape = { 1, 1, m_imageHeightInPixels, m_imageWidthInPixels };
ITensor intermediateTensor = TensorFloat::Create(shape);
hstring inputName = m_sessionPreprocess.Model().InputFeatures().GetAt(0).Name();
hstring outputName = m_sessionPreprocess.Model().OutputFeatures().GetAt(0).Name();
m_bindingPreprocess.Bind(inputName, m_inputVideoFrame);
outputBindProperties.Insert(L"DisableTensorCpuSync", PropertyValue::CreateBoolean(true));
m_bindingPreprocess.Bind(outputName, intermediateTensor, outputBindProperties);
m_sessionPreprocess.Evaluate(m_bindingPreprocess, L"");
// Run through actual model
std::vector<int64_t> helloOutputShape = {1, -1, 4}; // TODO: How many bboxes?
ITensor helloOutput = TensorFloat::Create(helloOutputShape);
m_binding.Bind(m_session.Model().InputFeatures().GetAt(0).Name(), intermediateTensor);
m_binding.Bind(m_session.Model().OutputFeatures().GetAt(0).Name(), helloOutput, outputBindProperties);
m_session.Evaluate(m_binding, L"");
//m_.CopyToAsync(outVideoFrame);
// Draw bboxes on videoframe?? For now just print them out
auto bbox = helloOutput.as<TensorFloat>().GetAsVectorView();
m_inputVideoFrame.CopyToAsync(outVideoFrame).get();
OutputDebugString(L"BOUNDING BOX: ");
OutputDebugString(std::to_wstring(bbox.GetAt(0)).c_str());
OutputDebugString(L"\n");
}
LearningModel WindowsHello::GetModel() {
auto modelPath = std::filesystem::path(m_modelBasePath.c_str());
modelPath.append("retina_rgb.onnx");
@ -226,6 +266,21 @@ LearningModel Invert(long n, long c, long h, long w)
return builder.CreateModel();
}
LearningModel GrayScale(long n, long c, long h, long w)
{
// Go from [n, c, h, w] -> [1, 1, h, w]
// Start with simple reduction of mean across channels
auto builder = LearningModelBuilder::Create(opset)
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, { 1, c, h, w }))
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, { 1, 1, h, w }))
.Operators().Add(LearningModelOperator(L"ReduceMean")
.SetInput(L"data", L"Input")
.SetConstant(L"axes", TensorFloat::CreateFromIterable({ 1 }, { 1.f }))
.SetOutput(L"reduced", L"Output")
);
return builder.CreateModel();
}
LearningModel Normalize0_1ThenZScore(long h, long w, long c, const std::array<float, 3>& means, const std::array<float, 3>& stddev)
{
assert(means.size() == c);

Просмотреть файл

@ -110,7 +110,7 @@ protected:
class StyleTransfer : public StreamModelBase
{
public:
StyleTransfer() : StreamModelBase() {};
StyleTransfer() : StreamModelBase(){};
void InitializeSession(int w, int h);
void Run(IDirect3DSurface src, IDirect3DSurface dest);
private:
@ -119,12 +119,16 @@ private:
class WindowsHello : public StreamModelBase {
public:
WindowsHello() : StreamModelBase() {};
WindowsHello() :
StreamModelBase(),
m_sessionPreprocess(nullptr),
m_bindingPreprocess(nullptr) {};
void InitializeSession(int w, int h);
void Run(IDirect3DSurface src, IDirect3DSurface dest);
private:
LearningModel GetModel();
LearningModelSession m_sessionPreprocess;
LearningModelBinding m_bindingPreprocess;
};
class BackgroundBlur : public StreamModelBase