storresbusquets commited on
Commit
a445c9d
·
1 Parent(s): f357bf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -13
app.py CHANGED
@@ -10,10 +10,7 @@ class GradioInference():
10
  self.current_size = "base"
11
  self.loaded_model = whisper.load_model(self.current_size)
12
  self.yt = None
13
- # self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
14
-
15
- self.tokenizer_model = AutoTokenizer.from_pretrained("google/pegasus-large")
16
- self.summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-large")
17
 
18
  # Initialize VoiceLabT5 model and tokenizer
19
  self.keyword_model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords")
@@ -35,16 +32,9 @@ class GradioInference():
35
  self.current_size = size
36
 
37
  results = self.loaded_model.transcribe(path, language=lang)
38
-
39
- inputs = self.tokenizer_model(results["text"], max_length=1024, truncation=True, return_tensors="pt")
40
-
41
- summary_ids = self.keyword_model.generate(inputs["input_ids"])
42
- summary = self.keyword_tokenizer.batch_decode(summary_ids,
43
- skip_special_tokens=True,
44
- clean_up_tokenization_spaces=False)
45
 
46
  # Perform summarization on the transcription
47
- # transcription_summary = self.summarizer(results["text"], max_length=130, min_length=30, do_sample=False)
48
 
49
  # Extract keywords using VoiceLabT5
50
  task_prefix = "Keywords: "
@@ -56,7 +46,7 @@ class GradioInference():
56
 
57
  label = self.classifier(results["text"])[0]["label"]
58
 
59
- return results["text"], summary[0], keywords, label
60
 
61
  def populate_metadata(self, link):
62
  self.yt = YouTube(link)
 
10
  self.current_size = "base"
11
  self.loaded_model = whisper.load_model(self.current_size)
12
  self.yt = None
13
+ self.summarizer = pipeline("summarization", model="google/pegasus-large")
 
 
 
14
 
15
  # Initialize VoiceLabT5 model and tokenizer
16
  self.keyword_model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords")
 
32
  self.current_size = size
33
 
34
  results = self.loaded_model.transcribe(path, language=lang)
 
 
 
 
 
 
 
35
 
36
  # Perform summarization on the transcription
37
+ transcription_summary = self.summarizer(results["text"], max_length=130, min_length=30, do_sample=False)
38
 
39
  # Extract keywords using VoiceLabT5
40
  task_prefix = "Keywords: "
 
46
 
47
  label = self.classifier(results["text"])[0]["label"]
48
 
49
+ return results["text"], transcription_summary[0]["summary_text"], keywords, label
50
 
51
  def populate_metadata(self, link):
52
  self.yt = YouTube(link)