Merge remote-tracking branch 'origin/master' into thilow/WordRnn3
This commit is contained in:
Коммит
eac4bdce59
|
@ -119,6 +119,12 @@
|
|||
<LinkIncremental>$(DebugBuild)</LinkIncremental>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemDefinitionGroup>
|
||||
<ClCompile>
|
||||
<PreprocessorDefinitions>HAS_MPI=1</PreprocessorDefinitions>
|
||||
</ClCompile>
|
||||
</ItemDefinitionGroup>
|
||||
|
||||
<ItemDefinitionGroup Condition="'$(ConfigurationType)' == 'StaticLibrary'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
|
|
|
@ -32,7 +32,7 @@ TrainConvNet = {
|
|||
x2s = SplitDimension(x2, 3, 1)
|
||||
# 3D convolution with a filter that has a non 1-size only in the 3rd axis, and does not reduce since the reduction dimension is fake and 1
|
||||
W = ParameterTensor{(1:1:2*n+1:1), learningRateMultiplier = 0, initValue = alpha/(2*n+1)}
|
||||
y = Convolution (W, x2s, (1:1:2*n+1), mapDims = 1, stride = 1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, transpose = false, maxTempMemSizeInSamples = 0)
|
||||
y = Convolution (W, x2s, (1:1:2*n+1), mapDims = 1, stride = 1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, maxTempMemSizeInSamples = 0)
|
||||
# reshape back to remove the fake singleton reduction dimension
|
||||
b = FlattenDimensions(y, 3, 2)
|
||||
den = Exp (beta .* Log(k + b))
|
||||
|
|
|
@ -117,7 +117,7 @@ def convnetlrn_cifar10_dataaug(reader_train, reader_test, epoch_size=50000, max_
|
|||
}
|
||||
|
||||
cntk.utils.log_number_of_parameters(z) ; print()
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training')
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
|
||||
# perform model training
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
|
|
|
@ -84,10 +84,10 @@ def convnet_cifar10(debug_output=False):
|
|||
}
|
||||
|
||||
cntk.utils.log_number_of_parameters(z) ; print()
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training')
|
||||
max_epochs = 30
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
max_epochs = 30
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
sample_count = 0
|
||||
while sample_count < epoch_size: # loop over minibatches in the epoch
|
||||
|
|
|
@ -95,7 +95,7 @@ def convnet_cifar10_dataaug(reader_train, reader_test, epoch_size = 50000, max_e
|
|||
}
|
||||
|
||||
cntk.utils.log_number_of_parameters(z) ; print()
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training')
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
|
||||
# perform model training
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
|
|
|
@ -74,10 +74,10 @@ def convnet_mnist(debug_output=False):
|
|||
}
|
||||
|
||||
cntk.utils.log_number_of_parameters(z) ; print()
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training')
|
||||
max_epochs = 40
|
||||
progress_printer = cntk.utils.ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
max_epochs = 40
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
sample_count = 0
|
||||
while sample_count < epoch_size: # loop over minibatches in the epoch
|
||||
|
|
|
@ -97,7 +97,7 @@ def train_and_evaluate(reader_train, reader_test, network_name, epoch_size, max_
|
|||
}
|
||||
|
||||
log_number_of_parameters(z) ; print()
|
||||
progress_printer = ProgressPrinter(tag='Training')
|
||||
progress_printer = ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
|
||||
# perform model training
|
||||
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
from __future__ import print_function
|
||||
import zipfile
|
||||
import os
|
||||
from sys import platform
|
||||
import shutil
|
||||
|
||||
try:
|
||||
from urllib.request import urlretrieve
|
||||
except ImportError:
|
||||
|
@ -26,6 +29,15 @@ def download_grocery_data():
|
|||
print('Extracting ' + filename + '...')
|
||||
with zipfile.ZipFile(filename) as myzip:
|
||||
myzip.extractall(dataset_folder)
|
||||
if platform != "win32":
|
||||
testfile = os.path.join(dataset_folder, "grocery", "test.txt")
|
||||
unixfile = os.path.join(dataset_folder, "grocery", "test_unix.txt")
|
||||
out = open(unixfile, 'w')
|
||||
with open(testfile) as f:
|
||||
for line in f:
|
||||
out.write(line.replace('\\', '/'))
|
||||
out.close()
|
||||
shutil.move(unixfile, testfile)
|
||||
finally:
|
||||
os.remove(filename)
|
||||
print('Done.')
|
||||
|
@ -34,4 +46,4 @@ def download_grocery_data():
|
|||
|
||||
if __name__ == "__main__":
|
||||
download_grocery_data()
|
||||
|
||||
|
||||
|
|
|
@ -13,8 +13,7 @@ import cntk
|
|||
# variables and stuff #
|
||||
########################
|
||||
|
||||
cntk_dir = os.path.dirname(os.path.abspath(__file__)) + "/../../../.." # data resides in the CNTK folder
|
||||
data_dir = cntk_dir + "/Examples/LanguageUnderstanding/ATIS/Data" # under Examples/LanguageUnderstanding/ATIS
|
||||
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "Data")
|
||||
vocab_size = 943 ; num_labels = 129 ; num_intents = 26 # number of words in vocab, slot labels, and intent labels
|
||||
|
||||
model_dir = "./Models"
|
||||
|
@ -92,8 +91,8 @@ def train(reader, model, max_epochs):
|
|||
|
||||
# process minibatches and perform model training
|
||||
cntk.utils.log_number_of_parameters(z) ; print()
|
||||
progress_printer = cntk.ProgressPrinter(freq=100, first=10, tag='Training') # more detailed logging
|
||||
#progress_printer = ProgressPrinter(tag='Training')
|
||||
progress_printer = cntk.ProgressPrinter(freq=100, first=10, tag='Training', num_epochs=max_epochs) # more detailed logging
|
||||
#progress_printer = ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
|
||||
t = 0
|
||||
|
||||
|
|
|
@ -139,7 +139,7 @@ def create_inputs(vocab_dim):
|
|||
return input_sequence, label_sequence
|
||||
|
||||
# Creates and trains a character-level language model
|
||||
def train_lm(training_file, max_num_minibatches):
|
||||
def train_lm(training_file, epochs, max_num_minibatches):
|
||||
|
||||
# load the data and vocab
|
||||
data, char_to_ix, ix_to_char, data_size, vocab_dim = load_data_and_vocab(training_file)
|
||||
|
@ -168,46 +168,34 @@ def train_lm(training_file, max_num_minibatches):
|
|||
trainer = Trainer(z, (ce, errs), learner)
|
||||
|
||||
sample_freq = 1000
|
||||
epochs = 50
|
||||
minibatches_per_epoch = int((data_size / minibatch_size))
|
||||
minibatches = min(epochs * minibatches_per_epoch, max_num_minibatches)
|
||||
minibatches_per_epoch = min(data_size // minibatch_size, max_num_minibatches // epochs)
|
||||
|
||||
# print out some useful training information
|
||||
log_number_of_parameters(z) ; print()
|
||||
progress_printer = ProgressPrinter(freq=100, tag='Training')
|
||||
log_number_of_parameters(z)
|
||||
print ("Running %d epochs with %d minibatches per epoch" % (epochs, minibatches_per_epoch))
|
||||
print()
|
||||
|
||||
e = 0
|
||||
p = 0
|
||||
for i in range(0, minibatches):
|
||||
|
||||
if p + minibatch_size+1 >= data_size:
|
||||
p = 0
|
||||
e += 1
|
||||
model_filename = "models/shakespeare_epoch%d.dnn" % e
|
||||
z.save(model_filename)
|
||||
print("Saved model to '%s'" % model_filename)
|
||||
|
||||
# get the data
|
||||
features, labels = get_data(p, minibatch_size, data, char_to_ix, vocab_dim)
|
||||
progress_printer = ProgressPrinter(freq=100, tag='Training')
|
||||
|
||||
for e in range(0, epochs):
|
||||
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
|
||||
# If it's the start of the data, we specify that we are looking at a new sequence (True)
|
||||
mask = [False]
|
||||
if p == 0:
|
||||
mask = [True]
|
||||
arguments = ({input_sequence : features, label_sequence : labels}, mask)
|
||||
trainer.train_minibatch(arguments)
|
||||
mask = [True]
|
||||
for b in range(0, minibatches_per_epoch):
|
||||
# get the data
|
||||
features, labels = get_data(b, minibatch_size, data, char_to_ix, vocab_dim)
|
||||
arguments = ({input_sequence : features, label_sequence : labels}, mask)
|
||||
mask = [False]
|
||||
trainer.train_minibatch(arguments)
|
||||
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
|
||||
if i % sample_freq == 0:
|
||||
print(sample(z, ix_to_char, vocab_dim, char_to_ix))
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
global_minibatch = e*minibatches_per_epoch + b
|
||||
if global_minibatch % sample_freq == 0:
|
||||
print(sample(z, ix_to_char, vocab_dim, char_to_ix))
|
||||
|
||||
p += minibatch_size
|
||||
|
||||
# Do a final save of the model
|
||||
model_filename = "models/shakespeare_epoch%d.dnn" % e
|
||||
z.save(model_filename)
|
||||
model_filename = "models/shakespeare_epoch%d.dnn" % (e+1)
|
||||
z.save_model(model_filename)
|
||||
print("Saved model to '%s'" % model_filename)
|
||||
|
||||
|
||||
def load_and_sample(model_filename, vocab_filename, prime_text='', use_hardmax=False, length=1000, temperature=1.0):
|
||||
|
@ -223,13 +211,13 @@ def load_and_sample(model_filename, vocab_filename, prime_text='', use_hardmax=F
|
|||
|
||||
return sample(model, ix_to_char, len(chars), char_to_ix, prime_text=prime_text, use_hardmax=use_hardmax, length=length, temperature=temperature)
|
||||
|
||||
def train_and_eval_char_rnn(max_num_minibatches=sys.maxsize):
|
||||
# train the LM
|
||||
train_lm("data/tinyshakespeare.txt", max_num_minibatches)
|
||||
def train_and_eval_char_rnn(epochs=50, max_num_minibatches=sys.maxsize):
|
||||
# train the LM
|
||||
train_lm("data/tinyshakespeare.txt", epochs, max_num_minibatches)
|
||||
|
||||
# load and sample
|
||||
text = "T"
|
||||
return load_and_sample("models/shakespeare_epoch0.dnn", "data/tinyshakespeare.txt.vocab", prime_text=text, use_hardmax=False, length=100, temperature=0.95)
|
||||
return load_and_sample("models/shakespeare_epoch%d.dnn" % (epochs), "data/tinyshakespeare.txt.vocab", prime_text=text, use_hardmax=False, length=100, temperature=0.95)
|
||||
|
||||
if __name__=='__main__':
|
||||
# Specify the target device to be used for computing, if you do not want to
|
||||
|
|
|
@ -23,7 +23,7 @@ from _cntk_py import set_computation_network_trace_level
|
|||
|
||||
# Paths relative to current python file.
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
data_path = os.path.join(abs_path, "..", "..", "Datasets", "UCF11")
|
||||
data_path = os.path.join(abs_path, "..", "..", "DataSets", "UCF11")
|
||||
model_path = os.path.join(abs_path, "Models")
|
||||
|
||||
# Define the reader for both training and evaluation action.
|
||||
|
@ -194,14 +194,14 @@ def conv3d_ucf11(train_reader, test_reader, max_epochs=30):
|
|||
lr_per_sample = [0.01]*10+[0.001]*10+[0.0001]
|
||||
lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size, unit=UnitType.sample)
|
||||
momentum_time_constant = 4096
|
||||
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant, epoch_size=epoch_size)
|
||||
mm_schedule = momentum_as_time_constant_schedule([momentum_time_constant], epoch_size=epoch_size)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule, True)
|
||||
trainer = Trainer(z, (ce, pe), learner)
|
||||
|
||||
log_number_of_parameters(z) ; print()
|
||||
progress_printer = ProgressPrinter(tag='Training')
|
||||
progress_printer = ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
|
|
9
Makefile
9
Makefile
|
@ -77,7 +77,10 @@ endif
|
|||
|
||||
# The mpic++ wrapper only adds MPI specific flags to the g++ command line.
|
||||
# The actual compiler/linker flags added can be viewed by running 'mpic++ --showme:compile' and 'mpic++ --showme:link'
|
||||
ifneq ($(HAS_MPI),0)
|
||||
CXX = $(MPI_PATH)/bin/mpic++
|
||||
endif
|
||||
|
||||
SSE_FLAGS = -msse4.1 -mssse3
|
||||
|
||||
PROTOC = $(PROTOBUF_PATH)/bin/protoc
|
||||
|
@ -90,8 +93,8 @@ SOURCEDIR:= Source
|
|||
INCLUDEPATH:= $(addprefix $(SOURCEDIR)/, Common/Include CNTKv2LibraryDll CNTKv2LibraryDll/API CNTKv2LibraryDll/proto Math CNTK ActionsLib ComputationNetworkLib SGDLib SequenceTrainingLib CNTK/BrainScript Readers/ReaderLib PerformanceProfilerDll)
|
||||
INCLUDEPATH+=$(PROTOBUF_PATH)/include
|
||||
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
|
||||
COMMON_FLAGS:= -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11
|
||||
CPPFLAGS:=
|
||||
COMMON_FLAGS:= -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11
|
||||
CPPFLAGS:=
|
||||
CXXFLAGS:= $(SSE_FLAGS) -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new
|
||||
LIBPATH:=
|
||||
LIBS_LIST:=
|
||||
|
@ -270,7 +273,7 @@ RPATH=-Wl,-rpath,
|
|||
# Build info
|
||||
########################################
|
||||
|
||||
BUILDINFO:= $(SOURCEDIR)/CNTK/buildinfo.h
|
||||
BUILDINFO:= $(SOURCEDIR)/CNTKv2LibraryDll/buildinfo.h
|
||||
GENBUILD:=Tools/generate_build_info
|
||||
|
||||
BUILDINFO_OUTPUT := $(shell $(GENBUILD) $(BUILD_TOP)/Config.make && echo Success)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 7cfb2870164c49dfdfd219e4bcab102c8ef0cd96
|
||||
Subproject commit c8b77d6e325a4786547b27624890276c1483aed1
|
|
@ -36,6 +36,7 @@
|
|||
#include "BrainScriptEvaluator.h"
|
||||
#include "BrainScriptParser.h"
|
||||
#include "PerformanceProfiler.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
#include <string>
|
||||
#include <chrono>
|
||||
|
@ -372,55 +373,6 @@ std::string TimeDateStamp()
|
|||
return buf;
|
||||
}
|
||||
|
||||
void PrintBuiltInfo()
|
||||
{
|
||||
LOGPRINTF(stderr, "-------------------------------------------------------------------\n");
|
||||
LOGPRINTF(stderr, "Build info: \n\n");
|
||||
LOGPRINTF(stderr, "\t\tBuilt time: %s %s\n", __DATE__, __TIME__);
|
||||
LOGPRINTF(stderr, "\t\tLast modified date: %s\n", __TIMESTAMP__);
|
||||
#ifdef _BUILDTYPE_
|
||||
LOGPRINTF(stderr, "\t\tBuild type: %s\n", _BUILDTYPE_);
|
||||
#endif
|
||||
#ifdef _BUILDTARGET_
|
||||
LOGPRINTF(stderr, "\t\tBuild target: %s\n", _BUILDTARGET_);
|
||||
#endif
|
||||
#ifdef _WITH_1BITSGD_
|
||||
LOGPRINTF(stderr, "\t\tWith 1bit-SGD: %s\n", _WITH_1BITSGD_);
|
||||
#endif
|
||||
#ifdef _WITH_ASGD_
|
||||
LOGPRINTF(stderr, "\t\tWith ASGD: %s\n", _WITH_ASGD_);
|
||||
#endif
|
||||
#ifdef _MATHLIB_
|
||||
LOGPRINTF(stderr, "\t\tMath lib: %s\n", _MATHLIB_);
|
||||
#endif
|
||||
#ifdef _CUDA_PATH_
|
||||
LOGPRINTF(stderr, "\t\tCUDA_PATH: %s\n", _CUDA_PATH_);
|
||||
#endif
|
||||
#ifdef _CUB_PATH_
|
||||
LOGPRINTF(stderr, "\t\tCUB_PATH: %s\n", _CUB_PATH_);
|
||||
#endif
|
||||
#ifdef _CUDNN_PATH_
|
||||
LOGPRINTF(stderr, "\t\tCUDNN_PATH: %s\n", _CUDNN_PATH_);
|
||||
#endif
|
||||
#ifdef _GIT_EXIST
|
||||
LOGPRINTF(stderr, "\t\tBuild Branch: %s\n", _BUILDBRANCH_);
|
||||
LOGPRINTF(stderr, "\t\tBuild SHA1: %s\n", _BUILDSHA1_);
|
||||
#endif
|
||||
#ifdef _BUILDER_
|
||||
LOGPRINTF(stderr, "\t\tBuilt by %s on %s\n", _BUILDER_, _BUILDMACHINE_);
|
||||
#endif
|
||||
#ifdef _BUILDPATH_
|
||||
LOGPRINTF(stderr, "\t\tBuild Path: %s\n", _BUILDPATH_);
|
||||
#endif
|
||||
#ifdef _MPI_NAME_
|
||||
LOGPRINTF(stderr, "\t\tMPI distribution: %s\n", _MPI_NAME_);
|
||||
#endif
|
||||
#ifdef _MPI_VERSION_
|
||||
LOGPRINTF(stderr, "\t\tMPI version: %s\n", _MPI_VERSION_);
|
||||
#endif
|
||||
LOGPRINTF(stderr, "-------------------------------------------------------------------\n");
|
||||
}
|
||||
|
||||
void PrintUsageInfo()
|
||||
{
|
||||
LOGPRINTF(stderr, "-------------------------------------------------------------------\n");
|
||||
|
@ -598,7 +550,7 @@ int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapp
|
|||
|
||||
RedirectStdErr(logpath);
|
||||
LOGPRINTF(stderr, "%ls\n", startupMessage.c_str());
|
||||
PrintBuiltInfo();
|
||||
::CNTK::PrintBuiltInfo();
|
||||
}
|
||||
|
||||
// echo gpu info to log
|
||||
|
@ -764,7 +716,7 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[])
|
|||
}
|
||||
|
||||
// full config info
|
||||
PrintBuiltInfo();
|
||||
::CNTK::PrintBuiltInfo();
|
||||
PrintGpuInfo();
|
||||
|
||||
#ifdef _DEBUG
|
||||
|
@ -857,7 +809,7 @@ int wmain1(int argc, wchar_t* argv[]) // called from wmain which is a wrapper th
|
|||
{
|
||||
if (argc <= 1)
|
||||
{
|
||||
PrintBuiltInfo(); // print build info directly in case that user provides zero argument (convenient for checking build type)
|
||||
::CNTK::PrintBuiltInfo(); // print build info directly in case that user provides zero argument (convenient for checking build type)
|
||||
LOGPRINTF(stderr, "No command-line argument given.\n");
|
||||
PrintUsageInfo();
|
||||
fflush(stderr);
|
||||
|
|
|
@ -85,7 +85,8 @@
|
|||
<StackReserveSize>100000000</StackReserveSize>
|
||||
</Link>
|
||||
<PreBuildEvent>
|
||||
<Command>prebuild.bat "$(Configuration)" "$(CNTK_MKL_SEQUENTIAL)" "$(CNTK_ENABLE_1BitSGD)" "$(CudaPath)" "$(CUDNN_PATH)" "$(CUB_PATH)" "$(CNTK_ENABLE_ASGD)"</Command>
|
||||
<Command>
|
||||
</Command>
|
||||
</PreBuildEvent>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(ReleaseBuild)">
|
||||
|
@ -113,7 +114,8 @@
|
|||
<StackReserveSize>100000000</StackReserveSize>
|
||||
</Link>
|
||||
<PreBuildEvent>
|
||||
<Command>prebuild.bat "$(Configuration)" "$(CNTK_MKL_SEQUENTIAL)" "$(CNTK_ENABLE_1BitSGD)" "$(CudaPath)" "$(CUDNN_PATH)" "$(CUB_PATH)" "$(CNTK_ENABLE_ASGD)"</Command>
|
||||
<Command>
|
||||
</Command>
|
||||
</PreBuildEvent>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
|
|
|
@ -4727,6 +4727,9 @@ namespace CNTK
|
|||
bool keepExistingCheckpoints = false,
|
||||
size_t maxNumberOfTrainingSamples = std::numeric_limits<size_t>::max(),
|
||||
size_t progressFrequency = std::numeric_limits<size_t>::max());
|
||||
|
||||
|
||||
CNTK_API void PrintBuiltInfo();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -106,6 +106,12 @@
|
|||
<DelayLoadDLLs>Math.dll; msmpi.dll; PerformanceProfilerDll.dll </DelayLoadDLLs>
|
||||
<OptimizeReferences Condition="'$(Configuration)|$(Platform)'=='Debug_CpuOnly|x64'">false</OptimizeReferences>
|
||||
</Link>
|
||||
<PreBuildEvent>
|
||||
<Command Condition="'$(Configuration)|$(Platform)'=='Release_CpuOnly|x64'">prebuild.bat "$(Configuration)" "$(CNTK_MKL_SEQUENTIAL)" "$(CNTK_ENABLE_1BitSGD)" "$(CudaPath)" "$(CUDNN_PATH)" "$(CUB_PATH)" "$(CNTK_ENABLE_ASGD)"</Command>
|
||||
</PreBuildEvent>
|
||||
<PreBuildEvent>
|
||||
<Command Condition="'$(Configuration)|$(Platform)'=='Debug_CpuOnly|x64'">prebuild.bat "$(Configuration)" "$(CNTK_MKL_SEQUENTIAL)" "$(CNTK_ENABLE_1BitSGD)" "$(CudaPath)" "$(CUDNN_PATH)" "$(CUB_PATH)" "$(CNTK_ENABLE_ASGD)"</Command>
|
||||
</PreBuildEvent>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
|
@ -118,6 +124,15 @@
|
|||
<Command>if exist "%ProgramW6432%\NVIDIA Corporation\NVSMI" xcopy /I /D /Y "%ProgramW6432%\NVIDIA Corporation\NVSMI\nvml*.dll" "$(TargetDir)"</Command>
|
||||
<Message>Copying NVidia GDK extension DLL to target folder</Message>
|
||||
</PostBuildEvent>
|
||||
<PreBuildEvent>
|
||||
<Command Condition="'$(Configuration)|$(Platform)'=='Release|x64'">prebuild.bat "$(Configuration)" "$(CNTK_MKL_SEQUENTIAL)" "$(CNTK_ENABLE_1BitSGD)" "$(CudaPath)" "$(CUDNN_PATH)" "$(CUB_PATH)" "$(CNTK_ENABLE_ASGD)"</Command>
|
||||
</PreBuildEvent>
|
||||
<PreBuildEvent>
|
||||
<Command Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">prebuild.bat "$(Configuration)" "$(CNTK_MKL_SEQUENTIAL)" "$(CNTK_ENABLE_1BitSGD)" "$(CudaPath)" "$(CUDNN_PATH)" "$(CUB_PATH)" "$(CNTK_ENABLE_ASGD)"</Command>
|
||||
</PreBuildEvent>
|
||||
<PreBuildEvent>
|
||||
<Command Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">prebuild.bat "$(Configuration)" "$(CNTK_MKL_SEQUENTIAL)" "$(CNTK_ENABLE_1BitSGD)" "$(CudaPath)" "$(CUDNN_PATH)" "$(CUB_PATH)" "$(CNTK_ENABLE_ASGD)"</Command>
|
||||
</PreBuildEvent>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="API\CNTKLibrary.h" />
|
||||
|
|
|
@ -17,8 +17,11 @@
|
|||
#include "PerformanceProfiler.h"
|
||||
#include "MPIWrapper.h"
|
||||
#include "Basics.h"
|
||||
#include "ProgressTracing.h"
|
||||
#include "buildinfo.h"
|
||||
|
||||
extern bool g_shareNodeValueMatrices;
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
@ -617,6 +620,56 @@ namespace CNTK
|
|||
va_end(args);
|
||||
}
|
||||
|
||||
|
||||
void PrintBuiltInfo()
|
||||
{
|
||||
LOGPRINTF(stderr, "-------------------------------------------------------------------\n");
|
||||
LOGPRINTF(stderr, "Build info: \n\n");
|
||||
LOGPRINTF(stderr, "\t\tBuilt time: %s %s\n", __DATE__, __TIME__);
|
||||
LOGPRINTF(stderr, "\t\tLast modified date: %s\n", __TIMESTAMP__);
|
||||
#ifdef _BUILDTYPE_
|
||||
LOGPRINTF(stderr, "\t\tBuild type: %s\n", _BUILDTYPE_);
|
||||
#endif
|
||||
#ifdef _BUILDTARGET_
|
||||
LOGPRINTF(stderr, "\t\tBuild target: %s\n", _BUILDTARGET_);
|
||||
#endif
|
||||
#ifdef _WITH_1BITSGD_
|
||||
LOGPRINTF(stderr, "\t\tWith 1bit-SGD: %s\n", _WITH_1BITSGD_);
|
||||
#endif
|
||||
#ifdef _WITH_ASGD_
|
||||
LOGPRINTF(stderr, "\t\tWith ASGD: %s\n", _WITH_ASGD_);
|
||||
#endif
|
||||
#ifdef _MATHLIB_
|
||||
LOGPRINTF(stderr, "\t\tMath lib: %s\n", _MATHLIB_);
|
||||
#endif
|
||||
#ifdef _CUDA_PATH_
|
||||
LOGPRINTF(stderr, "\t\tCUDA_PATH: %s\n", _CUDA_PATH_);
|
||||
#endif
|
||||
#ifdef _CUB_PATH_
|
||||
LOGPRINTF(stderr, "\t\tCUB_PATH: %s\n", _CUB_PATH_);
|
||||
#endif
|
||||
#ifdef _CUDNN_PATH_
|
||||
LOGPRINTF(stderr, "\t\tCUDNN_PATH: %s\n", _CUDNN_PATH_);
|
||||
#endif
|
||||
#ifdef _GIT_EXIST
|
||||
LOGPRINTF(stderr, "\t\tBuild Branch: %s\n", _BUILDBRANCH_);
|
||||
LOGPRINTF(stderr, "\t\tBuild SHA1: %s\n", _BUILDSHA1_);
|
||||
#endif
|
||||
#ifdef _BUILDER_
|
||||
LOGPRINTF(stderr, "\t\tBuilt by %s on %s\n", _BUILDER_, _BUILDMACHINE_);
|
||||
#endif
|
||||
#ifdef _BUILDPATH_
|
||||
LOGPRINTF(stderr, "\t\tBuild Path: %s\n", _BUILDPATH_);
|
||||
#endif
|
||||
#ifdef _MPI_NAME_
|
||||
LOGPRINTF(stderr, "\t\tMPI distribution: %s\n", _MPI_NAME_);
|
||||
#endif
|
||||
#ifdef _MPI_VERSION_
|
||||
LOGPRINTF(stderr, "\t\tMPI version: %s\n", _MPI_VERSION_);
|
||||
#endif
|
||||
LOGPRINTF(stderr, "-------------------------------------------------------------------\n");
|
||||
}
|
||||
|
||||
template CNTK_API __declspec_noreturn void ThrowFormatted<std::runtime_error>(const char* format, ...);
|
||||
template CNTK_API __declspec_noreturn void ThrowFormatted<std::logic_error>(const char* format, ...);
|
||||
template CNTK_API __declspec_noreturn void ThrowFormatted<std::invalid_argument>(const char* format, ...);
|
||||
|
|
|
@ -340,16 +340,16 @@ namespace CNTK
|
|||
if (dataType == DataType::Float)
|
||||
{
|
||||
if (inputData == outputData)
|
||||
m_mpi->AllReduceAsync<float>(static_cast<float*>(outputData), numElements, &allReduceRequests[i]);
|
||||
m_mpi->AllReduceAsync(static_cast<float*>(outputData), numElements, &allReduceRequests[i]);
|
||||
else
|
||||
m_mpi->AllReduceAsync<float>(static_cast<float*>(inputData), static_cast<float*>(outputData), numElements, &allReduceRequests[i]);
|
||||
m_mpi->AllReduceAsync(static_cast<float*>(inputData), static_cast<float*>(outputData), numElements, &allReduceRequests[i]);
|
||||
}
|
||||
else if (dataType == DataType::Double)
|
||||
{
|
||||
if (inputData == outputData)
|
||||
m_mpi->AllReduceAsync<double>(static_cast<double*>(outputData), numElements, &allReduceRequests[i]);
|
||||
m_mpi->AllReduceAsync(static_cast<double*>(outputData), numElements, &allReduceRequests[i]);
|
||||
else
|
||||
m_mpi->AllReduceAsync<double>(static_cast<double*>(inputData), static_cast<double*>(outputData), numElements, &allReduceRequests[i]);
|
||||
m_mpi->AllReduceAsync(static_cast<double*>(inputData), static_cast<double*>(outputData), numElements, &allReduceRequests[i]);
|
||||
}
|
||||
else
|
||||
LogicError("Unknown DataType");
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full licence information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#if HAS_MPI
|
||||
// Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#ms-mpi or
|
||||
// https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Linux#open-mpi for setup instructions
|
||||
// of an MPI implementation on your platform.
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// Suppress warning for non-ASCII characters in MS-MPI headers
|
||||
#pragma warning(push)
|
||||
|
@ -18,7 +18,25 @@
|
|||
#else
|
||||
#include "mpi.h"
|
||||
#endif
|
||||
#pragma comment(lib, "msmpi.lib")
|
||||
#else
|
||||
// Note: the following macros/typedefs define some of the MPI related functions and constants such that code
|
||||
// using these functionality will compile cleanly - but will not actually perform the MPI operation.
|
||||
// The clean way to go is to move any code related to mpi into the mpiwrapper class implementation and decide
|
||||
// in this class if to use mpi.h or not.
|
||||
typedef void *MPI_Comm;
|
||||
typedef enum _MPI_Datatype { MPI_CHAR, MPI_INT, MPI_FLOAT, MPI_DOUBLE, MPI_UNSIGNED, MPI_LONG_LONG_INT } MPI_Datatype;
|
||||
|
||||
#define MPI_IN_PLACE ((void*)(int)-1)
|
||||
#define MPI_SUM ((MPI_Op)0x58000003)
|
||||
|
||||
#define MPI_STATUSES_IGNORE (MPI_Status*)1
|
||||
#define MPI_STATUS_IGNORE (MPI_Status*)1
|
||||
#define MPI_UNDEFINED (-32766)
|
||||
|
||||
typedef int MPI_Op;
|
||||
typedef int MPI_Request;
|
||||
typedef void *MPI_Status;
|
||||
#endif
|
||||
|
||||
#include <errno.h>
|
||||
#include <string>
|
||||
|
@ -28,8 +46,6 @@
|
|||
|
||||
#include "CommonMatrix.h"
|
||||
|
||||
#define FFLUSH_SUCCESS 0
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
struct MpiFail : public std::string
|
||||
|
@ -40,481 +56,126 @@ struct MpiFail : public std::string
|
|||
}
|
||||
};
|
||||
|
||||
static int operator||(int rc, const MpiFail &what)
|
||||
{
|
||||
if (rc == MPI_SUCCESS)
|
||||
{
|
||||
return rc;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s, MPI error %d\n", what.c_str(), rc);
|
||||
fflush(stderr);
|
||||
|
||||
// (special case: we use that code to indicate a missing msmpi.dll...)
|
||||
if (rc != MPI_ERR_INTERN)
|
||||
{
|
||||
char errbuf[MPI_MAX_ERROR_STRING + 1] = {0};
|
||||
int len;
|
||||
MPI_Error_string(rc, &errbuf[0], &len);
|
||||
fprintf(stderr, "%s, MPI error %d: %s\n", what.c_str(), rc, errbuf);
|
||||
fflush(stderr);
|
||||
|
||||
// we abort through this, so that the MPI system gets the memo
|
||||
MPI_Abort(MPI_COMM_WORLD, rc);
|
||||
|
||||
// TODO: or does that only signal an issue, and we should still terminate ourselves?
|
||||
// BUGBUG: We'd also need to Abort through the other sub-set communicator
|
||||
}
|
||||
RuntimeError("%s", what.c_str());
|
||||
}
|
||||
extern int operator||(int rc, const MpiFail &what);
|
||||
|
||||
class MPIWrapper;
|
||||
typedef std::shared_ptr<MPIWrapper> MPIWrapperPtr;
|
||||
|
||||
extern "C" void GetMpiWrapper(MPIWrapper **mpi);
|
||||
|
||||
// Note: This is now a pure interface, so please don't add
|
||||
// any functionality to this class.
|
||||
// Instead, make your own implementation class, add/change
|
||||
// functions there as needed and use a private interface to
|
||||
// these functions.
|
||||
// In case you need to add functions that affect all
|
||||
// implementations, add a pure virtual function here and
|
||||
// update any affected implementation.
|
||||
class MPIWrapper : public std::enable_shared_from_this<MPIWrapper>
|
||||
{
|
||||
int m_myRank;
|
||||
std::wstring m_myName;
|
||||
int m_numMPINodes;
|
||||
size_t m_numNodesInUse;
|
||||
bool m_multiHost;
|
||||
|
||||
// MPI communicator that reflects the current subset selection
|
||||
MPI_Comm m_currentComm;
|
||||
|
||||
static MPIWrapperPtr s_mpi;
|
||||
|
||||
// MPI_Init() with delay-loading the msmpi.dll (possibly causing a failure if missing; we want to catch that)
|
||||
int MPI_Init_DL()
|
||||
{
|
||||
#ifdef WIN32
|
||||
__try
|
||||
#endif
|
||||
{
|
||||
// don't initialize if that has been done already
|
||||
int flag = 0;
|
||||
MPI_Initialized(&flag);
|
||||
if (flag)
|
||||
return MPI_SUCCESS;
|
||||
|
||||
int argc = 0;
|
||||
char **argv = NULL;
|
||||
// TODO(qiwye) Multiverso(parameter server) will benefit from MPI_THREAD_MULTIPLE .
|
||||
int requiredThreadLevelSupport = MPI_THREAD_SERIALIZED;
|
||||
int provided;
|
||||
int ret = MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided);
|
||||
if (provided != requiredThreadLevelSupport)
|
||||
LogicError("Failed to initialize MPI with the desired level of thread support");
|
||||
|
||||
return ret;
|
||||
}
|
||||
#ifdef WIN32
|
||||
__except (EXCEPTION_EXECUTE_HANDLER)
|
||||
{
|
||||
fprintf(stderr, "mpihelper: msmpi.dll missing\n");
|
||||
return MPI_ERR_INTERN;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Workaround for the issue with MPI hanging when we have non-0 exit codes from CNTK processes
|
||||
// OpenMPI has a confirmed race condition on killing child process vs. handling their non-zero exit statuses, resulting
|
||||
// in a deadlock, where all processes killed but MPI is still waiting.
|
||||
// This happens when several perfectly synchronized processes (for example on MPI barrier)
|
||||
// simulatenously exit with non-0 exit code.
|
||||
// As a workaround, we simply sleep 50*rank miliseconds, effectively "de-synchronizing processes" at exit,
|
||||
// allowing MPI to sequentially handle terminations
|
||||
static int s_myRank;
|
||||
static void MPIWorkaroundAtExit()
|
||||
{
|
||||
Sleep(s_myRank * 50);
|
||||
}
|
||||
|
||||
public:
|
||||
MPIWrapper()
|
||||
: m_currentComm(MPI_COMM_WORLD)
|
||||
{
|
||||
static bool initialized = false;
|
||||
if (initialized)
|
||||
{
|
||||
LogicError("MPIWrapper: this is a singleton class that can only be instantiated once per process");
|
||||
}
|
||||
MPIWrapper() {}
|
||||
virtual ~MPIWrapper() {}
|
||||
|
||||
initialized = true;
|
||||
|
||||
if (GetMathLibTraceLevel() > 0)
|
||||
{
|
||||
fprintf(stderr, "MPIWrapper: initializing MPI\n");
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
MPI_Init_DL() || MpiFail("mpiaggregator: MPI_Init");
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &m_myRank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &m_numMPINodes);
|
||||
m_numNodesInUse = m_numMPINodes;
|
||||
m_multiHost = true;
|
||||
|
||||
// Verify that the environment variable used by GetTotalNumberOfMPINodes()
|
||||
// matches what the MPI API says. There're actually two possible cases:
|
||||
// 1) when we're running with mpiexec both values have to match;
|
||||
// 2) when we're running without mpiexec, the former will return 0, and
|
||||
// the later will be set to 1.
|
||||
assert((GetTotalNumberOfMPINodes() == 0 && m_numNodesInUse == 1) ||
|
||||
(GetTotalNumberOfMPINodes() == m_numNodesInUse));
|
||||
|
||||
char name[BUFSIZ];
|
||||
int length;
|
||||
MPI_Get_processor_name(name, &length);
|
||||
m_myName = std::wstring(name, name+length);
|
||||
|
||||
// Applying MPI workaround
|
||||
s_myRank = m_myRank;
|
||||
atexit(&MPIWrapper::MPIWorkaroundAtExit);
|
||||
|
||||
// by default we use all of them
|
||||
RequestNodes("MPIWrapper");
|
||||
|
||||
if (GetMathLibTraceLevel() > 0)
|
||||
{
|
||||
if (m_numMPINodes > 1)
|
||||
fprintf(stderr, "mpihelper: we are cog %d in a gearbox of %d\n", (int) m_myRank, (int) m_numMPINodes);
|
||||
else
|
||||
fprintf(stderr, "mpihelper: only one MPI process: MPI operation will be boring\n");
|
||||
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
// do an initial handshake
|
||||
Ping("mpihelper");
|
||||
|
||||
// stagger the jobs just a little to get a sort-of deterministic order e.g. in GPU allocation when running on one machine
|
||||
// continue 0.5 seconds apart
|
||||
::Sleep((DWORD)(500 * CurrentNodeRank()));
|
||||
}
|
||||
static MPIWrapperPtr GetInstance(bool create = false);
|
||||
static void DeleteInstance();
|
||||
static MPIWrapperPtr s_mpi;
|
||||
|
||||
// Note that specifically, this function is such that it does not require
|
||||
// MPI initialization. Moreover, it can be used without actually loading any
|
||||
// MPI libs.
|
||||
// TODO: Once we move to dynamic loading for MPI libs on Linux, move it to utilities.
|
||||
static int GetTotalNumberOfMPINodes()
|
||||
{
|
||||
#ifdef WIN32
|
||||
const char* p = std::getenv("PMI_SIZE");
|
||||
#else
|
||||
const char* p = std::getenv("OMPI_COMM_WORLD_SIZE");
|
||||
#endif
|
||||
if (!p)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::stoi(string(p));
|
||||
}
|
||||
}
|
||||
static int GetTotalNumberOfMPINodes();
|
||||
|
||||
// Note: we don't clear the sub-communication here although we should, because in case of a crash, this prevents the EXE from terminating.
|
||||
// It's OK since this class is a singleton anyway that gets instantiated exactly once at program startup.
|
||||
~MPIWrapper()
|
||||
{
|
||||
if (GetMathLibTraceLevel() > 0)
|
||||
{
|
||||
fprintf(stderr, "~MPIWrapper\n");
|
||||
}
|
||||
|
||||
// Do not finalize in event of an exception since calling MPI_Finalize without
|
||||
// all pending communications being finished results in a hang
|
||||
int rc = fflush(stderr);
|
||||
if (!std::uncaught_exception())
|
||||
{
|
||||
if (rc != FFLUSH_SUCCESS)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
RuntimeError("MPIWrapper: Failed to flush stderr, %d", ::GetLastError());
|
||||
#else
|
||||
RuntimeError("MPIWrapper: Failed to flush stderr, %d", errno);
|
||||
#endif
|
||||
}
|
||||
|
||||
MPI_Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Ping(const char *msg) const
|
||||
{
|
||||
#undef USE2NDCOMM
|
||||
#ifndef USE2NDCOMM
|
||||
if (NumNodesInUse() != m_numMPINodes)
|
||||
{
|
||||
fprintf(stderr, "ping [%s]: cannot be applied to subset (%d) of nodes, skipping\n", msg, (int) NumNodesInUse());
|
||||
fflush(stderr);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
std::array<int, 1> handshake;
|
||||
handshake[0] = 1;
|
||||
|
||||
if (GetMathLibTraceLevel() > 0)
|
||||
{
|
||||
fprintf(stderr, "ping [%s]: %d nodes pinging each other\n", msg, (int) NumNodesInUse());
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
AllReduce(handshake);
|
||||
|
||||
if (GetMathLibTraceLevel() > 0)
|
||||
{
|
||||
fprintf(stderr, "ping [%s]: all %d nodes responded\n", msg, handshake[0]);
|
||||
fflush(stderr);
|
||||
}
|
||||
}
|
||||
|
||||
void RequestNodes(const char *msg, size_t requestednodes = SIZE_MAX /*default: all*/)
|
||||
{
|
||||
Ping("requestnodes (before change)");
|
||||
|
||||
// undo current split
|
||||
#ifdef USE2NDCOMM
|
||||
if (m_currentComm != MPI_COMM_WORLD /*no subset*/ && m_currentComm != MPI_COMM_NULL /*idle nodes*/)
|
||||
{
|
||||
fprintf(stderr, "requestnodes: MPI_Comm_free %x\n", (int) m_currentComm);
|
||||
fflush(stderr);
|
||||
MPI_Comm_free(&m_currentComm) || MpiFail("requestnodes: MPI_Comm_free"); // will leave MPI_COMM_NULL here
|
||||
}
|
||||
#endif
|
||||
// reset to MPI_COMM_WORLD
|
||||
m_currentComm = MPI_COMM_WORLD;
|
||||
// create a new split (unless all nodes were requested)
|
||||
if (requestednodes < (size_t) m_numMPINodes)
|
||||
{
|
||||
#ifdef USE2NDCOMM
|
||||
fprintf(stderr, "requestnodes: MPI_Comm_split %d\n", (node() < requestednodes) ? 1 : MPI_UNDEFINED);
|
||||
fflush(stderr);
|
||||
MPI_Comm_split(communicator(), (node() < requestednodes) ? 1 : MPI_UNDEFINED, 0, &m_currentComm) || MpiFail("requestnodes: MPI_Comm_split");
|
||||
fprintf(stderr, "requestnodes: MPI_Comm_split -> %x\n", (int) m_currentComm);
|
||||
fflush(stderr);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
// leave m_currentComm as MPI_COMM_WORLD
|
||||
// and clip to #nodes
|
||||
requestednodes = m_numMPINodes;
|
||||
}
|
||||
|
||||
m_numNodesInUse = requestednodes;
|
||||
|
||||
if (GetMathLibTraceLevel() > 0)
|
||||
{
|
||||
fprintf(stderr, "requestnodes [%s]: using %d out of %d MPI nodes (%d requested); we (%d) are %s\n",
|
||||
msg, (int) m_numNodesInUse, (int) m_numMPINodes, (int) requestednodes,
|
||||
(int) CurrentNodeRank(), IsIdle() ? "out (idle)" : "in (participating)");
|
||||
fflush(stderr);
|
||||
}
|
||||
Ping("requestnodes (after change)");
|
||||
|
||||
// If all ranks run on a single host, we can enable optimized communication
|
||||
// paths (e.g. NCCL). To determine if a single machine is being used, we
|
||||
// check that MPI_Get_processor_name matches for all ranks.
|
||||
const int nameMax = MPI_MAX_PROCESSOR_NAME + 1;
|
||||
char myName[nameMax] = {0};
|
||||
int myNameLen = 0;
|
||||
MPI_Get_processor_name(myName, &myNameLen) || MpiFail("requestnodes: MPI_Get_processor_name");
|
||||
myName[myNameLen] = '\0';
|
||||
|
||||
std::vector<char> nameBuffer(m_numNodesInUse * nameMax);
|
||||
char* allNames = nameBuffer.data();
|
||||
MPI_Allgather(myName, nameMax, MPI_CHAR, allNames, nameMax, MPI_CHAR, m_currentComm)
|
||||
|| MpiFail("requestnodes: MPI_Allgather");
|
||||
|
||||
m_multiHost = false;
|
||||
for(size_t i=1; i<m_numNodesInUse; i++)
|
||||
{
|
||||
if (strcmp(allNames, allNames+i*nameMax) != 0)
|
||||
{
|
||||
m_multiHost = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "requestnodes [%s]: using %d out of %d MPI nodes on %s (%d requested); we (%d) are %s\n",
|
||||
msg, (int) m_numNodesInUse, (int) m_numMPINodes, m_multiHost ? "multiple hosts" : "a single host",
|
||||
(int) requestednodes, (int) CurrentNodeRank(), IsIdle() ? "out (idle)" : "in (participating)");
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
static MPIWrapperPtr GetInstance(bool create = false)
|
||||
{
|
||||
if (create)
|
||||
{
|
||||
if (s_mpi != nullptr)
|
||||
LogicError("Creating MPIWrapper instance after a GetInstance call has been already made!");
|
||||
else
|
||||
s_mpi = std::make_shared<MPIWrapper>();
|
||||
}
|
||||
|
||||
return s_mpi;
|
||||
}
|
||||
|
||||
static void DeleteInstance()
|
||||
{
|
||||
s_mpi = nullptr;
|
||||
}
|
||||
|
||||
MPI_Comm Communicator() const
|
||||
{
|
||||
return m_currentComm;
|
||||
}
|
||||
size_t NumNodesInUse() const
|
||||
{
|
||||
return m_numNodesInUse;
|
||||
}
|
||||
size_t CurrentNodeRank() const
|
||||
{
|
||||
return m_myRank;
|
||||
}
|
||||
std::wstring CurrentNodeName() const
|
||||
{
|
||||
return m_myName;
|
||||
}
|
||||
bool IsMainNode() const
|
||||
{
|
||||
return m_myRank == 0;
|
||||
} // we are the chosen one--do extra stuff like saving the model to disk
|
||||
bool IsIdle() const
|
||||
{
|
||||
return CurrentNodeRank() >= NumNodesInUse();
|
||||
} // user had requested to not use this many nodes
|
||||
bool UsingAllNodes() const
|
||||
{
|
||||
return NumNodesInUse() == m_numMPINodes;
|
||||
} // all nodes participate (used to check whether we can use MPI_Allreduce directly)
|
||||
size_t MainNodeRank() const
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool IsMultiHost()
|
||||
{
|
||||
return m_multiHost;
|
||||
}
|
||||
virtual size_t NumNodesInUse() const = 0;
|
||||
virtual size_t CurrentNodeRank() const = 0;
|
||||
virtual bool IsMainNode() const = 0;
|
||||
virtual std::wstring CurrentNodeName() const = 0;
|
||||
virtual bool IsIdle() const = 0;
|
||||
virtual bool UsingAllNodes() const = 0;
|
||||
virtual size_t MainNodeRank() const = 0;
|
||||
virtual bool IsMultiHost() const = 0;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// data-exchange functions (wrappers around MPI functions)
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
virtual int Finalize(void) = 0;
|
||||
virtual int Wait(MPI_Request* request, MPI_Status* status) = 0;
|
||||
virtual int Waitany(int count, MPI_Request array_of_requests[], int* index, MPI_Status* status) = 0;
|
||||
virtual int Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]) = 0;
|
||||
virtual int Isend(const void* buf, int count, MPI_Datatype datatype, int dest, int tag, /*MPI_Comm comm,*/ MPI_Request* request) = 0;
|
||||
virtual int Recv(void* buf, int count, MPI_Datatype datatype, int source, int tag, /*MPI_Comm comm,*/ MPI_Status* status) = 0;
|
||||
virtual int Irecv(void* buf, int count, MPI_Datatype datatype, int source, int tag, /*MPI_Comm comm,*/ MPI_Request* request) = 0;
|
||||
virtual int Iallreduce(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, /*MPI_Comm comm,*/ MPI_Request* request) = 0;
|
||||
virtual int Abort(int errorcode) = 0;
|
||||
virtual int Error_string(int errorcode, char* string, int* resultlen) = 0;
|
||||
|
||||
|
||||
// helpers to determine the MPI_Datatype of a pointer
|
||||
static MPI_Datatype GetDataType(char *)
|
||||
{
|
||||
return MPI_CHAR;
|
||||
}
|
||||
static MPI_Datatype GetDataType(int *)
|
||||
{
|
||||
return MPI_INT;
|
||||
}
|
||||
static MPI_Datatype GetDataType(float *)
|
||||
{
|
||||
return MPI_FLOAT;
|
||||
}
|
||||
static MPI_Datatype GetDataType(double *)
|
||||
{
|
||||
return MPI_DOUBLE;
|
||||
}
|
||||
static MPI_Datatype GetDataType(size_t *)
|
||||
{
|
||||
return sizeof(size_t) == 4 ? MPI_UNSIGNED : MPI_LONG_LONG_INT;
|
||||
}
|
||||
static MPI_Datatype GetDataType(char *);
|
||||
static MPI_Datatype GetDataType(int *);
|
||||
static MPI_Datatype GetDataType(float *);
|
||||
static MPI_Datatype GetDataType(double *);
|
||||
static MPI_Datatype GetDataType(size_t *);
|
||||
|
||||
// allreduce of a vector
|
||||
template <typename VECTORLIKEOBJECT>
|
||||
void AllReduce(VECTORLIKEOBJECT &accumulator) const
|
||||
{
|
||||
auto *dataptr = accumulator.data();
|
||||
size_t totalnumelements = accumulator.size();
|
||||
|
||||
// use MPI to compute the sum over all elements in (dataptr, totalnumelements) and redistribute to all nodes
|
||||
AllReduce<typename VECTORLIKEOBJECT::value_type>(dataptr, totalnumelements);
|
||||
}
|
||||
virtual void AllReduce(std::vector<size_t>& accumulator) const = 0;
|
||||
virtual void AllReduce(std::vector<int>& accumulator) const = 0;
|
||||
virtual void AllReduce(std::vector<double>& accumulator) const = 0;
|
||||
virtual void AllReduce(std::vector<float>& accumulator) const = 0;
|
||||
|
||||
// for raw pointer
|
||||
template <class ElemType>
|
||||
void AllReduce(ElemType* sendData, size_t numElements, MPI_Op op = MPI_SUM) const
|
||||
{
|
||||
AllReduce<ElemType>(static_cast<ElemType*>(MPI_IN_PLACE), sendData, numElements, op);
|
||||
}
|
||||
virtual void AllReduce(size_t* sendData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduce(int* sendData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduce(double* sendData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduce(float* sendData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void AllReduceAsync(ElemType* sendData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const
|
||||
{
|
||||
AllReduceAsync<ElemType>(static_cast<ElemType*>(MPI_IN_PLACE), sendData, numElements, request, op);
|
||||
}
|
||||
virtual void AllReduce(size_t* sendData, size_t* receiveData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduce(int* sendData, int* receiveData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduce(double* sendData, double* receiveData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduce(float* sendData, float* receiveData, size_t numElements, MPI_Op op = MPI_SUM) const = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void AllGatherAsync(const ElemType *sendData, size_t numSendElements, ElemType *receiveData, size_t numRecvElements, MPI_Request* request) const
|
||||
{
|
||||
MPI_Iallgather(sendData, (int)numSendElements, GetDataType(receiveData), receiveData, (int)numRecvElements, GetDataType(receiveData), Communicator(), request) || MpiFail("AllReduceAsync: MPI_Iallgather");
|
||||
}
|
||||
virtual void AllReduceAsync(size_t* sendData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduceAsync(int* sendData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduceAsync(double* sendData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduceAsync(float* sendData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void AllGather(const ElemType *sendData, size_t numSendElements, ElemType *receiveData, size_t numRecvElements) const
|
||||
{
|
||||
MPI_Allgather(sendData, (int)numSendElements, GetDataType(receiveData), receiveData, (int)numRecvElements, GetDataType(receiveData), Communicator()) || MpiFail("AllReduceAsync: MPI_Allgather");
|
||||
}
|
||||
virtual void AllReduceAsync(size_t* sendData, size_t* receiveData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduceAsync(int* sendData, int* receiveData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduceAsync(double* sendData, double* receiveData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
virtual void AllReduceAsync(float* sendData, float* receiveData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void AllReduceAsync(ElemType *sendData, ElemType *receiveData, size_t numElements, MPI_Request* request, MPI_Op op = MPI_SUM) const
|
||||
{
|
||||
MPI_Iallreduce(sendData, receiveData, (int)numElements, GetDataType(sendData), op, Communicator(), request) || MpiFail("AllReduceAsync: MPI_Iallreduce");
|
||||
}
|
||||
virtual void Bcast(size_t* sendData, size_t numElements, size_t srcRank) = 0;
|
||||
virtual void Bcast(double* sendData, size_t numElements, size_t srcRank) = 0;
|
||||
virtual void Bcast(float* sendData, size_t numElements, size_t srcRank) = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void AllReduce(ElemType *sendData, ElemType *receiveData, size_t numElements, MPI_Op op = MPI_SUM) const
|
||||
{
|
||||
MPI_Allreduce(sendData, receiveData, (int)numElements, GetDataType(sendData), op, Communicator()) || MpiFail("AllReduce: MPI_Allreduce");
|
||||
}
|
||||
virtual void AllGatherAsync(const size_t *sendData, size_t numSendElements, size_t *receiveData, size_t numRecvElements, MPI_Request* request) const = 0;
|
||||
virtual void AllGatherAsync(const int *sendData, size_t numSendElements, int *receiveData, size_t numRecvElements, MPI_Request* request) const = 0;
|
||||
virtual void AllGatherAsync(const float *sendData, size_t numSendElements, float *receiveData, size_t numRecvElements, MPI_Request* request) const = 0;
|
||||
virtual void AllGatherAsync(const double *sendData, size_t numSendElements, double *receiveData, size_t numRecvElements, MPI_Request* request) const = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void Gather(const ElemType *sendData, size_t numSendElements, ElemType *receiveData, size_t numRecvElements, size_t rootRank) const
|
||||
{
|
||||
MPI_Gather(sendData, (int)numSendElements, GetDataType(receiveData), receiveData, (int)numRecvElements, GetDataType(receiveData), (int)rootRank, Communicator()) || MpiFail("AllReduceAsync: MPI_Gather");
|
||||
}
|
||||
virtual void AllGather(const size_t *sendData, size_t numSendElements, size_t *receiveData, size_t numRecvElements) const = 0;
|
||||
virtual void AllGather(const int *sendData, size_t numSendElements, int *receiveData, size_t numRecvElements) const = 0;
|
||||
virtual void AllGather(const float *sendData, size_t numSendElements, float *receiveData, size_t numRecvElements) const = 0;
|
||||
virtual void AllGather(const double *sendData, size_t numSendElements, double *receiveData, size_t numRecvElements) const = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void Gatherv(const ElemType *sendData, size_t numSendElements, ElemType *receiveData, int recvCounts[], int offsets[], size_t rootRank) const
|
||||
{
|
||||
MPI_Gatherv(sendData, (int)numSendElements, GetDataType(receiveData), receiveData, recvCounts, offsets, GetDataType(receiveData), (int)rootRank, Communicator()) || MpiFail("AllReduceAsync: MPI_Gatherv");
|
||||
}
|
||||
virtual void Gather(const size_t *sendData, size_t numSendElements, size_t *receiveData, size_t numRecvElements, size_t rootRank) const = 0;
|
||||
virtual void Gather(const int *sendData, size_t numSendElements, int *receiveData, size_t numRecvElements, size_t rootRank) const = 0;
|
||||
virtual void Gather(const float *sendData, size_t numSendElements, float *receiveData, size_t numRecvElements, size_t rootRank) const = 0;
|
||||
virtual void Gather(const double *sendData, size_t numSendElements, double *receiveData, size_t numRecvElements, size_t rootRank) const = 0;
|
||||
|
||||
template <class ElemType>
|
||||
void Bcast(ElemType *pData, size_t nData, size_t srcRank)
|
||||
{
|
||||
MPI_Bcast(pData, (int) nData, GetDataType(pData), (int) srcRank, Communicator()) || MpiFail("Bcast: MPI_Bcast");
|
||||
}
|
||||
|
||||
// wait for an async request to finish
|
||||
void Wait(MPI_Request* request)
|
||||
{
|
||||
MPI_Wait(request, MPI_STATUSES_IGNORE) || MpiFail("Wait: MPI_Wait");
|
||||
}
|
||||
|
||||
void WaitAny(MPI_Request* requests, int numRequests, int* index)
|
||||
{
|
||||
MPI_Waitany(numRequests, requests, index, MPI_STATUSES_IGNORE) || MpiFail("WaitAny: MPI_Waitany");
|
||||
}
|
||||
virtual void Gatherv(const size_t *sendData, size_t numSendElements, size_t *receiveData, int recvCounts[], int offsets[], size_t rootRank) const = 0;
|
||||
virtual void Gatherv(const char *sendData, size_t numSendElements, char *receiveData, int recvCounts[], int offsets[], size_t rootRank) const = 0;
|
||||
virtual void Gatherv(const int *sendData, size_t numSendElements, int *receiveData, int recvCounts[], int offsets[], size_t rootRank) const = 0;
|
||||
virtual void Gatherv(const float *sendData, size_t numSendElements, float *receiveData, int recvCounts[], int offsets[], size_t rootRank) const = 0;
|
||||
virtual void Gatherv(const double *sendData, size_t numSendElements, double *receiveData, int recvCounts[], int offsets[], size_t rootRank) const = 0;
|
||||
|
||||
// wait for all ranks to reach here
|
||||
void WaitAll()
|
||||
{
|
||||
MPI_Barrier(m_currentComm) || MpiFail("waitall: MPI_Barrier");
|
||||
}
|
||||
|
||||
void WaitAll(std::vector<MPI_Request>& requests)
|
||||
{
|
||||
MPI_Waitall((int)requests.size(), &requests[0], MPI_STATUSES_IGNORE) || MpiFail("waitall: MPI_Waitall");
|
||||
}
|
||||
virtual int WaitAll() = 0;
|
||||
virtual void WaitAny(MPI_Request* requests, int numRequests, int* index) = 0;
|
||||
virtual void Wait(MPI_Request* request) = 0;
|
||||
virtual int WaitAll(std::vector<MPI_Request>& requests) = 0;
|
||||
};
|
||||
|
||||
}}}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -159,11 +159,13 @@ template<class ElemType>
|
|||
/*static*/ void ComputationNode<ElemType>::BroadcastToPacked(const Matrix<ElemType>& dataToBroadcast,
|
||||
const MBLayoutPtr& inputLayout,
|
||||
Matrix<ElemType>& broadcastTo,
|
||||
const MBLayoutPtr& targetLayout,
|
||||
const FrameRange& targetFrameRange,
|
||||
const std::shared_ptr<Matrix<ElemType>>& tempIndicesStorage)
|
||||
{
|
||||
auto targetLayout = targetFrameRange.m_pMBLayout;
|
||||
|
||||
// Generate the gather indices
|
||||
std::vector<ElemType> gatherIndicesVector(targetLayout->GetNumCols(), -1);
|
||||
std::vector<ElemType> gatherIndicesVector(broadcastTo.GetNumCols(), -1);
|
||||
auto& layoutSequences = targetLayout->GetAllSequences();
|
||||
int numLayoutSequences = (int)layoutSequences.size();
|
||||
|
||||
|
@ -175,11 +177,18 @@ template<class ElemType>
|
|||
for (int layoutSequenceIdx = 0; layoutSequenceIdx < numLayoutSequences; ++layoutSequenceIdx)
|
||||
{
|
||||
auto sequenceInfo = layoutSequences[layoutSequenceIdx];
|
||||
if (sequenceInfo.seqId != GAP_SEQUENCE_ID)
|
||||
|
||||
if ((sequenceInfo.seqId != GAP_SEQUENCE_ID) &&
|
||||
(targetFrameRange.IsAllFrames() || ((sequenceInfo.tBegin <= (ptrdiff_t)(targetFrameRange.timeIdxInSeq + targetFrameRange.m_timeOffset)) && (sequenceInfo.tEnd > (targetFrameRange.timeIdxInSeq + targetFrameRange.m_timeOffset)))))
|
||||
{
|
||||
auto srcSequenceInfo = inputLayout->FindSequence(sequenceInfo.seqId);
|
||||
auto gatherFromIndex = inputLayout->GetColumnIndex(srcSequenceInfo, 0);
|
||||
auto currentSequenceColumnIndices = targetLayout->GetColumnIndices(sequenceInfo);
|
||||
std::vector<size_t> currentSequenceColumnIndices;
|
||||
if (targetFrameRange.IsAllFrames())
|
||||
currentSequenceColumnIndices = targetLayout->GetColumnIndices(sequenceInfo);
|
||||
else
|
||||
currentSequenceColumnIndices.push_back(sequenceInfo.s);
|
||||
|
||||
for (auto i : currentSequenceColumnIndices)
|
||||
gatherIndicesVector[i] = (ElemType)gatherFromIndex;
|
||||
}
|
||||
|
@ -187,9 +196,9 @@ template<class ElemType>
|
|||
|
||||
auto gatherIdxMatrix = tempIndicesStorage;
|
||||
if (!gatherIdxMatrix)
|
||||
gatherIdxMatrix = std::make_shared<Matrix<ElemType>>(1, targetLayout->GetNumCols(), gatherIndicesVector.data(), broadcastTo.GetDeviceId());
|
||||
gatherIdxMatrix = std::make_shared<Matrix<ElemType>>(1, broadcastTo.GetNumCols(), gatherIndicesVector.data(), broadcastTo.GetDeviceId());
|
||||
else
|
||||
gatherIdxMatrix->SetValue(1, targetLayout->GetNumCols(), broadcastTo.GetDeviceId(), gatherIndicesVector.data());
|
||||
gatherIdxMatrix->SetValue(1, broadcastTo.GetNumCols(), broadcastTo.GetDeviceId(), gatherIndicesVector.data());
|
||||
|
||||
broadcastTo.DoGatherColumnsOf(0, *gatherIdxMatrix, dataToBroadcast, 1);
|
||||
}
|
||||
|
|
|
@ -1445,7 +1445,7 @@ public:
|
|||
static void BroadcastToPacked(const Matrix<ElemType>& dataToBroadcast,
|
||||
const MBLayoutPtr& inputLayout,
|
||||
Matrix<ElemType>& broadcastTo,
|
||||
const MBLayoutPtr& targetLayout,
|
||||
const FrameRange& targetFrameRange,
|
||||
const std::shared_ptr<Matrix<ElemType>>& tempIndicesStorage);
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -279,7 +279,7 @@ public:
|
|||
// this does a deep value-level comparison
|
||||
m_layoutsMatch = InputRef(0).GetMBLayout() && (*m_pMBLayout == *InputRef(0).GetMBLayout());
|
||||
if (InputRef(0).GetMBLayout() && !m_layoutsMatch &&
|
||||
((InputRef(0).GetMBLayout()->GetNumTimeSteps() != 1) || (InputRef(0).GetMBLayout()->GetNumSequences() != m_pMBLayout->GetNumSequences()) || !fr.IsAllFrames()))
|
||||
((InputRef(0).GetMBLayout()->GetNumTimeSteps() != 1) || (InputRef(0).GetMBLayout()->GetNumSequences() != m_pMBLayout->GetNumSequences())))
|
||||
{
|
||||
InvalidArgument("%ls %ls operation discovered that %ls %ls operation produced an MB layout that is incompatible with that of %ls %ls.",
|
||||
NodeName().c_str(), OperationName().c_str(),
|
||||
|
@ -300,7 +300,7 @@ public:
|
|||
{
|
||||
// Broadcast along the sequence
|
||||
auto result = ValueFor(fr);
|
||||
ComputationNode<ElemType>::BroadcastToPacked(InputRef(0).Value(), InputRef(0).GetMBLayout(), result, m_pMBLayout, m_tempGatherIndices);
|
||||
ComputationNode<ElemType>::BroadcastToPacked(InputRef(0).Value(), InputRef(0).GetMBLayout(), result, fr, m_tempGatherIndices);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -308,37 +308,35 @@ public:
|
|||
{
|
||||
if (inputIndex == 0)
|
||||
{
|
||||
size_t rank = GetSampleLayout().GetRank();
|
||||
|
||||
// if reduction then mask the respective input(s) (zero out the gaps)
|
||||
if (Input(inputIndex)->ReducesInTimeWrt(shared_from_this()))
|
||||
MaskMissingGradientColumnsToZero(fr);
|
||||
|
||||
TensorView<ElemType> gradient;
|
||||
TensorView<ElemType> inputGradient;
|
||||
if (!InputRef(0).GetMBLayout() || m_layoutsMatch)
|
||||
{
|
||||
size_t rank = GetSampleLayout().GetRank();
|
||||
auto gradient = GradientTensorFor(rank, fr);
|
||||
auto inputGradient = Input(inputIndex)->GradientTensorFor(rank, InputRef(inputIndex).GetMBLayout() ? fr.WithLayout(InputRef(inputIndex).GetMBLayout()) : fr.AllowBroadcast());
|
||||
|
||||
// if reduction then mask the respective input(s) (zero out the gaps)
|
||||
if (Input(inputIndex)->ReducesInTimeWrt(shared_from_this()))
|
||||
MaskMissingGradientColumnsToZero(fr);
|
||||
|
||||
if (Input(inputIndex)->ParentOverwritesGradient())
|
||||
inputGradient.AssignCopyOf(gradient);
|
||||
else
|
||||
inputGradient.AddCopyOf(gradient);
|
||||
|
||||
// TODO: Once we do in-place, the above must include a copy-to-self check (pay special attention to adding vs. copying).
|
||||
gradient = GradientTensorFor(rank, fr);
|
||||
inputGradient = Input(inputIndex)->GradientTensorFor(rank, InputRef(inputIndex).GetMBLayout() ? fr.WithLayout(InputRef(inputIndex).GetMBLayout()) : fr.AllowBroadcast());
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(fr.IsAllFrames());
|
||||
// Broadcasting along the sequence
|
||||
if (!fr.IsAllFrames())
|
||||
InvalidArgument("%ls %ls operation does not support broadcasting the left operand to the right operand's MB layout, inside a recurrent loop.", NodeName().c_str(), OperationName().c_str());
|
||||
|
||||
MaskMissingGradientColumnsToZero(fr);
|
||||
auto unpackedGradientTensor = ComputationNode<ElemType>::Unpack(GetSampleLayout(), GradientFor(fr), m_pMBLayout, m_tempUnpackedData, m_tempScatterIndices, /*batchMajor=*/ true, /*maskGaps=*/ true);
|
||||
|
||||
size_t rank = GetSampleLayout().GetRank();
|
||||
auto inputGradient = Input(inputIndex)->GradientTensorFor(rank, FrameRange(InputRef(inputIndex).GetMBLayout(), 0));
|
||||
if (Input(inputIndex)->ParentOverwritesGradient())
|
||||
inputGradient.AssignCopyOf(unpackedGradientTensor);
|
||||
else
|
||||
inputGradient.AddCopyOf(unpackedGradientTensor);
|
||||
gradient = ComputationNode<ElemType>::Unpack(GetSampleLayout(), GradientFor(fr), m_pMBLayout, m_tempUnpackedData, m_tempScatterIndices, /*batchMajor=*/ true, /*maskGaps=*/ true);
|
||||
inputGradient = Input(inputIndex)->GradientTensorFor(rank, FrameRange(InputRef(inputIndex).GetMBLayout(), 0));
|
||||
}
|
||||
|
||||
if (Input(inputIndex)->ParentOverwritesGradient())
|
||||
inputGradient.AssignCopyOf(gradient);
|
||||
else
|
||||
inputGradient.AddCopyOf(gradient);
|
||||
|
||||
// TODO: Once we do in-place, the above must include a copy-to-self check (pay special attention to adding vs. copying).
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -60,8 +60,8 @@
|
|||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>$(OutDir)</AdditionalLibraryDirectories>
|
||||
<AdditionalDependencies>EvalDLL.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<DelayLoadDLLs>EvalDll.dll</DelayLoadDLLs>
|
||||
<AdditionalDependencies>EvalDLL.lib;Math.lib;Common.lib;$(MSMPI_LIB64)msmpi.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<DelayLoadDLLs>EvalDll.dll;Math.dll</DelayLoadDLLs>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(DebugBuild)">
|
||||
|
|
|
@ -2374,7 +2374,9 @@ ElemType GPUMatrix<ElemType>::AbsoluteMax() const
|
|||
int resInd = 0;
|
||||
cublasIdamax(cuHandle, (CUDA_LONG)GetNumElements(), reinterpret_cast<double*>(Data()), 1, &resInd);
|
||||
resInd--;
|
||||
|
||||
CUDA_CALL(cudaMemcpy(reinterpret_cast<double*>(&res), Data() + resInd, sizeof(double), cudaMemcpyDeviceToHost));
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -224,7 +224,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
if (dest != m_myRank)
|
||||
{
|
||||
MPI_Isend(&sentSignal, 1, MPI_INT, dest, m_numSyncPerformed, m_pMPI->Communicator() , &sendRequests[dest]);
|
||||
m_pMPI->Isend(&sentSignal, 1, MPI_INT, dest, m_numSyncPerformed, &sendRequests[dest]);
|
||||
}
|
||||
}
|
||||
// 2. recv others
|
||||
|
@ -234,7 +234,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
int recvSignal = 0;
|
||||
MPI_Status status;
|
||||
MPI_Recv(&recvSignal, 1, MPI_INT, src, m_numSyncPerformed, m_pMPI->Communicator(), &status);
|
||||
m_pMPI->Recv(&recvSignal, 1, MPI_INT, src, m_numSyncPerformed, &status);
|
||||
m_MAworkerStatus[src] = (MAWorkerStatus)recvSignal;
|
||||
#if 0
|
||||
assert(status.MPI_SOURCE == src);
|
||||
|
@ -247,7 +247,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
if (dest != m_myRank)
|
||||
{
|
||||
MPI_Wait(&sendRequests[dest], MPI_STATUS_IGNORE);
|
||||
m_pMPI->Wait(&sendRequests[dest], MPI_STATUS_IGNORE);
|
||||
}
|
||||
}
|
||||
retval = true;
|
||||
|
@ -266,7 +266,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
if (dest != m_myRank)
|
||||
{
|
||||
MPI_Isend(&sentSignal, 1, MPI_INT, dest, m_numSyncPerformed, m_pMPI->Communicator(), &sendRequests[dest]);
|
||||
m_pMPI->Isend(&sentSignal, 1, MPI_INT, dest, m_numSyncPerformed, &sendRequests[dest]);
|
||||
}
|
||||
}
|
||||
// 2. recv status from others (blocking call)
|
||||
|
@ -276,7 +276,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
int recvSignal = 0;
|
||||
MPI_Status status;
|
||||
MPI_Recv(&recvSignal, 1, MPI_INT, src, m_numSyncPerformed, m_pMPI->Communicator(), &status);
|
||||
m_pMPI->Recv(&recvSignal, 1, MPI_INT, src, m_numSyncPerformed, &status);
|
||||
#if 0
|
||||
// for debugging purpose, to be removed when mature
|
||||
assert(status.MPI_SOURCE == src);
|
||||
|
@ -290,7 +290,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
if (dest != m_myRank)
|
||||
{
|
||||
MPI_Wait(&sendRequests[dest], MPI_STATUS_IGNORE);
|
||||
m_pMPI->Wait(&sendRequests[dest], MPI_STATUS_IGNORE);
|
||||
}
|
||||
}
|
||||
// 4. check peer status again
|
||||
|
@ -318,7 +318,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
size_t m_numWorkers;
|
||||
size_t m_myRank;
|
||||
MASGDPerfStats m_perfReporter;
|
||||
MPIWrapperPtr m_pMPI;
|
||||
MPIWrapperPtr m_pMPI;
|
||||
DEVICEID_TYPE m_deviceId;
|
||||
};
|
||||
|
||||
|
|
|
@ -209,8 +209,15 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
}
|
||||
|
||||
std::vector<ComputationNodeBasePtr> additionalNodesToEvaluate;
|
||||
auto& outputNodes = net->OutputNodes();
|
||||
additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), outputNodes.cbegin(), outputNodes.cend());
|
||||
|
||||
// Do not include the output nodes in the matrix sharing structure when using forward value matrix
|
||||
// sharing, since the output nodes are only used for AttemptUtteranceDerivativeFeatures functionality
|
||||
// which does not work properly with forward value matrix sharing.
|
||||
if (!Globals::ShouldEnableShareNodeValueMatrices())
|
||||
{
|
||||
auto& outputNodes = net->OutputNodes();
|
||||
additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), outputNodes.cbegin(), outputNodes.cend());
|
||||
}
|
||||
|
||||
auto preComputeNodesList = net->GetNodesRequiringPreComputation();
|
||||
additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), preComputeNodesList.cbegin(), preComputeNodesList.cend());
|
||||
|
@ -2063,6 +2070,10 @@ void SGD<ElemType>::AttemptUtteranceDerivativeFeatures(ComputationNetworkPtr net
|
|||
if (outputNodes.empty())
|
||||
LogicError("no output node was found.");
|
||||
|
||||
if (Globals::ShouldEnableShareNodeValueMatrices())
|
||||
InvalidArgument("AttemptUtteranceDerivativeFeatures cannot be used together with forward value memory sharing. "
|
||||
"Set 'shareNodeValueMatrices=false' at the top level of your CNTK config file to get around this error");
|
||||
|
||||
// BUGBUG (Issue #95): This is no longer correct once we have multiple input layouts.
|
||||
trainSetDataReader->CopyMBLayoutTo(net->GetMBLayoutPtrOfNetwork());
|
||||
net->ForwardProp(outputNodes[0]); // only evaluate the first output
|
||||
|
|
|
@ -238,14 +238,14 @@ private:
|
|||
{
|
||||
int source = (j >= MyRank()) ? (j + 1) : j;
|
||||
// We use a tag of 'numGradMatrices' for the pre-aggregation header
|
||||
MPI_Irecv(m_recvHeaders[j], m_recvHeaders[j]->Size(), MPI_CHAR, source, numGradMatrices, m_mpi->Communicator(), &(recvHeaderRequests[j])) || MpiFail("MPI_Irecv");
|
||||
m_mpi->Irecv(m_recvHeaders[j], m_recvHeaders[j]->Size(), MPI_CHAR, source, numGradMatrices, &(recvHeaderRequests[j])) || MpiFail("MPI_Irecv");
|
||||
}
|
||||
}
|
||||
|
||||
// Send the headers from all nodes but the main node
|
||||
MPI_Request sendHeaderRequest;
|
||||
if (!m_mpi->IsMainNode())
|
||||
MPI_Isend(headerCPU, headerCPU->Size(), MPI_CHAR, m_mpi->MainNodeRank(), numGradMatrices, m_mpi->Communicator(), &sendHeaderRequest) || MpiFail("MPI_Isend");
|
||||
m_mpi->Isend(headerCPU, headerCPU->Size(), MPI_CHAR, m_mpi->MainNodeRank(), numGradMatrices, &sendHeaderRequest) || MpiFail("MPI_Isend");
|
||||
|
||||
// Perform async allreduce on the gradient data
|
||||
std::vector<MPI_Request> allReduceRequests(numGradMatrices);
|
||||
|
@ -261,9 +261,9 @@ private:
|
|||
}
|
||||
|
||||
// On Windows this async MPI_Iallreduce call requires MS MPI v7 or higher to be installed
|
||||
MPI_Iallreduce(MPI_IN_PLACE, reductionBuffer, gradients[i]->GetNumElements(),
|
||||
m_mpi->Iallreduce(MPI_IN_PLACE, reductionBuffer, gradients[i]->GetNumElements(),
|
||||
MPIWrapper::GetDataType(reductionBuffer), MPI_SUM,
|
||||
m_mpi->Communicator(), &allReduceRequests[i]) || MpiFail("MPI_Iallreduce");
|
||||
&allReduceRequests[i]) || MpiFail("MPI_Iallreduce");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -276,7 +276,7 @@ private:
|
|||
while (numNodesHeadersReceivedFrom < (NumProc() - 1))
|
||||
{
|
||||
int idx = MPI_UNDEFINED;
|
||||
MPI_Waitany(recvHeaderRequests.size(), recvHeaderRequests.data(), &idx, MPI_STATUS_IGNORE) || MpiFail("MPI_Waitany");
|
||||
m_mpi->Waitany(recvHeaderRequests.size(), recvHeaderRequests.data(), &idx, MPI_STATUS_IGNORE) || MpiFail("MPI_Waitany");
|
||||
if (idx == MPI_UNDEFINED)
|
||||
{
|
||||
break;
|
||||
|
@ -293,7 +293,7 @@ private:
|
|||
// Initiate receive of the aggregate header
|
||||
MPI_Request recvAggHeaderRequest;
|
||||
if (!m_mpi->IsMainNode())
|
||||
MPI_Irecv(headerCPU, headerCPU->Size(), MPI_CHAR, m_mpi->MainNodeRank(), numGradMatrices + 1 + numGradMatrices, m_mpi->Communicator(), &recvAggHeaderRequest) || MpiFail("MPI_Irecv");
|
||||
m_mpi->Irecv(headerCPU, headerCPU->Size(), MPI_CHAR, m_mpi->MainNodeRank(), numGradMatrices + 1 + numGradMatrices, &recvAggHeaderRequest) || MpiFail("MPI_Irecv");
|
||||
|
||||
// Intiate send of the aggregate header from main node
|
||||
std::vector<MPI_Request> sendAggHeaderRequests(NumProc() - 1);
|
||||
|
@ -303,7 +303,7 @@ private:
|
|||
{
|
||||
int dest = (j >= MyRank()) ? (j + 1) : j;
|
||||
// TODO: Should we use MPI_Bcast instead for better performance
|
||||
MPI_Isend(headerCPU, headerCPU->Size(), MPI_CHAR, dest, numGradMatrices + 1 + numGradMatrices, m_mpi->Communicator(), &(sendAggHeaderRequests[j])) || MpiFail("MPI_Isend");
|
||||
m_mpi->Isend(headerCPU, headerCPU->Size(), MPI_CHAR, dest, numGradMatrices + 1 + numGradMatrices, &(sendAggHeaderRequests[j])) || MpiFail("MPI_Isend");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -312,7 +312,7 @@ private:
|
|||
{
|
||||
for (size_t i = 0; i < numGradMatrices; ++i)
|
||||
{
|
||||
MPI_Wait(&allReduceRequests[i], MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
|
||||
m_mpi->Wait(&allReduceRequests[i], MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
|
||||
if (deviceId >= 0)
|
||||
m_gpuDataTransferers[i]->CopyCPUToGPUAsync(m_intermediateCPUBuffers[i].get(), gradients[i]->GetNumElements(), gradients[i]->Data());
|
||||
}
|
||||
|
@ -320,7 +320,7 @@ private:
|
|||
|
||||
// Wait to receive aggregate header
|
||||
if (!m_mpi->IsMainNode())
|
||||
MPI_Wait(&recvAggHeaderRequest, MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
|
||||
m_mpi->Wait(&recvAggHeaderRequest, MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
|
||||
|
||||
// Wait for all the transfers to finish
|
||||
if (m_nccl.IsSupported())
|
||||
|
@ -333,9 +333,9 @@ private:
|
|||
|
||||
// Wait for completion of the async send requests
|
||||
if (!m_mpi->IsMainNode())
|
||||
MPI_Wait(&sendHeaderRequest, MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
|
||||
m_mpi->Wait(&sendHeaderRequest, MPI_STATUSES_IGNORE) || MpiFail("MPI_Wait");
|
||||
else
|
||||
MPI_Waitall(sendAggHeaderRequests.size(), sendAggHeaderRequests.data(), MPI_STATUSES_IGNORE) || MpiFail("MPI_Waitall");
|
||||
m_mpi->Waitall(sendAggHeaderRequests.size(), sendAggHeaderRequests.data(), MPI_STATUSES_IGNORE) || MpiFail("MPI_Waitall");
|
||||
|
||||
if (showSyncPerfStats)
|
||||
{
|
||||
|
|
|
@ -17,5 +17,5 @@ def test_char_rnn(device_id):
|
|||
set_default_device(cntk_device(device_id))
|
||||
|
||||
# Just run and verify it does not crash
|
||||
output = train_and_eval_char_rnn(200)
|
||||
output = train_and_eval_char_rnn(1, 200)
|
||||
print(output)
|
||||
|
|
|
@ -483,7 +483,7 @@
|
|||
" }\n",
|
||||
"\n",
|
||||
" log_number_of_parameters(z) ; print()\n",
|
||||
" progress_printer = ProgressPrinter(tag='Training')\n",
|
||||
" progress_printer = ProgressPrinter(tag='Training', num_epochs=max_epochs)\n",
|
||||
"\n",
|
||||
" # perform model training\n",
|
||||
" batch_index = 0\n",
|
||||
|
|
|
@ -437,8 +437,8 @@
|
|||
"\n",
|
||||
" # process minibatches and perform model training\n",
|
||||
" log_number_of_parameters(model)\n",
|
||||
" progress_printer = ProgressPrinter(tag='Training')\n",
|
||||
" #progress_printer = ProgressPrinter(freq=100, first=10, tag='Training') # more detailed logging\n",
|
||||
" progress_printer = ProgressPrinter(tag='Training', num_epochs=max_epochs)\n",
|
||||
" #progress_printer = ProgressPrinter(freq=100, first=10, tag='Training', num_epochs=max_epochs) # more detailed logging\n",
|
||||
"\n",
|
||||
" t = 0\n",
|
||||
" for epoch in range(max_epochs): # loop over epochs\n",
|
||||
|
@ -560,7 +560,7 @@
|
|||
" dummy_learner = adam_sgd(criterion.parameters, \n",
|
||||
" lr=lr_schedule, momentum=momentum_as_time_constant, low_memory=True)\n",
|
||||
" evaluator = Trainer(model, criterion, dummy_learner)\n",
|
||||
" progress_printer = ProgressPrinter(tag='Evaluation')\n",
|
||||
" progress_printer = ProgressPrinter(tag='Evaluation', num_epochs=0)\n",
|
||||
"\n",
|
||||
" while True:\n",
|
||||
" minibatch_size = 500\n",
|
||||
|
|
|
@ -278,6 +278,7 @@
|
|||
%ignore_function CNTK::Internal::DisableProfiler;
|
||||
%ignore_function CNTK::Internal::AreEquivalent;
|
||||
%ignore_function CNTK::Internal::AreEqual;
|
||||
%ignore_function CNTK::PrintBuiltInfo;
|
||||
|
||||
// map the pointer to array
|
||||
%apply float INPUT[] { float *dataBuffer }
|
||||
|
|
|
@ -1086,7 +1086,7 @@ def log_add_exp(left, right, name=''):
|
|||
def times(left, right, output_rank=1, infer_input_rank_to_map=-1, name=''):
|
||||
'''
|
||||
The output of this operation is the matrix product of the two input matrices.
|
||||
It supports broadcasting. Sparse is supported in the right operand, if it is a matrix.
|
||||
It supports broadcasting. Sparse is supported in the left operand, if it is a matrix.
|
||||
The operator '@' has been overloaded such that in Python 3.5 and later X @ W equals times(X, W).
|
||||
|
||||
Example:
|
||||
|
|
|
@ -431,3 +431,51 @@ def test_op_scatter_sparse(device_id):
|
|||
res = a_last_scatter_dense.eval({a : input_data})
|
||||
assert np.array_equal(res[0], np.asarray([[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]))
|
||||
assert np.array_equal(res[1], np.asarray([[0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0]]))
|
||||
|
||||
|
||||
def test_op_broadcast_as(device_id, precision):
|
||||
from .. import sequence
|
||||
|
||||
a_data = [AA([1], dtype=PRECISION_TO_TYPE[precision]), AA([2], dtype=PRECISION_TO_TYPE[precision]), AA([3], dtype=PRECISION_TO_TYPE[precision])]
|
||||
b_data = [AA([[2]], dtype=PRECISION_TO_TYPE[precision]), AA([[2], [3]], dtype=PRECISION_TO_TYPE[precision]), AA([[2], [3], [4]], dtype=PRECISION_TO_TYPE[precision])]
|
||||
|
||||
a = I(shape=(1,),
|
||||
dtype=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
|
||||
name='a',
|
||||
dynamic_axes=[Axis.default_batch_axis()])
|
||||
|
||||
b = I(shape=(1,),
|
||||
dtype=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
|
||||
name='b')
|
||||
|
||||
broadcast_a_as_b = sequence.broadcast_as(a, b)
|
||||
|
||||
res = broadcast_a_as_b.eval({a: a_data, b: b_data})
|
||||
assert np.array_equal(res[0], np.asarray([[1.]]))
|
||||
assert np.array_equal(res[1], np.asarray([[2.], [2.]]))
|
||||
assert np.array_equal(res[2], np.asarray([[3.], [3.], [3.]]))
|
||||
|
||||
|
||||
def test_op_broadcast_as_in_loop(device_id):
|
||||
from .. import sequence, placeholder_variable, past_value
|
||||
|
||||
a_data = [AA([1]), AA([2]), AA([3])]
|
||||
b_data = [AA([[2]]), AA([[2], [3]]), AA([[2], [3], [4]])]
|
||||
|
||||
a = I(shape=(1,),
|
||||
name='a',
|
||||
dynamic_axes=[Axis.default_batch_axis()])
|
||||
|
||||
b = I(shape=(1,),
|
||||
name='b')
|
||||
|
||||
out_placeholder = placeholder_variable()
|
||||
out_delayed = past_value(out_placeholder, time_step=5)
|
||||
out_delayed_plus_b = out_delayed + b
|
||||
out = sequence.broadcast_as(a, out_delayed_plus_b)
|
||||
out.replace_placeholder(out)
|
||||
|
||||
res = out.eval({a: a_data, b: b_data})
|
||||
assert np.array_equal(res[0], np.asarray([[1.]]))
|
||||
assert np.array_equal(res[1], np.asarray([[2.], [2.]]))
|
||||
assert np.array_equal(res[2], np.asarray([[3.], [3.], [3.]]))
|
||||
|
|
|
@ -13,7 +13,7 @@ import numpy as np
|
|||
from cntk import *
|
||||
from cntk.learner import *
|
||||
from cntk.ops import *
|
||||
from .ops_test_utils import cntk_device
|
||||
from cntk.ops.tests.ops_test_utils import cntk_device
|
||||
from cntk.ops.functions import UserFunction
|
||||
|
||||
from cntk.utils import get_train_eval_criterion, get_train_loss
|
||||
|
@ -46,8 +46,9 @@ def linear_layer(input_var, output_dim):
|
|||
|
||||
def dense_layer(input, output_dim, nonlinearity):
|
||||
r = linear_layer(input, output_dim)
|
||||
r = nonlinearity(r)
|
||||
if isinstance(nonlinearity, UserFunction):
|
||||
r = user_function(nonlinearity(r))
|
||||
r = user_function(r)
|
||||
return r
|
||||
|
||||
def fully_connected_classifier_net(input, num_output_classes, hidden_layer_dim,
|
||||
|
@ -160,4 +161,7 @@ def measure_runtime(device_id):
|
|||
print("%i\t%.2f\t%.2f"%(num_hidden_layers, min(timings_my_sigmoid), min(timings_sigmoid)))
|
||||
|
||||
if __name__=='__main__':
|
||||
measure_runtime()
|
||||
print("CPU")
|
||||
measure_runtime(-1)
|
||||
print("GPU")
|
||||
measure_runtime(0)
|
||||
|
|
|
@ -7,8 +7,7 @@ from __future__ import print_function
|
|||
import os
|
||||
import time
|
||||
|
||||
from cntk.cntk_py import TensorBoardFileWriter
|
||||
|
||||
from cntk.cntk_py import TensorBoardFileWriter, print_built_info
|
||||
|
||||
# TODO: Let's switch to import logging in the future instead of print. [ebarsoum]
|
||||
class ProgressPrinter(object):
|
||||
|
@ -66,6 +65,9 @@ class ProgressPrinter(object):
|
|||
self.gen_heartbeat = gen_heartbeat
|
||||
self.num_epochs = num_epochs
|
||||
self.trainer = None
|
||||
|
||||
# print out data about CNTK build
|
||||
print_built_info()
|
||||
|
||||
# Create TensorBoardFileWriter if the path to a log directory was provided.
|
||||
self.tensorboard_writer = None
|
||||
|
|
|
@ -88,6 +88,9 @@ declare -A py_paths
|
|||
|
||||
mathlib=
|
||||
|
||||
have_mpi=yes
|
||||
default_use_mpi=$have_mpi
|
||||
|
||||
default_use_1bitsgd=no
|
||||
enable_1bitsgd=$default_use_1bitsgd
|
||||
|
||||
|
@ -337,6 +340,7 @@ function show_help ()
|
|||
echo " --asgd[=(yes|no)] use ASGD powered by Multiverso $(show_default $(default_use_asgd))"
|
||||
echo " --cuda[=(yes|no)] use cuda GPU $(show_default $(default_use_cuda))"
|
||||
echo " --python[=(yes|no)] with Python bindings $(show_default $(default_use_python))"
|
||||
echo " --mpi[=(yes|no)] use MPI communication $(show_default ${default_use_mpi})"
|
||||
echo " --with-cuda[=directory] $(show_default $(find_cuda))"
|
||||
echo " --with-cub[=directory] $(show_default $(find_cub))"
|
||||
echo " --with-gdk-include[=directory] $(show_default $(find_gdk_include))"
|
||||
|
@ -485,7 +489,7 @@ do
|
|||
echo "Cannot find Python $py_version directory."
|
||||
echo "Please specify a value for $key"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
if check_python $py_version "$optarg"
|
||||
then
|
||||
|
@ -519,6 +523,17 @@ do
|
|||
fi
|
||||
;;
|
||||
|
||||
--mpi*)
|
||||
if test x$optarg = xyes || test x$optarg = xno
|
||||
then
|
||||
have_mpi=$optarg
|
||||
else
|
||||
echo "Invalid value for --mpi $optarg"
|
||||
show_help
|
||||
exit
|
||||
fi
|
||||
;;
|
||||
|
||||
--with-cuda*)
|
||||
enable_cuda=yes
|
||||
if test x$optarg = x
|
||||
|
@ -1063,6 +1078,11 @@ case $mathlib in
|
|||
echo OPENBLAS_PATH=$openblas_path >> $config
|
||||
;;
|
||||
esac
|
||||
if test $have_mpi = yes ; then
|
||||
echo HAS_MPI=1 >> $config
|
||||
else
|
||||
echo HAS_MPI=0 >> $config
|
||||
fi
|
||||
if test $enable_cuda = yes ; then
|
||||
echo CUDA_PATH=$cuda_path >> $config
|
||||
echo GDK_INCLUDE_PATH=$gdk_include_path >> $config
|
||||
|
|
Загрузка…
Ссылка в новой задаче