added option to filter train and test differently
This commit is contained in:
Родитель
a5393bc73c
Коммит
9a0c19de93
|
@ -27,7 +27,7 @@ class DataLoader():
|
|||
dataset_info['frame_size'],
|
||||
dataset_info['frame_norm_method'],
|
||||
dataset_info['annotations_to_load'],
|
||||
dataset_info['filter_by_annotations'],
|
||||
dataset_info['train_filter_by_annotations'],
|
||||
dataset_info['num_train_tasks'],
|
||||
with_cluster_labels=dataset_info['with_cluster_labels'],
|
||||
with_caps=dataset_info['with_train_shot_caps'],
|
||||
|
@ -46,7 +46,7 @@ class DataLoader():
|
|||
dataset_info['frame_size'],
|
||||
dataset_info['frame_norm_method'],
|
||||
dataset_info['annotations_to_load'],
|
||||
dataset_info['filter_by_annotations'],
|
||||
dataset_info['test_filter_by_annotations'],
|
||||
dataset_info['num_val_tasks'],
|
||||
test_mode=True,
|
||||
logfile=dataset_info['logfile'])
|
||||
|
@ -64,7 +64,7 @@ class DataLoader():
|
|||
dataset_info['frame_size'],
|
||||
dataset_info['frame_norm_method'],
|
||||
dataset_info['annotations_to_load'],
|
||||
dataset_info['filter_by_annotations'],
|
||||
dataset_info['test_filter_by_annotations'],
|
||||
dataset_info['num_test_tasks'],
|
||||
test_mode=True,
|
||||
logfile=dataset_info['logfile'])
|
||||
|
|
|
@ -63,14 +63,12 @@ class Learner:
|
|||
random.seed(self.args.seed)
|
||||
torch.manual_seed(self.args.seed)
|
||||
device_id = 'cpu'
|
||||
self.map_location = 'cpu'
|
||||
if torch.cuda.is_available() and self.args.gpu >= 0:
|
||||
cudnn.enabled = True
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
device_id = 'cuda:' + str(self.args.gpu)
|
||||
torch.cuda.manual_seed_all(self.args.seed)
|
||||
self.map_location = lambda storage, loc: storage.cuda()
|
||||
|
||||
self.device = torch.device(device_id)
|
||||
self.ops_counter = OpsCounter(count_backward=True)
|
||||
|
@ -101,7 +99,7 @@ class Learner:
|
|||
'frame_size': self.args.frame_size,
|
||||
'frame_norm_method': self.args.frame_norm_method,
|
||||
'annotations_to_load': self.args.annotations_to_load,
|
||||
'filter_by_annotations': [self.args.filter_context, self.args.filter_target],
|
||||
'test_filter_by_annotations': [self.args.test_filter_context, self.args.test_filter_target],
|
||||
'logfile': self.logfile
|
||||
}
|
||||
|
||||
|
@ -137,7 +135,7 @@ class Learner:
|
|||
|
||||
self.model = self.init_model()
|
||||
if path and os.path.exists(path): # if path exists, load from disk
|
||||
self.model.load_state_dict(torch.load(path, map_location=self.map_location), strict=False)
|
||||
self.model.load_state_dict(torch.load(path), strict=False)
|
||||
else:
|
||||
print_and_log(self.logfile, 'warning: saved model path could not be found; using original param initialisation.')
|
||||
path = self.checkpoint_dir
|
||||
|
@ -190,7 +188,7 @@ class Learner:
|
|||
if (step+1) % self.args.num_test_tasks == 0:
|
||||
self.test_evaluator.set_current_user(task_dict["task_id"])
|
||||
_,_,_,current_video_stats = self.test_evaluator.get_mean_stats(current_user=True)
|
||||
print_and_log(self.logfile, f'{self.args.test_set} user {task_dict["task_id"]} ({self.test_evaluator.current_user+1}/{len(self.test_queue)}) stats: {stats_to_str(current_video_stats)} avg. #context clips/task: {np.mean(num_context_clips_per_task):.0f} avg. #target clips/task: {np.mean(num_target_clips_per_task):.0f}')
|
||||
print_and_log(self.logfile, f'{self.args.test_set} user {task_dict["task_id"]} ({self.test_evaluator.current_user+1}/{len(self.test_queue)}) stats: {stats_to_str(current_video_stats)} avg # context clips/task: {np.mean(num_context_clips_per_task):.0f} avg # target clips/task: {np.mean(num_target_clips_per_task):.0f}')
|
||||
if (step+1) < num_test_tasks:
|
||||
num_context_clips_per_task, num_target_clips_per_task = [], []
|
||||
self.test_evaluator.next_user()
|
||||
|
|
|
@ -64,14 +64,12 @@ class Learner:
|
|||
random.seed(self.args.seed)
|
||||
torch.manual_seed(self.args.seed)
|
||||
device_id='cpu'
|
||||
self.map_location='cpu'
|
||||
if torch.cuda.is_available() and self.args.gpu>=0:
|
||||
cudnn.enabled = True
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
device_id = 'cuda:' + str(self.args.gpu)
|
||||
torch.cuda.manual_seed_all(self.args.seed)
|
||||
self.map_location=lambda storage, loc: storage.cuda()
|
||||
|
||||
self.device = torch.device(device_id)
|
||||
self.ops_counter = OpsCounter()
|
||||
|
@ -112,7 +110,8 @@ class Learner:
|
|||
'frame_size': self.args.frame_size,
|
||||
'frame_norm_method': self.args.frame_norm_method,
|
||||
'annotations_to_load': self.args.annotations_to_load,
|
||||
'filter_by_annotations': [self.args.filter_context, self.args.filter_target],
|
||||
'train_filter_by_annotations': [self.args.train_filter_context, self.args.train_filter_target],
|
||||
'test_filter_by_annotations': [self.args.test_filter_context, self.args.test_filter_target],
|
||||
'logfile': self.logfile
|
||||
}
|
||||
|
||||
|
@ -277,7 +276,7 @@ class Learner:
|
|||
if (step+1) % self.args.num_val_tasks == 0:
|
||||
self.validation_evaluator.set_current_user(task_dict["task_id"])
|
||||
_,_,_,current_video_stats = self.validation_evaluator.get_mean_stats(current_user=True)
|
||||
print_and_log(self.logfile, f'validation user {task_dict["task_id"]} ({self.validation_evaluator.current_user+1}/{len(self.validation_queue)}) stats: {stats_to_str(current_video_stats)} avg. #context clips/task: {np.mean(num_context_clips_per_task):.0f} avg. #target clips/task: {np.mean(num_target_clips_per_task):.0f}')
|
||||
print_and_log(self.logfile, f'validation user {task_dict["task_id"]} ({self.validation_evaluator.current_user+1}/{len(self.validation_queue)}) stats: {stats_to_str(current_video_stats)} avg # context clips/task: {np.mean(num_context_clips_per_task):.0f} avg # target clips/task: {np.mean(num_target_clips_per_task):.0f}')
|
||||
if (step+1) < num_val_tasks:
|
||||
num_context_clips_per_task, num_target_clips_per_task = [], []
|
||||
self.validation_evaluator.next_user()
|
||||
|
@ -302,7 +301,7 @@ class Learner:
|
|||
|
||||
self.init_model()
|
||||
if path and os.path.exists(path): #if path exists
|
||||
self.model.load_state_dict(torch.load(path, map_location=self.map_location))
|
||||
self.model.load_state_dict(torch.load(path))
|
||||
else:
|
||||
print_and_log(self.logfile, 'warning: saved model path could not be found; using pretrained initialisation.')
|
||||
path = self.checkpoint_dir
|
||||
|
@ -344,7 +343,7 @@ class Learner:
|
|||
if (step+1) % self.args.num_test_tasks == 0:
|
||||
self.test_evaluator.set_current_user(task_dict["task_id"])
|
||||
_,_,_,current_video_stats = self.test_evaluator.get_mean_stats(current_user=True)
|
||||
print_and_log(self.logfile, f'{self.args.test_set} user {task_dict["task_id"]} ({self.test_evaluator.current_user+1}/{len(self.test_queue)}) stats: {stats_to_str(current_video_stats)} avg. #context clips/task: {np.mean(num_context_clips_per_task):.0f} avg. #target clips/task: {np.mean(num_target_clips_per_task):.0f}')
|
||||
print_and_log(self.logfile, f'{self.args.test_set} user {task_dict["task_id"]} ({self.test_evaluator.current_user+1}/{len(self.test_queue)}) stats: {stats_to_str(current_video_stats)} avg # context clips/task: {np.mean(num_context_clips_per_task):.0f} avg # target clips/task: {np.mean(num_target_clips_per_task):.0f}')
|
||||
if (step+1) < num_test_tasks:
|
||||
num_context_clips_per_task, num_target_clips_per_task = [], []
|
||||
self.test_evaluator.next_user()
|
||||
|
|
|
@ -174,8 +174,10 @@ def parse_args(learner='default'):
|
|||
help="Momentum for SGD optimizer during personalization (default: 0.0).")
|
||||
|
||||
args = parser.parse_args()
|
||||
args.filter_context = expand_issues(args.filter_context)
|
||||
args.filter_target = expand_issues(args.filter_target)
|
||||
args.train_filter_context = expand_issues(args.train_filter_context)
|
||||
args.train_filter_target = expand_issues(args.train_filter_target)
|
||||
args.test_filter_context = expand_issues(args.test_filter_context)
|
||||
args.test_filter_target = expand_issues(args.test_filter_target)
|
||||
if args.feature_extractor == 'efficientnet_b0':
|
||||
args.frame_norm_method = 'imagenet'
|
||||
elif args.feature_extractor in ['efficientnet_v2_s', 'vit_s_32', 'vit_b_32']:
|
||||
|
|
Загрузка…
Ссылка в новой задаче