import sys
from io import BytesIO

import numpy as np
import soundfile as sf
from pydub import AudioSegment, effects
import pyrubberband as pyrb

INT16_MAX = np.iinfo(np.int16).max


def audio_to_int16(audio_data: np.ndarray) -> np.ndarray:
    if (
        audio_data.dtype == np.float32
        or audio_data.dtype == np.float64
        or audio_data.dtype == np.float128
        or audio_data.dtype == np.float16
    ):
        audio_data = (audio_data * INT16_MAX).astype(np.int16)
    return audio_data


def pydub_to_np(audio: AudioSegment) -> tuple[int, np.ndarray]:
    """
    Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
    where each value is in range [-1.0, 1.0].
    Returns tuple (audio_np_array, sample_rate).
    """
    nd_array = np.array(audio.get_array_of_samples(), dtype=np.float32)
    if audio.channels != 1:
        nd_array = nd_array.reshape((-1, audio.channels))
    nd_array = nd_array / (1 << (8 * audio.sample_width - 1))

    return (
        audio.frame_rate,
        nd_array,
    )


def audiosegment_to_librosawav(audiosegment: AudioSegment) -> np.ndarray:
    """
    Converts pydub audio segment into np.float32 of shape [duration_in_seconds*sample_rate, channels],
    where each value is in range [-1.0, 1.0].
    """
    channel_sounds = audiosegment.split_to_mono()
    samples = [s.get_array_of_samples() for s in channel_sounds]

    fp_arr = np.array(samples).T.astype(np.float32)
    fp_arr /= np.iinfo(samples[0].typecode).max
    fp_arr = fp_arr.reshape(-1)

    return fp_arr


def ndarray_to_segment(
    ndarray: np.ndarray, frame_rate: int, sample_width: int = None, channels: int = None
) -> AudioSegment:
    buffer = BytesIO()
    sf.write(buffer, ndarray, frame_rate, format="wav", subtype="PCM_16")
    buffer.seek(0)
    sound: AudioSegment = AudioSegment.from_wav(buffer)

    if sample_width is None:
        sample_width = sound.sample_width
    if channels is None:
        channels = sound.channels

    return (
        sound.set_frame_rate(frame_rate)
        .set_sample_width(sample_width)
        .set_channels(channels)
    )


def apply_prosody_to_audio_segment(
    audio_segment: AudioSegment,
    rate: float = 1,
    volume: float = 0,
    pitch: int = 0,
    sr: int = 24000,
) -> AudioSegment:
    audio_data = audiosegment_to_librosawav(audio_segment)

    audio_data = apply_prosody_to_audio_data(audio_data, rate, volume, pitch, sr)

    audio_segment = ndarray_to_segment(
        audio_data, sr, audio_segment.sample_width, audio_segment.channels
    )

    return audio_segment


def apply_prosody_to_audio_data(
    audio_data: np.ndarray,
    rate: float = 1,
    volume: float = 0,
    pitch: float = 0,
    sr: int = 24000,
) -> np.ndarray:
    if rate != 1:
        audio_data = pyrb.time_stretch(audio_data, sr=sr, rate=rate)

    if volume != 0:
        audio_data = audio_data * volume

    if pitch != 0:
        audio_data = pyrb.pitch_shift(audio_data, sr=sr, n_steps=pitch)

    return audio_data


def apply_normalize(
    audio_data: np.ndarray,
    headroom: float = 1,
    sr: int = 24000,
):
    segment = ndarray_to_segment(audio_data, sr)
    segment = effects.normalize(seg=segment, headroom=headroom)

    return pydub_to_np(segment)


if __name__ == "__main__":
    input_file = sys.argv[1]

    time_stretch_factors = [0.5, 0.75, 1.5, 1.0]
    pitch_shift_factors = [-12, -5, 0, 5, 12]

    input_sound = AudioSegment.from_mp3(input_file)

    for time_factor in time_stretch_factors:
        output_wav = f"{input_file}_time_{time_factor}.wav"
        output_sound = apply_prosody_to_audio_segment(input_sound, rate=time_factor)
        output_sound.export(output_wav, format="wav")

    for pitch_factor in pitch_shift_factors:
        output_wav = f"{input_file}_pitch_{pitch_factor}.wav"
        output_sound = apply_prosody_to_audio_segment(input_sound, pitch=pitch_factor)
        output_sound.export(output_wav, format="wav")