fix(config_base): Fixes Config class not being inheritted by PretrainedConfig.

This commit is contained in:
Gustavo de Rosa 2022-03-30 16:31:14 -03:00 коммит произвёл Gustavo Rosa
Родитель e898dd35f5
Коммит ed7fb59ea8
3 изменённых файлов: 12 добавлений и 37 удалений

Просмотреть файл

@ -9,39 +9,22 @@ from typing import Any, Dict, Mapping, Optional, List, Union
import torch
from transformers import PretrainedConfig
class Config:
class Config(PretrainedConfig):
"""Base configuration class, used to define some common attributes
and shared methods for loading and saving configurations.
"""
hyperparameter_map: Dict[str, str] = {}
def __getattribute__(self, key: str) -> Any:
if key != 'hyperparameter_map' and key in super().__getattribute__('hyperparameter_map'):
key = super().__getattribute__('hyperparameter_map')[key]
return super().__getattribute__(key)
def __setattr__(self, key: str, value: Any) -> None:
if key in super().__getattribute__('hyperparameter_map'):
key = super().__getattribute__('hyperparameter_map')[key]
super().__setattr__(key, value)
def __init__(self, **kwargs) -> None:
"""Initializes the class by verifying whether keyword arguments
are valid and setting them as attributes.
"""
# Non-default attributes
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as e:
raise e
super().__init__(**kwargs)
def _map_to_list(self,
variable: Union[int, float, List[Union[int, float]]],
@ -58,14 +41,6 @@ class Config:
return [variable] * size
def to_dict(self) -> Dict[str, Any]:
config_dict = {}
for key, value in self.__dict__.items():
config_dict[key] = value
return config_dict
class SearchConfigParameter:
"""Base search configuration parameter class, used to define whether

Просмотреть файл

@ -7,15 +7,15 @@
from typing import List, Optional, Union
from archai.nlp.models.config_base import Config, SearchConfigParameter, SearchConfig
from transformers import CONFIG_MAPPING, PretrainedConfig
from transformers import CONFIG_MAPPING
class HfGPT2Config(Config, PretrainedConfig):
class HfGPT2Config(Config):
"""Huggingface's Open AI GPT-2 default configuration.
"""
hyperparameter_map = {
attribute_map = {
'n_token': 'vocab_size',
'tgt_len': 'n_positions',
'd_model': 'n_embd',
@ -24,7 +24,7 @@ class HfGPT2Config(Config, PretrainedConfig):
'dropatt': 'attn_pdrop',
'weight_init_std': 'initializer_range'
}
hyperparameter_map.update(CONFIG_MAPPING['gpt2']().attribute_map)
attribute_map.update(CONFIG_MAPPING['gpt2']().attribute_map)
def __init__(self,
n_token: Optional[int] = 10000, # changed from 50257 for model's production

Просмотреть файл

@ -7,21 +7,21 @@
from typing import List, Optional
from archai.nlp.models.config_base import Config, SearchConfigParameter, SearchConfig
from transformers import CONFIG_MAPPING, PretrainedConfig
from transformers import CONFIG_MAPPING
class HfTransfoXLConfig(Config, PretrainedConfig):
class HfTransfoXLConfig(Config):
"""Huggingface's Transformer-XL default configuration.
"""
hyperparameter_map = {
attribute_map = {
'n_token': 'vocab_size',
'weight_init_type': 'init',
'weight_init_range': 'init_range',
'weight_init_std': 'init_std',
}
hyperparameter_map.update(CONFIG_MAPPING['transfo-xl']().attribute_map)
attribute_map.update(CONFIG_MAPPING['transfo-xl']().attribute_map)
def __init__(self,
n_token: Optional[int] = 267736,