import torch from ding.interaction.slave import Slave, TaskFail class NaiveCollector(Slave): """ Overview: A slave, whose master is coordinator. Used to pass message between comm collector and coordinator. Interfaces: _process_task, _get_timestep """ def __init__(self, *args, prefix='', **kwargs): super().__init__(*args, **kwargs) self._prefix = prefix def _process_task(self, task): """ Overview: Process a task according to input task info dict, which is passed in by master coordinator. For each type of task, you can refer to corresponding callback function in comm collector for details. Arguments: - cfg (:obj:`EasyDict`): Task dict. Must contain key "name". Returns: - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. """ task_name = task['name'] if task_name == 'resource': return {'cpu': '20', 'gpu': '1'} elif task_name == 'collector_start_task': self.count = 0 self.task_info = task['task_info'] return {'message': 'collector task has started'} elif task_name == 'collector_data_task': self.count += 1 data_id = './{}_{}_{}'.format(self._prefix, self.task_info['task_id'], self.count) torch.save(self._get_timestep(), data_id) data = {'data_id': data_id, 'buffer_id': self.task_info['buffer_id'], 'unroll_split_begin': 0} data['task_id'] = self.task_info['task_id'] if self.count == 20: return { 'task_id': self.task_info['task_id'], 'collector_done': True, 'cur_episode': 1, 'cur_step': 314, 'cur_sample': 314, } else: return data else: raise TaskFail( result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name) ) def _get_timestep(self): return [ { 'obs': torch.rand(4), 'next_obs': torch.randn(4), 'reward': torch.randint(0, 2, size=(3, )).float(), 'action': torch.randint(0, 2, size=(1, )), 'done': False, } ]