File size: 2,335 Bytes
000c2c2
 
04b62bf
 
 
 
 
003c8a7
be0f3ee
04b62bf
 
003c8a7
 
 
 
04b62bf
5c37754
 
253c97c
5c37754
 
be0f3ee
003c8a7
 
 
04b62bf
003c8a7
04b62bf
003c8a7
 
 
04b62bf
003c8a7
04b62bf
003c8a7
04b62bf
003c8a7
04b62bf
003c8a7
 
04b62bf
cca9a4c
be0f3ee
04b62bf
18d8458
1c09801
000c2c2
be0f3ee
000c2c2
04b62bf
 
be0f3ee
 
 
 
 
003c8a7
1c09801
003c8a7
18d8458
35a5116
43439da
000c2c2
 
37fe72d
003c8a7
9ac034b
37fe72d
000c2c2
18d8458
000c2c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import pipeline

import librosa
import numpy as np
import torch

# from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import AutoProcessor, AutoModelForCausalLM


# checkpoint = "microsoft/speecht5_tts"
# tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
# tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
# vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

# ic_processor = AutoProcessor.from_pretrained("microsoft/git-base")
# ic_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

ic_processor = AutoProcessor.from_pretrained("ronniet/git-base-env")
ic_model = AutoModelForCausalLM.from_pretrained("ronniet/git-base-env")

# def tts(text):
#     if len(text.strip()) == 0:
#         return (16000, np.zeros(0).astype(np.int16))

#     inputs = tts_processor(text=text, return_tensors="pt")

#     # limit input length
#     input_ids = inputs["input_ids"]
#     input_ids = input_ids[..., :tts_model.config.max_text_positions]

#     speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")

#     speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)

#     speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)

#     speech = (speech.numpy() * 32767).astype(np.int16)
#     return (16000, speech)


# captioner = pipeline(model="microsoft/git-base")
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)


def predict(image):
    # text = captioner(image)[0]["generated_text"]

    # audio_output = "output.wav"
    # tts.tts_to_file(text, speaker=tts.speakers[0], language="en", file_path=audio_output)

    pixel_values = ic_processor(images=image, return_tensors="pt").pixel_values
    text_ids = ic_model.generate(pixel_values=pixel_values, max_length=50)
    text = ic_processor.batch_decode(text_ids, skip_special_tokens=True)[0]
    
    # audio = tts(text)
    
    return text

# theme = gr.themes.Default(primary_hue="#002A5B")

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil",label="Environment"),
    outputs=gr.Textbox(label="Caption"),
    css=".gradio-container {background-color: #002A5B}",
    theme=gr.themes.Soft()
)

demo.launch()