hindilanguage / app.py
dkhatate's picture
fixed err
6e22ae8
from accelerate import init_empty_weights
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Initialize the tokenizer and model with empty weights
with init_empty_weights():
tokenizer = AutoTokenizer.from_pretrained("Cognitive-Lab/LLama3-Gaja-Hindi-8B-v0.1")
model = AutoModelForCausalLM.from_pretrained("Cognitive-Lab/LLama3-Gaja-Hindi-8B-v0.1")
# Load the model weights after initialization
model.load_state_dict(torch.load("path/to/your/model/pytorch_model.bin"))
# Move the model to the GPU if available; otherwise, keep it on the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define a function to run inference
def generate_response(prompt, max_new_tokens=30):
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate response using the model
with torch.no_grad(): # Disable gradient calculation
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
# Decode the generated tokens to get the output text
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Example usage
if __name__ == "__main__":
prompt = "आपका नाम क्या है?" # Example Hindi prompt
response = generate_response(prompt)
print(response)