59 строки
2.0 KiB
C++
59 строки
2.0 KiB
C++
//
|
|
// <copyright file="PTaskComputationNetwork.h" company="Microsoft">
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// </copyright>
|
|
//
|
|
#pragma once
|
|
|
|
#include <string>
|
|
#include "ComputationNetwork.h"
|
|
|
|
#include "PTask.h"
|
|
#include "PTaskGraphBuilder.h"
|
|
|
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
|
|
|
template<class ElemType>
|
|
class PTaskComputationNetwork : public ComputationNetwork<ElemType>
|
|
{
|
|
public:
|
|
PTaskComputationNetwork(short deviceId=AUTOPLACEMATRIX)
|
|
: ComputationNetwork<ElemType>(deviceId)
|
|
{
|
|
m_PTaskGraphBuilder = new PTaskGraphBuilder<ElemType>();
|
|
}
|
|
|
|
virtual ~PTaskComputationNetwork() { }
|
|
|
|
virtual void LoadFromFile(const std::wstring& fileName, FileOptions fileFormat = FileOptions::fileOptionsBinary) override
|
|
{
|
|
// Let the base class implementation deserialize all the state
|
|
// and construct its regular CN ...
|
|
this->ComputationNetwork<ElemType>::LoadFromFile(fileName, fileFormat);
|
|
|
|
// ... then use that state to create the corresponding PTask graph.
|
|
m_PTaskGraphBuilder->BuildFromComputationNetwork(this);
|
|
}
|
|
|
|
virtual void ComputeGradient(ComputationNodePtr rootNode) override
|
|
{
|
|
//printf("PTaskComputationNetwork::ComputeGradient called.\n");
|
|
this->ComputationNetwork<ElemType>::ComputeGradient(rootNode);
|
|
}
|
|
|
|
private:
|
|
// Copy constructor, should never be called.
|
|
PTaskComputationNetwork(const PTaskComputationNetwork<ElemType>& deepCopyFrom) {};
|
|
|
|
// Assignment operator, should never be called.
|
|
PTaskComputationNetwork<ElemType>& operator=(const PTaskComputationNetwork<ElemType>& deepCopyFrom)
|
|
{
|
|
assert(false);
|
|
return const_cast<PTaskComputationNetwork<ElemType>&>(deepCopyFrom); // return a value to avoid compile errors
|
|
};
|
|
|
|
PTaskGraphBuilder<ElemType>* m_PTaskGraphBuilder;
|
|
};
|
|
|
|
}}}
|