This commit is contained in:
Gaizka Navarro 2016-02-17 11:14:45 +01:00
Родитель 73520b3160
Коммит e3dbc9cc83
1 изменённых файлов: 12 добавлений и 13 удалений

Просмотреть файл

@ -353,7 +353,16 @@ void ComputationNetwork::Read(const wstring& fileName)
} }
// save network to legacy DBN.exe format // save network to legacy DBN.exe format
class DbnLayer; class DbnLayer
{
public:
DbnLayer() : Node(nullptr), Bias(nullptr), Sigmoided(false) {}
ComputationNodeBasePtr Node;
ComputationNodeBasePtr Bias;
bool Sigmoided;
~DbnLayer() {};
};
template <class ElemType> template <class ElemType>
void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const
{ {
@ -450,7 +459,7 @@ void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wst
{ {
if (item == input) if (item == input)
{ {
RuntimeError("Cyclic dependency on node '%s'", item->GetName().c_str()); RuntimeError("Cyclic dependency on node '%ls'", item->GetName().c_str());
} }
} }
@ -544,7 +553,7 @@ void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wst
std::vector<ComputationNodeBasePtr> normalizationNodes = GetNodesWithType(net->GetAllNodes(), OperationNameOf(PerDimMeanVarNormalizationNode)); std::vector<ComputationNodeBasePtr> normalizationNodes = GetNodesWithType(net->GetAllNodes(), OperationNameOf(PerDimMeanVarNormalizationNode));
if (normalizationNodes.size() == 0) if (normalizationNodes.size() == 0)
{ {
RuntimeError("Model does not contain at least one node with the '%s' operation.", OperationNameOf(PerDimMeanVarNormalizationNode).c_str()); RuntimeError("Model does not contain at least one node with the '%ls' operation.", OperationNameOf(PerDimMeanVarNormalizationNode).c_str());
} }
ComputationNodeBasePtr meanNode = normalizationNodes.front()->GetInputs()[1]; ComputationNodeBasePtr meanNode = normalizationNodes.front()->GetInputs()[1];
@ -1333,14 +1342,4 @@ template void ComputationNetwork::SaveToDbnFile<double>(ComputationNetworkPtr ne
// register ComputationNetwork with the ScriptableObject system // register ComputationNetwork with the ScriptableObject system
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<ComputationNetwork> registerComputationNetwork(L"ComputationNetwork"); ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<ComputationNetwork> registerComputationNetwork(L"ComputationNetwork");
class DbnLayer
{
public:
DbnLayer() : Sigmoided(false), Node(nullptr), Bias(nullptr) {}
ComputationNodeBasePtr Node;
ComputationNodeBasePtr Bias;
bool Sigmoided;
};
} } } } } }