rm boto3 dependency
This commit is contained in:
Родитель
4e817ff418
Коммит
97a375484c
|
@ -21,7 +21,7 @@ jobs:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install torch
|
pip install torch
|
||||||
pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses
|
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses
|
||||||
|
|
||||||
- name: Torch hub list
|
- name: Torch hub list
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -16,7 +16,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
|
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(AutoConfig.__doc__)
|
@add_start_docstrings(AutoConfig.__doc__)
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -99,8 +99,6 @@ setup(
|
||||||
"tokenizers == 0.7.0",
|
"tokenizers == 0.7.0",
|
||||||
# dataclasses for Python versions that don't have it
|
# dataclasses for Python versions that don't have it
|
||||||
"dataclasses;python_version<'3.7'",
|
"dataclasses;python_version<'3.7'",
|
||||||
# accessing files from S3 directly
|
|
||||||
"boto3",
|
|
||||||
# filesystem locks e.g. to prevent parallel downloads
|
# filesystem locks e.g. to prevent parallel downloads
|
||||||
"filelock",
|
"filelock",
|
||||||
# for downloading models over HTTPS
|
# for downloading models over HTTPS
|
||||||
|
|
|
@ -19,10 +19,7 @@ from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from zipfile import ZipFile, is_zipfile
|
from zipfile import ZipFile, is_zipfile
|
||||||
|
|
||||||
import boto3
|
|
||||||
import requests
|
import requests
|
||||||
from botocore.config import Config
|
|
||||||
from botocore.exceptions import ClientError
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
@ -144,7 +141,7 @@ def add_end_docstrings(*docstr):
|
||||||
|
|
||||||
def is_remote_url(url_or_filename):
|
def is_remote_url(url_or_filename):
|
||||||
parsed = urlparse(url_or_filename)
|
parsed = urlparse(url_or_filename)
|
||||||
return parsed.scheme in ("http", "https", "s3")
|
return parsed.scheme in ("http", "https")
|
||||||
|
|
||||||
|
|
||||||
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
||||||
|
@ -297,55 +294,6 @@ def cached_path(
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
def split_s3_path(url):
|
|
||||||
"""Split a full s3 path into the bucket name and path."""
|
|
||||||
parsed = urlparse(url)
|
|
||||||
if not parsed.netloc or not parsed.path:
|
|
||||||
raise ValueError("bad s3 path {}".format(url))
|
|
||||||
bucket_name = parsed.netloc
|
|
||||||
s3_path = parsed.path
|
|
||||||
# Remove '/' at beginning of path.
|
|
||||||
if s3_path.startswith("/"):
|
|
||||||
s3_path = s3_path[1:]
|
|
||||||
return bucket_name, s3_path
|
|
||||||
|
|
||||||
|
|
||||||
def s3_request(func):
|
|
||||||
"""
|
|
||||||
Wrapper function for s3 requests in order to create more helpful error
|
|
||||||
messages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(url, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
return func(url, *args, **kwargs)
|
|
||||||
except ClientError as exc:
|
|
||||||
if int(exc.response["Error"]["Code"]) == 404:
|
|
||||||
raise EnvironmentError("file {} not found".format(url))
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@s3_request
|
|
||||||
def s3_etag(url, proxies=None):
|
|
||||||
"""Check ETag on S3 object."""
|
|
||||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
|
||||||
s3_object = s3_resource.Object(bucket_name, s3_path)
|
|
||||||
return s3_object.e_tag
|
|
||||||
|
|
||||||
|
|
||||||
@s3_request
|
|
||||||
def s3_get(url, temp_file, proxies=None):
|
|
||||||
"""Pull a file directly from S3."""
|
|
||||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
|
||||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
|
||||||
|
|
||||||
|
|
||||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
@ -406,17 +354,13 @@ def get_from_cache(
|
||||||
|
|
||||||
etag = None
|
etag = None
|
||||||
if not local_files_only:
|
if not local_files_only:
|
||||||
# Get eTag to add to filename, if it exists.
|
try:
|
||||||
if url.startswith("s3://"):
|
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||||
etag = s3_etag(url, proxies=proxies)
|
if response.status_code == 200:
|
||||||
else:
|
etag = response.headers.get("ETag")
|
||||||
try:
|
except (EnvironmentError, requests.exceptions.Timeout):
|
||||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
# etag is already None
|
||||||
if response.status_code == 200:
|
pass
|
||||||
etag = response.headers.get("ETag")
|
|
||||||
except (EnvironmentError, requests.exceptions.Timeout):
|
|
||||||
# etag is already None
|
|
||||||
pass
|
|
||||||
|
|
||||||
filename = url_to_filename(url, etag)
|
filename = url_to_filename(url, etag)
|
||||||
|
|
||||||
|
@ -483,13 +427,7 @@ def get_from_cache(
|
||||||
with temp_file_manager() as temp_file:
|
with temp_file_manager() as temp_file:
|
||||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||||
|
|
||||||
# GET file object
|
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||||
if url.startswith("s3://"):
|
|
||||||
if resume_download:
|
|
||||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
|
||||||
s3_get(url, temp_file, proxies=proxies)
|
|
||||||
else:
|
|
||||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
|
||||||
|
|
||||||
logger.info("storing %s in cache at %s", url, cache_path)
|
logger.info("storing %s in cache at %s", url, cache_path)
|
||||||
os.replace(temp_file.name, cache_path)
|
os.replace(temp_file.name, cache_path)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче