File size: 5,691 Bytes
cad6415
1352014
cad6415
 
 
 
 
1352014
 
 
cad6415
 
 
 
 
0f8c45c
1352014
cad6415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f8c45c
 
cad6415
1352014
 
 
 
 
 
 
 
 
cad6415
 
9a01338
1352014
0f8c45c
cad6415
1352014
cad6415
 
 
1352014
cad6415
 
 
 
 
81ac7bd
cad6415
76ee624
cad6415
 
 
1352014
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, pipeline, AutoProcessor, MusicgenForConditionalGeneration
import torch
from PIL import Image
import scipy.io.wavfile

# Load the MusicGen model
#musicgen = pipeline("text-to-audio", model="facebook/musicgen-small")
musicProcessor = AutoProcessor.from_pretrained("facebook/musicgen-small")
musicgen = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
# Load the StreetCLIP model
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")

labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Egypt', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda','Saudi Arabia', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Syria','Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
labelsAr = ['ألبانيا', 'أندورا', 'الأرجنتين', 'أستراليا', 'النمسا', 'بنغلاديش', 'بلجيكا', 'برمودا', 'بوتان', 'بوليفيا', 'بوتسوانا', 'البرازيل', 'بلغاريا', 'كمبوديا', 'كندا', 'تشيلي', 'الصين', 'كولومبيا', 'كرواتيا', 'جمهورية التشيك', 'الدنمارك', 'جمهورية الدومينيكان', 'مصر', 'الإكوادور', 'إستونيا', 'فنلندا', 'فرنسا', 'ألمانيا', 'غانا', 'اليونان', 'جرينلاند', 'غوام', 'غواتيمالا', 'المجر', 'آيسلندا', 'الهند', 'إندونيسيا', 'أيرلندا', 'إسرائيل', 'إيطاليا', 'اليابان', 'الأردن', 'كينيا', 'قيرغيزستان', 'لاوس', 'لاتفيا', 'ليسوتو', 'ليتوانيا', 'لوكسمبورغ', 'مقدونيا', 'مدغشقر', 'ماليزيا', 'مالطا', 'المكسيك', 'موناكو', 'منغوليا', 'الجبل الأسود', 'هولندا', 'نيوزيلندا', 'نيجيريا', 'النرويج', 'باكستان', 'فلسطين', 'بيرو', 'الفلبين', 'بولندا', 'البرتغال', 'بورتوريكو', 'رومانيا', 'روسيا', 'رواندا', 'المملكة العربية السعودية', 'السنغال', 'صربيا', 'سنغافورة', 'سلوفاكيا', 'سلوفينيا', 'جنوب أفريقيا', 'كوريا الجنوبية', 'إسبانيا', 'سريلانكا', 'سوازيلاند', 'السويد', 'سويسرا', 'سوريا', 'تايوان', 'تايلاند', 'تونس', 'تركيا', 'أوغندا', 'أوكرانيا', 'الإمارات العربية المتحدة', 'المملكة المتحدة', 'الولايات المتحدة', 'أوروغواي']
def process_image(image, audio_path="musicgen_out.wav"):
    # Ensure the image is in the correct format
    if isinstance(image, str):
        image = Image.open(image)

    # Process the image and text inputs
    inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)

    # Get the model outputs
    with torch.no_grad():
        outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    # Get the country with the highest probability
    country_index = probs.argmax(dim=1).item()
    country = labels[country_index]
    countryAr = labelsAr[country_index]
    
    # Generate music based on the country
    music_description = f"Traditional music from {country}"
    #music = musicgen(music_description, forward_params={"do_sample": True})
    inputs = musicProcessor(
    text=[music_description],
    padding=True,
    return_tensors="pt",
    )
    audio_values = musicgen.generate(**inputs, max_new_tokens=256)


    # Save the generated music to the specified path
    sampling_rate = musicgen.config.audio_encoder.sampling_rate
    scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
    country = country + " - " + countryAr
    # Return the country and the path to the generated music
    return country, audio_path

# Define the Gradio interface
inputs = gr.Image(type="pil", label="Upload a photo (تحميل صورة)")
outputs = [gr.Textbox(label="Country (البلد)"), gr.Audio(label="Generated Music (الموسيقى المولدة)")]

iface = gr.Interface(
    fn=process_image,
    inputs=inputs,
    outputs=outputs,
    title="Photo to Country and Music Generator محدد الموقع من الصور و منشئ موسيقى",
    description="Upload a photo to identify the country and generate traditional music from that country. (قم بتحميل صورة لتحديد البلد وإنشاء موسيقى تقليدية من هذا البلد.)",
    examples=["Egypt.jfif", "Riyadh.jpeg", "Syria.jfif", "Turkey.jfif"]
)

# Launch the interface
iface.launch(debug=True)