fix(supergraph): Adds Shital's fixes for working with FP16.

This commit is contained in:
Gustavo Rosa 2023-02-23 09:41:43 -03:00
Родитель f631895dff
Коммит 3c5c453e58
5 изменённых файлов: 9 добавлений и 6 удалений

2
.vscode/launch.json поставляемый
Просмотреть файл

@ -209,7 +209,7 @@
"request": "launch", "request": "launch",
"program": "${cwd}/scripts/supergraph/main.py", "program": "${cwd}/scripts/supergraph/main.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["--no-search", "--algos", "manual", "--datasets", "imagenet"] "args": ["--no-search", "--algos", "manual"]
}, },
{ {
"name": "Resnet-Full", "name": "Resnet-Full",

Просмотреть файл

@ -172,7 +172,7 @@ class ApexUtils:
def is_mixed(self)->bool: def is_mixed(self)->bool:
return self._enabled and self._mixed_prec_enabled return self._enabled and self._mixed_prec_enabled
def is_dist(self)->bool: def is_dist(self)->bool:
return self._enabled and self._distributed_enabled return self._enabled and self._distributed_enabled and self.world_size > 1
def is_master(self)->bool: def is_master(self)->bool:
return self.global_rank == 0 return self.global_rank == 0
def is_ray(self)->bool: def is_ray(self)->bool:

Просмотреть файл

@ -4,7 +4,6 @@
import os import os
from overrides import overrides from overrides import overrides
from PIL import Image
from torchvision import datasets from torchvision import datasets
from torchvision.transforms import transforms from torchvision.transforms import transforms
@ -59,7 +58,7 @@ class ImagenetProvider(DatasetProvider):
transform_train = transforms.Compose([ transform_train = transforms.Compose([
transforms.RandomResizedCrop(224, transforms.RandomResizedCrop(224,
scale=(0.08, 1.0), # TODO: these two params are normally not specified scale=(0.08, 1.0), # TODO: these two params are normally not specified
interpolation=Image.BICUBIC), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ColorJitter( transforms.ColorJitter(
brightness=0.4, brightness=0.4,

Просмотреть файл

@ -294,7 +294,8 @@ class Trainer(EnforceOverrides):
loss_sum += loss_c.item() * len(logits_c) loss_sum += loss_c.item() * len(logits_c)
loss_count += len(logits_c) loss_count += len(logits_c)
logits_chunks.append(logits_c.detach().cpu()) # pyright: ignore[reportGeneralTypeIssues] # TODO: cannot place on CPU if it was half precision but should we somehow?
logits_chunks.append(logits_c.detach()) # pyright: ignore[reportGeneralTypeIssues]
# TODO: original darts clips alphas as well but pt.darts doesn't # TODO: original darts clips alphas as well but pt.darts doesn't
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim) self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
@ -304,7 +305,8 @@ class Trainer(EnforceOverrides):
# TODO: we possibly need to sync so all replicas are upto date # TODO: we possibly need to sync so all replicas are upto date
self._apex.sync_devices() self._apex.sync_devices()
self.post_step(x, y, # TODO: we need to put y on GPU because logits are on GPU. Is this good idea from GPU mem perspective?
self.post_step(x, y.to(self.get_device(), non_blocking=True),
ml_utils.join_chunks(logits_chunks), ml_utils.join_chunks(logits_chunks),
torch.tensor(loss_sum/loss_count), torch.tensor(loss_sum/loss_count),
steps) steps)

Просмотреть файл

@ -0,0 +1,2 @@
REM Creates symbolic link to datasets folder, so that Archai can find the datasets
mklink /j %USERPROFILE%\dataroot E:\datasets