diff --git a/README.md b/README.md index 876c566543532d521f06a0d1ce3b8b01e78923ed..dfe7b99f5bc577ac02e459678a370b2c271bf0be 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- title: DataViz -emoji: 📉 -colorFrom: red -colorTo: indigo +emoji: 👁 +colorFrom: blue +colorTo: pink sdk: streamlit sdk_version: 1.36.0 app_file: app.py diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2a51a0fdcaba7fe979b0e7fd1d2ecd6b261b3d24 --- /dev/null +++ b/app.py @@ -0,0 +1,303 @@ +import os +import re +import json +import streamlit as st +from PIL import Image, ImageDraw +import requests +from io import BytesIO +import seaborn as sns +import matplotlib.pyplot as plt +from streamlit_chat import message as st_message + +import yaml + +st.set_page_config(page_title="Data Exploration", page_icon="🌍", layout="wide", initial_sidebar_state="collapsed") +COLORS = sns.color_palette("Paired", n_colors=100).as_hex() + +def load_config(config_fn, field='data_explore') -> dict: + config = yaml.load(open(config_fn), Loader=yaml.Loader) + return config[field] + +def convert_from_prompt_tokens(s_with_region_tokens): + """Convert from strings with prompt tokens for prompt encoders + + e.g.: + + Input: "<24>" + + Output: [0.012, 0.024, 0.101, 0.777] + """ + REGION_PATTERN = r'(\s*\s*\s*\s*\s*)' + boxes = [] + boxes_str = re.findall(REGION_PATTERN, s_with_region_tokens) + for boxes_str_i in boxes_str: + matched_str_i, boxes_str_i = boxes_str_i[0], boxes_str_i[1:] + boxes.append(tuple([int(s)/1000 for s in boxes_str_i])) + return boxes + +def parse_regions(s): + pattern = r"\[([\d.,\s]+)\]" + matches = re.findall(pattern, s) + bboxes = [] + points = [] + for res in matches: + res = eval(res) + if len(res) == 4: + # bbox + x1, y1, x2, y2 = res + bboxes.append((x1, y1, x2, y2)) + else: + x1, y1 = res + points.append((x1, y1)) + + bboxes.extend(convert_from_prompt_tokens(s)) + return list(set(bboxes)) + +def get_image(image_path, bboxes=None): + + if os.path.exists(image_path): + image = Image.open(image_path).convert('RGB') + else: + # 从URL获取图片 + response = requests.get(image_path) + image = Image.open(BytesIO(response.content)).convert('RGB') + + draw = ImageDraw.Draw(image, 'RGB') + color_mapping = None + if bboxes is not None: + width, height = image.size + color_mapping = [] + for i, bbox_coords in enumerate(bboxes): + color = COLORS[i] + + x1, y1, x2, y2 = bbox_coords + x1 *= width + y1 *= height + x2 *= width + y2 *= height + draw.rectangle([x1, y1, x2, y2], outline=color, width=3) + + color_mapping.append([bbox_coords, color]) + + color_mapping = dict(color_mapping) + return image, color_mapping + +def insert_color(s, color_mapping): + for coords, color in color_mapping.items(): + coords_str = ', '.join([str(x) for x in coords]) + s = s.replace('[' + coords_str + ']', f'' + ' [' + coords_str + ']') + + return s + +modal_indicator = ['', '