dkhatate commited on
Commit
5b0f86d
1 Parent(s): 64071f5

modified codde

Browse files
Files changed (2) hide show
  1. app.py +26 -12
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,17 +1,31 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # Load the model and tokenizer
6
- tokenizer = AutoTokenizer.from_pretrained("Cognitive-Lab/LLama3-Gaja-Hindi-8B-v0.1")
7
- model = AutoModelForCausalLM.from_pretrained("Cognitive-Lab/LLama3-Gaja-Hindi-8B-v0.1").to("cuda")
 
8
 
9
- def generate_text(prompt):
10
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
11
- with torch.no_grad():
12
- output = model.generate(inputs["input_ids"], max_new_tokens=50)
13
- return tokenizer.decode(output[0], skip_special_tokens=True)
14
 
15
- # Launch Gradio app
16
- interface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
17
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate import init_empty_weights
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Initialize the tokenizer and model with empty weights
6
+ with init_empty_weights():
7
+ tokenizer = AutoTokenizer.from_pretrained("Cognitive-Lab/LLama3-Gaja-Hindi-8B-v0.1")
8
+ model = AutoModelForCausalLM.from_pretrained("Cognitive-Lab/LLama3-Gaja-Hindi-8B-v0.1")
9
 
10
+ # Move the model to the GPU if available; otherwise, keep it on the CPU
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = model.to(device)
 
 
13
 
14
+ # Define a function to run inference
15
+ def generate_response(prompt, max_new_tokens=30):
16
+ # Tokenize the input
17
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
18
+
19
+ # Generate response using the model
20
+ with torch.no_grad(): # Disable gradient calculation
21
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
22
+
23
+ # Decode the generated tokens to get the output text
24
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
+ return response
26
+
27
+ # Example usage
28
+ if __name__ == "__main__":
29
+ prompt = "आपका नाम क्या है?" # Example Hindi prompt
30
+ response = generate_response(prompt)
31
+ print(response)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  transformers
2
  torch # If you are using PyTorch
3
  gradio
 
 
1
  transformers
2
  torch # If you are using PyTorch
3
  gradio
4
+ accelerate