add download capability to FaceSyntheticsDataset (#216)

* cleanup launch.json
add download capability to FaceSyntheticsDataset
add download_and_extract_zip helper

* fix file count test

* add ability to resolve environment variables in a config file.
This commit is contained in:
Chris Lovett 2023-04-12 10:48:14 -07:00 коммит произвёл GitHub
Родитель 30d929c9c2
Коммит 3179a8c7a2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 192 добавлений и 103 удалений

5
.vscode/launch.json поставляемый
Просмотреть файл

@ -262,12 +262,7 @@
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"cwd": "D:\\git\\microsoft\\archai\\archai\\tasks\\face_segmentation",
"args":[
"--dataset_dir", "C:\\datasets\\FaceSynthetics",
"--output_dir", "d:\\temp\\face_segmentation",
"--search_config", "confs\\snp_search.yaml",
"--serial_training"
]
}
]

Просмотреть файл

@ -15,11 +15,12 @@ import yaml
from archai.common import yaml_utils
# global config instance
_config:Optional['Config'] = None
_config: Optional['Config'] = None
# TODO: remove this duplicate code which is also in utils.py without circular deps
def deep_update(d:MutableMapping, u:Mapping, create_map:Callable[[],MutableMapping])\
->MutableMapping:
def deep_update(d: MutableMapping, u: Mapping, create_map: Callable[[], MutableMapping])\
-> MutableMapping:
for k, v in u.items():
if isinstance(v, Mapping):
d[k] = deep_update(d.get(k, create_map()), v, create_map)
@ -27,24 +28,25 @@ def deep_update(d:MutableMapping, u:Mapping, create_map:Callable[[],MutableMappi
d[k] = v
return d
class Config(UserDict):
def __init__(self, config_filepath:Optional[str]=None,
app_desc:Optional[str]=None, use_args=False,
param_args: Sequence = [], resolve_redirects=True) -> None:
def __init__(self, config_filepath: Optional[str] = None,
app_desc: Optional[str] = None, use_args=False,
param_args: Sequence = [], resolve_redirects=True,
resolve_env_vars=False) -> None:
"""Create config from specified files and args
Config is simply a dictionary of key, value map. The value can itself be
a dictionary so config can be hierarchical. This class allows to load
config from yaml. A special key '__include__' can specify another yaml
relative file path (or list of file paths) which will be loaded first
and the key-value pairs in the main file
will override the ones in include file. You can think of included file as
defaults provider. This allows to create one base config and then several
environment/experiment specific configs. On the top of that you can use
Config is simply a dictionary of key, value map. The value can itself be a dictionary so config can be
hierarchical. This class allows to load config from yaml. A special key '__include__' can specify another yaml
relative file path (or list of file paths) which will be loaded first and the key-value pairs in the main file
will override the ones in include file. You can think of included file as defaults provider. This allows to
create one base config and then several environment/experiment specific configs. On the top of that you can use
param_args to perform final overrides for a given run.
You can also have values that reference environment variables using ${ENV_VAR_NAME} syntax.
Keyword Arguments:
config_filepath {[str]} -- [Yaml file to load config from, could be names of files separated by semicolon which will be loaded in sequence oveeriding previous config] (default: {None})
config_filepath {[str]} -- [Yaml file to load config from, could be names of files separated by semicolon which will be loaded in sequence overriding previous config] (default: {None})
app_desc {[str]} -- [app description that will show up in --help] (default: {None})
use_args {bool} -- [if true then command line parameters will override parameters from config files] (default: {False})
param_args {Sequence} -- [parameters specified as ['--key1',val1,'--key2',val2,...] which will override parameters from config file.] (default: {[]})
@ -58,7 +60,7 @@ class Config(UserDict):
# let command line args specify/override config file
parser = argparse.ArgumentParser(description=app_desc)
parser.add_argument('--config', type=str, default=None,
help='config filepath in yaml format, can be list separated by ;')
help='config filepath in yaml format, can be list separated by ;')
self.args, self.extra_args = parser.parse_known_args()
config_filepath = self.args.config or config_filepath
@ -74,15 +76,18 @@ class Config(UserDict):
yaml_utils.resolve_all(resolved_conf)
# Let's do final overrides from args
self._update_from_args(param_args, resolved_conf) # merge from params
self._update_from_args(self.extra_args, resolved_conf) # merge from command line
self._update_from_args(param_args, resolved_conf) # merge from params
self._update_from_args(self.extra_args, resolved_conf) # merge from command line
if resolve_env_vars:
self._process_envvars(resolved_conf)
if resolve_redirects:
yaml_utils.resolve_all(self)
self.config_filepath = config_filepath
def _load_from_file(self, filepath: Optional[str])->None:
def _load_from_file(self, filepath: Optional[str]) -> None:
if filepath:
filepath = os.path.expanduser(os.path.expandvars(filepath))
filepath = os.path.abspath(filepath)
@ -92,7 +97,7 @@ class Config(UserDict):
deep_update(self, config_yaml, lambda: Config(resolve_redirects=False))
print('config loaded from: ', filepath)
def _process_includes(self, config_yaml, filepath:str):
def _process_includes(self, config_yaml, filepath: str):
if '__include__' in config_yaml:
# include could be file name or array of file names to apply in sequence
includes = config_yaml['__include__']
@ -103,58 +108,66 @@ class Config(UserDict):
include_filepath = os.path.join(os.path.dirname(filepath), include)
self._load_from_file(include_filepath)
def _update_from_args(self, args:Sequence, resolved_section:'Config')->None:
def _process_envvars(self, config_yaml):
for key in config_yaml:
value = config_yaml[key]
if isinstance(value, Config):
self._process_envvars(value)
elif isinstance(value, str) and '$' in value:
config_yaml[key] = os.path.expandvars(value)
def _update_from_args(self, args: Sequence, resolved_section: 'Config') -> None:
i = 0
while i < len(args)-1:
arg = args[i]
if arg.startswith(("--")):
path = arg[len("--"):].split('.')
i += Config._update_section(self, path, args[i+1], resolved_section)
else: # some other arg
else: # some other arg
i += 1
def to_dict(self)->dict:
return deep_update({}, self, lambda: dict()) # type: ignore
def to_dict(self) -> dict:
return deep_update({}, self, lambda: dict()) # type: ignore
@staticmethod
def _update_section(section:'Config', path:List[str], val:Any, resolved_section:'Config')->int:
def _update_section(section: 'Config', path: List[str], val: Any, resolved_section: 'Config') -> int:
for p in range(len(path)-1):
sub_path = path[p]
if sub_path in resolved_section:
resolved_section = resolved_section[sub_path]
if not sub_path in section:
if sub_path not in section:
section[sub_path] = Config(resolve_redirects=False)
section = section[sub_path]
else:
return 1 # path not found, ignore this
key = path[-1] # final leaf node value
return 1 # path not found, ignore this
key = path[-1] # final leaf node value
if key in resolved_section:
original_val, original_type = None, None
try:
original_val = resolved_section[key]
original_type = type(original_val)
if original_type == bool: # bool('False') is True :(
original_type = lambda x: strtobool(x)==1
if original_type == bool: # bool('False') is True :(
original_type = lambda x: strtobool(x) == 1
section[key] = original_type(val)
except Exception as e:
raise KeyError(
f'The yaml key or command line argument "{key}" is likely not named correctly or value is of wrong data type. Error was occured when setting it to value "{val}".'
f'The yaml key or command line argument "{key}" is likely not named correctly or value is of wrong data type. Error was occurred when setting it to value "{val}".'
f'Originally it is set to {original_val} which is of type {original_type}.'
f'Original exception: {e}')
return 2 # path was found, increment arg pointer by 2 as we use up val
return 2 # path was found, increment arg pointer by 2 as we use up val
else:
return 1 # path not found, ignore this
return 1 # path not found, ignore this
def get_val(self, key, default_val):
return super().get(key, default_val)
@staticmethod
def set_inst(instance:'Config')->None:
def set_inst(instance: 'Config') -> None:
global _config
_config = instance
@staticmethod
def get_inst()->'Config':
def get_inst() -> 'Config':
global _config
return _config

Просмотреть файл

@ -54,13 +54,14 @@ class AverageMeter:
self.cnt += n
self.avg = self.sum / self.cnt
def first_or_default(it:Iterable, default=None):
def first_or_default(it: Iterable, default=None):
for i in it:
return i
return default
def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)\
->MutableMapping:
def deep_update(d: MutableMapping, u: Mapping, map_type: Type[MutableMapping] = dict) -> MutableMapping:
for k, v in u.items():
if isinstance(v, Mapping):
d[k] = deep_update(d.get(k, map_type()), v, map_type)
@ -68,7 +69,8 @@ def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)
d[k] = v
return d
def state_dict(val)->Mapping:
def state_dict(val) -> Mapping:
assert hasattr(val, '__dict__'), 'val must be object with __dict__ otherwise it cannot be loaded back in load_state_dict'
# Can't do below because val has state_dict() which calls utils.state_dict
@ -79,7 +81,8 @@ def state_dict(val)->Mapping:
return {'yaml': yaml.dump(val)}
def load_state_dict(val:Any, state_dict:Mapping)->None:
def load_state_dict(val: Any, state_dict: Mapping) -> None:
assert hasattr(val, '__dict__'), 'val must be object with __dict__'
# Can't do below because val has state_dict() which calls utils.state_dict
@ -93,7 +96,8 @@ def load_state_dict(val:Any, state_dict:Mapping)->None:
for k, v in obj.__dict__.items():
setattr(val, k, v)
def deep_comp(o1:Any, o2:Any)->bool:
def deep_comp(o1: Any, o2: Any) -> bool:
# NOTE: dict don't have __dict__
o1d = getattr(o1, '__dict__', None)
o2d = getattr(o2, '__dict__', None)
@ -111,21 +115,25 @@ def deep_comp(o1:Any, o2:Any)->bool:
if not deep_comp(o1[k], o2[k]):
return False
else:
return False # some key missing
return False # some key missing
return True
# mismatched object types or both are scalers, or one or both None
return o1 == o2
# We setup env variable if debugging mode is detected for vs_code_debugging.
# The reason for this is that when Python multiprocessing is used, the new process
# spawned do not inherit 'pydevd' so those process do not get detected as in debugging mode
# even though they are. So we set env var which does get inherited by sub processes.
if 'pydevd' in sys.modules:
os.environ['vs_code_debugging'] = 'True'
def is_debugging()->bool:
return 'vs_code_debugging' in os.environ and os.environ['vs_code_debugging']=='True'
def full_path(path:str, create=False)->str:
def is_debugging() -> bool:
return 'vs_code_debugging' in os.environ and os.environ['vs_code_debugging'] == 'True'
def full_path(path: str, create=False) -> str:
assert path
path = os.path.abspath(
os.path.expanduser(
@ -134,21 +142,27 @@ def full_path(path:str, create=False)->str:
os.makedirs(path, exist_ok=True)
return path
def zero_file(filepath)->None:
def zero_file(filepath) -> None:
"""Creates or truncates existing file"""
open(filepath, 'w').close()
def write_string(filepath:str, content:str)->None:
def write_string(filepath: str, content: str) -> None:
pathlib.Path(filepath).write_text(content)
def read_string(filepath:str)->str:
def read_string(filepath: str) -> str:
return pathlib.Path(filepath).read_text()
def fmt(val:Any)->str:
def fmt(val: Any) -> str:
if isinstance(val, float):
return f'{val:.4g}'
return str(val)
def append_csv_file(filepath:str, new_row:List[Tuple[str, Any]], delimiter='\t'):
def append_csv_file(filepath: str, new_row: List[Tuple[str, Any]], delimiter='\t'):
fieldnames, rows = [], []
if os.path.exists(filepath):
with open(filepath, 'r') as f:
@ -160,19 +174,21 @@ def append_csv_file(filepath:str, new_row:List[Tuple[str, Any]], delimiter='\t')
new_fieldnames = OrderedDict([(fn, None) for fn, v in new_row])
for fn in fieldnames:
new_fieldnames[fn]=None
new_fieldnames[fn] = None
with open(filepath, 'w', newline='') as f:
dr = csv.DictWriter(f, fieldnames=new_fieldnames.keys(), delimiter=delimiter)
dr.writeheader()
for row in rows:
d = dict((k,v) for k,v in zip(fieldnames, row))
d = dict((k, v) for k, v in zip(fieldnames, row))
dr.writerow(d)
dr.writerow(OrderedDict(new_row))
def has_method(o, name):
return callable(getattr(o, name, None))
def extract_tar(src, dest=None, gzip=None, delete=False):
import tarfile
@ -188,6 +204,20 @@ def extract_tar(src, dest=None, gzip=None, delete=False):
if delete:
os.remove(src)
def extract_zip(src, dest=None, delete=False):
import zipfile
if dest is None:
dest = os.path.dirname(src)
with zipfile.ZipFile(src, 'r') as zip_ref:
zip_ref.extractall(dest)
if delete:
os.remove(src)
def download_and_extract_tar(url, download_root, extract_root=None, filename=None,
md5=None, **kwargs):
download_root = os.path.expanduser(download_root)
@ -201,7 +231,22 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non
extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
def setup_cuda(seed:Union[float, int], local_rank:int=0):
def download_and_extract_zip(url, download_root, extract_root=None, filename=None,
md5=None, **kwargs):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if filename is None:
filename = os.path.basename(url)
if not tvutils.check_integrity(os.path.join(download_root, filename), md5):
tvutils.download_url(url, download_root, filename=filename, md5=md5)
extract_zip(os.path.join(download_root, filename), extract_root, delete=True, **kwargs)
def setup_cuda(seed: Union[float, int], local_rank: int = 0):
seed = int(seed) + local_rank
# setup cuda
cudnn.enabled = True
@ -210,17 +255,19 @@ def setup_cuda(seed:Union[float, int], local_rank:int=0):
np.random.seed(seed)
random.seed(seed)
#torch.cuda.manual_seed_all(seed)
cudnn.benchmark = True # set to false if deterministic
# torch.cuda.manual_seed_all(seed)
cudnn.benchmark = True # set to false if deterministic
torch.set_printoptions(precision=10)
#cudnn.deterministic = False
# cudnn.deterministic = False
# torch.cuda.empty_cache()
# torch.cuda.synchronize()
def cuda_device_names()->str:
def cuda_device_names() -> str:
return ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
def exec_shell_command(command:str, print_command_start=True, print_command_end=True)->subprocess.CompletedProcess:
def exec_shell_command(command: str, print_command_start=True, print_command_end=True) -> subprocess.CompletedProcess:
if print_command_start:
print(f'[{datetime.now()}] Running: {command}')
@ -231,49 +278,59 @@ def exec_shell_command(command:str, print_command_start=True, print_command_end=
return ret
def zip_eq(*iterables):
sentinel = object()
for count, combo in enumerate(zip_longest(*iterables, fillvalue=sentinel)):
if any(True for c in combo if sentinel is c):
shorter_its = ','.join([str(i) for i,c in enumerate(combo) if sentinel is c])
shorter_its = ','.join([str(i) for i, c in enumerate(combo) if sentinel is c])
raise ValueError(f'Iterator {shorter_its} have length {count} which is shorter than others')
yield combo
def dir_downloads()->str:
def dir_downloads() -> str:
return full_path(str(os.path.join(pathlib.Path.home(), "Downloads")))
def filepath_without_ext(filepath:str)->str:
def filepath_without_ext(filepath: str) -> str:
"""Returns '/a/b/c/d.e' for '/a/b/c/d.e.f' """
return str(pathlib.Path(filepath).with_suffix(''))
def filepath_ext(filepath:str)->str:
def filepath_ext(filepath: str) -> str:
"""Returns '.f' for '/a/b/c/d.e.f' """
return pathlib.Path(filepath).suffix
def filepath_name_ext(filepath:str)->str:
def filepath_name_ext(filepath: str) -> str:
"""Returns 'd.e.f' for '/a/b/c/d.e.f' """
return pathlib.Path(filepath).name
def filepath_name_only(filepath:str)->str:
def filepath_name_only(filepath: str) -> str:
"""Returns 'd.e' for '/a/b/c/d.e.f' """
return pathlib.Path(filepath).stem
def change_filepath_ext(filepath:str, new_ext:str)->str:
def change_filepath_ext(filepath: str, new_ext: str) -> str:
"""Returns '/a/b/c/d.e.g' for filepath='/a/b/c/d.e.f', new_ext='.g' """
return str(pathlib.Path(filepath).with_suffix(new_ext))
def change_filepath_name(filepath:str, new_name:str, new_ext:Optional[str]=None)->str:
def change_filepath_name(filepath: str, new_name: str, new_ext: Optional[str] = None) -> str:
"""Returns '/a/b/c/h.f' for filepath='/a/b/c/d.e.f', new_name='h' """
ext = new_ext or filepath_ext(filepath)
return str(pathlib.Path(filepath).with_name(new_name).with_suffix(ext))
def append_to_filename(filepath:str, name_suffix:str, new_ext:Optional[str]=None)->str:
def append_to_filename(filepath: str, name_suffix: str, new_ext: Optional[str] = None) -> str:
"""Returns '/a/b/c/h.f' for filepath='/a/b/c/d.e.f', new_name='h' """
ext = new_ext or filepath_ext(filepath)
name = filepath_name_only(filepath)
return str(pathlib.Path(filepath).with_name(name+name_suffix).with_suffix(ext))
def copy_file(src_file:str, dest_dir_or_file:str, preserve_metadata=False, use_shutil:bool=True)->str:
def copy_file(src_file: str, dest_dir_or_file: str, preserve_metadata=False, use_shutil: bool = True) -> str:
if not use_shutil:
assert not preserve_metadata
return copy_file_basic(src_file, dest_dir_or_file)
@ -284,11 +341,12 @@ def copy_file(src_file:str, dest_dir_or_file:str, preserve_metadata=False, use_s
copy_fn = shutil.copy2 if preserve_metadata else shutil.copy
return copy_fn(src_file, dest_dir_or_file)
except OSError as ex:
if preserve_metadata or ex.errno != 38: # OSError: [Errno 38] Function not implemented
if preserve_metadata or ex.errno != 38: # OSError: [Errno 38] Function not implemented
raise
return copy_file_basic(src_file, dest_dir_or_file)
def copy_file_basic(src_file:str, dest_dir_or_file:str)->str:
def copy_file_basic(src_file: str, dest_dir_or_file: str) -> str:
# try basic python functions
# first if dest is dir, get dest file name
if os.path.isdir(dest_dir_or_file):
@ -297,7 +355,8 @@ def copy_file_basic(src_file:str, dest_dir_or_file:str)->str:
dst.write(src.read())
return dest_dir_or_file
def copy_dir(src_dir:str, dest_dir:str, use_shutil:bool=True)->None:
def copy_dir(src_dir: str, dest_dir: str, use_shutil: bool = True) -> None:
if os.path.isdir(src_dir):
if use_shutil:
shutil.copytree(src_dir, dest_dir)
@ -307,24 +366,33 @@ def copy_dir(src_dir:str, dest_dir:str, use_shutil:bool=True)->None:
files = os.listdir(src_dir)
for f in files:
copy_dir(os.path.join(src_dir, f),
os.path.join(dest_dir, f), use_shutil=use_shutil)
os.path.join(dest_dir, f), use_shutil=use_shutil)
else:
copy_file(src_dir, dest_dir, use_shutil=use_shutil)
if 'main_process_pid' not in os.environ:
os.environ['main_process_pid'] = str(os.getpid())
def is_main_process()->bool:
def is_main_process() -> bool:
"""Returns True if this process was started as main process instead of child process during multiprocessing"""
return multiprocessing.current_process().name == 'MainProcess' and os.environ['main_process_pid'] == str(os.getpid())
def main_process_pid()->int:
def main_process_pid() -> int:
return int(os.environ['main_process_pid'])
def process_name()->str:
def process_name() -> str:
return multiprocessing.current_process().name
def is_windows()->bool:
return platform.system()=='Windows'
def path2uri(path:str, windows_non_standard:bool=False)->str:
def is_windows() -> bool:
return platform.system() == 'Windows'
def path2uri(path: str, windows_non_standard: bool = False) -> str:
uri = pathlib.Path(full_path(path)).as_uri()
# there is lot of buggy regex based code out there which expects Windows file URIs as
@ -334,7 +402,8 @@ def path2uri(path:str, windows_non_standard:bool=False)->str:
uri = uri.replace('file:///', 'file://')
return uri
def uri2path(file_uri:str, windows_non_standard:bool=False)->str:
def uri2path(file_uri: str, windows_non_standard: bool = False) -> str:
# there is lot of buggy regex based code out there which expects Windows file URIs as
# file://C/... instead of standard file:///C/...
# When passing file uri to such code, turn on windows_non_standard
@ -347,28 +416,33 @@ def uri2path(file_uri:str, windows_non_standard:bool=False)->str:
os.path.join(host, url2pathname(unquote(parsed.path)))
)
def get_ranks(items:list, key=lambda v:v, reverse=False)->List[int]:
def get_ranks(items: list, key=lambda v: v, reverse=False) -> List[int]:
sorted_t = sorted(zip(items, range(len(items))),
key=lambda t: key(t[0]),
reverse=reverse)
sorted_map = dict((t[1], i) for i, t in enumerate(sorted_t))
return [sorted_map[i] for i in range(len(items))]
def dedup_list(l:List)->List:
def dedup_list(l: List) -> List:
return list(OrderedDict.fromkeys(l))
def delete_file(filepath:str)->bool:
def delete_file(filepath: str) -> bool:
if os.path.isfile(filepath):
os.remove(filepath)
return True
else:
return False
def save_as_yaml(obj, filepath:str)->None:
def save_as_yaml(obj, filepath: str) -> None:
with open(filepath, 'w', encoding='utf-8') as f:
yaml.dump(obj, f, default_flow_style=False)
def map_to_list(variable:Union[int,float,Sized], size:int)->Sized:
def map_to_list(variable: Union[int, float, Sized], size: int) -> Sized:
if isinstance(variable, Sized):
size_diff = size - len(variable)
@ -381,7 +455,8 @@ def map_to_list(variable:Union[int,float,Sized], size:int)->Sized:
return [variable] * size
def attr_to_dict(obj:Any, recursive:bool=True)->Dict[str, Any]:
def attr_to_dict(obj: Any, recursive: bool = True) -> Dict[str, Any]:
MAX_LIST_LEN = 10
variables = {}

Просмотреть файл

@ -7,6 +7,7 @@ import torchvision.transforms.functional as F
from torchvision.io import read_image
from archai.api.dataset_provider import DatasetProvider
from archai.common.utils import download_and_extract_zip
class FaceSyntheticsDataset(torch.utils.data.Dataset):
@ -25,18 +26,19 @@ class FaceSyntheticsDataset(torch.utils.data.Dataset):
img_size (Tuple[int, int]): Image size (width, height). Defaults to (256, 256).
subset (str, optional): Subset ['train', 'test', 'validation']. Defaults to 'train'.
val_size (int, optional): Validation set size. Defaults to 2000.
mask_size (Optional[Tuple[int, int]], optional): Segmentation mask size (width, height). If `None`,
mask_size (Optional[Tuple[int, int]], optional): Segmentation mask size (width, height). If `None`,
`img_size` is used. Defaults to None.
augmentation (Optional[Callable], optional): Augmentation function. Expects a callable object
with named arguments 'image' and 'mask' that returns a dictionary with 'image' and 'mask' as
with named arguments 'image' and 'mask' that returns a dictionary with 'image' and 'mask' as
keys. Defaults to None.
"""
dataset_dir = Path(dataset_dir)
assert dataset_dir.is_dir()
assert isinstance(img_size, tuple)
zip_url = "https://facesyntheticspubwedata.blob.core.windows.net/iccv-2021/dataset_100000.zip"
self.img_size = img_size
self.dataset_dir = dataset_dir
self.subset = subset
@ -44,6 +46,10 @@ class FaceSyntheticsDataset(torch.utils.data.Dataset):
self.augmentation = augmentation
all_seg_files = [str(f) for f in sorted(self.dataset_dir.glob('*_seg.png'))]
if len(all_seg_files) < 100000:
download_and_extract_zip(zip_url, self.dataset_dir)
all_seg_files = [str(f) for f in sorted(self.dataset_dir.glob('*_seg.png'))]
train_subset, test_subset = all_seg_files[:90_000], all_seg_files[90_000:]
if subset == 'train':
@ -55,22 +61,22 @@ class FaceSyntheticsDataset(torch.utils.data.Dataset):
self.img_files = [s.replace("_seg.png",".png") for s in self.seg_files]
self.ignore_index = ignore_index
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
sample = {
'image': read_image(self.img_files[idx]),
'mask': read_image(self.seg_files[idx]).long()
}
if self.augmentation and self.subset == 'train':
sample = self.augmentation(**sample)
sample['image'] = sample['image']/255
mask_size = self.mask_size if self.mask_size else self.img_size
mask_size = self.mask_size if self.mask_size else self.img_size
sample['mask'] = F.resize(
sample['mask'], mask_size[::-1],
interpolation=F.InterpolationMode.NEAREST
@ -84,17 +90,17 @@ class FaceSyntheticsDatasetProvider(DatasetProvider):
def __init__(self, dataset_dir: str):
self.dataset_dir = Path(dataset_dir)
assert self.dataset_dir.is_dir()
@overrides
def get_train_dataset(self, **kwargs) -> torch.utils.data.Dataset:
return FaceSyntheticsDataset(
self.dataset_dir, subset='train', **kwargs
)
@overrides
def get_test_dataset(self, **kwargs) -> torch.utils.data.Dataset:
return FaceSyntheticsDataset(
self.dataset_dir, subset='train', **kwargs
self.dataset_dir, subset='test', **kwargs
)
@overrides