Re-apply #4446 + add packaging dependency

As discussed w/ @lysandrejik

packaging is maintained by PyPA (the Python Packaging Authority), and should be lightweight and stable
This commit is contained in:
Julien Chaumond 2020-05-22 17:27:47 -04:00
Родитель e6aeb0d3e8
Коммит 2c1ebb8b50
2 изменённых файлов: 9 добавлений и 1 удалений

Просмотреть файл

@ -111,6 +111,8 @@ setup(
"tokenizers == 0.7.0",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions
"packaging",
# filesystem locks e.g. to prevent parallel downloads
"filelock",
# for downloading models over HTTPS

Просмотреть файл

@ -11,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
@ -494,7 +495,12 @@ class Trainer:
):
logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
logs["learning_rate"] = scheduler.get_last_lr()[0]
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss = tr_loss
self._log(logs)