зеркало из https://github.com/microsoft/archai.git
add download capability to FaceSyntheticsDataset (#216)
* cleanup launch.json add download capability to FaceSyntheticsDataset add download_and_extract_zip helper * fix file count test * add ability to resolve environment variables in a config file.
This commit is contained in:
Родитель
30d929c9c2
Коммит
3179a8c7a2
|
@ -262,12 +262,7 @@
|
|||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "D:\\git\\microsoft\\archai\\archai\\tasks\\face_segmentation",
|
||||
"args":[
|
||||
"--dataset_dir", "C:\\datasets\\FaceSynthetics",
|
||||
"--output_dir", "d:\\temp\\face_segmentation",
|
||||
"--search_config", "confs\\snp_search.yaml",
|
||||
"--serial_training"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
@ -15,11 +15,12 @@ import yaml
|
|||
from archai.common import yaml_utils
|
||||
|
||||
# global config instance
|
||||
_config:Optional['Config'] = None
|
||||
_config: Optional['Config'] = None
|
||||
|
||||
|
||||
# TODO: remove this duplicate code which is also in utils.py without circular deps
|
||||
def deep_update(d:MutableMapping, u:Mapping, create_map:Callable[[],MutableMapping])\
|
||||
->MutableMapping:
|
||||
def deep_update(d: MutableMapping, u: Mapping, create_map: Callable[[], MutableMapping])\
|
||||
-> MutableMapping:
|
||||
for k, v in u.items():
|
||||
if isinstance(v, Mapping):
|
||||
d[k] = deep_update(d.get(k, create_map()), v, create_map)
|
||||
|
@ -27,24 +28,25 @@ def deep_update(d:MutableMapping, u:Mapping, create_map:Callable[[],MutableMappi
|
|||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
class Config(UserDict):
|
||||
def __init__(self, config_filepath:Optional[str]=None,
|
||||
app_desc:Optional[str]=None, use_args=False,
|
||||
param_args: Sequence = [], resolve_redirects=True) -> None:
|
||||
def __init__(self, config_filepath: Optional[str] = None,
|
||||
app_desc: Optional[str] = None, use_args=False,
|
||||
param_args: Sequence = [], resolve_redirects=True,
|
||||
resolve_env_vars=False) -> None:
|
||||
"""Create config from specified files and args
|
||||
|
||||
Config is simply a dictionary of key, value map. The value can itself be
|
||||
a dictionary so config can be hierarchical. This class allows to load
|
||||
config from yaml. A special key '__include__' can specify another yaml
|
||||
relative file path (or list of file paths) which will be loaded first
|
||||
and the key-value pairs in the main file
|
||||
will override the ones in include file. You can think of included file as
|
||||
defaults provider. This allows to create one base config and then several
|
||||
environment/experiment specific configs. On the top of that you can use
|
||||
Config is simply a dictionary of key, value map. The value can itself be a dictionary so config can be
|
||||
hierarchical. This class allows to load config from yaml. A special key '__include__' can specify another yaml
|
||||
relative file path (or list of file paths) which will be loaded first and the key-value pairs in the main file
|
||||
will override the ones in include file. You can think of included file as defaults provider. This allows to
|
||||
create one base config and then several environment/experiment specific configs. On the top of that you can use
|
||||
param_args to perform final overrides for a given run.
|
||||
|
||||
You can also have values that reference environment variables using ${ENV_VAR_NAME} syntax.
|
||||
|
||||
Keyword Arguments:
|
||||
config_filepath {[str]} -- [Yaml file to load config from, could be names of files separated by semicolon which will be loaded in sequence oveeriding previous config] (default: {None})
|
||||
config_filepath {[str]} -- [Yaml file to load config from, could be names of files separated by semicolon which will be loaded in sequence overriding previous config] (default: {None})
|
||||
app_desc {[str]} -- [app description that will show up in --help] (default: {None})
|
||||
use_args {bool} -- [if true then command line parameters will override parameters from config files] (default: {False})
|
||||
param_args {Sequence} -- [parameters specified as ['--key1',val1,'--key2',val2,...] which will override parameters from config file.] (default: {[]})
|
||||
|
@ -58,7 +60,7 @@ class Config(UserDict):
|
|||
# let command line args specify/override config file
|
||||
parser = argparse.ArgumentParser(description=app_desc)
|
||||
parser.add_argument('--config', type=str, default=None,
|
||||
help='config filepath in yaml format, can be list separated by ;')
|
||||
help='config filepath in yaml format, can be list separated by ;')
|
||||
self.args, self.extra_args = parser.parse_known_args()
|
||||
config_filepath = self.args.config or config_filepath
|
||||
|
||||
|
@ -74,15 +76,18 @@ class Config(UserDict):
|
|||
yaml_utils.resolve_all(resolved_conf)
|
||||
|
||||
# Let's do final overrides from args
|
||||
self._update_from_args(param_args, resolved_conf) # merge from params
|
||||
self._update_from_args(self.extra_args, resolved_conf) # merge from command line
|
||||
self._update_from_args(param_args, resolved_conf) # merge from params
|
||||
self._update_from_args(self.extra_args, resolved_conf) # merge from command line
|
||||
|
||||
if resolve_env_vars:
|
||||
self._process_envvars(resolved_conf)
|
||||
|
||||
if resolve_redirects:
|
||||
yaml_utils.resolve_all(self)
|
||||
|
||||
self.config_filepath = config_filepath
|
||||
|
||||
def _load_from_file(self, filepath: Optional[str])->None:
|
||||
def _load_from_file(self, filepath: Optional[str]) -> None:
|
||||
if filepath:
|
||||
filepath = os.path.expanduser(os.path.expandvars(filepath))
|
||||
filepath = os.path.abspath(filepath)
|
||||
|
@ -92,7 +97,7 @@ class Config(UserDict):
|
|||
deep_update(self, config_yaml, lambda: Config(resolve_redirects=False))
|
||||
print('config loaded from: ', filepath)
|
||||
|
||||
def _process_includes(self, config_yaml, filepath:str):
|
||||
def _process_includes(self, config_yaml, filepath: str):
|
||||
if '__include__' in config_yaml:
|
||||
# include could be file name or array of file names to apply in sequence
|
||||
includes = config_yaml['__include__']
|
||||
|
@ -103,58 +108,66 @@ class Config(UserDict):
|
|||
include_filepath = os.path.join(os.path.dirname(filepath), include)
|
||||
self._load_from_file(include_filepath)
|
||||
|
||||
def _update_from_args(self, args:Sequence, resolved_section:'Config')->None:
|
||||
def _process_envvars(self, config_yaml):
|
||||
for key in config_yaml:
|
||||
value = config_yaml[key]
|
||||
if isinstance(value, Config):
|
||||
self._process_envvars(value)
|
||||
elif isinstance(value, str) and '$' in value:
|
||||
config_yaml[key] = os.path.expandvars(value)
|
||||
|
||||
def _update_from_args(self, args: Sequence, resolved_section: 'Config') -> None:
|
||||
i = 0
|
||||
while i < len(args)-1:
|
||||
arg = args[i]
|
||||
if arg.startswith(("--")):
|
||||
path = arg[len("--"):].split('.')
|
||||
i += Config._update_section(self, path, args[i+1], resolved_section)
|
||||
else: # some other arg
|
||||
else: # some other arg
|
||||
i += 1
|
||||
|
||||
def to_dict(self)->dict:
|
||||
return deep_update({}, self, lambda: dict()) # type: ignore
|
||||
def to_dict(self) -> dict:
|
||||
return deep_update({}, self, lambda: dict()) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _update_section(section:'Config', path:List[str], val:Any, resolved_section:'Config')->int:
|
||||
def _update_section(section: 'Config', path: List[str], val: Any, resolved_section: 'Config') -> int:
|
||||
for p in range(len(path)-1):
|
||||
sub_path = path[p]
|
||||
if sub_path in resolved_section:
|
||||
resolved_section = resolved_section[sub_path]
|
||||
if not sub_path in section:
|
||||
if sub_path not in section:
|
||||
section[sub_path] = Config(resolve_redirects=False)
|
||||
section = section[sub_path]
|
||||
else:
|
||||
return 1 # path not found, ignore this
|
||||
key = path[-1] # final leaf node value
|
||||
return 1 # path not found, ignore this
|
||||
key = path[-1] # final leaf node value
|
||||
|
||||
if key in resolved_section:
|
||||
original_val, original_type = None, None
|
||||
try:
|
||||
original_val = resolved_section[key]
|
||||
original_type = type(original_val)
|
||||
if original_type == bool: # bool('False') is True :(
|
||||
original_type = lambda x: strtobool(x)==1
|
||||
if original_type == bool: # bool('False') is True :(
|
||||
original_type = lambda x: strtobool(x) == 1
|
||||
section[key] = original_type(val)
|
||||
except Exception as e:
|
||||
raise KeyError(
|
||||
f'The yaml key or command line argument "{key}" is likely not named correctly or value is of wrong data type. Error was occured when setting it to value "{val}".'
|
||||
f'The yaml key or command line argument "{key}" is likely not named correctly or value is of wrong data type. Error was occurred when setting it to value "{val}".'
|
||||
f'Originally it is set to {original_val} which is of type {original_type}.'
|
||||
f'Original exception: {e}')
|
||||
return 2 # path was found, increment arg pointer by 2 as we use up val
|
||||
return 2 # path was found, increment arg pointer by 2 as we use up val
|
||||
else:
|
||||
return 1 # path not found, ignore this
|
||||
return 1 # path not found, ignore this
|
||||
|
||||
def get_val(self, key, default_val):
|
||||
return super().get(key, default_val)
|
||||
|
||||
@staticmethod
|
||||
def set_inst(instance:'Config')->None:
|
||||
def set_inst(instance: 'Config') -> None:
|
||||
global _config
|
||||
_config = instance
|
||||
|
||||
@staticmethod
|
||||
def get_inst()->'Config':
|
||||
def get_inst() -> 'Config':
|
||||
global _config
|
||||
return _config
|
||||
|
|
|
@ -54,13 +54,14 @@ class AverageMeter:
|
|||
self.cnt += n
|
||||
self.avg = self.sum / self.cnt
|
||||
|
||||
def first_or_default(it:Iterable, default=None):
|
||||
|
||||
def first_or_default(it: Iterable, default=None):
|
||||
for i in it:
|
||||
return i
|
||||
return default
|
||||
|
||||
def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)\
|
||||
->MutableMapping:
|
||||
|
||||
def deep_update(d: MutableMapping, u: Mapping, map_type: Type[MutableMapping] = dict) -> MutableMapping:
|
||||
for k, v in u.items():
|
||||
if isinstance(v, Mapping):
|
||||
d[k] = deep_update(d.get(k, map_type()), v, map_type)
|
||||
|
@ -68,7 +69,8 @@ def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)
|
|||
d[k] = v
|
||||
return d
|
||||
|
||||
def state_dict(val)->Mapping:
|
||||
|
||||
def state_dict(val) -> Mapping:
|
||||
assert hasattr(val, '__dict__'), 'val must be object with __dict__ otherwise it cannot be loaded back in load_state_dict'
|
||||
|
||||
# Can't do below because val has state_dict() which calls utils.state_dict
|
||||
|
@ -79,7 +81,8 @@ def state_dict(val)->Mapping:
|
|||
|
||||
return {'yaml': yaml.dump(val)}
|
||||
|
||||
def load_state_dict(val:Any, state_dict:Mapping)->None:
|
||||
|
||||
def load_state_dict(val: Any, state_dict: Mapping) -> None:
|
||||
assert hasattr(val, '__dict__'), 'val must be object with __dict__'
|
||||
|
||||
# Can't do below because val has state_dict() which calls utils.state_dict
|
||||
|
@ -93,7 +96,8 @@ def load_state_dict(val:Any, state_dict:Mapping)->None:
|
|||
for k, v in obj.__dict__.items():
|
||||
setattr(val, k, v)
|
||||
|
||||
def deep_comp(o1:Any, o2:Any)->bool:
|
||||
|
||||
def deep_comp(o1: Any, o2: Any) -> bool:
|
||||
# NOTE: dict don't have __dict__
|
||||
o1d = getattr(o1, '__dict__', None)
|
||||
o2d = getattr(o2, '__dict__', None)
|
||||
|
@ -111,21 +115,25 @@ def deep_comp(o1:Any, o2:Any)->bool:
|
|||
if not deep_comp(o1[k], o2[k]):
|
||||
return False
|
||||
else:
|
||||
return False # some key missing
|
||||
return False # some key missing
|
||||
return True
|
||||
# mismatched object types or both are scalers, or one or both None
|
||||
return o1 == o2
|
||||
|
||||
|
||||
# We setup env variable if debugging mode is detected for vs_code_debugging.
|
||||
# The reason for this is that when Python multiprocessing is used, the new process
|
||||
# spawned do not inherit 'pydevd' so those process do not get detected as in debugging mode
|
||||
# even though they are. So we set env var which does get inherited by sub processes.
|
||||
if 'pydevd' in sys.modules:
|
||||
os.environ['vs_code_debugging'] = 'True'
|
||||
def is_debugging()->bool:
|
||||
return 'vs_code_debugging' in os.environ and os.environ['vs_code_debugging']=='True'
|
||||
|
||||
def full_path(path:str, create=False)->str:
|
||||
|
||||
def is_debugging() -> bool:
|
||||
return 'vs_code_debugging' in os.environ and os.environ['vs_code_debugging'] == 'True'
|
||||
|
||||
|
||||
def full_path(path: str, create=False) -> str:
|
||||
assert path
|
||||
path = os.path.abspath(
|
||||
os.path.expanduser(
|
||||
|
@ -134,21 +142,27 @@ def full_path(path:str, create=False)->str:
|
|||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def zero_file(filepath)->None:
|
||||
|
||||
def zero_file(filepath) -> None:
|
||||
"""Creates or truncates existing file"""
|
||||
open(filepath, 'w').close()
|
||||
|
||||
def write_string(filepath:str, content:str)->None:
|
||||
|
||||
def write_string(filepath: str, content: str) -> None:
|
||||
pathlib.Path(filepath).write_text(content)
|
||||
def read_string(filepath:str)->str:
|
||||
|
||||
|
||||
def read_string(filepath: str) -> str:
|
||||
return pathlib.Path(filepath).read_text()
|
||||
|
||||
def fmt(val:Any)->str:
|
||||
|
||||
def fmt(val: Any) -> str:
|
||||
if isinstance(val, float):
|
||||
return f'{val:.4g}'
|
||||
return str(val)
|
||||
|
||||
def append_csv_file(filepath:str, new_row:List[Tuple[str, Any]], delimiter='\t'):
|
||||
|
||||
def append_csv_file(filepath: str, new_row: List[Tuple[str, Any]], delimiter='\t'):
|
||||
fieldnames, rows = [], []
|
||||
if os.path.exists(filepath):
|
||||
with open(filepath, 'r') as f:
|
||||
|
@ -160,19 +174,21 @@ def append_csv_file(filepath:str, new_row:List[Tuple[str, Any]], delimiter='\t')
|
|||
|
||||
new_fieldnames = OrderedDict([(fn, None) for fn, v in new_row])
|
||||
for fn in fieldnames:
|
||||
new_fieldnames[fn]=None
|
||||
new_fieldnames[fn] = None
|
||||
|
||||
with open(filepath, 'w', newline='') as f:
|
||||
dr = csv.DictWriter(f, fieldnames=new_fieldnames.keys(), delimiter=delimiter)
|
||||
dr.writeheader()
|
||||
for row in rows:
|
||||
d = dict((k,v) for k,v in zip(fieldnames, row))
|
||||
d = dict((k, v) for k, v in zip(fieldnames, row))
|
||||
dr.writerow(d)
|
||||
dr.writerow(OrderedDict(new_row))
|
||||
|
||||
|
||||
def has_method(o, name):
|
||||
return callable(getattr(o, name, None))
|
||||
|
||||
|
||||
def extract_tar(src, dest=None, gzip=None, delete=False):
|
||||
import tarfile
|
||||
|
||||
|
@ -188,6 +204,20 @@ def extract_tar(src, dest=None, gzip=None, delete=False):
|
|||
if delete:
|
||||
os.remove(src)
|
||||
|
||||
|
||||
def extract_zip(src, dest=None, delete=False):
|
||||
import zipfile
|
||||
|
||||
if dest is None:
|
||||
dest = os.path.dirname(src)
|
||||
|
||||
with zipfile.ZipFile(src, 'r') as zip_ref:
|
||||
zip_ref.extractall(dest)
|
||||
|
||||
if delete:
|
||||
os.remove(src)
|
||||
|
||||
|
||||
def download_and_extract_tar(url, download_root, extract_root=None, filename=None,
|
||||
md5=None, **kwargs):
|
||||
download_root = os.path.expanduser(download_root)
|
||||
|
@ -201,7 +231,22 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non
|
|||
|
||||
extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
|
||||
|
||||
def setup_cuda(seed:Union[float, int], local_rank:int=0):
|
||||
|
||||
def download_and_extract_zip(url, download_root, extract_root=None, filename=None,
|
||||
md5=None, **kwargs):
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if filename is None:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
if not tvutils.check_integrity(os.path.join(download_root, filename), md5):
|
||||
tvutils.download_url(url, download_root, filename=filename, md5=md5)
|
||||
|
||||
extract_zip(os.path.join(download_root, filename), extract_root, delete=True, **kwargs)
|
||||
|
||||
|
||||
def setup_cuda(seed: Union[float, int], local_rank: int = 0):
|
||||
seed = int(seed) + local_rank
|
||||
# setup cuda
|
||||
cudnn.enabled = True
|
||||
|
@ -210,17 +255,19 @@ def setup_cuda(seed:Union[float, int], local_rank:int=0):
|
|||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
#torch.cuda.manual_seed_all(seed)
|
||||
cudnn.benchmark = True # set to false if deterministic
|
||||
# torch.cuda.manual_seed_all(seed)
|
||||
cudnn.benchmark = True # set to false if deterministic
|
||||
torch.set_printoptions(precision=10)
|
||||
#cudnn.deterministic = False
|
||||
# cudnn.deterministic = False
|
||||
# torch.cuda.empty_cache()
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
def cuda_device_names()->str:
|
||||
|
||||
def cuda_device_names() -> str:
|
||||
return ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
|
||||
|
||||
def exec_shell_command(command:str, print_command_start=True, print_command_end=True)->subprocess.CompletedProcess:
|
||||
|
||||
def exec_shell_command(command: str, print_command_start=True, print_command_end=True) -> subprocess.CompletedProcess:
|
||||
if print_command_start:
|
||||
print(f'[{datetime.now()}] Running: {command}')
|
||||
|
||||
|
@ -231,49 +278,59 @@ def exec_shell_command(command:str, print_command_start=True, print_command_end=
|
|||
|
||||
return ret
|
||||
|
||||
|
||||
def zip_eq(*iterables):
|
||||
sentinel = object()
|
||||
for count, combo in enumerate(zip_longest(*iterables, fillvalue=sentinel)):
|
||||
if any(True for c in combo if sentinel is c):
|
||||
shorter_its = ','.join([str(i) for i,c in enumerate(combo) if sentinel is c])
|
||||
shorter_its = ','.join([str(i) for i, c in enumerate(combo) if sentinel is c])
|
||||
raise ValueError(f'Iterator {shorter_its} have length {count} which is shorter than others')
|
||||
yield combo
|
||||
|
||||
def dir_downloads()->str:
|
||||
|
||||
def dir_downloads() -> str:
|
||||
return full_path(str(os.path.join(pathlib.Path.home(), "Downloads")))
|
||||
|
||||
def filepath_without_ext(filepath:str)->str:
|
||||
|
||||
def filepath_without_ext(filepath: str) -> str:
|
||||
"""Returns '/a/b/c/d.e' for '/a/b/c/d.e.f' """
|
||||
return str(pathlib.Path(filepath).with_suffix(''))
|
||||
|
||||
def filepath_ext(filepath:str)->str:
|
||||
|
||||
def filepath_ext(filepath: str) -> str:
|
||||
"""Returns '.f' for '/a/b/c/d.e.f' """
|
||||
return pathlib.Path(filepath).suffix
|
||||
|
||||
def filepath_name_ext(filepath:str)->str:
|
||||
|
||||
def filepath_name_ext(filepath: str) -> str:
|
||||
"""Returns 'd.e.f' for '/a/b/c/d.e.f' """
|
||||
return pathlib.Path(filepath).name
|
||||
|
||||
def filepath_name_only(filepath:str)->str:
|
||||
|
||||
def filepath_name_only(filepath: str) -> str:
|
||||
"""Returns 'd.e' for '/a/b/c/d.e.f' """
|
||||
return pathlib.Path(filepath).stem
|
||||
|
||||
def change_filepath_ext(filepath:str, new_ext:str)->str:
|
||||
|
||||
def change_filepath_ext(filepath: str, new_ext: str) -> str:
|
||||
"""Returns '/a/b/c/d.e.g' for filepath='/a/b/c/d.e.f', new_ext='.g' """
|
||||
return str(pathlib.Path(filepath).with_suffix(new_ext))
|
||||
|
||||
def change_filepath_name(filepath:str, new_name:str, new_ext:Optional[str]=None)->str:
|
||||
|
||||
def change_filepath_name(filepath: str, new_name: str, new_ext: Optional[str] = None) -> str:
|
||||
"""Returns '/a/b/c/h.f' for filepath='/a/b/c/d.e.f', new_name='h' """
|
||||
ext = new_ext or filepath_ext(filepath)
|
||||
return str(pathlib.Path(filepath).with_name(new_name).with_suffix(ext))
|
||||
|
||||
def append_to_filename(filepath:str, name_suffix:str, new_ext:Optional[str]=None)->str:
|
||||
|
||||
def append_to_filename(filepath: str, name_suffix: str, new_ext: Optional[str] = None) -> str:
|
||||
"""Returns '/a/b/c/h.f' for filepath='/a/b/c/d.e.f', new_name='h' """
|
||||
ext = new_ext or filepath_ext(filepath)
|
||||
name = filepath_name_only(filepath)
|
||||
return str(pathlib.Path(filepath).with_name(name+name_suffix).with_suffix(ext))
|
||||
|
||||
def copy_file(src_file:str, dest_dir_or_file:str, preserve_metadata=False, use_shutil:bool=True)->str:
|
||||
|
||||
def copy_file(src_file: str, dest_dir_or_file: str, preserve_metadata=False, use_shutil: bool = True) -> str:
|
||||
if not use_shutil:
|
||||
assert not preserve_metadata
|
||||
return copy_file_basic(src_file, dest_dir_or_file)
|
||||
|
@ -284,11 +341,12 @@ def copy_file(src_file:str, dest_dir_or_file:str, preserve_metadata=False, use_s
|
|||
copy_fn = shutil.copy2 if preserve_metadata else shutil.copy
|
||||
return copy_fn(src_file, dest_dir_or_file)
|
||||
except OSError as ex:
|
||||
if preserve_metadata or ex.errno != 38: # OSError: [Errno 38] Function not implemented
|
||||
if preserve_metadata or ex.errno != 38: # OSError: [Errno 38] Function not implemented
|
||||
raise
|
||||
return copy_file_basic(src_file, dest_dir_or_file)
|
||||
|
||||
def copy_file_basic(src_file:str, dest_dir_or_file:str)->str:
|
||||
|
||||
def copy_file_basic(src_file: str, dest_dir_or_file: str) -> str:
|
||||
# try basic python functions
|
||||
# first if dest is dir, get dest file name
|
||||
if os.path.isdir(dest_dir_or_file):
|
||||
|
@ -297,7 +355,8 @@ def copy_file_basic(src_file:str, dest_dir_or_file:str)->str:
|
|||
dst.write(src.read())
|
||||
return dest_dir_or_file
|
||||
|
||||
def copy_dir(src_dir:str, dest_dir:str, use_shutil:bool=True)->None:
|
||||
|
||||
def copy_dir(src_dir: str, dest_dir: str, use_shutil: bool = True) -> None:
|
||||
if os.path.isdir(src_dir):
|
||||
if use_shutil:
|
||||
shutil.copytree(src_dir, dest_dir)
|
||||
|
@ -307,24 +366,33 @@ def copy_dir(src_dir:str, dest_dir:str, use_shutil:bool=True)->None:
|
|||
files = os.listdir(src_dir)
|
||||
for f in files:
|
||||
copy_dir(os.path.join(src_dir, f),
|
||||
os.path.join(dest_dir, f), use_shutil=use_shutil)
|
||||
os.path.join(dest_dir, f), use_shutil=use_shutil)
|
||||
else:
|
||||
copy_file(src_dir, dest_dir, use_shutil=use_shutil)
|
||||
|
||||
|
||||
if 'main_process_pid' not in os.environ:
|
||||
os.environ['main_process_pid'] = str(os.getpid())
|
||||
def is_main_process()->bool:
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""Returns True if this process was started as main process instead of child process during multiprocessing"""
|
||||
return multiprocessing.current_process().name == 'MainProcess' and os.environ['main_process_pid'] == str(os.getpid())
|
||||
def main_process_pid()->int:
|
||||
|
||||
|
||||
def main_process_pid() -> int:
|
||||
return int(os.environ['main_process_pid'])
|
||||
def process_name()->str:
|
||||
|
||||
|
||||
def process_name() -> str:
|
||||
return multiprocessing.current_process().name
|
||||
|
||||
def is_windows()->bool:
|
||||
return platform.system()=='Windows'
|
||||
|
||||
def path2uri(path:str, windows_non_standard:bool=False)->str:
|
||||
def is_windows() -> bool:
|
||||
return platform.system() == 'Windows'
|
||||
|
||||
|
||||
def path2uri(path: str, windows_non_standard: bool = False) -> str:
|
||||
uri = pathlib.Path(full_path(path)).as_uri()
|
||||
|
||||
# there is lot of buggy regex based code out there which expects Windows file URIs as
|
||||
|
@ -334,7 +402,8 @@ def path2uri(path:str, windows_non_standard:bool=False)->str:
|
|||
uri = uri.replace('file:///', 'file://')
|
||||
return uri
|
||||
|
||||
def uri2path(file_uri:str, windows_non_standard:bool=False)->str:
|
||||
|
||||
def uri2path(file_uri: str, windows_non_standard: bool = False) -> str:
|
||||
# there is lot of buggy regex based code out there which expects Windows file URIs as
|
||||
# file://C/... instead of standard file:///C/...
|
||||
# When passing file uri to such code, turn on windows_non_standard
|
||||
|
@ -347,28 +416,33 @@ def uri2path(file_uri:str, windows_non_standard:bool=False)->str:
|
|||
os.path.join(host, url2pathname(unquote(parsed.path)))
|
||||
)
|
||||
|
||||
def get_ranks(items:list, key=lambda v:v, reverse=False)->List[int]:
|
||||
|
||||
def get_ranks(items: list, key=lambda v: v, reverse=False) -> List[int]:
|
||||
sorted_t = sorted(zip(items, range(len(items))),
|
||||
key=lambda t: key(t[0]),
|
||||
reverse=reverse)
|
||||
sorted_map = dict((t[1], i) for i, t in enumerate(sorted_t))
|
||||
return [sorted_map[i] for i in range(len(items))]
|
||||
|
||||
def dedup_list(l:List)->List:
|
||||
|
||||
def dedup_list(l: List) -> List:
|
||||
return list(OrderedDict.fromkeys(l))
|
||||
|
||||
def delete_file(filepath:str)->bool:
|
||||
|
||||
def delete_file(filepath: str) -> bool:
|
||||
if os.path.isfile(filepath):
|
||||
os.remove(filepath)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def save_as_yaml(obj, filepath:str)->None:
|
||||
|
||||
def save_as_yaml(obj, filepath: str) -> None:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(obj, f, default_flow_style=False)
|
||||
|
||||
def map_to_list(variable:Union[int,float,Sized], size:int)->Sized:
|
||||
|
||||
def map_to_list(variable: Union[int, float, Sized], size: int) -> Sized:
|
||||
if isinstance(variable, Sized):
|
||||
size_diff = size - len(variable)
|
||||
|
||||
|
@ -381,7 +455,8 @@ def map_to_list(variable:Union[int,float,Sized], size:int)->Sized:
|
|||
|
||||
return [variable] * size
|
||||
|
||||
def attr_to_dict(obj:Any, recursive:bool=True)->Dict[str, Any]:
|
||||
|
||||
def attr_to_dict(obj: Any, recursive: bool = True) -> Dict[str, Any]:
|
||||
MAX_LIST_LEN = 10
|
||||
variables = {}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import torchvision.transforms.functional as F
|
|||
from torchvision.io import read_image
|
||||
|
||||
from archai.api.dataset_provider import DatasetProvider
|
||||
from archai.common.utils import download_and_extract_zip
|
||||
|
||||
|
||||
class FaceSyntheticsDataset(torch.utils.data.Dataset):
|
||||
|
@ -25,18 +26,19 @@ class FaceSyntheticsDataset(torch.utils.data.Dataset):
|
|||
img_size (Tuple[int, int]): Image size (width, height). Defaults to (256, 256).
|
||||
subset (str, optional): Subset ['train', 'test', 'validation']. Defaults to 'train'.
|
||||
val_size (int, optional): Validation set size. Defaults to 2000.
|
||||
|
||||
mask_size (Optional[Tuple[int, int]], optional): Segmentation mask size (width, height). If `None`,
|
||||
|
||||
mask_size (Optional[Tuple[int, int]], optional): Segmentation mask size (width, height). If `None`,
|
||||
`img_size` is used. Defaults to None.
|
||||
|
||||
|
||||
augmentation (Optional[Callable], optional): Augmentation function. Expects a callable object
|
||||
with named arguments 'image' and 'mask' that returns a dictionary with 'image' and 'mask' as
|
||||
with named arguments 'image' and 'mask' that returns a dictionary with 'image' and 'mask' as
|
||||
keys. Defaults to None.
|
||||
"""
|
||||
dataset_dir = Path(dataset_dir)
|
||||
assert dataset_dir.is_dir()
|
||||
assert isinstance(img_size, tuple)
|
||||
|
||||
zip_url = "https://facesyntheticspubwedata.blob.core.windows.net/iccv-2021/dataset_100000.zip"
|
||||
self.img_size = img_size
|
||||
self.dataset_dir = dataset_dir
|
||||
self.subset = subset
|
||||
|
@ -44,6 +46,10 @@ class FaceSyntheticsDataset(torch.utils.data.Dataset):
|
|||
self.augmentation = augmentation
|
||||
|
||||
all_seg_files = [str(f) for f in sorted(self.dataset_dir.glob('*_seg.png'))]
|
||||
if len(all_seg_files) < 100000:
|
||||
download_and_extract_zip(zip_url, self.dataset_dir)
|
||||
all_seg_files = [str(f) for f in sorted(self.dataset_dir.glob('*_seg.png'))]
|
||||
|
||||
train_subset, test_subset = all_seg_files[:90_000], all_seg_files[90_000:]
|
||||
|
||||
if subset == 'train':
|
||||
|
@ -55,22 +61,22 @@ class FaceSyntheticsDataset(torch.utils.data.Dataset):
|
|||
|
||||
self.img_files = [s.replace("_seg.png",".png") for s in self.seg_files]
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = {
|
||||
'image': read_image(self.img_files[idx]),
|
||||
'mask': read_image(self.seg_files[idx]).long()
|
||||
}
|
||||
|
||||
|
||||
if self.augmentation and self.subset == 'train':
|
||||
sample = self.augmentation(**sample)
|
||||
|
||||
|
||||
sample['image'] = sample['image']/255
|
||||
|
||||
mask_size = self.mask_size if self.mask_size else self.img_size
|
||||
|
||||
mask_size = self.mask_size if self.mask_size else self.img_size
|
||||
sample['mask'] = F.resize(
|
||||
sample['mask'], mask_size[::-1],
|
||||
interpolation=F.InterpolationMode.NEAREST
|
||||
|
@ -84,17 +90,17 @@ class FaceSyntheticsDatasetProvider(DatasetProvider):
|
|||
def __init__(self, dataset_dir: str):
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
assert self.dataset_dir.is_dir()
|
||||
|
||||
|
||||
@overrides
|
||||
def get_train_dataset(self, **kwargs) -> torch.utils.data.Dataset:
|
||||
return FaceSyntheticsDataset(
|
||||
self.dataset_dir, subset='train', **kwargs
|
||||
)
|
||||
|
||||
|
||||
@overrides
|
||||
def get_test_dataset(self, **kwargs) -> torch.utils.data.Dataset:
|
||||
return FaceSyntheticsDataset(
|
||||
self.dataset_dir, subset='train', **kwargs
|
||||
self.dataset_dir, subset='test', **kwargs
|
||||
)
|
||||
|
||||
@overrides
|
||||
|
|
Загрузка…
Ссылка в новой задаче