78 строки
2.2 KiB
Python
78 строки
2.2 KiB
Python
import os
|
|
import base64
|
|
import json
|
|
import logging
|
|
import urllib
|
|
from io import BytesIO
|
|
import sys
|
|
|
|
import requests
|
|
import toolz
|
|
from PIL import Image, ImageOps
|
|
|
|
ch = logging.StreamHandler(sys.stdout)
|
|
ch.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
ch.setFormatter(formatter)
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
|
logger = logging.getLogger(__name__)
|
|
# logger.addHandler(ch)
|
|
|
|
TEST_IMAGES ={ # Examples used for testing
|
|
'https://www.britishairways.com/assets/images/information/about-ba/fleet-facts/airbus-380-800/photo-gallery/240x295-BA-A380-exterior-2-high-res.jpg':
|
|
'n02690373 airliner',
|
|
}
|
|
|
|
def read_image_from(url):
|
|
return toolz.pipe(url,
|
|
urllib.request.urlopen,
|
|
lambda x: x.read(),
|
|
BytesIO)
|
|
|
|
|
|
def to_rgb(img_bytes):
|
|
return Image.open(img_bytes).convert('RGB')
|
|
|
|
|
|
@toolz.curry
|
|
def resize(img_file, new_size=(100, 100)):
|
|
return ImageOps.fit(img_file, new_size, Image.ANTIALIAS)
|
|
|
|
|
|
def to_base64(img):
|
|
imgio = BytesIO()
|
|
img.save(imgio, 'PNG')
|
|
imgio.seek(0)
|
|
dataimg = base64.b64encode(imgio.read())
|
|
return dataimg.decode('utf-8')
|
|
|
|
|
|
def to_img(img_url):
|
|
return toolz.pipe(img_url,
|
|
read_image_from,
|
|
to_rgb,
|
|
resize(new_size=(224,224)))
|
|
|
|
|
|
def img_url_to_json(url):
|
|
img_data = toolz.pipe(url,
|
|
to_img,
|
|
to_base64)
|
|
return json.dumps({'input':'[\"{0}\"]'.format(img_data)})
|
|
|
|
|
|
if __name__=='__main__':
|
|
logger.info('Starting classifier test')
|
|
headers = {'content-type': 'application/json'}
|
|
for url, label in TEST_IMAGES.items():
|
|
jsonimg = img_url_to_json(url)
|
|
r = requests.post(os.environ.get('MODEL_API_URL', "http://0.0.0.0:88").rstrip('/') + "/score", data=jsonimg, headers=headers)
|
|
json_response = r.json()
|
|
logger.info(json_response)
|
|
prediction=json_response['result'][0][0][0][0]
|
|
if prediction != label:
|
|
raise ValueError('The predicted label {} is not the same as {}'.format(prediction,label))
|
|
logger.info('CORRECT! {}'.format(prediction))
|