muryshev commited on
Commit
c04a60f
1 Parent(s): 5de4fc8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from peft import PeftModel, PeftConfig
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+ import torch
5
+
6
+ app = Flask(__name__)
7
+
8
+ MODEL_NAME = "IlyaGusev/saiga2_70b_lora"
9
+ DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
10
+ DEFAULT_SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
11
+
12
+ class Conversation:
13
+ def __init__(
14
+ self,
15
+ message_template=DEFAULT_MESSAGE_TEMPLATE,
16
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
17
+ start_token_id=1,
18
+ bot_token_id=9225
19
+ ):
20
+ self.message_template = message_template
21
+ self.start_token_id = start_token_id
22
+ self.bot_token_id = bot_token_id
23
+ self.messages = [{
24
+ "role": "system",
25
+ "content": system_prompt
26
+ }]
27
+
28
+ def get_start_token_id(self):
29
+ return self.start_token_id
30
+
31
+ def get_bot_token_id(self):
32
+ return self.bot_token_id
33
+
34
+ def add_user_message(self, message):
35
+ self.messages.append({
36
+ "role": "user",
37
+ "content": message
38
+ })
39
+
40
+ def add_bot_message(self, message):
41
+ self.messages.append({
42
+ "role": "bot",
43
+ "content": message
44
+ })
45
+
46
+ def get_prompt(self, tokenizer):
47
+ final_text = ""
48
+ for message in self.messages:
49
+ message_text = self.message_template.format(**message)
50
+ final_text += message_text
51
+ final_text += tokenizer.decode([self.start_token_id, self.bot_token_id])
52
+ return final_text.strip()
53
+
54
+ def generate(model, tokenizer, prompt, generation_config):
55
+ data = tokenizer(prompt, return_tensors="pt")
56
+ data = {k: v.to(model.device) for k, v in data.items()}
57
+ output_ids = model.generate(
58
+ **data,
59
+ generation_config=generation_config
60
+ )[0]
61
+ output_ids = output_ids[len(data["input_ids"][0]):]
62
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
63
+ return output.strip()
64
+
65
+ config = PeftConfig.from_pretrained(MODEL_NAME)
66
+
67
+ # Use GPU if available, else fall back to CPU
68
+ device = "cuda" if torch.cuda.is_available() else "cpu"
69
+
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ config.base_model_name_or_path,
72
+ load_in_8bit=True,
73
+ torch_dtype=torch.float16,
74
+ device_map=device
75
+ )
76
+ model = PeftModel.from_pretrained(
77
+ model,
78
+ MODEL_NAME,
79
+ torch_dtype=torch.float16
80
+ )
81
+ model.eval()
82
+
83
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
84
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
85
+
86
+ @app.route('/run_inference', methods=['POST'])
87
+ def run_inference():
88
+ try:
89
+ data = request.json
90
+ inputs = data.get('inputs', [])
91
+
92
+ conversation = Conversation()
93
+ outputs = []
94
+
95
+ for inp in inputs:
96
+ conversation.add_user_message(inp)
97
+ prompt = conversation.get_prompt(tokenizer)
98
+ output = generate(model, tokenizer, prompt, generation_config)
99
+ outputs.append({'input': inp, 'output': output})
100
+
101
+ return jsonify(outputs)
102
+
103
+ except Exception as e:
104
+ return jsonify({'error': str(e)}), 500
105
+
106
+ if __name__ == '__main__':
107
+ app.run(port=7860)