зеркало из https://github.com/microsoft/archai.git
fix(supergraph): Adds Shital's fixes for working with FP16.
This commit is contained in:
Родитель
f631895dff
Коммит
3c5c453e58
|
@ -209,7 +209,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/supergraph/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--no-search", "--algos", "manual", "--datasets", "imagenet"]
|
||||
"args": ["--no-search", "--algos", "manual"]
|
||||
},
|
||||
{
|
||||
"name": "Resnet-Full",
|
||||
|
|
|
@ -172,7 +172,7 @@ class ApexUtils:
|
|||
def is_mixed(self)->bool:
|
||||
return self._enabled and self._mixed_prec_enabled
|
||||
def is_dist(self)->bool:
|
||||
return self._enabled and self._distributed_enabled
|
||||
return self._enabled and self._distributed_enabled and self.world_size > 1
|
||||
def is_master(self)->bool:
|
||||
return self.global_rank == 0
|
||||
def is_ray(self)->bool:
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
import os
|
||||
|
||||
from overrides import overrides
|
||||
from PIL import Image
|
||||
from torchvision import datasets
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
|
@ -59,7 +58,7 @@ class ImagenetProvider(DatasetProvider):
|
|||
transform_train = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224,
|
||||
scale=(0.08, 1.0), # TODO: these two params are normally not specified
|
||||
interpolation=Image.BICUBIC),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
|
|
|
@ -294,7 +294,8 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
loss_sum += loss_c.item() * len(logits_c)
|
||||
loss_count += len(logits_c)
|
||||
logits_chunks.append(logits_c.detach().cpu()) # pyright: ignore[reportGeneralTypeIssues]
|
||||
# TODO: cannot place on CPU if it was half precision but should we somehow?
|
||||
logits_chunks.append(logits_c.detach()) # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
|
||||
|
@ -304,7 +305,8 @@ class Trainer(EnforceOverrides):
|
|||
# TODO: we possibly need to sync so all replicas are upto date
|
||||
self._apex.sync_devices()
|
||||
|
||||
self.post_step(x, y,
|
||||
# TODO: we need to put y on GPU because logits are on GPU. Is this good idea from GPU mem perspective?
|
||||
self.post_step(x, y.to(self.get_device(), non_blocking=True),
|
||||
ml_utils.join_chunks(logits_chunks),
|
||||
torch.tensor(loss_sum/loss_count),
|
||||
steps)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
REM Creates symbolic link to datasets folder, so that Archai can find the datasets
|
||||
mklink /j %USERPROFILE%\dataroot E:\datasets
|
Загрузка…
Ссылка в новой задаче