Spaces:
Running
on
Zero
Running
on
Zero
import dataclasses | |
import logging | |
import os | |
from typing import Any, Dict, List | |
import gradio as gr | |
import PIL.Image as Image | |
import PIL.ImageOps as ImageOps | |
import spaces | |
import torch | |
from peft import PeftModel | |
from transformers import AutoProcessor | |
from transformers import Idefics2ForConditionalGeneration, Idefics2Processor | |
from adapter import IdeficsAdapter | |
from config_generator import GameConfig, generate_game_config | |
from utils import device, nested_to_device, sorted_list | |
import copy | |
### Constants | |
IMG_DIR = "tangram_pngs" | |
### Bot server | |
GEN_KWS: Dict[str, Any] = { | |
"max_new_tokens": 10, | |
"do_sample": True, | |
"temperature": 1.0, | |
"output_logits": True, | |
"return_dict_in_generate": True, | |
"remove_invalid_values": True, # just to be safe | |
"renormalize_logits": True, | |
"suppress_tokens": IdeficsAdapter.SUPPRESS_TOKEN_IDS | |
} | |
def get_model_response( # predict | |
model: PeftModel, adapter_name: str, adapter: IdeficsAdapter, | |
image_paths: List[str], chat : str, chats: List[str], | |
previous_selected: List[List[str]] | |
) -> List[str]: | |
if model.active_adapter != adapter_name: | |
model.set_adapter(adapter_name) | |
model.to(device()) | |
new_chats = chats + [chat] | |
currently_selected = previous_selected[-1] if len(previous_selected) > 0 else [] | |
model_input: Dict[str, Any] = adapter.compose( | |
image_paths, new_chats, previous_selected, True, False) | |
model_input = nested_to_device(model_input) | |
with torch.inference_mode(), torch.autocast(device_type=device().type, | |
dtype=torch.bfloat16): | |
model_output = model.generate(**model_input, **GEN_KWS) | |
decoded_out: str = adapter.tokenizer.decode( | |
model_output.sequences[0], skip_special_tokens=True) | |
model_clicks = adapter.parse( | |
image_paths, decoded_out, currently_selected) | |
if len(model_clicks) == 0: | |
logging.warning("empty clicks by model") | |
model_clicks = [image_paths[0]] | |
logging.debug(f"{image_paths=}") | |
logging.debug(f"selecting {model_clicks}") | |
prob = -1 | |
else: | |
prob = -3 | |
logging.debug(f"{prob=}") | |
logging.info(f"User input: {chat}") | |
logging.info(f"Model selected: {model_clicks}") | |
logging.debug(f"Model output: {decoded_out}") | |
return model_clicks | |
def get_model() -> PeftModel: | |
model_id = 'lil-lab/respect' | |
checkpoint = "HuggingFaceM4/idefics2-8b" | |
model = Idefics2ForConditionalGeneration.from_pretrained( | |
checkpoint, torch_dtype=torch.bfloat16,) | |
peft_model = PeftModel.from_pretrained( | |
model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp") | |
# Add other adapter - hack to avoid conflict | |
lora_config = copy.deepcopy(peft_model.active_peft_config) | |
targets = list(set(n[:n.find('lora')-1] for n, _ in model.named_parameters() | |
if 'lora' in n)) | |
lora_config.target_modules = targets | |
peft_model.add_adapter("r0", lora_config) | |
peft_model.load_adapter(model_id, "r0", is_trainable=False, revision="r0", | |
peft_config=lora_config) | |
return peft_model | |
def get_processor() -> Idefics2Processor: | |
checkpoint = "HuggingFaceM4/idefics2-8b" | |
processor = AutoProcessor.from_pretrained( | |
checkpoint, do_image_splitting=False, | |
size={"longest_edge": 224, "shortest_edge": 224}) | |
return processor | |
def get_adapter() -> IdeficsAdapter: | |
processor = get_processor() | |
return IdeficsAdapter(IMG_DIR, processor) | |
### Game logic | |
class GameState: | |
config: GameConfig | |
adapter_name: str | |
chats: List[str] | |
currently_selected: List[str] | |
selected_accum: List[List[str]] | |
clicks_accum: List[List[str]] | |
turn: int = 0 | |
def has_ended(self): | |
return self.has_successfully_ended() or self.turn >= 10 | |
def has_successfully_ended(self): | |
return set(self.currently_selected) == set(self.config.targets) | |
### UI helpers | |
def serialize_conversation(self): | |
output = [f"Turn {i+1}: {message}" | |
for i, message in enumerate(self.chats)] | |
return "\n".join(output) | |
def markup_images(self): | |
context = self.config.speaker_context | |
targets = self.config.targets | |
selected = self.currently_selected | |
changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else [] | |
tangram_list = self._display_context(context, targets, changes, selected) | |
return tangram_list | |
def _display_context(context: List[str], targets: List[str], | |
changes: List[str], selected: List[str]) -> List[Image.Image]: | |
tangram_list: List[Image.Image] = [] | |
arrow = Image.open("yellow_circle.png").resize((20, 20)).convert("RGBA") | |
for img in context: | |
image = Image.open(os.path.join(IMG_DIR, img)).resize((60, 60)).convert("RGB") | |
image = ImageOps.expand(image, border=2, fill="white") | |
if img in targets and img in selected: # listener selected a target image | |
image = ImageOps.expand(image, border=10, fill="green") | |
elif img in targets and img not in selected: # unselected target: | |
image = ImageOps.expand(image, border=10, fill="black") | |
elif img in selected and img not in targets: # listener selected a wrong image | |
image = ImageOps.expand(image, border=10, fill="red") | |
else: | |
image = ImageOps.expand(image, border=10, fill="white") | |
image = ImageOps.expand(image, border=2, fill="white") | |
if img in changes: | |
image.paste(arrow, (68, 0), mask=arrow) | |
tangram_list.append(image) | |
return tangram_list | |
class GameFlow: | |
def initialize(cls, model_iteration: str) -> GameState: | |
config = generate_game_config() | |
adapter_name = "r0" if model_iteration == "Initial System" else "r6_bp" | |
state = GameState( | |
config=config, | |
adapter_name=adapter_name, | |
chats=[], | |
currently_selected=[], | |
selected_accum=[], | |
clicks_accum=[], | |
turn=0, | |
) | |
return state | |
def progress(cls, state: GameState, chat: str, | |
model: PeftModel, | |
adapter: IdeficsAdapter) -> GameState: | |
turn = state.turn | |
model_context_images = state.config.listener_context | |
model_clicks = get_model_response( | |
model, state.adapter_name, adapter, | |
model_context_images, chat, | |
state.chats, state.selected_accum | |
) | |
# symmetric difference (apply deselection, then selection) | |
currently_selected2 = sorted_list( | |
(set(state.currently_selected) - set(model_clicks)) \ | |
| (set(model_clicks) - set(state.currently_selected)) | |
) | |
state2 = GameState( | |
# constants | |
config=state.config, | |
adapter_name=state.adapter_name, | |
# updates | |
chats=state.chats.copy() + [chat], | |
currently_selected=currently_selected2, | |
selected_accum=state.selected_accum.copy() + [currently_selected2], | |
clicks_accum=state.clicks_accum.copy() + [model_clicks], | |
turn=turn+1, | |
) | |
return state2 | |
### UI | |
def create_app_inner(): | |
### layout | |
gr.Markdown("# Tangram Multi-Reference Game") | |
gr.Markdown( | |
'### You will be playing a multi-reference games against a model. \ | |
To start a game, first select whether you wish to play against our \ | |
initial trained model ("Initial System") or \ | |
our model at the end of continual learning ("Final System") \ | |
and press the "Start Game" button.') | |
gr.Markdown( | |
'You will take on a "speaker" role at each round. \ | |
Your goal is to describe this image (via a message in the textbox) \ | |
so that the model can guess what it is.\ | |
Targets have black borders. \ | |
Correctly selected targets have green borders. \ | |
Incorrectly selected targets have red borders. \ | |
Actions are marked with yellow dot. \ | |
The listener cannot see boxes or colors and the order is different.') | |
gr.Markdown( | |
'### Press "Send" to submit your action to proceed to the next turn. \ | |
You have 10 turns in total.') | |
with gr.Row(): | |
model_iteration = gr.Radio(["Initial System", "Final System"], | |
label="Model Iteration", | |
value="Final System") | |
start_btn = gr.Button("Start Game") | |
status = gr.Textbox(label="Status", interactive=False, show_label=False, | |
text_align="center", value="Please start a game.") | |
with gr.Row(): | |
image_output = gr.Gallery( | |
label="CONTEXT", show_label=False, elem_id="gallery", | |
columns=5, rows=2, object_fit="contain", height="250px", | |
allow_preview=False, container=True, interactive=False | |
) | |
with gr.Row(): | |
conversation_output = gr.Textbox(label="Interaction History") | |
with gr.Column(): | |
user_input = gr.Textbox(label="Your Message as Speaker", interactive=True) | |
send_btn = gr.Button("Send", interactive=True) | |
### globals | |
model = get_model() | |
adapter = get_adapter() | |
game_state = gr.State(value=None) | |
### callbacks | |
def output_from_state(state: GameState): | |
has_ended = state.has_ended() | |
success = "Success" if state.has_successfully_ended() else "Failure" | |
status = f"{success} (Turn {state.turn}/10) - Start another game?" \ | |
if has_ended else f"Turn {state.turn+1}/10" | |
return ( | |
state.markup_images(), # image_output | |
state.serialize_conversation(), # conversation_output | |
status, # status | |
gr.update(interactive=not has_ended, value=""), # user_input | |
gr.update(interactive=not has_ended), # send_btn | |
gr.update(interactive=has_ended), # model_iteration | |
state, # game_history | |
) | |
def on_start_interaction(model_iteration: str): | |
assert model_iteration in ["Initial System", "Final System"] | |
state = GameFlow.initialize(model_iteration) | |
return output_from_state(state) | |
def on_send_message(message: str, state: GameState): | |
nonlocal model | |
nonlocal adapter | |
if message.strip() == "": | |
logging.info("Empty message") | |
return output_from_state(state) | |
state = GameFlow.progress(state, message, model, adapter) | |
return output_from_state(state) | |
start_btn.click( | |
on_start_interaction, | |
inputs=[model_iteration], | |
outputs=[image_output, conversation_output, status, | |
user_input, send_btn, model_iteration, game_state], | |
queue=False | |
) | |
send_btn.click( | |
on_send_message, | |
inputs=[user_input, game_state], | |
outputs=[image_output, conversation_output, status, | |
user_input, send_btn, model_iteration, game_state], | |
queue=True | |
) | |
def create_app(): | |
with gr.Blocks(theme='saq1b/gradio-theme') as app: | |
create_app_inner() | |
return app | |
if __name__ == "__main__": | |
app = create_app() | |
app.queue() | |
app.launch() | |