File size: 3,510 Bytes
e06c27b
 
 
65e0b57
6c022f9
65e0b57
e06c27b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611b363
6c022f9
 
 
 
 
 
 
 
 
 
 
 
e06c27b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c022f9
 
 
e06c27b
 
 
6c022f9
 
 
 
e06c27b
 
 
6c022f9
 
e06c27b
 
 
 
 
 
6c022f9
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os


import streamlit as st
from urllib.parse import urlparse, parse_qs

from tqdm import tqdm
from stqdm import stqdm

# https://github.com/pytorch/pytorch/issues/77764
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from youtube_transcript_api import YouTubeTranscriptApi

from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer

import torch

# Setting device for PYTorch
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.has_mps:
    device = torch.device('mps')
else:
    device = torch.device('cpu')



def get_videoid_from_url(url:str):
    url_data = urlparse(url)
    query = parse_qs(url_data.query)

    try:
        video_id = query["v"][0]
    except KeyError:
        video_id = ''

    return video_id

def process_click_callback():
    st.session_state.process_btn = True

    print('Using {} device'.format(device))

    transcript_list = YouTubeTranscriptApi.list_transcripts('aircAruvnKk')  # 3blue1Brown

    try:
        transcript = transcript_list.find_manually_created_transcript(['en'])
    except Exception as e:
        print('No manual transcripts were found, trying to load generated ones...')
        transcript = transcript_list.find_generated_transcript(['en'])

    subtitles = transcript.fetch()

    subtitles = [sbt['text'] for sbt in subtitles if sbt['text'] != '[Music]']
    subtitles_len = [len(sbt) for sbt in subtitles]
    sbt_mean_len = sum(subtitles_len)/len(subtitles_len)

    print('Mean length of subtitles: {}'.format(sbt_mean_len))
    print(subtitles)
    print(len(subtitles))

    # Number of subtitles per step/summary
    # Since number length of transcripts differs
    # between generated and manual ones
    # we set different step size
    n_sbt_per_step = int(400 / (sbt_mean_len / 4))
    print('Number subtitles per summary: {}'.format(n_sbt_per_step))

    n_steps = len(subtitles) // n_sbt_per_step if len(subtitles) % n_sbt_per_step == 0 else \
        len(subtitles) // n_sbt_per_step + 1

    summaries = []

    for i in stqdm(range(n_steps)):
        sbt_txt = ' '.join(subtitles[n_sbt_per_step*i:n_sbt_per_step*(i+1)])
        # print('length of text: {}'.format(len(sbt_txt)))
        # print(sbt_txt)

        summarizer = pipeline('summarization', model='t5-small', tokenizer='t5-small',
                              max_length=512, truncation=True)

        summary = summarizer(sbt_txt, do_sample=False)
        summary = summary[0]['summary_text']

        # print('Summary: ' + summary)
        summaries.append(summary)

    out = ' '.join(summaries)
    print(out)

    st.session_state.summary_output = out
    st.success('Processing complete!', icon="βœ…")
    st.session_state.process_btn = False



def main():
    st.title('YouTube Video Summary πŸ“ƒ')
    st.markdown('Creates summary for given YouTube video URL based on transcripts.')
    st.code('https://www.youtube.com/watch?v=aircAruvnKk')
    st.code('https://youtu.be/p0G68ORc8uQ')

    col1, col2 = st.columns(2)

    with col1:
        video_url = st.text_input('YouTube Video URL:',  placeholder='YouTube URL',
                                 label_visibility='collapsed')
        st.write(get_videoid_from_url(video_url))

    with col2:
        st.button('Process πŸ“­', key='process_btn', on_click=process_click_callback)

    st.text_area(label='', key='summary_output', height=444)






    # x = st.slider('Select a value')
    # st.write(x, 'squared is', x * x)


if __name__ == "__main__":
    main()