File size: 4,266 Bytes
16c358c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import traceback
from functools import partial

import gradio as gr

import extensions
import modules.shared as shared

state = {}
available_extensions = []
setup_called = set()


def load_extensions():
    global state, setup_called
    for i, name in enumerate(shared.args.extensions):
        if name in available_extensions:
            if name != 'api':
                print(f'Loading the extension "{name}"... ', end='')
            try:
                exec(f"import extensions.{name}.script")
                extension = getattr(extensions, name).script
                if extension not in setup_called and hasattr(extension, "setup"):
                    setup_called.add(extension)
                    extension.setup()
                state[name] = [True, i]
                if name != 'api':
                    print('Ok.')
            except:
                if name != 'api':
                    print('Fail.')
                traceback.print_exc()


# This iterator returns the extensions in the order specified in the command-line
def iterator():
    for name in sorted(state, key=lambda x: state[x][1]):
        if state[name][0]:
            yield getattr(extensions, name).script, name


# Extension functions that map string -> string
def _apply_string_extensions(function_name, text):
    for extension, _ in iterator():
        if hasattr(extension, function_name):
            text = getattr(extension, function_name)(text)
    return text


# Input hijack of extensions
def _apply_input_hijack(text, visible_text):
    for extension, _ in iterator():
        if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
            extension.input_hijack['state'] = False
            if callable(extension.input_hijack['value']):
                text, visible_text = extension.input_hijack['value'](text, visible_text)
            else:
                text, visible_text = extension.input_hijack['value']
    return text, visible_text


# custom_generate_chat_prompt handling
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
    custom_generate_chat_prompt = None
    for extension, _ in iterator():
        if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
            custom_generate_chat_prompt = extension.custom_generate_chat_prompt
    if custom_generate_chat_prompt is not None:
        return custom_generate_chat_prompt(text, state, **kwargs)
    return None


# Extension functions that override the default tokenizer output
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
    for extension, _ in iterator():
        if hasattr(extension, function_name):
            prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
    return prompt, input_ids, input_embeds


EXTENSION_MAP = {
    "input": partial(_apply_string_extensions, "input_modifier"),
    "output": partial(_apply_string_extensions, "output_modifier"),
    "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
    "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
    "input_hijack": _apply_input_hijack,
    "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
}


def apply_extensions(typ, *args, **kwargs):
    if typ not in EXTENSION_MAP:
        raise ValueError(f"Invalid extension type {typ}")
    return EXTENSION_MAP[typ](*args, **kwargs)


def create_extensions_block():
    global setup_called

    # Updating the default values
    for extension, name in iterator():
        if hasattr(extension, 'params'):
            for param in extension.params:
                _id = f"{name}-{param}"
                if _id in shared.settings:
                    extension.params[param] = shared.settings[_id]

    should_display_ui = False
    for extension, name in iterator():
        if hasattr(extension, "ui"):
            should_display_ui = True

    # Creating the extension ui elements
    if should_display_ui:
        with gr.Column(elem_id="extensions"):
            for extension, name in iterator():
                gr.Markdown(f"\n### {name}")
                if hasattr(extension, "ui"):
                    extension.ui()