This commit is contained in:
Julien Chaumond 2020-04-25 10:43:44 -04:00
Родитель 4e817ff418
Коммит 97a375484c
4 изменённых файлов: 11 добавлений и 75 удалений

2
.github/workflows/github-torch-hub.yml поставляемый
Просмотреть файл

@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
pip install torch
pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses
- name: Torch hub list
run: |

Просмотреть файл

@ -16,7 +16,7 @@ from transformers import (
)
dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
@add_start_docstrings(AutoConfig.__doc__)

Просмотреть файл

@ -99,8 +99,6 @@ setup(
"tokenizers == 0.7.0",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# accessing files from S3 directly
"boto3",
# filesystem locks e.g. to prevent parallel downloads
"filelock",
# for downloading models over HTTPS

Просмотреть файл

@ -19,10 +19,7 @@ from typing import Optional
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import boto3
import requests
from botocore.config import Config
from botocore.exceptions import ClientError
from filelock import FileLock
from tqdm.auto import tqdm
@ -144,7 +141,7 @@ def add_end_docstrings(*docstr):
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https", "s3")
return parsed.scheme in ("http", "https")
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
@ -297,55 +294,6 @@ def cached_path(
return output_path
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url, proxies=None):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file, proxies=None):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
@ -406,17 +354,13 @@ def get_from_cache(
etag = None
if not local_files_only:
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url, proxies=proxies)
else:
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
# etag is already None
pass
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
# etag is already None
pass
filename = url_to_filename(url, etag)
@ -483,13 +427,7 @@ def get_from_cache(
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)