import logging from dataclasses import dataclass from functools import cached_property from pathlib import Path from typing import Any, Generator import boto3 from pydantic import BaseModel, Field, TypeAdapter from tqdm import tqdm from .utils import data_dir, env_str log = logging.getLogger(__name__) @dataclass class S3VoiceObj: key: str size: int @property def name(self) -> str: return self.key.split("/")[-1] @classmethod def from_s3_obj(cls, obj: Any) -> "S3VoiceObj": return S3VoiceObj(key=obj["Key"], size=obj["Size"]) class Voice(BaseModel): name: str tts: str class TTSVoice(BaseModel): name: str = Field(alias="ShortName") class VoiceManager: def __init__(self) -> None: self.s3 = boto3.client("s3") self.bucket = env_str("BUCKET") self.prefix = env_str("VOICES_KEY_PREFIX") self.voices_dir = Path(data_dir("voices")) def _iter_s3_objects(self) -> Generator[S3VoiceObj, None, None]: response = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix) for obj in response.get("Contents", []): yield S3VoiceObj.from_s3_obj(obj) def get_voices_size_if_missing(self) -> int: """ Calculate the total size of the voice files only if they do not exist locally. """ total_size = 0 paths: set[Path] = set() for obj in self._iter_s3_objects(): destination_path = self.voices_dir / obj.name paths.add(destination_path) if not destination_path.exists() or destination_path.stat().st_size != obj.size: total_size += obj.size for path in self.voices_dir.glob("*"): if path not in paths: path.unlink() return total_size def download_voice_files(self, progress_bar: tqdm) -> None: """ Download all voice files from s3 updating the global progress bar. """ def callback(bytes_amount: int) -> None: progress_bar.update(bytes_amount) for obj in self._iter_s3_objects(): destination_path = self.voices_dir / obj.name if not destination_path.exists() or destination_path.stat().st_size != obj.size: self.s3.download_file(Bucket=self.bucket, Key=obj.key, Filename=destination_path, Callback=callback) progress_bar.clear() log.info(f"Downloaded {obj.size:,} bytes to {destination_path}") progress_bar.display() @cached_property def tts_voices(self) -> dict[str, TTSVoice]: path = Path("rvc/lib/tools/tts_voices.json") voices = TypeAdapter(list[TTSVoice]).validate_json(path.read_bytes()) return {v.name: v for v in voices} @property def voice_names(self) -> list[str]: return list(self.voices.keys()) @cached_property def voices(self) -> dict[str, Voice]: rv = {} for path in sorted(self.voices_dir.glob("*.json")): voice = Voice.model_validate_json(path.read_bytes()) if voice.tts not in self.tts_voices: logging.warning("Voice %s references invalid tts %s", voice.name, voice.tts) else: rv[voice.name] = voice return rv voice_manager = VoiceManager()