274 строки
10 KiB
C++
274 строки
10 KiB
C++
//
|
|
// <copyright file="FullNetworkBuilder.h" company="Microsoft">
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// </copyright>
|
|
//
|
|
#pragma once
|
|
|
|
#include "ComputationNetwork.h"
|
|
#include "IComputationNetBuilder.h"
|
|
#include "basetypes.h"
|
|
#include <string>
|
|
#include "commandArgUtil.h"
|
|
#include "DataReader.h"
|
|
#include "matrix.h"
|
|
#include <stdexcept>
|
|
|
|
using namespace std;
|
|
|
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
|
|
|
template<class ElemType>
|
|
class FullNetworkBuilder : public IComputationNetBuilder<ElemType>
|
|
{
|
|
typedef ComputationNode<ElemType>* ComputationNodePtr;
|
|
|
|
public:
|
|
FullNetworkBuilder() {} // empty constructor, call Init immediately hereafter
|
|
FullNetworkBuilder(const ConfigParameters& config)
|
|
{
|
|
Init(config);
|
|
}
|
|
|
|
void Init(const std::wstring& networkConfig, short deviceId=AUTOPLACEMATRIX, const bool uniformInit = true, const ElemType initValueScale = 1.0f)
|
|
{
|
|
m_deviceId=deviceId;
|
|
m_networkConfig=networkConfig;
|
|
m_uniformInit=uniformInit;
|
|
m_initValueScale=initValueScale;
|
|
|
|
if (m_deviceId == AUTOPLACEMATRIX)
|
|
m_deviceId = Matrix<ElemType>::GetBestGPUDeviceId();
|
|
|
|
m_net.SetDeviceID(m_deviceId);
|
|
if (m_deviceId < 0)
|
|
fprintf(stderr,"FullNetworkBuilder Using CPU\n");
|
|
else
|
|
fprintf(stderr,"FullNetworkBuilder Using GPU %d\n", m_deviceId);
|
|
}
|
|
|
|
// Init - Builder Initialize for multiple data sets
|
|
// config - [in] configuration parameters for the network builder
|
|
virtual void Init(const ConfigParameters& config)
|
|
{
|
|
const std::wstring& networkConfig = config("networkDescriptionPath");
|
|
short deviceId = config("deviceId", "1000");//AUTOPLACEMATRIX);
|
|
bool uniformInit = config("uniformInit", "true");
|
|
ElemType initValueScale=config("initValueScale", "1.0");
|
|
|
|
Init(networkConfig, deviceId, uniformInit, initValueScale);
|
|
}
|
|
|
|
virtual ~FullNetworkBuilder()
|
|
{}
|
|
virtual ComputationNetwork<ElemType>& LoadNetworkFromFile(const wstring& modelFileName, bool forceLoad = true)
|
|
{
|
|
if (m_net.GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load
|
|
m_net.LoadFromFile(modelFileName);
|
|
|
|
m_net.ResetEvalTimeStamp();
|
|
return m_net;
|
|
}
|
|
|
|
ComputationNetwork<ElemType>& LoadNetworkFromConfig(const wstring& configFileName, bool forceLoad = true)
|
|
{
|
|
if (m_net.GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load
|
|
LoadFromConfig(configFileName);
|
|
|
|
m_net.ResetEvalTimeStamp();
|
|
return m_net;
|
|
}
|
|
|
|
void SaveNetworkToConfig(const wstring& configFileName)
|
|
{
|
|
m_net.SaveToConfig(configFileName);
|
|
}
|
|
|
|
void LoadFromConfig(const std::wstring& fileName)
|
|
{
|
|
ConfigParameters configCN;
|
|
m_net.ClearNet();
|
|
configCN.LoadConfigFile(fileName);
|
|
ULONG randomSeed = 1;
|
|
|
|
std::map<std::wstring, ComputationNodePtr> origNameToNormNameMap;
|
|
|
|
ConfigParameters configNodes = configCN("NodeList");
|
|
for(auto iter = configNodes.begin(); iter != configNodes.end(); iter++)
|
|
{
|
|
std::wstring nodeName;
|
|
nodeName = msra::strfun::utf16(iter->first);
|
|
ConfigArray configNode = iter->second;
|
|
std::wstring opName = configNode[0];
|
|
|
|
if (opName == L"Input" || opName == InputValue<ElemType>::TypeName())
|
|
{
|
|
size_t rows = configNode.size() >= 2? configNode[1] : 1;
|
|
size_t cols = configNode.size() >= 3? configNode[2] : 1;
|
|
|
|
ComputationNodePtr input = m_net.Input(rows, cols, nodeName);
|
|
|
|
if (configNode.size()>=4)
|
|
{
|
|
std::wstring cmd = configNode[3];
|
|
transform(cmd.begin(), cmd.end(), cmd.begin(),tolower);
|
|
|
|
if (cmd == L"mvnorm") // do mean/var normalization
|
|
{
|
|
ComputationNodePtr meanPtr = m_net.Mean(input, nodeName + L"Mean");
|
|
ComputationNodePtr varPtr = m_net.InvStdDev(input, nodeName + L"InvStdDev");
|
|
ComputationNodePtr normPtr = m_net.PerDimMeanVarNormalization(input, meanPtr, varPtr, nodeName + L"MVNorm");
|
|
|
|
origNameToNormNameMap[nodeName] = normPtr;
|
|
}
|
|
else
|
|
throw new runtime_error("for InputValue nodes, only mvNorm is supported 4th argument\n");
|
|
}
|
|
}
|
|
else if (opName == L"Parameter" || opName == LearnableParameter<ElemType>::TypeName())
|
|
{
|
|
size_t rows = configNode.size() >= 2? configNode[1] : 1;
|
|
size_t cols = configNode.size() >= 3? configNode[2] : 1;
|
|
|
|
bool needGradient = false;
|
|
bool init = false;
|
|
ConfigArray initData;
|
|
|
|
// look for optional parameters
|
|
for (int i = 3; i < configNode.size(); ++i)
|
|
{
|
|
ConfigParameters configParam = configNode[i];
|
|
if (configParam.Exists("needGradient"))
|
|
needGradient = true;
|
|
else if (configParam.Exists("init"))
|
|
{
|
|
init = true;
|
|
initData = configParam["init"];
|
|
}
|
|
}
|
|
ComputationNodePtr nodePtr = m_net.Parameter(rows, cols, nodeName);
|
|
nodePtr->NeedGradient() = needGradient;
|
|
if (init)
|
|
{
|
|
m_net.InitLearnableParameters(nodePtr, m_uniformInit, randomSeed++, m_initValueScale);
|
|
}
|
|
}
|
|
else if (opName==L"Constant")
|
|
{
|
|
size_t rows = configNode.size() >= 2? configNode[1] : 1;
|
|
size_t cols = configNode.size() >= 3? configNode[2] : 1;
|
|
|
|
bool init = false;
|
|
ElemType val = configNode.size() >= 4? configNode[3] : 0;
|
|
|
|
ComputationNodePtr nodePtr = m_net.Parameter(rows, cols, nodeName);
|
|
nodePtr->NeedGradient() = false;
|
|
nodePtr->FunctionValues().SetValue(val);
|
|
}
|
|
else
|
|
{
|
|
m_net.CreateComputationNode(opName, nodeName);
|
|
}
|
|
}
|
|
|
|
//now link up all the nodes
|
|
ConfigParameters configRelation = configCN("Relation");
|
|
for(ConfigParameters::iterator iter = configRelation.begin();
|
|
iter != configRelation.end(); iter++)
|
|
{
|
|
std::wstring nodeName = msra::strfun::utf16(iter->first);
|
|
ConfigArray configNode = iter->second;
|
|
ComputationNodePtr nodePtr = m_net.GetNodeFromName(nodeName);
|
|
vector<ComputationNodePtr> inputs;
|
|
size_t numChildren = configNode.size();
|
|
for (size_t i=0; i < numChildren;++i)
|
|
{
|
|
ComputationNodePtr cnp;
|
|
if (origNameToNormNameMap.find(configNode[i])==origNameToNormNameMap.end())
|
|
cnp = m_net.GetNodeFromName(configNode[i]);
|
|
else
|
|
cnp = origNameToNormNameMap[configNode[i]];
|
|
|
|
inputs.push_back(cnp);
|
|
}
|
|
|
|
switch (numChildren)
|
|
{
|
|
case 1:
|
|
nodePtr->AttachInputs(inputs[0]);
|
|
break;
|
|
case 2:
|
|
nodePtr->AttachInputs(inputs[0], inputs[1]);
|
|
break;
|
|
case 3:
|
|
nodePtr->AttachInputs(inputs[0], inputs[1], inputs[2]);
|
|
break;
|
|
default:
|
|
throw std::logic_error("Invalid number of children.");
|
|
}
|
|
}
|
|
|
|
ConfigParameters configRoots = configCN("RootNodes");
|
|
ConfigArray configNode = configRoots("FeatureNodes");
|
|
for (size_t i=0; i<configNode.size(); i++)
|
|
{
|
|
std::wstring nodeName = configNode[i];
|
|
m_net.FeatureNodes().push_back(m_net.GetNodeFromName(nodeName));
|
|
}
|
|
|
|
configNode = configRoots("LabelNodes");
|
|
for (size_t i=0; i<configNode.size(); i++)
|
|
{
|
|
std::wstring nodeName = configNode[i];
|
|
m_net.LabelNodes().push_back(m_net.GetNodeFromName(nodeName));
|
|
}
|
|
|
|
configNode = configRoots("CriteriaNodes");
|
|
for (size_t i=0; i<configNode.size(); i++)
|
|
{
|
|
std::wstring nodeName = configNode[i];
|
|
m_net.FinalCriterionNodes().push_back(m_net.GetNodeFromName(nodeName));
|
|
}
|
|
|
|
configNode = configRoots("EvalNodes");
|
|
for (size_t i=0; i<configNode.size(); i++)
|
|
{
|
|
std::wstring nodeName = configNode[i];
|
|
m_net.EvaluationNodes().push_back(m_net.GetNodeFromName(nodeName));
|
|
}
|
|
|
|
configNode = configRoots("OutputNodes");
|
|
for (size_t i=0; i<configNode.size(); i++)
|
|
{
|
|
std::wstring nodeName = configNode[i];
|
|
m_net.OutputNodes().push_back(m_net.GetNodeFromName(nodeName));
|
|
}
|
|
|
|
m_net.PrintComputationTree(m_net.FinalCriterionNodes()[0],true,false);
|
|
}
|
|
|
|
virtual ComputationNetwork<ElemType>& BuildNetworkFromDescription()
|
|
{
|
|
if (m_net.GetTotalNumberOfNodes() < 1) //not built yet
|
|
{
|
|
ULONG randomSeed = 1;
|
|
|
|
LoadNetworkFromConfig(m_networkConfig);
|
|
}
|
|
|
|
m_net.ResetEvalTimeStamp();
|
|
return m_net;
|
|
}
|
|
|
|
|
|
private:
|
|
ComputationNetwork<ElemType> m_net;
|
|
std::wstring m_networkConfig;
|
|
bool m_uniformInit;
|
|
ElemType m_initValueScale;
|
|
|
|
short m_deviceId;
|
|
};
|
|
template class FullNetworkBuilder<float>;
|
|
template class FullNetworkBuilder<double>;
|
|
}}} |