Spaces:
Running
on
Zero
Running
on
Zero
import fire | |
import json | |
import os | |
import datasets | |
import random | |
import datetime | |
from pathlib import Path | |
from datetime import datetime | |
from PIL import Image | |
datasets.config.DEFAULT_MAX_BATCH_SIZE = 500 | |
def create_hf_battle_dataset(data_file: str, split="test", task_type="t2i_generation"): | |
if task_type == "t2i_generation": | |
features = datasets.Features( | |
{ | |
"index": datasets.Value("int32"), | |
"tstamp": datasets.Value("int32"), | |
"prompt": datasets.Value("string"), | |
"left_model": datasets.Value("string"), | |
"left_image": datasets.Image(), | |
"right_model": datasets.Value("string"), | |
"right_image": datasets.Image(), | |
"vote_type": datasets.Value("string"), | |
"winner": datasets.Value("string"), | |
"anony": datasets.Value("bool"), | |
"judge": datasets.Value("string"), | |
} | |
) | |
else: | |
raise ValueError(f"Task type {task_type} not supported") | |
hf_dataset = datasets.Dataset.from_list( | |
data_file, | |
features=features, | |
split=split, | |
) | |
return hf_dataset | |
def load_image(path:str): | |
try: | |
return Image.open(path) | |
except Exception as e: | |
print(f"Error loading image {path}: {e}") | |
return None | |
def get_date_from_time_stamp(unix_timestamp: int): | |
# Create a datetime object from the Unix timestamp | |
dt = datetime.fromtimestamp(unix_timestamp) | |
# Convert the datetime object to a string with the desired format | |
date_str = dt.strftime("%Y-%m-%d") | |
return date_str | |
def load_battle_image(battle, log_dir): | |
image_path = Path(log_dir) / f"{get_date_from_time_stamp(battle['tstamp'])}-convinput_images" / f"input_image_{battle['question_id']}.png" | |
return load_image(image_path) | |
def find_media_path(conv_id, task_type, log_dir): | |
media_directory_map = { | |
"t2i_generation": "images/generation", | |
"image_edition": "images/edition", | |
"text2video": "videos/generation" | |
} | |
if task_type == "t2i_generation": | |
media_path = Path(log_dir) / media_directory_map[task_type] / f"{conv_id}.jpg" | |
else: | |
raise ValueError(f"Task type {task_type} not supported") | |
return media_path | |
def main( | |
task_type='t2i_generation', | |
# data_file: str = "./results/latest/clean_battle_conv.json", | |
data_file: str = None, | |
repo_id: str = "TIGER-Lab/GenAI-Arena-human-eval", | |
log_dir: str = os.getenv("LOGDIR", "../GenAI-Arena-hf-logs/vote_log"), | |
config_name='battle', | |
split='test', | |
token = os.environ.get("HUGGINGFACE_TOKEN", None), | |
seed=42, | |
): | |
if data_file is None: | |
data_file = f"./results/latest/clean_battle_{task_type}.json" | |
if not os.path.exists(data_file): | |
raise ValueError(f"Data file {data_file} does not exist") | |
with open(data_file, "r") as f: | |
data = json.load(f) | |
# add index according to the tsamp | |
if seed is not None: | |
random.seed(seed) | |
data = sorted(data, key=lambda x: x['tstamp']) | |
required_keys_each_task = { | |
"image_editing": ["source_prompt", "target_prompt", "instruct_prompt"], | |
"t2i_generation": ["prompt"], | |
"video_generation": ["prompt"] | |
} | |
valid_data = [] | |
for i, battle in enumerate(data): | |
if any(key not in battle['inputs'] for key in required_keys_each_task[task_type]): | |
# print(battle['inputs']) | |
# print(f"Skipping battle {i} due to missing keys") | |
continue | |
valid_data.append(battle) | |
print(f"Total battles: {len(data)}, valid battles: {len(valid_data)}, removed battles: {len(data) - len(valid_data)}") | |
data = valid_data | |
# data = random.sample(data, 50 * 7+2) | |
for i, battle in enumerate(data): | |
battle['index'] = i | |
new_data = [] | |
if task_type == 't2i_generation': | |
for battle in data: | |
prompt = battle['inputs']['prompt'] | |
model_a = battle['model_a'] | |
model_b = battle['model_b'] | |
model_a_conv_id = battle['model_a_conv_id'] | |
model_b_conv_id = battle['model_b_conv_id'] | |
tstamp = battle['tstamp'] | |
vote_type = battle['vote_type'] | |
left_image_path = find_media_path(model_a_conv_id, task_type, log_dir) | |
right_image_path = find_media_path(model_b_conv_id, task_type, log_dir) | |
left_image = load_image(left_image_path) | |
right_image = load_image(right_image_path) | |
if left_image is None or right_image is None: | |
print(f"Skipping battle {battle['index']} due to missing images") | |
continue | |
new_data.append({ | |
"index": battle['index'], | |
"tstamp": tstamp, | |
"prompt": prompt, | |
"left_model": model_a, | |
"left_image": left_image, | |
"right_model": model_b, | |
"right_image": right_image, | |
"vote_type": vote_type, | |
"winner": battle['winner'], | |
"anony": battle['anony'], | |
"judge": battle['judge'], | |
}) | |
split = "test" | |
hf_dataset = create_hf_battle_dataset(new_data, split, task_type) | |
else: | |
raise ValueError(f"Task type {task_type} not supported") | |
print(hf_dataset) | |
print(f"Uploading to part {repo_id}:{split}...") | |
hf_dataset.push_to_hub( | |
repo_id=repo_id, | |
config_name=config_name, | |
split=split, | |
token=token, | |
commit_message=f"Add vision-arena {split} dataset", | |
) | |
print("Done!") | |
if __name__ == "__main__": | |
fire.Fire(main) |