зеркало из https://github.com/microsoft/caffe.git
removing all references to Blob.num property (that assumes Blob is 4D). Replacing it with accessing Blob.shape[0] - for Blobs with num_axes() != 4
This commit is contained in:
Родитель
4541f89005
Коммит
29bb23fc92
|
@ -98,7 +98,7 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
|
|||
# Set input according to defined shapes and make arrays single and
|
||||
# C-contiguous as Caffe expects.
|
||||
for in_, blob in kwargs.iteritems():
|
||||
if blob.shape[0] != self.blobs[in_].num:
|
||||
if blob.shape[0] != self.blobs[in_].shape[0]:
|
||||
raise Exception('Input is not batch sized')
|
||||
self.blobs[in_].data[...] = blob
|
||||
|
||||
|
@ -146,7 +146,7 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
|
|||
# Set top diffs according to defined shapes and make arrays single and
|
||||
# C-contiguous as Caffe expects.
|
||||
for top, diff in kwargs.iteritems():
|
||||
if diff.shape[0] != self.blobs[top].num:
|
||||
if diff.shape[0] != self.blobs[top].shape[0]:
|
||||
raise Exception('Diff is not batch sized')
|
||||
self.blobs[top].diff[...] = diff
|
||||
|
||||
|
@ -257,7 +257,7 @@ def _Net_batch(self, blobs):
|
|||
batch: {blob name: list of blobs} dict for a single batch.
|
||||
"""
|
||||
num = len(blobs.itervalues().next())
|
||||
batch_size = self.blobs.itervalues().next().num
|
||||
batch_size = self.blobs.itervalues().next().shape[0]
|
||||
remainder = num % batch_size
|
||||
num_batches = num / batch_size
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче