зеркало из https://github.com/microsoft/archai.git
fix windows run
This commit is contained in:
Родитель
d861b5f2a6
Коммит
5a4ef14edb
|
@ -43,7 +43,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--no-search", "--algos", "darts", "--nas.eval.final_desc_filename", "models/final_model_desc.yaml"]
|
||||
"args": ["--no-search", "--algos", "darts", "--nas.eval.final_desc_filename", "models/darts/final_model_desc1.yaml"]
|
||||
},
|
||||
{
|
||||
"name": "Darts-E2E-Toy",
|
||||
|
|
|
@ -34,9 +34,6 @@ class ApexUtils:
|
|||
# to avoid circular references= with common, logger is passed from outside
|
||||
self.logger = logger
|
||||
|
||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||
|
||||
# defaults for non-distributed mode
|
||||
self._amp, self._ddp = None, None
|
||||
self.world_size = 1
|
||||
|
@ -50,6 +47,9 @@ class ApexUtils:
|
|||
#_log_info({'apex_config': apex_config.to_dict()})
|
||||
self._log_info({'torch.distributed.is_available': dist.is_available()})
|
||||
if dist.is_available():
|
||||
# dist.* properties are otherwise not accessible
|
||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||
self._log_info({'gloo_available': dist.is_gloo_available(),
|
||||
'mpi_available': dist.is_mpi_available(),
|
||||
'nccl_available': dist.is_nccl_available()})
|
||||
|
|
|
@ -88,7 +88,7 @@ nas:
|
|||
trainer:
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
enabled: True
|
||||
enabled: False
|
||||
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
|
|
Загрузка…
Ссылка в новой задаче