removed option to not preload clips; replaced clip_loader with standard batch indexing instead
This commit is contained in:
Родитель
aa3eaabb37
Коммит
e543ecce9c
|
@ -24,7 +24,6 @@ class DataLoader():
|
|||
dataset_info['subsample_factor'],
|
||||
dataset_info['train_clip_methods'],
|
||||
dataset_info['clip_length'],
|
||||
dataset_info['preload_clips'],
|
||||
dataset_info['frame_size'],
|
||||
dataset_info['annotations_to_load'],
|
||||
dataset_info['filter_by_annotations'],
|
||||
|
@ -43,7 +42,6 @@ class DataLoader():
|
|||
dataset_info['subsample_factor'],
|
||||
dataset_info['test_clip_methods'],
|
||||
dataset_info['clip_length'],
|
||||
dataset_info['preload_clips'],
|
||||
dataset_info['frame_size'],
|
||||
dataset_info['annotations_to_load'],
|
||||
dataset_info['filter_by_annotations'],
|
||||
|
@ -61,7 +59,6 @@ class DataLoader():
|
|||
dataset_info['subsample_factor'],
|
||||
dataset_info['test_clip_methods'],
|
||||
dataset_info['clip_length'],
|
||||
dataset_info['preload_clips'],
|
||||
dataset_info['frame_size'],
|
||||
dataset_info['annotations_to_load'],
|
||||
dataset_info['filter_by_annotations'],
|
||||
|
@ -79,15 +76,15 @@ class DataLoader():
|
|||
return self.test_queue
|
||||
|
||||
def config_user_centric_queue(self, root, way_method, object_cap, shot_method, shots, video_types, \
|
||||
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
num_tasks, test_mode=False, with_cluster_labels=False, with_caps=False, shuffle=False, logfile=None):
|
||||
return UserEpisodicDatasetQueue(root, way_method, object_cap, shot_method, shots, video_types, \
|
||||
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile)
|
||||
|
||||
def config_object_centric_queue(self, root, way_method, object_cap, shot_method, shots, video_types, \
|
||||
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
num_tasks, test_mode=False, with_cluster_labels=False, with_caps=False, shuffle=False, logfile=None):
|
||||
return ObjectEpisodicDatasetQueue(root, way_method, object_cap, shot_method, shots, video_types, \
|
||||
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
|
||||
num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile)
|
||||
|
|
|
@ -19,7 +19,7 @@ class ORBITDataset(Dataset):
|
|||
"""
|
||||
Base class for ORBIT dataset.
|
||||
"""
|
||||
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile=None):
|
||||
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile=None):
|
||||
"""
|
||||
Creates instance of ORBITDataset.
|
||||
:param root: (str) Path to train/validation/test folder in ORBIT dataset root folder.
|
||||
|
@ -28,11 +28,10 @@ class ORBITDataset(Dataset):
|
|||
:param shot_methods: (str, str) Method for sampling videos for context and target sets.
|
||||
:param shots: (int, int) Number of videos to sample for context and target sets.
|
||||
:param video_types: (str, str) Video types to sample for context and target sets.
|
||||
:param subsample_factor: (int) Factor to subsample video frames before sampling clips.
|
||||
:param subsample_factor: (int) Factor to subsample video frames if sampling frames uniformly.
|
||||
:param clip_methods: (str, str) Method for sampling clips of contiguous frames from videos for context and target sets.
|
||||
:param clip_length: (int) Number of contiguous frames per video clip.
|
||||
:param preload_clips: (bool) If True, preload clips from disk and return as tensors, otherwise return clip paths.
|
||||
:param frame_size: (int) Size in pixels of preloaded frames.
|
||||
:param frame_size: (int) Size in pixels of loaded frames.
|
||||
:param annotations_to_load: (list::str) Types of frame annotations to load from disk and return per task.
|
||||
:param filter_by_annotations (list::str) Types of frame annotations to filter by for context and target sets.
|
||||
:param test_mode: (bool) If True, returns task with target set grouped by video, otherwise returns task with target set as clips.
|
||||
|
@ -50,7 +49,6 @@ class ORBITDataset(Dataset):
|
|||
self.subsample_factor = subsample_factor
|
||||
self.context_clip_method, self.target_clip_method = clip_methods
|
||||
self.clip_length = clip_length
|
||||
self.preload_clips = preload_clips
|
||||
self.frame_size = frame_size
|
||||
self.test_mode = test_mode
|
||||
self.with_cluster_labels = with_cluster_labels
|
||||
|
@ -161,7 +159,7 @@ class ORBITDataset(Dataset):
|
|||
video_path = os.path.join(obj_path, video_types[set_type], video_name)
|
||||
frames = glob.glob(os.path.join(video_path, "*.jpg"))
|
||||
|
||||
if self.with_annotations or (self.filter_params[set_type]['criteria']):
|
||||
if self.with_annotations or self.filter_params[set_type]['criteria']:
|
||||
video_annotations = self.__load_video_annotations(video_name)
|
||||
self.frame2anns.update(video_annotations)
|
||||
if self.filter_params[set_type]['criteria']:
|
||||
|
@ -345,9 +343,8 @@ class ORBITDataset(Dataset):
|
|||
sampled_paths = frame_paths[sampled_idxs].reshape(-1, self.clip_length)
|
||||
paths.extend(sampled_paths)
|
||||
|
||||
if self.preload_clips:
|
||||
sampled_clips = self.load_clips(sampled_paths)
|
||||
clips += sampled_clips
|
||||
sampled_clips = self.load_clips(sampled_paths)
|
||||
clips += sampled_clips
|
||||
|
||||
if self.with_annotations:
|
||||
sampled_annotations = self.load_annotations(sampled_paths)
|
||||
|
@ -386,21 +383,24 @@ class ORBITDataset(Dataset):
|
|||
|
||||
return loaded_clips
|
||||
|
||||
def load_annotations(self, paths: np.ndarray) -> torch.Tensor:
|
||||
def load_annotations(self, paths: np.ndarray, without_clip_history=True) -> torch.Tensor:
|
||||
"""
|
||||
Function to load frame annotations, arrange in clips, from disk.
|
||||
Function to load frame annotations, arrange in clips.
|
||||
:param paths: (np.ndarray::str) Frame paths organised in clips of self.clip_length contiguous frames.
|
||||
:param without_clip_history: (bool) If True, only load annotations for last frame in every clip.
|
||||
:return: (torch.Tensor) Frame annotations arranged in clips.
|
||||
"""
|
||||
num_clips, clip_length = paths.shape
|
||||
frames_per_clip = 1 if without_clip_history else clip_length
|
||||
assert clip_length == self.clip_length
|
||||
|
||||
loaded_annotations = {
|
||||
annotation : torch.empty(num_clips, clip_length, self.annotation_dims.get(annotation, 1)) # The dimension defaults to 1 unless specified.
|
||||
annotation : torch.empty(num_clips, frames_per_clip, self.annotation_dims.get(annotation, 1)) # The dimension defaults to 1 unless specified.
|
||||
for annotation in self.annotations_to_load }
|
||||
|
||||
for clip_idx in range(num_clips):
|
||||
for frame_idx in range(clip_length):
|
||||
frames_to_load = [clip_length-1] if without_clip_history else range(clip_length)
|
||||
for frame_idx in frames_to_load:
|
||||
frame_path = paths[clip_idx, frame_idx]
|
||||
frame_name = os.path.basename(frame_path)
|
||||
for annotation in self.annotations_to_load:
|
||||
|
@ -411,11 +411,11 @@ class ORBITDataset(Dataset):
|
|||
loaded_annotations[annotation][clip_idx, frame_idx] = float('nan')
|
||||
|
||||
return loaded_annotations
|
||||
|
||||
|
||||
def load_and_transform_frame(self, frame_path):
|
||||
"""
|
||||
Function to load and transform frame.
|
||||
:param frame_path: (str) str to frame.
|
||||
:param frame_path: (str) Path to frame.
|
||||
:return: (torch.Tensor) Loaded and transformed frame.
|
||||
"""
|
||||
frame = Image.open(frame_path)
|
||||
|
@ -472,7 +472,7 @@ class ORBITDataset(Dataset):
|
|||
:param test_mode: (bool) If False, do not shuffle task, otherwise shuffle.
|
||||
:return: (torch.Tensor or list::torch.Tensor, np.ndarray::str or list::np.ndarray, torch.Tensor or list::torch.Tensor, dict::torch.Tensor or list::dict::torch.Tensor) Frame data, paths, video-level labels and annotations organised in clips (if train) or grouped and flattened by video (if test/validation).
|
||||
"""
|
||||
clips = torch.stack(clips) if self.preload_clips else torch.tensor(clips)
|
||||
clips = torch.stack(clips)
|
||||
paths = np.array(paths)
|
||||
labels = torch.tensor(labels)
|
||||
annotations = { ann: torch.stack(annotations[ann]) for ann in self.annotations_to_load }
|
||||
|
@ -484,7 +484,7 @@ class ORBITDataset(Dataset):
|
|||
# get all clips belonging to current video
|
||||
idxs = video_ids == video_id
|
||||
# flatten frames and paths from current video (assumed to be sorted)
|
||||
video_frames = clips[idxs].flatten(end_dim=1) if self.preload_clips else None
|
||||
video_frames = clips[idxs].flatten(end_dim=1)
|
||||
video_paths = paths[idxs].reshape(-1)
|
||||
frames_by_video.append(video_frames)
|
||||
paths_by_video.append(video_paths)
|
||||
|
@ -492,7 +492,7 @@ class ORBITDataset(Dataset):
|
|||
video_label = labels[idxs][0]
|
||||
labels_by_video.append(video_label)
|
||||
# get all frame annotations for current video
|
||||
video_anns = { ann : annotations[ann][idxs].flatten(end_dim=1) for ann in self.annotations_to_load } if self.with_annotations else None
|
||||
video_anns = { ann : annotations[ann][idxs].flatten(end_dim=1) for ann in self.annotations_to_load } if self.with_annotations else None
|
||||
annotations_by_video.append(video_anns)
|
||||
return frames_by_video, paths_by_video, labels_by_video, annotations_by_video
|
||||
else:
|
||||
|
@ -509,16 +509,10 @@ class ORBITDataset(Dataset):
|
|||
"""
|
||||
idxs = np.arange(len(paths))
|
||||
random.shuffle(idxs)
|
||||
if self.preload_clips:
|
||||
if self.with_annotations:
|
||||
return clips[idxs], paths[idxs], labels[idxs], { ann : annotations[ann][idxs] for ann in self.annotations_to_load }
|
||||
else:
|
||||
return clips[idxs], paths[idxs], labels[idxs], annotations
|
||||
if self.with_annotations:
|
||||
return clips[idxs], paths[idxs], labels[idxs], { ann : annotations[ann][idxs] for ann in self.annotations_to_load }
|
||||
else:
|
||||
if self.with_annotations:
|
||||
return clips, paths[idxs], labels[idxs], { ann : annotations[ann][idxs] for ann in self.annotations_to_load }
|
||||
else:
|
||||
return clips, paths[idxs], labels[idxs], annotations
|
||||
return clips[idxs], paths[idxs], labels[idxs], annotations
|
||||
|
||||
def get_label_map(self, objects, with_cluster_labels=False):
|
||||
"""
|
||||
|
@ -536,7 +530,7 @@ class ORBITDataset(Dataset):
|
|||
map_dict[old_label] = new_labels[i]
|
||||
return map_dict
|
||||
|
||||
def sample_task(self, task_objects: List[int], with_target_set: bool, task_id: str) -> Dict:
|
||||
def sample_task(self, task_objects: List[int], task_id: str) -> Dict:
|
||||
|
||||
# select way (number of classes/objects) randomly
|
||||
num_objects = len(task_objects)
|
||||
|
@ -570,17 +564,15 @@ class ORBITDataset(Dataset):
|
|||
context_video_ids.extend(cvi)
|
||||
context_annotations = self.extend_ann_dict(context_annotations, ca)
|
||||
|
||||
if with_target_set:
|
||||
tc, tp, tvi, ta = self.sample_clips_from_videos(target_videos, self.target_clip_method)
|
||||
target_clips.extend(tc)
|
||||
target_paths.extend(tp)
|
||||
target_labels.extend([label for _ in range(len(tp))])
|
||||
target_video_ids.extend(tvi)
|
||||
target_annotations = self.extend_ann_dict(target_annotations, ta)
|
||||
tc, tp, tvi, ta = self.sample_clips_from_videos(target_videos, self.target_clip_method)
|
||||
target_clips.extend(tc)
|
||||
target_paths.extend(tp)
|
||||
target_labels.extend([label for _ in range(len(tp))])
|
||||
target_video_ids.extend(tvi)
|
||||
target_annotations = self.extend_ann_dict(target_annotations, ta)
|
||||
|
||||
context_clips, context_paths, context_labels, context_annotations = self.prepare_set(context_clips, context_paths, context_labels, context_annotations, context_video_ids)
|
||||
if with_target_set:
|
||||
target_clips, target_paths, target_labels, target_annotations = self.prepare_set(target_clips, target_paths, target_labels, target_annotations, target_video_ids, test_mode=self.test_mode)
|
||||
target_clips, target_paths, target_labels, target_annotations = self.prepare_set(target_clips, target_paths, target_labels, target_annotations, target_video_ids, test_mode=self.test_mode)
|
||||
|
||||
task_dict = {
|
||||
# Data required for train / test
|
||||
|
@ -602,7 +594,7 @@ class UserEpisodicORBITDataset(ORBITDataset):
|
|||
"""
|
||||
Class for user-centric episodic sampling of ORBIT dataset.
|
||||
"""
|
||||
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
|
||||
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
|
||||
"""
|
||||
Creates instance of UserEpisodicORBITDataset.
|
||||
:param root: (str) Path to train/validation/test folder in ORBIT dataset root folder.
|
||||
|
@ -611,11 +603,10 @@ class UserEpisodicORBITDataset(ORBITDataset):
|
|||
:param shot_methods: (str, str) Method for sampling videos for context and target sets.
|
||||
:param shots: (int, int) Number of videos to sample for context and target sets.
|
||||
:param video_types: (str, str) Video types to sample for context and target sets.
|
||||
:param subsample_factor: (int) Factor to subsample video frames before sampling clips.
|
||||
:param subsample_factor: (int) Factor to subsample video frames if sampling frames uniformly.
|
||||
:param clip_methods: (str, str) Method for sampling clips of contiguous frames from videos for context and target sets.
|
||||
:param clip_length: (int) Number of contiguous frames per video clip.
|
||||
:param preload_clips: (bool) If True, preload clips from disk and return as tensors, otherwise return clip paths.
|
||||
:param frame_size: (int) Size in pixels of preloaded frames.
|
||||
:param frame_size: (int) Size in pixels of loaded frames.
|
||||
:param annotations_to_load: (list::str) Types of frame annotations to load from disk and return per task.
|
||||
:param filter_by_annotations (list::str) Types of frame annotations to filter by for context and target sets.
|
||||
:param test_mode: (bool) If True, returns task with target set grouped by video, otherwise returns task with target set as clips.
|
||||
|
@ -624,25 +615,24 @@ class UserEpisodicORBITDataset(ORBITDataset):
|
|||
:param logfile: (file object) File for printing out loaded data summaries.
|
||||
:return: Nothing.
|
||||
"""
|
||||
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Function to get a user-centric task as a set of (context and target) clips and labels.
|
||||
:param index: (tuple) Task ID and whether to load task target set.
|
||||
:param index: (tuple) Task index.
|
||||
:return: (dict) Context and target set data for task.
|
||||
"""
|
||||
|
||||
task_id, with_target_set = index
|
||||
user = self.users[task_id] # get user (each task == user id)
|
||||
user = self.users[index] # get user (each task == user id)
|
||||
user_objects = self.user2objs[user] # get user's objects
|
||||
return self.sample_task(user_objects, with_target_set, user)
|
||||
return self.sample_task(user_objects, user)
|
||||
|
||||
class ObjectEpisodicORBITDataset(ORBITDataset):
|
||||
"""
|
||||
Class for object-centric episodic sampling of ORBIT dataset.
|
||||
"""
|
||||
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
|
||||
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
|
||||
"""
|
||||
Creates instance of ObjectEpisodicORBITDataset.
|
||||
:param root: (str) Path to train/validation/test folder in ORBIT dataset root folder.
|
||||
|
@ -651,11 +641,10 @@ class ObjectEpisodicORBITDataset(ORBITDataset):
|
|||
:param shot_methods: (str, str) Method for sampling videos for context and target sets.
|
||||
:param shots: (int, int) Number of videos to sample for context and target sets.
|
||||
:param video_types: (str, str) Video types to sample for context and target sets.
|
||||
:param subsample_factor: (int) Factor to subsample video frames before sampling clips.
|
||||
:param subsample_factor: (int) Factor to subsample video frames if sampling frames uniformly.
|
||||
:param clip_methods: (str, str) Method for sampling clips of contiguous frames from videos for context and target sets.
|
||||
:param clip_length: (int) Number of contiguous frames per video clip.
|
||||
:param preload_clips: (bool) If True, preload clips from disk and return as tensors, otherwise return clip paths.
|
||||
:param frame_size: (int) Size in pixels of preloaded frames.
|
||||
:param frame_size: (int) Size in pixels of loaded frames.
|
||||
:param annotations_to_load: (list::str) Types of frame annotations to load from disk and return per task.
|
||||
:param filter_by_annotations (list::str) Types of frame annotations to filter by for context and target sets.
|
||||
:param test_mode: (bool) If True, returns task with target set grouped by video, otherwise returns task with target set as clips.
|
||||
|
@ -664,14 +653,13 @@ class ObjectEpisodicORBITDataset(ORBITDataset):
|
|||
:param logfile: (file object) File for printing out loaded data summaries.
|
||||
:return: Nothing.
|
||||
"""
|
||||
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Function to get a object-centric task as a set of (context and target) clips and labels.
|
||||
:param index: (tuple) Task ID and whether to load task target set.
|
||||
:param index: (tuple) Task index.
|
||||
:return: (dict) Context and target set data for task.
|
||||
"""
|
||||
_, with_target_set = index
|
||||
all_objects = range(0, len(self.obj2vids)) # task can consider all possible objects, not just 1 user's objects
|
||||
return self.sample_task(all_objects, with_target_set)
|
||||
return self.sample_task(all_objects)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
from data.samplers import TaskSampler
|
||||
from data.datasets import UserEpisodicORBITDataset, ObjectEpisodicORBITDataset
|
||||
|
||||
|
@ -11,22 +10,17 @@ class DatasetQueue:
|
|||
Class for a queue of tasks sampled from UserEpisodicORIBTDataset/ObjectEpisodicORBITDataset.
|
||||
|
||||
"""
|
||||
def __init__(self, num_tasks, shuffle, test_mode, override_num_workers: Optional[int]=None):
|
||||
def __init__(self, num_tasks: int, shuffle: bool, num_workers: int) -> None:
|
||||
"""
|
||||
Creates instance of DatasetQueue.
|
||||
:param num_tasks: (int) Number of tasks per user to add to the queue.
|
||||
:param shuffle: (bool) If True, shuffle tasks, else do not shuffle.
|
||||
:param test_mode: (bool) If True, only return target set for first task per user.
|
||||
:param num_workers: (Optional[int]) Number of workers to use. Overrides defaults (4 if test, 8 otherwise).
|
||||
:param num_workers: (int) Number of workers to use.
|
||||
:return: Nothing.
|
||||
"""
|
||||
self.num_tasks = num_tasks
|
||||
self.shuffle = shuffle
|
||||
self.test_mode = test_mode
|
||||
if override_num_workers is None:
|
||||
self.num_workers = 4 if self.test_mode else 8
|
||||
else:
|
||||
self.num_workers = override_num_workers
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.num_users = None
|
||||
self.collate_fn = self.unpack
|
||||
|
@ -50,23 +44,23 @@ class DatasetQueue:
|
|||
dataset=self.dataset,
|
||||
pin_memory=False,
|
||||
num_workers=self.num_workers,
|
||||
sampler=TaskSampler(self.num_tasks, self.num_users, self.shuffle, self.test_mode),
|
||||
sampler=TaskSampler(self.num_tasks, self.num_users, self.shuffle),
|
||||
collate_fn=self.collate_fn
|
||||
)
|
||||
|
||||
class UserEpisodicDatasetQueue(DatasetQueue):
|
||||
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
|
||||
DatasetQueue.__init__(self, num_tasks, shuffle, test_mode)
|
||||
self.dataset = UserEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
|
||||
DatasetQueue.__init__(self, num_tasks, shuffle, num_workers=4 if test_mode else 8)
|
||||
self.dataset = UserEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
self.num_users = self.dataset.num_users
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset.num_users
|
||||
|
||||
class ObjectEpisodicDatasetQueue(DatasetQueue):
|
||||
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
|
||||
DatasetQueue.__init__(self, num_tasks, shuffle, test_mode)
|
||||
self.dataset = ObjectEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
|
||||
DatasetQueue.__init__(self, num_tasks, shuffle, num_workers=4 if test_model else 8)
|
||||
self.dataset = ObjectEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
|
||||
self.num_users = self.dataset.num_users
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -9,28 +9,22 @@ class TaskSampler(Sampler):
|
|||
"""
|
||||
Sampler class for a fixed number of tasks per user.
|
||||
"""
|
||||
def __init__(self, tasks_per_user, num_users, shuffle, test_mode):
|
||||
def __init__(self, tasks_per_user, num_users, shuffle):
|
||||
"""
|
||||
Creates instances of TaskSampler.
|
||||
:param tasks_per_user: (int) Number of tasks to sample per user.
|
||||
:param num_users: (int) Total number of users.
|
||||
:param shuffle: (bool) If True, shuffle tasks, otherwise do not shuffle.
|
||||
:param test_mode: (bool) If True, only load target set for first task per user.
|
||||
:return: Nothing.
|
||||
"""
|
||||
self.tasks_per_user = tasks_per_user
|
||||
self.num_users = num_users
|
||||
self.shuffle = shuffle
|
||||
self.test_mode = test_mode
|
||||
|
||||
def __iter__(self):
|
||||
task_ids = []
|
||||
for user in range(self.num_users):
|
||||
for task in range(self.tasks_per_user):
|
||||
with_target_set = True
|
||||
if self.test_mode and task > 0:
|
||||
with_target_set = False
|
||||
task_ids.append((user, with_target_set))
|
||||
task_ids.extend([user]*self.tasks_per_user)
|
||||
if self.shuffle:
|
||||
random.shuffle(task_ids)
|
||||
return iter(task_ids)
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
def attach_frame_history(frames, history_length):
|
||||
"""
|
||||
Function to attach the immediate history of history_length frames to each frame in a tensor of frame data.
|
||||
param frames: (torch.Tensor) Frames.
|
||||
:param history_length: (int) Number of frames of history to append to each frame.
|
||||
:return: (torch.Tensor) Frames with attached frame history.
|
||||
"""
|
||||
# pad with first frame so that frames 0 to history_length-1 can be evaluated
|
||||
frame_0 = frames.narrow(0, 0, 1)
|
||||
frames = torch.cat((frame_0.repeat(history_length-1, 1, 1, 1), frames), dim=0)
|
||||
|
||||
# for each frame, attach its immediate history of history_length frames
|
||||
frames = [ frames ]
|
||||
for l in range(1, history_length):
|
||||
frames.append( frames[0].roll(shifts=-l, dims=0) )
|
||||
frames_with_history = torch.stack(frames, dim=1) # of size num_clips x history_length
|
||||
|
||||
if history_length > 1:
|
||||
return frames_with_history[:-(history_length-1)] # frames have wrapped around, remove last (history_length - 1) frames
|
||||
else:
|
||||
return frames_with_history
|
||||
|
||||
def unpack_task(task_dict, device, context_to_device=True, target_to_device=False):
|
||||
|
||||
context_clips = task_dict['context_clips']
|
||||
context_paths = task_dict['context_paths']
|
||||
context_labels = task_dict['context_labels']
|
||||
context_annotations = task_dict['context_annotations']
|
||||
target_clips = task_dict['target_clips']
|
||||
target_paths = task_dict['target_paths']
|
||||
target_labels = task_dict['target_labels']
|
||||
target_annotations = task_dict['target_annotations']
|
||||
object_list = task_dict['object_list']
|
||||
|
||||
if context_to_device and isinstance(context_labels, torch.Tensor):
|
||||
context_labels = context_labels.to(device)
|
||||
if target_to_device and isinstance(target_labels, torch.Tensor):
|
||||
target_labels = target_labels.to(device)
|
||||
|
||||
return context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list
|
||||
|
||||
def get_batch_indices(index, last_element, batch_size):
|
||||
batch_start_index = index * batch_size
|
||||
batch_end_index = batch_start_index + batch_size
|
||||
if batch_end_index > last_element:
|
||||
batch_end_index = last_element
|
||||
return batch_start_index, batch_end_index
|
|
@ -28,11 +28,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from data.utils import get_batch_indices
|
||||
from features import extractors
|
||||
from feature_adapters import FilmAdapter, NullAdapter
|
||||
from models.poolers import MeanPooler
|
||||
|
@ -40,7 +40,6 @@ from models.normalisation_layers import TaskNorm
|
|||
from models.set_encoder import SetEncoder, NullSetEncoder
|
||||
from models.classifiers import LinearClassifier, VersaClassifier, PrototypicalClassifier, MahalanobisClassifier
|
||||
from utils.optim import init_optimizer
|
||||
from utils.data import get_clip_loader
|
||||
|
||||
class FewShotRecogniser(nn.Module):
|
||||
"""
|
||||
|
@ -140,7 +139,6 @@ class FewShotRecogniser(nn.Module):
|
|||
:return: (torch.Tensor) Adapted frame features flattened across all clips.
|
||||
"""
|
||||
self._set_model_state(context)
|
||||
t1 = time.time()
|
||||
if self.use_two_gpus:
|
||||
clips = clips.cuda(1)
|
||||
features = self.feature_extractor(clips, feature_adapter_params).cuda(0)
|
||||
|
@ -148,16 +146,14 @@ class FewShotRecogniser(nn.Module):
|
|||
features = self.feature_extractor(clips, feature_adapter_params)
|
||||
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
ops_counter.compute_macs(self.feature_extractor, clips, feature_adapter_params)
|
||||
|
||||
return features
|
||||
|
||||
def _get_features_in_batches(self, clip_loader, feature_adapter_params, ops_counter=None, context=False):
|
||||
def _get_features_in_batches(self, clips, feature_adapter_params, ops_counter=None, context=False):
|
||||
"""
|
||||
Function that passes clips in batches through an adapted feature extractor to get adapted (and flattened) frame features.
|
||||
:param clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for clips, each composed of self.clip_length contiguous frames.
|
||||
:param clips: (torch.Tensor) Clips, each composed of self.clip_length contiguous frames.
|
||||
:param feature_adapter_params: (list::dict::torch.Tensor or list::dict::list::torch.Tensor) Parameters of all FiLM layers.
|
||||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:param context: (bool) True if clips are from context videos, otherwise False.
|
||||
|
@ -165,20 +161,32 @@ class FewShotRecogniser(nn.Module):
|
|||
"""
|
||||
features = []
|
||||
self._set_model_state(context)
|
||||
for batch_clips in clip_loader:
|
||||
|
||||
if self.adapt_features:
|
||||
extract_func = lambda clips: self.feature_extractor(clips, feature_adapter_params)
|
||||
else:
|
||||
extract_func = lambda clips: self.feature_extractor(clips)
|
||||
|
||||
num_clips = len(clips)
|
||||
num_batches = int(np.ceil(float(num_clips) / float(self.batch_size)))
|
||||
for batch in range(num_batches):
|
||||
batch_start_index, batch_end_index = get_batch_indices(batch, num_clips, self.batch_size)
|
||||
batch_clips = clips[batch_start_index:batch_end_index]
|
||||
if len(batch_clips.shape) == 5:
|
||||
batch_clips = batch_clips.flatten(end_dim=1)
|
||||
|
||||
batch_clips = batch_clips.to(self.device, non_blocking=True)
|
||||
t1 = time.time()
|
||||
if self.use_two_gpus:
|
||||
batch_clips = batch_clips.cuda(1)
|
||||
batch_features = self.feature_extractor(batch_clips, feature_adapter_params).cuda(0)
|
||||
batch_features = extract_func(batch_clips).cuda(0)
|
||||
else:
|
||||
batch_features = self.feature_extractor(batch_clips, feature_adapter_params)
|
||||
batch_features = extract_func(batch_clips)
|
||||
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
# TODO add MACs to spatial transformaer
|
||||
ops_counter.compute_macs(self.feature_extractor, batch_clips, feature_adapter_params)
|
||||
if self.adapt_features:
|
||||
ops_counter.compute_macs(self.feature_extractor, batch_clips, feature_adapter_params)
|
||||
else:
|
||||
ops_counter.compute_macs(self.feature_extractor, batch_clips)
|
||||
|
||||
features.append(batch_features)
|
||||
|
||||
|
@ -191,12 +199,9 @@ class FewShotRecogniser(nn.Module):
|
|||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:return: (list::dict::torch.Tensor or list::dict::list::torch.Tensor or None) Parameters of all FiLM layers.
|
||||
"""
|
||||
t1 = time.time()
|
||||
feature_adapter_params = self.feature_adapter(task_embedding)
|
||||
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
ops_counter.compute_macs(self.feature_adapter, task_embedding)
|
||||
|
||||
return feature_adapter_params
|
||||
|
@ -209,20 +214,17 @@ class FewShotRecogniser(nn.Module):
|
|||
:param reduction: (str) Method to aggregate clip encodings from self.set_encoder.
|
||||
:return: (torch.Tensor or None) Task embedding.
|
||||
"""
|
||||
t1 = time.time()
|
||||
reps = self.set_encoder(context_clips)
|
||||
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
ops_counter.compute_macs(self.set_encoder, context_clips)
|
||||
|
||||
return self.set_encoder.aggregate(reps, reduction=reduction, switch_device=self.use_two_gpus)
|
||||
|
||||
def _get_task_embedding_in_batches(self, context_clip_loader, ops_counter=None, reduction='mean'):
|
||||
def _get_task_embedding_in_batches(self, context_clips, ops_counter=None, reduction='mean'):
|
||||
"""
|
||||
Function that passes all of a task's context set through the set encoder to get a task embedding.
|
||||
:param context_clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for context clips, each composed of self.clip_length contiguous frames.
|
||||
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
|
||||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:param reduction: (str) Method to aggregate clip encodings from self.set_encoder.
|
||||
:return: (torch.Tensor or None) Task embedding.
|
||||
|
@ -231,14 +233,16 @@ class FewShotRecogniser(nn.Module):
|
|||
return None
|
||||
|
||||
reps = []
|
||||
for batch_clips in context_clip_loader:
|
||||
num_clips = len(context_clips)
|
||||
num_batches = int(np.ceil(float(num_clips) / float(self.batch_size)))
|
||||
for batch in range(num_batches):
|
||||
batch_start_index, batch_end_index = get_batch_indices(batch, num_clips, self.batch_size)
|
||||
batch_clips = context_clips[batch_start_index:batch_end_index]
|
||||
batch_clips = batch_clips.to(self.device, non_blocking=True)
|
||||
t1 = time.time()
|
||||
batch_reps = self.set_encoder(batch_clips)
|
||||
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
ops_counter.compute_macs(self.set_encoder, batch_clips)
|
||||
|
||||
reps.append(batch_reps)
|
||||
|
@ -252,11 +256,8 @@ class FewShotRecogniser(nn.Module):
|
|||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:return: (torch.Tensor) Frame features pooled per clip i.e. as (num_clips) x (feat_dim).
|
||||
"""
|
||||
t1 = time.time()
|
||||
pooled_features = self.frame_pooler(features)
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
ops_counter.add_macs(features.size(0) * features.size(1))
|
||||
|
||||
return pooled_features
|
||||
|
@ -323,68 +324,54 @@ class MultiStepFewShotRecogniser(FewShotRecogniser):
|
|||
|
||||
self.num_grad_steps = num_grad_steps
|
||||
|
||||
def personalise(self, context_clips, context_clip_labels, learning_args, ops_counter=None):
|
||||
def personalise(self, context_clips, context_labels, learning_args, ops_counter=None):
|
||||
"""
|
||||
Function that learns a new task by taking a fixed number of gradient steps on the task's context set. For each task, a new linear classification layer is added (and FiLM layers if self.adapt_features == True).
|
||||
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
|
||||
:param context_clip_labels: (torch.Tensor) Video-level labels for each context clip.
|
||||
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
|
||||
:param context_labels: (torch.Tensor) Video-level labels for each context clip.
|
||||
:param learning_args: (float, func, str, float) Learning hyper-parameters including learning rate, loss function, optimiser type and factor to scale the extractor's learning rate.
|
||||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:return: Nothing.
|
||||
"""
|
||||
lr, loss_fn, optimizer_type, extractor_scale_factor = learning_args
|
||||
num_classes = len(torch.unique(context_clip_labels))
|
||||
num_classes = len(torch.unique(context_labels))
|
||||
self.init_classifier(num_classes)
|
||||
self.init_feature_adapter()
|
||||
inner_loop_optimizer = init_optimizer(self, lr, optimizer_type, extractor_scale_factor)
|
||||
|
||||
context_clip_loader = get_clip_loader((context_clips, context_clip_labels), self.batch_size, with_labels=True)
|
||||
batch_context_set_size = len(context_labels)
|
||||
num_batches = int(np.ceil(float(batch_context_set_size) / float(self.batch_size)))
|
||||
|
||||
for _ in range(self.num_grad_steps):
|
||||
for batch_context_clips, batch_context_labels in context_clip_loader:
|
||||
batch_context_clips = batch_context_clips.to(self.device)
|
||||
batch_context_labels = batch_context_labels.to(self.device)
|
||||
batch_context_logits = self.predict_a_batch(batch_context_clips, ops_counter=ops_counter, context=True)
|
||||
t1 = time.time()
|
||||
batch_context_loss = loss_fn(batch_context_logits, batch_context_labels)
|
||||
batch_context_loss.backward()
|
||||
for batch in range(num_batches):
|
||||
batch_start_index, batch_end_index = get_batch_indices(batch, batch_context_set_size, self.batch_size)
|
||||
batch_context_clips = context_clips[batch_start_index:batch_end_index].to(self.device)
|
||||
batch_context_labels = context_labels[batch_start_index:batch_end_index].to(self.device)
|
||||
batch_len = len(context_labels[batch_start_index:batch_end_index])
|
||||
|
||||
feature_adapter_params = self._get_feature_adapter_params(None, ops_counter)
|
||||
batch_context_features = self._get_features(batch_context_clips, feature_adapter_params, ops_counter, context=True)
|
||||
batch_context_features = self._pool_features(batch_context_features, ops_counter)
|
||||
batch_context_logits = self.classifier.predict(batch_context_features, ops_counter)
|
||||
loss = loss_fn(batch_context_logits, batch_context_labels)
|
||||
loss += 0.001 * self.feature_adapter.regularization_term(switch_device=self.use_two_gpus)
|
||||
loss *= batch_len/batch_context_set_size
|
||||
loss.backward()
|
||||
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
|
||||
t1 = time.time()
|
||||
inner_loop_optimizer.step()
|
||||
inner_loop_optimizer.zero_grad()
|
||||
if ops_counter:
|
||||
torch.cuda.synchronize()
|
||||
ops_counter.log_time(time.time() - t1)
|
||||
|
||||
def predict(self, clips, ops_counter=None, context=False):
|
||||
"""
|
||||
Function that processes target clips in batches to get logits over object classes for each clip.
|
||||
:param clips: (np.ndarray or torch.Tensor) Clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
|
||||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:param context: (bool) True if a context set is being processed, otherwise False.
|
||||
:return: (torch.Tensor) Logits over object classes for each clip in clips.
|
||||
"""
|
||||
clip_loader = get_clip_loader(clips, self.batch_size)
|
||||
task_embedding = None # multi-step methods do not use set encoder
|
||||
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
|
||||
features = self._get_features_in_batches(clip_loader, self.feature_adapter_params, ops_counter, context=context)
|
||||
features = self._pool_features(features, ops_counter)
|
||||
return self.classifier.predict(features, ops_counter)
|
||||
|
||||
def predict_a_batch(self, clips, ops_counter=None, context=False):
|
||||
"""
|
||||
Function that processes a batch of clips to get logits over object classes for each clip.
|
||||
:param clips: (torch.Tensor) Tensor of clips, each composed of self.clip_length contiguous frames.
|
||||
:param clips: (torch.Tensor) Clips, each composed of self.clip_length contiguous frames.
|
||||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:param context: (bool) True if a context set is being processed, otherwise False.
|
||||
:return: (torch.Tensor) Logits over object classes for each clip in clips.
|
||||
"""
|
||||
task_embedding = None # multi-step methods do not use set encoder
|
||||
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
|
||||
features = self._get_features(clips, self.feature_adapter_params, ops_counter, context=context)
|
||||
feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
|
||||
features = self._get_features_in_batches(clips, feature_adapter_params, ops_counter, context=context)
|
||||
features = self._pool_features(features, ops_counter)
|
||||
return self.classifier.predict(features, ops_counter)
|
||||
|
||||
|
@ -424,75 +411,72 @@ class SingleStepFewShotRecogniser(FewShotRecogniser):
|
|||
def personalise(self, context_clips, context_labels, ops_counter=None):
|
||||
"""
|
||||
Function that learns a new task by performing a forward pass of the task's context set.
|
||||
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
|
||||
:param context_clips: (torch.Tensor) Context clips each composed of self.clip_length contiguous frames.
|
||||
:param context_labels: (torch.Tensor) Video-level labels for each context clip.
|
||||
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
|
||||
:return: Nothing.
|
||||
"""
|
||||
context_clip_loader = get_clip_loader(context_clips, self.batch_size)
|
||||
task_embedding = self._get_task_embedding_in_batches(context_clip_loader, ops_counter)
|
||||
task_embedding = self._get_task_embedding_in_batches(context_clips, ops_counter)
|
||||
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
|
||||
context_features = self._get_features_in_batches(context_clip_loader, self.feature_adapter_params, ops_counter, context=True)
|
||||
context_features = self._get_features_in_batches(context_clips, self.feature_adapter_params, ops_counter, context=True)
|
||||
self.context_features = self._pool_features(context_features, ops_counter)
|
||||
self.classifier.configure(self.context_features, context_labels, ops_counter)
|
||||
|
||||
def personalise_with_lite(self, context_clips, context_labels):
|
||||
"""
|
||||
Function that learns a new task by performning a forward pass of the task's context set with LITE. Namely a random subset of the context set (self.num_lite_samples) is processed with back-propagation enabled, while the remainder is processed with back-propagation disabled.
|
||||
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
|
||||
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
|
||||
:param context_labels: (torch.Tensor) Video-level labels for each context clip.
|
||||
:return: Nothing.
|
||||
"""
|
||||
shuffled_idxs = np.random.permutation(len(context_clips))
|
||||
H = self.num_lite_samples
|
||||
context_clip_loader = get_clip_loader(context_clips[shuffled_idxs][:H], self.batch_size)
|
||||
task_embedding = self._get_task_embedding_with_lite(context_clip_loader, shuffled_idxs)
|
||||
task_embedding = self._get_task_embedding_with_lite(context_clips[shuffled_idxs][:H], shuffled_idxs)
|
||||
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding)
|
||||
self.context_features = self._get_pooled_features_with_lite(context_clip_loader, shuffled_idxs)
|
||||
self.context_features = self._get_pooled_features_with_lite(context_clips[shuffled_idxs][:H], shuffled_idxs)
|
||||
self.classifier.configure(self.context_features, context_labels[shuffled_idxs])
|
||||
|
||||
def _cache_context_outputs(self, context_clips):
|
||||
"""
|
||||
Function that performs a forward pass with a task's entire context set with back-propagation disabled and caches the individual 1) encodings from the set encoder and 2) adapted features from the adapted feature extractor, for each clip.
|
||||
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
|
||||
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
|
||||
:return: Nothing.
|
||||
"""
|
||||
context_clip_loader = get_clip_loader(context_clips, self.batch_size)
|
||||
with torch.set_grad_enabled(False):
|
||||
# cache encoding for each clip from self.set_encoder
|
||||
self.cached_set_encoder_reps = self._get_task_embedding_in_batches(context_clip_loader, reduction='none')
|
||||
self.cached_set_encoder_reps = self._get_task_embedding_in_batches(context_clips, reduction='none')
|
||||
|
||||
# get feature adapter parameters
|
||||
task_embedding = self.set_encoder.mean_pool(self.cached_set_encoder_reps)
|
||||
feature_adapter_params = self._get_feature_adapter_params(task_embedding)
|
||||
|
||||
# cache adapted features for each clip
|
||||
context_features = self._get_features_in_batches(context_clip_loader, feature_adapter_params, context=True)
|
||||
context_features = self._get_features_in_batches(context_clips, feature_adapter_params, context=True)
|
||||
self.cached_context_features = self._pool_features(context_features)
|
||||
|
||||
def _get_task_embedding_with_lite(self, context_clip_loader, idxs):
|
||||
def _get_task_embedding_with_lite(self, context_clips, idxs):
|
||||
"""
|
||||
Function that passes all of a task's context set through the set encoder to get a task embedding with LITE.
|
||||
:param context_clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for context clips, each composed of self.clip_length contiguous frames.
|
||||
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
|
||||
:param idxs: (torch.Tensor) Indicies of elements in context_clips to process with back-propagation enabled.
|
||||
:return: (torch.Tensor or None) Task embedding.
|
||||
"""
|
||||
if isinstance(self.set_encoder, NullSetEncoder):
|
||||
return None
|
||||
H = self.num_lite_samples
|
||||
task_embedding_with_grads = self._get_task_embedding_in_batches(context_clip_loader, reduction='none')
|
||||
task_embedding_with_grads = self._get_task_embedding_in_batches(context_clips, reduction='none')
|
||||
task_embedding_without_grads = self.cached_set_encoder_reps[idxs][H:]
|
||||
return torch.cat((task_embedding_with_grads, task_embedding_without_grads)).mean(dim=0)
|
||||
|
||||
def _get_pooled_features_with_lite(self, context_clip_loader, idxs):
|
||||
def _get_pooled_features_with_lite(self, context_clips, idxs):
|
||||
"""
|
||||
Function that gets adapted clip features for a task's context set with LITE.
|
||||
:param context_clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for context clips, each composed of self.clip_length contiguous frames.
|
||||
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
|
||||
:param idxs: (torch.Tensor) Indicies of elements in context_clips to process with back-propagation enabled.
|
||||
:return: (torch.Tensor) Adapted frame features pooled per clip i.e. as (num_clips) x (feat_dim).
|
||||
"""
|
||||
H = self.num_lite_samples
|
||||
context_features_with_grads = self._get_features_in_batches(context_clip_loader, self.feature_adapter_params, context=True)
|
||||
context_features_with_grads = self._get_features_in_batches(context_clips, self.feature_adapter_params, context=True)
|
||||
context_features_with_grads = self._pool_features(context_features_with_grads)
|
||||
context_features_without_grads = self.cached_context_features[idxs][H:]
|
||||
return torch.cat((context_features_with_grads, context_features_without_grads))
|
||||
|
@ -500,11 +484,10 @@ class SingleStepFewShotRecogniser(FewShotRecogniser):
|
|||
def predict(self, target_clips):
|
||||
"""
|
||||
Function that processes target clips in batches to get logits over object classes for each clip.
|
||||
:param target_clips: (np.ndarray or torch.Tensor) Target clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
|
||||
:param target_clips: (torch.Tensor) Target clips, each composed of self.clip_length contiguous frames.
|
||||
:return: (torch.Tensor) Logits over object classes for each clip in target_clips.
|
||||
"""
|
||||
target_clip_loader = get_clip_loader(target_clips, self.batch_size)
|
||||
target_features = self._get_features_in_batches(target_clip_loader, self.feature_adapter_params)
|
||||
target_features = self._get_features_in_batches(target_clips, self.feature_adapter_params)
|
||||
target_features = self._pool_features(target_features)
|
||||
return self.classifier.predict(target_features)
|
||||
|
||||
|
|
|
@ -36,11 +36,11 @@ import numpy as np
|
|||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from data.dataloaders import DataLoader
|
||||
from data.utils import unpack_task, attach_frame_history
|
||||
from models.few_shot_recognisers import MultiStepFewShotRecogniser
|
||||
from utils.args import parse_args
|
||||
from utils.ops_counter import OpsCounter
|
||||
from utils.optim import cross_entropy, init_optimizer
|
||||
from utils.data import unpack_task, attach_frame_history
|
||||
from utils.logging import print_and_log, get_log_files, stats_to_str
|
||||
from utils.eval_metrics import TrainEvaluator, ValidationEvaluator, TestEvaluator
|
||||
|
||||
|
@ -105,7 +105,6 @@ class Learner:
|
|||
'frame_size': self.args.frame_size,
|
||||
'annotations_to_load': self.args.annotations_to_load,
|
||||
'filter_by_annotations': [self.args.filter_context, self.args.filter_target],
|
||||
'preload_clips': self.args.preload_clips,
|
||||
'logfile': self.logfile
|
||||
}
|
||||
|
||||
|
@ -199,9 +198,9 @@ class Learner:
|
|||
|
||||
def train_task(self, task_dict):
|
||||
|
||||
context_clips, context_frames, context_labels, target_clips, target_frames, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True, preload_clips=self.args.preload_clips)
|
||||
context_clips, context_frames, context_labels, target_clips, target_frames, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True)
|
||||
|
||||
joint_context_clips = torch.cat((context_clips, target_clips)) if self.args.preload_clips else np.concatenate((context_clips, target_clips))
|
||||
joint_context_clips = torch.cat((context_clips, target_clips))
|
||||
joint_context_labels = torch.cat((context_labels, target_labels), dim=0)
|
||||
joint_context_logits = self.model.predict(joint_context_clips, context=True)
|
||||
self.train_evaluator.update_stats(joint_context_logits, joint_context_labels)
|
||||
|
@ -217,7 +216,7 @@ class Learner:
|
|||
# loop through validation tasks (num_validation_users * num_test_tasks_per_user)
|
||||
num_val_tasks = len(self.validation_queue) * self.args.num_test_tasks
|
||||
for step, task_dict in enumerate(self.validation_queue.get_tasks()):
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False, preload_clips=self.args.preload_clips)
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False)
|
||||
num_context_clips = len(context_clips)
|
||||
self.validation_evaluator.set_task_object_list(object_list)
|
||||
self.validation_evaluator.set_task_context_paths(context_paths)
|
||||
|
@ -279,7 +278,7 @@ class Learner:
|
|||
# loop through test tasks (num_test_users * num_test_tasks_per_user)
|
||||
num_test_tasks = len(self.test_queue) * self.args.num_test_tasks
|
||||
for step, task_dict in enumerate(self.test_queue.get_tasks()):
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False, preload_clips=self.args.preload_clips)
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False)
|
||||
num_context_clips = len(context_clips)
|
||||
self.test_evaluator.set_task_object_list(object_list)
|
||||
self.test_evaluator.set_task_context_paths(context_paths)
|
||||
|
|
|
@ -36,11 +36,11 @@ import numpy as np
|
|||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from data.dataloaders import DataLoader
|
||||
from data.utils import get_batch_indices, unpack_task, attach_frame_history
|
||||
from models.few_shot_recognisers import SingleStepFewShotRecogniser
|
||||
from utils.args import parse_args
|
||||
from utils.ops_counter import OpsCounter
|
||||
from utils.optim import cross_entropy, init_optimizer
|
||||
from utils.data import get_clip_loader, unpack_task, attach_frame_history
|
||||
from utils.logging import print_and_log, get_log_files, stats_to_str
|
||||
from utils.eval_metrics import TrainEvaluator, ValidationEvaluator, TestEvaluator
|
||||
|
||||
|
@ -107,7 +107,6 @@ class Learner:
|
|||
'frame_size': self.args.frame_size,
|
||||
'annotations_to_load': self.args.annotations_to_load,
|
||||
'filter_by_annotations': [self.args.filter_context, self.args.filter_target],
|
||||
'preload_clips': self.args.preload_clips,
|
||||
'logfile': self.logfile
|
||||
}
|
||||
|
||||
|
@ -190,7 +189,7 @@ class Learner:
|
|||
self.logfile.close()
|
||||
|
||||
def train_task(self, task_dict):
|
||||
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True, preload_clips=self.args.preload_clips)
|
||||
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True)
|
||||
|
||||
self.model.personalise(context_clips, context_labels)
|
||||
target_logits = self.model.predict(target_clips)
|
||||
|
@ -206,32 +205,37 @@ class Learner:
|
|||
return task_loss
|
||||
|
||||
def train_task_with_lite(self, task_dict):
|
||||
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device, preload_clips=self.args.preload_clips)
|
||||
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device)
|
||||
|
||||
# compute and save personalise outputs of whole context set with back-propagation disabled
|
||||
self.model._cache_context_outputs(context_clips)
|
||||
|
||||
task_loss = 0
|
||||
target_logits = []
|
||||
target_clip_loader = get_clip_loader((target_clips, target_labels), self.args.batch_size, with_labels=True)
|
||||
for batch_target_clips, batch_target_labels in target_clip_loader:
|
||||
target_logits, target_boxes_pred = [], []
|
||||
num_clips = len(target_clips)
|
||||
num_batches = int(np.ceil(float(num_clips) / float(self.args.batch_size)))
|
||||
for batch in range(num_batches):
|
||||
self.model.personalise_with_lite(context_clips, context_labels)
|
||||
batch_target_clips = batch_target_clips.to(device=self.device)
|
||||
batch_target_labels = batch_target_labels.to(device=self.device)
|
||||
batch_target_logits = self.model.predict_a_batch(batch_target_clips)
|
||||
|
||||
batch_start_index, batch_end_index = get_batch_indices(batch, num_clips, self.args.batch_size)
|
||||
batch_target_clips = target_clips[batch_start_index:batch_end_index].to(device=self.device)
|
||||
batch_target_labels = target_labels[batch_start_index:batch_end_index].to(device=self.device)
|
||||
|
||||
batch_target_logits = self.model.predict_a_batch(batch_target_clips)
|
||||
target_logits.extend(batch_target_logits.detach())
|
||||
|
||||
|
||||
loss_scaling = len(context_labels) / (self.args.num_lite_samples * self.args.tasks_per_batch)
|
||||
batch_loss = loss_scaling * self.loss(batch_target_logits, batch_target_labels)
|
||||
batch_loss += 0.001 * self.model.feature_adapter.regularization_term(switch_device=self.args.use_two_gpus)
|
||||
batch_loss += 0.001 * self.model.feature_adapter.regularization_term(switch_device=self.args.use_two_gpus)
|
||||
batch_loss.backward(retain_graph=False)
|
||||
task_loss += batch_loss.detach()
|
||||
|
||||
|
||||
# reset task's params
|
||||
self.model._reset()
|
||||
|
||||
target_logits = torch.stack(target_logits)
|
||||
self.train_evaluator.update_stats(target_logits, target_labels)
|
||||
|
||||
return task_loss
|
||||
|
||||
def validate(self):
|
||||
|
@ -241,7 +245,7 @@ class Learner:
|
|||
# loop through validation tasks (num_validation_users * num_test_tasks_per_user)
|
||||
num_val_tasks = len(self.validation_queue) * self.args.num_test_tasks
|
||||
for step, task_dict in enumerate(self.validation_queue.get_tasks()):
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, preload_clips=self.args.preload_clips)
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device)
|
||||
num_context_clips = len(context_clips)
|
||||
self.validation_evaluator.set_task_object_list(object_list)
|
||||
self.validation_evaluator.set_task_context_paths(context_paths)
|
||||
|
@ -293,7 +297,7 @@ class Learner:
|
|||
# loop through test tasks (num_test_users * num_test_tasks_per_user)
|
||||
num_test_tasks = len(self.test_queue) * self.args.num_test_tasks
|
||||
for step, task_dict in enumerate(self.test_queue.get_tasks()):
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, preload_clips=self.args.preload_clips)
|
||||
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device)
|
||||
num_context_clips = len(context_clips)
|
||||
self.test_evaluator.set_task_object_list(object_list)
|
||||
self.test_evaluator.set_task_context_paths(context_paths)
|
||||
|
|
|
@ -80,8 +80,6 @@ def parse_args(learner='default'):
|
|||
help="Method to sample clips per target video for a test/validation task (default: max).")
|
||||
parser.add_argument("--clip_length", type=int, default=1,
|
||||
help="Number of frames to sample per clip (default: 1).")
|
||||
parser.add_argument("--no_preload_clips", action="store_true",
|
||||
help="Do not preload clips per task from disk. Use if CPU memory is limited, but will mean slower training/testing.")
|
||||
parser.add_argument("--frame_size", type=int, default=224, choices=[84, 224],
|
||||
help="Frame size (default: 224).")
|
||||
parser.add_argument("--annotations_to_load", nargs='+', type=str, default=[], choices=FRAME_ANNOTATION_OPTIONS+BOUNDING_BOX_OPTIONS,
|
||||
|
@ -130,7 +128,6 @@ def parse_args(learner='default'):
|
|||
help="Learning rate for inner loop (MAML) or fine-tuning (FineTuner) (default: 0.1).")
|
||||
|
||||
args = parser.parse_args()
|
||||
args.preload_clips = not args.no_preload_clips
|
||||
verify_args(learner, args)
|
||||
return args
|
||||
|
||||
|
@ -140,9 +137,6 @@ def verify_args(learner, args):
|
|||
cyellow = "\33[33m"
|
||||
cend = "\33[0m"
|
||||
|
||||
if len(args.annotations_to_load) and args.no_preload_clips:
|
||||
sys.exit('{:}error: loading annotations with --annotations_to_load is currently not supported with --no_preload_clips{:}'.format(cred, cend))
|
||||
|
||||
if 'train' in args.mode and not args.learn_extractor and not args.adapt_features:
|
||||
sys.exit('{:}error: at least one of "--learn_extractor" and "--adapt_features" must be used during training{:}'.format(cred, cend))
|
||||
|
||||
|
|
133
utils/data.py
133
utils/data.py
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms.functional as tv_F
|
||||
|
||||
class DatasetFromClipPaths(Dataset):
|
||||
def __init__(self, clip_paths, with_labels):
|
||||
super().__init__()
|
||||
#TODO currently doesn't support loading of annotations
|
||||
self.with_labels = with_labels
|
||||
if self.with_labels:
|
||||
self.clip_paths, self.clip_labels = clip_paths
|
||||
else:
|
||||
self.clip_paths = clip_paths
|
||||
|
||||
self.normalize_stats = {'mean' : [0.500, 0.436, 0.396], 'std' : [0.145, 0.143, 0.138]} # orbit mean train frame
|
||||
|
||||
def __getitem__(self, index):
|
||||
clip = []
|
||||
for frame_path in self.clip_paths[index]:
|
||||
frame = self.load_and_transform_frame(frame_path)
|
||||
clip.append(frame)
|
||||
|
||||
if self.with_labels:
|
||||
return torch.stack(clip, dim=0), self.clip_labels[index]
|
||||
else:
|
||||
return torch.stack(clip, dim=0)
|
||||
|
||||
def load_and_transform_frame(self, frame_path):
|
||||
"""
|
||||
Function to load and transform frame.
|
||||
:param frame_path: (str) Path to frame.
|
||||
:return: (torch.Tensor) Loaded and transformed frame.
|
||||
"""
|
||||
frame = Image.open(frame_path)
|
||||
frame = tv_F.to_tensor(frame)
|
||||
return tv_F.normalize(frame, mean=self.normalize_stats['mean'], std=self.normalize_stats['std'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clip_paths)
|
||||
|
||||
def get_clip_loader(clips, batch_size, with_labels=False):
|
||||
if isinstance(clips[0], np.ndarray):
|
||||
clips_dataset = DatasetFromClipPaths(clips, with_labels=with_labels)
|
||||
return DataLoader(clips_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
prefetch_factor=8,
|
||||
persistent_workers=True)
|
||||
|
||||
elif isinstance(clips[0], torch.Tensor):
|
||||
if with_labels:
|
||||
return list(zip(clips[0].split(batch_size), clips[1].split(batch_size)))
|
||||
else:
|
||||
return clips.split(batch_size)
|
||||
|
||||
def attach_frame_history(frames, history_length):
|
||||
|
||||
if isinstance(frames, np.ndarray):
|
||||
return attach_frame_history_paths(frames, history_length)
|
||||
elif isinstance(frames, torch.Tensor):
|
||||
return attach_frame_history_tensor(frames, history_length)
|
||||
|
||||
def attach_frame_history_paths(frame_paths, history_length):
|
||||
"""
|
||||
Function to attach the immediate history of history_length frames to each frame in an array of frame paths.
|
||||
:param frame_paths: (np.ndarray) Frame paths.
|
||||
:param history_length: (int) Number of frames of history to append to each frame.
|
||||
:return: (np.ndarray) Frame paths with attached frame history.
|
||||
"""
|
||||
# pad with first frame so that frames 0 to history_length-1 can be evaluated
|
||||
frame_paths = np.concatenate([np.repeat(frame_paths[0], history_length-1), frame_paths])
|
||||
|
||||
# for each frame path, attach its immediate history of history_length frames
|
||||
frame_paths = [ frame_paths ]
|
||||
for l in range(1, history_length):
|
||||
frame_paths.append( np.roll(frame_paths[0], shift=-l, axis=0) )
|
||||
frame_paths_with_history = np.stack(frame_paths, axis=1) # of size num_clips x history_length
|
||||
|
||||
if history_length > 1:
|
||||
return frames_with_history[:-(history_length-1)] # frames have wrapped around, remove last (history_length - 1) frames
|
||||
else:
|
||||
return frames_with_history
|
||||
|
||||
def attach_frame_history_tensor(frames, history_length):
|
||||
"""
|
||||
Function to attach the immediate history of history_length frames to each frame in a tensor of frame data.
|
||||
param frames: (torch.Tensor) Frames.
|
||||
:param history_length: (int) Number of frames of history to append to each frame.
|
||||
:return: (torch.Tensor) Frames with attached frame history.
|
||||
"""
|
||||
# pad with first frame so that frames 0 to history_length-1 can be evaluated
|
||||
frame_0 = frames.narrow(0, 0, 1)
|
||||
frames = torch.cat((frame_0.repeat(history_length-1, 1, 1, 1), frames), dim=0)
|
||||
|
||||
# for each frame, attach its immediate history of history_length frames
|
||||
frames = [ frames ]
|
||||
for l in range(1, history_length):
|
||||
frames.append( frames[0].roll(shifts=-l, dims=0) )
|
||||
frames_with_history = torch.stack(frames, dim=1) # of size num_clips x history_length
|
||||
|
||||
if history_length > 1:
|
||||
return frames_with_history[:-(history_length-1)] # frames have wrapped around, remove last (history_length - 1) frames
|
||||
else:
|
||||
return frames_with_history
|
||||
|
||||
def unpack_task(task_dict, device, context_to_device=True, target_to_device=False, preload_clips=False):
|
||||
|
||||
context_clips = task_dict['context_clips']
|
||||
context_paths = task_dict['context_paths']
|
||||
context_labels = task_dict['context_labels']
|
||||
context_annotations = task_dict['context_annotations']
|
||||
target_clips = task_dict['target_clips']
|
||||
target_paths = task_dict['target_paths']
|
||||
target_labels = task_dict['target_labels']
|
||||
target_annotations = task_dict['target_annotations']
|
||||
object_list = task_dict['object_list']
|
||||
|
||||
if context_to_device and isinstance(context_labels, torch.Tensor):
|
||||
context_labels = context_labels.to(device)
|
||||
if target_to_device and isinstance(target_labels, torch.Tensor):
|
||||
target_labels = target_labels.to(device)
|
||||
|
||||
if preload_clips:
|
||||
return context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list
|
||||
else:
|
||||
return context_paths, context_paths, context_labels, target_paths, target_paths, target_labels, object_list
|
|
@ -7,7 +7,8 @@ from datetime import datetime
|
|||
|
||||
def print_and_log(log_file, message):
|
||||
print(message)
|
||||
log_file.write(message + '\n')
|
||||
if log_file:
|
||||
log_file.write(message + '\n')
|
||||
|
||||
def get_log_files(checkpoint_dir, model_path):
|
||||
"""
|
||||
|
|
Загрузка…
Ссылка в новой задаче