File size: 3,717 Bytes
1378843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any, Generator

import boto3
from pydantic import BaseModel, Field, TypeAdapter
from tqdm import tqdm

from .utils import data_dir, env_str


@dataclass
class S3VoiceObj:
    key: str
    size: int

    @property
    def name(self) -> str:
        return self.key.split("/")[-1]

    @classmethod
    def from_s3_obj(cls, obj: Any) -> "S3VoiceObj":
        return S3VoiceObj(key=obj["Key"], size=obj["Size"])


class Voice(BaseModel):
    name: str
    model: str
    tts: str
    index: str = ""
    autotune: float | None = None
    clean: float | None = 0.5
    upscale: bool = False
    pitch: int = 0
    filter_radius: int = 3
    index_rate: float = 0.75
    rms_mix_rate: float = 1
    protect: float = 0.5
    hop_length: int = 128
    f0_method: str = "rmvpe"
    embedder_model: str = "contentvec"


class TTSVoice(BaseModel):
    name: str = Field(alias="ShortName")


class VoiceManager:
    def __init__(self) -> None:
        self.s3 = boto3.client("s3")
        self.bucket = env_str("BUCKET")
        self.prefix = env_str("VOICES_KEY_PREFIX")
        self.voices_dir = Path(data_dir("voices"))

    def _iter_s3_objects(self) -> Generator[S3VoiceObj, None, None]:
        response = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix)
        for obj in response.get("Contents", []):
            yield S3VoiceObj.from_s3_obj(obj)

    def get_voices_size_if_missing(self) -> int:
        """
        Calculate the total size of the voice files only if they do not exist locally.
        """
        total_size = 0
        paths: set[Path] = set()
        for obj in self._iter_s3_objects():
            destination_path = self.voices_dir / obj.name
            paths.add(destination_path)
            if not destination_path.exists() or destination_path.stat().st_size != obj.size:
                total_size += obj.size

        for path in self.voices_dir.glob("*"):
            if path not in paths:
                path.unlink()

        return total_size

    def download_voice_files(self, progress_bar: tqdm) -> None:
        """
        Download all voice files from s3 updating the global progress bar.
        """

        def callback(bytes_amount: int) -> None:
            progress_bar.update(bytes_amount)

        for obj in self._iter_s3_objects():
            destination_path = self.voices_dir / obj.name
            if not destination_path.exists() or destination_path.stat().st_size != obj.size:
                self.s3.download_file(Bucket=self.bucket, Key=obj.key, Filename=destination_path, Callback=callback)

    @cached_property
    def tts_voices(self) -> dict[str, TTSVoice]:
        path = Path("rvc/lib/tools/tts_voices.json")
        voices = TypeAdapter(list[TTSVoice]).validate_json(path.read_bytes())
        return {v.name: v for v in voices}

    @property
    def voice_names(self) -> list[str]:
        return list(self.voices.keys())

    @cached_property
    def voices(self) -> dict[str, Voice]:
        rv = {}
        for path in self.voices_dir.glob("*.json"):
            voice = Voice.model_validate_json(path.read_bytes())
            model_path = self.voices_dir / f"{voice.model}"
            if not model_path.exists():
                logging.warning("Voice %s missing model %s", voice.name, voice.model)
            elif voice.tts not in self.tts_voices:
                logging.warning("Voice %s references invalid tts %s", voice.name, voice.model)
            else:
                voice.model = str(model_path)
                rv[voice.name] = voice
        return rv


voice_manager = VoiceManager()