File size: 1,862 Bytes
bce23e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225162b
 
 
 
 
 
 
 
bce23e4
 
 
 
 
 
 
 
 
 
 
d810840
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import requests
from youtube_transcript_api import YouTubeTranscriptApi
import json
import os

headers = {
    "Authorization": f"Bearer {os.environ['HF_Token']}"
}  # NOTE: put this somewhere else


def retrieve_transcript(vid_id):
    try:
        transcript = YouTubeTranscriptApi.get_transcript(vid_id)
        return transcript
    except Exception as e:
        return None


def split_transcript(transcript, chunk_size=40):
    sentences = []
    for i in range(0, len(transcript), chunk_size):
        to_add = [x["text"] for x in transcript[i : i + chunk_size]]
        sentences.append(" ".join(to_add))
    return sentences


def query_punctuation(splits):
    payload = {"inputs": splits}
    API_URL = "https://api-inference.huggingface.co/models/oliverguhr/fullstop-punctuation-multilang-large"
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()


def parse_output(output, comb):
    total = []

    # loop over the response from the huggingface api
    for i, o in enumerate(output):
        added = 0
        tt = comb[i]
        for elem in o:
            try:
                # Loop over the output chunks and add the . and ?
                if elem["entity_group"] not in ["0", ",", ""]:
                    split = elem["end"] + added
                    tt = tt[:split] + elem["entity_group"] + tt[split:]
                    added += 1
            except:
                continue
        total.append(tt)
    return " ".join(total)


def punctuate(video_id):
    transcript = retrieve_transcript(video_id)
    splits = split_transcript(
        transcript
    )  # Get the transcript from the YoutubeTranscriptApi
    resp = query_punctuation(splits)  # Get the response from the Inference API
    punctuated_transcript = parse_output(resp, splits)
    return punctuated_transcript, transcript