DL4NLP / app.py
santanus24's picture
uploading all .py files
9b5fe77 verified
raw
history blame
3.74 kB
import torch
from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import requests
from PIL import Image
import requests
import gradio as gr
# Load translation model and tokenizer
translate_model_name = "facebook/mbart-large-50-many-to-many-mmt"
translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(translate_model_name)
# load the base model in 4 bit quantized
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
)
# finetuned model adapter path (Hugging Face Hub)
model_id = 'somnathsingh31/llava-1.5-7b-hf-ft-merged_model'
# merge the models
merged_model = LlavaForConditionalGeneration.from_pretrained(model_id,
quantization_config=quantization_config,
torch_dtype=torch.float16)
# create processor from base model
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
# function to translate
def translate(text, source_lang, target_lang):
# Set source language
tokenizer.src_lang = source_lang
# Encode the text
encoded_text = tokenizer(text, return_tensors="pt")
# Force target language token
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
# Generate the translation
generated_tokens = translate_model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id)
# Decode the translation
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
# function for making inference
def ask_vlm(hindi_input_text, image):
# translate from Hindi to English
prompt_eng = translate(hindi_input_text, "hi_IN", "en_XX")
prompt = "USER: <image>\n" + prompt_eng + " ASSISTANT:"
# If image is uploaded, open the image from bytes, else open from URL
if hasattr(image, 'read'):
image = Image.open(image)
else:
image = Image.open(requests.get(image, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
generate_ids = merged_model.generate(**inputs, max_new_tokens=250)
decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
assistant_index = decoded_response.find("ASSISTANT:")
# Extract text after "ASSISTANT:"
if assistant_index != -1:
text_after_assistant = decoded_response[assistant_index + len("ASSISTANT:"):]
# Remove leading and trailing whitespace
text_after_assistant = text_after_assistant.strip()
else:
text_after_assistant = None
hindi_output_text = translate(text_after_assistant, "en_XX", "hi_IN")
return hindi_output_text
# Define Gradio interface
input_image = gr.inputs.Image(type="pil", label="Input Image (Upload or URL)")
input_question = gr.inputs.Textbox(lines=2, label="Question (Hindi)")
output_text = gr.outputs.Textbox(label="Response (Hindi)")
# Create Gradio app
gr.Interface(fn=ask_vlm, inputs=[input_question, input_image], outputs=output_text, title="Image and Text-based Dialogue System", description="Enter a question in Hindi and an image, either by uploading or providing URL, and get a response in Hindi.").launch()
if __name__ == '__main__':
image_url = 'https://images.metmuseum.org/CRDImages/ad/original/138425.jpg'
user_query = 'यह किस प्रकार की कला है? विस्तार से बताइये'
output = ask_vlm(user_query, image_url)
print('Output:\n', output)