This commit is contained in:
thomwolf 2019-12-20 13:47:35 +01:00
Родитель 15dda5ea32
Коммит 73fcebf7ec
3 изменённых файлов: 27 добавлений и 16 удалений

Просмотреть файл

@ -38,9 +38,9 @@ from setuptools import find_packages, setup
extras = {
'serving': ['uvicorn', 'fastapi'],
'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['uvicorn', 'fastapi', 'torch']
'serving': ['pydantic', 'uvicorn', 'fastapi'],
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
}
extras['all'] = [package for package in extras.values()]

Просмотреть файл

@ -3,9 +3,9 @@ from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from transformers.commands.convert import ConvertCommand
from transformers.commands.serving import ServeCommand
if __name__ == '__main__':
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')

Просмотреть файл

@ -1,16 +1,23 @@
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Body
from logging import getLogger
import logging
from pydantic import BaseModel
from uvicorn import run
try:
from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
_serve_dependancies_installed = True
except (ImportError, AttributeError):
BaseModel = object
Body = lambda *x, **y: None
_serve_dependancies_installed = False
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving')
def serve_command_factory(args: Namespace):
"""
@ -70,20 +77,24 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int):
self._logger = getLogger('transformers-cli/serving')
self._pipeline = pipeline
self._logger.info('Serving model over {}:{}'.format(host, port))
self._host = host
self._port = port
self._app = FastAPI()
if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly.")
else:
logger.info('Serving model over {}:{}'.format(host, port))
self._app = FastAPI()
# Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
# Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
def run(self):
run(self._app, host=self._host, port=self._port)