File size: 2,436 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from ding.interaction.slave import Slave, TaskFail


class NaiveCollector(Slave):
    """
    Overview:
        A slave, whose master is coordinator.
        Used to pass message between comm collector and coordinator.
    Interfaces:
        _process_task, _get_timestep
    """

    def __init__(self, *args, prefix='', **kwargs):
        super().__init__(*args, **kwargs)
        self._prefix = prefix

    def _process_task(self, task):
        """
        Overview:
            Process a task according to input task info dict, which is passed in by master coordinator.
            For each type of task, you can refer to corresponding callback function in comm collector for details.
        Arguments:
            - cfg (:obj:`EasyDict`): Task dict. Must contain key "name".
        Returns:
            - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception.
        """
        task_name = task['name']
        if task_name == 'resource':
            return {'cpu': '20', 'gpu': '1'}
        elif task_name == 'collector_start_task':
            self.count = 0
            self.task_info = task['task_info']
            return {'message': 'collector task has started'}
        elif task_name == 'collector_data_task':
            self.count += 1
            data_id = './{}_{}_{}'.format(self._prefix, self.task_info['task_id'], self.count)
            torch.save(self._get_timestep(), data_id)
            data = {'data_id': data_id, 'buffer_id': self.task_info['buffer_id'], 'unroll_split_begin': 0}
            data['task_id'] = self.task_info['task_id']
            if self.count == 20:
                return {
                    'task_id': self.task_info['task_id'],
                    'collector_done': True,
                    'cur_episode': 1,
                    'cur_step': 314,
                    'cur_sample': 314,
                }
            else:
                return data
        else:
            raise TaskFail(
                result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name)
            )

    def _get_timestep(self):
        return [
            {
                'obs': torch.rand(4),
                'next_obs': torch.randn(4),
                'reward': torch.randint(0, 2, size=(3, )).float(),
                'action': torch.randint(0, 2, size=(1, )),
                'done': False,
            }
        ]