zjowowen's picture
init space
079c32c
raw
history blame
2.44 kB
import torch
from ding.interaction.slave import Slave, TaskFail
class NaiveCollector(Slave):
"""
Overview:
A slave, whose master is coordinator.
Used to pass message between comm collector and coordinator.
Interfaces:
_process_task, _get_timestep
"""
def __init__(self, *args, prefix='', **kwargs):
super().__init__(*args, **kwargs)
self._prefix = prefix
def _process_task(self, task):
"""
Overview:
Process a task according to input task info dict, which is passed in by master coordinator.
For each type of task, you can refer to corresponding callback function in comm collector for details.
Arguments:
- cfg (:obj:`EasyDict`): Task dict. Must contain key "name".
Returns:
- result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception.
"""
task_name = task['name']
if task_name == 'resource':
return {'cpu': '20', 'gpu': '1'}
elif task_name == 'collector_start_task':
self.count = 0
self.task_info = task['task_info']
return {'message': 'collector task has started'}
elif task_name == 'collector_data_task':
self.count += 1
data_id = './{}_{}_{}'.format(self._prefix, self.task_info['task_id'], self.count)
torch.save(self._get_timestep(), data_id)
data = {'data_id': data_id, 'buffer_id': self.task_info['buffer_id'], 'unroll_split_begin': 0}
data['task_id'] = self.task_info['task_id']
if self.count == 20:
return {
'task_id': self.task_info['task_id'],
'collector_done': True,
'cur_episode': 1,
'cur_step': 314,
'cur_sample': 314,
}
else:
return data
else:
raise TaskFail(
result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name)
)
def _get_timestep(self):
return [
{
'obs': torch.rand(4),
'next_obs': torch.randn(4),
'reward': torch.randint(0, 2, size=(3, )).float(),
'action': torch.randint(0, 2, size=(1, )),
'done': False,
}
]