File size: 3,638 Bytes
5352a13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
This module handles the management of the database.
"""
from dataclasses import dataclass
import enum
import os
from typing import List

import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
from google.cloud.firestore_v1 import base_query
import gradio as gr

from credentials import get_credentials_json


def get_required_env(name: str) -> str:
  value = os.getenv(name)
  if value is None:
    raise ValueError(f"Environment variable {name} is not set")
  return value


RATINGS_COLLECTION = get_required_env("RATINGS_COLLECTION")
SUMMARIZATIONS_COLLECTION = get_required_env("SUMMARIZATIONS_COLLECTION")
TRANSLATIONS_COLLECTION = get_required_env("TRANSLATIONS_COLLECTION")

if gr.NO_RELOAD:
  firebase_admin.initialize_app(credentials.Certificate(get_credentials_json()))
  db = firestore.client()


class Category(enum.Enum):
  SUMMARIZATION = "summarization"
  TRANSLATION = "translation"


@dataclass
class Rating:
  model: str
  rating: int


def get_ratings(category: Category, source_lang: str | None,
                target_lang: str | None) -> List[Rating] | None:
  doc_id = "#".join([category.value] +
                    [lang for lang in (source_lang, target_lang) if lang])
  # TODO(#37): Make it more clear what fields are in the document.
  doc_dict = db.collection(RATINGS_COLLECTION).document(doc_id).get().to_dict()
  if doc_dict is None:
    return None

  # TODO(#37): Return the timestamp as well.
  doc_dict.pop("timestamp")

  return [Rating(model, rating) for model, rating in doc_dict.items()]


def set_ratings(category: Category, ratings: List[Rating], source_lang: str,
                target_lang: str | None):
  source_lang_lowercase = source_lang.lower()
  target_lang_lowercase = target_lang.lower() if target_lang else None

  doc_id = "#".join([category.value, source_lang_lowercase] +
                    ([target_lang_lowercase] if target_lang_lowercase else []))
  doc_ref = db.collection(RATINGS_COLLECTION).document(doc_id)

  new_ratings = {rating.model: rating.rating for rating in ratings}
  new_ratings["timestamp"] = firestore.SERVER_TIMESTAMP
  doc_ref.set(new_ratings, merge=True)


@dataclass
class Battle:
  model_a: str
  model_b: str
  winner: str


def get_battles(category: Category, source_lang: str | None,
                target_lang: str | None) -> List[Battle]:
  source_lang_lowercase = source_lang.lower() if source_lang else None
  target_lang_lowercase = target_lang.lower() if target_lang else None

  if category == Category.SUMMARIZATION:
    collection = db.collection(SUMMARIZATIONS_COLLECTION).order_by("timestamp")

    if source_lang_lowercase:
      collection = collection.where(filter=base_query.FieldFilter(
          "model_a_response_language", "==", source_lang_lowercase)).where(
              filter=base_query.FieldFilter("model_b_response_language", "==",
                                            source_lang_lowercase))

  elif category == Category.TRANSLATION:
    collection = db.collection(TRANSLATIONS_COLLECTION).order_by("timestamp")

    if source_lang_lowercase:
      collection = collection.where(filter=base_query.FieldFilter(
          "source_language", "==", source_lang_lowercase))

    if target_lang_lowercase:
      collection = collection.where(filter=base_query.FieldFilter(
          "target_language", "==", target_lang_lowercase))

  else:
    raise ValueError(f"Invalid category: {category}")

  docs = collection.stream()
  battles = []
  for doc in docs:
    data = doc.to_dict()
    battles.append(Battle(data["model_a"], data["model_b"], data["winner"]))
  return battles