refactor(wandb): consolidate import (#5044)
This commit is contained in:
Родитель
9e03364999
Коммит
edcb3ac59a
|
@ -21,7 +21,7 @@ from tqdm.auto import tqdm, trange
|
|||
from .data.data_collator import DataCollator, default_data_collator
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available
|
||||
from .training_args import TrainingArguments, is_torch_tpu_available
|
||||
|
||||
|
||||
|
@ -59,22 +59,9 @@ def is_tensorboard_available():
|
|||
return _has_tensorboard
|
||||
|
||||
|
||||
try:
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key is None:
|
||||
_has_wandb = False
|
||||
wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
|
||||
else:
|
||||
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
|
||||
except (ImportError, AttributeError):
|
||||
_has_wandb = False
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -10,26 +10,13 @@ import tensorflow as tf
|
|||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
|
||||
try:
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key is None:
|
||||
_has_wandb = False
|
||||
wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
|
||||
else:
|
||||
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
|
||||
except (ImportError, AttributeError):
|
||||
_has_wandb = False
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1,8 +1,26 @@
|
|||
import os
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key is None:
|
||||
_has_wandb = False
|
||||
wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
|
||||
else:
|
||||
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
|
||||
except (ImportError, AttributeError):
|
||||
_has_wandb = False
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
"""
|
||||
Evaluation output (always contains labels), to be used
|
||||
|
|
Загрузка…
Ссылка в новой задаче