ai-forever commited on
Commit
2adb577
1 Parent(s): e759a3a

add system for i2v and max_tokens=128 param

Browse files
Files changed (1) hide show
  1. src/gigachat.py +153 -142
src/gigachat.py CHANGED
@@ -1,143 +1,154 @@
1
- import requests
2
- import base64
3
- import uuid
4
- import json
5
- import time
6
- from typing import Dict, Optional, Any
7
- from dotenv import load_dotenv
8
- import os
9
-
10
- # Load environment variables from .env file
11
- load_dotenv()
12
-
13
- AUTH_TOKEN = os.getenv("AUTH_TOKEN")
14
- COOKIE = os.getenv("COOKIE")
15
-
16
- # print(f"AUTH_TOKEN: {AUTH_TOKEN}")
17
- # print(f"COOKIE: {COOKIE}")
18
-
19
- def get_auth_token(timeout: float = 2) -> Dict[str, Any]:
20
- """
21
- Get authentication token.
22
-
23
- Args:
24
- timeout (float): Timeout duration in seconds.
25
-
26
- Returns:
27
- Dict[str, Any]: Dictionary containing the access token and its expiration time.
28
- """
29
- url = "https://beta.saluteai.sberdevices.ru/v1/token"
30
- payload = 'scope=GIGACHAT_API_CORP'
31
- headers = {
32
- 'Content-Type': 'application/x-www-form-urlencoded',
33
- 'Accept': 'application/json',
34
- 'RqUID': str(uuid.uuid4()),
35
- 'Cookie': COOKIE,
36
- 'Authorization': f'Basic {AUTH_TOKEN}'
37
- }
38
- response = requests.post(url, headers=headers, data=payload, timeout=timeout)
39
- response_dict = response.json()
40
- return {
41
- 'access_token': response_dict['tok'],
42
- 'expires_at': response_dict['exp']
43
- }
44
-
45
- def check_auth_token(token_data: Dict[str, Any]) -> bool:
46
- """
47
- Check if the authentication token is valid.
48
-
49
- Args:
50
- token_data (Dict[str, Any]): Dictionary containing token data.
51
-
52
- Returns:
53
- bool: True if the token is valid, False otherwise.
54
- """
55
- return token_data['expires_at'] - time.time() > 5
56
-
57
- token_data: Optional[Dict[str, Any]] = None
58
-
59
- def get_response(
60
- prompt: str,
61
- model: str,
62
- timeout: int = 120,
63
- n: int = 1,
64
- fuse_key_word: Optional[str] = None,
65
- use_giga_censor: bool = False,
66
- max_tokens: int = 512,
67
- ) -> requests.Response:
68
- """
69
- Send a text generation request to the API.
70
-
71
- Args:
72
- prompt (str): The input prompt.
73
- model (str): The model to be used for generation.
74
- timeout (int): Timeout duration in seconds.
75
- n (int): Number of responses.
76
- fuse_key_word (Optional[str]): Additional keyword to include in the prompt.
77
- use_giga_censor (bool): Whether to use profanity filtering.
78
- max_tokens (int): Maximum number of tokens in the response.
79
-
80
- Returns:
81
- requests.Response: API response.
82
- """
83
- global token_data
84
-
85
- url = "https://beta.saluteai.sberdevices.ru/v1/chat/completions"
86
- payload = json.dumps({
87
- "model": model,
88
- "messages": [
89
- {
90
- "role": "user",
91
- "content": ' '.join([fuse_key_word, prompt]) if fuse_key_word else prompt
92
- }
93
- ],
94
- "temperature": 0.87,
95
- "top_p": 0.47,
96
- "n": n,
97
- "stream": False,
98
- "max_tokens": max_tokens,
99
- "repetition_penalty": 1.07,
100
- "profanity_check": use_giga_censor
101
- })
102
-
103
- if token_data is None or not check_auth_token(token_data):
104
- token_data = get_auth_token()
105
-
106
- headers = {
107
- 'Content-Type': 'application/json',
108
- 'Accept': 'application/json',
109
- 'Authorization': f'Bearer {token_data["access_token"]}'
110
- }
111
- response = requests.post(url, headers=headers, data=payload, timeout=timeout)
112
- return response
113
-
114
- def giga_generate(
115
- prompt: str,
116
- model_version: str = "GigaChat-Max",
117
- max_tokens: int = 2048
118
- ) -> str:
119
- """
120
- Generate text using the GigaChat model.
121
-
122
- Args:
123
- prompt (str): The input prompt.
124
- model_version (str): The version of the model to use.
125
- max_tokens (int): Maximum number of tokens in the response.
126
-
127
- Returns:
128
- str: Generated text.
129
- """
130
- response = get_response(
131
- prompt,
132
- model_version,
133
- use_giga_censor=False,
134
- max_tokens=max_tokens,
135
- )
136
- response_dict = response.json()
137
-
138
- if response_dict['choices'][0]['finish_reason'] == 'blacklist':
139
- print('GigaCensor triggered!')
140
- return 'Censored Text'
141
- else:
142
- response_str = response_dict['choices'][0]['message']['content']
 
 
 
 
 
 
 
 
 
 
 
143
  return response_str
 
