rm boto3 dependency
This commit is contained in:
Родитель
4e817ff418
Коммит
97a375484c
|
@ -21,7 +21,7 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: |
|
||||
pip install torch
|
||||
pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses
|
||||
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses
|
||||
|
||||
- name: Torch hub list
|
||||
run: |
|
||||
|
|
|
@ -16,7 +16,7 @@ from transformers import (
|
|||
)
|
||||
|
||||
|
||||
dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
|
||||
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
|
||||
|
||||
|
||||
@add_start_docstrings(AutoConfig.__doc__)
|
||||
|
|
2
setup.py
2
setup.py
|
@ -99,8 +99,6 @@ setup(
|
|||
"tokenizers == 0.7.0",
|
||||
# dataclasses for Python versions that don't have it
|
||||
"dataclasses;python_version<'3.7'",
|
||||
# accessing files from S3 directly
|
||||
"boto3",
|
||||
# filesystem locks e.g. to prevent parallel downloads
|
||||
"filelock",
|
||||
# for downloading models over HTTPS
|
||||
|
|
|
@ -19,10 +19,7 @@ from typing import Optional
|
|||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from filelock import FileLock
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
@ -144,7 +141,7 @@ def add_end_docstrings(*docstr):
|
|||
|
||||
def is_remote_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https", "s3")
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
||||
|
@ -297,55 +294,6 @@ def cached_path(
|
|||
return output_path
|
||||
|
||||
|
||||
def split_s3_path(url):
|
||||
"""Split a full s3 path into the bucket name and path."""
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc or not parsed.path:
|
||||
raise ValueError("bad s3 path {}".format(url))
|
||||
bucket_name = parsed.netloc
|
||||
s3_path = parsed.path
|
||||
# Remove '/' at beginning of path.
|
||||
if s3_path.startswith("/"):
|
||||
s3_path = s3_path[1:]
|
||||
return bucket_name, s3_path
|
||||
|
||||
|
||||
def s3_request(func):
|
||||
"""
|
||||
Wrapper function for s3 requests in order to create more helpful error
|
||||
messages.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(url, *args, **kwargs):
|
||||
try:
|
||||
return func(url, *args, **kwargs)
|
||||
except ClientError as exc:
|
||||
if int(exc.response["Error"]["Code"]) == 404:
|
||||
raise EnvironmentError("file {} not found".format(url))
|
||||
else:
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_etag(url, proxies=None):
|
||||
"""Check ETag on S3 object."""
|
||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_object = s3_resource.Object(bucket_name, s3_path)
|
||||
return s3_object.e_tag
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_get(url, temp_file, proxies=None):
|
||||
"""Pull a file directly from S3."""
|
||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||
|
||||
|
||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||
if is_torch_available():
|
||||
|
@ -406,17 +354,13 @@ def get_from_cache(
|
|||
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
# Get eTag to add to filename, if it exists.
|
||||
if url.startswith("s3://"):
|
||||
etag = s3_etag(url, proxies=proxies)
|
||||
else:
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
|
@ -483,13 +427,7 @@ def get_from_cache(
|
|||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
|
Загрузка…
Ссылка в новой задаче