DongfuJiang commited on
Commit
45557a3
2 Parent(s): cf709ea 8d4ac15

Merge branch 'main' of https://huggingface.co/spaces/TIGER-Lab/GenAI-Arena

Browse files
model/model_manager.py CHANGED
@@ -10,6 +10,7 @@ from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_mus
10
  from .pre_download import pre_download_all_models, pre_download_video_models
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import torch
 
13
 
14
  def debug_packages():
15
  import pkg_resources
@@ -17,7 +18,59 @@ def debug_packages():
17
  installed_packages = pkg_resources.working_set
18
  for package in installed_packages:
19
  print(f"{package.key}=={package.version}")
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class ModelManager:
22
  def __init__(self, enable_nsfw=False, do_pre_download=False, do_debug_packages=False):
23
  self.model_ig_list = IMAGE_GENERATION_MODELS
@@ -52,6 +105,13 @@ class ModelManager:
52
  else:
53
  self.guard_tokenizer = None
54
  self.guard = None
 
 
 
 
 
 
 
55
 
56
  def NSFW_filter(self, prompt):
57
  chat = [{"role": "user", "content": prompt}]
 
10
  from .pre_download import pre_download_all_models, pre_download_video_models
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import torch
13
+ import re
14
 
15
  def debug_packages():
16
  import pkg_resources
 
18
  installed_packages = pkg_resources.working_set
19
  for package in installed_packages:
20
  print(f"{package.key}=={package.version}")
21
+
22
+ def fetch_unsafe_words(file_path):
23
+ """
24
+ Loads unsafe words from a file and returns them as a list.
25
+ """
26
+ try:
27
+ with open(file_path, 'r') as file:
28
+ # Read lines from file and strip any extra whitespace
29
+ unsafe_words = [line.strip() for line in file.readlines()]
30
+ # Remove any empty strings that may result from empty lines
31
+ unsafe_words = [word for word in unsafe_words if word]
32
+ return unsafe_words
33
+ except Exception as e:
34
+ print(f"Error loading file: {e}. Using default unsafe words.")
35
+ # Default unsafe words list
36
+ return [
37
+ "anal", "anus", "arse", "ass", "ballsack", "bastard", "bdsm", "bitch", "bimbo",
38
+ "blow job", "blowjob", "blue waffle", "boob", "booobs", "breasts", "booty call",
39
+ "boner", "bondage", "bullshit", "busty", "butthole", "cawk", "chink", "clit",
40
+ "cnut", "cock", "cokmuncher", "cowgirl", "crap", "crotch", "cum", "cunt", "damn",
41
+ "dick", "dildo", "dink", "deepthroat", "deep throat", "dog style", "doggie style",
42
+ "doggy style", "doosh", "douche", "duche", "ejaculate", "ejaculating",
43
+ "ejaculation", "ejakulate", "erotic", "erotism", "fag", "fatass", "femdom",
44
+ "fingering", "footjob", "foot job", "fuck", "fcuk", "fingerfuck", "fistfuck",
45
+ "fook", "fooker", "fuk", "gangbang", "gang bang", "gaysex", "handjob",
46
+ "hand job", "hentai", "hooker", "hoer", "homo", "horny", "incest", "jackoff",
47
+ "jack off", "jerkoff", "jerk off", "jizz", "masturbate", "mofo", "mothafuck",
48
+ "motherfuck", "milf", "muff", "nigga", "nigger", "nipple", "nob", "numbnuts",
49
+ "nutsack", "nude", "orgy", "orgasm", "panty", "panties", "penis", "playboy",
50
+ "porn", "pussy", "pussies", "rape", "raping", "rapist", "rectum", "retard",
51
+ "rimming", "sadist", "sadism", "scrotum", "sex", "semen", "shemale", "she male",
52
+ "shit", "slut", "spunk", "strip club", "stripclub", "tit", "threesome",
53
+ "three some", "throating", "twat", "viagra", "vagina", "wank", "whore", "whoar",
54
+ "xxx"
55
+ ]
56
+
57
+ def check_prompt_safety(prompt, unsafe_words_file='./profanity_words.txt'):
58
+ """
59
+ Checking prompt safety. Returns boolean (Not Safe = False, Safe = True)
60
+ """
61
+ # Load unsafe words from the provided file or use default if loading fails
62
+ unsafe_words = fetch_unsafe_words(unsafe_words_file)
63
+
64
+ # Convert input string to lowercase to ensure case-insensitive matching
65
+ prompt = prompt.lower()
66
+
67
+ # Check if any unsafe word is in the input string
68
+ for word in unsafe_words:
69
+ # Use regex to match whole words only
70
+ if re.search(r'\b' + re.escape(word) + r'\b', prompt):
71
+ return False
72
+ return True
73
+
74
  class ModelManager:
