Re-apply #4446 + add packaging dependency
As discussed w/ @lysandrejik packaging is maintained by PyPA (the Python Packaging Authority), and should be lightweight and stable
This commit is contained in:
Родитель
e6aeb0d3e8
Коммит
2c1ebb8b50
2
setup.py
2
setup.py
|
@ -111,6 +111,8 @@ setup(
|
|||
"tokenizers == 0.7.0",
|
||||
# dataclasses for Python versions that don't have it
|
||||
"dataclasses;python_version<'3.7'",
|
||||
# utilities from PyPA to e.g. compare versions
|
||||
"packaging",
|
||||
# filesystem locks e.g. to prevent parallel downloads
|
||||
"filelock",
|
||||
# for downloading models over HTTPS
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
@ -494,7 +495,12 @@ class Trainer:
|
|||
):
|
||||
logs: Dict[str, float] = {}
|
||||
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
|
||||
logs["learning_rate"] = scheduler.get_last_lr()[0]
|
||||
# backward compatibility for pytorch schedulers
|
||||
logs["learning_rate"] = (
|
||||
scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else scheduler.get_lr()[0]
|
||||
)
|
||||
logging_loss = tr_loss
|
||||
|
||||
self._log(logs)
|
||||
|
|
Загрузка…
Ссылка в новой задаче