Weiyun1025's picture
Upload folder using huggingface_hub
2abfccb verified
# -*- coding: utf-8 -*-
from io import BytesIO
import os
import logging
import hashlib
import boto3
from botocore.exceptions import ClientError as BotoClientError
from botocore.client import Config
from botocore import UNSIGNED
from petrel_client.ceph.ceph import Ceph
from petrel_client.common.io_profile import profile
from petrel_client.common.exception import NoSuchBucketError, NoSuchKeyError, S3ClientError, AccessDeniedError, RangeError, MultipartError
from .generator import FileIterator
LOG = logging.getLogger(__name__)
class S3Client(Ceph):
def __init__(self, cluster, conf, anonymous_access, *args, **kwargs):
if anonymous_access:
s3_args = {
'config': Config(signature_version=UNSIGNED)
}
else:
s3_args = {
'aws_access_key_id': conf['access_key'],
'aws_secret_access_key': conf['secret_key']
}
s3_args['endpoint_url'] = conf['endpoint_url']
s3_args['verify'] = conf.get_boolean('verify_ssl', False)
super(S3Client, self).__init__(cluster, conf, *args, **kwargs)
self._cluster = cluster
self._conf = conf
self._session = boto3.session.Session()
self._s3_resource = self._session.resource(
's3',
**s3_args
)
self._multipart_config = boto3.s3.transfer.TransferConfig(conf.get_int('multipart_threshold'), conf.get_int('max_concurrency'), conf.get_int('multipart_chunksize'), conf.get_boolean('use_threads'))
# 返回值:(1) 未启用multipart时: 返回获取的内容; (2) 启用multipart时,返回写入的文件名
@profile('get')
def get_with_info(self, cluster, bucket, key, **kwargs):
enable_etag = kwargs.get('enable_etag', False)
enable_stream = kwargs.get('enable_stream', False)
info = {}
get_kwargs = {}
if enable_etag:
info['etag'] = obj['ETag'].strip('"')
assert self._cluster == cluster
range = kwargs.get('range', None)
if range:
get_kwargs['Range'] = f'bytes={range}'
try:
if None != kwargs.get("multipart") and True == kwargs["multipart"]:
if None == kwargs.get("filename"):
raise MultipartError("need filename")
# MB = 1024**2
# config = boto3.s3.transfer.TransferConfig(multipart_threshold=100 * MB, max_concurrency=8, multipart_chunksize=2 * MB, use_threads=True)
self._s3_resource.meta.client.download_file(bucket, key, kwargs["filename"], Config=self._multipart_config)
return kwargs["filename"], info
else:
obj = self._s3_resource.Object(bucket, key).get(**get_kwargs)
content = obj['Body']
if not enable_stream:
content = content.read()
return content, info
except BotoClientError as err:
if type(err).__name__ == 'NoSuchKey':
# 这里的 err 的类型是 botocore.errorfactory.NoSuchKey 或 NoSuchBucket
# 但是该类型是通过
# type(exception_name, (ClientError,), {}) // botocore.errorfactory.py:83
# 运行时构造的,目前的办法只能通过其基类 ClientError 来捕捉
raise NoSuchKeyError(cluster, bucket, key)
elif type(err).__name__ == 'NoSuchBucket':
raise NoSuchBucketError(cluster, bucket)
elif err.response['ResponseMetadata']['HTTPStatusCode'] == 403:
raise AccessDeniedError(err)
elif err.response['ResponseMetadata']['HTTPStatusCode'] == 416:
raise RangeError(range)
else:
raise S3ClientError(err)
def create_bucket(self, bucket):
return self._s3_resource.create_bucket(Bucket=bucket)
def isdir(self, bucket, key):
itr = self.list(bucket, key)
try:
next(itr)
return True
except StopIteration:
return False
def list(self, bucket, key, page_size=None):
if key is None:
key = ''
elif key and not key.endswith('/'):
key = key + '/'
client = self._s3_resource.meta.client
paginator = client.get_paginator('list_objects')
paging_args = {
'Bucket': bucket, 'Prefix': key, 'Delimiter': '/',
'PaginationConfig': {'PageSize': page_size}
}
itr = paginator.paginate(**paging_args)
for response_data in itr:
common_prefixes = response_data.get('CommonPrefixes', [])
contents = response_data.get('Contents', [])
for common_prefix in common_prefixes:
prefix_components = common_prefix['Prefix'].split('/')
prefix = prefix_components[-2]
yield prefix + '/'
for content in contents:
filename_components = content['Key'].split('/')
filename = filename_components[-1]
yield filename
def get_file_iterator(self, bucket, key):
client = self._s3_resource.meta.client
path = 's3://{0}'.format(bucket)
if key:
path = path + '/' + key
file_iterator = FileIterator(client, path)
return file_iterator
@profile('put')
def put_with_info(self, cluster, bucket, key, body, **kwargs):
if isinstance(body, (bytes, bytearray)):
if len(body) <= 1024*1024*1024*2:
result, info = self.put_bytes(
cluster, bucket, key, body, **kwargs)
else:
result, info = self.multipart_upload_boto3(cluster, bucket, key, body.name, **kwargs)
elif hasattr(body, 'read'):
result, info = self.multipart_upload_boto3(
cluster, bucket, key, body.name, **kwargs)
else:
raise TypeError(
f'{type(self)} does not support content type {type(body)}')
if kwargs.get('enable_etag', False):
info['etag'] = result.e_tag.strip('"')
return result, info
def multipart_upload_boto3(self, cluster, bucket, key, filename, **kwargs):
self._s3_resource.meta.client.upload_file(filename, bucket, os.path.basename(filename), Config=self._multipart_config)
return True, {}
def put_bytes(self, cluster, bucket, key, body, **kwargs):
assert self._cluster == cluster
enable_md5 = kwargs.get('enable_md5', False)
info = {}
try:
obj = self._s3_resource.Object(bucket, key)
obj.put(Body=body)
if enable_md5:
info['md5'] = hashlib.md5(body).hexdigest()
return obj, info
except BotoClientError as err:
if err.response['ResponseMetadata']['HTTPStatusCode'] == 403:
raise AccessDeniedError(err)
else:
raise S3ClientError(err)
def multipart_upload(self, cluster, bucket, key, stream, chunk_size=1024 * 1024 * 1024 * 2, **kwargs):
assert self._cluster == cluster
info = {}
obj = self._s3_resource.Object(bucket, key)
multipart = obj.initiate_multipart_upload()
part_id = 0
parts = []
total_size = 0
enable_md5 = kwargs.get('enable_md5', False)
if enable_md5:
md5 = hashlib.md5()
while True:
chunk = stream.read(chunk_size)
actual_size = len(chunk)
if actual_size == 0:
break
part_id += 1
total_size += actual_size
part = multipart.Part(part_id)
response = part.upload(Body=chunk)
parts.append({
"PartNumber": part_id,
"ETag": response["ETag"]
})
if enable_md5:
md5.update(chunk)
part_info = {
'Parts': parts
}
result = multipart.complete(MultipartUpload=part_info)
if enable_md5:
info['md5'] = md5.hexdigest()
return result, info
def size(self, cluster, bucket, key):
assert self._cluster == cluster
try:
obj = self._s3_resource.Object(bucket, key)
obj.load()
return obj.content_length
except BotoClientError as err:
if err.response['ResponseMetadata']['HTTPStatusCode'] == 404:
raise NoSuchKeyError(cluster, bucket, key, err) from None
elif err.response['ResponseMetadata']['HTTPStatusCode'] == 403:
raise AccessDeniedError(err) from None
else:
raise S3ClientError(err) from None
def contains(self, cluster, bucket, key):
assert self._cluster == cluster
try:
self._s3_resource.Object(bucket, key).load()
return True
except BotoClientError as err:
if err.response['ResponseMetadata']['HTTPStatusCode'] == 404:
return False
elif err.response['ResponseMetadata']['HTTPStatusCode'] == 403:
raise AccessDeniedError(err)
else:
raise S3ClientError(err)
def delete(self, cluster, bucket, key, **kwargs):
assert self._cluster == cluster
try:
return self._s3_resource.Object(bucket, key).delete()
except BotoClientError as err:
if type(err).__name__ == 'NoSuchKey':
raise NoSuchKeyError(cluster, bucket, key)
elif type(err).__name__ == 'NoSuchBucket':
raise NoSuchBucketError(cluster, bucket)
elif err.response['ResponseMetadata']['HTTPStatusCode'] == 403:
raise AccessDeniedError(err)
else:
raise S3ClientError(err)
def generate_presigned_url(self, cluster, bucket, key, client_method, expires_in):
assert self._cluster == cluster
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.generate_presigned_url
return self._s3_resource.meta.client.generate_presigned_url(
client_method,
{'Bucket': bucket, 'Key': key},
expires_in
)
def generate_presigned_post(self, cluster, bucket, key, fields=None, conditions=None, expires_in=3600):
assert self._cluster == cluster
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.generate_presigned_post
return self._s3_resource.meta.client.generate_presigned_post(
bucket, key, fields, conditions, expires_in
)