cleanedup all case-insensitive comparisons
This commit is contained in:
Родитель
f4549625d1
Коммит
16175a17c5
|
@ -14,6 +14,7 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// EqualInsensitive - check to see if two nodes are equal up to the length of the first string (must be at least half as long as actual node name)
|
||||
// TODO: Allowing partial matches seems misguided. We should discourage that, or just remove it.
|
||||
// string1 - [in,out] string to compare, if comparision is equal insensitive but not sensitive, will replace with sensitive version
|
||||
// string2 - second string to compare
|
||||
// alternate - alternate naming of the string
|
||||
|
|
|
@ -310,7 +310,7 @@ public:
|
|||
for (NDLNode* param : m_parameters)
|
||||
{
|
||||
bool optParam = param->GetType() == ndlTypeOptionalParameter;
|
||||
if (optParam && !_stricmp(param->GetName().c_str(), name.c_str()))
|
||||
if (optParam && EqualCI(param->GetName().c_str(), name.c_str()))
|
||||
{
|
||||
auto paramValue = param->GetValue();
|
||||
auto resolveParamNode = m_parent->ParseVariable(paramValue, false);
|
||||
|
|
|
@ -1771,27 +1771,27 @@ template class SimpleNetworkBuilder<double>;
|
|||
|
||||
TrainingCriterion ParseTrainingCriterionString(wstring s)
|
||||
{
|
||||
if (!_wcsicmp(s.c_str(), L"crossEntropyWithSoftmax")) return TrainingCriterion::CrossEntropyWithSoftmax;
|
||||
else if (!_wcsicmp(s.c_str(), L"sequenceWithSoftmax")) return TrainingCriterion::SequenceWithSoftmax;
|
||||
else if (!_wcsicmp(s.c_str(), L"squareError")) return TrainingCriterion::SquareError;
|
||||
else if (!_wcsicmp(s.c_str(), L"logistic")) return TrainingCriterion::Logistic;
|
||||
else if (!_wcsicmp(s.c_str(), L"noiseContrastiveEstimation") || !_wcsicmp(s.c_str(), L"noiseContrastiveEstimationNode" /*spelling error, deprecated*/))
|
||||
if (EqualCI(s.c_str(), L"crossEntropyWithSoftmax")) return TrainingCriterion::CrossEntropyWithSoftmax;
|
||||
else if (EqualCI(s.c_str(), L"sequenceWithSoftmax")) return TrainingCriterion::SequenceWithSoftmax;
|
||||
else if (EqualCI(s.c_str(), L"squareError")) return TrainingCriterion::SquareError;
|
||||
else if (EqualCI(s.c_str(), L"logistic")) return TrainingCriterion::Logistic;
|
||||
else if (EqualCI(s.c_str(), L"noiseContrastiveEstimation") || EqualCI(s.c_str(), L"noiseContrastiveEstimationNode" /*spelling error, deprecated*/))
|
||||
return TrainingCriterion::NCECrossEntropyWithSoftmax;
|
||||
else if (!!_wcsicmp(s.c_str(), L"classCrossEntropyWithSoftmax")) // (twisted logic to keep compiler happy w.r.t. not returning from LogicError)
|
||||
else if (!EqualCI(s.c_str(), L"classCrossEntropyWithSoftmax")) // (twisted logic to keep compiler happy w.r.t. not returning from LogicError)
|
||||
LogicError("trainingCriterion: Invalid trainingCriterion value. Valid values are (crossEntropyWithSoftmax | squareError | logistic | classCrossEntropyWithSoftmax| sequenceWithSoftmax)");
|
||||
return TrainingCriterion::ClassCrossEntropyWithSoftmax;
|
||||
}
|
||||
|
||||
EvalCriterion ParseEvalCriterionString(wstring s)
|
||||
{
|
||||
if (!_wcsicmp(s.c_str(), L"errorPrediction")) return EvalCriterion::ErrorPrediction;
|
||||
else if (!_wcsicmp(s.c_str(), L"crossEntropyWithSoftmax")) return EvalCriterion::CrossEntropyWithSoftmax;
|
||||
else if (!_wcsicmp(s.c_str(), L"sequenceWithSoftmax")) return EvalCriterion::SequenceWithSoftmax;
|
||||
else if (!_wcsicmp(s.c_str(), L"classCrossEntropyWithSoftmax")) return EvalCriterion::ClassCrossEntropyWithSoftmax;
|
||||
else if (!_wcsicmp(s.c_str(), L"logistic")) return EvalCriterion::Logistic;
|
||||
else if (!_wcsicmp(s.c_str(), L"noiseContrastiveEstimation") || !_wcsicmp(s.c_str(), L"noiseContrastiveEstimationNode" /*spelling error, deprecated*/))
|
||||
if (EqualCI(s.c_str(), L"errorPrediction")) return EvalCriterion::ErrorPrediction;
|
||||
else if (EqualCI(s.c_str(), L"crossEntropyWithSoftmax")) return EvalCriterion::CrossEntropyWithSoftmax;
|
||||
else if (EqualCI(s.c_str(), L"sequenceWithSoftmax")) return EvalCriterion::SequenceWithSoftmax;
|
||||
else if (EqualCI(s.c_str(), L"classCrossEntropyWithSoftmax")) return EvalCriterion::ClassCrossEntropyWithSoftmax;
|
||||
else if (EqualCI(s.c_str(), L"logistic")) return EvalCriterion::Logistic;
|
||||
else if (EqualCI(s.c_str(), L"noiseContrastiveEstimation") || EqualCI(s.c_str(), L"noiseContrastiveEstimationNode" /*spelling error, deprecated*/))
|
||||
return EvalCriterion::NCECrossEntropyWithSoftmax;
|
||||
else if (!!_wcsicmp(s.c_str(), L"squareError"))
|
||||
else if (!EqualCI(s.c_str(), L"squareError"))
|
||||
LogicError("evalCriterion: Invalid trainingCriterion value. Valid values are (errorPrediction | crossEntropyWithSoftmax | squareError | logistic | sequenceWithSoftmax)");
|
||||
return EvalCriterion::SquareError;
|
||||
}
|
||||
|
|
|
@ -137,13 +137,13 @@ void SynchronousNodeEvaluator<ElemType>::Evaluate(NDLNode<ElemType>* node, const
|
|||
bool initOnCPUOnly = node->GetOptionalParameter("initOnCPUOnly", "false");
|
||||
int forcedRandomSeed = node->GetOptionalParameter("randomSeed", "-1" /*disabled*/);
|
||||
|
||||
if (!_wcsicmp(initString.c_str(), L"fixedValue"))
|
||||
if (EqualCI(initString.c_str(), L"fixedValue"))
|
||||
nodePtr->Value().SetValue(value);
|
||||
else if (!_wcsicmp(initString.c_str(), L"uniform"))
|
||||
else if (EqualCI(initString.c_str(), L"uniform"))
|
||||
m_net->InitLearnableParameters(nodePtr, true, forcedRandomSeed < 0 ? randomSeed++ : (unsigned long) forcedRandomSeed, initValueScale, initOnCPUOnly);
|
||||
else if (!_wcsicmp(initString.c_str(), L"gaussian"))
|
||||
else if (EqualCI(initString.c_str(), L"gaussian"))
|
||||
m_net->InitLearnableParameters(nodePtr, false, forcedRandomSeed < 0 ? randomSeed++ : (unsigned long) forcedRandomSeed, initValueScale, initOnCPUOnly);
|
||||
else if (!_wcsicmp(initString.c_str(), L"fromFile"))
|
||||
else if (EqualCI(initString.c_str(), L"fromFile"))
|
||||
{
|
||||
std::string initFromFilePath = node->GetOptionalParameter("initFromFilePath", "");
|
||||
if (initFromFilePath == "")
|
||||
|
|
|
@ -269,35 +269,19 @@ public:
|
|||
// loop through all the optional parameters processing them as necessary
|
||||
for (NDLNode<ElemType>* param : params)
|
||||
{
|
||||
// make sure it's a "tag" optional parameter, that's all we process currently
|
||||
if (_stricmp(param->GetName().c_str(), "tag"))
|
||||
// we only process the "tag" optional parameter for now
|
||||
if (!EqualCI(param->GetName().c_str(), "tag"))
|
||||
continue;
|
||||
|
||||
std::string value = param->GetValue();
|
||||
if (!_stricmp(value.c_str(), "feature"))
|
||||
{
|
||||
SetOutputNode(m_net->FeatureNodes(), compNode);
|
||||
}
|
||||
else if (!_stricmp(value.c_str(), "label"))
|
||||
{
|
||||
SetOutputNode(m_net->LabelNodes(), compNode);
|
||||
}
|
||||
else if (!_stricmp(value.c_str(), "criterion") || !_stricmp(value.c_str(), "criteria"))
|
||||
{
|
||||
SetOutputNode(m_net->FinalCriterionNodes(), compNode);
|
||||
}
|
||||
else if (!_stricmp(value.c_str(), "multiSeq"))
|
||||
{
|
||||
fprintf(stderr, "'multiSeq' tag is defunct.\n");
|
||||
}
|
||||
else if (!_strnicmp(value.c_str(), "eval", 4)) // only compare the first 4 characters. Yikes!!
|
||||
{
|
||||
SetOutputNode(m_net->EvaluationNodes(), compNode);
|
||||
}
|
||||
else if (!_stricmp(value.c_str(), "output"))
|
||||
{
|
||||
SetOutputNode(m_net->OutputNodes(), compNode);
|
||||
}
|
||||
if (EqualCI(value.c_str(), "feature")) SetOutputNode(m_net->FeatureNodes(), compNode);
|
||||
else if (EqualCI(value.c_str(), "label")) SetOutputNode(m_net->LabelNodes(), compNode);
|
||||
else if (EqualCI(value.c_str(), "criterion")) SetOutputNode(m_net->FinalCriterionNodes(), compNode);
|
||||
else if (!_strnicmp(value.c_str(), "eval", 4)) SetOutputNode(m_net->EvaluationNodes(), compNode); // only compare the first 4 characters. Yikes!!
|
||||
else if (EqualCI(value.c_str(), "output")) SetOutputNode(m_net->OutputNodes(), compNode);
|
||||
// legacy
|
||||
else if (EqualCI(value.c_str(), "criteria")) SetOutputNode(m_net->FinalCriterionNodes(), compNode); // legacy (mis-spelled)
|
||||
else if (EqualCI(value.c_str(), "multiSeq")) fprintf(stderr, "'multiSeq' tag is defunct.\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -284,7 +284,7 @@ void TestConfiguration(const ConfigParameters& configBase)
|
|||
if (initData.size() > 0)
|
||||
initValueScale = initData[0];
|
||||
if (initData.size() > 1)
|
||||
uniform = !_stricmp(initData[1], "uniform");
|
||||
uniform = EqualCI(initData[1], "uniform");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -188,9 +188,9 @@ DEVICEID_TYPE DeviceFromConfig(const ConfigParameters& config)
|
|||
ConfigValue val = config("deviceId", "auto");
|
||||
bool bLockGPU = config(L"lockGPU", true);
|
||||
|
||||
if (!_stricmp(val.c_str(), "cpu"))
|
||||
if (EqualCI(val.c_str(), "cpu"))
|
||||
return SelectDevice(CPUDEVICE, false);
|
||||
else if (!_stricmp(val.c_str(), "auto"))
|
||||
else if (EqualCI(val.c_str(), "auto"))
|
||||
return SelectDevice(DEVICEID_AUTO, bLockGPU);
|
||||
else
|
||||
return SelectDevice((int) val, bLockGPU);
|
||||
|
|
|
@ -1158,7 +1158,7 @@ public:
|
|||
bool Match(const std::string& key, const std::string& compareValue) const
|
||||
{
|
||||
std::string value = Find(key);
|
||||
return !_stricmp(compareValue.c_str(), value.c_str());
|
||||
return EqualCI(compareValue.c_str(), value.c_str());
|
||||
}
|
||||
bool Match(const std::wstring& key, const std::wstring& compareValue) const
|
||||
{
|
||||
|
|
|
@ -866,7 +866,7 @@ inline bool IConfigRecord::Match(const std::wstring &id, const std::wstring &com
|
|||
{
|
||||
auto *valp = Find(id);
|
||||
std::wstring val = valp ? *valp : std::wstring();
|
||||
return !_wcsicmp(compareValue.c_str(), val.c_str());
|
||||
return EqualCI(compareValue.c_str(), val.c_str());
|
||||
}
|
||||
inline const std::string IConfigRecord::ConfigName() const
|
||||
{
|
||||
|
|
|
@ -1158,39 +1158,39 @@ void SectionStats::Store()
|
|||
for (int i = 0; i < GetElementCount(); i++)
|
||||
{
|
||||
auto stat = GetElement<NumericStatistics>(i);
|
||||
if (!_stricmp(stat->statistic, "sum"))
|
||||
if (EqualCI(stat->statistic, "sum"))
|
||||
{
|
||||
stat->value = m_sum;
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "count"))
|
||||
else if (EqualCI(stat->statistic, "count"))
|
||||
{
|
||||
stat->value = (double) m_count;
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "mean"))
|
||||
else if (EqualCI(stat->statistic, "mean"))
|
||||
{
|
||||
stat->value = m_mean;
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "max"))
|
||||
else if (EqualCI(stat->statistic, "max"))
|
||||
{
|
||||
stat->value = m_max;
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "min"))
|
||||
else if (EqualCI(stat->statistic, "min"))
|
||||
{
|
||||
stat->value = m_min;
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "range"))
|
||||
else if (EqualCI(stat->statistic, "range"))
|
||||
{
|
||||
stat->value = abs(m_max - m_min);
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "rootmeansquare"))
|
||||
else if (EqualCI(stat->statistic, "rootmeansquare"))
|
||||
{
|
||||
stat->value = m_rms;
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "variance"))
|
||||
else if (EqualCI(stat->statistic, "variance"))
|
||||
{
|
||||
stat->value = m_variance;
|
||||
}
|
||||
else if (!_stricmp(stat->statistic, "stddev"))
|
||||
else if (EqualCI(stat->statistic, "stddev"))
|
||||
{
|
||||
stat->value = m_stddev;
|
||||
}
|
||||
|
@ -1255,7 +1255,7 @@ void SectionStats::SetCompute(const std::string& name, double value)
|
|||
for (int i = 0; i < GetElementCount(); i++)
|
||||
{
|
||||
auto stat = GetElement<NumericStatistics>(i);
|
||||
if (!_stricmp(stat->statistic, name.c_str()))
|
||||
if (EqualCI(stat->statistic, name.c_str()))
|
||||
{
|
||||
stat->value = value;
|
||||
break;
|
||||
|
@ -1271,7 +1271,7 @@ double SectionStats::GetCompute(const std::string& name)
|
|||
for (int i = 0; i < GetElementCount(); i++)
|
||||
{
|
||||
auto stat = GetElement<NumericStatistics>(i);
|
||||
if (!_stricmp(stat->statistic, name.c_str()))
|
||||
if (EqualCI(stat->statistic, name.c_str()))
|
||||
{
|
||||
return stat->value;
|
||||
}
|
||||
|
|
|
@ -126,7 +126,7 @@ void BinaryReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig
|
|||
|
||||
// determine if partial minibatches are desired
|
||||
std::string minibatchMode(readerConfig(L"minibatchMode", "Partial"));
|
||||
m_partialMinibatch = !_stricmp(minibatchMode.c_str(), "Partial");
|
||||
m_partialMinibatch = EqualCI(minibatchMode.c_str(), "Partial");
|
||||
|
||||
// Initial load is complete
|
||||
DisplayProperties();
|
||||
|
|
|
@ -130,7 +130,7 @@ Section* BinaryWriter<ElemType>::CreateSection(const ConfigParameters& config, S
|
|||
wstring type = config(L"sectionType");
|
||||
for (int i = 0; i < sectionTypeMax; i++)
|
||||
{
|
||||
if (!_wcsicmp(type.c_str(), SectionTypeStrings[i]))
|
||||
if (EqualCI(type.c_str(), SectionTypeStrings[i]))
|
||||
{
|
||||
foundType = SectionType(i);
|
||||
break;
|
||||
|
|
|
@ -197,7 +197,7 @@ void DSSMReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig)
|
|||
}
|
||||
|
||||
std::string minibatchMode(readerConfig(L"minibatchMode", "Partial"));
|
||||
m_partialMinibatch = !_stricmp(minibatchMode.c_str(), "Partial");
|
||||
m_partialMinibatch = EqualCI(minibatchMode.c_str(), "Partial");
|
||||
|
||||
// Get the config parameters for query feature and doc feature
|
||||
ConfigParameters configFeaturesQuery = readerConfig(m_featuresNameQuery, "");
|
||||
|
|
|
@ -158,7 +158,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
m_featDims[i] = m_featDims[i] * (1 + numContextLeft[i] + numContextRight[i]);
|
||||
|
||||
wstring type = thisFeature(L"type", L"real");
|
||||
if (!_wcsicmp(type.c_str(), L"real"))
|
||||
if (EqualCI(type.c_str(), L"real"))
|
||||
{
|
||||
m_nameToTypeMap[featureNames[i]] = InputOutputTypes::real;
|
||||
}
|
||||
|
@ -194,7 +194,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
else
|
||||
type = (const wstring&) thisLabel(L"type", L"category"); // outputs should default to category
|
||||
|
||||
if (!_wcsicmp(type.c_str(), L"category"))
|
||||
if (EqualCI(type.c_str(), L"category"))
|
||||
m_nameToTypeMap[labelNames[i]] = InputOutputTypes::category;
|
||||
else
|
||||
InvalidArgument("label type must be 'category'");
|
||||
|
@ -289,9 +289,9 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
if (readerConfig.Exists(L"randomize"))
|
||||
{
|
||||
wstring randomizeString = readerConfig.CanBeString(L"randomize") ? readerConfig(L"randomize") : wstring();
|
||||
if (!_wcsicmp(randomizeString.c_str(), L"none"))
|
||||
if (EqualCI(randomizeString.c_str(), L"none"))
|
||||
randomize = randomizeNone;
|
||||
else if (!_wcsicmp(randomizeString.c_str(), L"auto"))
|
||||
else if (EqualCI(randomizeString.c_str(), L"auto"))
|
||||
randomize = randomizeAuto;
|
||||
else
|
||||
randomize = readerConfig(L"randomize");
|
||||
|
@ -302,7 +302,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
|
||||
// determine if we partial minibatches are desired
|
||||
wstring minibatchMode(readerConfig(L"minibatchMode", L"partial"));
|
||||
m_partialMinibatch = !_wcsicmp(minibatchMode.c_str(), L"partial");
|
||||
m_partialMinibatch = EqualCI(minibatchMode.c_str(), L"partial");
|
||||
|
||||
// get the read method, defaults to "blockRandomize" other option is "rollingWindow"
|
||||
wstring readMethod(readerConfig(L"readMethod", L"blockRandomize"));
|
||||
|
@ -447,7 +447,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
labelsmulti.push_back(std::move(labels));
|
||||
}
|
||||
|
||||
if (!_wcsicmp(readMethod.c_str(), L"blockRandomize"))
|
||||
if (EqualCI(readMethod.c_str(), L"blockRandomize"))
|
||||
{
|
||||
// construct all the parameters we don't need, but need to be passed to the constructor...
|
||||
|
||||
|
@ -458,7 +458,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(infilesmulti, labelsmulti, m_featDims, m_labelDims, numContextLeft, numContextRight, randomize, *m_lattices, m_latticeMap, m_frameMode));
|
||||
m_frameSource->setverbosity(m_verbosity);
|
||||
}
|
||||
else if (!_wcsicmp(readMethod.c_str(), L"rollingWindow"))
|
||||
else if (EqualCI(readMethod.c_str(), L"rollingWindow"))
|
||||
{
|
||||
std::wstring pageFilePath;
|
||||
std::vector<std::wstring> pagePaths;
|
||||
|
@ -585,7 +585,7 @@ void HTKMLFReader<ElemType>::PrepareForWriting(const ConfigRecordType& readerCon
|
|||
realDims[i] = realDims[i] * (1 + numContextLeft[i] + numContextRight[i]);
|
||||
|
||||
wstring type = thisFeature(L"type", L"real");
|
||||
if (!_wcsicmp(type.c_str(), L"real"))
|
||||
if (EqualCI(type.c_str(), L"real"))
|
||||
{
|
||||
m_nameToTypeMap[featureNames[i]] = InputOutputTypes::real;
|
||||
}
|
||||
|
|
|
@ -109,7 +109,7 @@ void HTKMLFReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig
|
|||
|
||||
// Checks if partial minibatches are allowed.
|
||||
std::string minibatchMode(readerConfig(L"minibatchMode", "Partial"));
|
||||
m_partialMinibatch = !_stricmp(minibatchMode.c_str(), "Partial");
|
||||
m_partialMinibatch = EqualCI(minibatchMode.c_str(), "Partial");
|
||||
|
||||
// Figures out if we have to do minibatch buffering and how.
|
||||
if (m_doSeqTrain)
|
||||
|
@ -195,12 +195,12 @@ void HTKMLFReader<ElemType>::PrepareForSequenceTraining(const ConfigRecordType&
|
|||
if (temp.ExistsCurrent(L"type"))
|
||||
{
|
||||
wstring type = temp(L"type");
|
||||
if (!_wcsicmp(type.c_str(), L"readerDeriv") || !_wcsicmp(type.c_str(), L"seqTrainDeriv") /*for back compatibility */)
|
||||
if (EqualCI(type.c_str(), L"readerDeriv") || EqualCI(type.c_str(), L"seqTrainDeriv") /*for back compatibility */)
|
||||
{
|
||||
m_nameToTypeMap[id] = InputOutputTypes::readerDeriv;
|
||||
hasDrive = true;
|
||||
}
|
||||
else if (!_wcsicmp(type.c_str(), L"readerObj") || !_wcsicmp(type.c_str(), L"seqTrainObj") /*for back compatibility */)
|
||||
else if (EqualCI(type.c_str(), L"readerObj") || EqualCI(type.c_str(), L"seqTrainObj") /*for back compatibility */)
|
||||
{
|
||||
m_nameToTypeMap[id] = InputOutputTypes::readerObj;
|
||||
hasObj = true;
|
||||
|
@ -305,7 +305,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
|
||||
// Figures out the category.
|
||||
wstring type = thisFeature(L"type", L"real");
|
||||
if (!_wcsicmp(type.c_str(), L"real"))
|
||||
if (EqualCI(type.c_str(), L"real"))
|
||||
{
|
||||
m_nameToTypeMap[featureNames[i]] = InputOutputTypes::real;
|
||||
}
|
||||
|
@ -348,7 +348,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
type = (const wstring&) thisLabel(L"labelType"); // let's deprecate this eventually and just use "type"...
|
||||
else
|
||||
type = (const wstring&) thisLabel(L"type", L"category"); // outputs should default to category
|
||||
if (!_wcsicmp(type.c_str(), L"category"))
|
||||
if (EqualCI(type.c_str(), L"category"))
|
||||
m_nameToTypeMap[labelNames[i]] = InputOutputTypes::category;
|
||||
else
|
||||
InvalidArgument("label type must be Category");
|
||||
|
@ -402,11 +402,11 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
if (readerConfig.Exists(L"randomize"))
|
||||
{
|
||||
const std::string& randomizeString = readerConfig(L"randomize");
|
||||
if (!_stricmp(randomizeString.c_str(), "none"))
|
||||
if (EqualCI(randomizeString.c_str(), "none"))
|
||||
{
|
||||
randomize = randomizeNone;
|
||||
}
|
||||
else if (!_stricmp(randomizeString.c_str(), "auto"))
|
||||
else if (EqualCI(randomizeString.c_str(), "auto"))
|
||||
{
|
||||
randomize = randomizeAuto;
|
||||
}
|
||||
|
@ -464,7 +464,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
// option is "rollingWindow". We only support "blockRandomize" in
|
||||
// sequence training.
|
||||
std::string readMethod(readerConfig(L"readMethod", "blockRandomize"));
|
||||
if (!_stricmp(readMethod.c_str(), "blockRandomize"))
|
||||
if (EqualCI(readMethod.c_str(), "blockRandomize"))
|
||||
{
|
||||
// construct all the parameters we don't need, but need to be passed to the constructor...
|
||||
std::pair<std::vector<wstring>, std::vector<wstring>> latticetocs;
|
||||
|
@ -480,7 +480,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
scriptpaths, infilesmulti, labelsmulti, m_featDims, m_labelDims,
|
||||
numContextLeft, numContextRight, randomize, *m_lattices, m_latticeMap, m_framemode);
|
||||
}
|
||||
else if (!_stricmp(readMethod.c_str(), "rollingWindow"))
|
||||
else if (EqualCI(readMethod.c_str(), "rollingWindow"))
|
||||
{
|
||||
// "rollingWindow" is not supported in sequence training.
|
||||
if (m_doSeqTrain)
|
||||
|
|
|
@ -445,9 +445,9 @@ void SequenceParser<float, std::string>::StoreLabel(float /*finalResult*/)
|
|||
if (m_spaceDelimitedMax <= m_spaceDelimitedStart)
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
std::string label((LPCSTR) &m_fileBuffer[m_spaceDelimitedStart - m_bufferStart], m_spaceDelimitedMax - m_spaceDelimitedStart);
|
||||
if (!m_beginSequence && !_stricmp(label.c_str(), m_beginTag.c_str()))
|
||||
if (!m_beginSequence && EqualCI(label.c_str(), m_beginTag.c_str()))
|
||||
m_beginSequence = true;
|
||||
if (!m_endSequence && !_stricmp(label.c_str(), m_endTag.c_str()))
|
||||
if (!m_endSequence && EqualCI(label.c_str(), m_endTag.c_str()))
|
||||
m_endSequence = true;
|
||||
m_labels->push_back(move(label));
|
||||
m_labelsConvertedThisLine++;
|
||||
|
@ -489,9 +489,9 @@ void SequenceParser<double, std::string>::StoreLabel(double /*finalResult*/)
|
|||
if (m_spaceDelimitedMax <= m_spaceDelimitedStart)
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
std::string label((LPCSTR) &m_fileBuffer[m_spaceDelimitedStart - m_bufferStart], m_spaceDelimitedMax - m_spaceDelimitedStart);
|
||||
if (!m_beginSequence && !_stricmp(label.c_str(), m_beginTag.c_str()))
|
||||
if (!m_beginSequence && EqualCI(label.c_str(), m_beginTag.c_str()))
|
||||
m_beginSequence = true;
|
||||
if (!m_endSequence && !_stricmp(label.c_str(), m_endTag.c_str()))
|
||||
if (!m_endSequence && EqualCI(label.c_str(), m_endTag.c_str()))
|
||||
m_endSequence = true;
|
||||
m_labels->push_back(move(label));
|
||||
m_labelsConvertedThisLine++;
|
||||
|
|
|
@ -167,7 +167,7 @@ bool SequenceReader<ElemType>::EnsureDataAvailable(size_t mbStartSample, bool /*
|
|||
continue; // empty input
|
||||
|
||||
// check for end of sequence marker
|
||||
if (!bSentenceStart && (!_stricmp(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()) || ((label - 1) % m_mbSize == 0)))
|
||||
if (!bSentenceStart && (EqualCI(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()) || ((label - 1) % m_mbSize == 0)))
|
||||
{
|
||||
// ignore those cases where $</s> is put in the begining, because those are used for initialization purpose
|
||||
spos.flags |= seqFlagStopLabel;
|
||||
|
@ -183,7 +183,7 @@ bool SequenceReader<ElemType>::EnsureDataAvailable(size_t mbStartSample, bool /*
|
|||
RuntimeError("read sentence length is longer than the minibatch size. should be smaller. increase the minibatch size to at least %d", (int) epochSample);
|
||||
}
|
||||
|
||||
if (!_stricmp(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()))
|
||||
if (EqualCI(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()))
|
||||
continue; // ignore sentence ending
|
||||
}
|
||||
|
||||
|
@ -235,7 +235,7 @@ bool SequenceReader<ElemType>::EnsureDataAvailable(size_t mbStartSample, bool /*
|
|||
{
|
||||
// this is the next word (label was incremented above)
|
||||
labelValue = labelTemp[label];
|
||||
if (!_stricmp(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()))
|
||||
if (EqualCI(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()))
|
||||
{
|
||||
labelValue = labelInfo.endSequence;
|
||||
}
|
||||
|
@ -1529,11 +1529,11 @@ void BatchSequenceReader<ElemType>::InitFromConfig(const ConfigRecordType& reade
|
|||
if (readerConfig.Exists(L"randomize"))
|
||||
{
|
||||
string randomizeString = readerConfig(L"randomize");
|
||||
if (!_stricmp(randomizeString.c_str(), "none"))
|
||||
if (EqualCI(randomizeString.c_str(), "none"))
|
||||
{
|
||||
;
|
||||
}
|
||||
else if (!_stricmp(randomizeString.c_str(), "auto"))
|
||||
else if (EqualCI(randomizeString.c_str(), "auto"))
|
||||
{
|
||||
;
|
||||
}
|
||||
|
@ -1812,7 +1812,7 @@ bool BatchSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample*/
|
|||
{
|
||||
// this is the next word (label was incremented above)
|
||||
labelValue = m_labelTemp[label];
|
||||
if (!_stricmp(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()))
|
||||
if (EqualCI(labelValue.c_str(), m_labelInfo[labelInfoIn].endSequence.c_str()))
|
||||
{
|
||||
labelValue = labelInfo.endSequence;
|
||||
}
|
||||
|
|
|
@ -365,7 +365,7 @@ void BatchLUSequenceReader<ElemType>::InitFromConfig(const ConfigRecordType& rea
|
|||
|
||||
// determine label type desired
|
||||
wstring labelType(labelConfig(L"labelType", L"category"));
|
||||
if (!_wcsicmp(labelType.c_str(), L"category"))
|
||||
if (EqualCI(labelType.c_str(), L"category"))
|
||||
{
|
||||
m_labelInfo[index].type = labelCategory;
|
||||
}
|
||||
|
@ -429,11 +429,11 @@ void BatchLUSequenceReader<ElemType>::InitFromConfig(const ConfigRecordType& rea
|
|||
if (readerConfig.Exists(L"randomize"))
|
||||
{
|
||||
string randomizeString = readerConfig(L"randomize");
|
||||
if (!_stricmp(randomizeString.c_str(), "none"))
|
||||
if (EqualCI(randomizeString.c_str(), "none"))
|
||||
{
|
||||
;
|
||||
}
|
||||
else if (!_stricmp(randomizeString.c_str(), "auto") || !_stricmp(randomizeString.c_str(), "true"))
|
||||
else if (EqualCI(randomizeString.c_str(), "auto") || EqualCI(randomizeString.c_str(), "true"))
|
||||
{
|
||||
mRandomize = true;
|
||||
}
|
||||
|
|
|
@ -683,7 +683,7 @@ void LibSVMBinaryReader<ElemType>::InitFromConfig(const ConfigRecordType& reader
|
|||
|
||||
m_partialMinibatch = true;
|
||||
std::string minibatchMode(readerConfig(L"minibatchMode", "Partial"));
|
||||
m_partialMinibatch = !_stricmp(minibatchMode.c_str(), "Partial");
|
||||
m_partialMinibatch = EqualCI(minibatchMode.c_str(), "Partial");
|
||||
|
||||
std::wstring file = readerConfig(L"file", L"");
|
||||
|
||||
|
|
|
@ -336,11 +336,11 @@ void UCIFastReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfi
|
|||
if (readerConfig.Exists(L"randomize"))
|
||||
{
|
||||
string randomizeString = readerConfig(L"randomize");
|
||||
if (!_stricmp(randomizeString.c_str(), "none"))
|
||||
if (EqualCI(randomizeString.c_str(), "none"))
|
||||
{
|
||||
m_randomizeRange = randomizeNone;
|
||||
}
|
||||
else if (!_stricmp(randomizeString.c_str(), "auto"))
|
||||
else if (EqualCI(randomizeString.c_str(), "auto"))
|
||||
{
|
||||
m_randomizeRange = randomizeAuto;
|
||||
}
|
||||
|
@ -356,7 +356,7 @@ void UCIFastReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfi
|
|||
|
||||
// determine if we partial minibatches are desired
|
||||
std::string minibatchMode(readerConfig(L"minibatchMode", "partial"));
|
||||
m_partialMinibatch = !_stricmp(minibatchMode.c_str(), "partial");
|
||||
m_partialMinibatch = EqualCI(minibatchMode.c_str(), "partial");
|
||||
|
||||
// get start and dimensions for labels and features
|
||||
size_t startLabels = configLabels(L"start", (size_t) 0);
|
||||
|
@ -373,15 +373,15 @@ void UCIFastReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfi
|
|||
labelType = (wstring) configLabels(L"labelType", L"category");
|
||||
|
||||
// convert to lower case for case insensitive comparison
|
||||
if (!_wcsicmp(labelType.c_str(), L"category"))
|
||||
if (EqualCI(labelType.c_str(), L"category"))
|
||||
{
|
||||
m_labelType = labelCategory;
|
||||
}
|
||||
else if (!_wcsicmp(labelType.c_str(), L"regression"))
|
||||
else if (EqualCI(labelType.c_str(), L"regression"))
|
||||
{
|
||||
m_labelType = labelRegression;
|
||||
}
|
||||
else if (!_wcsicmp(labelType.c_str(), L"none"))
|
||||
else if (EqualCI(labelType.c_str(), L"none"))
|
||||
{
|
||||
m_labelType = labelNone;
|
||||
dimLabels = 0; // override for no labels
|
||||
|
|
|
@ -2349,51 +2349,41 @@ template class SGD<double>;
|
|||
|
||||
static AdaptationRegType ParseAdaptationRegType(const wstring& s)
|
||||
{
|
||||
if (!_wcsicmp(s.c_str(), L"") || !_wcsicmp(s.c_str(), L"none"))
|
||||
return AdaptationRegType::None;
|
||||
else if (!_wcsicmp(s.c_str(), L"kl") || !_wcsicmp(s.c_str(), L"klReg"))
|
||||
return AdaptationRegType::KL;
|
||||
if (EqualCI(s.c_str(), L"") || EqualCI(s.c_str(), L"none")) return AdaptationRegType::None;
|
||||
else if (EqualCI(s.c_str(), L"kl") || EqualCI(s.c_str(), L"klReg")) return AdaptationRegType::KL;
|
||||
else
|
||||
InvalidArgument("ParseAdaptationRegType: Invalid Adaptation Regularization Type. Valid values are (none | kl)");
|
||||
}
|
||||
|
||||
static GradientsUpdateType ParseGradUpdateType(const wstring& s)
|
||||
{
|
||||
if (!_wcsicmp(s.c_str(), L"") || !_wcsicmp(s.c_str(), L"none") || !_wcsicmp(s.c_str(), L"normal") || !_wcsicmp(s.c_str(), L"simple"))
|
||||
return GradientsUpdateType::None;
|
||||
else if (!_wcsicmp(s.c_str(), L"adagrad"))
|
||||
return GradientsUpdateType::AdaGrad;
|
||||
else if (!_wcsicmp(s.c_str(), L"rmsProp"))
|
||||
return GradientsUpdateType::RmsProp;
|
||||
else if (!_wcsicmp(s.c_str(), L"fsAdagrad"))
|
||||
return GradientsUpdateType::FSAdaGrad;
|
||||
else
|
||||
InvalidArgument("ParseGradUpdateType: Invalid Gradient Updating Type. Valid values are (none | adagrad | rmsProp | fsAdagrad )");
|
||||
if (EqualCI(s.c_str(), L"") || EqualCI(s.c_str(), L"none")) return GradientsUpdateType::None;
|
||||
else if (EqualCI(s.c_str(), L"adagrad")) return GradientsUpdateType::AdaGrad;
|
||||
else if (EqualCI(s.c_str(), L"rmsProp")) return GradientsUpdateType::RmsProp;
|
||||
else if (EqualCI(s.c_str(), L"fsAdagrad")) return GradientsUpdateType::FSAdaGrad;
|
||||
// legacy
|
||||
else if (EqualCI(s.c_str(), L"normal") || EqualCI(s.c_str(), L"simple")) return GradientsUpdateType::None;
|
||||
else InvalidArgument("ParseGradUpdateType: Invalid Gradient Updating Type. Valid values are (none | adagrad | rmsProp | fsAdagrad )");
|
||||
}
|
||||
|
||||
static ParallelizationMethod ParseParallelizationMethod(const wstring& s)
|
||||
{
|
||||
if (!_wcsicmp(s.c_str(), L"") || !_wcsicmp(s.c_str(), L"none"))
|
||||
return ParallelizationMethod::None;
|
||||
else if (!_wcsicmp(s.c_str(), L"DataParallelSGD"))
|
||||
return ParallelizationMethod::DataParallelSGD;
|
||||
else if (!_wcsicmp(s.c_str(), L"ModelAveragingSGD"))
|
||||
return ParallelizationMethod::ModelAveragingSGD;
|
||||
else
|
||||
InvalidArgument("ParseParallelizationMethod: Invalid Parallelization Method. Valid values are (none | dataParallelSGD | modelAveragingSGD)");
|
||||
if (EqualCI(s.c_str(), L"") || EqualCI(s.c_str(), L"none")) return ParallelizationMethod::None;
|
||||
else if (EqualCI(s.c_str(), L"DataParallelSGD")) return ParallelizationMethod::DataParallelSGD;
|
||||
else if (EqualCI(s.c_str(), L"ModelAveragingSGD")) return ParallelizationMethod::ModelAveragingSGD;
|
||||
else InvalidArgument("ParseParallelizationMethod: Invalid Parallelization Method. Valid values are (none | dataParallelSGD | modelAveragingSGD)");
|
||||
}
|
||||
|
||||
static LearningRateSearchAlgorithm ParseLearningRateSearchType(const wstring& s)
|
||||
{
|
||||
// TODO: why allow so many variants?
|
||||
if (!_wcsicmp(s.c_str(), L"false") || !_wcsicmp(s.c_str(), L"none"))
|
||||
return LearningRateSearchAlgorithm::None;
|
||||
else if (!_wcsicmp(s.c_str(), L"searchBeforeEpoch") || !_wcsicmp(s.c_str(), L"beforeEpoch" /*legacy, deprecated*/) || !_wcsicmp(s.c_str(), L"before" /*legacy, deprecated*/))
|
||||
return LearningRateSearchAlgorithm::SearchBeforeEpoch;
|
||||
else if (!_wcsicmp(s.c_str(), L"adjustAfterEpoch") || !_wcsicmp(s.c_str(), L"afterEpoch" /*legacy, deprecated*/) || !_wcsicmp(s.c_str(), L"after" /*legacy, deprecated*/))
|
||||
return LearningRateSearchAlgorithm::AdjustAfterEpoch;
|
||||
else
|
||||
InvalidArgument("autoAdjustLR: Invalid learning rate search type. Valid values are (none | searchBeforeEpoch | adjustAfterEpoch)");
|
||||
if (EqualCI(s.c_str(), L"false") || EqualCI(s.c_str(), L"none")) return LearningRateSearchAlgorithm::None;
|
||||
else if (EqualCI(s.c_str(), L"searchBeforeEpoch")) return LearningRateSearchAlgorithm::SearchBeforeEpoch;
|
||||
else if (EqualCI(s.c_str(), L"adjustAfterEpoch")) return LearningRateSearchAlgorithm::AdjustAfterEpoch;
|
||||
// legacy
|
||||
else if (EqualCI(s.c_str(), L"beforeEpoch" /*legacy, deprecated*/) || EqualCI(s.c_str(), L"before" /*legacy, deprecated*/)) return LearningRateSearchAlgorithm::SearchBeforeEpoch;
|
||||
else if (EqualCI(s.c_str(), L"afterEpoch" /*legacy, deprecated*/) || EqualCI(s.c_str(), L"after" /*legacy, deprecated*/)) return LearningRateSearchAlgorithm::AdjustAfterEpoch;
|
||||
else InvalidArgument("autoAdjustLR: Invalid learning rate search type. Valid values are (none | searchBeforeEpoch | adjustAfterEpoch)");
|
||||
}
|
||||
|
||||
template <class ConfigRecordType>
|
||||
|
|
Загрузка…
Ссылка в новой задаче