75
  def __init__(self, enable_nsfw=False, do_pre_download=False, do_debug_packages=False):
76
  self.model_ig_list = IMAGE_GENERATION_MODELS
 
105
  else:
106
  self.guard_tokenizer = None
107
  self.guard = None
108
+
109
+ def NSFW_filter_simple(self, prompt):
110
+ is_safe = check_prompt_safety(prompt)
111
+ if is_safe:
112
+ return "safe"
113
+ else:
114
+ return "unsafe"
115
 
116
  def NSFW_filter(self, prompt):
117
  chat = [{"role": "user", "content": prompt}]
model/profanity_words.txt ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anal
2
+ anus
3
+ arse
4
+ ass
5
+ ballsack
6
+ bastard
7
+ bdsm
8
+ bitch
9
+ bimbo
10
+ blow job
11
+ blowjob
12
+ blue waffle
13
+ boob
14
+ booobs
15
+ breasts
16
+ booty call
17
+ boner
18
+ bondage
19
+ bullshit
20
+ busty
21
+ butthole
22
+ cawk
23
+ chink
24
+ clit
25
+ cnut
26
+ cock
27
+ cokmuncher
28
+ cowgirl
29
+ crap
30
+ crotch
31
+ cum
32
+ cunt
33
+ damn
34
+ dick
35
+ dildo
36
+ dink
37
+ deepthroat
38
+ deep throat
39
+ dog style
40
+ doggie style
41
+ doggy style
42
+ doosh
43
+ douche
44
+ duche
45
+ ejaculate
46
+ ejaculating
47
+ ejaculation
48
+ ejakulate
49
+ erotic
50
+ erotism
51
+ fag
52
+ fatass
53
+ femdom
54
+ fingering
55
+ footjob
56
+ foot job
57
+ fuck
58
+ fcuk
59
+ fingerfuck
60
+ fistfuck
61
+ fook
62
+ fooker
63
+ fuk
64
+ gangbang
65
+ gang bang
66
+ gaysex
67
+ handjob
68
+ hand job
69
+ hentai
70
+ hooker
71
+ hoer
72
+ homo
73
+ horny
74
+ incest
75
+ jackoff
76
+ jack off
77
+ jerkoff
78
+ jerk off
79
+ jizz
80
+ masturbate
81
+ mofo
82
+ mothafuck
83
+ motherfuck
84
+ milf
85
+ muff
86
+ nigga
87
+ nigger
88
+ nipple
89
+ nob
90
+ numbnuts
91
+ nutsack
92
+ nude
93
+ orgy
94
+ orgasm
95
+ panty
96
+ panties
97
+ penis
98
+ playboy
99
+ porn
100
+ pussy
101
+ pussies
102
+ rape
103
+ raping
104
+ rapist
105
+ rectum
106
+ retard
107
+ rimming
108
+ sadist
109
+ sadism
110
+ scrotum
111
+ sex
112
+ semen
113
+ shemale
114
+ she male
115
+ shit
116
+ slut
117
+ spunk
118
+ strip club
119
+ stripclub
120
+ tit
121
+ threesome
122
+ three some
123
+ throating
124
+ twat
125
+ viagra
126
+ vagina
127
+ wank
128
+ whore
129
+ whoar
130
+ xxx
serve/leaderboard.py CHANGED
@@ -49,13 +49,6 @@ def load_leaderboard_table_csv(filename, add_hyperlink=True):
49
  for col in df.columns:
