SREDWise commited on
Commit
48a9cfb
1 Parent(s): 8a17d7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -45
app.py CHANGED
@@ -1,62 +1,73 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
- from typing import Dict, List
4
  import os
5
 
6
- model_id = "mistralai/Mistral-7B-Instruct-v0.2"
7
-
8
- # Initialize model and tokenizer with GPU settings
9
- def load_model():
 
 
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_id,
13
- device_map="auto",
14
  torch_dtype=torch.bfloat16,
15
  trust_remote_code=True
16
  )
17
- model.eval()
18
- return model, tokenizer
19
-
20
- # Load model and tokenizer globally
21
- model, tokenizer = load_model()
22
-
23
- def generate(prompt: str,
24
- max_new_tokens: int = 500,
25
- temperature: float = 0.7,
26
- top_p: float = 0.95,
27
- top_k: int = 50) -> Dict:
28
 
29
- inputs = tokenizer(prompt, return_tensors="pt", padding=True)
 
 
30
 
31
- # Move inputs to GPU
32
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
33
 
34
- outputs = model.generate(
35
- **inputs,
36
- max_new_tokens=max_new_tokens,
37
- temperature=temperature,
38
- top_p=top_p,
39
- top_k=top_k,
40
- pad_token_id=tokenizer.pad_token_id,
41
- eos_token_id=tokenizer.eos_token_id,
42
- )
43
-
44
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- return {"generated_text": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def inference(inputs: Dict) -> Dict:
48
  prompt = inputs.get("inputs", "")
49
  params = inputs.get("parameters", {})
50
-
51
- max_new_tokens = params.get("max_new_tokens", 500)
52
- temperature = params.get("temperature", 0.7)
53
- top_p = params.get("top_p", 0.95)
54
- top_k = params.get("top_k", 50)
55
-
56
- return generate(
57
- prompt,
58
- max_new_tokens=max_new_tokens,
59
- temperature=temperature,
60
- top_p=top_p,
61
- top_k=top_k
62
- )
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
+ from typing import Dict
4
  import os
5
 
6
+ def get_model():
7
+ model_id = "mistralai/Mistral-7B-Instruct-v0.2"
8
+
9
+ # Force CUDA to be the default device
10
+ if torch.cuda.is_available():
11
+ torch.set_default_device('cuda')
12
+
13
+ # Load tokenizer
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+
16
+ # Load model with explicit device placement
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_id,
 
19
  torch_dtype=torch.bfloat16,
20
  trust_remote_code=True
21
  )
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Explicitly move model to GPU
24
+ if torch.cuda.is_available():
25
+ model = model.cuda()
26
 
27
+ return model, tokenizer
 
28
 
29
+ # Initialize model and tokenizer
30
+ model, tokenizer = get_model()
31
+
32
+ def generate(text: str, params: Dict) -> Dict:
33
+ try:
34
+ # Ensure we're using CUDA
35
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+ print(f"Using device: {device}")
37
+
38
+ # Tokenize with explicit device placement
39
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
40
+ inputs = {k: v.to(device) for k, v in inputs.items()}
41
+
42
+ # Print debug info
43
+ print(f"Input device: {inputs['input_ids'].device}")
44
+ print(f"Model device: {next(model.parameters()).device}")
45
+
46
+ # Generate with explicit device placement
47
+ with torch.cuda.device(device):
48
+ outputs = model.generate(
49
+ **inputs,
50
+ max_new_tokens=params.get('max_new_tokens', 500),
51
+ temperature=params.get('temperature', 0.7),
52
+ top_p=params.get('top_p', 0.95),
53
+ top_k=params.get('top_k', 50),
54
+ pad_token_id=tokenizer.pad_token_id,
55
+ eos_token_id=tokenizer.eos_token_id,
56
+ )
57
+
58
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
+ return {"generated_text": response}
60
+
61
+ except Exception as e:
62
+ print(f"Error in generation: {str(e)}")
63
+ # Print device information for debugging
64
+ print(f"CUDA available: {torch.cuda.is_available()}")
65
+ if torch.cuda.is_available():
66
+ print(f"Current CUDA device: {torch.cuda.current_device()}")
67
+ print(f"Device count: {torch.cuda.device_count()}")
68
+ raise e
69
 
70
  def inference(inputs: Dict) -> Dict:
71
  prompt = inputs.get("inputs", "")
72
  params = inputs.get("parameters", {})
73
+ return generate(prompt, params)