зеркало из https://github.com/microsoft/caffe.git
Use 'six' library to ensure python3 compliance.
Use '//' instead of '/' for entire division.
This commit is contained in:
Родитель
c2769c1096
Коммит
666da79ad2
|
@ -14,6 +14,8 @@ from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \
|
||||||
RMSPropSolver, AdaDeltaSolver, AdamSolver
|
RMSPropSolver, AdaDeltaSolver, AdamSolver
|
||||||
import caffe.io
|
import caffe.io
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
# We directly update methods from Net here (rather than using composition or
|
# We directly update methods from Net here (rather than using composition or
|
||||||
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
|
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
|
||||||
# automatically have the improved interface.
|
# automatically have the improved interface.
|
||||||
|
@ -97,7 +99,7 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
|
||||||
raise Exception('Input blob arguments do not match net inputs.')
|
raise Exception('Input blob arguments do not match net inputs.')
|
||||||
# Set input according to defined shapes and make arrays single and
|
# Set input according to defined shapes and make arrays single and
|
||||||
# C-contiguous as Caffe expects.
|
# C-contiguous as Caffe expects.
|
||||||
for in_, blob in kwargs.iteritems():
|
for in_, blob in six.iteritems(kwargs):
|
||||||
if blob.shape[0] != self.blobs[in_].shape[0]:
|
if blob.shape[0] != self.blobs[in_].shape[0]:
|
||||||
raise Exception('Input is not batch sized')
|
raise Exception('Input is not batch sized')
|
||||||
self.blobs[in_].data[...] = blob
|
self.blobs[in_].data[...] = blob
|
||||||
|
@ -145,7 +147,7 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
|
||||||
raise Exception('Top diff arguments do not match net outputs.')
|
raise Exception('Top diff arguments do not match net outputs.')
|
||||||
# Set top diffs according to defined shapes and make arrays single and
|
# Set top diffs according to defined shapes and make arrays single and
|
||||||
# C-contiguous as Caffe expects.
|
# C-contiguous as Caffe expects.
|
||||||
for top, diff in kwargs.iteritems():
|
for top, diff in six.iteritems(kwargs):
|
||||||
if diff.shape[0] != self.blobs[top].shape[0]:
|
if diff.shape[0] != self.blobs[top].shape[0]:
|
||||||
raise Exception('Diff is not batch sized')
|
raise Exception('Diff is not batch sized')
|
||||||
self.blobs[top].diff[...] = diff
|
self.blobs[top].diff[...] = diff
|
||||||
|
@ -174,13 +176,13 @@ def _Net_forward_all(self, blobs=None, **kwargs):
|
||||||
all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
|
all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
|
||||||
for batch in self._batch(kwargs):
|
for batch in self._batch(kwargs):
|
||||||
outs = self.forward(blobs=blobs, **batch)
|
outs = self.forward(blobs=blobs, **batch)
|
||||||
for out, out_blob in outs.iteritems():
|
for out, out_blob in six.iteritems(outs):
|
||||||
all_outs[out].extend(out_blob.copy())
|
all_outs[out].extend(out_blob.copy())
|
||||||
# Package in ndarray.
|
# Package in ndarray.
|
||||||
for out in all_outs:
|
for out in all_outs:
|
||||||
all_outs[out] = np.asarray(all_outs[out])
|
all_outs[out] = np.asarray(all_outs[out])
|
||||||
# Discard padding.
|
# Discard padding.
|
||||||
pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
|
pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
|
||||||
if pad:
|
if pad:
|
||||||
for out in all_outs:
|
for out in all_outs:
|
||||||
all_outs[out] = all_outs[out][:-pad]
|
all_outs[out] = all_outs[out][:-pad]
|
||||||
|
@ -215,16 +217,16 @@ def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs):
|
||||||
for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}):
|
for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}):
|
||||||
batch_blobs = self.forward(blobs=blobs, **fb)
|
batch_blobs = self.forward(blobs=blobs, **fb)
|
||||||
batch_diffs = self.backward(diffs=diffs, **bb)
|
batch_diffs = self.backward(diffs=diffs, **bb)
|
||||||
for out, out_blobs in batch_blobs.iteritems():
|
for out, out_blobs in six.iteritems(batch_blobs):
|
||||||
all_outs[out].extend(out_blobs.copy())
|
all_outs[out].extend(out_blobs.copy())
|
||||||
for diff, out_diffs in batch_diffs.iteritems():
|
for diff, out_diffs in six.iteritems(batch_diffs):
|
||||||
all_diffs[diff].extend(out_diffs.copy())
|
all_diffs[diff].extend(out_diffs.copy())
|
||||||
# Package in ndarray.
|
# Package in ndarray.
|
||||||
for out, diff in zip(all_outs, all_diffs):
|
for out, diff in zip(all_outs, all_diffs):
|
||||||
all_outs[out] = np.asarray(all_outs[out])
|
all_outs[out] = np.asarray(all_outs[out])
|
||||||
all_diffs[diff] = np.asarray(all_diffs[diff])
|
all_diffs[diff] = np.asarray(all_diffs[diff])
|
||||||
# Discard padding at the end and package in ndarray.
|
# Discard padding at the end and package in ndarray.
|
||||||
pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
|
pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
|
||||||
if pad:
|
if pad:
|
||||||
for out, diff in zip(all_outs, all_diffs):
|
for out, diff in zip(all_outs, all_diffs):
|
||||||
all_outs[out] = all_outs[out][:-pad]
|
all_outs[out] = all_outs[out][:-pad]
|
||||||
|
@ -256,10 +258,10 @@ def _Net_batch(self, blobs):
|
||||||
------
|
------
|
||||||
batch: {blob name: list of blobs} dict for a single batch.
|
batch: {blob name: list of blobs} dict for a single batch.
|
||||||
"""
|
"""
|
||||||
num = len(blobs.itervalues().next())
|
num = len(six.next(six.itervalues(blobs)))
|
||||||
batch_size = self.blobs.itervalues().next().shape[0]
|
batch_size = six.next(six.itervalues(self.blobs)).shape[0]
|
||||||
remainder = num % batch_size
|
remainder = num % batch_size
|
||||||
num_batches = num / batch_size
|
num_batches = num // batch_size
|
||||||
|
|
||||||
# Yield full batches.
|
# Yield full batches.
|
||||||
for b in range(num_batches):
|
for b in range(num_batches):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче