Fix mem leak cause by shared_ptr UserBackPropState

This commit is contained in:
Willi Richert 2017-05-08 13:45:53 +02:00
Родитель 012108315d
Коммит 712d3ed86f
4 изменённых файлов: 67 добавлений и 25 удалений

Просмотреть файл

@ -1822,12 +1822,16 @@ namespace CNTK {
class UserBackPropState;
typedef std::shared_ptr<UserBackPropState> UserBackPropStatePtr;
class UserBackPropState : public BackPropState {
class UserBackPropState : public BackPropState
{
template <typename T, typename ...CtorArgTypes>
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
public:
UserBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, PyObject* userData)
: BackPropState(function, computeDevice), m_userData(userData)
static BackPropStatePtr Create(const FunctionPtr& function, const DeviceDescriptor& computeDevice, PyObject* userData)
{
Py_INCREF(m_userData);
return MakeSharedObject<UserBackPropState>(function, computeDevice, userData);
}
const PyObject* Data() const
@ -1851,6 +1855,12 @@ namespace CNTK {
}
private:
UserBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, PyObject* userData)
: BackPropState(function, computeDevice), m_userData(userData)
{
Py_INCREF(m_userData);
}
const PyObject* m_userData;
};
}

Просмотреть файл

@ -1288,7 +1288,7 @@ class UserFunction(Function):
# Since the state will frequently not be used, we cache the None-state
# to speed up.
self._none_state = cntk_py.UserBackPropState(self, cpu(), None)
self._none_state = cntk_py.UserBackPropState.create(self, cpu(), None)
# Memory management for user defined functions has to be controlled by
# the C++ side. For more information:
@ -1297,7 +1297,7 @@ class UserFunction(Function):
def _get_none_state(self, device=cpu()):
if self._none_state.device() != device:
self._none_state = cntk_py.UserBackPropState(self, device, None)
self._none_state = cntk_py.UserBackPropState.create(self, device, None)
return self._none_state
@ -1339,7 +1339,7 @@ class UserFunction(Function):
if state is None:
state = self._get_none_state(device)
elif not isinstance(state, cntk_py.BackPropState):
state = cntk_py.UserBackPropState(self, device, state)
state = cntk_py.UserBackPropState.create(self, device, state)
if self.as_numpy:
for k,v in outputs.items():

Просмотреть файл

@ -35,6 +35,16 @@ def cntk_device(device_id):
return gpu(device_id)
def mem_used():
'''
Return the non-swapped physical memory the Python process is using.
'''
import os
import psutil
process = psutil.Process(os.getpid())
return process.memory_info().rss
def _test_unary_op(precision, device_id, op_func,
value, expected_forward, expected_backward_all, op_param_dict={}):

Просмотреть файл

@ -13,7 +13,7 @@ import numpy as np
from cntk import *
from cntk.learners import *
from cntk.ops import *
from cntk.ops.tests.ops_test_utils import cntk_device
from cntk.ops.tests.ops_test_utils import cntk_device, mem_used
from cntk.ops.functions import UserFunction
np.random.seed(0)
@ -29,7 +29,7 @@ def generate_random_data_sample(sample_size, feature_dim, num_classes):
# Make sure that the data is separable
X = (np.random.randn(sample_size, feature_dim)+3) * (Y+1)
X = X.astype(np.float32)
class_ind = [Y==class_number for class_number in range(num_classes)]
class_ind = [Y == class_number for class_number in range(num_classes)]
Y = np.asarray(np.hstack(class_ind), dtype=np.float32)
return X, Y
@ -57,17 +57,20 @@ def fully_connected_classifier_net(inp, num_output_classes, hidden_layer_dim,
r = linear_layer(h, num_output_classes)
return r
def print_training_progress(trainer, mb, frequency):
training_loss = "NA"
eval_error = "NA"
if mb%frequency == 0:
if mb % frequency == 0:
training_loss = trainer.previous_minibatch_loss_average
eval_error = trainer.previous_minibatch_evaluation_average
return mb, training_loss, eval_error
def train(nonlinearity, num_hidden_layers, device_id):
def train(nonlinearity, num_hidden_layers, device_id,
minibatch_size=10, num_samples=10000):
from cntk.cntk_py import always_allow_setting_default_device
always_allow_setting_default_device()
try_set_default_device(cntk_device(device_id))
@ -93,27 +96,46 @@ def train(nonlinearity, num_hidden_layers, device_id):
learner = sgd(z.parameters, lr_schedule)
trainer = Trainer(z, (loss, eval_error), [learner])
minibatch_size = 25
num_samples = 2500
num_minibatches_to_train = num_samples / minibatch_size
num_minibatches_to_train = int(num_samples / minibatch_size)
training_progress_output_freq = 20
losses = []
errors = []
# Preallocate so that we don't measure the memory incrase
losses = [0]*num_minibatches_to_train
errors = [0]*num_minibatches_to_train
for i in range(0, int(num_minibatches_to_train)):
features, labels = generate_random_data_sample(minibatch_size, input_dim, num_output_classes)
mem = [0]*num_minibatches_to_train
# Accept at most 500K memory increase. This is in line with the non-UDF,
# pure CNTK usage.
MEMORY_THRESH = 500 * 1024
i = 0
while i < num_minibatches_to_train:
mem[i] = mem_used()
features, labels = generate_random_data_sample(minibatch_size,
input_dim,
num_output_classes)
# Specify the input variables mapping in the model to actual minibatch
# data for training.
trainer.train_minibatch({inp: features, label: labels},
device=cntk_device(device_id))
# Specify the input variables mapping in the model to actual minibatch data for training
trainer.train_minibatch({inp : features, label : labels},
device=cntk_device(device_id))
batchsize, loss, error = print_training_progress(trainer, i,
training_progress_output_freq)
if not (loss == "NA" or error =="NA"):
losses.append(loss)
errors.append(error)
if not (loss == "NA" or error == "NA"):
losses[i] = loss
errors[i] = error
i += 1
mem_diff = mem[-1] - mem[10]
if mem_diff > MEMORY_THRESH:
raise ValueError('Memory leak detected with %s: %i' %
(nonlinearity, mem_diff))
return losses, errors