Spaces:
Runtime error
Runtime error
import logging | |
from dataclasses import dataclass | |
from functools import cached_property | |
from pathlib import Path | |
from typing import Any, Generator | |
import boto3 | |
from pydantic import BaseModel, Field, TypeAdapter | |
from tqdm import tqdm | |
from .utils import data_dir, env_str | |
log = logging.getLogger(__name__) | |
class S3VoiceObj: | |
key: str | |
size: int | |
def name(self) -> str: | |
return self.key.split("/")[-1] | |
def from_s3_obj(cls, obj: Any) -> "S3VoiceObj": | |
return S3VoiceObj(key=obj["Key"], size=obj["Size"]) | |
class Voice(BaseModel): | |
name: str | |
tts: str | |
class TTSVoice(BaseModel): | |
name: str = Field(alias="ShortName") | |
class VoiceManager: | |
def __init__(self) -> None: | |
self.s3 = boto3.client("s3") | |
self.bucket = env_str("BUCKET") | |
self.prefix = env_str("VOICES_KEY_PREFIX") | |
self.voices_dir = Path(data_dir("voices")) | |
def _iter_s3_objects(self) -> Generator[S3VoiceObj, None, None]: | |
response = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix) | |
for obj in response.get("Contents", []): | |
yield S3VoiceObj.from_s3_obj(obj) | |
def get_voices_size_if_missing(self) -> int: | |
""" | |
Calculate the total size of the voice files only if they do not exist locally. | |
""" | |
total_size = 0 | |
paths: set[Path] = set() | |
for obj in self._iter_s3_objects(): | |
destination_path = self.voices_dir / obj.name | |
paths.add(destination_path) | |
if not destination_path.exists() or destination_path.stat().st_size != obj.size: | |
total_size += obj.size | |
for path in self.voices_dir.glob("*"): | |
if path not in paths: | |
path.unlink() | |
return total_size | |
def download_voice_files(self, progress_bar: tqdm) -> None: | |
""" | |
Download all voice files from s3 updating the global progress bar. | |
""" | |
def callback(bytes_amount: int) -> None: | |
progress_bar.update(bytes_amount) | |
for obj in self._iter_s3_objects(): | |
destination_path = self.voices_dir / obj.name | |
if not destination_path.exists() or destination_path.stat().st_size != obj.size: | |
self.s3.download_file(Bucket=self.bucket, Key=obj.key, Filename=destination_path, Callback=callback) | |
progress_bar.clear() | |
log.info(f"Downloaded {obj.size:,} bytes to {destination_path}") | |
progress_bar.display() | |
def tts_voices(self) -> dict[str, TTSVoice]: | |
path = Path("rvc/lib/tools/tts_voices.json") | |
voices = TypeAdapter(list[TTSVoice]).validate_json(path.read_bytes()) | |
return {v.name: v for v in voices} | |
def voice_names(self) -> list[str]: | |
return list(self.voices.keys()) | |
def voices(self) -> dict[str, Voice]: | |
rv = {} | |
for path in sorted(self.voices_dir.glob("*.json")): | |
voice = Voice.model_validate_json(path.read_bytes()) | |
if voice.tts not in self.tts_voices: | |
logging.warning("Voice %s references invalid tts %s", voice.name, voice.tts) | |
else: | |
rv[voice.name] = voice | |
return rv | |
voice_manager = VoiceManager() | |