1
+ import requests
2
+ import base64
3
+ import uuid
4
+ import json
5
+ import time
6
+ from typing import Dict, Optional, Any
7
+ from dotenv import load_dotenv
8
+ import os
9
+ from prompts import giga_system_prompt_i2v
10
+
11
+ # Load environment variables from .env file
12
+ load_dotenv()
13
+
14
+ AUTH_TOKEN = os.getenv("AUTH_TOKEN")
15
+ COOKIE = os.getenv("COOKIE")
16
+
17
+ # print(f"AUTH_TOKEN: {AUTH_TOKEN}")
18
+ # print(f"COOKIE: {COOKIE}")
19
+
20
+ def get_auth_token(timeout: float = 2) -> Dict[str, Any]:
21
+ """
22
+ Get authentication token.
23
+
24
+ Args:
25
+ timeout (float): Timeout duration in seconds.
26
+
27
+ Returns:
28
+ Dict[str, Any]: Dictionary containing the access token and its expiration time.
29
+ """
30
+ url = "https://beta.saluteai.sberdevices.ru/v1/token"
31
+ payload = 'scope=GIGACHAT_API_CORP'
32
+ headers = {
33
+ 'Content-Type': 'application/x-www-form-urlencoded',
34
+ 'Accept': 'application/json',
35
+ 'RqUID': str(uuid.uuid4()),
36
+ 'Cookie': COOKIE,
37
+ 'Authorization': f'Basic {AUTH_TOKEN}'
38
+ }
39
+ response = requests.post(url, headers=headers, data=payload, timeout=timeout)
40
+ response_dict = response.json()
41
+ return {
42
+ 'access_token': response_dict['tok'],
43
+ 'expires_at': response_dict['exp']
44
+ }
45
+
46
+ def check_auth_token(token_data: Dict[str, Any]) -> bool:
47
+ """
48
+ Check if the authentication token is valid.
49
+
50
+ Args:
51
+ token_data (Dict[str, Any]): Dictionary containing token data.
52
+
53
+ Returns:
54
+ bool: True if the token is valid, False otherwise.
55
+ """
56
+ return token_data['expires_at'] - time.time() > 5
57
+
58
+ token_data: Optional[Dict[str, Any]] = None
59
+
60
+ def get_response(
61
+ prompt: str,
62
+ model: str,
63
+ timeout: int = 120,
64
+ n: int = 1,
65
+ fuse_key_word: Optional[str] = None,
66
+ use_giga_censor: bool = False,
67
+ max_tokens: int = 512,
68
+ ) -> requests.Response:
69
+ """
70
+ Send a text generation request to the API.
71
+
72
+ Args:
73
+ prompt (str): The input prompt.
74
+ model (str): The model to be used for generation.
75
+ timeout (int): Timeout duration in seconds.
76
+ n (int): Number of responses.
77
+ fuse_key_word (Optional[str]): Additional keyword to include in the prompt.
78
+ use_giga_censor (bool): Whether to use profanity filtering.
79
+ max_tokens (int): Maximum number of tokens in the response.
80
+
81
+ Returns:
82
+ requests.Response: API response.
83
+ """
84
+ global token_data
85
+
86
+ url = "https://beta.saluteai.sberdevices.ru/v1/chat/completions"
87
+ payload = json.dumps({
88
+ "model": model,
89
+ # "messages": [
90
+ # {
91
+ # "role": "user",
92
+ # "content": ' '.join([fuse_key_word, prompt]) if fuse_key_word else prompt
93
+ # }
94
+ # ],
95
+ "messages": [
96
+ {
97
+ "role": "system",
98
+ "content": giga_system_prompt_i2v
99
+ },
100
+ {
101
+ "role": "user",
102
+ "content": prompt
103
+ }
104
+ ],
105
+ "temperature": 0.87,
106
+ "top_p": 0.47,
107
+ "n": n,
108
+ "stream": False,
109
+ "max_tokens": max_tokens,
110
+ "repetition_penalty": 1.07,
111
+ "profanity_check": use_giga_censor
112
+ })
113
+
114
+ if token_data is None or not check_auth_token(token_data):
115
+ token_data = get_auth_token()
116
+
117
+ headers = {
118
+ 'Content-Type': 'application/json',
119
+ 'Accept': 'application/json',
120
+ 'Authorization': f'Bearer {token_data["access_token"]}'
121
+ }
122
+ response = requests.post(url, headers=headers, data=payload, timeout=timeout)
123
+ return response
124
+
125
+ def giga_generate(
126
+ prompt: str,
127
+ model_version: str = "GigaChat-Max",
128
+ max_tokens: int = 2048
129
+ ) -> str:
130
+ """
131
+ Generate text using the GigaChat model.
132
+
133
+ Args:
134
+ prompt (str): The input prompt.
135
+ model_version (str): The version of the model to use.
136
+ max_tokens (int): Maximum number of tokens in the response.
137
+
138
+ Returns:
139
+ str: Generated text.
140
+ """
141
+ response = get_response(
142
+ prompt,
143
+ model_version,
144
+ use_giga_censor=False,
145
+ max_tokens=128,
146
+ )
147
+ response_dict = response.json()
148
+
149
+ if response_dict['choices'][0]['finish_reason'] == 'blacklist':
150
+ print('GigaCensor triggered!')
151
+ return 'Censored Text'
152
+ else:
153
+ response_str = response_dict['choices'][0]['message']['content']
154
  return response_str