Spaces:
Sleeping
Sleeping
yadda yadda... getting rid of flash attention for now
Browse files
app.py
CHANGED
@@ -25,40 +25,19 @@ MODEL_NAME = "openai/whisper-large-v3-turbo"
|
|
25 |
BATCH_SIZE = 8 # Optimized for better GPU utilization
|
26 |
YT_LENGTH_LIMIT_S = 10800 # 3 hours
|
27 |
DATASET_NAME = "dwb2023/yt-transcripts-v3"
|
|
|
28 |
|
29 |
# Environment setup
|
30 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
31 |
|
32 |
-
|
33 |
-
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
34 |
-
MODEL_NAME,
|
35 |
-
use_cache=False,
|
36 |
-
device_map="auto",
|
37 |
-
low_cpu_mem_usage=True,
|
38 |
-
attn_implementation="flash_attention_2",
|
39 |
-
torch_dtype=torch.bfloat16
|
40 |
-
)
|
41 |
-
|
42 |
-
# Flash Attention setup for memory and speed optimization if supported
|
43 |
-
try:
|
44 |
-
from flash_attn import flash_attn_fn
|
45 |
-
model.config.use_flash_attention = True
|
46 |
-
except ImportError:
|
47 |
-
print("Flash Attention is not available. Proceeding without it.")
|
48 |
-
|
49 |
-
# Note: torch.compile is not compatible with Flash Attention or the chunked long-form algorithm.
|
50 |
-
|
51 |
-
# Processor setup
|
52 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
53 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
|
54 |
|
55 |
# Pipeline setup
|
56 |
pipe = pipeline(
|
57 |
task="automatic-speech-recognition",
|
58 |
-
model=
|
59 |
-
|
60 |
-
|
61 |
-
chunk_length_s=30, # 30 seconds
|
62 |
)
|
63 |
|
64 |
def reset_and_update_dataset(new_data):
|
|
|
25 |
BATCH_SIZE = 8 # Optimized for better GPU utilization
|
26 |
YT_LENGTH_LIMIT_S = 10800 # 3 hours
|
27 |
DATASET_NAME = "dwb2023/yt-transcripts-v3"
|
28 |
+
FILE_LIMIT_MB = 1000
|
29 |
|
30 |
# Environment setup
|
31 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
32 |
|
33 |
+
device = 0 if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# Pipeline setup
|
36 |
pipe = pipeline(
|
37 |
task="automatic-speech-recognition",
|
38 |
+
model=MODEL_NAME,
|
39 |
+
chunk_length_s=30,
|
40 |
+
device=device,
|
|
|
41 |
)
|
42 |
|
43 |
def reset_and_update_dataset(new_data):
|