Superkingjcj's picture
Upload 111 files
e679d69 verified
raw
history blame
1.47 kB
import importlib
import sys
from typing import Dict
import ray
from lagent.schema import AgentMessage
from lagent.utils import load_class_from_string
class AsyncAgentRayActor:
def __init__(
self,
config: Dict,
num_gpus: int,
):
cls_name = config.pop('type')
python_path = config.pop('python_path', None)
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
cls_name, str) else cls_name
AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
self.agent_actor = AsyncAgentActor.remote(**config)
async def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
response = await self.agent_actor.__call__.remote(
*message, session_id=session_id, **kwargs)
return response
class AgentRayActor:
def __init__(
self,
config: Dict,
num_gpus: int,
):
cls_name = config.pop('type')
python_path = config.pop('python_path', None)
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
cls_name, str) else cls_name
AgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
self.agent_actor = AgentActor.remote(**config)
def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
response = self.agent_actor.__call__.remote(
*message, session_id=session_id, **kwargs)
return ray.get(response)