File size: 8,635 Bytes
824afbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f90f1b
824afbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import gradio as gr
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import matplotlib.pyplot as plt
from utils import load_ckpt, print_colored
from tokenizer import make_tokenizer
from model import get_hertz_dev_config
from typing import Tuple
import numpy as np
import os

# Global variables for model and tokenizer
global_generator = None
global_tokenizer = None
default_audio_path = "sample.wav"  # Changed from "testingtesting.wav"

def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
    """Initialize the model and tokenizer"""
    global global_generator, global_tokenizer
    
    if global_generator is not None and global_tokenizer is not None:
        return global_generator, global_tokenizer
        
    device = 'cuda' if T.cuda.is_available() else 'cpu'
    T.cuda.set_device(0) if device == 'cuda' else None
    
    print_colored("Initializing model and tokenizer...", "blue")
    global_tokenizer = make_tokenizer(device)
    model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation)
    
    global_generator = model_config()
    global_generator = global_generator.eval().to(T.bfloat16).to(device)
    print_colored("Model initialization complete!", "green")
    
    return global_generator, global_tokenizer

def process_audio(audio_path: str, sr: int) -> T.Tensor:
    """Load and preprocess audio file"""
    audio_tensor, sr = torchaudio.load(audio_path)
    

    if audio_tensor.shape[0] == 2:
        audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
    
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
        audio_tensor = resampler(audio_tensor)
    
    max_samples = 16000 * 60 * 5  # 5 minutes
    if audio_tensor.shape[1] > max_samples:
        audio_tensor = audio_tensor[:, :max_samples]
    
    return audio_tensor.unsqueeze(0)

def generate_completion(
    audio_file,
    prompt_len_seconds: float = 3.0,
    num_completions: int = 5,
    generation_seconds: float = 20.0,
    token_temp: float = 0.8,
    categorical_temp: float = 0.5,
    gaussian_temp: float = 0.1,
    progress=gr.Progress(track_tqdm=True)
) -> list:
    """Generate audio completions from the input audio"""
    device = 'cuda' if T.cuda.is_available() else 'cpu'
    
    # Use existing model and tokenizer
    generator, audio_tokenizer = global_generator, global_tokenizer
    
    progress(0, desc="Processing input audio...")
    # Process input audio
    prompt_audio = process_audio(audio_file, sr=16000)
    prompt_len = int(prompt_len_seconds * 8)
    
    progress(0.2, desc="Encoding prompt...")
    # Encode prompt
    with T.autocast(device_type='cuda', dtype=T.bfloat16):
        encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
    
    completions = []
    for i in range(num_completions):
        progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}")
        
        # Generate completion
        encoded_prompt = encoded_prompt_audio[:, :prompt_len]
        with T.autocast(device_type='cuda', dtype=T.bfloat16):
            completed_audio_batch = generator.completion(
                encoded_prompt,
                temps=(token_temp, (categorical_temp, gaussian_temp)),
                use_cache=True,
                gen_len=int(generation_seconds * 8)
            )
            
            decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16())
        
        # Process audio for output
        audio_tensor = decoded_completion.cpu().squeeze()
        if audio_tensor.ndim == 1:
            audio_tensor = audio_tensor.unsqueeze(0)
        audio_tensor = audio_tensor.float()
        
        if audio_tensor.abs().max() > 1:
            audio_tensor = audio_tensor / audio_tensor.abs().max()
        
        # Trim to include only the generated portion
        output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
        completions.append((16000, output_audio.numpy().T))
    
    progress(1.0, desc="Generation complete!")
    return completions

def create_interface():
    # Initialize model at startup
    init_model()
    
    with gr.Blocks(title="Audio Completion Generator") as app:
        gr.Markdown("""
        # Audio Completion Generator
        Upload an audio file (or use the default) and generate AI completions based on the prompt.
        """)
        
        with gr.Row():
            with gr.Column():
                # Load the default audio if it exists
                default_value = default_audio_path if os.path.exists(default_audio_path) else None
                
                audio_input = gr.Audio(
                    label="Input Audio",
                    type="filepath",
                    sources=["microphone", "upload"],
                    value=default_value
                )
                
                with gr.Row():
                    prompt_len = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=3,
                        step=0.5,
                        label="Prompt Length (seconds)"
                    )
                    default_num_completions = 5
                    num_completions = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=default_num_completions,
                        step=1,
                        label="Number of Completions"
                    )
                    gen_length = gr.Slider(
                        minimum=5,
                        maximum=60,
                        value=20,
                        step=5,
                        label="Generation Length (seconds)"
                    )
                
                with gr.Row():
                    token_temp = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.8,
                        step=0.1,
                        label="Token Temperature"
                    )
                    cat_temp = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.5,
                        step=0.1,
                        label="Categorical Temperature"
                    )
                    gauss_temp = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.1,
                        step=0.1,
                        label="Gaussian Temperature"
                    )
                
                generate_btn = gr.Button("Generate Completions")
                status_text = gr.Markdown("Ready")
            
            with gr.Column():
                output_audios = []
                for i in range(10):  # Create 10 audio components
                    output_audios.append(gr.Audio(
                        label=f"Generated Completion {i+1}",
                        type="numpy",
                        visible=False
                    ))

        def update_visibility(num):
            return [gr.update(visible=(i < num)) for i in range(10)]

        def generate_with_status(*args):
            status_text.value = "Processing input audio..."
            completions = generate_completion(*args)
            status_text.value = "Generation complete!"
            
            # Prepare outputs for all audio components
            outputs = []
            for i in range(10):
                if i < len(completions):
                    outputs.append(completions[i])
                else:
                    outputs.append(None)
            return outputs

        # Set initial visibility on load
        app.load(
            fn=update_visibility,
            inputs=[num_completions],
            outputs=output_audios
        )

        # Update visibility when slider changes
        num_completions.change(
            fn=update_visibility,
            inputs=[num_completions],
            outputs=output_audios
        )
        
        generate_btn.click(
            fn=generate_with_status,
            inputs=[
                audio_input,
                prompt_len,
                num_completions,
                gen_length,
                token_temp,
                cat_temp,
                gauss_temp
            ],
            outputs=output_audios
        )
    
    return app

if __name__ == "__main__":
    app = create_interface()
    app.launch(share=True)