File size: 3,411 Bytes
1378843
 
 
 
 
 
 
 
 
 
 
 
2c01ee6
 
1378843
 
c6fd5b2
1378843
 
 
 
 
 
 
 
c6fd5b2
 
1378843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6fd5b2
1378843
c6fd5b2
1378843
 
c6fd5b2
 
 
 
1378843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c01ee6
 
 
1378843
 
 
c6fd5b2
1378843
 
 
 
 
 
 
 
 
 
2c01ee6
1378843
571ece7
 
1378843
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any, Generator

import boto3
from pydantic import BaseModel, Field, TypeAdapter
from tqdm import tqdm

from .utils import data_dir, env_str

log = logging.getLogger(__name__)


@dataclass
class S3Obj:
    key: str
    size: int

    @property
    def name(self) -> str:
        return self.key.split("/")[-1]

    @classmethod
    def from_s3_obj(_cls, obj: Any) -> "S3Obj":
        return S3Obj(key=obj["Key"], size=obj["Size"])


class Voice(BaseModel):
    name: str
    tts: str


class TTSVoice(BaseModel):
    name: str = Field(alias="ShortName")


class VoiceManager:
    def __init__(self) -> None:
        self.s3 = boto3.client("s3")
        self.bucket = env_str("BUCKET")
        self.prefix = env_str("VOICES_KEY_PREFIX")
        self.voices_dir = data_dir("voices")

    def _iter_s3_objects(self) -> Generator[S3Obj, None, None]:
        response = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix)
        for obj in response.get("Contents", []):
            s3_obj = S3Obj.from_s3_obj(obj)
            if not s3_obj.name.endswith(".json"):
                continue
            yield s3_obj

    def get_voices_size_if_missing(self) -> int:
        """
        Calculate the total size of the voice files only if they do not exist locally.
        """
        total_size = 0
        paths: set[Path] = set()
        for obj in self._iter_s3_objects():
            destination_path = self.voices_dir / obj.name
            paths.add(destination_path)
            if not destination_path.exists() or destination_path.stat().st_size != obj.size:
                total_size += obj.size

        for path in self.voices_dir.glob("*"):
            if path not in paths:
                path.unlink()

        return total_size

    def download_voice_files(self, progress_bar: tqdm) -> None:
        """
        Download all voice files from s3 updating the global progress bar.
        """

        def callback(bytes_amount: int) -> None:
            progress_bar.update(bytes_amount)

        for obj in self._iter_s3_objects():
            destination_path = self.voices_dir / obj.name
            if not destination_path.exists() or destination_path.stat().st_size != obj.size:
                self.s3.download_file(Bucket=self.bucket, Key=obj.key, Filename=destination_path, Callback=callback)
                progress_bar.clear()
                log.info(f"Downloaded {obj.size:,} bytes to {destination_path}")
                progress_bar.display()

    @cached_property
    def tts_voices(self) -> dict[str, TTSVoice]:
        path = Path(__file__).parent / "tts_voices.json"
        voices = TypeAdapter(list[TTSVoice]).validate_json(path.read_bytes())
        return {v.name: v for v in voices}

    @property
    def voice_names(self) -> list[str]:
        return list(self.voices.keys())

    @cached_property
    def voices(self) -> dict[str, Voice]:
        rv = {}
        for path in sorted(self.voices_dir.glob("*.json")):
            voice = Voice.model_validate_json(path.read_bytes())
            if voice.tts not in self.tts_voices:
                logging.warning("Voice %s references invalid tts %s", voice.name, voice.tts)
            else:
                rv[voice.name] = voice
        return rv


voice_manager = VoiceManager()