50
  if "Arena Elo rating" in col:
51
  df[col] = df[col].apply(lambda x: int(x) if x != "-" else np.nan)
52
- elif col == "MMLU":
53
- df[col] = df[col].apply(lambda x: round(x * 100, 1) if x != "-" else np.nan)
54
- elif col == "MT-bench (win rate %)":
55
- df[col] = df[col].apply(lambda x: round(x, 1) if x != "-" else np.nan)
56
- elif col == "MT-bench (score)":
57
- df[col] = df[col].apply(lambda x: round(x, 2) if x != "-" else np.nan)
58
-
59
  if add_hyperlink and col == "Model":
60
  df[col] = df.apply(lambda row: model_hyperlink(row[col], row["Link"]), axis=1)
61
  return df
@@ -111,9 +104,6 @@ def get_full_table(anony_arena_df, full_arena_df, model_table_df):
111
  row.append(np.nan)
112
  row.append("N/A")
113
  row.append(np.nan)
114
- # row.append(model_table_df.iloc[i]["MT-bench (score)"])
115
- # row.append(model_table_df.iloc[i]["Num Battles"])
116
- # row.append(model_table_df.iloc[i]["MMLU"])
117
  # Organization
118
  row.append(model_table_df.iloc[i]["Organization"])
119
  # license
@@ -244,7 +234,7 @@ def build_leaderboard_tab(elo_results_file, leaderboard_table_file, show_plot=Tr
244
  value=arena_table_vals,
245
  elem_id="arena_leaderboard_dataframe",
246
  height=700,
247
- column_widths=[30, 50, 30, 30, 30, 70, 150],
248
  wrap=True,
249
  )
250
  with gr.Tab("Full Leaderboard", id=1):
@@ -266,7 +256,7 @@ def build_leaderboard_tab(elo_results_file, leaderboard_table_file, show_plot=Tr
266
  datatype=["str", "markdown", "number", "str", "number", "str", "number", "str", "str"],
267
  value=full_table_vals,
268
  elem_id="full_leaderboard_dataframe",
269
- column_widths=[30, 50, 30, 30, 30, 30, 30, 70, 150],
270
  height=700,
271
  wrap=True,
272
  )
 
49
  for col in df.columns:
50
  if "Arena Elo rating" in col:
51
  df[col] = df[col].apply(lambda x: int(x) if x != "-" else np.nan)
 
 
 
 
 
 
 
52
  if add_hyperlink and col == "Model":
53
  df[col] = df.apply(lambda row: model_hyperlink(row[col], row["Link"]), axis=1)
54
  return df
 
104
  row.append(np.nan)
105
  row.append("N/A")
106
  row.append(np.nan)
 
 
 
107
  # Organization
108
  row.append(model_table_df.iloc[i]["Organization"])
109
  # license
 
234
  value=arena_table_vals,
235
  elem_id="arena_leaderboard_dataframe",
236
  height=700,
237
+ column_widths=[30, 70, 30, 30, 30, 70, 100],
238
  wrap=True,
239
  )
240
  with gr.Tab("Full Leaderboard", id=1):
 
256
  datatype=["str", "markdown", "number", "str", "number", "str", "number", "str", "str"],
257
  value=full_table_vals,
258
  elem_id="full_leaderboard_dataframe",
259
+ column_widths=[30, 70, 30, 30, 30, 30, 30, 70, 100],
260
  height=700,
261
  wrap=True,
262
  )