зеркало из https://github.com/microsoft/caffe.git
generalize power_wrapper to different networks and inputs
generalize from the imagenet reference to other detection networks: - refactor to configure for a given net and data source - determine dimensions of net input and output automatically from blobs
This commit is contained in:
Родитель
2f7fbfa9e4
Коммит
69671dd30b
|
@ -25,20 +25,19 @@ import skimage.transform
|
|||
import selective_search_ijcv_with_python as selective_search
|
||||
import caffe
|
||||
|
||||
IMAGE_DIM = 256
|
||||
CROPPED_DIM = 227
|
||||
IMAGE_CENTER = int((IMAGE_DIM - CROPPED_DIM) / 2)
|
||||
NET = None
|
||||
|
||||
IMAGE_DIM = None
|
||||
CROPPED_DIM = None
|
||||
IMAGE_CENTER = None
|
||||
|
||||
IMAGE_MEAN = None
|
||||
CROPPED_IMAGE_MEAN = None
|
||||
|
||||
NUM_OUTPUT = None
|
||||
|
||||
CROP_MODES = ['center_only', 'corners', 'selective_search']
|
||||
|
||||
# Load the imagenet mean file
|
||||
IMAGENET_MEAN = np.load(
|
||||
os.path.join(os.path.dirname(__file__), 'ilsvrc_2012_mean.npy'))
|
||||
CROPPED_IMAGENET_MEAN = IMAGENET_MEAN[IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
|
||||
IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
|
||||
:]
|
||||
|
||||
|
||||
def load_image(filename):
|
||||
"""
|
||||
Input:
|
||||
|
@ -72,14 +71,14 @@ def format_image(image, window=None, cropped_size=False):
|
|||
if window is not None:
|
||||
image = image[window[0]:window[2], window[1]:window[3]]
|
||||
|
||||
# Resize to ImageNet size, convert to BGR, subtract mean.
|
||||
# Resize to input size, convert to BGR, subtract mean.
|
||||
image = image[:, :, ::-1]
|
||||
if cropped_size:
|
||||
image = skimage.transform.resize(image, (CROPPED_DIM, CROPPED_DIM)) * 255
|
||||
image -= CROPPED_IMAGENET_MEAN
|
||||
image -= CROPPED_IMAGE_MEAN
|
||||
else:
|
||||
image = skimage.transform.resize(image, (IMAGE_DIM, IMAGE_DIM)) * 255
|
||||
image -= IMAGENET_MEAN
|
||||
image -= IMAGE_MEAN
|
||||
|
||||
image = image.swapaxes(1, 2).swapaxes(0, 1)
|
||||
return image
|
||||
|
@ -235,19 +234,14 @@ def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10):
|
|||
return df_batches
|
||||
|
||||
|
||||
def compute_feats(images_df, layer='imagenet'):
|
||||
if layer == 'imagenet':
|
||||
num_output = 1000
|
||||
else:
|
||||
raise ValueError("Unknown layer requested: {}".format(layer))
|
||||
|
||||
def compute_feats(images_df):
|
||||
num = images_df.shape[0]
|
||||
input_blobs = [np.ascontiguousarray(
|
||||
np.concatenate(images_df['image'].values), dtype='float32')]
|
||||
output_blobs = [np.empty((num, num_output, 1, 1), dtype=np.float32)]
|
||||
output_blobs = [np.empty((num, NUM_OUTPUT, 1, 1), dtype=np.float32)]
|
||||
print(input_blobs[0].shape, output_blobs[0].shape)
|
||||
|
||||
caffenet.Forward(input_blobs, output_blobs)
|
||||
NET.Forward(input_blobs, output_blobs)
|
||||
feats = [output_blobs[0][i].flatten() for i in range(len(output_blobs[0]))]
|
||||
|
||||
# Add the features and delete the images.
|
||||
|
@ -256,6 +250,34 @@ def compute_feats(images_df, layer='imagenet'):
|
|||
return images_df
|
||||
|
||||
|
||||
def config(model_def, pretrained_model, gpu, image_dim, image_mean_file):
|
||||
global IMAGE_DIM, CROPPED_DIM, IMAGE_CENTER, IMAGE_MEAN, CROPPED_IMAGE_MEAN
|
||||
global NET, NUM_OUTPUT
|
||||
|
||||
# Initialize network by loading model definition and weights.
|
||||
t = time.time()
|
||||
print("Loading Caffe model.")
|
||||
NET = caffe.CaffeNet(model_def, pretrained_model)
|
||||
NET.set_phase_test()
|
||||
if gpu:
|
||||
NET.set_mode_gpu()
|
||||
print("Caffe model loaded in {:.3f} s".format(time.time() - t))
|
||||
|
||||
# Configure for input/output data
|
||||
IMAGE_DIM = image_dim
|
||||
CROPPED_DIM = NET.blobs()[0].width
|
||||
IMAGE_CENTER = int((IMAGE_DIM - CROPPED_DIM) / 2)
|
||||
|
||||
# Load the data set mean file
|
||||
IMAGE_MEAN = np.load(image_mean_file)
|
||||
|
||||
|
||||
CROPPED_IMAGE_MEAN = IMAGE_MEAN[IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
|
||||
IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
|
||||
:]
|
||||
NUM_OUTPUT = NET.blobs()[-1].channels # number of output classes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse cmdline options
|
||||
gflags.DEFINE_string(
|
||||
|
@ -273,10 +295,19 @@ if __name__ == "__main__":
|
|||
gflags.DEFINE_string(
|
||||
"output", "", "Output DataFrame HDF5 filename.")
|
||||
gflags.DEFINE_string(
|
||||
"layer", "imagenet", "Layer to output.")
|
||||
"image_dim", 256, "Canonical (square) image dimension.")
|
||||
gflags.DEFINE_string(
|
||||
"image_mean_file",
|
||||
os.path.join(os.path.dirname(__file__), 'ilsvrc_2012_mean.npy'),
|
||||
"Data set image mean (numpy array).")
|
||||
FLAGS = gflags.FLAGS
|
||||
FLAGS(sys.argv)
|
||||
|
||||
|
||||
# Configure network, input, output
|
||||
config(FLAGS.model_def, FLAGS.pretrained_model, FLAGS.gpu, FLAGS.image_dim,
|
||||
FLAGS.image_mean_file)
|
||||
|
||||
# Load list of image filenames and assemble into batches.
|
||||
t = time.time()
|
||||
print('Assembling batches...')
|
||||
|
@ -287,15 +318,6 @@ if __name__ == "__main__":
|
|||
print('{} batches assembled in {:.3f} s'.format(len(image_batches),
|
||||
time.time() - t))
|
||||
|
||||
# Initialize network by loading model definition and weights.
|
||||
t = time.time()
|
||||
print("Loading Caffe model.")
|
||||
caffenet = caffe.CaffeNet(FLAGS.model_def, FLAGS.pretrained_model)
|
||||
caffenet.set_phase_test()
|
||||
if FLAGS.gpu:
|
||||
caffenet.set_mode_gpu()
|
||||
print("Caffe model loaded in {:.3f} s".format(time.time() - t))
|
||||
|
||||
# Process the batches.
|
||||
t = time.time()
|
||||
print 'Processing {} files in {} batches'.format(len(image_fnames),
|
||||
|
|
Загрузка…
Ссылка в новой задаче