зеркало из https://github.com/microsoft/caffe.git
Коммит
d842f4a24f
|
@ -70,10 +70,10 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
|
|||
# Set input according to defined shapes and make arrays single and
|
||||
# C-contiguous as Caffe expects.
|
||||
for in_, blob in kwargs.iteritems():
|
||||
if blob.shape[0] != self.blobs[in_].num:
|
||||
raise Exception('Input is not batch sized')
|
||||
if blob.ndim != 4:
|
||||
raise Exception('{} blob is not 4-d'.format(in_))
|
||||
if blob.shape[0] != self.blobs[in_].num:
|
||||
raise Exception('Input is not batch sized')
|
||||
self.blobs[in_].data[...] = blob
|
||||
|
||||
self._forward(start_ind, end_ind)
|
||||
|
@ -117,10 +117,10 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
|
|||
# Set top diffs according to defined shapes and make arrays single and
|
||||
# C-contiguous as Caffe expects.
|
||||
for top, diff in kwargs.iteritems():
|
||||
if diff.shape[0] != self.blobs[top].num:
|
||||
raise Exception('Diff is not batch sized')
|
||||
if diff.ndim != 4:
|
||||
raise Exception('{} diff is not 4-d'.format(top))
|
||||
if diff.shape[0] != self.blobs[top].num:
|
||||
raise Exception('Diff is not batch sized')
|
||||
self.blobs[top].diff[...] = diff
|
||||
|
||||
self._backward(start_ind, end_ind)
|
||||
|
@ -284,17 +284,16 @@ def _Net_preprocess(self, input_name, input_):
|
|||
caffe_in = input_.astype(np.float32)
|
||||
input_scale = self.input_scale.get(input_name)
|
||||
channel_order = self.channel_swap.get(input_name)
|
||||
mean = self.mean.get(input_name)
|
||||
in_size = self.blobs[input_name].data.shape[2:]
|
||||
if caffe_in.shape[:2] != in_size:
|
||||
caffe_in = caffe.io.resize_image(caffe_in, in_size)
|
||||
if input_scale:
|
||||
if input_scale is not None:
|
||||
caffe_in *= input_scale
|
||||
if channel_order:
|
||||
if channel_order is not None:
|
||||
caffe_in = caffe_in[:, :, channel_order]
|
||||
caffe_in = caffe_in.transpose((2, 0, 1))
|
||||
if mean is not None:
|
||||
caffe_in -= mean
|
||||
if hasattr(self, 'mean'):
|
||||
caffe_in -= self.mean.get(input_name, 0)
|
||||
return caffe_in
|
||||
|
||||
|
||||
|
@ -305,15 +304,14 @@ def _Net_deprocess(self, input_name, input_):
|
|||
decaf_in = input_.copy().squeeze()
|
||||
input_scale = self.input_scale.get(input_name)
|
||||
channel_order = self.channel_swap.get(input_name)
|
||||
mean = self.mean.get(input_name)
|
||||
if mean is not None:
|
||||
decaf_in += mean
|
||||
if hasattr(self, 'mean'):
|
||||
decaf_in += self.mean.get(input_name, 0)
|
||||
decaf_in = decaf_in.transpose((1,2,0))
|
||||
if channel_order:
|
||||
if channel_order is not None:
|
||||
channel_order_inverse = [channel_order.index(i)
|
||||
for i in range(decaf_in.shape[2])]
|
||||
decaf_in = decaf_in[:, :, channel_order_inverse]
|
||||
if input_scale:
|
||||
if input_scale is not None:
|
||||
decaf_in /= input_scale
|
||||
return decaf_in
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче