dwb2023 commited on
Commit
4c90570
1 Parent(s): 793d15e

yadda yadda... getting rid of flash attention for now

Browse files
Files changed (1) hide show
  1. app.py +5 -26
app.py CHANGED
@@ -25,40 +25,19 @@ MODEL_NAME = "openai/whisper-large-v3-turbo"
25
  BATCH_SIZE = 8 # Optimized for better GPU utilization
26
  YT_LENGTH_LIMIT_S = 10800 # 3 hours
27
  DATASET_NAME = "dwb2023/yt-transcripts-v3"
 
28
 
29
  # Environment setup
30
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
31
 
32
- # Model setup
33
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
34
- MODEL_NAME,
35
- use_cache=False,
36
- device_map="auto",
37
- low_cpu_mem_usage=True,
38
- attn_implementation="flash_attention_2",
39
- torch_dtype=torch.bfloat16
40
- )
41
-
42
- # Flash Attention setup for memory and speed optimization if supported
43
- try:
44
- from flash_attn import flash_attn_fn
45
- model.config.use_flash_attention = True
46
- except ImportError:
47
- print("Flash Attention is not available. Proceeding without it.")
48
-
49
- # Note: torch.compile is not compatible with Flash Attention or the chunked long-form algorithm.
50
-
51
- # Processor setup
52
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
53
- feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
54
 
55
  # Pipeline setup
56
  pipe = pipeline(
57
  task="automatic-speech-recognition",
58
- model=model,
59
- tokenizer=tokenizer,
60
- feature_extractor=feature_extractor,
61
- chunk_length_s=30, # 30 seconds
62
  )
63
 
64
  def reset_and_update_dataset(new_data):
 
25
  BATCH_SIZE = 8 # Optimized for better GPU utilization
26
  YT_LENGTH_LIMIT_S = 10800 # 3 hours
27
  DATASET_NAME = "dwb2023/yt-transcripts-v3"
28
+ FILE_LIMIT_MB = 1000
29
 
30
  # Environment setup
31
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
32
 
33
+ device = 0 if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # Pipeline setup
36
  pipe = pipeline(
37
  task="automatic-speech-recognition",
38
+ model=MODEL_NAME,
39
+ chunk_length_s=30,
40
+ device=device,
 
41
  )
42
 
43
  def reset_and_update_dataset(new_data):