This commit is contained in:
Julien Chaumond 2020-04-25 10:43:44 -04:00
Родитель 4e817ff418
Коммит 97a375484c
4 изменённых файлов: 11 добавлений и 75 удалений

2
.github/workflows/github-torch-hub.yml поставляемый
Просмотреть файл

@ -21,7 +21,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install torch pip install torch
pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses
- name: Torch hub list - name: Torch hub list
run: | run: |

Просмотреть файл

@ -16,7 +16,7 @@ from transformers import (
) )
dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"] dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
@add_start_docstrings(AutoConfig.__doc__) @add_start_docstrings(AutoConfig.__doc__)

Просмотреть файл

@ -99,8 +99,6 @@ setup(
"tokenizers == 0.7.0", "tokenizers == 0.7.0",
# dataclasses for Python versions that don't have it # dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'", "dataclasses;python_version<'3.7'",
# accessing files from S3 directly
"boto3",
# filesystem locks e.g. to prevent parallel downloads # filesystem locks e.g. to prevent parallel downloads
"filelock", "filelock",
# for downloading models over HTTPS # for downloading models over HTTPS

Просмотреть файл

@ -19,10 +19,7 @@ from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile from zipfile import ZipFile, is_zipfile
import boto3
import requests import requests
from botocore.config import Config
from botocore.exceptions import ClientError
from filelock import FileLock from filelock import FileLock
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -144,7 +141,7 @@ def add_end_docstrings(*docstr):
def is_remote_url(url_or_filename): def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https", "s3") return parsed.scheme in ("http", "https")
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
@ -297,55 +294,6 @@ def cached_path(
return output_path return output_path
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url, proxies=None):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file, proxies=None):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available(): if is_torch_available():
@ -406,17 +354,13 @@ def get_from_cache(
etag = None etag = None
if not local_files_only: if not local_files_only:
# Get eTag to add to filename, if it exists. try:
if url.startswith("s3://"): response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
etag = s3_etag(url, proxies=proxies) if response.status_code == 200:
else: etag = response.headers.get("ETag")
try: except (EnvironmentError, requests.exceptions.Timeout):
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) # etag is already None
if response.status_code == 200: pass
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
# etag is already None
pass
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
@ -483,13 +427,7 @@ def get_from_cache(
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path) logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path) os.replace(temp_file.name, cache_path)