Clean up and some extra documenation

This commit is contained in:
Linnea May 2021-11-11 11:05:43 -08:00
Родитель 04e052370e
Коммит 7cf979e8de
6 изменённых файлов: 19 добавлений и 126 удалений

Просмотреть файл

@ -268,6 +268,11 @@ void OnStartStream(HWND hwnd)
{
HRESULT hr = S_OK;
hr = g_pPlayer->StartStream();
if (!SUCCEEDED(hr))
{
NotifyError(hwnd, L"Could not find an adequate video capture device.", hr);
UpdateUI(hwnd, Closed);
}
}

Просмотреть файл

@ -73,42 +73,36 @@ void SegmentModel::Run(const BYTE** pSrc, BYTE** pDest, DWORD cbImageSize)
LearningModelSession normalizationSession = CreateLearningModelSession(Normalize0_1ThenZScore(m_imageHeightInPixels, m_imageWidthInPixels, 3, mean, std));
ITensor intermediateTensor = TensorFloat::Create(shape);
auto normalizationBinding = Evaluate(normalizationSession, std::vector<ITensor*>{&tensorizedImg}, & intermediateTensor);
// intermediateTensor.as<TensorFloat>().GetAsVectorView().GetMany(0, testpixels);
// 3. Run through actual model
std::vector<int64_t> FCNResnetOutputShape = { 1, 21, m_imageHeightInPixels, m_imageWidthInPixels };
LearningModelSession FCNResnetSession = CreateLearningModelSession(FCNResnet());
ITensor FCNResnetOutput = TensorFloat::Create(FCNResnetOutputShape);
auto FCNResnetBinding = Evaluate(FCNResnetSession, std::vector<ITensor*>{&intermediateTensor}, & FCNResnetOutput);
//FCNResnetOutput.as<TensorFloat>().GetAsVectorView().GetMany(0, testpixels);
// 4.Extract labels with argmax
ITensor rawLabels = TensorFloat::Create({1, 1, m_imageHeightInPixels, m_imageWidthInPixels});
LearningModelSession argmaxSession = CreateLearningModelSession(Argmax(1, m_imageHeightInPixels, m_imageWidthInPixels));
auto argmaxBinding = Evaluate(argmaxSession, std::vector<ITensor*>{&FCNResnetOutput}, & rawLabels);
//rawLabels.as<TensorFloat>().GetAsVectorView().GetMany(0, testpixels);
// 5. Get the foreground
ITensor foreground = TensorUInt8Bit::Create(std::vector<int64_t>{1, m_imageHeightInPixels, m_imageWidthInPixels, 3});
LearningModelSession foregroundSession = CreateLearningModelSession(GetBackground(1, 3, m_imageHeightInPixels, m_imageWidthInPixels));
auto foregroundBinding = Evaluate(foregroundSession, std::vector<ITensor*>{&tensorizedImg, &rawLabels}, & foreground);
// Enable tensorcpusync for the last evaluate so can extract and give back to buffer
// Will remove once can just pass along d3d reources back to MFT
auto foregroundBinding = Evaluate(foregroundSession, std::vector<ITensor*>{&tensorizedImg, &rawLabels}, & foreground, true);
// TODO: Move data over to CPU somehow?
UINT32 outCapacity = 0;
if (useGPU)
{
// v1: just get the reference- should fail
auto reference = foreground.as<TensorUInt8Bit>().CreateReference().data();
// v2: get the buffer from tensornative
auto f = foreground.as<ITensorNative>();
foreground.as<ITensorNative>()->GetBuffer(pDest, &outCapacity);
// v3: get from a d3dresource
ID3D12Resource* res = NULL;
/*ID3D12Resource* res = NULL;
HRESULT hr = foreground.as<ITensorNative>()->GetD3D12Resource(&res);
UINT DstRowPitch = 0, DstDepthPitch = 0, SrcSubresource = 0;
hr = res->ReadFromSubresource((void*)*pDest, DstRowPitch, DstDepthPitch, SrcSubresource, NULL);
hr = res->ReadFromSubresource((void*)*pDest, DstRowPitch, DstDepthPitch, SrcSubresource, NULL);*/
return;
}
else
@ -142,7 +136,6 @@ void SegmentModel::RunTest(const BYTE** pSrc, BYTE** pDest, DWORD cbImageSize)
hr = res->ReadFromSubresource((void*)*pDest, DstRowPitch, DstDepthPitch, SrcSubresource, NULL);*/
return;
}
;
}
void SegmentModel::SetImageSize(UINT32 w, UINT32 h)

