muryshev commited on
Commit
8b45905
1 Parent(s): 8799be8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -16,6 +16,8 @@ ROLE_TOKENS = {
16
  "system": SYSTEM_TOKEN
17
  }
18
 
 
 
19
  # Create a lock object
20
  lock = threading.Lock()
21
 
@@ -38,16 +40,16 @@ model_name = "ggml-model-q4_1.gguf"
38
  snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
39
 
40
  model = Llama(
41
- model_path=model_name,
42
- n_ctx=4096,
43
- n_parts=1,
44
- #n_batch=100,
45
- logits_all=True,
46
- #n_threads=12,
47
- verbose=True,
48
- n_gpu_layers=35,
49
- n_gqa=8 #must be set for 70b models
50
- )
51
 
52
 
53
  def get_message_tokens(model, role, content):
@@ -189,14 +191,13 @@ def generate_response():
189
 
190
  tokens = []
191
 
192
- for message in messages:#[:-1]:
193
  if message.get("from") == "assistant":
194
  message_tokens = get_message_tokens(model=model, role="bot", content=message.get("content", ""))
195
  else:
196
  message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
197
 
198
  tokens.extend(message_tokens)
199
- #LINEBREAK_TOKEN)
200
 
201
  #app.logger.info('model.eval start')
202
  #model.eval(tokens)
@@ -210,12 +211,13 @@ def generate_response():
210
 
211
  tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
212
 
 
213
  app.logger.info('Prompt:')
214
- app.logger.info(model.detokenize(tokens).decode("utf-8", errors="ignore"))
215
 
216
  app.logger.info('Generate started')
217
  generator = model.generate(
218
- tokens,
219
  top_k=top_k,
220
  top_p=top_p,
221
  temp=temperature,
 
16
  "system": SYSTEM_TOKEN
17
  }
18
 
19
+ CONTEXT_SIZE = 4096
20
+
21
  # Create a lock object
22
  lock = threading.Lock()
23
 
 
40
  snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
41
 
42
  model = Llama(
43
+ model_path=model_name,
44
+ n_ctx=CONTEXT_SIZE,
45
+ n_parts=1,
46
+ #n_batch=100,
47
+ logits_all=True,
48
+ #n_threads=12,
49
+ verbose=True,
50
+ n_gpu_layers=35,
51
+ n_gqa=8 #must be set for 70b models
52
+ )
53
 
54
 
55
  def get_message_tokens(model, role, content):
 
191
 
192
  tokens = []
193
 
194
+ for message in messages:
195
  if message.get("from") == "assistant":
196
  message_tokens = get_message_tokens(model=model, role="bot", content=message.get("content", ""))
197
  else:
198
  message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
199
 
200
  tokens.extend(message_tokens)
 
201
 
202
  #app.logger.info('model.eval start')
203
  #model.eval(tokens)
 
211
 
212
  tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
213
 
214
+
215
  app.logger.info('Prompt:')
216
+ app.logger.info(model.detokenize(tokens[:CONTEXT_SIZE]).decode("utf-8", errors="ignore"))
217
 
218
  app.logger.info('Generate started')
219
  generator = model.generate(
220
+ tokens[:CONTEXT_SIZE],
221
  top_k=top_k,
222
  top_p=top_p,
223
  temp=temperature,