SREDWise's picture
Update app.py
48a9cfb verified
raw
history blame
2.51 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Dict
import os
def get_model():
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# Force CUDA to be the default device
if torch.cuda.is_available():
torch.set_default_device('cuda')
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load model with explicit device placement
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# Explicitly move model to GPU
if torch.cuda.is_available():
model = model.cuda()
return model, tokenizer
# Initialize model and tokenizer
model, tokenizer = get_model()
def generate(text: str, params: Dict) -> Dict:
try:
# Ensure we're using CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Tokenize with explicit device placement
inputs = tokenizer(text, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Print debug info
print(f"Input device: {inputs['input_ids'].device}")
print(f"Model device: {next(model.parameters()).device}")
# Generate with explicit device placement
with torch.cuda.device(device):
outputs = model.generate(
**inputs,
max_new_tokens=params.get('max_new_tokens', 500),
temperature=params.get('temperature', 0.7),
top_p=params.get('top_p', 0.95),
top_k=params.get('top_k', 50),
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": response}
except Exception as e:
print(f"Error in generation: {str(e)}")
# Print device information for debugging
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"Current CUDA device: {torch.cuda.current_device()}")
print(f"Device count: {torch.cuda.device_count()}")
raise e
def inference(inputs: Dict) -> Dict:
prompt = inputs.get("inputs", "")
params = inputs.get("parameters", {})
return generate(prompt, params)