зеркало из https://github.com/microsoft/archai.git
chore(constraints): Adds initial draft of Text Predict scoring during search.
This commit is contained in:
Родитель
c53653cdd5
Коммит
57019465dc
|
@ -14,7 +14,8 @@ from archai.nlp.nas.nas_utils.constraints.onnx_constraints import (measure_onnx_
|
|||
from archai.nlp.nas.nas_utils.constraints.torch_constraints import (measure_torch_inference_latency,
|
||||
measure_torch_parameters,
|
||||
measure_torch_peak_memory,
|
||||
measure_torch_perplexity)
|
||||
measure_torch_perplexity,
|
||||
measure_torch_text_predict)
|
||||
|
||||
# Latency upper bound on different device targets
|
||||
# Any model with more latency than this will be removed from consideration during search
|
||||
|
@ -88,9 +89,10 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
|||
|
||||
def __init__(self,
|
||||
training_strategy: Optional[str] = 'decoder_params',
|
||||
training_dataset: Optional[str] = 'wt103',
|
||||
training_vocab_type: Optional[str] = 'word',
|
||||
training_vocab_size: Optional[int] = 10000,
|
||||
dataset: Optional[str] = 'wt103',
|
||||
scoring_file: Optional[str] = None,
|
||||
vocab_type: Optional[str] = 'word',
|
||||
vocab_size: Optional[int] = 10000,
|
||||
training_max_step: Optional[int] = 100,
|
||||
use_quantization: Optional[bool] = False,
|
||||
use_median: Optional[bool] = False,
|
||||
|
@ -102,10 +104,11 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
|||
"""Overrides initialization method.
|
||||
|
||||
Args:
|
||||
training_strategy: Training strategy (`decoder_params`, `val_ppl` or `char_acc_rate`).
|
||||
training_dataset: Training dataset (if not using `decoder_params`).
|
||||
training_vocab_type: Type of training vocabulary (if not using `decoder_params`).
|
||||
training_vocab_size: Size of training vocabulary (if not using `decoder_params`).
|
||||
training_strategy: Training strategy (`decoder_params`, `val_ppl` or `text_predict`).
|
||||
dataset: Dataset (if not using `decoder_params`).
|
||||
scoring_file: Scoring .ljson file (if using `text_predict`).
|
||||
vocab_type: Type of vocabulary (if not using `decoder_params`).
|
||||
vocab_size: Size of vocabulary (if not using `decoder_params`).
|
||||
training_max_step: Maximum training steps (if not using `decoder_params`).
|
||||
use_quantization: Whether measurement should be calculated with quantizated model or not.
|
||||
use_median: Whether should use median instead of mean for measurement.
|
||||
|
@ -118,19 +121,23 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
|||
"""
|
||||
|
||||
self.training_strategy = training_strategy
|
||||
self.training_dataset = training_dataset
|
||||
self.training_vocab_type = training_vocab_type
|
||||
self.training_vocab_size = training_vocab_size
|
||||
self.dataset = dataset
|
||||
self.scoring_file = scoring_file
|
||||
self.vocab_type = vocab_type
|
||||
self.vocab_size = vocab_size
|
||||
self.training_max_step = training_max_step
|
||||
|
||||
super().__init__(use_quantization, use_median, batch_size,
|
||||
seq_len, n_threads, n_trials, device)
|
||||
|
||||
def __call__(self, model: torch.nn.Module) -> Tuple[Union[int, float], int, float, float]:
|
||||
def __call__(self,
|
||||
model: torch.nn.Module,
|
||||
model_config: Dict[str, Any]) -> Tuple[Union[int, float], int, float, float]:
|
||||
"""Invokes the built-in call method.
|
||||
|
||||
Args:
|
||||
model: Model to be used within constraint pipeline.
|
||||
model_config: Configuration of model.
|
||||
|
||||
Returns:
|
||||
(Tuple[Union[int, float], int, float, float]): Decoder parameters or
|
||||
|
@ -143,16 +150,26 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
|||
measure_torch_proxy = measure_torch_parameters(model, ['non_embedding'])
|
||||
elif self.training_strategy == 'val_ppl':
|
||||
# Validation perplexity
|
||||
measure_torch_proxy = measure_torch_perplexity(model,
|
||||
dataset=self.training_dataset,
|
||||
vocab_type=self.training_vocab_type,
|
||||
vocab_size=self.training_vocab_size,
|
||||
max_step=self.training_max_step)
|
||||
_, measure_torch_proxy = measure_torch_perplexity(model,
|
||||
model_config,
|
||||
dataset=self.dataset,
|
||||
vocab_type=self.vocab_type,
|
||||
vocab_size=self.vocab_size,
|
||||
max_step=self.training_max_step)
|
||||
elif self.training_strategy == 'text_predict':
|
||||
# Text Predict with character acceptance rate
|
||||
measure_torch_proxy = measure_torch_text_predict(model,
|
||||
model_config,
|
||||
dataset=self.dataset,
|
||||
scoring_file=self.scoring_file,
|
||||
vocab_type=self.vocab_type,
|
||||
vocab_size=self.vocab_size,
|
||||
max_step=self.training_max_step)
|
||||
else:
|
||||
raise NotImplementedError(f'training_strategy: {self.training_strategy} has not been implemented yet.')
|
||||
|
||||
return (
|
||||
# Proxy (either decoder parameters or validation perplexity)
|
||||
# Proxy (decoder parameters, validation perplexity or character acceptance rate)
|
||||
measure_torch_proxy,
|
||||
|
||||
# Number of total parameters
|
||||
|
|
|
@ -4,15 +4,18 @@
|
|||
"""PyTorch-based constraints.
|
||||
"""
|
||||
|
||||
from argparse import Namespace
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch.profiler import ProfilerActivity, profile
|
||||
|
||||
from archai.nlp.compression.quantization.ptq import dynamic_quantization_torch_from_model
|
||||
from archai.nlp import train
|
||||
from archai.nlp.compression.quantization.ptq import dynamic_quantization_torch_from_model
|
||||
from archai.nlp.metrics.text_predict.predictor import run_score
|
||||
|
||||
|
||||
def measure_torch_inference_latency(model: torch.nn.Module,
|
||||
|
@ -129,21 +132,23 @@ def measure_torch_peak_memory(model: torch.nn.Module,
|
|||
|
||||
|
||||
def measure_torch_perplexity(model: torch.nn.Module,
|
||||
model_config: Dict[str, Any],
|
||||
dataset: Optional[str] = 'wt103',
|
||||
vocab_type: Optional[str] = 'word',
|
||||
vocab_size: Optional[int] = 10000,
|
||||
max_step: Optional[int] = 100) -> float:
|
||||
max_step: Optional[int] = 100) -> Tuple[Namespace, float]:
|
||||
"""Measures a model's validation perplexity.
|
||||
|
||||
Args:
|
||||
model: Model instance.
|
||||
model_config: Configuration of the model.
|
||||
dataset: Training dataset.
|
||||
vocab_type: Type of vocabulary.
|
||||
vocab_size: Vocabulary size.
|
||||
max_step: Maximum training steps.
|
||||
|
||||
Returns:
|
||||
(float): Validation perplexity.
|
||||
(Tuple[Namespace, float]): Training arguments and validation perplexity.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -165,7 +170,45 @@ def measure_torch_perplexity(model: torch.nn.Module,
|
|||
model.to(device)
|
||||
scheduler, scheduler_sparse = train.create_scheduler(args, optimizer, optimizer_sparse)
|
||||
_, best_val_loss, _ = train.train_main(args, device, train_itr, valid_itr, model, para_model,
|
||||
None, optimizer, optimizer_sparse, scheduler,
|
||||
model_config, optimizer, optimizer_sparse, scheduler,
|
||||
scheduler_sparse, scaler, vocab, file_stats[1])
|
||||
|
||||
return math.exp(best_val_loss)
|
||||
return args, math.exp(best_val_loss)
|
||||
|
||||
|
||||
def measure_torch_text_predict(model: torch.nn.Module,
|
||||
model_config: Dict[str, Any],
|
||||
dataset: Optional[str] = 'wt103',
|
||||
scoring_file: Optional[str] = None,
|
||||
vocab_type: Optional[str] = 'word',
|
||||
vocab_size: Optional[int] = 10000,
|
||||
max_step: Optional[int] = 100) -> float:
|
||||
"""Measures a model's character acceptance rate with Text Predict.
|
||||
|
||||
Args:
|
||||
model: Model instance.
|
||||
model_config: Configuration of the model.
|
||||
dataset: Training dataset.
|
||||
scoring_file: Scoring .ljson file.
|
||||
vocab_type: Type of vocabulary.
|
||||
vocab_size: Vocabulary size.
|
||||
max_step: Maximum training steps.
|
||||
|
||||
Returns:
|
||||
(float): Character acceptance rate.
|
||||
|
||||
"""
|
||||
|
||||
if vocab_type == 'word':
|
||||
raise ValueError('`vocab_type` should be either `bbpe` or `gpt2`.')
|
||||
|
||||
# Re-uses the perplexity function to train the model
|
||||
args, _ = measure_torch_perplexity(model, model_config, dataset, vocab_type, vocab_size, max_step=10)
|
||||
|
||||
# Defines some missing variables to run TextPredict
|
||||
model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
|
||||
vocab_path = os.path.join(args.cache_dir, dataset, vocab_type, str(vocab_size), 'vocab', 'bbpe_tokenizer.json')
|
||||
input_file_type = 'smartcompose'
|
||||
|
||||
# Runs the Text Predict scoring function
|
||||
run_score(args.work_dir, model_path, vocab_path, scoring_file, input_file_type, args.model_type)
|
||||
|
|
Загрузка…
Ссылка в новой задаче