fix(supergraph): Adds Shital's fixes for working with FP16.

This commit is contained in:
Gustavo Rosa 2023-02-23 09:41:43 -03:00
Родитель f631895dff
Коммит 3c5c453e58
5 изменённых файлов: 9 добавлений и 6 удалений

2
.vscode/launch.json поставляемый
Просмотреть файл

@ -209,7 +209,7 @@
"request": "launch",
"program": "${cwd}/scripts/supergraph/main.py",
"console": "integratedTerminal",
"args": ["--no-search", "--algos", "manual", "--datasets", "imagenet"]
"args": ["--no-search", "--algos", "manual"]
},
{
"name": "Resnet-Full",

Просмотреть файл

@ -172,7 +172,7 @@ class ApexUtils:
def is_mixed(self)->bool:
return self._enabled and self._mixed_prec_enabled
def is_dist(self)->bool:
return self._enabled and self._distributed_enabled
return self._enabled and self._distributed_enabled and self.world_size > 1
def is_master(self)->bool:
return self.global_rank == 0
def is_ray(self)->bool:

Просмотреть файл

@ -4,7 +4,6 @@
import os
from overrides import overrides
from PIL import Image
from torchvision import datasets
from torchvision.transforms import transforms
@ -59,7 +58,7 @@ class ImagenetProvider(DatasetProvider):
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224,
scale=(0.08, 1.0), # TODO: these two params are normally not specified
interpolation=Image.BICUBIC),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,

Просмотреть файл

@ -294,7 +294,8 @@ class Trainer(EnforceOverrides):
loss_sum += loss_c.item() * len(logits_c)
loss_count += len(logits_c)
logits_chunks.append(logits_c.detach().cpu()) # pyright: ignore[reportGeneralTypeIssues]
# TODO: cannot place on CPU if it was half precision but should we somehow?
logits_chunks.append(logits_c.detach()) # pyright: ignore[reportGeneralTypeIssues]
# TODO: original darts clips alphas as well but pt.darts doesn't
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
@ -304,7 +305,8 @@ class Trainer(EnforceOverrides):
# TODO: we possibly need to sync so all replicas are upto date
self._apex.sync_devices()
self.post_step(x, y,
# TODO: we need to put y on GPU because logits are on GPU. Is this good idea from GPU mem perspective?
self.post_step(x, y.to(self.get_device(), non_blocking=True),
ml_utils.join_chunks(logits_chunks),
torch.tensor(loss_sum/loss_count),
steps)

Просмотреть файл

@ -0,0 +1,2 @@
REM Creates symbolic link to datasets folder, so that Archai can find the datasets
mklink /j %USERPROFILE%\dataroot E:\datasets