ronniet commited on
Commit
003c8a7
1 Parent(s): 059ab7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -36
app.py CHANGED
@@ -6,14 +6,14 @@ import librosa
6
  import numpy as np
7
  import torch
8
 
9
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
10
  from transformers import AutoProcessor, AutoModelForCausalLM
11
 
12
 
13
- checkpoint = "microsoft/speecht5_tts"
14
- tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
15
- tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
16
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
17
 
18
  # ic_processor = AutoProcessor.from_pretrained("microsoft/git-base")
19
  # ic_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
@@ -21,40 +21,24 @@ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
21
  ic_processor = AutoProcessor.from_pretrained("ronniet/git-base-env")
22
  ic_model = AutoModelForCausalLM.from_pretrained("ronniet/git-base-env")
23
 
24
- def tts(text):
25
- if len(text.strip()) == 0:
26
- return (16000, np.zeros(0).astype(np.int16))
27
 
28
- inputs = tts_processor(text=text, return_tensors="pt")
29
 
30
- # limit input length
31
- input_ids = inputs["input_ids"]
32
- input_ids = input_ids[..., :tts_model.config.max_text_positions]
33
 
34
- # if speaker == "Surprise Me!":
35
- # # load one of the provided speaker embeddings at random
36
- # idx = np.random.randint(len(speaker_embeddings))
37
- # key = list(speaker_embeddings.keys())[idx]
38
- # speaker_embedding = np.load(speaker_embeddings[key])
39
 
40
- # # randomly shuffle the elements
41
- # np.random.shuffle(speaker_embedding)
42
 
43
- # # randomly flip half the values
44
- # x = (np.random.rand(512) >= 0.5) * 1.0
45
- # x[x == 0] = -1.0
46
- # speaker_embedding *= x
47
 
48
- #speaker_embedding = np.random.rand(512).astype(np.float32) * 0.3 - 0.15
49
- # else:
50
- speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")
51
-
52
- speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
53
-
54
- speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
55
-
56
- speech = (speech.numpy() * 32767).astype(np.int16)
57
- return (16000, speech)
58
 
59
 
60
  # captioner = pipeline(model="microsoft/git-base")
@@ -71,16 +55,16 @@ def predict(image):
71
  text_ids = ic_model.generate(pixel_values=pixel_values, max_length=50)
72
  text = ic_processor.batch_decode(text_ids, skip_special_tokens=True)[0]
73
 
74
- audio = tts(text)
75
 
76
- return text, audio
77
 
78
  # theme = gr.themes.Default(primary_hue="#002A5B")
79
 
80
  demo = gr.Interface(
81
  fn=predict,
82
  inputs=gr.Image(type="pil",label="Environment"),
83
- outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
84
  css=".gradio-container {background-color: #002A5B}",
85
  theme=gr.themes.Soft()
86
  )
 
6
  import numpy as np
7
  import torch
8
 
9
+ # from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
10
  from transformers import AutoProcessor, AutoModelForCausalLM
11
 
12
 
13
+ # checkpoint = "microsoft/speecht5_tts"
14
+ # tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
15
+ # tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
16
+ # vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
17
 
18
  # ic_processor = AutoProcessor.from_pretrained("microsoft/git-base")
19
  # ic_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
 
21
  ic_processor = AutoProcessor.from_pretrained("ronniet/git-base-env")
22
  ic_model = AutoModelForCausalLM.from_pretrained("ronniet/git-base-env")
23
 
24
+ # def tts(text):
25
+ # if len(text.strip()) == 0:
26
+ # return (16000, np.zeros(0).astype(np.int16))
27
 
28
+ # inputs = tts_processor(text=text, return_tensors="pt")
29
 
30
+ # # limit input length
31
+ # input_ids = inputs["input_ids"]
32
+ # input_ids = input_ids[..., :tts_model.config.max_text_positions]
33
 
34
+ # speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")
 
 
 
 
35
 
36
+ # speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
 
37
 
38
+ # speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
 
 
 
39
 
40
+ # speech = (speech.numpy() * 32767).astype(np.int16)
41
+ # return (16000, speech)
 
 
 
 
 
 
 
 
42
 
43
 
44
  # captioner = pipeline(model="microsoft/git-base")
 
55
  text_ids = ic_model.generate(pixel_values=pixel_values, max_length=50)
56
  text = ic_processor.batch_decode(text_ids, skip_special_tokens=True)[0]
57
 
58
+ # audio = tts(text)
59
 
60
+ return text
61
 
62
  # theme = gr.themes.Default(primary_hue="#002A5B")
63
 
64
  demo = gr.Interface(
65
  fn=predict,
66
  inputs=gr.Image(type="pil",label="Environment"),
67
+ outputs=gr.Textbox(label="Caption"),
68
  css=".gradio-container {background-color: #002A5B}",
69
  theme=gr.themes.Soft()
70
  )