chore(constraints): Adds initial draft of Text Predict scoring during search.

This commit is contained in:
Gustavo Rosa 2022-04-04 19:18:27 +00:00
Родитель c53653cdd5
Коммит 57019465dc
2 изменённых файлов: 84 добавлений и 24 удалений

Просмотреть файл

@ -14,7 +14,8 @@ from archai.nlp.nas.nas_utils.constraints.onnx_constraints import (measure_onnx_
from archai.nlp.nas.nas_utils.constraints.torch_constraints import (measure_torch_inference_latency,
measure_torch_parameters,
measure_torch_peak_memory,
measure_torch_perplexity)
measure_torch_perplexity,
measure_torch_text_predict)
# Latency upper bound on different device targets
# Any model with more latency than this will be removed from consideration during search
@ -88,9 +89,10 @@ class TorchConstraintPipeline(ConstraintPipeline):
def __init__(self,
training_strategy: Optional[str] = 'decoder_params',
training_dataset: Optional[str] = 'wt103',
training_vocab_type: Optional[str] = 'word',
training_vocab_size: Optional[int] = 10000,
dataset: Optional[str] = 'wt103',
scoring_file: Optional[str] = None,
vocab_type: Optional[str] = 'word',
vocab_size: Optional[int] = 10000,
training_max_step: Optional[int] = 100,
use_quantization: Optional[bool] = False,
use_median: Optional[bool] = False,
@ -102,10 +104,11 @@ class TorchConstraintPipeline(ConstraintPipeline):
"""Overrides initialization method.
Args:
training_strategy: Training strategy (`decoder_params`, `val_ppl` or `char_acc_rate`).
training_dataset: Training dataset (if not using `decoder_params`).
training_vocab_type: Type of training vocabulary (if not using `decoder_params`).
training_vocab_size: Size of training vocabulary (if not using `decoder_params`).
training_strategy: Training strategy (`decoder_params`, `val_ppl` or `text_predict`).
dataset: Dataset (if not using `decoder_params`).
scoring_file: Scoring .ljson file (if using `text_predict`).
vocab_type: Type of vocabulary (if not using `decoder_params`).
vocab_size: Size of vocabulary (if not using `decoder_params`).
training_max_step: Maximum training steps (if not using `decoder_params`).
use_quantization: Whether measurement should be calculated with quantizated model or not.
use_median: Whether should use median instead of mean for measurement.
@ -118,19 +121,23 @@ class TorchConstraintPipeline(ConstraintPipeline):
"""
self.training_strategy = training_strategy
self.training_dataset = training_dataset
self.training_vocab_type = training_vocab_type
self.training_vocab_size = training_vocab_size
self.dataset = dataset
self.scoring_file = scoring_file
self.vocab_type = vocab_type
self.vocab_size = vocab_size
self.training_max_step = training_max_step
super().__init__(use_quantization, use_median, batch_size,
seq_len, n_threads, n_trials, device)
def __call__(self, model: torch.nn.Module) -> Tuple[Union[int, float], int, float, float]:
def __call__(self,
model: torch.nn.Module,
model_config: Dict[str, Any]) -> Tuple[Union[int, float], int, float, float]:
"""Invokes the built-in call method.
Args:
model: Model to be used within constraint pipeline.
model_config: Configuration of model.
Returns:
(Tuple[Union[int, float], int, float, float]): Decoder parameters or
@ -143,16 +150,26 @@ class TorchConstraintPipeline(ConstraintPipeline):
measure_torch_proxy = measure_torch_parameters(model, ['non_embedding'])
elif self.training_strategy == 'val_ppl':
# Validation perplexity
measure_torch_proxy = measure_torch_perplexity(model,
dataset=self.training_dataset,
vocab_type=self.training_vocab_type,
vocab_size=self.training_vocab_size,
max_step=self.training_max_step)
_, measure_torch_proxy = measure_torch_perplexity(model,
model_config,
dataset=self.dataset,
vocab_type=self.vocab_type,
vocab_size=self.vocab_size,
max_step=self.training_max_step)
elif self.training_strategy == 'text_predict':
# Text Predict with character acceptance rate
measure_torch_proxy = measure_torch_text_predict(model,
model_config,
dataset=self.dataset,
scoring_file=self.scoring_file,
vocab_type=self.vocab_type,
vocab_size=self.vocab_size,
max_step=self.training_max_step)
else:
raise NotImplementedError(f'training_strategy: {self.training_strategy} has not been implemented yet.')
return (
# Proxy (either decoder parameters or validation perplexity)
# Proxy (decoder parameters, validation perplexity or character acceptance rate)
measure_torch_proxy,
# Number of total parameters

Просмотреть файл

@ -4,15 +4,18 @@
"""PyTorch-based constraints.
"""
from argparse import Namespace
import math
from typing import List, Optional
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.utils.benchmark as benchmark
from torch.profiler import ProfilerActivity, profile
from archai.nlp.compression.quantization.ptq import dynamic_quantization_torch_from_model
from archai.nlp import train
from archai.nlp.compression.quantization.ptq import dynamic_quantization_torch_from_model
from archai.nlp.metrics.text_predict.predictor import run_score
def measure_torch_inference_latency(model: torch.nn.Module,
@ -129,21 +132,23 @@ def measure_torch_peak_memory(model: torch.nn.Module,
def measure_torch_perplexity(model: torch.nn.Module,
model_config: Dict[str, Any],
dataset: Optional[str] = 'wt103',
vocab_type: Optional[str] = 'word',
vocab_size: Optional[int] = 10000,
max_step: Optional[int] = 100) -> float:
max_step: Optional[int] = 100) -> Tuple[Namespace, float]:
"""Measures a model's validation perplexity.
Args:
model: Model instance.
model_config: Configuration of the model.
dataset: Training dataset.
vocab_type: Type of vocabulary.
vocab_size: Vocabulary size.
max_step: Maximum training steps.
Returns:
(float): Validation perplexity.
(Tuple[Namespace, float]): Training arguments and validation perplexity.
"""
@ -165,7 +170,45 @@ def measure_torch_perplexity(model: torch.nn.Module,
model.to(device)
scheduler, scheduler_sparse = train.create_scheduler(args, optimizer, optimizer_sparse)
_, best_val_loss, _ = train.train_main(args, device, train_itr, valid_itr, model, para_model,
None, optimizer, optimizer_sparse, scheduler,
model_config, optimizer, optimizer_sparse, scheduler,
scheduler_sparse, scaler, vocab, file_stats[1])
return math.exp(best_val_loss)
return args, math.exp(best_val_loss)
def measure_torch_text_predict(model: torch.nn.Module,
model_config: Dict[str, Any],
dataset: Optional[str] = 'wt103',
scoring_file: Optional[str] = None,
vocab_type: Optional[str] = 'word',
vocab_size: Optional[int] = 10000,
max_step: Optional[int] = 100) -> float:
"""Measures a model's character acceptance rate with Text Predict.
Args:
model: Model instance.
model_config: Configuration of the model.
dataset: Training dataset.
scoring_file: Scoring .ljson file.
vocab_type: Type of vocabulary.
vocab_size: Vocabulary size.
max_step: Maximum training steps.
Returns:
(float): Character acceptance rate.
"""
if vocab_type == 'word':
raise ValueError('`vocab_type` should be either `bbpe` or `gpt2`.')
# Re-uses the perplexity function to train the model
args, _ = measure_torch_perplexity(model, model_config, dataset, vocab_type, vocab_size, max_step=10)
# Defines some missing variables to run TextPredict
model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
vocab_path = os.path.join(args.cache_dir, dataset, vocab_type, str(vocab_size), 'vocab', 'bbpe_tokenizer.json')
input_file_type = 'smartcompose'
# Runs the Text Predict scoring function
run_score(args.work_dir, model_path, vocab_path, scoring_file, input_file_type, args.model_type)