Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
•
2cbbc23
1
Parent(s):
308c6f6
Add youtube transcription processor
Browse files- Dockerfile +1 -1
- demo.py +2 -0
- pyproject.toml +2 -1
- requirements-dev.txt +1 -0
- requirements.txt +2 -0
- src/gistillery/config.py +2 -0
- src/gistillery/media.py +68 -0
- src/gistillery/preprocessing.py +75 -2
- src/gistillery/registry.py +5 -3
Dockerfile
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
FROM pytorch/pytorch:latest
|
2 |
|
3 |
-
RUN apt update && apt install -y && rm -rf /var/lib/apt/lists/*
|
4 |
|
5 |
# Set up a new user named "user" with user ID 1000
|
6 |
RUN useradd -m -u 1000 user
|
|
|
1 |
FROM pytorch/pytorch:latest
|
2 |
|
3 |
+
RUN apt update && apt install -y && apt install ffmpeg && rm -rf /var/lib/apt/lists/*
|
4 |
|
5 |
# Set up a new user named "user" with user ID 1000
|
6 |
RUN useradd -m -u 1000 user
|
demo.py
CHANGED
@@ -6,6 +6,8 @@ client = httpx.Client()
|
|
6 |
|
7 |
|
8 |
def submit(inputs):
|
|
|
|
|
9 |
payload = {"content": inputs, "author": "anna nymous"}
|
10 |
httpx.post("http://localhost:8080/submit/", json=payload)
|
11 |
|
|
|
6 |
|
7 |
|
8 |
def submit(inputs):
|
9 |
+
if not inputs:
|
10 |
+
return
|
11 |
payload = {"content": inputs, "author": "anna nymous"}
|
12 |
httpx.post("http://localhost:8080/submit/", json=payload)
|
13 |
|
pyproject.toml
CHANGED
@@ -16,7 +16,8 @@ addopts = "--cov=src --cov-report=term-missing"
|
|
16 |
[tool.mypy]
|
17 |
no_implicit_optional = true
|
18 |
strict = true
|
|
|
19 |
|
20 |
[[tool.mypy.overrides]]
|
21 |
-
module = "huggingface_hub,trafilatura,transformers
|
22 |
ignore_missing_imports = true
|
|
|
16 |
[tool.mypy]
|
17 |
no_implicit_optional = true
|
18 |
strict = true
|
19 |
+
plugins = "numpy.typing.mypy_plugin"
|
20 |
|
21 |
[[tool.mypy.overrides]]
|
22 |
+
module = "huggingface_hub,trafilatura,transformers.*,pytube"
|
23 |
ignore_missing_imports = true
|
requirements-dev.txt
CHANGED
@@ -5,3 +5,4 @@ ruff
|
|
5 |
pytest
|
6 |
pytest-cov
|
7 |
types-Pillow
|
|
|
|
5 |
pytest
|
6 |
pytest-cov
|
7 |
types-Pillow
|
8 |
+
types-urllib3
|
requirements.txt
CHANGED
@@ -8,3 +8,5 @@ charset-normalizer
|
|
8 |
trafilatura
|
9 |
pillow
|
10 |
gradio
|
|
|
|
|
|
8 |
trafilatura
|
9 |
pillow
|
10 |
gradio
|
11 |
+
urllib3
|
12 |
+
pytube
|
src/gistillery/config.py
CHANGED
@@ -8,6 +8,8 @@ class Config(BaseSettings):
|
|
8 |
hf_hub_token: str = "missing"
|
9 |
hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
|
10 |
db_file_name: Path = Path("sqlite-data.db")
|
|
|
|
|
11 |
|
12 |
class Config:
|
13 |
# load .env file by default, with provisio to use other .env files if set
|
|
|
8 |
hf_hub_token: str = "missing"
|
9 |
hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
|
10 |
db_file_name: Path = Path("sqlite-data.db")
|
11 |
+
sampling_rate: int = 16_000 # audio transcription
|
12 |
+
max_yt_length: int = 1800 # in minutes
|
13 |
|
14 |
class Config:
|
15 |
# load .env file by default, with provisio to use other .env files if set
|
src/gistillery/media.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import numpy.typing as npt
|
6 |
+
import pytube
|
7 |
+
|
8 |
+
|
9 |
+
def download_yt_audio(url: str, max_length: int) -> str:
|
10 |
+
yt = pytube.YouTube(url)
|
11 |
+
if (max_length is not None) and (yt.length > max_length):
|
12 |
+
raise ValueError(f"Youtube video exceeds max length of {max_length}")
|
13 |
+
|
14 |
+
video = yt.streams.filter(only_audio=True).first()
|
15 |
+
tmp_path = tempfile.mkdtemp()
|
16 |
+
fname = video.download(output_path=tmp_path)
|
17 |
+
assert isinstance(fname, str)
|
18 |
+
return fname
|
19 |
+
|
20 |
+
|
21 |
+
def check_ffmpeg_installed() -> None:
|
22 |
+
cmd = ["ffmpeg", "-version"] # sic
|
23 |
+
try:
|
24 |
+
subprocess.run(cmd, check=True)
|
25 |
+
except FileNotFoundError as exc:
|
26 |
+
raise RuntimeError("This feature requires ffmpeg to be installed") from exc
|
27 |
+
|
28 |
+
|
29 |
+
# from openai whisper
|
30 |
+
def load_audio(file: str, sampling_rate: int) -> npt.NDArray[np.float32]:
|
31 |
+
"""Open an audio file and read as mono waveform, resampling as necessary
|
32 |
+
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
file: str
|
36 |
+
The audio file to open
|
37 |
+
|
38 |
+
sampling_rate: int
|
39 |
+
The sample rate to resample the audio if necessary
|
40 |
+
|
41 |
+
Returns
|
42 |
+
-------
|
43 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
44 |
+
|
45 |
+
"""
|
46 |
+
check_ffmpeg_installed() # BB
|
47 |
+
|
48 |
+
# This launches a subprocess to decode audio while down-mixing
|
49 |
+
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
50 |
+
# fmt: off
|
51 |
+
cmd = [
|
52 |
+
"ffmpeg",
|
53 |
+
"-nostdin",
|
54 |
+
"-threads", "0",
|
55 |
+
"-i", file,
|
56 |
+
"-f", "s16le",
|
57 |
+
"-ac", "1",
|
58 |
+
"-acodec", "pcm_s16le",
|
59 |
+
"-ar", str(sampling_rate),
|
60 |
+
"-"
|
61 |
+
]
|
62 |
+
# fmt: on
|
63 |
+
try:
|
64 |
+
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
65 |
+
except subprocess.CalledProcessError as e:
|
66 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
67 |
+
|
68 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
src/gistillery/preprocessing.py
CHANGED
@@ -4,15 +4,18 @@ import logging
|
|
4 |
import re
|
5 |
from typing import Optional
|
6 |
|
|
|
7 |
import trafilatura
|
|
|
8 |
from httpx import Client
|
9 |
-
|
10 |
from PIL import Image
|
|
|
11 |
|
12 |
from gistillery.base import JobInput
|
|
|
|
|
13 |
from gistillery.tools import get_agent
|
14 |
|
15 |
-
|
16 |
logger = logging.getLogger(__name__)
|
17 |
logger.setLevel(logging.DEBUG)
|
18 |
|
@@ -109,3 +112,73 @@ class ImageUrlProcessor(Processor):
|
|
109 |
image = Image.open(io.BytesIO(response.content)).convert('RGB')
|
110 |
caption = get_agent().run("Caption the following image", image=image)
|
111 |
return str(caption)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import re
|
5 |
from typing import Optional
|
6 |
|
7 |
+
import torch
|
8 |
import trafilatura
|
9 |
+
import urllib3
|
10 |
from httpx import Client
|
|
|
11 |
from PIL import Image
|
12 |
+
from transformers import AutoProcessor, WhisperForConditionalGeneration
|
13 |
|
14 |
from gistillery.base import JobInput
|
15 |
+
from gistillery.config import get_config
|
16 |
+
from gistillery.media import download_yt_audio, load_audio
|
17 |
from gistillery.tools import get_agent
|
18 |
|
|
|
19 |
logger = logging.getLogger(__name__)
|
20 |
logger.setLevel(logging.DEBUG)
|
21 |
|
|
|
112 |
image = Image.open(io.BytesIO(response.content)).convert('RGB')
|
113 |
caption = get_agent().run("Caption the following image", image=image)
|
114 |
return str(caption)
|
115 |
+
|
116 |
+
|
117 |
+
class YoutubeUrlProcessor(Processor):
|
118 |
+
"""Download yt audio, transcribe with whisper"""
|
119 |
+
|
120 |
+
def __init__(self) -> None:
|
121 |
+
self.client = Client()
|
122 |
+
self.url = Optional[str]
|
123 |
+
self.template = "{url}\n\n{content}"
|
124 |
+
|
125 |
+
self.processor = AutoProcessor.from_pretrained("openai/whisper-small.en")
|
126 |
+
self.model = WhisperForConditionalGeneration.from_pretrained(
|
127 |
+
"openai/whisper-small.en"
|
128 |
+
)
|
129 |
+
|
130 |
+
self.hosts = {"www.youtube.com", "youtube.com", "youtu.be"}
|
131 |
+
|
132 |
+
def match(self, input: JobInput) -> bool:
|
133 |
+
url = get_url(input.content.strip())
|
134 |
+
if url is None:
|
135 |
+
return False
|
136 |
+
|
137 |
+
parsed = urllib3.util.parse_url(url)
|
138 |
+
if parsed.host not in self.hosts:
|
139 |
+
return False
|
140 |
+
|
141 |
+
self.url = url
|
142 |
+
return True
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def make_batch(input_ids: torch.Tensor, max_len: int) -> torch.Tensor:
|
146 |
+
"""Create batches from last dimension, pad last batch if necessary
|
147 |
+
|
148 |
+
Examples
|
149 |
+
>>> import torch
|
150 |
+
>>> x = torch.zeros((1, 10, 213))
|
151 |
+
>>> YoutubeUrlProcessor.make_batch(x, max_len=100).shape
|
152 |
+
torch.Size([3, 10, 100])
|
153 |
+
|
154 |
+
"""
|
155 |
+
# ugly workaround, transformers whisper implementation requires a
|
156 |
+
# specific shape of input length, probably there is a better way...
|
157 |
+
batches = input_ids.split(max_len, dim=-1) # type: ignore
|
158 |
+
last = batches[-1]
|
159 |
+
n = last.shape[-1]
|
160 |
+
last = torch.nn.functional.pad(last, (1, max_len - n - 1), value=0.0)
|
161 |
+
batches = batches[:-1] + (last,)
|
162 |
+
return torch.concat(batches)
|
163 |
+
|
164 |
+
def process(self, input: JobInput) -> str:
|
165 |
+
if not isinstance(self.url, str):
|
166 |
+
raise TypeError("self.url must be a string")
|
167 |
+
|
168 |
+
config = get_config()
|
169 |
+
fname = download_yt_audio(self.url, max_length=config.max_yt_length)
|
170 |
+
audio = load_audio(fname, sampling_rate=config.sampling_rate)
|
171 |
+
inputs = self.processor(
|
172 |
+
audio,
|
173 |
+
return_tensors='pt',
|
174 |
+
sampling_rate=config.sampling_rate,
|
175 |
+
max_length=-1,
|
176 |
+
)
|
177 |
+
batch = self.make_batch(
|
178 |
+
inputs['input_features'], max_len=2 * self.model.config.max_source_positions
|
179 |
+
)
|
180 |
+
generated_ids = self.model.generate(batch)
|
181 |
+
transcription = self.processor.batch_decode(
|
182 |
+
generated_ids, skip_special_tokens=True
|
183 |
+
)
|
184 |
+
return self.template.format(url=self.url, content=" ".join(transcription))
|
src/gistillery/registry.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
from gistillery.base import JobInput
|
2 |
-
from gistillery.tools import Summarizer, Tagger, HfDefaultSummarizer, HfDefaultTagger
|
3 |
from gistillery.preprocessing import (
|
|
|
|
|
4 |
Processor,
|
5 |
RawTextProcessor,
|
6 |
-
|
7 |
-
DefaultUrlProcessor,
|
8 |
)
|
|
|
9 |
|
10 |
|
11 |
class ToolRegistry:
|
@@ -57,6 +58,7 @@ def get_tool_registry() -> ToolRegistry:
|
|
57 |
tagger = HfDefaultTagger()
|
58 |
|
59 |
_registry = ToolRegistry()
|
|
|
60 |
_registry.register_processor(ImageUrlProcessor())
|
61 |
_registry.register_processor(DefaultUrlProcessor())
|
62 |
_registry.register_processor(RawTextProcessor())
|
|
|
1 |
from gistillery.base import JobInput
|
|
|
2 |
from gistillery.preprocessing import (
|
3 |
+
DefaultUrlProcessor,
|
4 |
+
ImageUrlProcessor,
|
5 |
Processor,
|
6 |
RawTextProcessor,
|
7 |
+
YoutubeUrlProcessor,
|
|
|
8 |
)
|
9 |
+
from gistillery.tools import HfDefaultSummarizer, HfDefaultTagger, Summarizer, Tagger
|
10 |
|
11 |
|
12 |
class ToolRegistry:
|
|
|
58 |
tagger = HfDefaultTagger()
|
59 |
|
60 |
_registry = ToolRegistry()
|
61 |
+
_registry.register_processor(YoutubeUrlProcessor())
|
62 |
_registry.register_processor(ImageUrlProcessor())
|
63 |
_registry.register_processor(DefaultUrlProcessor())
|
64 |
_registry.register_processor(RawTextProcessor())
|