File size: 7,818 Bytes
444f09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# encoding: utf-8
# @Time   : 2023/12/25
# @Author : Spike
# @Descr   :
import json
import os
import re
import requests
from typing import List, Dict, Tuple
from toolbox import get_conf, encode_image, get_pictures_list

proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")

"""
========================================================================
第五部分 一些文件处理方法
files_filter_handler 根据type过滤文件
input_encode_handler 提取input中的文件,并解析
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
html_view_blank 超链接
html_local_file 本地文件取相对路径
to_markdown_tabs 文件list 转换为 md tab
"""


def files_filter_handler(file_list):
    new_list = []
    filter_ = [
        "png",
        "jpg",
        "jpeg",
        "bmp",
        "svg",
        "webp",
        "ico",
        "tif",
        "tiff",
        "raw",
        "eps",
    ]
    for file in file_list:
        file = str(file).replace("file=", "")
        if os.path.exists(file):
            if str(os.path.basename(file)).split(".")[-1] in filter_:
                new_list.append(file)
    return new_list


def input_encode_handler(inputs, llm_kwargs):
    if llm_kwargs["most_recent_uploaded"].get("path"):
        image_paths = get_pictures_list(llm_kwargs["most_recent_uploaded"]["path"])
    md_encode = []
    for md_path in image_paths:
        type_ = os.path.splitext(md_path)[1].replace(".", "")
        type_ = "jpeg" if type_ == "jpg" else type_
        md_encode.append({"data": encode_image(md_path), "type": type_})
    return inputs, md_encode


def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
    new_list = []
    if not filter_:
        filter_ = [
            "png",
            "jpg",
            "jpeg",
            "bmp",
            "svg",
            "webp",
            "ico",
            "tif",
            "tiff",
            "raw",
            "eps",
        ]
    for file in file_list:
        if str(os.path.basename(file)).split(".")[-1] in filter_:
            new_list.append(html_local_img(file, md=md_type))
        elif os.path.exists(file):
            new_list.append(link_mtime_to_md(file))
        else:
            new_list.append(file)
    return new_list


def link_mtime_to_md(file):
    link_local = html_local_file(file)
    link_name = os.path.basename(file)
    a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
    return a


def html_local_file(file):
    base_path = os.path.dirname(__file__)  # 项目目录
    if os.path.exists(str(file)):
        file = f'file={file.replace(base_path, ".")}'
    return file


def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
    style = ""
    if max_width is not None:
        style += f"max-width: {max_width};"
    if max_height is not None:
        style += f"max-height: {max_height};"
    __file = html_local_file(__file)
    a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
    if md:
        a = f"![{__file}]({__file})"
    return a


def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
    """
    Args:
        head: 表头:[]
        tabs: 表值:[[列1], [列2], [列3], [列4]]
        alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
        column: True to keep data in columns, False to keep data in rows (default).
    Returns:
        A string representation of the markdown table.
    """
    if column:
        transposed_tabs = list(map(list, zip(*tabs)))
    else:
        transposed_tabs = tabs
    # Find the maximum length among the columns
    max_len = max(len(column) for column in transposed_tabs)

    tab_format = "| %s "
    tabs_list = "".join([tab_format % i for i in head]) + "|\n"
    tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"

    for i in range(max_len):
        row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
        row_data = file_manifest_filter_html(row_data, filter_=None)
        tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"

    return tabs_list


class GoogleChatInit:
    def __init__(self):
        self.url_gemini = "https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k"

    def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
        headers, payload = self.generate_message_payload(
            inputs, llm_kwargs, history, system_prompt
        )
        response = requests.post(
            url=self.url_gemini,
            headers=headers,
            data=json.dumps(payload),
            stream=True,
            proxies=proxies,
            timeout=TIMEOUT_SECONDS,
        )
        return response.iter_lines()

    def __conversation_user(self, user_input, llm_kwargs):
        what_i_have_asked = {"role": "user", "parts": []}
        if "vision" not in self.url_gemini:
            input_ = user_input
            encode_img = []
        else:
            input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
        what_i_have_asked["parts"].append({"text": input_})
        if encode_img:
            for data in encode_img:
                what_i_have_asked["parts"].append(
                    {
                        "inline_data": {
                            "mime_type": f"image/{data['type']}",
                            "data": data["data"],
                        }
                    }
                )
        return what_i_have_asked

    def __conversation_history(self, history, llm_kwargs):
        messages = []
        conversation_cnt = len(history) // 2
        if conversation_cnt:
            for index in range(0, 2 * conversation_cnt, 2):
                what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
                what_gpt_answer = {
                    "role": "model",
                    "parts": [{"text": history[index + 1]}],
                }
                messages.append(what_i_have_asked)
                messages.append(what_gpt_answer)
        return messages

    def generate_message_payload(
        self, inputs, llm_kwargs, history, system_prompt
    ) -> Tuple[Dict, Dict]:
        messages = [
            # {"role": "system", "parts": [{"text": system_prompt}]},  # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
            # {"role": "user", "parts": [{"text": ""}]},
            # {"role": "model", "parts": [{"text": ""}]}
        ]
        self.url_gemini = self.url_gemini.replace(
            "%m", llm_kwargs["llm_model"]
        ).replace("%k", get_conf("GEMINI_API_KEY"))
        header = {"Content-Type": "application/json"}
        if "vision" not in self.url_gemini:  # 不是vision 才处理history
            messages.extend(
                self.__conversation_history(history, llm_kwargs)
            )  # 处理 history
        messages.append(self.__conversation_user(inputs, llm_kwargs))  # 处理用户对话
        payload = {
            "contents": messages,
            "generationConfig": {
                # "maxOutputTokens": 800,
                "stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
                "temperature": llm_kwargs.get("temperature", 1),
                "topP": llm_kwargs.get("top_p", 0.8),
                "topK": 10,
            },
        }
        return header, payload


if __name__ == "__main__":
    google = GoogleChatInit()
    # print(gootle.generate_message_payload('你好呀', {},  ['123123', '3123123'], ''))
    # gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')