зеркало из https://github.com/microsoft/archai.git
chore(constraints): Adds initial draft of Text Predict scoring during search.
This commit is contained in:
Родитель
c53653cdd5
Коммит
57019465dc
|
@ -14,7 +14,8 @@ from archai.nlp.nas.nas_utils.constraints.onnx_constraints import (measure_onnx_
|
||||||
from archai.nlp.nas.nas_utils.constraints.torch_constraints import (measure_torch_inference_latency,
|
from archai.nlp.nas.nas_utils.constraints.torch_constraints import (measure_torch_inference_latency,
|
||||||
measure_torch_parameters,
|
measure_torch_parameters,
|
||||||
measure_torch_peak_memory,
|
measure_torch_peak_memory,
|
||||||
measure_torch_perplexity)
|
measure_torch_perplexity,
|
||||||
|
measure_torch_text_predict)
|
||||||
|
|
||||||
# Latency upper bound on different device targets
|
# Latency upper bound on different device targets
|
||||||
# Any model with more latency than this will be removed from consideration during search
|
# Any model with more latency than this will be removed from consideration during search
|
||||||
|
@ -88,9 +89,10 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
training_strategy: Optional[str] = 'decoder_params',
|
training_strategy: Optional[str] = 'decoder_params',
|
||||||
training_dataset: Optional[str] = 'wt103',
|
dataset: Optional[str] = 'wt103',
|
||||||
training_vocab_type: Optional[str] = 'word',
|
scoring_file: Optional[str] = None,
|
||||||
training_vocab_size: Optional[int] = 10000,
|
vocab_type: Optional[str] = 'word',
|
||||||
|
vocab_size: Optional[int] = 10000,
|
||||||
training_max_step: Optional[int] = 100,
|
training_max_step: Optional[int] = 100,
|
||||||
use_quantization: Optional[bool] = False,
|
use_quantization: Optional[bool] = False,
|
||||||
use_median: Optional[bool] = False,
|
use_median: Optional[bool] = False,
|
||||||
|
@ -102,10 +104,11 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
||||||
"""Overrides initialization method.
|
"""Overrides initialization method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
training_strategy: Training strategy (`decoder_params`, `val_ppl` or `char_acc_rate`).
|
training_strategy: Training strategy (`decoder_params`, `val_ppl` or `text_predict`).
|
||||||
training_dataset: Training dataset (if not using `decoder_params`).
|
dataset: Dataset (if not using `decoder_params`).
|
||||||
training_vocab_type: Type of training vocabulary (if not using `decoder_params`).
|
scoring_file: Scoring .ljson file (if using `text_predict`).
|
||||||
training_vocab_size: Size of training vocabulary (if not using `decoder_params`).
|
vocab_type: Type of vocabulary (if not using `decoder_params`).
|
||||||
|
vocab_size: Size of vocabulary (if not using `decoder_params`).
|
||||||
training_max_step: Maximum training steps (if not using `decoder_params`).
|
training_max_step: Maximum training steps (if not using `decoder_params`).
|
||||||
use_quantization: Whether measurement should be calculated with quantizated model or not.
|
use_quantization: Whether measurement should be calculated with quantizated model or not.
|
||||||
use_median: Whether should use median instead of mean for measurement.
|
use_median: Whether should use median instead of mean for measurement.
|
||||||
|
@ -118,19 +121,23 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.training_strategy = training_strategy
|
self.training_strategy = training_strategy
|
||||||
self.training_dataset = training_dataset
|
self.dataset = dataset
|
||||||
self.training_vocab_type = training_vocab_type
|
self.scoring_file = scoring_file
|
||||||
self.training_vocab_size = training_vocab_size
|
self.vocab_type = vocab_type
|
||||||
|
self.vocab_size = vocab_size
|
||||||
self.training_max_step = training_max_step
|
self.training_max_step = training_max_step
|
||||||
|
|
||||||
super().__init__(use_quantization, use_median, batch_size,
|
super().__init__(use_quantization, use_median, batch_size,
|
||||||
seq_len, n_threads, n_trials, device)
|
seq_len, n_threads, n_trials, device)
|
||||||
|
|
||||||
def __call__(self, model: torch.nn.Module) -> Tuple[Union[int, float], int, float, float]:
|
def __call__(self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
model_config: Dict[str, Any]) -> Tuple[Union[int, float], int, float, float]:
|
||||||
"""Invokes the built-in call method.
|
"""Invokes the built-in call method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model to be used within constraint pipeline.
|
model: Model to be used within constraint pipeline.
|
||||||
|
model_config: Configuration of model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tuple[Union[int, float], int, float, float]): Decoder parameters or
|
(Tuple[Union[int, float], int, float, float]): Decoder parameters or
|
||||||
|
@ -143,16 +150,26 @@ class TorchConstraintPipeline(ConstraintPipeline):
|
||||||
measure_torch_proxy = measure_torch_parameters(model, ['non_embedding'])
|
measure_torch_proxy = measure_torch_parameters(model, ['non_embedding'])
|
||||||
elif self.training_strategy == 'val_ppl':
|
elif self.training_strategy == 'val_ppl':
|
||||||
# Validation perplexity
|
# Validation perplexity
|
||||||
measure_torch_proxy = measure_torch_perplexity(model,
|
_, measure_torch_proxy = measure_torch_perplexity(model,
|
||||||
dataset=self.training_dataset,
|
model_config,
|
||||||
vocab_type=self.training_vocab_type,
|
dataset=self.dataset,
|
||||||
vocab_size=self.training_vocab_size,
|
vocab_type=self.vocab_type,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
max_step=self.training_max_step)
|
||||||
|
elif self.training_strategy == 'text_predict':
|
||||||
|
# Text Predict with character acceptance rate
|
||||||
|
measure_torch_proxy = measure_torch_text_predict(model,
|
||||||
|
model_config,
|
||||||
|
dataset=self.dataset,
|
||||||
|
scoring_file=self.scoring_file,
|
||||||
|
vocab_type=self.vocab_type,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
max_step=self.training_max_step)
|
max_step=self.training_max_step)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'training_strategy: {self.training_strategy} has not been implemented yet.')
|
raise NotImplementedError(f'training_strategy: {self.training_strategy} has not been implemented yet.')
|
||||||
|
|
||||||
return (
|
return (
|
||||||
# Proxy (either decoder parameters or validation perplexity)
|
# Proxy (decoder parameters, validation perplexity or character acceptance rate)
|
||||||
measure_torch_proxy,
|
measure_torch_proxy,
|
||||||
|
|
||||||
# Number of total parameters
|
# Number of total parameters
|
||||||
|
|
|
@ -4,15 +4,18 @@
|
||||||
"""PyTorch-based constraints.
|
"""PyTorch-based constraints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as benchmark
|
import torch.utils.benchmark as benchmark
|
||||||
from torch.profiler import ProfilerActivity, profile
|
from torch.profiler import ProfilerActivity, profile
|
||||||
|
|
||||||
from archai.nlp.compression.quantization.ptq import dynamic_quantization_torch_from_model
|
|
||||||
from archai.nlp import train
|
from archai.nlp import train
|
||||||
|
from archai.nlp.compression.quantization.ptq import dynamic_quantization_torch_from_model
|
||||||
|
from archai.nlp.metrics.text_predict.predictor import run_score
|
||||||
|
|
||||||
|
|
||||||
def measure_torch_inference_latency(model: torch.nn.Module,
|
def measure_torch_inference_latency(model: torch.nn.Module,
|
||||||
|
@ -129,21 +132,23 @@ def measure_torch_peak_memory(model: torch.nn.Module,
|
||||||
|
|
||||||
|
|
||||||
def measure_torch_perplexity(model: torch.nn.Module,
|
def measure_torch_perplexity(model: torch.nn.Module,
|
||||||
|
model_config: Dict[str, Any],
|
||||||
dataset: Optional[str] = 'wt103',
|
dataset: Optional[str] = 'wt103',
|
||||||
vocab_type: Optional[str] = 'word',
|
vocab_type: Optional[str] = 'word',
|
||||||
vocab_size: Optional[int] = 10000,
|
vocab_size: Optional[int] = 10000,
|
||||||
max_step: Optional[int] = 100) -> float:
|
max_step: Optional[int] = 100) -> Tuple[Namespace, float]:
|
||||||
"""Measures a model's validation perplexity.
|
"""Measures a model's validation perplexity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model instance.
|
model: Model instance.
|
||||||
|
model_config: Configuration of the model.
|
||||||
dataset: Training dataset.
|
dataset: Training dataset.
|
||||||
vocab_type: Type of vocabulary.
|
vocab_type: Type of vocabulary.
|
||||||
vocab_size: Vocabulary size.
|
vocab_size: Vocabulary size.
|
||||||
max_step: Maximum training steps.
|
max_step: Maximum training steps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(float): Validation perplexity.
|
(Tuple[Namespace, float]): Training arguments and validation perplexity.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -165,7 +170,45 @@ def measure_torch_perplexity(model: torch.nn.Module,
|
||||||
model.to(device)
|
model.to(device)
|
||||||
scheduler, scheduler_sparse = train.create_scheduler(args, optimizer, optimizer_sparse)
|
scheduler, scheduler_sparse = train.create_scheduler(args, optimizer, optimizer_sparse)
|
||||||
_, best_val_loss, _ = train.train_main(args, device, train_itr, valid_itr, model, para_model,
|
_, best_val_loss, _ = train.train_main(args, device, train_itr, valid_itr, model, para_model,
|
||||||
None, optimizer, optimizer_sparse, scheduler,
|
model_config, optimizer, optimizer_sparse, scheduler,
|
||||||
scheduler_sparse, scaler, vocab, file_stats[1])
|
scheduler_sparse, scaler, vocab, file_stats[1])
|
||||||
|
|
||||||
return math.exp(best_val_loss)
|
return args, math.exp(best_val_loss)
|
||||||
|
|
||||||
|
|
||||||
|
def measure_torch_text_predict(model: torch.nn.Module,
|
||||||
|
model_config: Dict[str, Any],
|
||||||
|
dataset: Optional[str] = 'wt103',
|
||||||
|
scoring_file: Optional[str] = None,
|
||||||
|
vocab_type: Optional[str] = 'word',
|
||||||
|
vocab_size: Optional[int] = 10000,
|
||||||
|
max_step: Optional[int] = 100) -> float:
|
||||||
|
"""Measures a model's character acceptance rate with Text Predict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model instance.
|
||||||
|
model_config: Configuration of the model.
|
||||||
|
dataset: Training dataset.
|
||||||
|
scoring_file: Scoring .ljson file.
|
||||||
|
vocab_type: Type of vocabulary.
|
||||||
|
vocab_size: Vocabulary size.
|
||||||
|
max_step: Maximum training steps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(float): Character acceptance rate.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if vocab_type == 'word':
|
||||||
|
raise ValueError('`vocab_type` should be either `bbpe` or `gpt2`.')
|
||||||
|
|
||||||
|
# Re-uses the perplexity function to train the model
|
||||||
|
args, _ = measure_torch_perplexity(model, model_config, dataset, vocab_type, vocab_size, max_step=10)
|
||||||
|
|
||||||
|
# Defines some missing variables to run TextPredict
|
||||||
|
model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
|
||||||
|
vocab_path = os.path.join(args.cache_dir, dataset, vocab_type, str(vocab_size), 'vocab', 'bbpe_tokenizer.json')
|
||||||
|
input_file_type = 'smartcompose'
|
||||||
|
|
||||||
|
# Runs the Text Predict scoring function
|
||||||
|
run_score(args.work_dir, model_path, vocab_path, scoring_file, input_file_type, args.model_type)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче