ai-forever commited on
Commit
e759a3a
1 Parent(s): 592fbff

add version of gigachat.py from slides_generator

Browse files
Files changed (1) hide show
  1. src/gigachat.py +143 -0
src/gigachat.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import base64
3
+ import uuid
4
+ import json
5
+ import time
6
+ from typing import Dict, Optional, Any
7
+ from dotenv import load_dotenv
8
+ import os
9
+
10
+ # Load environment variables from .env file
11
+ load_dotenv()
12
+
13
+ AUTH_TOKEN = os.getenv("AUTH_TOKEN")
14
+ COOKIE = os.getenv("COOKIE")
15
+
16
+ # print(f"AUTH_TOKEN: {AUTH_TOKEN}")
17
+ # print(f"COOKIE: {COOKIE}")
18
+
19
+ def get_auth_token(timeout: float = 2) -> Dict[str, Any]:
20
+ """
21
+ Get authentication token.
22
+
23
+ Args:
24
+ timeout (float): Timeout duration in seconds.
25
+
26
+ Returns:
27
+ Dict[str, Any]: Dictionary containing the access token and its expiration time.
28
+ """
29
+ url = "https://beta.saluteai.sberdevices.ru/v1/token"
30
+ payload = 'scope=GIGACHAT_API_CORP'
31
+ headers = {
32
+ 'Content-Type': 'application/x-www-form-urlencoded',
33
+ 'Accept': 'application/json',
34
+ 'RqUID': str(uuid.uuid4()),
35
+ 'Cookie': COOKIE,
36
+ 'Authorization': f'Basic {AUTH_TOKEN}'
37
+ }
38
+ response = requests.post(url, headers=headers, data=payload, timeout=timeout)
39
+ response_dict = response.json()
40
+ return {
41
+ 'access_token': response_dict['tok'],
42
+ 'expires_at': response_dict['exp']
43
+ }
44
+
45
+ def check_auth_token(token_data: Dict[str, Any]) -> bool:
46
+ """
47
+ Check if the authentication token is valid.
48
+
49
+ Args:
50
+ token_data (Dict[str, Any]): Dictionary containing token data.
51
+
52
+ Returns:
53
+ bool: True if the token is valid, False otherwise.
54
+ """
55
+ return token_data['expires_at'] - time.time() > 5
56
+
57
+ token_data: Optional[Dict[str, Any]] = None
58
+
59
+ def get_response(
60
+ prompt: str,
61
+ model: str,
62
+ timeout: int = 120,
63
+ n: int = 1,
64
+ fuse_key_word: Optional[str] = None,
65
+ use_giga_censor: bool = False,
66
+ max_tokens: int = 512,
67
+ ) -> requests.Response:
68
+ """
69
+ Send a text generation request to the API.
70
+
71
+ Args:
72
+ prompt (str): The input prompt.
73
+ model (str): The model to be used for generation.
74
+ timeout (int): Timeout duration in seconds.
75
+ n (int): Number of responses.
76
+ fuse_key_word (Optional[str]): Additional keyword to include in the prompt.
77
+ use_giga_censor (bool): Whether to use profanity filtering.
78
+ max_tokens (int): Maximum number of tokens in the response.
79
+
80
+ Returns:
81
+ requests.Response: API response.
82
+ """
83
+ global token_data
84
+
85
+ url = "https://beta.saluteai.sberdevices.ru/v1/chat/completions"
86
+ payload = json.dumps({
87
+ "model": model,
88
+ "messages": [
89
+ {
90
+ "role": "user",
91
+ "content": ' '.join([fuse_key_word, prompt]) if fuse_key_word else prompt
92
+ }
93
+ ],
94
+ "temperature": 0.87,
95
+ "top_p": 0.47,
96
+ "n": n,
97
+ "stream": False,
98
+ "max_tokens": max_tokens,
99
+ "repetition_penalty": 1.07,
100
+ "profanity_check": use_giga_censor
101
+ })
102
+
103
+ if token_data is None or not check_auth_token(token_data):
104
+ token_data = get_auth_token()
105
+
106
+ headers = {
107
+ 'Content-Type': 'application/json',
108
+ 'Accept': 'application/json',
109
+ 'Authorization': f'Bearer {token_data["access_token"]}'
110
+ }
111
+ response = requests.post(url, headers=headers, data=payload, timeout=timeout)
112
+ return response
113
+
114
+ def giga_generate(
115
+ prompt: str,
116
+ model_version: str = "GigaChat-Max",
117
+ max_tokens: int = 2048
118
+ ) -> str:
119
+ """
120
+ Generate text using the GigaChat model.
121
+
122
+ Args:
123
+ prompt (str): The input prompt.
124
+ model_version (str): The version of the model to use.
125
+ max_tokens (int): Maximum number of tokens in the response.
126
+
127
+ Returns:
128
+ str: Generated text.
129
+ """
130
+ response = get_response(
131
+ prompt,
132
+ model_version,
133
+ use_giga_censor=False,
134
+ max_tokens=max_tokens,
135
+ )
136
+ response_dict = response.json()
137
+
138
+ if response_dict['choices'][0]['finish_reason'] == 'blacklist':
139
+ print('GigaCensor triggered!')
140
+ return 'Censored Text'
141
+ else:
142
+ response_str = response_dict['choices'][0]['message']['content']
143
+ return response_str