gomoku / DI-engine /ding /utils /file_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
10.4 kB
import io
from ditk import logging
import os
import pickle
import time
from functools import lru_cache
from typing import Union
import torch
from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc
from .lock_helper import get_file_lock
_memcached = None
_redis_cluster = None
if os.environ.get('DI_STORE', 'off').lower() == 'on':
print('Enable DI-store')
from di_store import Client
di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml')
di_store_client = Client(di_store_config_path)
def save_to_di_store(data):
return di_store_client.put(data)
def read_from_di_store(object_ref):
data = di_store_client.get(object_ref)
di_store_client.delete(object_ref)
return data
else:
save_to_di_store = read_from_di_store = None
@lru_cache()
def get_ceph_package():
return try_import_ceph()
@lru_cache()
def get_redis_package():
return try_import_redis()
@lru_cache()
def get_rediscluster_package():
return try_import_rediscluster()
@lru_cache()
def get_mc_package():
return try_import_mc()
def read_from_ceph(path: str) -> object:
"""
Overview:
Read file from ceph
Arguments:
- path (:obj:`str`): File path in ceph, start with ``"s3://"``
Returns:
- (:obj:`data`): Deserialized data
"""
value = get_ceph_package().Get(path)
if not value:
raise FileNotFoundError("File({}) doesn't exist in ceph".format(path))
return pickle.loads(value)
@lru_cache()
def _get_redis(host='localhost', port=6379):
"""
Overview:
Ensures redis usage
Arguments:
- host (:obj:`str`): Host string
- port (:obj:`int`): Port number
Returns:
- (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0``
"""
return get_redis_package().StrictRedis(host=host, port=port, db=0)
def read_from_redis(path: str) -> object:
"""
Overview:
Read file from redis
Arguments:
- path (:obj:`str`): Dile path in redis, could be a string key
Returns:
- (:obj:`data`): Deserialized data
"""
return pickle.loads(_get_redis().get(path))
def _ensure_rediscluster(startup_nodes=[{"host": "127.0.0.1", "port": "7000"}]):
"""
Overview:
Ensures redis usage
Arguments:
- List of startup nodes (:obj:`dict`) of
- host (:obj:`str`): Host string
- port (:obj:`int`): Port number
Returns:
- (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \
and ``False`` for ``decode_responses`` in default.
"""
global _redis_cluster
if _redis_cluster is None:
_redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False)
return
def read_from_rediscluster(path: str) -> object:
"""
Overview:
Read file from rediscluster
Arguments:
- path (:obj:`str`): Dile path in rediscluster, could be a string key
Returns:
- (:obj:`data`): Deserialized data
"""
_ensure_rediscluster()
value_bytes = _redis_cluster.get(path)
value = pickle.loads(value_bytes)
return value
def read_from_file(path: str) -> object:
"""
Overview:
Read file from local file system
Arguments:
- path (:obj:`str`): File path in local file system
Returns:
- (:obj:`data`): Deserialized data
"""
with open(path, "rb") as f:
value = pickle.load(f)
return value
def _ensure_memcached():
"""
Overview:
Ensures memcache usage
Returns:
- (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \
memcached_client's ``server_list.conf`` and ``client.conf`` files
"""
global _memcached
if _memcached is None:
server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf"
client_config_file = "/mnt/lustre/share/memcached_client/client.conf"
_memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file)
return
def read_from_mc(path: str, flush=False) -> object:
"""
Overview:
Read file from memcache, file must be saved by `torch.save()`
Arguments:
- path (:obj:`str`): File path in local system
Returns:
- (:obj:`data`): Deserialized data
"""
_ensure_memcached()
while True:
try:
value = get_mc_package().pyvector()
if flush:
_memcached.Get(path, value, get_mc_package().MC_READ_THROUGH)
return
else:
_memcached.Get(path, value)
value_buf = get_mc_package().ConvertBuffer(value)
value_str = io.BytesIO(value_buf)
value_str = torch.load(value_str, map_location='cpu')
return value_str
except Exception:
print('read mc failed, retry...')
time.sleep(0.01)
def read_from_path(path: str):
"""
Overview:
Read file from ceph
Arguments:
- path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system
Returns:
- (:obj:`data`): Deserialized data
"""
if get_ceph_package() is None:
logging.info(
"You do not have ceph installed! Loading local file!"
" If you are not testing locally, something is wrong!"
)
return read_from_file(path)
else:
return read_from_ceph(path)
def save_file_ceph(path, data):
"""
Overview:
Save pickle dumped data file to ceph
Arguments:
- path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not
- data (:obj:`Any`): Could be dict, list or tensor etc.
"""
data = pickle.dumps(data)
save_path = os.path.dirname(path)
file_name = os.path.basename(path)
ceph = get_ceph_package()
if ceph is not None:
if hasattr(ceph, 'save_from_string'):
ceph.save_from_string(save_path, file_name, data)
elif hasattr(ceph, 'put'):
ceph.put(os.path.join(save_path, file_name), data)
else:
raise RuntimeError('ceph can not save file, check your ceph installation')
else:
size = len(data)
if save_path == 'do_not_save':
logging.info(
"You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) +
" If you are not testing locally, something is wrong!"
)
return
p = os.path.join(save_path, file_name)
with open(p, 'wb') as f:
logging.info(
"You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) +
" If you are not testing locally, something is wrong!"
)
f.write(data)
def save_file_redis(path, data):
"""
Overview:
Save pickle dumped data file to redis
Arguments:
- path (:obj:`str`): File path (could be a string key) in redis
- data (:obj:`Any`): Could be dict, list or tensor etc.
"""
_get_redis().set(path, pickle.dumps(data))
def save_file_rediscluster(path, data):
"""
Overview:
Save pickle dumped data file to rediscluster
Arguments:
- path (:obj:`str`): File path (could be a string key) in redis
- data (:obj:`Any`): Could be dict, list or tensor etc.
"""
_ensure_rediscluster()
data = pickle.dumps(data)
_redis_cluster.set(path, data)
return
def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object:
"""
Overview:
Read file from path
Arguments:
- path (:obj:`str`): The path of file to read
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
- use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
"""
if fs_type is None:
if path.lower().startswith('s3'):
fs_type = 'ceph'
elif get_mc_package() is not None:
fs_type = 'mc'
else:
fs_type = 'normal'
assert fs_type in ['normal', 'ceph', 'mc']
if fs_type == 'ceph':
data = read_from_path(path)
elif fs_type == 'normal':
if use_lock:
with get_file_lock(path, 'read'):
data = torch.load(path, map_location='cpu')
else:
data = torch.load(path, map_location='cpu')
elif fs_type == 'mc':
data = read_from_mc(path)
return data
def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None:
"""
Overview:
Save data to file of path
Arguments:
- path (:obj:`str`): The path of file to save to
- data (:obj:`object`): The data to save
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
- use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
"""
if fs_type is None:
if path.lower().startswith('s3'):
fs_type = 'ceph'
elif get_mc_package() is not None:
fs_type = 'mc'
else:
fs_type = 'normal'
assert fs_type in ['normal', 'ceph', 'mc']
if fs_type == 'ceph':
save_file_ceph(path, data)
elif fs_type == 'normal':
if use_lock:
with get_file_lock(path, 'write'):
torch.save(data, path)
else:
torch.save(data, path)
elif fs_type == 'mc':
torch.save(data, path)
read_from_mc(path, flush=True)
def remove_file(path: str, fs_type: Union[None, str] = None) -> None:
"""
Overview:
Remove file
Arguments:
- path (:obj:`str`): The path of file you want to remove
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
"""
if fs_type is None:
fs_type = 'ceph' if path.lower().startswith('s3') else 'normal'
assert fs_type in ['normal', 'ceph']
if fs_type == 'ceph':
os.popen("aws s3 rm --recursive {}".format(path))
elif fs_type == 'normal':
os.popen("rm -rf {}".format(path))