removed option to not preload clips; replaced clip_loader with standard batch indexing instead

This commit is contained in:
Daniela Massiceti 2022-11-25 05:10:28 +00:00
Родитель aa3eaabb37
Коммит e543ecce9c
11 изменённых файлов: 209 добавлений и 334 удалений

Просмотреть файл

@ -24,7 +24,6 @@ class DataLoader():
dataset_info['subsample_factor'],
dataset_info['train_clip_methods'],
dataset_info['clip_length'],
dataset_info['preload_clips'],
dataset_info['frame_size'],
dataset_info['annotations_to_load'],
dataset_info['filter_by_annotations'],
@ -43,7 +42,6 @@ class DataLoader():
dataset_info['subsample_factor'],
dataset_info['test_clip_methods'],
dataset_info['clip_length'],
dataset_info['preload_clips'],
dataset_info['frame_size'],
dataset_info['annotations_to_load'],
dataset_info['filter_by_annotations'],
@ -61,7 +59,6 @@ class DataLoader():
dataset_info['subsample_factor'],
dataset_info['test_clip_methods'],
dataset_info['clip_length'],
dataset_info['preload_clips'],
dataset_info['frame_size'],
dataset_info['annotations_to_load'],
dataset_info['filter_by_annotations'],
@ -79,15 +76,15 @@ class DataLoader():
return self.test_queue
def config_user_centric_queue(self, root, way_method, object_cap, shot_method, shots, video_types, \
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
num_tasks, test_mode=False, with_cluster_labels=False, with_caps=False, shuffle=False, logfile=None):
return UserEpisodicDatasetQueue(root, way_method, object_cap, shot_method, shots, video_types, \
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile)
def config_object_centric_queue(self, root, way_method, object_cap, shot_method, shots, video_types, \
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
num_tasks, test_mode=False, with_cluster_labels=False, with_caps=False, shuffle=False, logfile=None):
return ObjectEpisodicDatasetQueue(root, way_method, object_cap, shot_method, shots, video_types, \
subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, \
subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, \
num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile)

Просмотреть файл

@ -19,7 +19,7 @@ class ORBITDataset(Dataset):
"""
Base class for ORBIT dataset.
"""
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile=None):
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile=None):
"""
Creates instance of ORBITDataset.
:param root: (str) Path to train/validation/test folder in ORBIT dataset root folder.
@ -28,11 +28,10 @@ class ORBITDataset(Dataset):
:param shot_methods: (str, str) Method for sampling videos for context and target sets.
:param shots: (int, int) Number of videos to sample for context and target sets.
:param video_types: (str, str) Video types to sample for context and target sets.
:param subsample_factor: (int) Factor to subsample video frames before sampling clips.
:param subsample_factor: (int) Factor to subsample video frames if sampling frames uniformly.
:param clip_methods: (str, str) Method for sampling clips of contiguous frames from videos for context and target sets.
:param clip_length: (int) Number of contiguous frames per video clip.
:param preload_clips: (bool) If True, preload clips from disk and return as tensors, otherwise return clip paths.
:param frame_size: (int) Size in pixels of preloaded frames.
:param frame_size: (int) Size in pixels of loaded frames.
:param annotations_to_load: (list::str) Types of frame annotations to load from disk and return per task.
:param filter_by_annotations (list::str) Types of frame annotations to filter by for context and target sets.
:param test_mode: (bool) If True, returns task with target set grouped by video, otherwise returns task with target set as clips.
@ -50,7 +49,6 @@ class ORBITDataset(Dataset):
self.subsample_factor = subsample_factor
self.context_clip_method, self.target_clip_method = clip_methods
self.clip_length = clip_length
self.preload_clips = preload_clips
self.frame_size = frame_size
self.test_mode = test_mode
self.with_cluster_labels = with_cluster_labels
@ -161,7 +159,7 @@ class ORBITDataset(Dataset):
video_path = os.path.join(obj_path, video_types[set_type], video_name)
frames = glob.glob(os.path.join(video_path, "*.jpg"))
if self.with_annotations or (self.filter_params[set_type]['criteria']):
if self.with_annotations or self.filter_params[set_type]['criteria']:
video_annotations = self.__load_video_annotations(video_name)
self.frame2anns.update(video_annotations)
if self.filter_params[set_type]['criteria']:
@ -345,9 +343,8 @@ class ORBITDataset(Dataset):
sampled_paths = frame_paths[sampled_idxs].reshape(-1, self.clip_length)
paths.extend(sampled_paths)
if self.preload_clips:
sampled_clips = self.load_clips(sampled_paths)
clips += sampled_clips
sampled_clips = self.load_clips(sampled_paths)
clips += sampled_clips
if self.with_annotations:
sampled_annotations = self.load_annotations(sampled_paths)
@ -386,21 +383,24 @@ class ORBITDataset(Dataset):
return loaded_clips
def load_annotations(self, paths: np.ndarray) -> torch.Tensor:
def load_annotations(self, paths: np.ndarray, without_clip_history=True) -> torch.Tensor:
"""
Function to load frame annotations, arrange in clips, from disk.
Function to load frame annotations, arrange in clips.
:param paths: (np.ndarray::str) Frame paths organised in clips of self.clip_length contiguous frames.
:param without_clip_history: (bool) If True, only load annotations for last frame in every clip.
:return: (torch.Tensor) Frame annotations arranged in clips.
"""
num_clips, clip_length = paths.shape
frames_per_clip = 1 if without_clip_history else clip_length
assert clip_length == self.clip_length
loaded_annotations = {
annotation : torch.empty(num_clips, clip_length, self.annotation_dims.get(annotation, 1)) # The dimension defaults to 1 unless specified.
annotation : torch.empty(num_clips, frames_per_clip, self.annotation_dims.get(annotation, 1)) # The dimension defaults to 1 unless specified.
for annotation in self.annotations_to_load }
for clip_idx in range(num_clips):
for frame_idx in range(clip_length):
frames_to_load = [clip_length-1] if without_clip_history else range(clip_length)
for frame_idx in frames_to_load:
frame_path = paths[clip_idx, frame_idx]
frame_name = os.path.basename(frame_path)
for annotation in self.annotations_to_load:
@ -411,11 +411,11 @@ class ORBITDataset(Dataset):
loaded_annotations[annotation][clip_idx, frame_idx] = float('nan')
return loaded_annotations
def load_and_transform_frame(self, frame_path):
"""
Function to load and transform frame.
:param frame_path: (str) str to frame.
:param frame_path: (str) Path to frame.
:return: (torch.Tensor) Loaded and transformed frame.
"""
frame = Image.open(frame_path)
@ -472,7 +472,7 @@ class ORBITDataset(Dataset):
:param test_mode: (bool) If False, do not shuffle task, otherwise shuffle.
:return: (torch.Tensor or list::torch.Tensor, np.ndarray::str or list::np.ndarray, torch.Tensor or list::torch.Tensor, dict::torch.Tensor or list::dict::torch.Tensor) Frame data, paths, video-level labels and annotations organised in clips (if train) or grouped and flattened by video (if test/validation).
"""
clips = torch.stack(clips) if self.preload_clips else torch.tensor(clips)
clips = torch.stack(clips)
paths = np.array(paths)
labels = torch.tensor(labels)
annotations = { ann: torch.stack(annotations[ann]) for ann in self.annotations_to_load }
@ -484,7 +484,7 @@ class ORBITDataset(Dataset):
# get all clips belonging to current video
idxs = video_ids == video_id
# flatten frames and paths from current video (assumed to be sorted)
video_frames = clips[idxs].flatten(end_dim=1) if self.preload_clips else None
video_frames = clips[idxs].flatten(end_dim=1)
video_paths = paths[idxs].reshape(-1)
frames_by_video.append(video_frames)
paths_by_video.append(video_paths)
@ -492,7 +492,7 @@ class ORBITDataset(Dataset):
video_label = labels[idxs][0]
labels_by_video.append(video_label)
# get all frame annotations for current video
video_anns = { ann : annotations[ann][idxs].flatten(end_dim=1) for ann in self.annotations_to_load } if self.with_annotations else None
video_anns = { ann : annotations[ann][idxs].flatten(end_dim=1) for ann in self.annotations_to_load } if self.with_annotations else None
annotations_by_video.append(video_anns)
return frames_by_video, paths_by_video, labels_by_video, annotations_by_video
else:
@ -509,16 +509,10 @@ class ORBITDataset(Dataset):
"""
idxs = np.arange(len(paths))
random.shuffle(idxs)
if self.preload_clips:
if self.with_annotations:
return clips[idxs], paths[idxs], labels[idxs], { ann : annotations[ann][idxs] for ann in self.annotations_to_load }
else:
return clips[idxs], paths[idxs], labels[idxs], annotations
if self.with_annotations:
return clips[idxs], paths[idxs], labels[idxs], { ann : annotations[ann][idxs] for ann in self.annotations_to_load }
else:
if self.with_annotations:
return clips, paths[idxs], labels[idxs], { ann : annotations[ann][idxs] for ann in self.annotations_to_load }
else:
return clips, paths[idxs], labels[idxs], annotations
return clips[idxs], paths[idxs], labels[idxs], annotations
def get_label_map(self, objects, with_cluster_labels=False):
"""
@ -536,7 +530,7 @@ class ORBITDataset(Dataset):
map_dict[old_label] = new_labels[i]
return map_dict
def sample_task(self, task_objects: List[int], with_target_set: bool, task_id: str) -> Dict:
def sample_task(self, task_objects: List[int], task_id: str) -> Dict:
# select way (number of classes/objects) randomly
num_objects = len(task_objects)
@ -570,17 +564,15 @@ class ORBITDataset(Dataset):
context_video_ids.extend(cvi)
context_annotations = self.extend_ann_dict(context_annotations, ca)
if with_target_set:
tc, tp, tvi, ta = self.sample_clips_from_videos(target_videos, self.target_clip_method)
target_clips.extend(tc)
target_paths.extend(tp)
target_labels.extend([label for _ in range(len(tp))])
target_video_ids.extend(tvi)
target_annotations = self.extend_ann_dict(target_annotations, ta)
tc, tp, tvi, ta = self.sample_clips_from_videos(target_videos, self.target_clip_method)
target_clips.extend(tc)
target_paths.extend(tp)
target_labels.extend([label for _ in range(len(tp))])
target_video_ids.extend(tvi)
target_annotations = self.extend_ann_dict(target_annotations, ta)
context_clips, context_paths, context_labels, context_annotations = self.prepare_set(context_clips, context_paths, context_labels, context_annotations, context_video_ids)
if with_target_set:
target_clips, target_paths, target_labels, target_annotations = self.prepare_set(target_clips, target_paths, target_labels, target_annotations, target_video_ids, test_mode=self.test_mode)
target_clips, target_paths, target_labels, target_annotations = self.prepare_set(target_clips, target_paths, target_labels, target_annotations, target_video_ids, test_mode=self.test_mode)
task_dict = {
# Data required for train / test
@ -602,7 +594,7 @@ class UserEpisodicORBITDataset(ORBITDataset):
"""
Class for user-centric episodic sampling of ORBIT dataset.
"""
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
"""
Creates instance of UserEpisodicORBITDataset.
:param root: (str) Path to train/validation/test folder in ORBIT dataset root folder.
@ -611,11 +603,10 @@ class UserEpisodicORBITDataset(ORBITDataset):
:param shot_methods: (str, str) Method for sampling videos for context and target sets.
:param shots: (int, int) Number of videos to sample for context and target sets.
:param video_types: (str, str) Video types to sample for context and target sets.
:param subsample_factor: (int) Factor to subsample video frames before sampling clips.
:param subsample_factor: (int) Factor to subsample video frames if sampling frames uniformly.
:param clip_methods: (str, str) Method for sampling clips of contiguous frames from videos for context and target sets.
:param clip_length: (int) Number of contiguous frames per video clip.
:param preload_clips: (bool) If True, preload clips from disk and return as tensors, otherwise return clip paths.
:param frame_size: (int) Size in pixels of preloaded frames.
:param frame_size: (int) Size in pixels of loaded frames.
:param annotations_to_load: (list::str) Types of frame annotations to load from disk and return per task.
:param filter_by_annotations (list::str) Types of frame annotations to filter by for context and target sets.
:param test_mode: (bool) If True, returns task with target set grouped by video, otherwise returns task with target set as clips.
@ -624,25 +615,24 @@ class UserEpisodicORBITDataset(ORBITDataset):
:param logfile: (file object) File for printing out loaded data summaries.
:return: Nothing.
"""
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
def __getitem__(self, index):
"""
Function to get a user-centric task as a set of (context and target) clips and labels.
:param index: (tuple) Task ID and whether to load task target set.
:param index: (tuple) Task index.
:return: (dict) Context and target set data for task.
"""
task_id, with_target_set = index
user = self.users[task_id] # get user (each task == user id)
user = self.users[index] # get user (each task == user id)
user_objects = self.user2objs[user] # get user's objects
return self.sample_task(user_objects, with_target_set, user)
return self.sample_task(user_objects, user)
class ObjectEpisodicORBITDataset(ORBITDataset):
"""
Class for object-centric episodic sampling of ORBIT dataset.
"""
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
def __init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile):
"""
Creates instance of ObjectEpisodicORBITDataset.
:param root: (str) Path to train/validation/test folder in ORBIT dataset root folder.
@ -651,11 +641,10 @@ class ObjectEpisodicORBITDataset(ORBITDataset):
:param shot_methods: (str, str) Method for sampling videos for context and target sets.
:param shots: (int, int) Number of videos to sample for context and target sets.
:param video_types: (str, str) Video types to sample for context and target sets.
:param subsample_factor: (int) Factor to subsample video frames before sampling clips.
:param subsample_factor: (int) Factor to subsample video frames if sampling frames uniformly.
:param clip_methods: (str, str) Method for sampling clips of contiguous frames from videos for context and target sets.
:param clip_length: (int) Number of contiguous frames per video clip.
:param preload_clips: (bool) If True, preload clips from disk and return as tensors, otherwise return clip paths.
:param frame_size: (int) Size in pixels of preloaded frames.
:param frame_size: (int) Size in pixels of loaded frames.
:param annotations_to_load: (list::str) Types of frame annotations to load from disk and return per task.
:param filter_by_annotations (list::str) Types of frame annotations to filter by for context and target sets.
:param test_mode: (bool) If True, returns task with target set grouped by video, otherwise returns task with target set as clips.
@ -664,14 +653,13 @@ class ObjectEpisodicORBITDataset(ORBITDataset):
:param logfile: (file object) File for printing out loaded data summaries.
:return: Nothing.
"""
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
ORBITDataset.__init__(self, root, way_method, object_cap, shot_methods, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
def __getitem__(self, index):
"""
Function to get a object-centric task as a set of (context and target) clips and labels.
:param index: (tuple) Task ID and whether to load task target set.
:param index: (tuple) Task index.
:return: (dict) Context and target set data for task.
"""
_, with_target_set = index
all_objects = range(0, len(self.obj2vids)) # task can consider all possible objects, not just 1 user's objects
return self.sample_task(all_objects, with_target_set)
return self.sample_task(all_objects)

Просмотреть файл

@ -2,7 +2,6 @@
# Licensed under the MIT license.
import torch
from typing import Optional
from data.samplers import TaskSampler
from data.datasets import UserEpisodicORBITDataset, ObjectEpisodicORBITDataset
@ -11,22 +10,17 @@ class DatasetQueue:
Class for a queue of tasks sampled from UserEpisodicORIBTDataset/ObjectEpisodicORBITDataset.
"""
def __init__(self, num_tasks, shuffle, test_mode, override_num_workers: Optional[int]=None):
def __init__(self, num_tasks: int, shuffle: bool, num_workers: int) -> None:
"""
Creates instance of DatasetQueue.
:param num_tasks: (int) Number of tasks per user to add to the queue.
:param shuffle: (bool) If True, shuffle tasks, else do not shuffle.
:param test_mode: (bool) If True, only return target set for first task per user.
:param num_workers: (Optional[int]) Number of workers to use. Overrides defaults (4 if test, 8 otherwise).
:param num_workers: (int) Number of workers to use.
:return: Nothing.
"""
self.num_tasks = num_tasks
self.shuffle = shuffle
self.test_mode = test_mode
if override_num_workers is None:
self.num_workers = 4 if self.test_mode else 8
else:
self.num_workers = override_num_workers
self.num_workers = num_workers
self.num_users = None
self.collate_fn = self.unpack
@ -50,23 +44,23 @@ class DatasetQueue:
dataset=self.dataset,
pin_memory=False,
num_workers=self.num_workers,
sampler=TaskSampler(self.num_tasks, self.num_users, self.shuffle, self.test_mode),
sampler=TaskSampler(self.num_tasks, self.num_users, self.shuffle),
collate_fn=self.collate_fn
)
class UserEpisodicDatasetQueue(DatasetQueue):
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
DatasetQueue.__init__(self, num_tasks, shuffle, test_mode)
self.dataset = UserEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
DatasetQueue.__init__(self, num_tasks, shuffle, num_workers=4 if test_mode else 8)
self.dataset = UserEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
self.num_users = self.dataset.num_users
def __len__(self):
return self.dataset.num_users
class ObjectEpisodicDatasetQueue(DatasetQueue):
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
DatasetQueue.__init__(self, num_tasks, shuffle, test_mode)
self.dataset = ObjectEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, preload_clips, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
def __init__(self, root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, num_tasks, test_mode, with_cluster_labels, with_caps, shuffle, logfile):
DatasetQueue.__init__(self, num_tasks, shuffle, num_workers=4 if test_model else 8)
self.dataset = ObjectEpisodicORBITDataset(root, way_method, object_cap, shot_method, shots, video_types, subsample_factor, clip_methods, clip_length, frame_size, annotations_to_load, filter_by_annotations, test_mode, with_cluster_labels, with_caps, logfile)
self.num_users = self.dataset.num_users
def __len__(self):

Просмотреть файл

@ -9,28 +9,22 @@ class TaskSampler(Sampler):
"""
Sampler class for a fixed number of tasks per user.
"""
def __init__(self, tasks_per_user, num_users, shuffle, test_mode):
def __init__(self, tasks_per_user, num_users, shuffle):
"""
Creates instances of TaskSampler.
:param tasks_per_user: (int) Number of tasks to sample per user.
:param num_users: (int) Total number of users.
:param shuffle: (bool) If True, shuffle tasks, otherwise do not shuffle.
:param test_mode: (bool) If True, only load target set for first task per user.
:return: Nothing.
"""
self.tasks_per_user = tasks_per_user
self.num_users = num_users
self.shuffle = shuffle
self.test_mode = test_mode
def __iter__(self):
task_ids = []
for user in range(self.num_users):
for task in range(self.tasks_per_user):
with_target_set = True
if self.test_mode and task > 0:
with_target_set = False
task_ids.append((user, with_target_set))
task_ids.extend([user]*self.tasks_per_user)
if self.shuffle:
random.shuffle(task_ids)
return iter(task_ids)

54
data/utils.py Normal file
Просмотреть файл

@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import numpy as np
import torch.nn as nn
def attach_frame_history(frames, history_length):
"""
Function to attach the immediate history of history_length frames to each frame in a tensor of frame data.
param frames: (torch.Tensor) Frames.
:param history_length: (int) Number of frames of history to append to each frame.
:return: (torch.Tensor) Frames with attached frame history.
"""
# pad with first frame so that frames 0 to history_length-1 can be evaluated
frame_0 = frames.narrow(0, 0, 1)
frames = torch.cat((frame_0.repeat(history_length-1, 1, 1, 1), frames), dim=0)
# for each frame, attach its immediate history of history_length frames
frames = [ frames ]
for l in range(1, history_length):
frames.append( frames[0].roll(shifts=-l, dims=0) )
frames_with_history = torch.stack(frames, dim=1) # of size num_clips x history_length
if history_length > 1:
return frames_with_history[:-(history_length-1)] # frames have wrapped around, remove last (history_length - 1) frames
else:
return frames_with_history
def unpack_task(task_dict, device, context_to_device=True, target_to_device=False):
context_clips = task_dict['context_clips']
context_paths = task_dict['context_paths']
context_labels = task_dict['context_labels']
context_annotations = task_dict['context_annotations']
target_clips = task_dict['target_clips']
target_paths = task_dict['target_paths']
target_labels = task_dict['target_labels']
target_annotations = task_dict['target_annotations']
object_list = task_dict['object_list']
if context_to_device and isinstance(context_labels, torch.Tensor):
context_labels = context_labels.to(device)
if target_to_device and isinstance(target_labels, torch.Tensor):
target_labels = target_labels.to(device)
return context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list
def get_batch_indices(index, last_element, batch_size):
batch_start_index = index * batch_size
batch_end_index = batch_start_index + batch_size
if batch_end_index > last_element:
batch_end_index = last_element
return batch_start_index, batch_end_index

Просмотреть файл

@ -28,11 +28,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import time
import torch
import numpy as np
import torch.nn as nn
from data.utils import get_batch_indices
from features import extractors
from feature_adapters import FilmAdapter, NullAdapter
from models.poolers import MeanPooler
@ -40,7 +40,6 @@ from models.normalisation_layers import TaskNorm
from models.set_encoder import SetEncoder, NullSetEncoder
from models.classifiers import LinearClassifier, VersaClassifier, PrototypicalClassifier, MahalanobisClassifier
from utils.optim import init_optimizer
from utils.data import get_clip_loader
class FewShotRecogniser(nn.Module):
"""
@ -140,7 +139,6 @@ class FewShotRecogniser(nn.Module):
:return: (torch.Tensor) Adapted frame features flattened across all clips.
"""
self._set_model_state(context)
t1 = time.time()
if self.use_two_gpus:
clips = clips.cuda(1)
features = self.feature_extractor(clips, feature_adapter_params).cuda(0)
@ -148,16 +146,14 @@ class FewShotRecogniser(nn.Module):
features = self.feature_extractor(clips, feature_adapter_params)
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
ops_counter.compute_macs(self.feature_extractor, clips, feature_adapter_params)
return features
def _get_features_in_batches(self, clip_loader, feature_adapter_params, ops_counter=None, context=False):
def _get_features_in_batches(self, clips, feature_adapter_params, ops_counter=None, context=False):
"""
Function that passes clips in batches through an adapted feature extractor to get adapted (and flattened) frame features.
:param clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for clips, each composed of self.clip_length contiguous frames.
:param clips: (torch.Tensor) Clips, each composed of self.clip_length contiguous frames.
:param feature_adapter_params: (list::dict::torch.Tensor or list::dict::list::torch.Tensor) Parameters of all FiLM layers.
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:param context: (bool) True if clips are from context videos, otherwise False.
@ -165,20 +161,32 @@ class FewShotRecogniser(nn.Module):
"""
features = []
self._set_model_state(context)
for batch_clips in clip_loader:
if self.adapt_features:
extract_func = lambda clips: self.feature_extractor(clips, feature_adapter_params)
else:
extract_func = lambda clips: self.feature_extractor(clips)
num_clips = len(clips)
num_batches = int(np.ceil(float(num_clips) / float(self.batch_size)))
for batch in range(num_batches):
batch_start_index, batch_end_index = get_batch_indices(batch, num_clips, self.batch_size)
batch_clips = clips[batch_start_index:batch_end_index]
if len(batch_clips.shape) == 5:
batch_clips = batch_clips.flatten(end_dim=1)
batch_clips = batch_clips.to(self.device, non_blocking=True)
t1 = time.time()
if self.use_two_gpus:
batch_clips = batch_clips.cuda(1)
batch_features = self.feature_extractor(batch_clips, feature_adapter_params).cuda(0)
batch_features = extract_func(batch_clips).cuda(0)
else:
batch_features = self.feature_extractor(batch_clips, feature_adapter_params)
batch_features = extract_func(batch_clips)
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
# TODO add MACs to spatial transformaer
ops_counter.compute_macs(self.feature_extractor, batch_clips, feature_adapter_params)
if self.adapt_features:
ops_counter.compute_macs(self.feature_extractor, batch_clips, feature_adapter_params)
else:
ops_counter.compute_macs(self.feature_extractor, batch_clips)
features.append(batch_features)
@ -191,12 +199,9 @@ class FewShotRecogniser(nn.Module):
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:return: (list::dict::torch.Tensor or list::dict::list::torch.Tensor or None) Parameters of all FiLM layers.
"""
t1 = time.time()
feature_adapter_params = self.feature_adapter(task_embedding)
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
ops_counter.compute_macs(self.feature_adapter, task_embedding)
return feature_adapter_params
@ -209,20 +214,17 @@ class FewShotRecogniser(nn.Module):
:param reduction: (str) Method to aggregate clip encodings from self.set_encoder.
:return: (torch.Tensor or None) Task embedding.
"""
t1 = time.time()
reps = self.set_encoder(context_clips)
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
ops_counter.compute_macs(self.set_encoder, context_clips)
return self.set_encoder.aggregate(reps, reduction=reduction, switch_device=self.use_two_gpus)
def _get_task_embedding_in_batches(self, context_clip_loader, ops_counter=None, reduction='mean'):
def _get_task_embedding_in_batches(self, context_clips, ops_counter=None, reduction='mean'):
"""
Function that passes all of a task's context set through the set encoder to get a task embedding.
:param context_clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for context clips, each composed of self.clip_length contiguous frames.
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:param reduction: (str) Method to aggregate clip encodings from self.set_encoder.
:return: (torch.Tensor or None) Task embedding.
@ -231,14 +233,16 @@ class FewShotRecogniser(nn.Module):
return None
reps = []
for batch_clips in context_clip_loader:
num_clips = len(context_clips)
num_batches = int(np.ceil(float(num_clips) / float(self.batch_size)))
for batch in range(num_batches):
batch_start_index, batch_end_index = get_batch_indices(batch, num_clips, self.batch_size)
batch_clips = context_clips[batch_start_index:batch_end_index]
batch_clips = batch_clips.to(self.device, non_blocking=True)
t1 = time.time()
batch_reps = self.set_encoder(batch_clips)
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
ops_counter.compute_macs(self.set_encoder, batch_clips)
reps.append(batch_reps)
@ -252,11 +256,8 @@ class FewShotRecogniser(nn.Module):
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:return: (torch.Tensor) Frame features pooled per clip i.e. as (num_clips) x (feat_dim).
"""
t1 = time.time()
pooled_features = self.frame_pooler(features)
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
ops_counter.add_macs(features.size(0) * features.size(1))
return pooled_features
@ -323,68 +324,54 @@ class MultiStepFewShotRecogniser(FewShotRecogniser):
self.num_grad_steps = num_grad_steps
def personalise(self, context_clips, context_clip_labels, learning_args, ops_counter=None):
def personalise(self, context_clips, context_labels, learning_args, ops_counter=None):
"""
Function that learns a new task by taking a fixed number of gradient steps on the task's context set. For each task, a new linear classification layer is added (and FiLM layers if self.adapt_features == True).
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
:param context_clip_labels: (torch.Tensor) Video-level labels for each context clip.
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
:param context_labels: (torch.Tensor) Video-level labels for each context clip.
:param learning_args: (float, func, str, float) Learning hyper-parameters including learning rate, loss function, optimiser type and factor to scale the extractor's learning rate.
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:return: Nothing.
"""
lr, loss_fn, optimizer_type, extractor_scale_factor = learning_args
num_classes = len(torch.unique(context_clip_labels))
num_classes = len(torch.unique(context_labels))
self.init_classifier(num_classes)
self.init_feature_adapter()
inner_loop_optimizer = init_optimizer(self, lr, optimizer_type, extractor_scale_factor)
context_clip_loader = get_clip_loader((context_clips, context_clip_labels), self.batch_size, with_labels=True)
batch_context_set_size = len(context_labels)
num_batches = int(np.ceil(float(batch_context_set_size) / float(self.batch_size)))
for _ in range(self.num_grad_steps):
for batch_context_clips, batch_context_labels in context_clip_loader:
batch_context_clips = batch_context_clips.to(self.device)
batch_context_labels = batch_context_labels.to(self.device)
batch_context_logits = self.predict_a_batch(batch_context_clips, ops_counter=ops_counter, context=True)
t1 = time.time()
batch_context_loss = loss_fn(batch_context_logits, batch_context_labels)
batch_context_loss.backward()
for batch in range(num_batches):
batch_start_index, batch_end_index = get_batch_indices(batch, batch_context_set_size, self.batch_size)
batch_context_clips = context_clips[batch_start_index:batch_end_index].to(self.device)
batch_context_labels = context_labels[batch_start_index:batch_end_index].to(self.device)
batch_len = len(context_labels[batch_start_index:batch_end_index])
feature_adapter_params = self._get_feature_adapter_params(None, ops_counter)
batch_context_features = self._get_features(batch_context_clips, feature_adapter_params, ops_counter, context=True)
batch_context_features = self._pool_features(batch_context_features, ops_counter)
batch_context_logits = self.classifier.predict(batch_context_features, ops_counter)
loss = loss_fn(batch_context_logits, batch_context_labels)
loss += 0.001 * self.feature_adapter.regularization_term(switch_device=self.use_two_gpus)
loss *= batch_len/batch_context_set_size
loss.backward()
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
t1 = time.time()
inner_loop_optimizer.step()
inner_loop_optimizer.zero_grad()
if ops_counter:
torch.cuda.synchronize()
ops_counter.log_time(time.time() - t1)
def predict(self, clips, ops_counter=None, context=False):
"""
Function that processes target clips in batches to get logits over object classes for each clip.
:param clips: (np.ndarray or torch.Tensor) Clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:param context: (bool) True if a context set is being processed, otherwise False.
:return: (torch.Tensor) Logits over object classes for each clip in clips.
"""
clip_loader = get_clip_loader(clips, self.batch_size)
task_embedding = None # multi-step methods do not use set encoder
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
features = self._get_features_in_batches(clip_loader, self.feature_adapter_params, ops_counter, context=context)
features = self._pool_features(features, ops_counter)
return self.classifier.predict(features, ops_counter)
def predict_a_batch(self, clips, ops_counter=None, context=False):
"""
Function that processes a batch of clips to get logits over object classes for each clip.
:param clips: (torch.Tensor) Tensor of clips, each composed of self.clip_length contiguous frames.
:param clips: (torch.Tensor) Clips, each composed of self.clip_length contiguous frames.
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:param context: (bool) True if a context set is being processed, otherwise False.
:return: (torch.Tensor) Logits over object classes for each clip in clips.
"""
task_embedding = None # multi-step methods do not use set encoder
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
features = self._get_features(clips, self.feature_adapter_params, ops_counter, context=context)
feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
features = self._get_features_in_batches(clips, feature_adapter_params, ops_counter, context=context)
features = self._pool_features(features, ops_counter)
return self.classifier.predict(features, ops_counter)
@ -424,75 +411,72 @@ class SingleStepFewShotRecogniser(FewShotRecogniser):
def personalise(self, context_clips, context_labels, ops_counter=None):
"""
Function that learns a new task by performing a forward pass of the task's context set.
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
:param context_clips: (torch.Tensor) Context clips each composed of self.clip_length contiguous frames.
:param context_labels: (torch.Tensor) Video-level labels for each context clip.
:param ops_counter: (utils.OpsCounter or None) Object that counts operations performed.
:return: Nothing.
"""
context_clip_loader = get_clip_loader(context_clips, self.batch_size)
task_embedding = self._get_task_embedding_in_batches(context_clip_loader, ops_counter)
task_embedding = self._get_task_embedding_in_batches(context_clips, ops_counter)
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding, ops_counter)
context_features = self._get_features_in_batches(context_clip_loader, self.feature_adapter_params, ops_counter, context=True)
context_features = self._get_features_in_batches(context_clips, self.feature_adapter_params, ops_counter, context=True)
self.context_features = self._pool_features(context_features, ops_counter)
self.classifier.configure(self.context_features, context_labels, ops_counter)
def personalise_with_lite(self, context_clips, context_labels):
"""
Function that learns a new task by performning a forward pass of the task's context set with LITE. Namely a random subset of the context set (self.num_lite_samples) is processed with back-propagation enabled, while the remainder is processed with back-propagation disabled.
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
:param context_labels: (torch.Tensor) Video-level labels for each context clip.
:return: Nothing.
"""
shuffled_idxs = np.random.permutation(len(context_clips))
H = self.num_lite_samples
context_clip_loader = get_clip_loader(context_clips[shuffled_idxs][:H], self.batch_size)
task_embedding = self._get_task_embedding_with_lite(context_clip_loader, shuffled_idxs)
task_embedding = self._get_task_embedding_with_lite(context_clips[shuffled_idxs][:H], shuffled_idxs)
self.feature_adapter_params = self._get_feature_adapter_params(task_embedding)
self.context_features = self._get_pooled_features_with_lite(context_clip_loader, shuffled_idxs)
self.context_features = self._get_pooled_features_with_lite(context_clips[shuffled_idxs][:H], shuffled_idxs)
self.classifier.configure(self.context_features, context_labels[shuffled_idxs])
def _cache_context_outputs(self, context_clips):
"""
Function that performs a forward pass with a task's entire context set with back-propagation disabled and caches the individual 1) encodings from the set encoder and 2) adapted features from the adapted feature extractor, for each clip.
:param context_clips: (np.ndarray or torch.Tensor) Context clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
:return: Nothing.
"""
context_clip_loader = get_clip_loader(context_clips, self.batch_size)
with torch.set_grad_enabled(False):
# cache encoding for each clip from self.set_encoder
self.cached_set_encoder_reps = self._get_task_embedding_in_batches(context_clip_loader, reduction='none')
self.cached_set_encoder_reps = self._get_task_embedding_in_batches(context_clips, reduction='none')
# get feature adapter parameters
task_embedding = self.set_encoder.mean_pool(self.cached_set_encoder_reps)
feature_adapter_params = self._get_feature_adapter_params(task_embedding)
# cache adapted features for each clip
context_features = self._get_features_in_batches(context_clip_loader, feature_adapter_params, context=True)
context_features = self._get_features_in_batches(context_clips, feature_adapter_params, context=True)
self.cached_context_features = self._pool_features(context_features)
def _get_task_embedding_with_lite(self, context_clip_loader, idxs):
def _get_task_embedding_with_lite(self, context_clips, idxs):
"""
Function that passes all of a task's context set through the set encoder to get a task embedding with LITE.
:param context_clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for context clips, each composed of self.clip_length contiguous frames.
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
:param idxs: (torch.Tensor) Indicies of elements in context_clips to process with back-propagation enabled.
:return: (torch.Tensor or None) Task embedding.
"""
if isinstance(self.set_encoder, NullSetEncoder):
return None
H = self.num_lite_samples
task_embedding_with_grads = self._get_task_embedding_in_batches(context_clip_loader, reduction='none')
task_embedding_with_grads = self._get_task_embedding_in_batches(context_clips, reduction='none')
task_embedding_without_grads = self.cached_set_encoder_reps[idxs][H:]
return torch.cat((task_embedding_with_grads, task_embedding_without_grads)).mean(dim=0)
def _get_pooled_features_with_lite(self, context_clip_loader, idxs):
def _get_pooled_features_with_lite(self, context_clips, idxs):
"""
Function that gets adapted clip features for a task's context set with LITE.
:param context_clip_loader: (torch.utils.data.DataLoader or torch.Tensor) Loader for context clips, each composed of self.clip_length contiguous frames.
:param context_clips: (torch.Tensor) Context clips, each composed of self.clip_length contiguous frames.
:param idxs: (torch.Tensor) Indicies of elements in context_clips to process with back-propagation enabled.
:return: (torch.Tensor) Adapted frame features pooled per clip i.e. as (num_clips) x (feat_dim).
"""
H = self.num_lite_samples
context_features_with_grads = self._get_features_in_batches(context_clip_loader, self.feature_adapter_params, context=True)
context_features_with_grads = self._get_features_in_batches(context_clips, self.feature_adapter_params, context=True)
context_features_with_grads = self._pool_features(context_features_with_grads)
context_features_without_grads = self.cached_context_features[idxs][H:]
return torch.cat((context_features_with_grads, context_features_without_grads))
@ -500,11 +484,10 @@ class SingleStepFewShotRecogniser(FewShotRecogniser):
def predict(self, target_clips):
"""
Function that processes target clips in batches to get logits over object classes for each clip.
:param target_clips: (np.ndarray or torch.Tensor) Target clips (either as paths or tensors), each composed of self.clip_length contiguous frames.
:param target_clips: (torch.Tensor) Target clips, each composed of self.clip_length contiguous frames.
:return: (torch.Tensor) Logits over object classes for each clip in target_clips.
"""
target_clip_loader = get_clip_loader(target_clips, self.batch_size)
target_features = self._get_features_in_batches(target_clip_loader, self.feature_adapter_params)
target_features = self._get_features_in_batches(target_clips, self.feature_adapter_params)
target_features = self._pool_features(target_features)
return self.classifier.predict(target_features)

Просмотреть файл

@ -36,11 +36,11 @@ import numpy as np
import torch.backends.cudnn as cudnn
from data.dataloaders import DataLoader
from data.utils import unpack_task, attach_frame_history
from models.few_shot_recognisers import MultiStepFewShotRecogniser
from utils.args import parse_args
from utils.ops_counter import OpsCounter
from utils.optim import cross_entropy, init_optimizer
from utils.data import unpack_task, attach_frame_history
from utils.logging import print_and_log, get_log_files, stats_to_str
from utils.eval_metrics import TrainEvaluator, ValidationEvaluator, TestEvaluator
@ -105,7 +105,6 @@ class Learner:
'frame_size': self.args.frame_size,
'annotations_to_load': self.args.annotations_to_load,
'filter_by_annotations': [self.args.filter_context, self.args.filter_target],
'preload_clips': self.args.preload_clips,
'logfile': self.logfile
}
@ -199,9 +198,9 @@ class Learner:
def train_task(self, task_dict):
context_clips, context_frames, context_labels, target_clips, target_frames, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True, preload_clips=self.args.preload_clips)
context_clips, context_frames, context_labels, target_clips, target_frames, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True)
joint_context_clips = torch.cat((context_clips, target_clips)) if self.args.preload_clips else np.concatenate((context_clips, target_clips))
joint_context_clips = torch.cat((context_clips, target_clips))
joint_context_labels = torch.cat((context_labels, target_labels), dim=0)
joint_context_logits = self.model.predict(joint_context_clips, context=True)
self.train_evaluator.update_stats(joint_context_logits, joint_context_labels)
@ -217,7 +216,7 @@ class Learner:
# loop through validation tasks (num_validation_users * num_test_tasks_per_user)
num_val_tasks = len(self.validation_queue) * self.args.num_test_tasks
for step, task_dict in enumerate(self.validation_queue.get_tasks()):
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False, preload_clips=self.args.preload_clips)
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False)
num_context_clips = len(context_clips)
self.validation_evaluator.set_task_object_list(object_list)
self.validation_evaluator.set_task_context_paths(context_paths)
@ -279,7 +278,7 @@ class Learner:
# loop through test tasks (num_test_users * num_test_tasks_per_user)
num_test_tasks = len(self.test_queue) * self.args.num_test_tasks
for step, task_dict in enumerate(self.test_queue.get_tasks()):
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False, preload_clips=self.args.preload_clips)
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False)
num_context_clips = len(context_clips)
self.test_evaluator.set_task_object_list(object_list)
self.test_evaluator.set_task_context_paths(context_paths)

Просмотреть файл

@ -36,11 +36,11 @@ import numpy as np
import torch.backends.cudnn as cudnn
from data.dataloaders import DataLoader
from data.utils import get_batch_indices, unpack_task, attach_frame_history
from models.few_shot_recognisers import SingleStepFewShotRecogniser
from utils.args import parse_args
from utils.ops_counter import OpsCounter
from utils.optim import cross_entropy, init_optimizer
from utils.data import get_clip_loader, unpack_task, attach_frame_history
from utils.logging import print_and_log, get_log_files, stats_to_str
from utils.eval_metrics import TrainEvaluator, ValidationEvaluator, TestEvaluator
@ -107,7 +107,6 @@ class Learner:
'frame_size': self.args.frame_size,
'annotations_to_load': self.args.annotations_to_load,
'filter_by_annotations': [self.args.filter_context, self.args.filter_target],
'preload_clips': self.args.preload_clips,
'logfile': self.logfile
}
@ -190,7 +189,7 @@ class Learner:
self.logfile.close()
def train_task(self, task_dict):
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True, preload_clips=self.args.preload_clips)
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True)
self.model.personalise(context_clips, context_labels)
target_logits = self.model.predict(target_clips)
@ -206,32 +205,37 @@ class Learner:
return task_loss
def train_task_with_lite(self, task_dict):
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device, preload_clips=self.args.preload_clips)
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device)
# compute and save personalise outputs of whole context set with back-propagation disabled
self.model._cache_context_outputs(context_clips)
task_loss = 0
target_logits = []
target_clip_loader = get_clip_loader((target_clips, target_labels), self.args.batch_size, with_labels=True)
for batch_target_clips, batch_target_labels in target_clip_loader:
target_logits, target_boxes_pred = [], []
num_clips = len(target_clips)
num_batches = int(np.ceil(float(num_clips) / float(self.args.batch_size)))
for batch in range(num_batches):
self.model.personalise_with_lite(context_clips, context_labels)
batch_target_clips = batch_target_clips.to(device=self.device)
batch_target_labels = batch_target_labels.to(device=self.device)
batch_target_logits = self.model.predict_a_batch(batch_target_clips)
batch_start_index, batch_end_index = get_batch_indices(batch, num_clips, self.args.batch_size)
batch_target_clips = target_clips[batch_start_index:batch_end_index].to(device=self.device)
batch_target_labels = target_labels[batch_start_index:batch_end_index].to(device=self.device)
batch_target_logits = self.model.predict_a_batch(batch_target_clips)
target_logits.extend(batch_target_logits.detach())
loss_scaling = len(context_labels) / (self.args.num_lite_samples * self.args.tasks_per_batch)
batch_loss = loss_scaling * self.loss(batch_target_logits, batch_target_labels)
batch_loss += 0.001 * self.model.feature_adapter.regularization_term(switch_device=self.args.use_two_gpus)
batch_loss += 0.001 * self.model.feature_adapter.regularization_term(switch_device=self.args.use_two_gpus)
batch_loss.backward(retain_graph=False)
task_loss += batch_loss.detach()
# reset task's params
self.model._reset()
target_logits = torch.stack(target_logits)
self.train_evaluator.update_stats(target_logits, target_labels)
return task_loss
def validate(self):
@ -241,7 +245,7 @@ class Learner:
# loop through validation tasks (num_validation_users * num_test_tasks_per_user)
num_val_tasks = len(self.validation_queue) * self.args.num_test_tasks
for step, task_dict in enumerate(self.validation_queue.get_tasks()):
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, preload_clips=self.args.preload_clips)
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device)
num_context_clips = len(context_clips)
self.validation_evaluator.set_task_object_list(object_list)
self.validation_evaluator.set_task_context_paths(context_paths)
@ -293,7 +297,7 @@ class Learner:
# loop through test tasks (num_test_users * num_test_tasks_per_user)
num_test_tasks = len(self.test_queue) * self.args.num_test_tasks
for step, task_dict in enumerate(self.test_queue.get_tasks()):
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, preload_clips=self.args.preload_clips)
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device)
num_context_clips = len(context_clips)
self.test_evaluator.set_task_object_list(object_list)
self.test_evaluator.set_task_context_paths(context_paths)

Просмотреть файл

@ -80,8 +80,6 @@ def parse_args(learner='default'):
help="Method to sample clips per target video for a test/validation task (default: max).")
parser.add_argument("--clip_length", type=int, default=1,
help="Number of frames to sample per clip (default: 1).")
parser.add_argument("--no_preload_clips", action="store_true",
help="Do not preload clips per task from disk. Use if CPU memory is limited, but will mean slower training/testing.")
parser.add_argument("--frame_size", type=int, default=224, choices=[84, 224],
help="Frame size (default: 224).")
parser.add_argument("--annotations_to_load", nargs='+', type=str, default=[], choices=FRAME_ANNOTATION_OPTIONS+BOUNDING_BOX_OPTIONS,
@ -130,7 +128,6 @@ def parse_args(learner='default'):
help="Learning rate for inner loop (MAML) or fine-tuning (FineTuner) (default: 0.1).")
args = parser.parse_args()
args.preload_clips = not args.no_preload_clips
verify_args(learner, args)
return args
@ -140,9 +137,6 @@ def verify_args(learner, args):
cyellow = "\33[33m"
cend = "\33[0m"
if len(args.annotations_to_load) and args.no_preload_clips:
sys.exit('{:}error: loading annotations with --annotations_to_load is currently not supported with --no_preload_clips{:}'.format(cred, cend))
if 'train' in args.mode and not args.learn_extractor and not args.adapt_features:
sys.exit('{:}error: at least one of "--learn_extractor" and "--adapt_features" must be used during training{:}'.format(cred, cend))

Просмотреть файл

@ -1,133 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as tv_F
class DatasetFromClipPaths(Dataset):
def __init__(self, clip_paths, with_labels):
super().__init__()
#TODO currently doesn't support loading of annotations
self.with_labels = with_labels
if self.with_labels:
self.clip_paths, self.clip_labels = clip_paths
else:
self.clip_paths = clip_paths
self.normalize_stats = {'mean' : [0.500, 0.436, 0.396], 'std' : [0.145, 0.143, 0.138]} # orbit mean train frame
def __getitem__(self, index):
clip = []
for frame_path in self.clip_paths[index]:
frame = self.load_and_transform_frame(frame_path)
clip.append(frame)
if self.with_labels:
return torch.stack(clip, dim=0), self.clip_labels[index]
else:
return torch.stack(clip, dim=0)
def load_and_transform_frame(self, frame_path):
"""
Function to load and transform frame.
:param frame_path: (str) Path to frame.
:return: (torch.Tensor) Loaded and transformed frame.
"""
frame = Image.open(frame_path)
frame = tv_F.to_tensor(frame)
return tv_F.normalize(frame, mean=self.normalize_stats['mean'], std=self.normalize_stats['std'])
def __len__(self):
return len(self.clip_paths)
def get_clip_loader(clips, batch_size, with_labels=False):
if isinstance(clips[0], np.ndarray):
clips_dataset = DatasetFromClipPaths(clips, with_labels=with_labels)
return DataLoader(clips_dataset,
batch_size=batch_size,
num_workers=8,
pin_memory=True,
prefetch_factor=8,
persistent_workers=True)
elif isinstance(clips[0], torch.Tensor):
if with_labels:
return list(zip(clips[0].split(batch_size), clips[1].split(batch_size)))
else:
return clips.split(batch_size)
def attach_frame_history(frames, history_length):
if isinstance(frames, np.ndarray):
return attach_frame_history_paths(frames, history_length)
elif isinstance(frames, torch.Tensor):
return attach_frame_history_tensor(frames, history_length)
def attach_frame_history_paths(frame_paths, history_length):
"""
Function to attach the immediate history of history_length frames to each frame in an array of frame paths.
:param frame_paths: (np.ndarray) Frame paths.
:param history_length: (int) Number of frames of history to append to each frame.
:return: (np.ndarray) Frame paths with attached frame history.
"""
# pad with first frame so that frames 0 to history_length-1 can be evaluated
frame_paths = np.concatenate([np.repeat(frame_paths[0], history_length-1), frame_paths])
# for each frame path, attach its immediate history of history_length frames
frame_paths = [ frame_paths ]
for l in range(1, history_length):
frame_paths.append( np.roll(frame_paths[0], shift=-l, axis=0) )
frame_paths_with_history = np.stack(frame_paths, axis=1) # of size num_clips x history_length
if history_length > 1:
return frames_with_history[:-(history_length-1)] # frames have wrapped around, remove last (history_length - 1) frames
else:
return frames_with_history
def attach_frame_history_tensor(frames, history_length):
"""
Function to attach the immediate history of history_length frames to each frame in a tensor of frame data.
param frames: (torch.Tensor) Frames.
:param history_length: (int) Number of frames of history to append to each frame.
:return: (torch.Tensor) Frames with attached frame history.
"""
# pad with first frame so that frames 0 to history_length-1 can be evaluated
frame_0 = frames.narrow(0, 0, 1)
frames = torch.cat((frame_0.repeat(history_length-1, 1, 1, 1), frames), dim=0)
# for each frame, attach its immediate history of history_length frames
frames = [ frames ]
for l in range(1, history_length):
frames.append( frames[0].roll(shifts=-l, dims=0) )
frames_with_history = torch.stack(frames, dim=1) # of size num_clips x history_length
if history_length > 1:
return frames_with_history[:-(history_length-1)] # frames have wrapped around, remove last (history_length - 1) frames
else:
return frames_with_history
def unpack_task(task_dict, device, context_to_device=True, target_to_device=False, preload_clips=False):
context_clips = task_dict['context_clips']
context_paths = task_dict['context_paths']
context_labels = task_dict['context_labels']
context_annotations = task_dict['context_annotations']
target_clips = task_dict['target_clips']
target_paths = task_dict['target_paths']
target_labels = task_dict['target_labels']
target_annotations = task_dict['target_annotations']
object_list = task_dict['object_list']
if context_to_device and isinstance(context_labels, torch.Tensor):
context_labels = context_labels.to(device)
if target_to_device and isinstance(target_labels, torch.Tensor):
target_labels = target_labels.to(device)
if preload_clips:
return context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list
else:
return context_paths, context_paths, context_labels, target_paths, target_paths, target_labels, object_list

Просмотреть файл

@ -7,7 +7,8 @@ from datetime import datetime
def print_and_log(log_file, message):
print(message)
log_file.write(message + '\n')
if log_file:
log_file.write(message + '\n')
def get_log_files(checkpoint_dir, model_path):
"""