import os
import re
import json
import streamlit as st
from PIL import Image, ImageDraw
import requests
from io import BytesIO
import seaborn as sns
import matplotlib.pyplot as plt
from streamlit_chat import message as st_message
import yaml
st.set_page_config(page_title="Data Exploration", page_icon="🌍", layout="wide", initial_sidebar_state="collapsed")
COLORS = sns.color_palette("Paired", n_colors=100).as_hex()
def load_config(config_fn, field='data_explore') -> dict:
config = yaml.load(open(config_fn), Loader=yaml.Loader)
return config[field]
def convert_from_prompt_tokens(s_with_region_tokens):
"""Convert from strings with prompt tokens for prompt encoders
e.g.:
Input: "<24>"
Output: [0.012, 0.024, 0.101, 0.777]
"""
REGION_PATTERN = r'(\s*\s*\s*\s*\s*)'
boxes = []
boxes_str = re.findall(REGION_PATTERN, s_with_region_tokens)
for boxes_str_i in boxes_str:
matched_str_i, boxes_str_i = boxes_str_i[0], boxes_str_i[1:]
boxes.append(tuple([int(s)/1000 for s in boxes_str_i]))
return boxes
def parse_regions(s):
pattern = r"\[([\d.,\s]+)\]"
matches = re.findall(pattern, s)
bboxes = []
points = []
for res in matches:
res = eval(res)
if len(res) == 4:
# bbox
x1, y1, x2, y2 = res
bboxes.append((x1, y1, x2, y2))
else:
x1, y1 = res
points.append((x1, y1))
bboxes.extend(convert_from_prompt_tokens(s))
return list(set(bboxes))
def get_image(image_path, bboxes=None):
if os.path.exists(image_path):
image = Image.open(image_path).convert('RGB')
else:
# 从URL获取图片
response = requests.get(image_path)
image = Image.open(BytesIO(response.content)).convert('RGB')
draw = ImageDraw.Draw(image, 'RGB')
color_mapping = None
if bboxes is not None:
width, height = image.size
color_mapping = []
for i, bbox_coords in enumerate(bboxes):
color = COLORS[i]
x1, y1, x2, y2 = bbox_coords
x1 *= width
y1 *= height
x2 *= width
y2 *= height
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
color_mapping.append([bbox_coords, color])
color_mapping = dict(color_mapping)
return image, color_mapping
def insert_color(s, color_mapping):
for coords, color in color_mapping.items():
coords_str = ', '.join([str(x) for x in coords])
s = s.replace('[' + coords_str + ']', f'■' + ' [' + coords_str + ']')
return s
modal_indicator = ['', '