import io |
from ditk import logging |
import os |
import pickle |
import time |
from functools import lru_cache |
from typing import Union |
import torch |
from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc |
from .lock_helper import get_file_lock |
_memcached = None |
_redis_cluster = None |
if os.environ.get('DI_STORE', 'off').lower() == 'on': |
print('Enable DI-store') |
from di_store import Client |
di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml') |
di_store_client = Client(di_store_config_path) |
def save_to_di_store(data): |
return di_store_client.put(data) |
def read_from_di_store(object_ref): |
data = di_store_client.get(object_ref) |
di_store_client.delete(object_ref) |
return data |
else: |
save_to_di_store = read_from_di_store = None |
@lru_cache() |
def get_ceph_package(): |
return try_import_ceph() |
@lru_cache() |
def get_redis_package(): |
return try_import_redis() |
@lru_cache() |
def get_rediscluster_package(): |
return try_import_rediscluster() |
@lru_cache() |
def get_mc_package(): |
return try_import_mc() |
def read_from_ceph(path: str) -> object: |
""" |
Overview: |
Read file from ceph |
Arguments: |
- path (:obj:`str`): File path in ceph, start with ``"s3://"`` |
Returns: |
- (:obj:`data`): Deserialized data |
""" |
value = get_ceph_package().Get(path) |
if not value: |
raise FileNotFoundError("File({}) doesn't exist in ceph".format(path)) |
return pickle.loads(value) |
@lru_cache() |
def _get_redis(host='localhost', port=6379): |
""" |
Overview: |
Ensures redis usage |
Arguments: |
- host (:obj:`str`): Host string |
- port (:obj:`int`): Port number |
Returns: |
- (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0`` |
""" |
return get_redis_package().StrictRedis(host=host, port=port, db=0) |
def read_from_redis(path: str) -> object: |
""" |
Overview: |
Read file from redis |
Arguments: |
- path (:obj:`str`): Dile path in redis, could be a string key |
Returns: |
- (:obj:`data`): Deserialized data |
""" |
return pickle.loads(_get_redis().get(path)) |
def _ensure_rediscluster(startup_nodes=[{"host": "", "port": "7000"}]): |
""" |
Overview: |
Ensures redis usage |
Arguments: |
- List of startup nodes (:obj:`dict`) of |
- host (:obj:`str`): Host string |
- port (:obj:`int`): Port number |
Returns: |
- (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \ |
and ``False`` for ``decode_responses`` in default. |
""" |
global _redis_cluster |
if _redis_cluster is None: |
_redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False) |
return |
def read_from_rediscluster(path: str) -> object: |
""" |
Overview: |
Read file from rediscluster |
Arguments: |
- path (:obj:`str`): Dile path in rediscluster, could be a string key |
Returns: |
- (:obj:`data`): Deserialized data |
""" |
_ensure_rediscluster() |
value_bytes = _redis_cluster.get(path) |
value = pickle.loads(value_bytes) |
return value |
def read_from_file(path: str) -> object: |
""" |
Overview: |
Read file from local file system |
Arguments: |
- path (:obj:`str`): File path in local file system |
Returns: |
- (:obj:`data`): Deserialized data |
""" |
with open(path, "rb") as f: |
value = pickle.load(f) |
return value |
def _ensure_memcached(): |
""" |
Overview: |
Ensures memcache usage |
Returns: |
- (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \ |
memcached_client's ``server_list.conf`` and ``client.conf`` files |
""" |
global _memcached |
if _memcached is None: |
server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf" |
client_config_file = "/mnt/lustre/share/memcached_client/client.conf" |
_memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file) |
return |
def read_from_mc(path: str, flush=False) -> object: |
""" |
Overview: |
Read file from memcache, file must be saved by `torch.save()` |
Arguments: |
- path (:obj:`str`): File path in local system |
Returns: |
- (:obj:`data`): Deserialized data |
""" |
_ensure_memcached() |
while True: |
try: |
value = get_mc_package().pyvector() |
if flush: |
_memcached.Get(path, value, get_mc_package().MC_READ_THROUGH) |
return |
else: |
_memcached.Get(path, value) |
value_buf = get_mc_package().ConvertBuffer(value) |
value_str = io.BytesIO(value_buf) |
value_str = torch.load(value_str, map_location='cpu') |
return value_str |
except Exception: |
print('read mc failed, retry...') |
time.sleep(0.01) |
def read_from_path(path: str): |
""" |
Overview: |
Read file from ceph |
Arguments: |
- path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system |
Returns: |
- (:obj:`data`): Deserialized data |
""" |
if get_ceph_package() is None: |
logging.info( |
"You do not have ceph installed! Loading local file!" |
" If you are not testing locally, something is wrong!" |
) |
return read_from_file(path) |
else: |
return read_from_ceph(path) |
def save_file_ceph(path, data): |
""" |
Overview: |
Save pickle dumped data file to ceph |
Arguments: |
- path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not |
- data (:obj:`Any`): Could be dict, list or tensor etc. |
""" |
data = pickle.dumps(data) |
save_path = os.path.dirname(path) |
file_name = os.path.basename(path) |
ceph = get_ceph_package() |
if ceph is not None: |
if hasattr(ceph, 'save_from_string'): |
ceph.save_from_string(save_path, file_name, data) |
elif hasattr(ceph, 'put'): |
ceph.put(os.path.join(save_path, file_name), data) |
else: |
raise RuntimeError('ceph can not save file, check your ceph installation') |
else: |
size = len(data) |
if save_path == 'do_not_save': |
logging.info( |
"You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) + |
" If you are not testing locally, something is wrong!" |
) |
return |
p = os.path.join(save_path, file_name) |
with open(p, 'wb') as f: |
logging.info( |
"You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) + |
" If you are not testing locally, something is wrong!" |
) |
f.write(data) |
def save_file_redis(path, data): |
""" |
Overview: |
Save pickle dumped data file to redis |
Arguments: |
- path (:obj:`str`): File path (could be a string key) in redis |
- data (:obj:`Any`): Could be dict, list or tensor etc. |
""" |
_get_redis().set(path, pickle.dumps(data)) |
def save_file_rediscluster(path, data): |
""" |
Overview: |
Save pickle dumped data file to rediscluster |
Arguments: |
- path (:obj:`str`): File path (could be a string key) in redis |
- data (:obj:`Any`): Could be dict, list or tensor etc. |
""" |
_ensure_rediscluster() |
data = pickle.dumps(data) |
_redis_cluster.set(path, data) |
return |
def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object: |
""" |
Overview: |
Read file from path |
Arguments: |
- path (:obj:`str`): The path of file to read |
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` |
- use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system |
""" |
if fs_type is None: |
if path.lower().startswith('s3'): |
fs_type = 'ceph' |
elif get_mc_package() is not None: |
fs_type = 'mc' |
else: |
fs_type = 'normal' |
assert fs_type in ['normal', 'ceph', 'mc'] |
if fs_type == 'ceph': |
data = read_from_path(path) |
elif fs_type == 'normal': |
if use_lock: |
with get_file_lock(path, 'read'): |
data = torch.load(path, map_location='cpu') |
else: |
data = torch.load(path, map_location='cpu') |
elif fs_type == 'mc': |
data = read_from_mc(path) |
return data |
def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None: |
""" |
Overview: |
Save data to file of path |
Arguments: |
- path (:obj:`str`): The path of file to save to |
- data (:obj:`object`): The data to save |
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` |
- use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system |
""" |
if fs_type is None: |
if path.lower().startswith('s3'): |
fs_type = 'ceph' |
elif get_mc_package() is not None: |
fs_type = 'mc' |
else: |
fs_type = 'normal' |
assert fs_type in ['normal', 'ceph', 'mc'] |
if fs_type == 'ceph': |
save_file_ceph(path, data) |
elif fs_type == 'normal': |
if use_lock: |
with get_file_lock(path, 'write'): |
torch.save(data, path) |
else: |
torch.save(data, path) |
elif fs_type == 'mc': |
torch.save(data, path) |
read_from_mc(path, flush=True) |
def remove_file(path: str, fs_type: Union[None, str] = None) -> None: |
""" |
Overview: |
Remove file |
Arguments: |
- path (:obj:`str`): The path of file you want to remove |
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` |
""" |
if fs_type is None: |
fs_type = 'ceph' if path.lower().startswith('s3') else 'normal' |
assert fs_type in ['normal', 'ceph'] |
if fs_type == 'ceph': |
os.popen("aws s3 rm --recursive {}".format(path)) |
elif fs_type == 'normal': |
os.popen("rm -rf {}".format(path)) |