229 строки
12 KiB
Python
229 строки
12 KiB
Python
"""
|
|
Copyright (c) Microsoft Corporation.
|
|
Licensed under the MIT license
|
|
|
|
This code was based on the file run_cnaps.py (https://github.com/cambridge-mlg/cnaps/blob/master/src/run_cnaps.py)
|
|
from the cambridge-mlg/cnaps library (https://github.com/cambridge-mlg/cnaps).
|
|
|
|
The original license is included below:
|
|
|
|
Copyright (c) 2019 John Bronskill, Jonathan Gordon, and James Requeima.
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import torch
|
|
import random
|
|
import numpy as np
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
from data.dataloaders import DataLoader
|
|
from data.utils import unpack_task, attach_frame_history
|
|
from model.few_shot_recognisers import MultiStepFewShotRecogniser
|
|
from utils.args import parse_args
|
|
from utils.optim import cross_entropy
|
|
from utils.eval_metrics import TestEvaluator
|
|
from utils.logging import print_and_log, get_log_files, stats_to_str
|
|
|
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
|
|
|
def main():
|
|
learner = Learner()
|
|
learner.run()
|
|
|
|
class Learner:
|
|
def __init__(self):
|
|
self.args = parse_args(learner='multi-step-learner')
|
|
|
|
self.checkpoint_dir, self.logfile, _, _ \
|
|
= get_log_files(self.args.checkpoint_dir, self.args.model_path)
|
|
|
|
print_and_log(self.logfile, "Options: %s\n" % self.args)
|
|
print_and_log(self.logfile, "Checkpoint Directory: %s\n" % self.checkpoint_dir)
|
|
|
|
random.seed(self.args.seed)
|
|
torch.manual_seed(self.args.seed)
|
|
device_id = 'cpu'
|
|
if torch.cuda.is_available() and self.args.gpu >= 0:
|
|
cudnn.enabled = True
|
|
cudnn.benchmark = False
|
|
cudnn.deterministic = True
|
|
device_id = 'cuda:' + str(self.args.gpu)
|
|
torch.cuda.manual_seed_all(self.args.seed)
|
|
|
|
self.device = torch.device(device_id)
|
|
self.init_dataset()
|
|
self.init_evaluators()
|
|
self.model = self.init_model()
|
|
self.loss = cross_entropy
|
|
|
|
print_and_log(self.logfile, f"Model details:\n" \
|
|
f"\tfeature extractor: {self.args.feature_extractor} (pretrained: True, learnable: {self.args.learn_extractor}, finetune film params: {self.args.adapt_features})\n" \
|
|
f"\tclassifier: {self.args.classifier} with logit scale={self.args.logit_scale}\n")
|
|
|
|
def init_dataset(self):
|
|
|
|
dataset_info = {
|
|
'mode': self.args.mode,
|
|
'data_path': self.args.data_path,
|
|
'test_object_cap': self.args.test_object_cap,
|
|
'test_way_method': self.args.test_way_method,
|
|
'test_shot_methods': [self.args.test_context_shot_method, self.args.test_target_shot_method],
|
|
'num_test_tasks': self.args.num_test_tasks,
|
|
'test_set': self.args.test_set,
|
|
'shots': [self.args.context_shot, self.args.target_shot],
|
|
'video_types': [self.args.context_video_type, self.args.target_video_type],
|
|
'clip_length': self.args.clip_length,
|
|
'test_clip_methods': [self.args.test_context_clip_method, self.args.test_target_clip_method],
|
|
'subsample_factor': self.args.subsample_factor,
|
|
'frame_size': self.args.frame_size,
|
|
'frame_norm_method': self.args.frame_norm_method,
|
|
'annotations_to_load': self.args.annotations_to_load,
|
|
'test_filter_by_annotations': [self.args.test_filter_context, self.args.test_filter_target],
|
|
'logfile': self.logfile
|
|
}
|
|
|
|
dataloader = DataLoader(dataset_info)
|
|
self.test_queue = dataloader.get_test_queue()
|
|
|
|
def init_model(self):
|
|
model = MultiStepFewShotRecogniser(
|
|
self.args.feature_extractor, self.args.adapt_features, self.args.classifier, self.args.clip_length,
|
|
self.args.batch_size, self.args.learn_extractor, self.args.logit_scale
|
|
)
|
|
model._set_device(self.device)
|
|
model._send_to_device()
|
|
|
|
return model
|
|
|
|
def init_finetuner(self):
|
|
finetuner = self.init_model()
|
|
finetuner.load_state_dict(self.model.state_dict(), strict=False)
|
|
finetuner.set_test_mode(True)
|
|
return finetuner
|
|
|
|
def init_evaluators(self):
|
|
self.evaluation_metrics = ['frame_acc']
|
|
|
|
self.test_evaluator = TestEvaluator(self.evaluation_metrics, self.checkpoint_dir, with_ops_counter=True, count_backwards=True)
|
|
|
|
def run(self):
|
|
self.test(self.args.model_path)
|
|
self.logfile.close()
|
|
|
|
def test(self, path, save_evaluator=True):
|
|
|
|
self.model = self.init_model()
|
|
if path and os.path.exists(path): # if path exists, load from disk
|
|
self.model.load_state_dict(torch.load(path), strict=False)
|
|
else:
|
|
print_and_log(self.logfile, 'warning: saved model path could not be found; using original param initialisation.')
|
|
path = self.checkpoint_dir
|
|
self.test_evaluator.set_base_params(self.model)
|
|
print_and_log(self.logfile, self.test_evaluator.check_for_uncounted_modules(self.model)) # check for modules which thop will not counted by default
|
|
num_context_clips_per_task, num_target_clips_per_task = [], []
|
|
|
|
# loop through test tasks (num_test_users * num_test_tasks_per_user)
|
|
num_test_tasks = len(self.test_queue) * self.args.num_test_tasks
|
|
for step, task_dict in enumerate(self.test_queue.get_tasks()):
|
|
context_clips, context_paths, context_labels, target_frames_by_video, target_paths_by_video, target_labels_by_video, object_list = unpack_task(task_dict, self.device, context_to_device=False)
|
|
num_context_clips = len(context_clips) # num_context_clips will be the same for all tasks of the same user
|
|
# since we're sampling frames uniformly from each videos (clip_method = uniform)
|
|
# and we're sampling from all the user's videos (shot method = max)
|
|
self.test_evaluator.set_task_object_list(object_list)
|
|
|
|
# initialise finetuner model to initial state of self.model for current task
|
|
finetuner = self.init_finetuner()
|
|
|
|
# adapt to current task by finetuning on context clips
|
|
t1 = time.time()
|
|
learning_args= {
|
|
'num_grad_steps': self.args.personalize_num_grad_steps,
|
|
'learning_rate': self.args.personalize_learning_rate,
|
|
'extractor_lr_scale': self.args.personalize_extractor_lr_scale,
|
|
'loss_fn': self.loss,
|
|
'optimizer': self.args.personalize_optimizer,
|
|
'momentum' : self.args.personalize_momentum,
|
|
'weight_decay' : self.args.personalize_weight_decay,
|
|
'betas' : self.args.personalize_betas,
|
|
'epsilon' : self.args.personalize_epsilon
|
|
}
|
|
finetuner.personalise(context_clips, context_labels, learning_args, ops_counter=self.test_evaluator.ops_counter)
|
|
self.test_evaluator.log_time(time.time() - t1, 'personalise')
|
|
|
|
# loop through target videos for the current task
|
|
with torch.no_grad():
|
|
num_target_clips = 0
|
|
video_iterator = zip(target_frames_by_video, target_paths_by_video, target_labels_by_video)
|
|
for video_frames, video_paths, video_label in video_iterator:
|
|
video_clips = attach_frame_history(video_frames, self.args.clip_length)
|
|
num_clips = len(video_clips)
|
|
t1 = time.time()
|
|
video_logits = finetuner.predict(video_clips)
|
|
# log inference time per frame (so average over num_clips*clip_length)
|
|
self.test_evaluator.log_time((time.time() - t1)/float(num_clips*self.model.clip_length), 'inference')
|
|
self.test_evaluator.append_video(video_logits, video_label, video_paths)
|
|
num_target_clips += num_clips
|
|
|
|
# log number of clips per task
|
|
num_context_clips_per_task.append(num_context_clips)
|
|
num_target_clips_per_task.append(num_target_clips)
|
|
|
|
# complete the task (required for correct ops counter numbers)
|
|
self.test_evaluator.task_complete()
|
|
|
|
# if this is the user's last task, get the average performance for the user over all their tasks
|
|
if (step+1) % self.args.num_test_tasks == 0:
|
|
self.test_evaluator.set_current_user(task_dict["task_id"])
|
|
_,_,_,current_video_stats = self.test_evaluator.get_mean_stats(current_user=True)
|
|
current_macs_mean,_,_,_ = self.test_evaluator.get_mean_ops_counter_stats(current_user=True)
|
|
|
|
print_and_log(self.logfile, f'{self.args.test_set} user {task_dict["task_id"]} ({self.test_evaluator.current_user+1}/{len(self.test_queue)}) stats: {stats_to_str(current_video_stats)}, avg MACs to personalise/task: {current_macs_mean}, avg # context clips/task: {np.mean(num_context_clips_per_task):.0f}, avg # target clips/task: {np.mean(num_target_clips_per_task):.0f}')
|
|
if (step+1) < num_test_tasks:
|
|
num_context_clips_per_task, num_target_clips_per_task = [], [] # reset per user
|
|
self.test_evaluator.next_user()
|
|
else:
|
|
self.test_evaluator.next_task()
|
|
|
|
self.model._reset()
|
|
|
|
# get average performance over all users
|
|
stats_per_user, stats_per_obj, stats_per_task, stats_per_video = self.test_evaluator.get_mean_stats()
|
|
stats_per_user_str, stats_per_obj_str, stats_per_task_str, stats_per_video_str = stats_to_str(stats_per_user), stats_to_str(stats_per_obj), stats_to_str(stats_per_task), stats_to_str(stats_per_video)
|
|
mean_macs, std_macs, mean_params, params_breakdown = self.test_evaluator.get_mean_ops_counter_stats()
|
|
mean_personalise_time, std_personalise_time, mean_inference_time, std_inference_time = self.test_evaluator.get_mean_times()
|
|
print_and_log(self.logfile, (f"{self.args.test_set} [{path}]\n"
|
|
f"Frame accuracy (averaged per user): {stats_per_user_str}\n"
|
|
f"Frame accuracy (averaged per object): {stats_per_obj_str}\n"
|
|
f"Frame accuracy (averaged per task): {stats_per_task_str}\n"
|
|
f"Frame accuracy (averaged per video): {stats_per_video_str}\n"
|
|
f"Time to personalise (averaged per task) {mean_personalise_time} ({std_personalise_time})\n"
|
|
f"Inference time per frame (averaged per task): {mean_inference_time} ({std_inference_time})\n"
|
|
f"MACs to personalise (averaged per task): {mean_macs} ({std_macs})\n"
|
|
f"Number of params: {mean_params} ({params_breakdown})\n"))
|
|
|
|
if save_evaluator:
|
|
self.test_evaluator.save()
|
|
self.test_evaluator.reset()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|