Spaces:
Runtime error
Runtime error
uploading all .py files
Browse files- app.py +92 -0
- data_collection.py +93 -0
- dataset_preparation.py +125 -0
- inference.py +53 -0
- language_translation.py +116 -0
- llava_finetune.py +143 -0
- translation_model_evaluation.py +84 -0
- vqa_collection.py +89 -0
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor
|
3 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
import requests
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
|
10 |
+
# Load translation model and tokenizer
|
11 |
+
translate_model_name = "facebook/mbart-large-50-many-to-many-mmt"
|
12 |
+
translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name)
|
13 |
+
tokenizer = MBart50TokenizerFast.from_pretrained(translate_model_name)
|
14 |
+
|
15 |
+
# load the base model in 4 bit quantized
|
16 |
+
quantization_config = BitsAndBytesConfig(
|
17 |
+
load_in_4bit=True,
|
18 |
+
)
|
19 |
+
|
20 |
+
# finetuned model adapter path (Hugging Face Hub)
|
21 |
+
model_id = 'somnathsingh31/llava-1.5-7b-hf-ft-merged_model'
|
22 |
+
|
23 |
+
# merge the models
|
24 |
+
merged_model = LlavaForConditionalGeneration.from_pretrained(model_id,
|
25 |
+
quantization_config=quantization_config,
|
26 |
+
torch_dtype=torch.float16)
|
27 |
+
|
28 |
+
# create processor from base model
|
29 |
+
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
30 |
+
|
31 |
+
# function to translate
|
32 |
+
def translate(text, source_lang, target_lang):
|
33 |
+
# Set source language
|
34 |
+
tokenizer.src_lang = source_lang
|
35 |
+
|
36 |
+
# Encode the text
|
37 |
+
encoded_text = tokenizer(text, return_tensors="pt")
|
38 |
+
|
39 |
+
# Force target language token
|
40 |
+
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
|
41 |
+
|
42 |
+
# Generate the translation
|
43 |
+
generated_tokens = translate_model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id)
|
44 |
+
|
45 |
+
# Decode the translation
|
46 |
+
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
47 |
+
|
48 |
+
return translation
|
49 |
+
|
50 |
+
|
51 |
+
# function for making inference
|
52 |
+
def ask_vlm(hindi_input_text, image):
|
53 |
+
# translate from Hindi to English
|
54 |
+
prompt_eng = translate(hindi_input_text, "hi_IN", "en_XX")
|
55 |
+
prompt = "USER: <image>\n" + prompt_eng + " ASSISTANT:"
|
56 |
+
|
57 |
+
# If image is uploaded, open the image from bytes, else open from URL
|
58 |
+
if hasattr(image, 'read'):
|
59 |
+
image = Image.open(image)
|
60 |
+
else:
|
61 |
+
image = Image.open(requests.get(image, stream=True).raw)
|
62 |
+
|
63 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt")
|
64 |
+
generate_ids = merged_model.generate(**inputs, max_new_tokens=250)
|
65 |
+
decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
66 |
+
assistant_index = decoded_response.find("ASSISTANT:")
|
67 |
+
|
68 |
+
# Extract text after "ASSISTANT:"
|
69 |
+
if assistant_index != -1:
|
70 |
+
text_after_assistant = decoded_response[assistant_index + len("ASSISTANT:"):]
|
71 |
+
# Remove leading and trailing whitespace
|
72 |
+
text_after_assistant = text_after_assistant.strip()
|
73 |
+
else:
|
74 |
+
text_after_assistant = None
|
75 |
+
|
76 |
+
hindi_output_text = translate(text_after_assistant, "en_XX", "hi_IN")
|
77 |
+
return hindi_output_text
|
78 |
+
|
79 |
+
# Define Gradio interface
|
80 |
+
input_image = gr.inputs.Image(type="pil", label="Input Image (Upload or URL)")
|
81 |
+
input_question = gr.inputs.Textbox(lines=2, label="Question (Hindi)")
|
82 |
+
output_text = gr.outputs.Textbox(label="Response (Hindi)")
|
83 |
+
|
84 |
+
# Create Gradio app
|
85 |
+
gr.Interface(fn=ask_vlm, inputs=[input_question, input_image], outputs=output_text, title="Image and Text-based Dialogue System", description="Enter a question in Hindi and an image, either by uploading or providing URL, and get a response in Hindi.").launch()
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
image_url = 'https://images.metmuseum.org/CRDImages/ad/original/138425.jpg'
|
90 |
+
user_query = 'यह किस प्रकार की कला है? विस्तार से बताइये'
|
91 |
+
output = ask_vlm(user_query, image_url)
|
92 |
+
print('Output:\n', output)
|
data_collection.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from bs4 import BeautifulSoup
|
7 |
+
|
8 |
+
def download_metmuseum_data(ids):
|
9 |
+
data_final = []
|
10 |
+
for id in ids:
|
11 |
+
response = requests.get(f'https://collectionapi.metmuseum.org/public/collection/v1/objects/{id}')
|
12 |
+
if response.status_code == 200:
|
13 |
+
data_final.append(json.loads(response.text))
|
14 |
+
|
15 |
+
with open('met_museum_collection.json', 'w') as json_file:
|
16 |
+
json.dump(data_final, json_file)
|
17 |
+
|
18 |
+
def filter_and_save_data(keys_to_keep):
|
19 |
+
df = pd.read_json('/content/met_museum_collection.json')
|
20 |
+
final_data_df = df[keys_to_keep]
|
21 |
+
final_data_df.to_csv('final_data_df.csv', index=False)
|
22 |
+
final_data_df.to_json('final_data_json.json', orient='records', indent=4)
|
23 |
+
return final_data_df
|
24 |
+
|
25 |
+
def download_and_convert_image(url):
|
26 |
+
response = requests.get(url)
|
27 |
+
if response.status_code == 200:
|
28 |
+
nparr = np.frombuffer(response.content, np.uint8)
|
29 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
30 |
+
return img
|
31 |
+
else:
|
32 |
+
return None
|
33 |
+
|
34 |
+
def extract_description(url):
|
35 |
+
try:
|
36 |
+
response = requests.get(url)
|
37 |
+
if response.status_code == 200:
|
38 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
39 |
+
description_element = soup.find('div', class_='artwork__intro__desc')
|
40 |
+
if description_element:
|
41 |
+
description = description_element.get_text(strip=True)
|
42 |
+
return description
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
else:
|
46 |
+
return None
|
47 |
+
except Exception as e:
|
48 |
+
return str(e)
|
49 |
+
|
50 |
+
def create_master_data(df):
|
51 |
+
df['description'] = df['objectURL'].apply(extract_description)
|
52 |
+
df.to_csv('master_data.csv', index=False)
|
53 |
+
df.to_json('master_data.json', orient='records', indent=4)
|
54 |
+
|
55 |
+
|
56 |
+
def main():
|
57 |
+
# filtered ids based on Indian data from the overall data of metmuseum dataset
|
58 |
+
ids= [1986,7608,9748,38638,86116,86136,86167,86169,86171,86173,86176,86177,86178,86179,86186,86187,86189,86195,86216,86223,126982,127021,127511,308007,443097,443099,444411,444412,444539,444541,444546,444552,444557,444575,444577,444582,444588,444592,444609,444618,444623,444628,444634,444648,444650,444652,444657,444658,444661,444662,444667,444668,444670,444671,444672,444674,444675,444676,444678,444679,444681,444684,444689,444692,444694,444695,444720,444814,444816,444832,444836,444838,444854,444856,444857,444859,444860,444861,444863,444864,444865,444866,444867,444868,444963,444964,444984,444987,444988,444989,445003,445004,445005,445227,445231,445232,445233,445239,445245,445246,445261,445280,445362,445645,445646,445647,445648,445689,445690,445691,445693,445708,445869,445870,445871,445872,445873,445874,445875,445878,445879,445880,445881,445882,445889,445890,445894,445957,445958,445961,445963,445975,445976,445997,445998,445999,446000,446001,446002,446003,446004,446005,446007,446157,446158,446163,446164,446189,446266,446276,446278,446280,446290,446293,446299,446549,446556,446558,446560,446561,446562,446563,446564,446566,446567,446571,446572,446573,446574,446575,446576,446577,446578,446579,446580,446581,446582,446583,446584,446585,446586,446587,446588,446589,446591,446633,446634,446635,446638,446640,446643,446644,446646,446661,446662,446663,446704,446705,446706,446707,446708,446709,446710,446711,446712,446713,446714,446715,446716,446717,446718,446719,446720,446721,446722,446723,446724,446725,446726,446727,446728,446729,446730,446731,446732,446733,446734,446735,446736,446737,446738,446739,446740,446741,446742,446743,446744,446745,446746,446747,446748,446749,446750,446751,446752,446753,446754,446755,446756,446757,446759,446760,446761,446762,446763,446764,446765,446766,446767,446768,446769,446770,446771,446772,446773,446774,446775,446776,446777,446778,446779,446780,446781,446782,446783,446784,446785,446786,446787,446788,446789,446790,446791,446792,446793,446794,446795,446796,446797,446798,446799,446800,446801,446802,446804,446807,446808,446810,446816,446817,446855,446857,446859,446872,446882,446893,446967,446987,446988,446991,446992,446993,447021,447050,447051,447053,447054,447055,447062,447067,447082,447083,447084,447086,447087,447088,447089,447090,447091,447092,447093,447094,447095,447101,447102,447108,447109,447114,447118,447297,447298,447299,447300,447301,447302,447303,447304,447305,447306,447307,447308,447364,447365,447388,447389,447399,447401,447402,447403,447404,447405,447406,447407,447408,447409,447410,447411,447412,447413,447414,447415,447416,447417,447419,447421,447530,447531,447606,447627,447675,447703,447743,447744,447752,447755,447757,447758,447760,447773,447774,447775,447776,447780,447795,447796,447797,447799,447800,447801,447808,447816,447817,447821,447822,447848,447849,447867,447868,447918,447919,447939,447944,447946,447947,447948,447969,448015,448018,448147,448148,448149,448150,448151,448152,448153,448154,448155,448156,448157,448158,448159,448160,448161,448162,448163,448164,448165,448166,448167,448168,448178,448180,448188,448189,448190,448191,448206,448212,448213,448214,448215,448216,448241,448247,448257,448277,448278,448351,448352,448378,448379,448380,448450,448460,448464,448465,448466,448467,448468,448469,448470,448471,448472,448473,448474,448475,448476,448477,448478,448479,448480,448481,448484,448485,448486,448487,448488,448489,448491,448494,448495,448508,448517,448518,448519,448521,448522,448523,448524,448525,448526,448527,448528,448529,448530,448531,448532,448533,448534,448535,448536,448553,448583,448587,448589,448591,448657,448677,448923,448936,448937,449094,449105,450374,450395,450465,450469,450471,450496,450508,450529,450538,450572,450573,450584,450586,450592,450595,450601,450610,450613,450614,450618,450622,450625,450631,450639,450642,450665,450674,450677,450699,450708,450733,450750,450754,450964,450986,451009,451010,451018,451021,451118,451186,451195,451201,451257,451258,451259,451260,451261,451262,451263,451264,451265,451266,451267,451268,451269,451270,451271,451272,451273,451274,451275,451276,451277,451278,451279,451280,451281,451282,451283,451284,451285,451286,451287,451288,451289,451291,451292,451293,451294,451295,451296,451297,451313,451314,451315,451316,451317,451318,451334,451337,451393,451401,451409,451413,451479,451483,451485,451611,451683,451961,451982,452104,452192,452197,452205,452214,452294,452307,452310,452394,452395,452418,452419,452520,452521,452554,452680,452681,452734,452745,452746,452749,452776,452777,452780,452805,452811,452813,452814,452815,452818,452819,452836,452912,452913,452942,452943,452944,452963,453053,453054,453058,453059,453060,453061,453062,453063,453064,453065,453066,453067,453068,453069,453070,453071,453072,453073,453074,453075,453076,453077,453078,453079,453080,453081,453082,453083,453084,453085,453086,453087,453088,453089,453090,453091,453092,453093,453094,453095,453096,453097,453098,453099,453100,453101,453102,453103,453104,453105,453106,453107,453108,453109,453110,453111,453112,453113,453114,453115,453116,453117,453118,453119,453120,453121,453122,453123,453124,453125,453126,453127,453128,453129,453130,453131,453132,453133,453134,453135,453136,453137,453138,453139,453140,453141,453142,453143,453144,453145,453146,453147,453148,453149,453150,453151,453152,453153,453154,453159,453160,453161,453163,453169,453180,453183,453184,453188,453191,453192,453193,453194,453195,453206,453207,453210,453212,453213,453214,453215,453216,453217,453218,453219,453220,453221,453222,453223,453224,453225,453226,453227,453228,453234,453241,453242,453243,453253,453258,453259,453260,453261,453263,453265,453266,453267,453268,453271,453272,453273,453276,453280,453281,453287,453288,453309,453313,453328,453329,453330,453331,453332,453334,453335,453337,453338,453339,453340,453341,453343,453344,453348,453350,453354,453355,453356,453360,453366,453367,453384,453386,453387,453397,453412,453415,453417,453427,453428,453435,453436,453442,453479,453482,453498,453500,453564,453642,453673,453975,453983,454003,454011,454013,454027,454038,454049,454054,454064,454083,454084,454087,454088,454089,454090,454091,454093,454094,454095,454096,454097,454098,454099,454100,454113,454114,454115,454139,454343,454344,454357,454358,454359,454400,454406,454407,454408,454409,454415,454429,454445,454515,454530,454568,454601,454619,454620,454623,454624,454625,454713,454718,454738,454739,454768,455027,455028,455029,455030,455031,455032,455033,455034,455035,455036,455048,455060,455064,455285,455300,455544,455545,455546,455547,455548,456128,456312,456313,456314,456342,456875,456921,456922,456923,456926,456927,456928,456949,456951,456958,456959,456964,456966,456967,456968,456985,456986,457019,457020,457054,457058,457711,457728,457730,457743,457755,457758,457771,457781,457782,457783,457784,457788,457977,458260,458378,458437,458438,458439,458440,458444,458445,500709,500711,500712,500713,500714,500715,500716,500717,500718,500719,500720,500731,500732,500733,500734,500735,500736,500738,500739,500740,500741,500742,500743,500744,500745,500748,500749,500750,500751,500752,500754,500755,500756,500757,500758,500759,500760,500761,500763,500764,500767,500769,500771,500772,500773,500774,500775,500777,500778,500779,500781,500783,500785,500790,500791,500793,500794,500800,500801,500802,500806,500809,500835,500838,500844,500869,500909,500914,500994,501466,501811,501879,501899,501903,501981,502018,502072,502111,502149,502150,502236,502310,503431,503453,503554,503635,503672,503937,504217,505237,505818,506064,506151,632624,646829,659894,659901,659903,659904,659905,659906,662008,684250,688168,688233,688240,688242,698587,698590,698600,698602,698612,698613,698616,698646,698647,698648,698649,698651,698652,698655,698657,698660,698661,698662,698663,698664,698671,698672,698673,707408,717767,734073,752291,761737,772012,781487,781855,781858,781892,784283,784629,817378,820019,821905,822698,823157,823285,827351,835238,838357,843822,843823,853714,872977]
|
59 |
+
|
60 |
+
# download the data based on ids above
|
61 |
+
data_final = []
|
62 |
+
|
63 |
+
for id in ids:
|
64 |
+
response = requests.get(f'https://collectionapi.metmuseum.org/public/collection/v1/objects/{id}')
|
65 |
+
if response.status_code == 200:
|
66 |
+
data_final.append(json.loads(response.text))
|
67 |
+
|
68 |
+
# Save the data_final list as JSON
|
69 |
+
with open('met_museum_collection.json', 'w') as json_file:
|
70 |
+
json.dump(data_final, json_file)
|
71 |
+
|
72 |
+
# Read JSON file into a DataFrame
|
73 |
+
df = pd.read_json('met_museum_collection.json')
|
74 |
+
|
75 |
+
# List of keys to keep
|
76 |
+
keys_to_keep = [
|
77 |
+
'objectID', 'primaryImage', 'department', 'objectName', 'title', 'culture', 'period',
|
78 |
+
'artistDisplayName', 'objectDate', 'medium', 'dimensions', 'city',
|
79 |
+
'state', 'country', 'region', 'classification', 'repository', 'objectURL'
|
80 |
+
]
|
81 |
+
|
82 |
+
# Filter data
|
83 |
+
final_data_df = df[keys_to_keep]
|
84 |
+
|
85 |
+
# create new column with description
|
86 |
+
final_data_df['description'] = final_data_df['objectURL'].apply(extract_description)
|
87 |
+
|
88 |
+
# Save as csv and json
|
89 |
+
final_data_df.to_csv('master_data.csv', index=False)
|
90 |
+
final_data_df.to_json('master_data.json', orient='records', indent=4)
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
dataset_preparation.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
from datasets import Dataset
|
7 |
+
|
8 |
+
def modify_dataframe_and_extract_data(df):
|
9 |
+
data_list = []
|
10 |
+
for _, row in df.iterrows():
|
11 |
+
messages = []
|
12 |
+
for i in range(1, 5):
|
13 |
+
user_question = row[f'Question{i}']
|
14 |
+
user_answer = row[f'Answer{i}']
|
15 |
+
if user_question:
|
16 |
+
message_content = [{'index': None, 'text': user_question, 'type': 'text'}]
|
17 |
+
if i == 1:
|
18 |
+
message_content.append({'index': 0, 'text': None, 'type': 'image'})
|
19 |
+
messages.append({'content': message_content, 'role': 'user'})
|
20 |
+
if user_answer:
|
21 |
+
messages.append({'content': [{'index': None, 'text': user_answer, 'type': 'text'}], 'role': 'assistant'})
|
22 |
+
image = Image.open(row['imagePath'])
|
23 |
+
data_list.append({'messages': messages, 'images': [image]})
|
24 |
+
return {'messages': [data['messages'] for data in data_list], 'images': [data['images'] for data in data_list]}
|
25 |
+
|
26 |
+
|
27 |
+
def download_and_resize_images(df, image_dir, target_size=(250, 250)):
|
28 |
+
image_paths = []
|
29 |
+
for index, row in df.iterrows():
|
30 |
+
image_url = row['primaryImageLink']
|
31 |
+
object_id = row['objectID']
|
32 |
+
if image_url:
|
33 |
+
# Extract filename from the URL
|
34 |
+
filename = os.path.join(image_dir, f"{object_id}.jpg")
|
35 |
+
# Download image from the URL
|
36 |
+
response = requests.get(image_url)
|
37 |
+
if response.status_code == 200:
|
38 |
+
# Open the image using PIL
|
39 |
+
image = Image.open(BytesIO(response.content))
|
40 |
+
# Resize the image
|
41 |
+
image = image.resize(target_size)
|
42 |
+
# Save the resized image
|
43 |
+
image.save(filename)
|
44 |
+
image_paths.append(filename)
|
45 |
+
else:
|
46 |
+
print(f"Failed to download image from {image_url}")
|
47 |
+
image_paths.append(None)
|
48 |
+
else:
|
49 |
+
image_paths.append(None)
|
50 |
+
return image_paths
|
51 |
+
|
52 |
+
def split_data_dict(data_dict, train_ratio=0.7, test_ratio=0.2, val_ratio=0.1):
|
53 |
+
assert train_ratio + test_ratio + val_ratio == 1.0, "Ratios must sum up to 1.0"
|
54 |
+
|
55 |
+
total_samples = len(data_dict['messages'])
|
56 |
+
train_size = int(total_samples * train_ratio)
|
57 |
+
test_size = int(total_samples * test_ratio)
|
58 |
+
val_size = int(total_samples * val_ratio)
|
59 |
+
|
60 |
+
train_data_dict = {
|
61 |
+
'messages': data_dict['messages'][:train_size],
|
62 |
+
'images': data_dict['images'][:train_size]
|
63 |
+
}
|
64 |
+
test_data_dict = {
|
65 |
+
'messages': data_dict['messages'][train_size:train_size + test_size],
|
66 |
+
'images': data_dict['images'][train_size:train_size + test_size]
|
67 |
+
}
|
68 |
+
val_data_dict = {
|
69 |
+
'messages': data_dict['messages'][-val_size:],
|
70 |
+
'images': data_dict['images'][-val_size:]
|
71 |
+
}
|
72 |
+
|
73 |
+
return train_data_dict, test_data_dict, val_data_dict
|
74 |
+
|
75 |
+
|
76 |
+
def save_data_dict_as_arrow(data_dict, file_path):
|
77 |
+
# Convert the dictionary to a Dataset object
|
78 |
+
dataset = Dataset.from_dict(data_dict)
|
79 |
+
|
80 |
+
# Save the dataset to an Arrow file
|
81 |
+
dataset.save_to_disk(file_path)
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
# Example usage:
|
85 |
+
|
86 |
+
# df = pd.read_csv("/data/data_set_metmuseum.csv")
|
87 |
+
# df1 = df[['objectID', 'primaryImageLink', 'Question1', 'Answer1', 'Question2', 'Answer2', 'Question3', 'Answer3', 'Question4', 'Answer4']]
|
88 |
+
# df2 = df1.sample(frac=1)
|
89 |
+
# df3 = df2.head(250)
|
90 |
+
|
91 |
+
# df4 = df3.copy()
|
92 |
+
|
93 |
+
df4 = pd.read_csv("sampled_data250.csv")
|
94 |
+
paths = ['input_dataset', os.path.join('input_dataset', 'images'), 'output_dataset']
|
95 |
+
for path in paths:
|
96 |
+
if not os.path.exists(path):
|
97 |
+
os.makedirs(path)
|
98 |
+
|
99 |
+
# Call the function to download and resize images
|
100 |
+
image_dir = 'input_dataset/images'
|
101 |
+
# image_paths = download_and_resize_images(df4, image_dir)
|
102 |
+
|
103 |
+
# Update the DataFrame with the resized image paths
|
104 |
+
# new_df = df4.copy() # Create a copy of the original DataFrame
|
105 |
+
# df4['imagePath'] = image_paths # Add a new column 'imagePath' containing the resized image paths
|
106 |
+
# df4 = df4.drop(['primaryImageLink'], axis=1)
|
107 |
+
|
108 |
+
# Call the function to modify the DataFrame and extract data
|
109 |
+
data_dict = modify_dataframe_and_extract_data(df4)
|
110 |
+
# split data_dict into train, test, valid
|
111 |
+
train_data_dict, test_data_dict, val_data_dict = split_data_dict(data_dict, train_ratio=0.6, test_ratio=0.2, val_ratio=0.2)
|
112 |
+
|
113 |
+
# save these as arrow dataset
|
114 |
+
save_data_dict_as_arrow(train_data_dict, os.path.join('output_dataset', 'train.arrow'))
|
115 |
+
save_data_dict_as_arrow(test_data_dict, os.path.join('output_dataset', 'test.arrow'))
|
116 |
+
save_data_dict_as_arrow(val_data_dict, os.path.join('output_dataset', 'val.arrow'))
|
117 |
+
|
118 |
+
# save to zip format
|
119 |
+
import shutil
|
120 |
+
shutil.make_archive("/content/input_dataset", "zip", "/content/input_dataset")
|
121 |
+
shutil.make_archive("/content/output_dataset", "zip", "/content/output_dataset")
|
122 |
+
|
123 |
+
# read arrow from disk
|
124 |
+
test_data = Dataset.load_from_disk("output_dataset/test.arrow")
|
125 |
+
test_data
|
inference.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor
|
3 |
+
from peft import PeftModel
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def load_base_model(model_id):
|
9 |
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
10 |
+
base_model = LlavaForConditionalGeneration.from_pretrained(model_id,
|
11 |
+
quantization_config=quantization_config,
|
12 |
+
torch_dtype=torch.float16)
|
13 |
+
return base_model
|
14 |
+
|
15 |
+
|
16 |
+
def load_peft_lora_adapter(base_model, peft_lora_adapter_path):
|
17 |
+
peft_lora_adapter = PeftModel.from_pretrained(base_model, peft_lora_adapter_path, adapter_name="lora_adapter")
|
18 |
+
return peft_lora_adapter
|
19 |
+
|
20 |
+
|
21 |
+
def merge_adapters(base_model, peft_lora_adapter_path):
|
22 |
+
base_model.load_adapter(peft_lora_adapter_path, adapter_name="lora_adapter")
|
23 |
+
return base_model
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
model_id = "llava-hf/llava-1.5-7b-hf" # Actual base model id
|
28 |
+
peft_lora_adapter_path = 'somnathsingh31/llava-1.5-7b-hf-ft-museum' # Actual adapter path
|
29 |
+
|
30 |
+
# Load the base model
|
31 |
+
base_model = load_base_model(model_id)
|
32 |
+
|
33 |
+
# Load the PEFT Lora model (adapter)
|
34 |
+
peft_lora_adapter = load_peft_lora_adapter(base_model, peft_lora_adapter_path)
|
35 |
+
|
36 |
+
# Merge the adapters into the base model
|
37 |
+
merged_model = merge_adapters(base_model, peft_lora_adapter_path)
|
38 |
+
|
39 |
+
prompt = "USER: <image>\nWhat is special in this chess set and pieces? \nASSISTANT:"
|
40 |
+
url = "https://images.metmuseum.org/CRDImages/ad/original/138425.jpg"
|
41 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
42 |
+
|
43 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
44 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt")
|
45 |
+
|
46 |
+
# ... process the image and create inputs ...
|
47 |
+
generate_ids = merged_model.generate(**inputs, max_new_tokens=150)
|
48 |
+
decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
49 |
+
print("Generated response:", decoded_response)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
main()
|
language_translation.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# !pip install googletrans==3.1.0a0
|
3 |
+
# !pip install transformers sentencepiece
|
4 |
+
|
5 |
+
from googletrans import Translator
|
6 |
+
from transformers import MarianMTModel, MarianTokenizer # transformer based pre-trained language translation model
|
7 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def translate_hi2en_gtrans(sentence):
|
12 |
+
"""
|
13 |
+
Function to translate from Hindi to English.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
- sentence: string in Hindi
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
- English translated text string
|
20 |
+
|
21 |
+
"""
|
22 |
+
translator = Translator()
|
23 |
+
output = translator.translate(sentence, dest='en', src='hi')
|
24 |
+
return output.text
|
25 |
+
|
26 |
+
def translate_en2hi_gtrans(sentence):
|
27 |
+
"""
|
28 |
+
Function to translate from English to Hindi.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
- sentence: string in English
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
- Hindi translated text string
|
35 |
+
|
36 |
+
"""
|
37 |
+
translator = Translator()
|
38 |
+
output = translator.translate(sentence, dest='hi', src='en')
|
39 |
+
return output.text
|
40 |
+
|
41 |
+
# Translates text from source_lang to target_lang using the pre-trained model
|
42 |
+
def translate_en_hi_transformer(text):
|
43 |
+
# Load the Pre-trained Model and Tokenizer for english to hindi
|
44 |
+
model_name_en_hi = "Helsinki-NLP/opus-mt-en-hi" # English to Hindi translation model
|
45 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name_en_hi)
|
46 |
+
model_en_hi = MarianMTModel.from_pretrained(model_name_en_hi)
|
47 |
+
encoded = tokenizer(text, return_tensors="pt")
|
48 |
+
translated = model_en_hi.generate(**encoded)
|
49 |
+
return tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
|
50 |
+
|
51 |
+
# Translates text from Hindi to english using the pre-trained model
|
52 |
+
def translate_hi_en_transformer(text):
|
53 |
+
# Load the Pre-trained Model and Tokenizer for hindi to english
|
54 |
+
model_name_hi_en = "Helsinki-NLP/opus-mt-hi-en" # Hindi to English translation model
|
55 |
+
tokenizer_hi = MarianTokenizer.from_pretrained(model_name_hi_en)
|
56 |
+
model_hi_en = MarianMTModel.from_pretrained(model_name_hi_en)
|
57 |
+
encoded = tokenizer_hi(text, return_tensors="pt")
|
58 |
+
translated = model_hi_en.generate(**encoded)
|
59 |
+
return tokenizer_hi.batch_decode(translated, skip_special_tokens=True)[0]
|
60 |
+
|
61 |
+
def translate_mbart(text, source_lang, target_lang):
|
62 |
+
# Load model and tokenizer outside the function
|
63 |
+
model_name = "facebook/mbart-large-50-many-to-many-mmt"
|
64 |
+
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
65 |
+
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
|
66 |
+
|
67 |
+
# Set source language
|
68 |
+
tokenizer.src_lang = source_lang
|
69 |
+
# Encode the text
|
70 |
+
encoded_text = tokenizer(text, return_tensors="pt")
|
71 |
+
# Force target language token
|
72 |
+
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
|
73 |
+
# Generate the translation
|
74 |
+
generated_tokens = model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id)
|
75 |
+
# Decode the translation
|
76 |
+
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
77 |
+
return translation
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
print(translate_hi2en_gtrans("मैं खुश हूँ!!!"))
|
82 |
+
print(translate_en2hi_gtrans("I am happy!!!"))
|
83 |
+
import pandas as pd
|
84 |
+
|
85 |
+
# Read CSV file into a Pandas DataFrame
|
86 |
+
df_en = pd.read_csv('Data_with_QnA.csv', usecols=['Question1', 'Answer1', 'Question2', 'Answer2', 'Question3', 'Answer3', 'Question4', 'Answer4'])
|
87 |
+
df_en.head()
|
88 |
+
|
89 |
+
# convert all the question answers from English to Hindi
|
90 |
+
df_hi = df_en.applymap(translate_en2hi_gtrans)
|
91 |
+
df_hi.head()
|
92 |
+
|
93 |
+
# Save the modified DataFrame to a CSV file
|
94 |
+
df_hi.to_csv('Hindi_QnA.csv', index=False)
|
95 |
+
|
96 |
+
# English to Hindi example
|
97 |
+
english_text = " What is the material used to create the chess set?"
|
98 |
+
hindi_translation = translate_en_hi_transformer(english_text)
|
99 |
+
print(f"English: {english_text}")
|
100 |
+
print(f"Hindi: {hindi_translation}")
|
101 |
+
|
102 |
+
# Hindi to English example
|
103 |
+
hindi_text = "आपका दिन कैसा चल रहा है?" # How is your day going?
|
104 |
+
english_translation = translate_hi_en_transformer(hindi_text)
|
105 |
+
print(f"Hindi: {hindi_text}")
|
106 |
+
print(f"English: {english_translation}")
|
107 |
+
|
108 |
+
# Example usage
|
109 |
+
hindi_text = "हिन्दी साहित्य पर अगर समुचित परिप्रेक्ष्य में विचार किया जाए तो स्पष्ट होता है कि हिन्दी साहित्य का इतिहास अत्यन्त विस्तृत व प्राचीन है। सुप्रसिद्ध भाषा वैज्ञानिक डॉ० हरदेव बाहरी के शब्दों में, हिन्दी साहित्य का इतिहास वस्तुतः वैदिक काल से आरम्भ होता है। यह कहना ही ठीक होगा कि वैदिक भाषा ही हिन्दी है। इस भाषा का दुर्भाग्य रहा है कि युग-य��ग में इसका नाम परिवर्तित होता रहा है। कभी 'वैदिक', कभी 'संस्कृत', कभी 'प्राकृत', कभी'अपभ्रंश' और अब - हिन्दी।[1] आलोचक कह सकते हैं कि 'वैदिक संस्कृत' और 'हिन्दी' में तो जमीन-आसमान का अन्तर है। पर ध्यान देने योग्य है कि हिब्रू, रूसी, चीनी, जर्मन और तमिल आदि जिन भाषाओं को 'बहुत पुरानी' बताया जाता है, उनके भी प्राचीन और वर्तमान रूपों में जमीन-आसमान का अन्तर है; पर लोगों ने उन भाषाओं के नाम नहीं बदले और उनके परिवर्तित स्वरूपों को 'प्राचीन', 'मध्यकालीन', 'आधुनिक' आदि कहा गया, जबकि 'हिन्दी' के सन्दर्भ में प्रत्येक युग की भाषा का नया नाम रखा जाता रहा।"
|
110 |
+
english_translation = translate_mbart(hindi_text, "hi_IN", "en_XX")
|
111 |
+
print(english_translation)
|
112 |
+
|
113 |
+
english_text = "English literature, the body of written works produced in the English language by inhabitants of the British Isles (including Ireland) from the 7th century to the present day. The major literatures written in English outside the British Isles are treated separately under American literature, Australian literature, Canadian literature, and New Zealand literature. English literature has sometimes been stigmatized as insular. It can be argued that no single English novel attains the universality of the Russian writer Leo Tolstoy’s War and Peace or the French writer Gustave Flaubert’s Madame Bovary. Yet in the Middle Ages the Old English literature of the subjugated Saxons was leavened by the Latin and Anglo-Norman writings, eminently foreign in origin, in which the churchmen and the Norman conquerors expressed themselves. From this combination emerged a flexible and subtle linguistic instrument exploited by Geoffrey Chaucer and brought to supreme application by William Shakespeare. During the Renaissance the renewed interest in Classical learning and values had an important effect on English literature, as on all the arts; and ideas of Augustan literary propriety in the 18th century and reverence in the 19th century for a less specific, though still selectively viewed, Classical antiquity continued to shape the literature. All three of these impulses derived from a foreign source, namely the Mediterranean basin. The Decadents of the late 19th century and the Modernists of the early 20th looked to continental European individuals and movements for inspiration. Nor was attraction toward European intellectualism dead in the late 20th century, for by the mid-1980s the approach known as structuralism, a phenomenon predominantly French and German in origin, infused the very study of English literature itself in a host of published critical studies and university departments. Additional influence was exercised by deconstructionist analysis, based largely on the work of French philosopher Jacques Derrida."
|
114 |
+
hindi_translation = translate_mbart(english_text, "en_XX", "hi_IN")
|
115 |
+
print(hindi_translation)
|
116 |
+
|
llava_finetune.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# install dependencies
|
2 |
+
# !pip install -U "transformers>=4.39.0"
|
3 |
+
# !pip install peft bitsandbytes
|
4 |
+
# !pip install -U "trl>=0.8.3"
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import io
|
8 |
+
import PIL.Image as Image
|
9 |
+
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration, BitsAndBytesConfig
|
10 |
+
from trl import SFTTrainer
|
11 |
+
from peft import LoraConfig
|
12 |
+
from datasets import Dataset
|
13 |
+
from huggingface_hub import notebook_login
|
14 |
+
|
15 |
+
|
16 |
+
def load_llava_model(model_id):
|
17 |
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
18 |
+
model = LlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config,
|
19 |
+
torch_dtype=torch.float16)
|
20 |
+
torch.cuda.empty_cache()
|
21 |
+
return model
|
22 |
+
|
23 |
+
class LLavaDataCollator:
|
24 |
+
def __init__(self, processor):
|
25 |
+
self.processor = processor
|
26 |
+
|
27 |
+
def __call__(self, examples):
|
28 |
+
texts = []
|
29 |
+
images = []
|
30 |
+
for example in examples:
|
31 |
+
messages = example["messages"]
|
32 |
+
text = self.processor.tokenizer.apply_chat_template(
|
33 |
+
messages, tokenize=False, add_generation_prompt=False
|
34 |
+
)
|
35 |
+
texts.append(text)
|
36 |
+
img = Image.open(io.BytesIO(example['images'][0]['bytes']))
|
37 |
+
images.append(img)
|
38 |
+
|
39 |
+
batch = self.processor(texts, images, return_tensors="pt", padding=True)
|
40 |
+
|
41 |
+
labels = batch["input_ids"].clone()
|
42 |
+
if self.processor.tokenizer.pad_token_id is not None:
|
43 |
+
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
44 |
+
batch["labels"] = labels
|
45 |
+
|
46 |
+
return batch
|
47 |
+
|
48 |
+
|
49 |
+
def load_datasets(train_path, test_path):
|
50 |
+
train_dataset = Dataset.load_from_disk(train_path)
|
51 |
+
eval_dataset = Dataset.load_from_disk(test_path)
|
52 |
+
return train_dataset, eval_dataset
|
53 |
+
|
54 |
+
|
55 |
+
def configure_training_args(output_dir="llava-1.5-7b-hf-ft-museum"):
|
56 |
+
training_args = TrainingArguments(
|
57 |
+
output_dir=output_dir,
|
58 |
+
report_to="tensorboard",
|
59 |
+
learning_rate=1.4e-5,
|
60 |
+
per_device_train_batch_size=2,
|
61 |
+
gradient_accumulation_steps=1,
|
62 |
+
logging_steps=5,
|
63 |
+
num_train_epochs=5,
|
64 |
+
push_to_hub=True,
|
65 |
+
gradient_checkpointing=True,
|
66 |
+
remove_unused_columns=False,
|
67 |
+
fp16=True,
|
68 |
+
bf16=False
|
69 |
+
)
|
70 |
+
return training_args
|
71 |
+
|
72 |
+
|
73 |
+
def configure_llora_config():
|
74 |
+
lora_config = LoraConfig(
|
75 |
+
r=32,
|
76 |
+
lora_alpha=16,
|
77 |
+
target_modules="all-linear"
|
78 |
+
)
|
79 |
+
return lora_config
|
80 |
+
|
81 |
+
|
82 |
+
def configure_trainer(model, args, train_dataset, eval_dataset, lora_config, tokenizer, data_collator):
|
83 |
+
trainer = SFTTrainer(
|
84 |
+
model=model,
|
85 |
+
args=args,
|
86 |
+
train_dataset=train_dataset,
|
87 |
+
eval_dataset=eval_dataset,
|
88 |
+
peft_config=lora_config,
|
89 |
+
dataset_text_field="text", # need a dummy field
|
90 |
+
tokenizer=tokenizer,
|
91 |
+
data_collator=data_collator,
|
92 |
+
dataset_kwargs={"skip_prepare_dataset": True},
|
93 |
+
)
|
94 |
+
return trainer
|
95 |
+
|
96 |
+
|
97 |
+
def save_model(model, output_dir="llava_model_museum"):
|
98 |
+
training_args = TrainingArguments(output_dir=output_dir, push_to_hub=False)
|
99 |
+
model.save_pretrained(output_dir, push_to_hub=training_args.push_to_hub)
|
100 |
+
|
101 |
+
|
102 |
+
def main():
|
103 |
+
# Load LLAVA model
|
104 |
+
model_id = "llava-hf/llava-1.5-7b-hf"
|
105 |
+
model = load_llava_model(model_id)
|
106 |
+
|
107 |
+
# Load datasets
|
108 |
+
train_path = '/kaggle/input/metmesuem-data/train.arrow'
|
109 |
+
test_path = '/kaggle/input/metmesuem-data/test.arrow'
|
110 |
+
train_dataset, eval_dataset = load_datasets(train_path, test_path)
|
111 |
+
|
112 |
+
# Login to Hugging-face Hub
|
113 |
+
notebook_login()
|
114 |
+
|
115 |
+
# Configure training arguments
|
116 |
+
args = configure_training_args()
|
117 |
+
|
118 |
+
# Configure LLORA config
|
119 |
+
lora_config = configure_llora_config()
|
120 |
+
|
121 |
+
# Configure trainer
|
122 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
123 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
124 |
+
processor.tokenizer = tokenizer
|
125 |
+
data_collator = LLavaDataCollator(processor)
|
126 |
+
trainer = configure_trainer(model, args, train_dataset, eval_dataset, lora_config, tokenizer, data_collator)
|
127 |
+
|
128 |
+
# Load and set Tensorboard for logging
|
129 |
+
# %load_ext tensorboard
|
130 |
+
# %tensorboard --logdir llava-1.5-7b-hf-ft-museum
|
131 |
+
|
132 |
+
# Train model
|
133 |
+
trainer.train()
|
134 |
+
|
135 |
+
# Save the Lora adapters to Hugging Face Hub
|
136 |
+
trainer.push_to_hub()
|
137 |
+
|
138 |
+
# Save model
|
139 |
+
save_model(model)
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
main()
|
translation_model_evaluation.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from nltk.translate.meteor_score import meteor_score
|
3 |
+
from nltk.translate.bleu_score import corpus_bleu
|
4 |
+
from indicnlp.tokenize import indic_tokenize
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def configure_nltk():
|
9 |
+
nltk.download('punkt')
|
10 |
+
nltk.download('wordnet')
|
11 |
+
|
12 |
+
def tokenize_hindi(text):
|
13 |
+
return indic_tokenize.trivial_tokenize(text, lang='hi')
|
14 |
+
|
15 |
+
def calculate_meteor_score(ref, output):
|
16 |
+
ref_tokens = tokenize_hindi(ref)
|
17 |
+
output_tokens = tokenize_hindi(output)
|
18 |
+
return meteor_score([ref_tokens], output_tokens)
|
19 |
+
|
20 |
+
def calculate_bleu_score(ref, output):
|
21 |
+
ref_tokens = [tokenize_hindi(ref)]
|
22 |
+
output_tokens = tokenize_hindi(output)
|
23 |
+
return corpus_bleu(ref_tokens, [output_tokens])
|
24 |
+
|
25 |
+
def read_hindi_data(filepath):
|
26 |
+
hindi_df = pd.read_csv(filepath)
|
27 |
+
hindi_df = hindi_df.replace('नेन', np.nan)
|
28 |
+
hindi_df = hindi_df.dropna()
|
29 |
+
return hindi_df
|
30 |
+
|
31 |
+
def create_reference_df(hindi_df):
|
32 |
+
reference_df = pd.DataFrame()
|
33 |
+
for i in range(1, 3):
|
34 |
+
reference_df[f'Question{i}'] = hindi_df[f'Question{i}']
|
35 |
+
reference_df[f'Answer{i}'] = hindi_df[f'Answer{i}']
|
36 |
+
return reference_df
|
37 |
+
|
38 |
+
def select_first_n_rows(df, n=10):
|
39 |
+
return df.iloc[:n]
|
40 |
+
|
41 |
+
def calculate_scores(reference_df, n=10):
|
42 |
+
meteor_scores = []
|
43 |
+
bleu_scores = []
|
44 |
+
|
45 |
+
for i in range(1, 3):
|
46 |
+
ref_question_col = f'Question{i}'
|
47 |
+
ref_answer_col = f'Answer{i}'
|
48 |
+
output_question_col = f'Question{i}_hindi'
|
49 |
+
output_answer_col = f'Answer{i}_hindi'
|
50 |
+
|
51 |
+
for ref_question, output_question in zip(reference_df[ref_question_col], reference_df[output_question_col]):
|
52 |
+
meteor_scores.append(calculate_meteor_score(ref_question, output_question))
|
53 |
+
bleu_scores.append(calculate_bleu_score(ref_question, output_question))
|
54 |
+
|
55 |
+
for ref_answer, output_answer in zip(reference_df[ref_answer_col], reference_df[output_answer_col]):
|
56 |
+
meteor_scores.append(calculate_meteor_score(ref_answer, output_answer))
|
57 |
+
bleu_scores.append(calculate_bleu_score(ref_answer, output_answer))
|
58 |
+
|
59 |
+
average_meteor_score = sum(meteor_scores) / len(meteor_scores)
|
60 |
+
average_bleu_score = sum(bleu_scores) / len(bleu_scores)
|
61 |
+
|
62 |
+
return average_meteor_score, average_bleu_score
|
63 |
+
|
64 |
+
def main():
|
65 |
+
# Configure NLTK
|
66 |
+
configure_nltk()
|
67 |
+
|
68 |
+
# Read Hindi data
|
69 |
+
hindi_df = read_hindi_data('/content/Hindi_QnA.csv')
|
70 |
+
|
71 |
+
# Create reference dataframe
|
72 |
+
reference_df = create_reference_df(hindi_df)
|
73 |
+
|
74 |
+
# Select first 10 rows
|
75 |
+
reference_df_first10 = select_first_n_rows(reference_df)
|
76 |
+
|
77 |
+
# Calculate scores
|
78 |
+
average_meteor_score, average_bleu_score = calculate_scores(reference_df_first10)
|
79 |
+
|
80 |
+
print("Average METEOR Score:", average_meteor_score)
|
81 |
+
print("Average BLEU Score:", average_bleu_score)
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
main()
|
vqa_collection.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# !pip install -q -U google-generativeai
|
3 |
+
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import pandas as pd
|
7 |
+
from google.colab import userdata
|
8 |
+
import google.generativeai as genai
|
9 |
+
|
10 |
+
def configure_genai():
|
11 |
+
# Used to securely store your API key
|
12 |
+
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
|
13 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
14 |
+
|
15 |
+
def generate_art_questions(data_json, model):
|
16 |
+
art_texts = []
|
17 |
+
i = 0
|
18 |
+
|
19 |
+
for data_entry in data_json:
|
20 |
+
# Prompt for generating questions based on art description and metadata
|
21 |
+
prompt = f"Following is the description and meta data of an art. Based on this generate 4 questions and corresponding answer. Denote each question as **Question and answer as **Answer. {data_entry}"
|
22 |
+
# Generate content using the model
|
23 |
+
response = model.generate_content(prompt)
|
24 |
+
art_texts.append(response.text)
|
25 |
+
i += 1
|
26 |
+
if i == 60:
|
27 |
+
print("60 completed")
|
28 |
+
i = 0
|
29 |
+
time.sleep(10)
|
30 |
+
|
31 |
+
return art_texts
|
32 |
+
|
33 |
+
def extract_qa_pairs(text):
|
34 |
+
qa_pairs = {}
|
35 |
+
current_q = 1
|
36 |
+
current_a = 1
|
37 |
+
|
38 |
+
for line in text.splitlines():
|
39 |
+
if line.startswith("**Question"):
|
40 |
+
qa_pairs[f"Question{current_q}"] = line.strip().split(":", 1)[1].strip("**") # Extract question
|
41 |
+
current_q += 1
|
42 |
+
elif line.startswith(("**Answer", "Answer")):
|
43 |
+
qa_pairs[f"Answer{current_a}"] = line.strip().split(":", 1)[1].strip("**") # Extract answer
|
44 |
+
current_a += 1
|
45 |
+
|
46 |
+
return qa_pairs
|
47 |
+
|
48 |
+
def create_qa_dataframe(art_texts):
|
49 |
+
df = pd.DataFrame(columns=["Question1", "Answer1", "Question2", "Answer2", "Question3", "Answer3", "Question4", "Answer4"])
|
50 |
+
|
51 |
+
for art_text in art_texts:
|
52 |
+
qa_dict = extract_qa_pairs(art_text)
|
53 |
+
|
54 |
+
if len(list(qa_dict.values())) == 8:
|
55 |
+
# Fill the DataFrame with extracted data
|
56 |
+
df.loc[len(df)] = list(qa_dict.values()) # Efficiently fill with all values
|
57 |
+
else:
|
58 |
+
df.loc[len(df)] = [None]*8
|
59 |
+
|
60 |
+
return df
|
61 |
+
|
62 |
+
def main():
|
63 |
+
# Configure GenAI
|
64 |
+
configure_genai()
|
65 |
+
|
66 |
+
# Initialize Generative Model
|
67 |
+
model = genai.GenerativeModel('gemini-pro')
|
68 |
+
|
69 |
+
# Read the data from CSV
|
70 |
+
data = pd.read_csv('/content/Data_for_questioning.csv')
|
71 |
+
|
72 |
+
# Convert DataFrame to JSON
|
73 |
+
data.to_json('final_data_json.json', orient='records', indent=4)
|
74 |
+
|
75 |
+
# Open the JSON file and load its contents
|
76 |
+
with open('/content/final_data_json.json', 'r') as f:
|
77 |
+
data_json = json.load(f)
|
78 |
+
|
79 |
+
# Generate questions for art descriptions
|
80 |
+
art_texts = generate_art_questions(data_json, model)
|
81 |
+
|
82 |
+
# Create DataFrame with questions and answers
|
83 |
+
df = create_qa_dataframe(art_texts)
|
84 |
+
|
85 |
+
# Save DataFrame to CSV
|
86 |
+
df.to_csv('QnA.csv', index=False)
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
main()
|