File size: 1,472 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import importlib
import sys
from typing import Dict

import ray

from lagent.schema import AgentMessage
from lagent.utils import load_class_from_string


class AsyncAgentRayActor:

    def __init__(
        self,
        config: Dict,
        num_gpus: int,
    ):
        cls_name = config.pop('type')
        python_path = config.pop('python_path', None)
        cls_name = load_class_from_string(cls_name, python_path) if isinstance(
            cls_name, str) else cls_name
        AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
        self.agent_actor = AsyncAgentActor.remote(**config)

    async def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
        response = await self.agent_actor.__call__.remote(
            *message, session_id=session_id, **kwargs)
        return response


class AgentRayActor:

    def __init__(
        self,
        config: Dict,
        num_gpus: int,
    ):
        cls_name = config.pop('type')
        python_path = config.pop('python_path', None)
        cls_name = load_class_from_string(cls_name, python_path) if isinstance(
            cls_name, str) else cls_name
        AgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
        self.agent_actor = AgentActor.remote(**config)

    def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
        response = self.agent_actor.__call__.remote(
            *message, session_id=session_id, **kwargs)
        return ray.get(response)