Crosstalk Caffe Bug Fixes
Resolved bugs due to: 1. Scaling in batch normalization was done twice 2. Splicing did not take into account a default axis 3. Did not support newer protobuf versions. 4. Stdout was not being ignored during validation
This commit is contained in:
Родитель
08e367ef09
Коммит
0b384cba0d
|
@ -228,7 +228,7 @@ namespace CNTK
|
|||
break;
|
||||
|
||||
default:
|
||||
fprintf(stderr, "TensorBoardFileWriter: Unsupported data type: %d ", dtype);
|
||||
fprintf(stderr, "TensorBoardFileWriter: Unsupported data type: %d ", static_cast<int>(dtype));
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
@ -214,13 +214,16 @@ class SetupCaffeParameters(object):
|
|||
'''
|
||||
cntk_layer_def.parameters = cntkmodel.CntkSpliceParameters()
|
||||
|
||||
cntk_layer_def.parameters.axis = caffe_parameters.axis
|
||||
if caffe_parameters is not None:
|
||||
cntk_layer_def.parameters.axis = caffe_parameters.axis
|
||||
else:
|
||||
cntk_layer_def.parameters.axis = 1
|
||||
output_tensor = inputs_info[0].tensor[:]
|
||||
output_tensor[0] = 0
|
||||
for input_info in inputs_info:
|
||||
if not output_tensor[1:] == input_info.tensor[1:]:
|
||||
raise IndexError('Non-align tensor information\n')
|
||||
output_tensor[0] += input_info.tensor[caffe_parameters.axis - 1]
|
||||
output_tensor[0] += input_info.tensor[cntk_layer_def.parameters.axis - 1]
|
||||
cntk_layer_def.tensor = output_tensor
|
||||
|
||||
@staticmethod
|
||||
|
@ -541,9 +544,9 @@ class CaffeAdapter(baseadapter.Adapter):
|
|||
@staticmethod
|
||||
def _get_layer_parameters(raw_layer):
|
||||
convert_name = raw_layer.type.lower() + 'param'
|
||||
for term in dir(raw_layer):
|
||||
if term.lower().replace('_', '') == convert_name:
|
||||
return getattr(raw_layer, term)
|
||||
for (descriptor, attr) in raw_layer.ListFields():
|
||||
if descriptor.name.lower().replace('_','') == convert_name:
|
||||
return attr
|
||||
return None
|
||||
|
||||
def _try_special_case_wrapper(self, raw_layer):
|
||||
|
|
|
@ -203,9 +203,9 @@ class ApiSetup(object):
|
|||
mean_tensor = cntk_layer.parameter_tensor[0]
|
||||
variance_tensor = cntk_layer.parameter_tensor[1]
|
||||
global_scale = cntk_layer.parameter_tensor[2].data[0]
|
||||
scale_init = 1 / global_scale if global_scale != 0 else 0
|
||||
mean_init = np.asarray(mean_tensor.data, dtype=np.float32) * scale_init
|
||||
var_init = np.asarray(variance_tensor.data, dtype=np.float32) * scale_init
|
||||
moving_average_factor = 1 / global_scale if global_scale != 0 else 0
|
||||
mean_init = np.asarray(mean_tensor.data, dtype=np.float32) * moving_average_factor
|
||||
var_init = np.asarray(variance_tensor.data, dtype=np.float32) * moving_average_factor
|
||||
if len(cntk_layer.parameter_tensor) == 5:
|
||||
scale_tensor = cntk_layer.parameter_tensor[3]
|
||||
bias_tensor = cntk_layer.parameter_tensor[4]
|
||||
|
|
|
@ -47,9 +47,9 @@ class CaffeValidCore(ValidCore):
|
|||
caffe_solver = CaffeResolver()
|
||||
caffe = caffe_solver.caffe
|
||||
if not caffe_solver.runtime():
|
||||
sys.stdout.write('No caffe runtime support, ignore validation...\n')
|
||||
sys.__stdout__.write('No caffe runtime support, ignore validation...\n')
|
||||
return
|
||||
sys.stdout.write('Start valid feature map...\n')
|
||||
sys.__stdout__.write('Start valid feature map...\n')
|
||||
caffe.set_mode_gpu()
|
||||
caffe.set_device(0)
|
||||
net = caffe.Net(source_solver.model_path, source_solver.weights_path, caffe.TEST)
|
||||
|
@ -66,7 +66,7 @@ class CaffeValidCore(ValidCore):
|
|||
test_result = net.blobs[target].data
|
||||
power_error = np.power(gt_result.flatten() - test_result.flatten(), 2).sum()
|
||||
rsme_diff = np.sqrt(power_error / gt_result.size)
|
||||
sys.stdout.write('Validating %s with RMSE = %s, MAX = %s, MIN = %s\n' %
|
||||
(target, str(rsme_diff), str(gt_result.max()), str(gt_result.min())))
|
||||
sys.stdout.write('Validation finished...\n')
|
||||
sys.stdout.flush()
|
||||
sys.__stdout__.write(('Validating %s with RMSE = %s, MAX = %s, MIN = %s\n' %
|
||||
(target, str(rsme_diff), str(gt_result.max()), str(gt_result.min()))))
|
||||
sys.__stdout__.write('Validation finished...\n')
|
||||
sys.__stdout__.flush()
|
||||
|
|
Загрузка…
Ссылка в новой задаче