File size: 10,413 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import io
from ditk import logging
import os
import pickle
import time
from functools import lru_cache
from typing import Union

import torch

from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc
from .lock_helper import get_file_lock

_memcached = None
_redis_cluster = None

if os.environ.get('DI_STORE', 'off').lower() == 'on':
    print('Enable DI-store')
    from di_store import Client

    di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml')
    di_store_client = Client(di_store_config_path)

    def save_to_di_store(data):
        return di_store_client.put(data)

    def read_from_di_store(object_ref):
        data = di_store_client.get(object_ref)
        di_store_client.delete(object_ref)
        return data
else:
    save_to_di_store = read_from_di_store = None


@lru_cache()
def get_ceph_package():
    return try_import_ceph()


@lru_cache()
def get_redis_package():
    return try_import_redis()


@lru_cache()
def get_rediscluster_package():
    return try_import_rediscluster()


@lru_cache()
def get_mc_package():
    return try_import_mc()


def read_from_ceph(path: str) -> object:
    """
    Overview:
        Read file from ceph
    Arguments:
        - path (:obj:`str`): File path in ceph, start with ``"s3://"``
    Returns:
        - (:obj:`data`): Deserialized data
    """
    value = get_ceph_package().Get(path)
    if not value:
        raise FileNotFoundError("File({}) doesn't exist in ceph".format(path))

    return pickle.loads(value)


@lru_cache()
def _get_redis(host='localhost', port=6379):
    """
    Overview:
        Ensures redis usage
    Arguments:
        - host (:obj:`str`): Host string
        - port (:obj:`int`): Port number
    Returns:
        - (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0``
    """
    return get_redis_package().StrictRedis(host=host, port=port, db=0)


def read_from_redis(path: str) -> object:
    """
    Overview:
        Read file from redis
    Arguments:
        - path (:obj:`str`): Dile path in redis, could be a string key
    Returns:
        - (:obj:`data`): Deserialized data
    """
    return pickle.loads(_get_redis().get(path))


def _ensure_rediscluster(startup_nodes=[{"host": "127.0.0.1", "port": "7000"}]):
    """
    Overview:
        Ensures redis usage
    Arguments:
        - List of startup nodes (:obj:`dict`) of
            - host (:obj:`str`): Host string
            - port (:obj:`int`): Port number
    Returns:
        - (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \
            and ``False`` for ``decode_responses`` in default.
    """
    global _redis_cluster
    if _redis_cluster is None:
        _redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False)
    return


def read_from_rediscluster(path: str) -> object:
    """
    Overview:
        Read file from rediscluster
    Arguments:
        - path (:obj:`str`): Dile path in rediscluster, could be a string key
    Returns:
        - (:obj:`data`): Deserialized data
    """
    _ensure_rediscluster()
    value_bytes = _redis_cluster.get(path)
    value = pickle.loads(value_bytes)
    return value


def read_from_file(path: str) -> object:
    """
    Overview:
        Read file from local file system
    Arguments:
        - path (:obj:`str`): File path in local file system
    Returns:
        - (:obj:`data`): Deserialized data
    """
    with open(path, "rb") as f:
        value = pickle.load(f)

    return value


def _ensure_memcached():
    """
    Overview:
        Ensures memcache usage
    Returns:
        - (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \
            memcached_client's ``server_list.conf`` and ``client.conf`` files
    """
    global _memcached
    if _memcached is None:
        server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf"
        client_config_file = "/mnt/lustre/share/memcached_client/client.conf"
        _memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file)
    return


def read_from_mc(path: str, flush=False) -> object:
    """
    Overview:
        Read file from memcache, file must be saved by `torch.save()`
    Arguments:
        - path (:obj:`str`): File path in local system
    Returns:
        - (:obj:`data`): Deserialized data
    """
    _ensure_memcached()
    while True:
        try:
            value = get_mc_package().pyvector()
            if flush:
                _memcached.Get(path, value, get_mc_package().MC_READ_THROUGH)
                return
            else:
                _memcached.Get(path, value)
            value_buf = get_mc_package().ConvertBuffer(value)
            value_str = io.BytesIO(value_buf)
            value_str = torch.load(value_str, map_location='cpu')
            return value_str
        except Exception:
            print('read mc failed, retry...')
            time.sleep(0.01)


def read_from_path(path: str):
    """
    Overview:
        Read file from ceph
    Arguments:
        - path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system
    Returns:
        - (:obj:`data`): Deserialized data
    """
    if get_ceph_package() is None:
        logging.info(
            "You do not have ceph installed! Loading local file!"
            " If you are not testing locally, something is wrong!"
        )
        return read_from_file(path)
    else:
        return read_from_ceph(path)


def save_file_ceph(path, data):
    """
    Overview:
        Save pickle dumped data file to ceph
    Arguments:
        - path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not
        - data (:obj:`Any`): Could be dict, list or tensor etc.
    """
    data = pickle.dumps(data)
    save_path = os.path.dirname(path)
    file_name = os.path.basename(path)
    ceph = get_ceph_package()
    if ceph is not None:
        if hasattr(ceph, 'save_from_string'):
            ceph.save_from_string(save_path, file_name, data)
        elif hasattr(ceph, 'put'):
            ceph.put(os.path.join(save_path, file_name), data)
        else:
            raise RuntimeError('ceph can not save file, check your ceph installation')
    else:
        size = len(data)
        if save_path == 'do_not_save':
            logging.info(
                "You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) +
                " If you are not testing locally, something is wrong!"
            )
            return
        p = os.path.join(save_path, file_name)
        with open(p, 'wb') as f:
            logging.info(
                "You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) +
                " If you are not testing locally, something is wrong!"
            )
            f.write(data)


def save_file_redis(path, data):
    """
    Overview:
        Save pickle dumped data file to redis
    Arguments:
        - path (:obj:`str`): File path (could be a string key) in redis
        - data (:obj:`Any`): Could be dict, list or tensor etc.
    """
    _get_redis().set(path, pickle.dumps(data))


def save_file_rediscluster(path, data):
    """
    Overview:
        Save pickle dumped data file to rediscluster
    Arguments:
        - path (:obj:`str`): File path (could be a string key) in redis
        - data (:obj:`Any`): Could be dict, list or tensor etc.
    """
    _ensure_rediscluster()
    data = pickle.dumps(data)
    _redis_cluster.set(path, data)
    return


def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object:
    """
    Overview:
        Read file from path
    Arguments:
        - path (:obj:`str`): The path of file to read
        - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
        - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
    """
    if fs_type is None:
        if path.lower().startswith('s3'):
            fs_type = 'ceph'
        elif get_mc_package() is not None:
            fs_type = 'mc'
        else:
            fs_type = 'normal'
    assert fs_type in ['normal', 'ceph', 'mc']
    if fs_type == 'ceph':
        data = read_from_path(path)
    elif fs_type == 'normal':
        if use_lock:
            with get_file_lock(path, 'read'):
                data = torch.load(path, map_location='cpu')
        else:
            data = torch.load(path, map_location='cpu')
    elif fs_type == 'mc':
        data = read_from_mc(path)
    return data


def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None:
    """
    Overview:
        Save data to file of path
    Arguments:
        - path (:obj:`str`): The path of file to save to
        - data (:obj:`object`): The data to save
        - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
        - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
    """
    if fs_type is None:
        if path.lower().startswith('s3'):
            fs_type = 'ceph'
        elif get_mc_package() is not None:
            fs_type = 'mc'
        else:
            fs_type = 'normal'
    assert fs_type in ['normal', 'ceph', 'mc']
    if fs_type == 'ceph':
        save_file_ceph(path, data)
    elif fs_type == 'normal':
        if use_lock:
            with get_file_lock(path, 'write'):
                torch.save(data, path)
        else:
            torch.save(data, path)
    elif fs_type == 'mc':
        torch.save(data, path)
        read_from_mc(path, flush=True)


def remove_file(path: str, fs_type: Union[None, str] = None) -> None:
    """
    Overview:
        Remove file
    Arguments:
        - path (:obj:`str`): The path of file you want to remove
        - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
    """
    if fs_type is None:
        fs_type = 'ceph' if path.lower().startswith('s3') else 'normal'
    assert fs_type in ['normal', 'ceph']
    if fs_type == 'ceph':
        os.popen("aws s3 rm --recursive {}".format(path))
    elif fs_type == 'normal':
        os.popen("rm -rf {}".format(path))