Просмотреть файл

@ -9,7 +9,6 @@
using namespace winrt::Microsoft::AI::MachineLearning;
using namespace winrt::Microsoft::AI::MachineLearning::Experimental;
// TODO: Implement IUnknown?
class SegmentModel {
public:
LearningModelSession m_sess;

Просмотреть файл

@ -502,11 +502,6 @@ HRESULT TransformBlur::SetInputType(
// Find a decoder configuration
if (m_pD3DDeviceManager)
{
UINT numDevices = 0;
UINT numFormats = 0;
GUID* pguids = NULL;
D3DFORMAT* d3dFormats = NULL;
UINT numProfiles = m_pD3DVideoDevice->GetVideoDecoderProfileCount();
for (UINT i = 0; i < numProfiles; i++)
{
@ -515,18 +510,14 @@ HRESULT TransformBlur::SetInputType(
BOOL rgbSupport;
hr = m_pD3DVideoDevice->CheckVideoDecoderFormat(&pDecoderProfile, DXGI_FORMAT_AYUV, &rgbSupport);
// IF H264 and supports a yuv/rgb format
// DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8X8_UNORM
hr = m_pD3DVideoDevice->CheckVideoDecoderFormat(&pDecoderProfile, DXGI_FORMAT_B8G8R8X8_UNORM, &rgbSupport);
if (rgbSupport) {
// D3D11_DECODER_PROFILE_H264_VLD_NOFGT
OutputDebugString(L"supports AYUV!");
OutputDebugString(L"supports RGB32!\n");
}
}
// TODO: Move to done
CoTaskMemFree(pguids);
CoTaskMemFree(d3dFormats);
}
// The type is OK.

Просмотреть файл

@ -184,8 +184,6 @@ HRESULT CPlayer::OpenURL(const WCHAR* sURL)
goto done;
}
// Set the topology on the media session.
hr = m_pSession->SetTopology(0, pTopology);
if (FAILED(hr))
@ -295,12 +293,9 @@ HRESULT EnumerateCaptureFormats(IMFMediaSource* pSource)
GUID subtype = GUID_NULL;
CHECK_HR(hr = pType->GetGUID(MF_MT_SUBTYPE, &subtype));
if (subtype == MFVideoFormat_RGB24) {
//OutputDebugString(L"This device supports RGB!");
if (subtype == MFVideoFormat_RGB32) {
SetDeviceFormat(pSource, i);
//LogMediaType(pType);
//OutputDebugString(L"\n");
break;
}
@ -319,7 +314,7 @@ done:
}
// Open a URL for playback.
// Start streaming playback
HRESULT CPlayer::StartStream()
{
// 1. Create a new media session.
@ -362,11 +357,6 @@ HRESULT CPlayer::StartStream()
goto done;
}
/*IMFTopoLoader* loader;
MFCreateTopoLoader(&loader);
IMFTopology* pOut;
loader->Load(pTopology, &pOut, NULL);*/
// Set the topology on the media session.
hr = m_pSession->SetTopology(0, pTopology);
if (FAILED(hr))
@ -564,9 +554,6 @@ HRESULT CPlayer::HandleEvent(UINT_PTR pEventPtr)
case MESessionTopologyStatus:
hr = OnTopologyStatus(pEvent);
break;
case MESessionTopologySet:
// hr = OnTopologySet(pEvent);
break;
case MEEndOfPresentation:
hr = OnPresentationEnded(pEvent);
break;
@ -680,71 +667,6 @@ done:
return S_OK;
}
HRESULT CPlayer::OnTopologySet(IMFMediaEvent* pEvent) {
IDirect3DDeviceManager9* man = NULL;
HRESULT hr = S_OK;
IMFTransform* pMFT = NULL;
// Query the topo nodes for 1) video renderer service 2) MFT with D3d aware -> give them the d9manager and maybe set MF_TOPONODE_D3DAWARE
// Have to have the topology queued to the media session before can find the d3dmanager
hr = MFGetService(m_pSession, MR_VIDEO_ACCELERATION_SERVICE, IID_IDirect3DDeviceManager9, (void**)&man);
if (hr == S_OK)
{
OutputDebugString(L"Found the d3d9 manager");
PROPVARIANT var;
IMFTopology* pTopology = NULL;
WORD pNumNodes = 0;
IMFTopologyNode* pNode = NULL;
UINT32 aware = FALSE;
PropVariantInit(&var);
hr = pEvent->GetValue(&var);
if (SUCCEEDED(hr))
{
if (var.vt != VT_UNKNOWN)
{
hr = E_UNEXPECTED;
}
}
if (SUCCEEDED(hr))
{
hr = var.punkVal->QueryInterface(__uuidof(IMFTopology), (void**)&pTopology);
}
PropVariantClear(&var);
//m_pSession->GetFullTopology(MFSESSION_GETFULLTOPOLOGY_CURRENT, NULL, &pTopology);
CHECK_HR(hr = pTopology->GetNodeCount(&pNumNodes));
MF_TOPOLOGY_TYPE pType ;
for (WORD i = 0; i < pNumNodes; i++) {
pTopology->GetNode(i, &pNode);
// TODO: Instantiate outside loop?
pNode->GetNodeType(&pType);
if (pType != NULL && pType == MF_TOPOLOGY_TRANSFORM_NODE)
{
IMFAttributes* pAttr = NULL;
// Get the underlying MFT
CHECK_HR(hr = pNode->GetObject((IUnknown**)&pMFT));
pMFT->GetAttributes(&pAttr);
// UINT32 p_aware = FALSE;
aware = MFGetAttributeUINT32(pAttr, MF_SA_D3D_AWARE, FALSE);
if (aware) {
pMFT->ProcessMessage(MFT_MESSAGE_SET_D3D_MANAGER, (ULONG_PTR)man);
break;
}
}
}
// TODO: Can we add a cache of the MFT so add this ?
}
done:
// TODO: Release d3dManager?
SafeRelease(&pMFT);
return hr;
}
// Create a new instance of the media session.
HRESULT CPlayer::CreateSession()
{
@ -960,7 +882,6 @@ done:
// Create an activation object for a renderer, based on the stream media type.
HRESULT CreateMediaSinkActivate(
IMFStreamDescriptor* pSourceSD, // Pointer to the stream descriptor.
HWND hVideoWindow, // Handle to the video clipping window.
@ -1072,7 +993,6 @@ done:
// BindOutputNode
// Sets the IMFStreamSink pointer on an output node.
HRESULT BindOutputNode(IMFTopologyNode* pNode)
{
IUnknown* pNodeObject = NULL;
@ -1203,7 +1123,6 @@ done:
HRESULT AddTransformNode(
IMFTopology* pTopology, // Topology.
IMFDXGIDeviceManager* d3dManager,
//const CLSID& clsid, // CLSID of the MFT.
IMFTopologyNode** ppNode // Receives the node pointer.
)
{
@ -1218,7 +1137,7 @@ HRESULT AddTransformNode(
// Create the node.
hr = MFCreateTopologyNode(MF_TOPOLOGY_TRANSFORM_NODE, &pNode);
// Set the CLSID attribute.
// Set the object of the node to the MFT
if (SUCCEEDED(hr))
{
hr = pNode->SetObject(pMFT);
@ -1310,7 +1229,6 @@ HRESULT AddBranchToPartialTopology(
goto done;
}
// Check if a video stream, then create transform node.
// Get the media type handler for the stream.
CHECK_HR(hr = pSD->GetMediaTypeHandler(&pHandler));
@ -1363,9 +1281,6 @@ done:
return hr;
}
// Create a playback topology from a media source.
HRESULT CreatePlaybackTopology(
IMFMediaSource* pSource, // Media source.

Просмотреть файл

@ -30,8 +30,6 @@
#include "resource.h"
template <class T> void SafeRelease(T** ppT)
{
if (*ppT)
@ -107,7 +105,6 @@ protected:
virtual HRESULT OnTopologyStatus(IMFMediaEvent* pEvent);
virtual HRESULT OnPresentationEnded(IMFMediaEvent* pEvent);
virtual HRESULT OnNewPresentation(IMFMediaEvent* pEvent);
virtual HRESULT OnTopologySet(IMFMediaEvent* pEvent);
// Override to handle additional session events.
virtual HRESULT OnSessionEvent(IMFMediaEvent*, MediaEventType)
@ -115,18 +112,11 @@ protected:
return S_OK;
}
/*HRESULT AddTransformNode(
IMFTopology* pTopology, // Topology.
const CLSID& clsid, // CLSID of the MFT.
IMFTopologyNode** ppNode // Receives the node pointer.
);*/
protected:
long m_nRefCount; // Reference count.
IMFMediaSession* m_pSession;
IMFMediaSource* m_pSource;
IMFMediaSession* m_pSession;
IMFMediaSource* m_pSource;
IMFVideoDisplayControl* m_pVideoDisplay;
HWND m_hwndVideo; // Video window.