File size: 2,120 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import time
import os
from ding.interaction import Slave, TaskFail
from ding.utils import lists_to_dicts


class NaiveLearner(Slave):

    def __init__(self, *args, prefix='', **kwargs):
        super().__init__(*args, **kwargs)
        self._prefix = prefix

    def _process_task(self, task):
        task_name = task['name']
        if task_name == 'resource':
            return {'cpu': 'xxx', 'gpu': 'xxx'}
        elif task_name == 'learner_start_task':
            time.sleep(1)
            self.task_info = task['task_info']
            self.count = 0
            return {'message': 'learner task has started'}
        elif task_name == 'learner_get_data_task':
            time.sleep(0.01)
            return {
                'task_id': self.task_info['task_id'],
                'buffer_id': self.task_info['buffer_id'],
                'batch_size': 2,
                'cur_learner_iter': 1
            }
        elif task_name == 'learner_learn_task':
            data = task['data']
            if data is None:
                raise TaskFail(result={'message': 'no data'})
            time.sleep(0.1)
            data = lists_to_dicts(data)
            assert 'data_id' in data.keys()
            priority_keys = ['replay_unique_id', 'replay_buffer_idx', 'priority']
            self.count += 1
            ret = {
                'info': {
                    'learner_step': self.count
                },
                'task_id': self.task_info['task_id'],
                'buffer_id': self.task_info['buffer_id']
            }
            ret['info']['priority_info'] = {k: data[k] for k in priority_keys}
            if self.count > 5:
                ret['info']['learner_done'] = True
                os.popen('touch {}_final_model.pth'.format(self._prefix))
            return ret
        elif task_name == 'learner_close_task':
            return {'task_id': self.task_info['task_id'], 'buffer_id': self.task_info['buffer_id']}
        else:
            raise TaskFail(
                result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name)
            )