зеркало из https://github.com/microsoft/caffe.git
read single input, load/save csv, and record windows
Load `--input_file` as a list of image filenames if .txt OR read as CSV with fields filename,ymin,xmin,ymax,xmax (with labeled header) if .csv. Save `--output_file` as HDF5 if .h5 or CSV if .csv. For CSV, enumerate the class probabilities as numbered class fields. Record crop windows in original image coordinates if `center_only` or `corners` crop mode selected. Previously, these modes didn't report locations.
This commit is contained in:
Родитель
4db88bbe2f
Коммит
1c5df72b03
|
@ -70,7 +70,10 @@ def format_image(image, window=None, cropped_size=False):
|
||||||
Output:
|
Output:
|
||||||
image: (3 x H x W) ndarray
|
image: (3 x H x W) ndarray
|
||||||
Resized to either IMAGE_DIM or CROPPED_DIM.
|
Resized to either IMAGE_DIM or CROPPED_DIM.
|
||||||
|
dims: (H, W) of the original image
|
||||||
"""
|
"""
|
||||||
|
dims = image.shape[:2]
|
||||||
|
|
||||||
# Crop a subimage if window is provided.
|
# Crop a subimage if window is provided.
|
||||||
if window is not None:
|
if window is not None:
|
||||||
image = image[window[0]:window[2], window[1]:window[3]]
|
image = image[window[0]:window[2], window[1]:window[3]]
|
||||||
|
@ -85,26 +88,54 @@ def format_image(image, window=None, cropped_size=False):
|
||||||
image -= IMAGE_MEAN
|
image -= IMAGE_MEAN
|
||||||
|
|
||||||
image = image.swapaxes(1, 2).swapaxes(0, 1)
|
image = image.swapaxes(1, 2).swapaxes(0, 1)
|
||||||
return image
|
return image, dims
|
||||||
|
|
||||||
|
|
||||||
def _assemble_images_list(image_windows):
|
def _image_coordinates(dims, window):
|
||||||
"""
|
"""
|
||||||
For each image, collect the crops for the given windows
|
Calculate the original image coordinates of a
|
||||||
|
window in the canonical (IMAGE_DIM x IMAGE_DIM) coordinates
|
||||||
|
|
||||||
Input:
|
Input:
|
||||||
image_windows: list
|
dims: (H, W) of the original image
|
||||||
|
window: (ymin, xmin, ymax, xmax) in the (IMAGE_DIM x IMAGE_DIM) frame
|
||||||
|
|
||||||
|
Output:
|
||||||
|
image_window: (ymin, xmin, ymax, xmax) in the original image frame
|
||||||
|
"""
|
||||||
|
h, w = dims
|
||||||
|
h_scale, w_scale = h / IMAGE_DIM, w / IMAGE_DIM
|
||||||
|
image_window = window * np.array((1. / h_scale, 1. / w_scale,
|
||||||
|
h_scale, w_scale))
|
||||||
|
return image_window.round().astype(int)
|
||||||
|
|
||||||
|
|
||||||
|
def _assemble_images_list(input_df):
|
||||||
|
"""
|
||||||
|
For each image, collect the crops for the given windows.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
input_df: pandas.DataFrame
|
||||||
|
with 'filename', 'ymin', 'xmin', 'ymax', 'xmax' columns
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
images_df: pandas.DataFrame
|
images_df: pandas.DataFrame
|
||||||
With 'image', 'window', 'filename' columns
|
with 'image', 'window', 'filename' columns
|
||||||
"""
|
"""
|
||||||
|
# unpack sequence of (image filename, windows)
|
||||||
|
windows = input_df[['ymin', 'xmin', 'ymax', 'xmax']].values
|
||||||
|
image_windows = (
|
||||||
|
(ix, windows[input_df.index.get_loc(ix)]) for ix in input_df.index.unique()
|
||||||
|
)
|
||||||
|
|
||||||
|
# extract windows
|
||||||
data = []
|
data = []
|
||||||
for image_fname, windows in image_windows.iteritems():
|
for image_fname, windows in image_windows:
|
||||||
image = load_image(image_fname)
|
image = load_image(image_fname)
|
||||||
for window in windows:
|
for window in windows:
|
||||||
|
window_image, _ = format_image(image, window, cropped_size=True)
|
||||||
data.append({
|
data.append({
|
||||||
'image': format_image(image, window, cropped_size=True)[np.newaxis, :],
|
'image': window_image[np.newaxis, :],
|
||||||
'window': window,
|
'window': window,
|
||||||
'filename': image_fname
|
'filename': image_fname
|
||||||
})
|
})
|
||||||
|
@ -112,6 +143,7 @@ def _assemble_images_list(image_windows):
|
||||||
images_df = pd.DataFrame(data)
|
images_df = pd.DataFrame(data)
|
||||||
return images_df
|
return images_df
|
||||||
|
|
||||||
|
|
||||||
def _assemble_images_center_only(image_fnames):
|
def _assemble_images_center_only(image_fnames):
|
||||||
"""
|
"""
|
||||||
For each image, square the image and crop its center.
|
For each image, square the image and crop its center.
|
||||||
|
@ -121,22 +153,23 @@ def _assemble_images_center_only(image_fnames):
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
images_df: pandas.DataFrame
|
images_df: pandas.DataFrame
|
||||||
With 'image', 'filename' columns.
|
With 'image', 'window', 'filename' columns.
|
||||||
"""
|
"""
|
||||||
all_images = []
|
crop_start, crop_end = IMAGE_CENTER, IMAGE_CENTER + CROPPED_DIM
|
||||||
for image_filename in image_fnames:
|
crop_window = np.array((crop_start, crop_start, crop_end, crop_end))
|
||||||
image = format_image(load_image(image_filename))
|
|
||||||
all_images.append(np.ascontiguousarray(
|
|
||||||
image[np.newaxis, :,
|
|
||||||
IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
|
|
||||||
IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM],
|
|
||||||
dtype=np.float32
|
|
||||||
))
|
|
||||||
|
|
||||||
images_df = pd.DataFrame({
|
data = []
|
||||||
'image': all_images,
|
for image_fname in image_fnames:
|
||||||
'filename': image_fnames
|
image, dims = format_image(load_image(image_fname))
|
||||||
})
|
data.append({
|
||||||
|
'image': image[np.newaxis, :,
|
||||||
|
crop_start:crop_end,
|
||||||
|
crop_start:crop_end],
|
||||||
|
'window': _image_coordinates(dims, crop_window),
|
||||||
|
'filename': image_fname
|
||||||
|
})
|
||||||
|
|
||||||
|
images_df = pd.DataFrame(data)
|
||||||
return images_df
|
return images_df
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,32 +183,37 @@ def _assemble_images_corners(image_fnames):
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
images_df: pandas.DataFrame
|
images_df: pandas.DataFrame
|
||||||
With 'image', 'filename' columns.
|
With 'image', 'window', 'filename' columns.
|
||||||
"""
|
"""
|
||||||
all_images = []
|
# make crops
|
||||||
for image_filename in image_fnames:
|
indices = [0, IMAGE_DIM - CROPPED_DIM]
|
||||||
image = format_image(load_image(image_filename))
|
crops = np.empty((5, 4), dtype=int)
|
||||||
indices = [0, IMAGE_DIM - CROPPED_DIM]
|
curr = 0
|
||||||
|
for i in indices:
|
||||||
|
for j in indices:
|
||||||
|
crops[curr] = (i, j, i + CROPPED_DIM, j + CROPPED_DIM)
|
||||||
|
curr += 1
|
||||||
|
crops[4] = (IMAGE_CENTER, IMAGE_CENTER,
|
||||||
|
IMAGE_CENTER + CROPPED_DIM, IMAGE_CENTER + CROPPED_DIM)
|
||||||
|
all_crops = np.tile(crops, (2, 1))
|
||||||
|
|
||||||
images = np.empty((10, 3, CROPPED_DIM, CROPPED_DIM), dtype=np.float32)
|
data = []
|
||||||
|
for image_fname in image_fnames:
|
||||||
|
image, dims = format_image(load_image(image_fname))
|
||||||
|
image_crops = np.empty((10, 3, CROPPED_DIM, CROPPED_DIM), dtype=np.float32)
|
||||||
curr = 0
|
curr = 0
|
||||||
for i in indices:
|
for crop in crops:
|
||||||
for j in indices:
|
image_crops[curr] = image[:, crop[0]:crop[2], crop[1]:crop[3]]
|
||||||
images[curr] = image[:, i:i + CROPPED_DIM, j:j + CROPPED_DIM]
|
curr += 1
|
||||||
curr += 1
|
image_crops[5:] = image_crops[:5, :, :, ::-1] # flip for mirrors
|
||||||
images[4] = image[
|
for i in range(len(all_crops)):
|
||||||
:,
|
data.append({
|
||||||
IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
|
'image': image_crops[i][np.newaxis, :],
|
||||||
IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM
|
'window': _image_coordinates(dims, all_crops[i]),
|
||||||
]
|
'filename': image_fname
|
||||||
images[5:] = images[:5, :, :, ::-1] # flipped versions
|
})
|
||||||
|
|
||||||
all_images.append(images)
|
images_df = pd.DataFrame(data)
|
||||||
|
|
||||||
images_df = pd.DataFrame({
|
|
||||||
'image': [row[np.newaxis, :] for row in images for images in all_images],
|
|
||||||
'filename': np.repeat(image_fnames, 10)
|
|
||||||
})
|
|
||||||
return images_df
|
return images_df
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,8 +235,9 @@ def _assemble_images_selective_search(image_fnames):
|
||||||
for image_fname, windows in zip(image_fnames, windows_list):
|
for image_fname, windows in zip(image_fnames, windows_list):
|
||||||
image = load_image(image_fname)
|
image = load_image(image_fname)
|
||||||
for window in windows:
|
for window in windows:
|
||||||
|
window_image, _ = format_image(image, window, cropped_size=True)
|
||||||
data.append({
|
data.append({
|
||||||
'image': format_image(image, window, cropped_size=True)[np.newaxis, :],
|
'image': window_image[np.newaxis, :],
|
||||||
'window': window,
|
'window': window,
|
||||||
'filename': image_fname
|
'filename': image_fname
|
||||||
})
|
})
|
||||||
|
@ -207,50 +246,43 @@ def _assemble_images_selective_search(image_fnames):
|
||||||
return images_df
|
return images_df
|
||||||
|
|
||||||
|
|
||||||
def assemble_batches(image_fnames, crop_mode='center_only'):
|
def assemble_batches(inputs, crop_mode='center_only'):
|
||||||
"""
|
"""
|
||||||
Assemble DataFrame of image crops for feature computation.
|
Assemble DataFrame of image crops for feature computation.
|
||||||
|
|
||||||
Input:
|
Input:
|
||||||
image_fnames: list of string
|
inputs: list of filenames (center_only, corners, and selective_search mode)
|
||||||
|
OR input DataFrame (list mode)
|
||||||
mode: string
|
mode: string
|
||||||
'list': the crops are lines in a (image_filename ymin xmin ymax xmax)
|
'list': take the image windows from the input as-is
|
||||||
format file. Set this mode by --crop_mode=list:/path/to/windows_file
|
'center_only': take the CROPPED_DIM middle of the image windows
|
||||||
'center_only': the CROPPED_DIM middle of the image is taken as is
|
|
||||||
'corners': take CROPPED_DIM-sized boxes at 4 corners and center of
|
'corners': take CROPPED_DIM-sized boxes at 4 corners and center of
|
||||||
the image, as well as their flipped versions: a total of 10.
|
the image windows, as well as their flipped versions: a total of 10.
|
||||||
'selective_search': run Selective Search region proposal on the
|
'selective_search': run Selective Search region proposal on the
|
||||||
image, and take each enclosing subwindow.
|
image windows, and take each enclosing subwindow.
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
df_batches: list of DataFrames, each one of BATCH_SIZE rows.
|
df_batches: list of DataFrames, each one of BATCH_SIZE rows.
|
||||||
Each row has 'image', 'filename', and 'window' info.
|
Each row has 'image', 'filename', and 'window' info.
|
||||||
Column 'image' contains (X x 3 x CROPPED_DIM x CROPPED_IM) ndarrays.
|
Column 'image' contains (X x 3 x CROPPED_DIM x CROPPED_IM) ndarrays.
|
||||||
Column 'filename' contains source filenames.
|
Column 'filename' contains source filenames.
|
||||||
|
Column 'window' contains [ymin, xmin, ymax, xmax] ndarrays.
|
||||||
If 'filename' is None, then the row is just for padding.
|
If 'filename' is None, then the row is just for padding.
|
||||||
|
|
||||||
Note: for increased efficiency, increase the batch size (to the limit of gpu
|
Note: for increased efficiency, increase the batch size (to the limit of gpu
|
||||||
memory) to avoid the communication cost
|
memory) to avoid the communication cost
|
||||||
"""
|
"""
|
||||||
if crop_mode.startswith('list'):
|
if crop_mode == 'list':
|
||||||
from collections import defaultdict
|
images_df = _assemble_images_list(inputs)
|
||||||
image_windows = defaultdict(list)
|
|
||||||
crop_mode, windows_file = crop_mode.split(':')
|
|
||||||
with open(windows_file, 'r') as f:
|
|
||||||
for line in f:
|
|
||||||
parts = line.split(' ')
|
|
||||||
image_fname, window = parts[0], parts[1:]
|
|
||||||
image_windows[image_fname].append([int(x) for x in window])
|
|
||||||
images_df = _assemble_images_list(image_windows)
|
|
||||||
|
|
||||||
elif crop_mode == 'center_only':
|
elif crop_mode == 'center_only':
|
||||||
images_df = _assemble_images_center_only(image_fnames)
|
images_df = _assemble_images_center_only(inputs)
|
||||||
|
|
||||||
elif crop_mode == 'corners':
|
elif crop_mode == 'corners':
|
||||||
images_df = _assemble_images_corners(image_fnames)
|
images_df = _assemble_images_corners(inputs)
|
||||||
|
|
||||||
elif crop_mode == 'selective_search':
|
elif crop_mode == 'selective_search':
|
||||||
images_df = _assemble_images_selective_search(image_fnames)
|
images_df = _assemble_images_selective_search(inputs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Unknown mode: not in {}".format(CROP_MODES))
|
raise Exception("Unknown mode: not in {}".format(CROP_MODES))
|
||||||
|
@ -261,10 +293,11 @@ def assemble_batches(image_fnames, crop_mode='center_only'):
|
||||||
remainder = N % BATCH_SIZE
|
remainder = N % BATCH_SIZE
|
||||||
if remainder > 0:
|
if remainder > 0:
|
||||||
zero_image = np.zeros_like(images_df['image'].iloc[0])
|
zero_image = np.zeros_like(images_df['image'].iloc[0])
|
||||||
|
zero_window = np.zeros((1, 4), dtype=int)
|
||||||
remainder_df = pd.DataFrame([{
|
remainder_df = pd.DataFrame([{
|
||||||
'filename': None,
|
'filename': None,
|
||||||
'image': zero_image,
|
'image': zero_image,
|
||||||
'window': [0, 0, 0, 0]
|
'window': zero_window
|
||||||
}] * (BATCH_SIZE - remainder))
|
}] * (BATCH_SIZE - remainder))
|
||||||
images_df = images_df.append(remainder_df)
|
images_df = images_df.append(remainder_df)
|
||||||
N = images_df.shape[0]
|
N = images_df.shape[0]
|
||||||
|
@ -328,9 +361,9 @@ if __name__ == "__main__":
|
||||||
gflags.DEFINE_string(
|
gflags.DEFINE_string(
|
||||||
"crop_mode", "center_only", "Crop mode, from {}".format(CROP_MODES))
|
"crop_mode", "center_only", "Crop mode, from {}".format(CROP_MODES))
|
||||||
gflags.DEFINE_string(
|
gflags.DEFINE_string(
|
||||||
"images_file", "", "Image filenames file.")
|
"input_file", "", "Input txt/csv filename.")
|
||||||
gflags.DEFINE_string(
|
gflags.DEFINE_string(
|
||||||
"output_file", "", "Output DataFrame HDF5 filename.")
|
"output_file", "", "Output h5/csv filename.")
|
||||||
gflags.DEFINE_string(
|
gflags.DEFINE_string(
|
||||||
"images_dim", 256, "Canonical dimension of (square) images.")
|
"images_dim", 256, "Canonical dimension of (square) images.")
|
||||||
gflags.DEFINE_string(
|
gflags.DEFINE_string(
|
||||||
|
@ -344,18 +377,29 @@ if __name__ == "__main__":
|
||||||
config(FLAGS.model_def, FLAGS.pretrained_model, FLAGS.gpu, FLAGS.images_dim,
|
config(FLAGS.model_def, FLAGS.pretrained_model, FLAGS.gpu, FLAGS.images_dim,
|
||||||
FLAGS.images_mean_file)
|
FLAGS.images_mean_file)
|
||||||
|
|
||||||
# Load list of image filenames and assemble into batches.
|
# Load input
|
||||||
|
# .txt = list of filenames
|
||||||
|
# .csv = dataframe that must include a header
|
||||||
|
# with column names filename, ymin, xmin, ymax, xmax
|
||||||
t = time.time()
|
t = time.time()
|
||||||
print('Assembling batches...')
|
print('Loading input and assembling batches...')
|
||||||
with open(FLAGS.images_file) as f:
|
if FLAGS.input_file.lower().endswith('txt'):
|
||||||
image_fnames = [_.strip() for _ in f.readlines()]
|
with open(FLAGS.input_file) as f:
|
||||||
image_batches = assemble_batches(image_fnames, FLAGS.crop_mode)
|
inputs = [_.strip() for _ in f.readlines()]
|
||||||
|
elif FLAGS.input_file.lower().endswith('csv'):
|
||||||
|
inputs = pd.read_csv(FLAGS.input_file, sep=',', dtype={'filename': str})
|
||||||
|
inputs.set_index('filename', inplace=True)
|
||||||
|
else:
|
||||||
|
raise Exception("Uknown input file type: not in txt or csv")
|
||||||
|
|
||||||
|
# Assemble into batches
|
||||||
|
image_batches = assemble_batches(inputs, FLAGS.crop_mode)
|
||||||
print('{} batches assembled in {:.3f} s'.format(len(image_batches),
|
print('{} batches assembled in {:.3f} s'.format(len(image_batches),
|
||||||
time.time() - t))
|
time.time() - t))
|
||||||
|
|
||||||
# Process the batches.
|
# Process the batches.
|
||||||
t = time.time()
|
t = time.time()
|
||||||
print 'Processing {} files in {} batches'.format(len(image_fnames),
|
print 'Processing {} files in {} batches'.format(len(inputs),
|
||||||
len(image_batches))
|
len(image_batches))
|
||||||
dfs_with_feats = []
|
dfs_with_feats = []
|
||||||
for i in range(len(image_batches)):
|
for i in range(len(image_batches)):
|
||||||
|
@ -367,11 +411,29 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Concatenate, droppping the padding rows.
|
# Concatenate, droppping the padding rows.
|
||||||
df = pd.concat(dfs_with_feats).dropna(subset=['filename'])
|
df = pd.concat(dfs_with_feats).dropna(subset=['filename'])
|
||||||
|
df.set_index('filename', inplace=True)
|
||||||
print("Processing complete after {:.3f} s.".format(time.time() - t))
|
print("Processing complete after {:.3f} s.".format(time.time() - t))
|
||||||
|
|
||||||
# Write our the results.
|
# Label coordinates
|
||||||
|
coord_cols = ['ymin', 'xmin', 'ymax', 'xmax']
|
||||||
|
df[coord_cols] = pd.DataFrame(data=np.vstack(df['window']),
|
||||||
|
index=df.index,
|
||||||
|
columns=coord_cols)
|
||||||
|
del(df['window'])
|
||||||
|
|
||||||
|
# Write out the results.
|
||||||
t = time.time()
|
t = time.time()
|
||||||
df.to_hdf(FLAGS.output_file, 'df', mode='w')
|
if FLAGS.output_file.lower().endswith('csv'):
|
||||||
|
# enumerate the class probabilities
|
||||||
|
class_cols = ['class{}'.format(x) for x in range(NUM_OUTPUT)]
|
||||||
|
df[class_cols] = pd.DataFrame(data=np.vstack(df['feat']),
|
||||||
|
index=df.index,
|
||||||
|
columns=class_cols)
|
||||||
|
df.to_csv(FLAGS.output_file, sep=',',
|
||||||
|
cols=coord_cols + class_cols,
|
||||||
|
header=True)
|
||||||
|
else:
|
||||||
|
df.to_hdf(FLAGS.output_file, 'df', mode='w')
|
||||||
print("Done. Saving to {} took {:.3f} s.".format(
|
print("Done. Saving to {} took {:.3f} s.".format(
|
||||||
FLAGS.output_file, time.time() - t))
|
FLAGS.output_file, time.time() - t))
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче