import io from ditk import logging import os import pickle import time from functools import lru_cache from typing import Union import torch from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc from .lock_helper import get_file_lock _memcached = None _redis_cluster = None if os.environ.get('DI_STORE', 'off').lower() == 'on': print('Enable DI-store') from di_store import Client di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml') di_store_client = Client(di_store_config_path) def save_to_di_store(data): return di_store_client.put(data) def read_from_di_store(object_ref): data = di_store_client.get(object_ref) di_store_client.delete(object_ref) return data else: save_to_di_store = read_from_di_store = None @lru_cache() def get_ceph_package(): return try_import_ceph() @lru_cache() def get_redis_package(): return try_import_redis() @lru_cache() def get_rediscluster_package(): return try_import_rediscluster() @lru_cache() def get_mc_package(): return try_import_mc() def read_from_ceph(path: str) -> object: """ Overview: Read file from ceph Arguments: - path (:obj:`str`): File path in ceph, start with ``"s3://"`` Returns: - (:obj:`data`): Deserialized data """ value = get_ceph_package().Get(path) if not value: raise FileNotFoundError("File({}) doesn't exist in ceph".format(path)) return pickle.loads(value) @lru_cache() def _get_redis(host='localhost', port=6379): """ Overview: Ensures redis usage Arguments: - host (:obj:`str`): Host string - port (:obj:`int`): Port number Returns: - (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0`` """ return get_redis_package().StrictRedis(host=host, port=port, db=0) def read_from_redis(path: str) -> object: """ Overview: Read file from redis Arguments: - path (:obj:`str`): Dile path in redis, could be a string key Returns: - (:obj:`data`): Deserialized data """ return pickle.loads(_get_redis().get(path)) def _ensure_rediscluster(startup_nodes=[{"host": "127.0.0.1", "port": "7000"}]): """ Overview: Ensures redis usage Arguments: - List of startup nodes (:obj:`dict`) of - host (:obj:`str`): Host string - port (:obj:`int`): Port number Returns: - (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \ and ``False`` for ``decode_responses`` in default. """ global _redis_cluster if _redis_cluster is None: _redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False) return def read_from_rediscluster(path: str) -> object: """ Overview: Read file from rediscluster Arguments: - path (:obj:`str`): Dile path in rediscluster, could be a string key Returns: - (:obj:`data`): Deserialized data """ _ensure_rediscluster() value_bytes = _redis_cluster.get(path) value = pickle.loads(value_bytes) return value def read_from_file(path: str) -> object: """ Overview: Read file from local file system Arguments: - path (:obj:`str`): File path in local file system Returns: - (:obj:`data`): Deserialized data """ with open(path, "rb") as f: value = pickle.load(f) return value def _ensure_memcached(): """ Overview: Ensures memcache usage Returns: - (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \ memcached_client's ``server_list.conf`` and ``client.conf`` files """ global _memcached if _memcached is None: server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf" client_config_file = "/mnt/lustre/share/memcached_client/client.conf" _memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file) return def read_from_mc(path: str, flush=False) -> object: """ Overview: Read file from memcache, file must be saved by `torch.save()` Arguments: - path (:obj:`str`): File path in local system Returns: - (:obj:`data`): Deserialized data """ _ensure_memcached() while True: try: value = get_mc_package().pyvector() if flush: _memcached.Get(path, value, get_mc_package().MC_READ_THROUGH) return else: _memcached.Get(path, value) value_buf = get_mc_package().ConvertBuffer(value) value_str = io.BytesIO(value_buf) value_str = torch.load(value_str, map_location='cpu') return value_str except Exception: print('read mc failed, retry...') time.sleep(0.01) def read_from_path(path: str): """ Overview: Read file from ceph Arguments: - path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system Returns: - (:obj:`data`): Deserialized data """ if get_ceph_package() is None: logging.info( "You do not have ceph installed! Loading local file!" " If you are not testing locally, something is wrong!" ) return read_from_file(path) else: return read_from_ceph(path) def save_file_ceph(path, data): """ Overview: Save pickle dumped data file to ceph Arguments: - path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not - data (:obj:`Any`): Could be dict, list or tensor etc. """ data = pickle.dumps(data) save_path = os.path.dirname(path) file_name = os.path.basename(path) ceph = get_ceph_package() if ceph is not None: if hasattr(ceph, 'save_from_string'): ceph.save_from_string(save_path, file_name, data) elif hasattr(ceph, 'put'): ceph.put(os.path.join(save_path, file_name), data) else: raise RuntimeError('ceph can not save file, check your ceph installation') else: size = len(data) if save_path == 'do_not_save': logging.info( "You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) + " If you are not testing locally, something is wrong!" ) return p = os.path.join(save_path, file_name) with open(p, 'wb') as f: logging.info( "You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) + " If you are not testing locally, something is wrong!" ) f.write(data) def save_file_redis(path, data): """ Overview: Save pickle dumped data file to redis Arguments: - path (:obj:`str`): File path (could be a string key) in redis - data (:obj:`Any`): Could be dict, list or tensor etc. """ _get_redis().set(path, pickle.dumps(data)) def save_file_rediscluster(path, data): """ Overview: Save pickle dumped data file to rediscluster Arguments: - path (:obj:`str`): File path (could be a string key) in redis - data (:obj:`Any`): Could be dict, list or tensor etc. """ _ensure_rediscluster() data = pickle.dumps(data) _redis_cluster.set(path, data) return def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object: """ Overview: Read file from path Arguments: - path (:obj:`str`): The path of file to read - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system """ if fs_type is None: if path.lower().startswith('s3'): fs_type = 'ceph' elif get_mc_package() is not None: fs_type = 'mc' else: fs_type = 'normal' assert fs_type in ['normal', 'ceph', 'mc'] if fs_type == 'ceph': data = read_from_path(path) elif fs_type == 'normal': if use_lock: with get_file_lock(path, 'read'): data = torch.load(path, map_location='cpu') else: data = torch.load(path, map_location='cpu') elif fs_type == 'mc': data = read_from_mc(path) return data def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None: """ Overview: Save data to file of path Arguments: - path (:obj:`str`): The path of file to save to - data (:obj:`object`): The data to save - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system """ if fs_type is None: if path.lower().startswith('s3'): fs_type = 'ceph' elif get_mc_package() is not None: fs_type = 'mc' else: fs_type = 'normal' assert fs_type in ['normal', 'ceph', 'mc'] if fs_type == 'ceph': save_file_ceph(path, data) elif fs_type == 'normal': if use_lock: with get_file_lock(path, 'write'): torch.save(data, path) else: torch.save(data, path) elif fs_type == 'mc': torch.save(data, path) read_from_mc(path, flush=True) def remove_file(path: str, fs_type: Union[None, str] = None) -> None: """ Overview: Remove file Arguments: - path (:obj:`str`): The path of file you want to remove - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` """ if fs_type is None: fs_type = 'ceph' if path.lower().startswith('s3') else 'normal' assert fs_type in ['normal', 'ceph'] if fs_type == 'ceph': os.popen("aws s3 rm --recursive {}".format(path)) elif fs_type == 'normal': os.popen("rm -rf {}".format(path))