change each tab to four spaces for new code pulled.
This commit is contained in:
Родитель
b9f1268172
Коммит
29c42731b1
|
@ -28,20 +28,20 @@ using namespace std;
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template<class ElemType>
|
||||
void DecimateMinibatch(std::map<std::wstring, MSR::CNTK::Matrix<ElemType>*> &mb)
|
||||
{
|
||||
if ( numProcs > 1 ) for (auto it = mb.begin(); it != mb.end(); ++it)
|
||||
{
|
||||
MSR::CNTK::Matrix<ElemType> &mat = *(it->second);
|
||||
size_t nCols = mat.GetNumCols();
|
||||
size_t col_start = (nCols * myRank)/ numProcs;
|
||||
size_t col_end = (nCols*(myRank + 1)) / numProcs;
|
||||
if (col_end > nCols) col_end = nCols; // this shouldn't happen
|
||||
MSR::CNTK::Matrix<ElemType> tmp = mat.ColumnSlice(col_start, col_end - col_start);
|
||||
mat.SetValue(tmp);
|
||||
}
|
||||
}
|
||||
template<class ElemType>
|
||||
void DecimateMinibatch(std::map<std::wstring, MSR::CNTK::Matrix<ElemType>*> &mb)
|
||||
{
|
||||
if ( numProcs > 1 ) for (auto it = mb.begin(); it != mb.end(); ++it)
|
||||
{
|
||||
MSR::CNTK::Matrix<ElemType> &mat = *(it->second);
|
||||
size_t nCols = mat.GetNumCols();
|
||||
size_t col_start = (nCols * myRank)/ numProcs;
|
||||
size_t col_end = (nCols*(myRank + 1)) / numProcs;
|
||||
if (col_end > nCols) col_end = nCols; // this shouldn't happen
|
||||
MSR::CNTK::Matrix<ElemType> tmp = mat.ColumnSlice(col_start, col_end - col_start);
|
||||
mat.SetValue(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
enum class LearningRateSearchAlgorithm : int
|
||||
{
|
||||
|
@ -379,10 +379,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
fprintf(stderr, "Starting from checkpoint. Load Network From File %ls.\n", modelFileName.c_str());
|
||||
ComputationNetwork<ElemType>& net =
|
||||
startEpoch<0? netBuilder->BuildNetworkFromDescription() : netBuilder->LoadNetworkFromFile(modelFileName);
|
||||
// TODO: BUGBUG: if not starting from checkpoint, need to synchronize initial model
|
||||
// strategy should be to run the initializer above on myRank==0, and then broadcast parameters.
|
||||
// TODO: BUGBUG: if not starting from checkpoint, need to synchronize initial model
|
||||
// strategy should be to run the initializer above on myRank==0, and then broadcast parameters.
|
||||
|
||||
startEpoch = max(startEpoch, 0);
|
||||
startEpoch = max(startEpoch, 0);
|
||||
m_needRegularization = false;
|
||||
|
||||
TrainOrAdaptModel(startEpoch, net, net, nullptr, trainSetDataReader, validationSetDataReader);
|
||||
|
@ -494,9 +494,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
prevLearnRates[i] = std::numeric_limits<ElemType>::infinity();
|
||||
|
||||
//precompute mean and invStdDev nodes and save initial model
|
||||
if (PreCompute(net, trainSetDataReader, FeatureNodes, labelNodes, inputMatrices) || startEpoch == 0)
|
||||
if (0 == myRank) // only needs to be done by one process
|
||||
net.SaveToFile(GetModelNameForEpoch(int(startEpoch) - 1));
|
||||
if (PreCompute(net, trainSetDataReader, FeatureNodes, labelNodes, inputMatrices) || startEpoch == 0)
|
||||
if (0 == myRank) // only needs to be done by one process
|
||||
net.SaveToFile(GetModelNameForEpoch(int(startEpoch) - 1));
|
||||
|
||||
bool learnRateInitialized = false;
|
||||
if (startEpoch > 0)
|
||||
|
@ -598,26 +598,26 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
#ifdef MPI_SUPPORT
|
||||
// model reduction and averaging
|
||||
if ( numProcs > 0 )
|
||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
|
||||
{
|
||||
ComputationNodePtr node = (*nodeIter);
|
||||
Microsoft::MSR::CNTK::Matrix<ElemType> &mat = node->FunctionValues();
|
||||
ElemType *px = mat.CopyToArray();
|
||||
size_t nx = mat.GetNumElements();
|
||||
vector<ElemType> py = vector<ElemType>(nx, ElemType(0));
|
||||
// TODO: Replace this with the reduction-shuffle-dance
|
||||
MPI_Reduce(px, &(py[0]), (int)nx, sizeof(ElemType) == 4 ? MPI_FLOAT : MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
|
||||
if (myRank == 0)
|
||||
transform(py.begin(), py.end(), py.begin(), [](ElemType&val)->ElemType{return val / (ElemType)numProcs; });
|
||||
MPI_Bcast(&(py[0]), nx, sizeof(ElemType) == 4 ? MPI_FLOAT : MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mat.SetValue(mat.GetNumRows(), mat.GetNumCols(), &(py[0]));
|
||||
delete px;
|
||||
}
|
||||
// model reduction and averaging
|
||||
if ( numProcs > 0 )
|
||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
|
||||
{
|
||||
ComputationNodePtr node = (*nodeIter);
|
||||
Microsoft::MSR::CNTK::Matrix<ElemType> &mat = node->FunctionValues();
|
||||
ElemType *px = mat.CopyToArray();
|
||||
size_t nx = mat.GetNumElements();
|
||||
vector<ElemType> py = vector<ElemType>(nx, ElemType(0));
|
||||
// TODO: Replace this with the reduction-shuffle-dance
|
||||
MPI_Reduce(px, &(py[0]), (int)nx, sizeof(ElemType) == 4 ? MPI_FLOAT : MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
|
||||
if (myRank == 0)
|
||||
transform(py.begin(), py.end(), py.begin(), [](ElemType&val)->ElemType{return val / (ElemType)numProcs; });
|
||||
MPI_Bcast(&(py[0]), nx, sizeof(ElemType) == 4 ? MPI_FLOAT : MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mat.SetValue(mat.GetNumRows(), mat.GetNumCols(), &(py[0]));
|
||||
delete px;
|
||||
}
|
||||
#endif
|
||||
|
||||
if ( 0 == myRank ) // only evaluate once, on the master process. TODO: This could be faster by farming out the validation parts
|
||||
if ( 0 == myRank ) // only evaluate once, on the master process. TODO: This could be faster by farming out the validation parts
|
||||
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
|
||||
{
|
||||
SimpleEvaluator<ElemType> evalforvalidation(net);
|
||||
|
@ -632,8 +632,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
epochCriterion = vScore[0]; //the first one is the training criterion.
|
||||
}
|
||||
#ifdef MPI_SUPPORT
|
||||
// ensure all processes have the same epochCriterion
|
||||
MPI_Bcast(&epochCriterion, 1, sizeof(epochCriterion) == 4 ? MPI_FLOAT : MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
// ensure all processes have the same epochCriterion
|
||||
MPI_Bcast(&epochCriterion, 1, sizeof(epochCriterion) == 4 ? MPI_FLOAT : MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
#endif
|
||||
|
||||
bool loadedPrevModel = false;
|
||||
|
@ -661,8 +661,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
else
|
||||
{
|
||||
if ( myRank == 0 )
|
||||
net.SaveToFile(GetModelNameForEpoch(i, true));
|
||||
if ( myRank == 0 )
|
||||
net.SaveToFile(GetModelNameForEpoch(i, true));
|
||||
fprintf(stderr, "Finished training and saved final model\n\n");
|
||||
break;
|
||||
}
|
||||
|
@ -695,13 +695,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
//persist model and check-point info
|
||||
if (0 == myRank)
|
||||
{
|
||||
net.SaveToFile(GetModelNameForEpoch(i));
|
||||
SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion);
|
||||
if (!m_keepCheckPointFiles)
|
||||
_wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); //delete previous checkpiont file to save space
|
||||
}
|
||||
if (0 == myRank)
|
||||
{
|
||||
net.SaveToFile(GetModelNameForEpoch(i));
|
||||
SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion);
|
||||
if (!m_keepCheckPointFiles)
|
||||
_wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); //delete previous checkpiont file to save space
|
||||
}
|
||||
|
||||
if (learnRatePerSample < 1e-12)
|
||||
fprintf(stderr, "learnRate per sample is reduced to %.8g which is below 1e-12. stop training.\n", learnRatePerSample);
|
||||
|
@ -982,7 +982,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
while (trainSetDataReader->GetMinibatch(inputMatrices))
|
||||
{
|
||||
#ifdef MPI_SUPPORT
|
||||
DecimateMinibatch(inputMatrices);
|
||||
DecimateMinibatch(inputMatrices);
|
||||
#endif
|
||||
endReadMBTime=clock();
|
||||
startComputeMBTime=clock();
|
||||
|
|
|
@ -577,43 +577,43 @@ std::string TimeDateStamp()
|
|||
// Oh, my gosh, this is going to be ugly. MPI_INIT needs a char* argv[], so let's interface.
|
||||
int MPIAPI MPI_Init(_In_opt_ int *argc, _Inout_count_(*argc) wchar_t*** argv)
|
||||
{
|
||||
// this maps from the strings
|
||||
std::map<std::string, wchar_t*> recover_wstring;
|
||||
// this maps from the strings
|
||||
std::map<std::string, wchar_t*> recover_wstring;
|
||||
|
||||
// do the mapping to 8-bit encoding for MPI_Init()
|
||||
vector<vector<char>> argv_string_vector;
|
||||
transform(*argv, *argv + *argc, std::back_inserter(argv_string_vector),
|
||||
[&recover_wstring](wchar_t*pws)->vector<char>
|
||||
{
|
||||
std::string tmp = msra::strfun::utf8(std::wstring(pws));
|
||||
recover_wstring[tmp] = pws;
|
||||
vector<char> rv(tmp.begin(), tmp.end());
|
||||
rv.push_back('\0');
|
||||
return rv;
|
||||
}
|
||||
);
|
||||
vector<char*> argv_charptr_vector;
|
||||
transform(argv_string_vector.begin(), argv_string_vector.end(), std::back_inserter(argv_charptr_vector),
|
||||
[](std::vector<char>&cs)->char*{ return &(cs[0]); }
|
||||
);
|
||||
char** argv_char = &(argv_charptr_vector[0]);
|
||||
// do the mapping to 8-bit encoding for MPI_Init()
|
||||
vector<vector<char>> argv_string_vector;
|
||||
transform(*argv, *argv + *argc, std::back_inserter(argv_string_vector),
|
||||
[&recover_wstring](wchar_t*pws)->vector<char>
|
||||
{
|
||||
std::string tmp = msra::strfun::utf8(std::wstring(pws));
|
||||
recover_wstring[tmp] = pws;
|
||||
vector<char> rv(tmp.begin(), tmp.end());
|
||||
rv.push_back('\0');
|
||||
return rv;
|
||||
}
|
||||
);
|
||||
vector<char*> argv_charptr_vector;
|
||||
transform(argv_string_vector.begin(), argv_string_vector.end(), std::back_inserter(argv_charptr_vector),
|
||||
[](std::vector<char>&cs)->char*{ return &(cs[0]); }
|
||||
);
|
||||
char** argv_char = &(argv_charptr_vector[0]);
|
||||
|
||||
// Do the initialization
|
||||
int rv = MPI_Init(argc, &argv_char);
|
||||
// Do the initialization
|
||||
int rv = MPI_Init(argc, &argv_char);
|
||||
|
||||
// try and reconstruct how MPI_Init changed the argv
|
||||
transform(argv_char, argv_char + *argc, stdext::checked_array_iterator<wchar_t**>(*argv, *argc),
|
||||
[&recover_wstring](char*pc)->wchar_t*
|
||||
{
|
||||
auto it = recover_wstring.find(std::string(pc));
|
||||
if (it == recover_wstring.end())
|
||||
RuntimeError("Unexpected interaction between MPI_Init and command line parameters");
|
||||
return it->second;
|
||||
}
|
||||
);
|
||||
// try and reconstruct how MPI_Init changed the argv
|
||||
transform(argv_char, argv_char + *argc, stdext::checked_array_iterator<wchar_t**>(*argv, *argc),
|
||||
[&recover_wstring](char*pc)->wchar_t*
|
||||
{
|
||||
auto it = recover_wstring.find(std::string(pc));
|
||||
if (it == recover_wstring.end())
|
||||
RuntimeError("Unexpected interaction between MPI_Init and command line parameters");
|
||||
return it->second;
|
||||
}
|
||||
);
|
||||
|
||||
// pass through return value from internal call to MPI_Init()
|
||||
return rv;
|
||||
// pass through return value from internal call to MPI_Init()
|
||||
return rv;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -623,22 +623,22 @@ int wmain(int argc, wchar_t* argv[])
|
|||
{
|
||||
|
||||
#ifdef MPI_SUPPORT
|
||||
{
|
||||
int rc;
|
||||
rc = MPI_Init(&argc, &argv);
|
||||
if (rc != MPI_SUCCESS)
|
||||
{
|
||||
MPI_Abort(MPI_COMM_WORLD, rc);
|
||||
RuntimeError("Failure in MPI_Init: %d", rc);
|
||||
}
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &numProcs);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &myRank);
|
||||
fprintf(stderr, "MPI: RUNNING ON (%s), process %d/%d\n", getenv("COMPUTERNAME"), myRank, numProcs);
|
||||
fflush(stderr);
|
||||
}
|
||||
{
|
||||
int rc;
|
||||
rc = MPI_Init(&argc, &argv);
|
||||
if (rc != MPI_SUCCESS)
|
||||
{
|
||||
MPI_Abort(MPI_COMM_WORLD, rc);
|
||||
RuntimeError("Failure in MPI_Init: %d", rc);
|
||||
}
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &numProcs);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &myRank);
|
||||
fprintf(stderr, "MPI: RUNNING ON (%s), process %d/%d\n", getenv("COMPUTERNAME"), myRank, numProcs);
|
||||
fflush(stderr);
|
||||
}
|
||||
#else
|
||||
numProcs = 1;
|
||||
myRank = 0;
|
||||
numProcs = 1;
|
||||
myRank = 0;
|
||||
#endif
|
||||
|
||||
ConfigParameters config;
|
||||
|
@ -656,64 +656,64 @@ int wmain(int argc, wchar_t* argv[])
|
|||
logpath += (wstring)command[i];
|
||||
}
|
||||
logpath += L".log";
|
||||
if (numProcs > 1)
|
||||
{
|
||||
std::wostringstream oss;
|
||||
oss << myRank;
|
||||
logpath += L"rank" + oss.str();
|
||||
}
|
||||
if (numProcs > 1)
|
||||
{
|
||||
std::wostringstream oss;
|
||||
oss << myRank;
|
||||
logpath += L"rank" + oss.str();
|
||||
}
|
||||
RedirectStdErr(logpath);
|
||||
}
|
||||
|
||||
std::string timestamp = TimeDateStamp();
|
||||
|
||||
if (myRank == 0) // main process
|
||||
{
|
||||
//dump config info
|
||||
fprintf(stderr, "running on %s at %s\n", GetHostName().c_str(), timestamp.c_str());
|
||||
fprintf(stderr, "command line options: \n");
|
||||
for (int i = 1; i < argc; i++)
|
||||
fprintf(stderr, "%s ", WCharToString(argv[i]).c_str());
|
||||
if (myRank == 0) // main process
|
||||
{
|
||||
//dump config info
|
||||
fprintf(stderr, "running on %s at %s\n", GetHostName().c_str(), timestamp.c_str());
|
||||
fprintf(stderr, "command line options: \n");
|
||||
for (int i = 1; i < argc; i++)
|
||||
fprintf(stderr, "%s ", WCharToString(argv[i]).c_str());
|
||||
|
||||
// This simply merges all the different config parameters specified (eg, via config files or via command line directly),
|
||||
// and prints it.
|
||||
fprintf(stderr, "\n\n>>>>>>>>>>>>>>>>>>>> RAW CONFIG (VARIABLES NOT RESOLVED) >>>>>>>>>>>>>>>>>>>>\n");
|
||||
fprintf(stderr, "%s\n", rawConfigString.c_str());
|
||||
fprintf(stderr, "<<<<<<<<<<<<<<<<<<<< RAW CONFIG (VARIABLES NOT RESOLVED) <<<<<<<<<<<<<<<<<<<<\n");
|
||||
// This simply merges all the different config parameters specified (eg, via config files or via command line directly),
|
||||
// and prints it.
|
||||
fprintf(stderr, "\n\n>>>>>>>>>>>>>>>>>>>> RAW CONFIG (VARIABLES NOT RESOLVED) >>>>>>>>>>>>>>>>>>>>\n");
|
||||
fprintf(stderr, "%s\n", rawConfigString.c_str());
|
||||
fprintf(stderr, "<<<<<<<<<<<<<<<<<<<< RAW CONFIG (VARIABLES NOT RESOLVED) <<<<<<<<<<<<<<<<<<<<\n");
|
||||
|
||||
// Same as above, but all variables are resolved. If a parameter is set multiple times (eg, set in config, overriden at command line),
|
||||
// All of these assignments will appear, even though only the last assignment matters.
|
||||
fprintf(stderr, "\n>>>>>>>>>>>>>>>>>>>> RAW CONFIG WITH ALL VARIABLES RESOLVED >>>>>>>>>>>>>>>>>>>>\n");
|
||||
fprintf(stderr, "%s\n", config.ResolveVariables(rawConfigString).c_str());
|
||||
fprintf(stderr, "<<<<<<<<<<<<<<<<<<<< RAW CONFIG WITH ALL VARIABLES RESOLVED <<<<<<<<<<<<<<<<<<<<\n");
|
||||
// Same as above, but all variables are resolved. If a parameter is set multiple times (eg, set in config, overriden at command line),
|
||||
// All of these assignments will appear, even though only the last assignment matters.
|
||||
fprintf(stderr, "\n>>>>>>>>>>>>>>>>>>>> RAW CONFIG WITH ALL VARIABLES RESOLVED >>>>>>>>>>>>>>>>>>>>\n");
|
||||
fprintf(stderr, "%s\n", config.ResolveVariables(rawConfigString).c_str());
|
||||
fprintf(stderr, "<<<<<<<<<<<<<<<<<<<< RAW CONFIG WITH ALL VARIABLES RESOLVED <<<<<<<<<<<<<<<<<<<<\n");
|
||||
|
||||
// This outputs the final value each variable/parameter is assigned to in config (so if a parameter is set multiple times, only the last
|
||||
// value it is set to will appear).
|
||||
fprintf(stderr, "\n>>>>>>>>>>>>>>>>>>>> PROCESSED CONFIG WITH ALL VARIABLES RESOLVED >>>>>>>>>>>>>>>>>>>>\n");
|
||||
config.dumpWithResolvedVariables();
|
||||
fprintf(stderr, "<<<<<<<<<<<<<<<<<<<< PROCESSED CONFIG WITH ALL VARIABLES RESOLVED <<<<<<<<<<<<<<<<<<<<\n");
|
||||
// This outputs the final value each variable/parameter is assigned to in config (so if a parameter is set multiple times, only the last
|
||||
// value it is set to will appear).
|
||||
fprintf(stderr, "\n>>>>>>>>>>>>>>>>>>>> PROCESSED CONFIG WITH ALL VARIABLES RESOLVED >>>>>>>>>>>>>>>>>>>>\n");
|
||||
config.dumpWithResolvedVariables();
|
||||
fprintf(stderr, "<<<<<<<<<<<<<<<<<<<< PROCESSED CONFIG WITH ALL VARIABLES RESOLVED <<<<<<<<<<<<<<<<<<<<\n");
|
||||
|
||||
fprintf(stderr, "command: ");
|
||||
for (int i = 0; i < command.size(); i++)
|
||||
{
|
||||
fprintf(stderr, "%s ", command[i].c_str());
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "command: ");
|
||||
for (int i = 0; i < command.size(); i++)
|
||||
{
|
||||
fprintf(stderr, "%s ", command[i].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
//run commands
|
||||
std::string type = config("precision", "float");
|
||||
// accept old precision key for backward compatibility
|
||||
if (config.Exists("type"))
|
||||
type = config("type", "float");
|
||||
if ( myRank == 0 )
|
||||
fprintf(stderr, "\nprecision = %s\n", type.c_str());
|
||||
if ( myRank == 0 )
|
||||
fprintf(stderr, "\nprecision = %s\n", type.c_str());
|
||||
if (type == "float")
|
||||
DoCommand<float>(config);
|
||||
else if (type == "double")
|
||||
DoCommand<double>(config);
|
||||
else
|
||||
RuntimeError("invalid precision specified: %s", type.c_str());
|
||||
}
|
||||
}
|
||||
catch(std::exception &err)
|
||||
{
|
||||
fprintf(stderr, "EXCEPTION occurred: %s", err.what());
|
||||
|
@ -731,7 +731,7 @@ int wmain(int argc, wchar_t* argv[])
|
|||
return EXIT_FAILURE;
|
||||
}
|
||||
#ifdef MPI_SUPPORT
|
||||
MPI_Finalize();
|
||||
MPI_Finalize();
|
||||
#endif
|
||||
return EXIT_SUCCESS;
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче