ORBIT-Dataset/single-step-learner.py

394 строки
21 KiB
Python

"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license
This code was based on the file run_cnaps.py (https://github.com/cambridge-mlg/cnaps/blob/master/src/run_cnaps.py)
from the cambridge-mlg/cnaps library (https://github.com/cambridge-mlg/cnaps).
The original license is included below:
Copyright (c) 2019 John Bronskill, Jonathan Gordon, and James Requeima.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import os
import time
import torch
import random
import numpy as np
import torch.backends.cudnn as cudnn
from data.dataloaders import DataLoader
from data.utils import get_batch_indices, unpack_task, attach_frame_history
from model.few_shot_recognisers import SingleStepFewShotRecogniser
from utils.args import parse_args
from utils.optim import cross_entropy, init_optimizer, init_scheduler, get_curr_learning_rates
from utils.logging import print_and_log, get_log_files, stats_to_str
from utils.eval_metrics import TrainEvaluator, ValidationEvaluator, TestEvaluator
torch.multiprocessing.set_sharing_strategy('file_system')
def main():
learner = Learner()
learner.run()
class Learner:
def __init__(self):
self.args = parse_args()
self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
= get_log_files(self.args.checkpoint_dir, self.args.model_path)
print_and_log(self.logfile, "Options: %s\n" % self.args)
print_and_log(self.logfile, "Checkpoint Directory: %s\n" % self.checkpoint_dir)
random.seed(self.args.seed)
torch.manual_seed(self.args.seed)
device_id='cpu'
if torch.cuda.is_available() and self.args.gpu>=0:
cudnn.enabled = True
cudnn.benchmark = False
cudnn.deterministic = True
device_id = 'cuda:' + str(self.args.gpu)
torch.cuda.manual_seed_all(self.args.seed)
self.device = torch.device(device_id)
self.init_dataset()
self.init_model()
self.init_evaluators()
self.loss = cross_entropy
self.train_task_fn = self.train_task_with_lite if self.args.with_lite else self.train_task
print_and_log(self.logfile, f"Model details:\n" \
f"\tfeature extractor: {self.args.feature_extractor} (pretrained: True, learnable: {self.args.learn_extractor}, generate film params: {self.args.adapt_features})\n" \
f"\tclassifier: {self.args.classifier} with logit scale={self.args.logit_scale}\n")
def init_dataset(self):
dataset_info = {
'mode': self.args.mode,
'data_path': self.args.data_path,
'train_object_cap': self.args.train_object_cap,
'test_object_cap': self.args.test_object_cap,
'with_train_shot_caps': self.args.with_train_shot_caps,
'with_cluster_labels': False,
'train_way_method' : self.args.train_way_method,
'test_way_method' : self.args.test_way_method,
'train_shot_methods' : [self.args.train_context_shot_method, self.args.train_target_shot_method],
'test_shot_methods' : [self.args.test_context_shot_method, self.args.test_target_shot_method],
'num_train_tasks': self.args.num_train_tasks,
'num_val_tasks': self.args.num_val_tasks,
'num_test_tasks': self.args.num_test_tasks,
'train_task_type' : self.args.train_task_type,
'test_set': self.args.test_set,
'shots' : [self.args.context_shot, self.args.target_shot],
'video_types' : [self.args.context_video_type, self.args.target_video_type],
'clip_length': self.args.clip_length,
'train_clip_methods': [self.args.train_context_clip_method, self.args.train_target_clip_method],
'test_clip_methods': [self.args.test_context_clip_method, self.args.test_target_clip_method],
'subsample_factor': self.args.subsample_factor,
'frame_size': self.args.frame_size,
'frame_norm_method': self.args.frame_norm_method,
'annotations_to_load': self.args.annotations_to_load,
'train_filter_by_annotations': [self.args.train_filter_context, self.args.train_filter_target],
'test_filter_by_annotations': [self.args.test_filter_context, self.args.test_filter_target],
'logfile': self.logfile
}
dataloader = DataLoader(dataset_info)
self.train_queue = dataloader.get_train_queue()
self.validation_queue = dataloader.get_validation_queue()
self.test_queue = dataloader.get_test_queue()
def init_model(self):
self.model = SingleStepFewShotRecogniser(
self.args.feature_extractor, self.args.adapt_features, self.args.classifier, self.args.clip_length,
self.args.batch_size, self.args.learn_extractor, self.args.num_lite_samples, self.args.logit_scale)
self.model._set_device(self.device)
self.model._send_to_device()
def init_evaluators(self) -> None:
self.train_metrics = ['frame_acc']
self.evaluation_metrics = ['frame_acc']
self.train_evaluator = TrainEvaluator(self.train_metrics)
self.validation_evaluator = ValidationEvaluator(self.evaluation_metrics)
self.test_evaluator = TestEvaluator(self.evaluation_metrics, self.checkpoint_dir, with_ops_counter=True)
def run(self):
if self.args.mode == 'train' or self.args.mode == 'train_test':
self.optimizer = init_optimizer(self.model, self.args.learning_rate, self.args.optimizer, self.args, extractor_lr_scale=self.args.extractor_lr_scale)
self.scheduler = init_scheduler(self.optimizer, self.args)
num_updates = 0
for epoch in range(self.args.epochs):
losses = []
since = time.time()
torch.set_grad_enabled(True)
self.model.set_test_mode(False)
train_tasks = self.train_queue.get_tasks()
total_steps = len(train_tasks)
for step, task_dict in enumerate(train_tasks):
t1 = time.time()
task_loss = self.train_task_fn(task_dict)
task_time = time.time() - t1
losses.append(task_loss.detach())
if self.args.print_by_step:
current_stats_str = stats_to_str(self.train_evaluator.get_current_stats())
print_and_log(self.logfile, f'epoch [{epoch+1}/{self.args.epochs}][{step+1}/{total_steps}], train loss: {task_loss.item():.7f}, {current_stats_str.strip()}, time/task: {int(task_time/60):d}m{int(task_time%60):02d}s')
if ((step + 1) % self.args.tasks_per_batch == 0) or (step == (total_steps - 1)):
self.optimizer.step()
self.optimizer.zero_grad()
num_updates += 1
self.scheduler.step_update(num_updates)
mean_stats = self.train_evaluator.get_mean_stats()
mean_epoch_loss = torch.Tensor(losses).mean().item()
lr, fe_lr = get_curr_learning_rates(self.optimizer)
seconds = time.time() - since
# print
print_and_log(self.logfile, '-'*150)
print_and_log(self.logfile, f'epoch [{epoch+1}/{self.args.epochs}] train loss: {mean_epoch_loss:.7f} {stats_to_str(mean_stats)} lr: {lr:.3e} fe-lr: {fe_lr:.3e} time/epoch: {int(seconds/60):d}m{int(seconds%60):02d}s')
print_and_log(self.logfile, '-'*150)
self.train_evaluator.reset()
self.save_checkpoint(epoch+1)
self.scheduler.step(epoch+1)
# validate
if (epoch + 1) >= self.args.validation_on_epoch:
self.validate()
# save the final model
torch.save(self.model.state_dict(), self.checkpoint_path_final)
if self.args.mode == 'train_test':
self.test(self.checkpoint_path_final, save_evaluator=False)
self.test(self.checkpoint_path_validation)
if self.args.mode == 'test':
self.test(self.args.model_path)
self.logfile.close()
def train_task(self, task_dict):
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device, target_to_device=True)
self.model.personalise(context_clips, context_labels)
target_logits = self.model.predict(target_clips)
self.train_evaluator.update_stats(target_logits, target_labels)
task_loss = self.loss(target_logits, target_labels) / self.args.tasks_per_batch
task_loss += 0.001 * self.model.film_generator.regularization_term()
task_loss.backward(retain_graph=False)
# reset task's params
self.model._reset()
return task_loss
def train_task_with_lite(self, task_dict):
context_clips, context_paths, context_labels, target_clips, target_paths, target_labels, object_list = unpack_task(task_dict, self.device)
self.model._clear_caches()
task_loss = 0
target_logits, target_boxes_pred = [], []
num_clips = len(target_clips)
num_batches = int(np.ceil(float(num_clips) / float(self.args.batch_size)))
for batch in range(num_batches):
self.model.personalise_with_lite(context_clips, context_labels)
batch_start_index, batch_end_index = get_batch_indices(batch, num_clips, self.args.batch_size)
batch_target_clips = target_clips[batch_start_index:batch_end_index]
batch_target_labels = target_labels[batch_start_index:batch_end_index]
batch_target_logits = self.model.predict_a_batch(batch_target_clips)
target_logits.extend(batch_target_logits.detach())
loss_scaling = len(context_labels) / (self.args.num_lite_samples * self.args.tasks_per_batch)
batch_loss = loss_scaling * self.loss(batch_target_logits, batch_target_labels.to(self.device))
batch_loss += 0.001 * self.model.film_generator.regularization_term()
batch_loss.backward(retain_graph=False)
task_loss += batch_loss.detach()
# reset task's params
self.model._reset()
target_logits = torch.stack(target_logits)
self.train_evaluator.update_stats(target_logits, target_labels)
return task_loss
def validate(self):
self.model.set_test_mode(True)
num_context_clips_per_task, num_target_clips_per_task = [], []
with torch.no_grad():
# loop through validation tasks (num_validation_users * num_val_tasks)
num_val_tasks = len(self.validation_queue) * self.args.num_val_tasks
for step, task_dict in enumerate(self.validation_queue.get_tasks()):
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device)
num_context_clips = len(context_clips)
self.validation_evaluator.set_task_object_list(object_list)
self.validation_evaluator.set_task_context_paths(context_paths)
self.model.personalise(context_clips, context_labels)
# loop through cached target videos for the current task
num_target_clips = 0
video_iterator = zip(target_frames_by_video, target_paths_by_video, target_labels_by_video)
for video_frames, video_paths, video_label in video_iterator:
video_clips = attach_frame_history(video_frames, self.args.clip_length)
video_logits = self.model.predict(video_clips)
self.validation_evaluator.append_video(video_logits, video_label, video_paths)
num_target_clips += len(video_clips)
# reset task's params
self.model._reset()
# log number of clips per task
num_context_clips_per_task.append(num_context_clips)
num_target_clips_per_task.append(num_target_clips)
# if this is the user's last task, get the average performance for the user over all their tasks
if (step+1) % self.args.num_val_tasks == 0:
self.validation_evaluator.set_current_user(task_dict["task_id"])
_,_,_,current_video_stats = self.validation_evaluator.get_mean_stats(current_user=True)
print_and_log(self.logfile, f'validation user {task_dict["task_id"]} ({self.validation_evaluator.current_user+1}/{len(self.validation_queue)}) stats: {stats_to_str(current_video_stats)} avg # context clips/task: {np.mean(num_context_clips_per_task):.0f} avg # target clips/task: {np.mean(num_target_clips_per_task):.0f}')
if (step+1) < num_val_tasks:
num_context_clips_per_task, num_target_clips_per_task = [], [] # reset per user
self.validation_evaluator.next_user()
else:
self.validation_evaluator.next_task()
stats_per_user, stats_per_obj, stats_per_task, stats_per_video = self.validation_evaluator.get_mean_stats()
stats_per_user_str, stats_per_obj_str, stats_per_task_str, stats_per_video_str = stats_to_str(stats_per_user), stats_to_str(stats_per_obj), stats_to_str(stats_per_task), stats_to_str(stats_per_video)
print_and_log(self.logfile, f'validation\n per-user stats: {stats_per_user_str}\n per-object stats: {stats_per_obj_str}\n per-task stats: {stats_per_task_str}\n per-video stats: {stats_per_video_str}\n')
# save the model if validation is the best so far
if self.validation_evaluator.is_better(stats_per_video):
self.validation_evaluator.replace(stats_per_video)
torch.save(self.model.state_dict(), self.checkpoint_path_validation)
print_and_log(self.logfile, 'best validation model was updated.\n')
self.validation_evaluator.reset()
def test(self, path, save_evaluator=True):
self.init_model()
if path and os.path.exists(path): #if path exists
self.model.load_state_dict(torch.load(path))
else:
print_and_log(self.logfile, 'warning: saved model path could not be found; using pretrained initialisation.')
path = self.checkpoint_dir
self.model.set_test_mode(True)
self.test_evaluator.set_base_params(self.model)
print_and_log(self.logfile, self.test_evaluator.check_for_uncounted_modules(self.model)) # check for modules which thop will not counted by default
num_context_clips_per_task, num_target_clips_per_task = [], []
with torch.no_grad():
# loop through test tasks (num_test_users * num_test_tasks_per_user)
num_test_tasks = len(self.test_queue) * self.args.num_test_tasks
for step, task_dict in enumerate(self.test_queue.get_tasks()):
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device)
num_context_clips = len(context_clips) # num_context_clips will be the same for all tasks of the same user
# since we're sampling frames uniformly from each videos (clip_method = uniform)
# and we're sampling from all the user's videos (shot method = max)
self.test_evaluator.set_task_object_list(object_list)
t1 = time.time()
self.model.personalise(context_clips, context_labels, ops_counter=self.test_evaluator.ops_counter)
self.test_evaluator.log_time(time.time() - t1, 'personalise')
# loop through target videos for the current task
num_target_clips = 0
video_iterator = zip(target_frames_by_video, target_paths_by_video, target_labels_by_video)
for video_frames, video_paths, video_label in video_iterator:
video_clips = attach_frame_history(video_frames, self.args.clip_length)
num_clips = len(video_clips)
t1 = time.time()
video_logits = self.model.predict(video_clips)
# log inference time per frame (so average over num_clips*clip_length)
self.test_evaluator.log_time((time.time() - t1)/float(num_clips*self.model.clip_length), 'inference')
self.test_evaluator.append_video(video_logits, video_label, video_paths)
num_target_clips += num_clips
# reset task's params
self.model._reset()
# log number of clips per task
num_context_clips_per_task.append(num_context_clips)
num_target_clips_per_task.append(num_target_clips)
# complete the task (required for correct ops counter numbers)
self.test_evaluator.task_complete()
# if this is the user's last task, get the average performance for the user over all their tasks
if (step+1) % self.args.num_test_tasks == 0:
self.test_evaluator.set_current_user(task_dict["task_id"])
_,_,_,current_video_stats = self.test_evaluator.get_mean_stats(current_user=True)
current_macs_mean,_,_,_ = self.test_evaluator.get_mean_ops_counter_stats(current_user=True)
print_and_log(self.logfile, f'{self.args.test_set} user {task_dict["task_id"]} ({self.test_evaluator.current_user+1}/{len(self.test_queue)}) stats: {stats_to_str(current_video_stats)}, avg MACs to personalise/task: {current_macs_mean}, avg # context clips/task: {np.mean(num_context_clips_per_task):.0f}, avg # target clips/task: {np.mean(num_target_clips_per_task):.0f}')
if (step+1) < num_test_tasks:
num_context_clips_per_task, num_target_clips_per_task = [], [] # reset per user
self.test_evaluator.next_user()
else:
self.test_evaluator.next_task()
stats_per_user, stats_per_obj, stats_per_task, stats_per_video = self.test_evaluator.get_mean_stats()
stats_per_user_str, stats_per_obj_str, stats_per_task_str, stats_per_video_str = stats_to_str(stats_per_user), stats_to_str(stats_per_obj), stats_to_str(stats_per_task), stats_to_str(stats_per_video)
mean_macs, std_macs, mean_params, params_breakdown = self.test_evaluator.get_mean_ops_counter_stats()
mean_personalise_time, std_personalise_time, mean_inference_time, std_inference_time = self.test_evaluator.get_mean_times()
print_and_log(self.logfile, (f"{self.args.test_set} [{path}]\n"
f"Frame accuracy (averaged per user): {stats_per_user_str}\n"
f"Frame accuracy (averaged per object): {stats_per_obj_str}\n"
f"Frame accuracy (averaged per task): {stats_per_task_str}\n"
f"Frame accuracy (averaged per video): {stats_per_video_str}\n"
f"Time to personalise (averaged per task) {mean_personalise_time} ({std_personalise_time})\n"
f"Inference time per frame (averaged per task): {mean_inference_time} ({std_inference_time})\n"
f"MACs to personalise (averaged per task): {mean_macs} ({std_macs})\n"
f"Number of params: {mean_params} ({params_breakdown})\n"))
if save_evaluator:
self.test_evaluator.save()
self.test_evaluator.reset()
def save_checkpoint(self, epoch):
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_stats': self.validation_evaluator.get_current_best_stats()
}, os.path.join(self.checkpoint_dir, 'checkpoint.pt'))
def load_checkpoint(self):
checkpoint = torch.load(os.path.join(self.checkpoint_dir, 'checkpoint.pt'))
self.start_epoch = checkpoint['epoch']
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.validation_evaluator.replace(checkpoint['best_stats'])
if __name__ == "__main__":
main()