зеркало из https://github.com/microsoft/archai.git
fix(config_base): Fixes Config class not being inheritted by PretrainedConfig.
This commit is contained in:
Родитель
e898dd35f5
Коммит
ed7fb59ea8
|
@ -9,39 +9,22 @@ from typing import Any, Dict, Mapping, Optional, List, Union
|
|||
|
||||
import torch
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
class Config:
|
||||
|
||||
class Config(PretrainedConfig):
|
||||
"""Base configuration class, used to define some common attributes
|
||||
and shared methods for loading and saving configurations.
|
||||
|
||||
"""
|
||||
|
||||
hyperparameter_map: Dict[str, str] = {}
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
if key != 'hyperparameter_map' and key in super().__getattribute__('hyperparameter_map'):
|
||||
key = super().__getattribute__('hyperparameter_map')[key]
|
||||
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
if key in super().__getattribute__('hyperparameter_map'):
|
||||
key = super().__getattribute__('hyperparameter_map')[key]
|
||||
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""Initializes the class by verifying whether keyword arguments
|
||||
are valid and setting them as attributes.
|
||||
|
||||
"""
|
||||
|
||||
# Non-default attributes
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except AttributeError as e:
|
||||
raise e
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _map_to_list(self,
|
||||
variable: Union[int, float, List[Union[int, float]]],
|
||||
|
@ -58,14 +41,6 @@ class Config:
|
|||
|
||||
return [variable] * size
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
config_dict = {}
|
||||
|
||||
for key, value in self.__dict__.items():
|
||||
config_dict[key] = value
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
class SearchConfigParameter:
|
||||
"""Base search configuration parameter class, used to define whether
|
||||
|
|
|
@ -7,15 +7,15 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
from archai.nlp.models.config_base import Config, SearchConfigParameter, SearchConfig
|
||||
from transformers import CONFIG_MAPPING, PretrainedConfig
|
||||
from transformers import CONFIG_MAPPING
|
||||
|
||||
|
||||
class HfGPT2Config(Config, PretrainedConfig):
|
||||
class HfGPT2Config(Config):
|
||||
"""Huggingface's Open AI GPT-2 default configuration.
|
||||
|
||||
"""
|
||||
|
||||
hyperparameter_map = {
|
||||
attribute_map = {
|
||||
'n_token': 'vocab_size',
|
||||
'tgt_len': 'n_positions',
|
||||
'd_model': 'n_embd',
|
||||
|
@ -24,7 +24,7 @@ class HfGPT2Config(Config, PretrainedConfig):
|
|||
'dropatt': 'attn_pdrop',
|
||||
'weight_init_std': 'initializer_range'
|
||||
}
|
||||
hyperparameter_map.update(CONFIG_MAPPING['gpt2']().attribute_map)
|
||||
attribute_map.update(CONFIG_MAPPING['gpt2']().attribute_map)
|
||||
|
||||
def __init__(self,
|
||||
n_token: Optional[int] = 10000, # changed from 50257 for model's production
|
||||
|
|
|
@ -7,21 +7,21 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from archai.nlp.models.config_base import Config, SearchConfigParameter, SearchConfig
|
||||
from transformers import CONFIG_MAPPING, PretrainedConfig
|
||||
from transformers import CONFIG_MAPPING
|
||||
|
||||
|
||||
class HfTransfoXLConfig(Config, PretrainedConfig):
|
||||
class HfTransfoXLConfig(Config):
|
||||
"""Huggingface's Transformer-XL default configuration.
|
||||
|
||||
"""
|
||||
|
||||
hyperparameter_map = {
|
||||
attribute_map = {
|
||||
'n_token': 'vocab_size',
|
||||
'weight_init_type': 'init',
|
||||
'weight_init_range': 'init_range',
|
||||
'weight_init_std': 'init_std',
|
||||
}
|
||||
hyperparameter_map.update(CONFIG_MAPPING['transfo-xl']().attribute_map)
|
||||
attribute_map.update(CONFIG_MAPPING['transfo-xl']().attribute_map)
|
||||
|
||||
def __init__(self,
|
||||
n_token: Optional[int] = 267736,
|
||||
|
|
Загрузка…
Ссылка в новой задаче