File size: 4,491 Bytes
b6bff08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import traceback
from queue import Queue
from threading import Thread
import collections.abc

import torch
from transformers import StoppingCriteria


class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], encounters=[]):
        super().__init__()
        assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
        self.encounters = encounters
        self.stops = [stop.to("cuda") for stop in stops]
        self.num_stops = [0] * len(stops)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stopi, stop in enumerate(self.stops):
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                self.num_stops[stopi] += 1
                if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
                    return True
        # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
        # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
        return False


class Stream(StoppingCriteria):
    """
    This class can be used to callback during generation. Keep
    in mind for decoder-only type of transformers, this will include the initial prompted tokens.

    Args:
        func (`callable`):
            A callable function to apply on first input in list every iteration of generation
    """

    def __init__(self, func=None):
        self.func = func

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.func is not None:
            # only consume first of multiple responses
            self.func(input_ids[0])
        return False


class CallbackToGenerator(collections.abc.Generator):
    """
    A generator wrapper for a function that invokes a callback multiple times.

    Calling `send` on the generator emits a value from one callback, and returns
    the next.

    Note this starts a background thread
    """

    def __init__(self, func, *args, callback=None, **kwargs):
        self.func = func
        self.args = args
        self.kwargs = kwargs
        self.callback = callback

        self._ready_queue = Queue(1)
        self._done_queue = Queue(1)
        self._done_holder = [False]

        # local to avoid reference cycles
        ready_queue = self._ready_queue
        done_queue = self._done_queue
        done_holder = self._done_holder

        def val_callback(value):
            done_queue.put((False, value))
            cmd, val = ready_queue.get()
            if cmd == 'send':
                return val
            elif cmd == 'throw':
                raise val
            else:
                assert False  # pragma: no cover

        def thread_func():
            while True:
                cmd, val = ready_queue.get()
                if cmd == 'send' and val is not None:
                    done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
                    continue
                break
            try:
                if cmd == 'throw':
                    raise val
                ret = func(callback=val_callback, **self.kwargs)
                raise StopIteration(ret) if ret is not None else StopIteration
            except BaseException as e:
                done_holder[0] = True
                done_queue.put((True, e))

        self._thread = Thread(target=thread_func)
        self._thread.start()

    def _put(self, *args):
        if self._done_holder[0]:
            raise StopIteration
        self._ready_queue.put(args)
        is_exception, val = self._done_queue.get()
        if is_exception:
            try:
                raise val
            finally:
                # prevent val's traceback containing a reference cycle
                del val
        else:
            return val

    def send(self, value):
        return self._put('send', value)

    def throw(self, exc):
        return self._put('throw', exc)

    def close(self):
        try:
            self.throw(GeneratorExit)
        except StopIteration:
            self._thread.join()
        except GeneratorExit:
            self._thread.join()
        except BaseException:
            self._thread.join()
            raise
        else:
            # yielded again, can't clean up the thread
            raise RuntimeError('Task with callback ignored GeneratorExit')

    def __del__(self):
        self.close()