|
import os |
|
import re |
|
import json |
|
import streamlit as st |
|
from PIL import Image, ImageDraw |
|
import requests |
|
from io import BytesIO |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
from streamlit_chat import message as st_message |
|
|
|
import yaml |
|
|
|
st.set_page_config(page_title="Data Exploration", page_icon="🌍", layout="wide", initial_sidebar_state="collapsed") |
|
COLORS = sns.color_palette("Paired", n_colors=100).as_hex() |
|
|
|
def load_config(config_fn, field='data_explore') -> dict: |
|
config = yaml.load(open(config_fn), Loader=yaml.Loader) |
|
return config[field] |
|
|
|
def convert_from_prompt_tokens(s_with_region_tokens): |
|
"""Convert from strings with prompt tokens for prompt encoders |
|
|
|
e.g.: |
|
|
|
Input: "<Region><L12><24><L101><L777></Region>" |
|
|
|
Output: [0.012, 0.024, 0.101, 0.777] |
|
""" |
|
REGION_PATTERN = r'<Region>(\s*<L(\d{1,4})>\s*<L(\d{1,4})>\s*<L(\d{1,4})>\s*<L(\d{1,4})>\s*)</Region>' |
|
boxes = [] |
|
boxes_str = re.findall(REGION_PATTERN, s_with_region_tokens) |
|
for boxes_str_i in boxes_str: |
|
matched_str_i, boxes_str_i = boxes_str_i[0], boxes_str_i[1:] |
|
boxes.append(tuple([int(s)/1000 for s in boxes_str_i])) |
|
return boxes |
|
|
|
def parse_regions(s): |
|
pattern = r"\[([\d.,\s]+)\]" |
|
matches = re.findall(pattern, s) |
|
bboxes = [] |
|
points = [] |
|
for res in matches: |
|
res = eval(res) |
|
if len(res) == 4: |
|
|
|
x1, y1, x2, y2 = res |
|
bboxes.append((x1, y1, x2, y2)) |
|
else: |
|
x1, y1 = res |
|
points.append((x1, y1)) |
|
|
|
bboxes.extend(convert_from_prompt_tokens(s)) |
|
return list(set(bboxes)) |
|
|
|
def get_image(image_path, bboxes=None): |
|
|
|
if os.path.exists(image_path): |
|
image = Image.open(image_path).convert('RGB') |
|
else: |
|
|
|
response = requests.get(image_path) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
|
|
draw = ImageDraw.Draw(image, 'RGB') |
|
color_mapping = None |
|
if bboxes is not None: |
|
width, height = image.size |
|
color_mapping = [] |
|
for i, bbox_coords in enumerate(bboxes): |
|
color = COLORS[i] |
|
|
|
x1, y1, x2, y2 = bbox_coords |
|
x1 *= width |
|
y1 *= height |
|
x2 *= width |
|
y2 *= height |
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
|
|
|
color_mapping.append([bbox_coords, color]) |
|
|
|
color_mapping = dict(color_mapping) |
|
return image, color_mapping |
|
|
|
def insert_color(s, color_mapping): |
|
for coords, color in color_mapping.items(): |
|
coords_str = ', '.join([str(x) for x in coords]) |
|
s = s.replace('[' + coords_str + ']', f'<span style="color: {color}; font-weight: bold;">■</span>' + ' [' + coords_str + ']') |
|
|
|
return s |
|
|
|
modal_indicator = ['<image>', '<audio>', '<video>'] |
|
def show_one_msg(msg, modal_inputs): |
|
splits = re.split('(' + '|'.join(modal_indicator) + ')', msg) |
|
for s in splits: |
|
if s == '<image>': |
|
st.image(modal_inputs['image'].pop(0)) |
|
elif s == '<audio>': |
|
st.audio(modal_inputs['audio'].pop(0)) |
|
elif s == '<video>': |
|
st.video(modal_inputs['video'].pop(0)) |
|
else: |
|
st.write(s) |
|
|
|
def show_multimodal_example(example, col1, col2): |
|
with col1: |
|
info = example.get('info', {}) |
|
info['modal_inputs'] = example['modal_inputs'] |
|
st.json(info) |
|
|
|
with col2: |
|
conversations = example['conversations'] |
|
modal_inputs = example['modal_inputs'] |
|
for i in range(len(conversations) // 2): |
|
with st.chat_message("user"): |
|
show_one_msg(conversations[2*i]['value'], modal_inputs) |
|
with st.chat_message("assistant"): |
|
show_one_msg(conversations[2*i+1]['value'], modal_inputs) |
|
|
|
|
|
def show_example(example, col1, col2, enable_scores=True): |
|
if 'conversations' in example: |
|
regions = parse_regions(str(example['conversations'])) |
|
else: |
|
regions = parse_regions(str(example)) |
|
|
|
image_fn = example['image'] |
|
image, color_mapping = get_image(image_fn, regions) |
|
|
|
with col1: |
|
st.image(image) |
|
info = example.get('info', {}) |
|
info['id'] = example.get('id', 'N/A') |
|
info['image'] = image_fn |
|
if 'dataset' in example: |
|
info['source'] = example['dataset'] |
|
st.json(info) |
|
|
|
if len(color_mapping): |
|
table_md = "| 颜色 | 坐标 |\n| --- | --- |\n" |
|
for coords, color in color_mapping.items(): |
|
color_cell = f'<span style="color: {color}; font-weight: bold;">■</span>' |
|
table_md += f"| {color_cell} | {coords} |\n" |
|
|
|
|
|
st.markdown(table_md, unsafe_allow_html=True) |
|
|
|
score_dict = None |
|
with col2: |
|
if 'conversations' in example: |
|
if enable_scores: |
|
score_dict = {'image': image_fn, 'conversations': example['conversations']} |
|
with st.expander("Give a score based on the result above", expanded=True): |
|
quality_score = st.radio("问题质量分数",('Bad', 'Mediocre', 'Good'),key="quality", horizontal = True) |
|
format_score = st.radio("格式分数",('Bad', 'Mediocre', 'Good'),key="format", horizontal = True) |
|
score_dict['scores'] = { |
|
'quality': quality_score, 'format': format_score |
|
} |
|
st.subheader("Chat") |
|
conversations = example['conversations'] |
|
for i in range(len(conversations) // 2): |
|
st_message(conversations[2*i]['value'], is_user=True, key=image_fn + str(2*i)) |
|
st_message(conversations[2*i+1]['value'], is_user=False, key=image_fn + str(2*i+1)) |
|
|
|
if 'ground_truth' in example: |
|
|
|
gt = insert_color(json.dumps(example['ground_truth']), color_mapping) |
|
st.markdown(f"**Ground Truth:**\n\n{gt}", unsafe_allow_html=True) |
|
else: |
|
|
|
instruction = insert_color(example['instruction'], color_mapping) |
|
st.markdown(f"**Instruction:**\n\n{instruction}", unsafe_allow_html=True) |
|
|
|
|
|
if 'input' in example: |
|
input = insert_color(example['input'], color_mapping) |
|
st.markdown(f"**Input:**\n\n{input}", unsafe_allow_html=True) |
|
|
|
|
|
output = insert_color(example['output'], color_mapping) |
|
st.markdown(f"**Output:**\n\n{output}", unsafe_allow_html=True) |
|
|
|
if 'query' in example: |
|
|
|
query = insert_color(json.dumps(example['query']), color_mapping) |
|
st.markdown(f"**Query:**\n\n{query}", unsafe_allow_html=True) |
|
return score_dict |
|
|
|
def reset_state(): |
|
print('RESET') |
|
st.session_state['data_explore'] = {'idx': 0} |
|
st.session_state.scores = {} |
|
|
|
def load_dir_data(dir, dataset_configs): |
|
mapping_file = os.path.join(dir, 'mapping.yaml') |
|
assert os.path.exists(mapping_file) |
|
|
|
config = yaml.load(open(mapping_file), Loader=yaml.Loader) |
|
|
|
image_paths = config['image_paths'] |
|
image_paths['default'] = image_paths.get('default', '.') |
|
|
|
res = [] |
|
for k, v in config['mapping'].items(): |
|
if os.path.exists(os.path.join(dir, k + '.json')): |
|
data = json.load(open(os.path.join(dir, k + '.json'))) |
|
elif os.path.exists(os.path.join(dir, k + '.jsonl')): |
|
data = [json.loads(line) for line in open(os.path.join(dir, k + '.jsonl'))] |
|
elif os.path.exists(os.path.join(dir, k + '.txt')): |
|
data = [json.loads(line) for line in open(os.path.join(dir, k + '.txt'))] |
|
|
|
image_path = image_paths.get(v, image_paths['default']) |
|
for example in data: |
|
example['image'] = os.path.join(image_path, example['image']) |
|
example['dataset'] = k |
|
res.extend(data) |
|
|
|
return res |
|
|
|
@st.cache_data |
|
def load_data(fn, dataset_configs): |
|
if os.path.isdir(fn): |
|
res = load_dir_data(fn, dataset_configs) |
|
return res |
|
|
|
if fn.endswith(('.txt', '.jsonl')): |
|
res = [] |
|
for line in open(fn): |
|
example = json.loads(line) |
|
res.append(example) |
|
else: |
|
res = json.load(open(fn)) |
|
|
|
for example in res: |
|
dataset_path = dataset_configs[example.get('dataset', 'default')] |
|
|
|
if 'image' in example: |
|
example['image'] = os.path.join(dataset_path, example['image']) |
|
elif 'img_info' in example: |
|
if isinstance(example['img_info'], str): |
|
example['image'] = os.path.join(dataset_path, example['img_info']) |
|
else: |
|
if 'coco_url' in example['img_info']: |
|
example['image'] = example['img_info']['coco_url'] |
|
else: |
|
assert 'modal_inputs' in example |
|
|
|
return res |
|
|
|
dataset_configs = load_config('config.yaml') |
|
print(dataset_configs) |
|
data_paths = dataset_configs.get('data_paths', ['instruction_data']) |
|
|
|
files = [] |
|
def add_file(path): |
|
if os.path.exists(os.path.join(path, 'mapping.yaml')): |
|
files.append(path) |
|
else: |
|
for f in sorted(os.listdir(path)): |
|
file = os.path.join(path, f) |
|
if os.path.isfile(file) and file.endswith(('.txt', '.json')): |
|
files.append(file) |
|
else: |
|
add_file(file) |
|
|
|
for data_path in data_paths: |
|
add_file(data_path) |
|
|
|
|
|
st.session_state['data_explore'] = {'idx': 0} |
|
enable_score = st.sidebar.checkbox('Score it!', value=False) |
|
if enable_score and 'scores' not in st.session_state: |
|
st.session_state.scores = {} |
|
|
|
status_placeholder = st.empty() |
|
control_col1, control_col2 = st.columns(2) |
|
|
|
with control_col1: |
|
selected_file = st.selectbox('Select a file', files, on_change=reset_state) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
if selected_file: |
|
data = load_data(selected_file, dataset_configs) |
|
|
|
with control_col2: |
|
idx = st.number_input(f'Input an idx (Total: {len(data)})', min_value=0, max_value=len(data), value=st.session_state.get('data_explore', {}).get('idx', 0)) |
|
st.session_state['data_explore']['idx'] = idx |
|
|
|
if 'image' in data[idx]: |
|
show_example(data[idx], col1, col2, enable_scores=enable_score) |
|
else: |
|
show_multimodal_example(data[idx], col1, col2) |
|
|
|
if enable_score: |
|
name = st.sidebar.text_input("Username", placeholder = "Enter your name", value="cc") |
|
if st.sidebar.button(label ="Submit scores", key = "submit"): |
|
if name: |
|
score_path = f"score_results/{os.path.basename(selected_file)}_{name}.json" |
|
with open(score_path, "w") as score_file: |
|
json.dump(st.session_state.scores, score_file, indent = 4) |
|
status_placeholder.success("Successfully saved!") |
|
else: |
|
status_placeholder.error("Please enter your name on the sidebar!") |