File size: 2,385 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import uuid
from ditk import logging
from time import sleep
from typing import Tuple

import redis
from ding.framework.message_queue.mq import MQ
from ding.utils import MQ_REGISTRY


@MQ_REGISTRY.register("redis")
class RedisMQ(MQ):

    def __init__(self, redis_host: str, redis_port: int, **kwargs) -> None:
        """
        Overview:
            Connect distributed processes with redis
        Arguments:
            - redis_host (:obj:`str`): Redis server host.
            - redis_port (:obj:`int`): Redis server port.
        """
        self.host = redis_host
        self.port = redis_port if isinstance(redis_port, int) else int(redis_port)
        self.db = 0
        self._running = False
        self._id = uuid.uuid4().hex.encode()

    def listen(self) -> None:
        self._client = client = redis.Redis(host=self.host, port=self.port, db=self.db)
        self._sub = client.pubsub()
        self._running = True

    def publish(self, topic: str, data: bytes) -> None:
        data = self._id + b"::" + data
        self._client.publish(topic, data)

    def subscribe(self, topic: str) -> None:
        self._sub.subscribe(topic)

    def unsubscribe(self, topic: str) -> None:
        self._sub.unsubscribe(topic)

    def recv(self) -> Tuple[str, bytes]:
        while True:
            if not self._running:
                raise RuntimeError("Redis MQ was not running!")
            try:
                msg = self._sub.get_message(ignore_subscribe_messages=True)
                if msg is None:
                    sleep(0.001)
                    continue
                topic = msg["channel"].decode()
                data = msg["data"].split(b"::", maxsplit=1)
                if len(data) != 2 or len(data[0]) != 32:
                    logging.warn("Got invalid message from topic: {}".format(topic))
                    continue
                node_id, data = data
                if node_id == self._id:  # Discard message sent by self
                    continue
                return topic, data
            except (OSError, AttributeError, Exception) as e:
                logging.error("Meet exception when listening for new messages", e)

    def stop(self) -> None:
        if self._running:
            self._running = False
            self._sub.close()
            self._client.close()

    def __del__(self) -> None:
        self.stop()