Fix mem leak cause by shared_ptr UserBackPropState
This commit is contained in:
Родитель
012108315d
Коммит
712d3ed86f
|
@ -1822,12 +1822,16 @@ namespace CNTK {
|
|||
class UserBackPropState;
|
||||
typedef std::shared_ptr<UserBackPropState> UserBackPropStatePtr;
|
||||
|
||||
class UserBackPropState : public BackPropState {
|
||||
class UserBackPropState : public BackPropState
|
||||
{
|
||||
|
||||
template <typename T, typename ...CtorArgTypes>
|
||||
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
|
||||
|
||||
public:
|
||||
UserBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, PyObject* userData)
|
||||
: BackPropState(function, computeDevice), m_userData(userData)
|
||||
static BackPropStatePtr Create(const FunctionPtr& function, const DeviceDescriptor& computeDevice, PyObject* userData)
|
||||
{
|
||||
Py_INCREF(m_userData);
|
||||
return MakeSharedObject<UserBackPropState>(function, computeDevice, userData);
|
||||
}
|
||||
|
||||
const PyObject* Data() const
|
||||
|
@ -1851,6 +1855,12 @@ namespace CNTK {
|
|||
}
|
||||
|
||||
private:
|
||||
UserBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, PyObject* userData)
|
||||
: BackPropState(function, computeDevice), m_userData(userData)
|
||||
{
|
||||
Py_INCREF(m_userData);
|
||||
}
|
||||
|
||||
const PyObject* m_userData;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1288,7 +1288,7 @@ class UserFunction(Function):
|
|||
|
||||
# Since the state will frequently not be used, we cache the None-state
|
||||
# to speed up.
|
||||
self._none_state = cntk_py.UserBackPropState(self, cpu(), None)
|
||||
self._none_state = cntk_py.UserBackPropState.create(self, cpu(), None)
|
||||
|
||||
# Memory management for user defined functions has to be controlled by
|
||||
# the C++ side. For more information:
|
||||
|
@ -1297,7 +1297,7 @@ class UserFunction(Function):
|
|||
|
||||
def _get_none_state(self, device=cpu()):
|
||||
if self._none_state.device() != device:
|
||||
self._none_state = cntk_py.UserBackPropState(self, device, None)
|
||||
self._none_state = cntk_py.UserBackPropState.create(self, device, None)
|
||||
|
||||
return self._none_state
|
||||
|
||||
|
@ -1339,7 +1339,7 @@ class UserFunction(Function):
|
|||
if state is None:
|
||||
state = self._get_none_state(device)
|
||||
elif not isinstance(state, cntk_py.BackPropState):
|
||||
state = cntk_py.UserBackPropState(self, device, state)
|
||||
state = cntk_py.UserBackPropState.create(self, device, state)
|
||||
|
||||
if self.as_numpy:
|
||||
for k,v in outputs.items():
|
||||
|
|
|
@ -35,6 +35,16 @@ def cntk_device(device_id):
|
|||
return gpu(device_id)
|
||||
|
||||
|
||||
def mem_used():
|
||||
'''
|
||||
Return the non-swapped physical memory the Python process is using.
|
||||
'''
|
||||
import os
|
||||
import psutil
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss
|
||||
|
||||
|
||||
def _test_unary_op(precision, device_id, op_func,
|
||||
value, expected_forward, expected_backward_all, op_param_dict={}):
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import numpy as np
|
|||
from cntk import *
|
||||
from cntk.learners import *
|
||||
from cntk.ops import *
|
||||
from cntk.ops.tests.ops_test_utils import cntk_device
|
||||
from cntk.ops.tests.ops_test_utils import cntk_device, mem_used
|
||||
from cntk.ops.functions import UserFunction
|
||||
|
||||
np.random.seed(0)
|
||||
|
@ -29,7 +29,7 @@ def generate_random_data_sample(sample_size, feature_dim, num_classes):
|
|||
# Make sure that the data is separable
|
||||
X = (np.random.randn(sample_size, feature_dim)+3) * (Y+1)
|
||||
X = X.astype(np.float32)
|
||||
class_ind = [Y==class_number for class_number in range(num_classes)]
|
||||
class_ind = [Y == class_number for class_number in range(num_classes)]
|
||||
Y = np.asarray(np.hstack(class_ind), dtype=np.float32)
|
||||
return X, Y
|
||||
|
||||
|
@ -57,17 +57,20 @@ def fully_connected_classifier_net(inp, num_output_classes, hidden_layer_dim,
|
|||
r = linear_layer(h, num_output_classes)
|
||||
return r
|
||||
|
||||
|
||||
def print_training_progress(trainer, mb, frequency):
|
||||
training_loss = "NA"
|
||||
eval_error = "NA"
|
||||
|
||||
if mb%frequency == 0:
|
||||
if mb % frequency == 0:
|
||||
training_loss = trainer.previous_minibatch_loss_average
|
||||
eval_error = trainer.previous_minibatch_evaluation_average
|
||||
|
||||
return mb, training_loss, eval_error
|
||||
|
||||
def train(nonlinearity, num_hidden_layers, device_id):
|
||||
|
||||
def train(nonlinearity, num_hidden_layers, device_id,
|
||||
minibatch_size=10, num_samples=10000):
|
||||
from cntk.cntk_py import always_allow_setting_default_device
|
||||
always_allow_setting_default_device()
|
||||
try_set_default_device(cntk_device(device_id))
|
||||
|
@ -93,27 +96,46 @@ def train(nonlinearity, num_hidden_layers, device_id):
|
|||
learner = sgd(z.parameters, lr_schedule)
|
||||
trainer = Trainer(z, (loss, eval_error), [learner])
|
||||
|
||||
|
||||
minibatch_size = 25
|
||||
num_samples = 2500
|
||||
num_minibatches_to_train = num_samples / minibatch_size
|
||||
num_minibatches_to_train = int(num_samples / minibatch_size)
|
||||
|
||||
training_progress_output_freq = 20
|
||||
|
||||
losses = []
|
||||
errors = []
|
||||
# Preallocate so that we don't measure the memory incrase
|
||||
losses = [0]*num_minibatches_to_train
|
||||
errors = [0]*num_minibatches_to_train
|
||||
|
||||
for i in range(0, int(num_minibatches_to_train)):
|
||||
features, labels = generate_random_data_sample(minibatch_size, input_dim, num_output_classes)
|
||||
mem = [0]*num_minibatches_to_train
|
||||
|
||||
# Accept at most 500K memory increase. This is in line with the non-UDF,
|
||||
# pure CNTK usage.
|
||||
MEMORY_THRESH = 500 * 1024
|
||||
|
||||
i = 0
|
||||
while i < num_minibatches_to_train:
|
||||
mem[i] = mem_used()
|
||||
|
||||
features, labels = generate_random_data_sample(minibatch_size,
|
||||
input_dim,
|
||||
num_output_classes)
|
||||
|
||||
# Specify the input variables mapping in the model to actual minibatch
|
||||
# data for training.
|
||||
trainer.train_minibatch({inp: features, label: labels},
|
||||
device=cntk_device(device_id))
|
||||
|
||||
# Specify the input variables mapping in the model to actual minibatch data for training
|
||||
trainer.train_minibatch({inp : features, label : labels},
|
||||
device=cntk_device(device_id))
|
||||
batchsize, loss, error = print_training_progress(trainer, i,
|
||||
training_progress_output_freq)
|
||||
if not (loss == "NA" or error =="NA"):
|
||||
losses.append(loss)
|
||||
errors.append(error)
|
||||
|
||||
if not (loss == "NA" or error == "NA"):
|
||||
losses[i] = loss
|
||||
errors[i] = error
|
||||
|
||||
i += 1
|
||||
|
||||
mem_diff = mem[-1] - mem[10]
|
||||
if mem_diff > MEMORY_THRESH:
|
||||
raise ValueError('Memory leak detected with %s: %i' %
|
||||
(nonlinearity, mem_diff))
|
||||
|
||||
return losses, errors
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче