update serving command
This commit is contained in:
Родитель
15dda5ea32
Коммит
73fcebf7ec
6
setup.py
6
setup.py
|
@ -38,9 +38,9 @@ from setuptools import find_packages, setup
|
|||
|
||||
|
||||
extras = {
|
||||
'serving': ['uvicorn', 'fastapi'],
|
||||
'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'],
|
||||
'serving-torch': ['uvicorn', 'fastapi', 'torch']
|
||||
'serving': ['pydantic', 'uvicorn', 'fastapi'],
|
||||
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
|
||||
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
|
||||
}
|
||||
extras['all'] = [package for package in extras.values()]
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ from argparse import ArgumentParser
|
|||
|
||||
from transformers.commands.download import DownloadCommand
|
||||
from transformers.commands.run import RunCommand
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.commands.user import UserCommands
|
||||
from transformers.commands.convert import ConvertCommand
|
||||
from transformers.commands.serving import ServeCommand
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
from argparse import ArgumentParser, Namespace
|
||||
from typing import List, Optional, Union, Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from logging import getLogger
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from uvicorn import run
|
||||
try:
|
||||
from uvicorn import run
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from pydantic import BaseModel
|
||||
_serve_dependancies_installed = True
|
||||
except (ImportError, AttributeError):
|
||||
BaseModel = object
|
||||
Body = lambda *x, **y: None
|
||||
_serve_dependancies_installed = False
|
||||
|
||||
from transformers import Pipeline
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
logger = logging.getLogger('transformers-cli/serving')
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
"""
|
||||
|
@ -70,20 +77,24 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
||||
self._logger = getLogger('transformers-cli/serving')
|
||||
|
||||
self._pipeline = pipeline
|
||||
|
||||
self._logger.info('Serving model over {}:{}'.format(host, port))
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._app = FastAPI()
|
||||
if not _serve_dependancies_installed:
|
||||
raise ImportError("Using serve command requires FastAPI and unicorn. "
|
||||
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||
"Or install FastAPI and unicorn separatly.")
|
||||
else:
|
||||
logger.info('Serving model over {}:{}'.format(host, port))
|
||||
self._app = FastAPI()
|
||||
|
||||
# Register routes
|
||||
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
|
||||
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
|
||||
# Register routes
|
||||
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
|
||||
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self._host, port=self._port)
|
||||
|
|
Загрузка…
Ссылка в новой задаче