SREDWise commited on
Commit
8a17d7c
1 Parent(s): fa99fc6

Create app.py file

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+ from typing import Dict, List
4
+ import os
5
+
6
+ model_id = "mistralai/Mistral-7B-Instruct-v0.2"
7
+
8
+ # Initialize model and tokenizer with GPU settings
9
+ def load_model():
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_id,
13
+ device_map="auto",
14
+ torch_dtype=torch.bfloat16,
15
+ trust_remote_code=True
16
+ )
17
+ model.eval()
18
+ return model, tokenizer
19
+
20
+ # Load model and tokenizer globally
21
+ model, tokenizer = load_model()
22
+
23
+ def generate(prompt: str,
24
+ max_new_tokens: int = 500,
25
+ temperature: float = 0.7,
26
+ top_p: float = 0.95,
27
+ top_k: int = 50) -> Dict:
28
+
29
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
30
+
31
+ # Move inputs to GPU
32
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
33
+
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_new_tokens=max_new_tokens,
37
+ temperature=temperature,
38
+ top_p=top_p,
39
+ top_k=top_k,
40
+ pad_token_id=tokenizer.pad_token_id,
41
+ eos_token_id=tokenizer.eos_token_id,
42
+ )
43
+
44
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return {"generated_text": response}
46
+
47
+ def inference(inputs: Dict) -> Dict:
48
+ prompt = inputs.get("inputs", "")
49
+ params = inputs.get("parameters", {})
50
+
51
+ max_new_tokens = params.get("max_new_tokens", 500)
52
+ temperature = params.get("temperature", 0.7)
53
+ top_p = params.get("top_p", 0.95)
54
+ top_k = params.get("top_k", 50)
55
+
56
+ return generate(
57
+ prompt,
58
+ max_new_tokens=max_new_tokens,
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ top_k=top_k
62
+ )