SREDWise's picture
Create app.py file
8a17d7c verified
raw
history blame
1.75 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Dict, List
import os
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# Initialize model and tokenizer with GPU settings
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
model.eval()
return model, tokenizer
# Load model and tokenizer globally
model, tokenizer = load_model()
def generate(prompt: str,
max_new_tokens: int = 500,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50) -> Dict:
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
# Move inputs to GPU
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": response}
def inference(inputs: Dict) -> Dict:
prompt = inputs.get("inputs", "")
params = inputs.get("parameters", {})
max_new_tokens = params.get("max_new_tokens", 500)
temperature = params.get("temperature", 0.7)
top_p = params.get("top_p", 0.95)
top_k = params.get("top_k", 50)
return generate(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k
)