import json import os import gradio as gr import spaces from contents import ( citation, description, examples, how_it_works, how_to_use, subtitle, title, ) from gradio_highlightedtextbox import HighlightedTextbox from presets import ( set_chatml_preset, set_cora_preset, set_default_preset, set_mmt_preset, set_towerinstruct_preset, set_zephyr_preset, set_gemma_preset, ) from style import custom_css from utils import get_formatted_attribute_context_results from inseq import list_feature_attribution_methods, list_step_functions from inseq.commands.attribute_context.attribute_context import ( AttributeContextArgs, attribute_context_with_model, ) from inseq.models import HuggingfaceModel loaded_model: HuggingfaceModel = None @spaces.GPU() def pecore( input_current_text: str, input_context_text: str, output_current_text: str, output_context_text: str, model_name_or_path: str, attribution_method: str, attributed_fn: str | None, context_sensitivity_metric: str, context_sensitivity_std_threshold: float, context_sensitivity_topk: int, attribution_std_threshold: float, attribution_topk: int, input_template: str, contextless_input_current_text: str, output_template: str, special_tokens_to_keep: str | list[str] | None, decoder_input_output_separator: str, model_kwargs: str, tokenizer_kwargs: str, generation_kwargs: str, attribution_kwargs: str, ): global loaded_model if "{context}" in output_template and not output_context_text: raise gr.Error( "Parameter 'Generated context' is required when using {context} in the output template." ) if loaded_model is None or model_name_or_path != loaded_model.model_name: gr.Info("Loading model...") loaded_model = HuggingfaceModel.load( model_name_or_path, attribution_method, model_kwargs=json.loads(model_kwargs), tokenizer_kwargs=json.loads(tokenizer_kwargs), ) kwargs = {} if context_sensitivity_topk > 0: kwargs["context_sensitivity_topk"] = context_sensitivity_topk if attribution_topk > 0: kwargs["attribution_topk"] = attribution_topk if input_context_text: kwargs["input_context_text"] = input_context_text if output_context_text: kwargs["output_context_text"] = output_context_text if output_current_text: kwargs["output_current_text"] = output_current_text if decoder_input_output_separator: kwargs["decoder_input_output_separator"] = decoder_input_output_separator pecore_args = AttributeContextArgs( show_intermediate_outputs=False, save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"), add_output_info=True, viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"), show_viz=False, model_name_or_path=model_name_or_path, attribution_method=attribution_method, attributed_fn=attributed_fn, attribution_selectors=None, attribution_aggregators=None, normalize_attributions=True, model_kwargs=json.loads(model_kwargs), tokenizer_kwargs=json.loads(tokenizer_kwargs), generation_kwargs=json.loads(generation_kwargs), attribution_kwargs=json.loads(attribution_kwargs), context_sensitivity_metric=context_sensitivity_metric, prompt_user_for_contextless_output_next_tokens=False, special_tokens_to_keep=special_tokens_to_keep, context_sensitivity_std_threshold=context_sensitivity_std_threshold, attribution_std_threshold=attribution_std_threshold, input_current_text=input_current_text, input_template=input_template, output_template=output_template, contextless_input_current_text=contextless_input_current_text, handle_output_context_strategy="pre", **kwargs, ) out = attribute_context_with_model(pecore_args, loaded_model) tuples = get_formatted_attribute_context_results(loaded_model, out.info, out) if not tuples: msg = f"Output: {out.output_current}\nWarning: No pairs were found by PECoRe. Try adjusting Results Selection parameters." tuples = [(msg, None)] return tuples, gr.Button(visible=True), gr.Button(visible=True) @spaces.GPU() def preload_model( model_name_or_path: str, attribution_method: str, model_kwargs: str, tokenizer_kwargs: str, ): global loaded_model if loaded_model is None or model_name_or_path != loaded_model.model_name: gr.Info("Loading model...") loaded_model = HuggingfaceModel.load( model_name_or_path, attribution_method, model_kwargs=json.loads(model_kwargs), tokenizer_kwargs=json.loads(tokenizer_kwargs), ) with gr.Blocks(css=custom_css) as demo: gr.Markdown(title) gr.Markdown(subtitle) gr.Markdown(description) with gr.Tab("🐑 Attributing Context"): with gr.Row(): with gr.Column(): input_context_text = gr.Textbox( label="Input context", lines=4, placeholder="Your input context..." ) input_current_text = gr.Textbox( label="Input query", placeholder="Your input query..." ) attribute_input_button = gr.Button("Submit", variant="primary") with gr.Column(): pecore_output_highlights = HighlightedTextbox( value=[ ("This output will contain ", None), ("context sensitive", "Context sensitive"), (" generated tokens and ", None), ("influential context", "Influential context"), (" tokens.", None), ], color_map={ "Context sensitive": "green", "Influential context": "blue", }, show_legend=True, label="PECoRe Output", combine_adjacent=True, interactive=False, ) with gr.Row(equal_height=True): download_output_file_button = gr.Button( "⇓ Download output", visible=False, link=os.path.join( os.path.dirname(__file__), "/file=outputs/output.json" ), ) download_output_html_button = gr.Button( "🔍 Download HTML", visible=False, link=os.path.join( os.path.dirname(__file__), "/file=outputs/output.html" ), ) attribute_input_examples = gr.Examples( examples, inputs=[input_current_text, input_context_text], outputs=pecore_output_highlights, ) with gr.Tab("⚙️ Parameters") as params_tab: gr.Markdown( "## ✨ Presets\nSelect a preset to load default parameters into the fields below. ⚠️ This will overwrite existing parameters." ) with gr.Row(equal_height=True): with gr.Column(): default_preset = gr.Button("Default", variant="secondary") gr.Markdown( "Default preset using templates without special tokens or parameters.\nCan be used with most decoder-only and encoder-decoder models." ) with gr.Column(): cora_preset = gr.Button("CORA mQA", variant="secondary") gr.Markdown( "Preset for the CORA Multilingual QA model.\nUses special templates for inputs." ) with gr.Column(): zephyr_preset = gr.Button("Zephyr Template", variant="secondary") gr.Markdown( "Preset for models using the Zephyr conversational template.\nUses <|system|>, <|user|> and <|assistant|> special tokens." ) with gr.Row(equal_height=True): with gr.Column(scale=1): multilingual_mt_template = gr.Button( "Multilingual MT", variant="secondary" ) gr.Markdown( "Preset for multilingual MT models such as NLLB and mBART using language tags." ) with gr.Column(scale=1): chatml_template = gr.Button("Qwen ChatML", variant="secondary") gr.Markdown( "Preset for models using the ChatML conversational template.\nUses <|im_start|>, <|im_end|> special tokens." ) with gr.Column(scale=1): towerinstruct_template = gr.Button( "Unbabel TowerInstruct", variant="secondary" ) gr.Markdown( "Preset for models using the Unbabel TowerInstruct conversational template.\nUses <|im_start|>, <|im_end|> special tokens." ) with gr.Row(equal_height=True): with gr.Column(scale=1): gemma_template = gr.Button( "Gemma Chat Template", variant="secondary" ) gr.Markdown( "Preset for Gemma instruction-tuned models." ) gr.Markdown("## ⚙️ PECoRe Parameters") with gr.Row(equal_height=True): with gr.Column(): model_name_or_path = gr.Textbox( value="gpt2", label="Model", info="Hugging Face Hub identifier of the model to analyze with PECoRe.", interactive=True, ) load_model_button = gr.Button( "Load model", variant="secondary", ) context_sensitivity_metric = gr.Dropdown( value="kl_divergence", label="Context sensitivity metric", info="Metric to use to measure context sensitivity of generated tokens.", choices=list_step_functions(), interactive=True, ) attribution_method = gr.Dropdown( value="saliency", label="Attribution method", info="Attribution method identifier to identify relevant context tokens.", choices=list_feature_attribution_methods(), interactive=True, ) attributed_fn = gr.Dropdown( value="contrast_prob_diff", label="Attributed function", info="Function of model logits to use as target for the attribution method.", choices=list_step_functions(), interactive=True, ) gr.Markdown("#### Results Selection Parameters") with gr.Row(equal_height=True): context_sensitivity_std_threshold = gr.Number( value=1.0, label="Context sensitivity threshold", info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.", precision=1, minimum=0.0, maximum=5.0, step=0.5, interactive=True, ) context_sensitivity_topk = gr.Number( value=0, label="Context sensitivity top-k", info="Select N to keep top N context sensitive tokens. 0 = keep all.", interactive=True, precision=0, minimum=0, maximum=10, ) attribution_std_threshold = gr.Number( value=1.0, label="Attribution threshold", info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.", precision=1, minimum=0.0, maximum=5.0, step=0.5, interactive=True, ) attribution_topk = gr.Number( value=0, label="Attribution top-k", info="Select N to keep top N attributed tokens in the context. 0 = keep all.", interactive=True, precision=0, minimum=0, maximum=50, ) gr.Markdown("#### Text Format Parameters") with gr.Row(equal_height=True): input_template = gr.Textbox( value="{current}

:{context}", label="Input template", info="Template to format the input for the model. Use {current} and {context} placeholders.", interactive=True, ) output_template = gr.Textbox( value="{current}", label="Output template", info="Template to format the output from the model. Use {current} and {context} placeholders.", interactive=True, ) contextless_input_current_text = gr.Textbox( value=":{current}", label="Input current text template", info="Template to format the input query for the model. Use {current} placeholder.", interactive=True, ) with gr.Row(equal_height=True): special_tokens_to_keep = gr.Dropdown( label="Special tokens to keep", info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.", value=None, multiselect=True, allow_custom_value=True, ) decoder_input_output_separator = gr.Textbox( label="Decoder input/output separator", info="Separator to use between input and output in the decoder input.", value="", interactive=True, lines=1, ) gr.Markdown("## ⚙️ Generation Parameters") with gr.Row(equal_height=True): with gr.Column(scale=0.5): gr.Markdown( "The following arguments can be used to control generation parameters and force specific model outputs." ) with gr.Column(scale=1): generation_kwargs = gr.Code( value="{}", language="json", label="Generation kwargs (JSON)", interactive=True, lines=1, ) with gr.Row(equal_height=True): output_current_text = gr.Textbox( label="Generation output", info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.", interactive=True, ) output_context_text = gr.Textbox( label="Generation context", info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.", interactive=True, ) gr.Markdown("## ⚙️ Other Parameters") with gr.Row(equal_height=True): with gr.Column(): gr.Markdown( "The following arguments will be passed to initialize the Hugging Face model and tokenizer, and to the `inseq_model.attribute` method." ) with gr.Column(): model_kwargs = gr.Code( value="{}", language="json", label="Model kwargs (JSON)", interactive=True, lines=1, min_width=160, ) with gr.Column(): tokenizer_kwargs = gr.Code( value="{}", language="json", label="Tokenizer kwargs (JSON)", interactive=True, lines=1, ) with gr.Column(): attribution_kwargs = gr.Code( value="{}", language="json", label="Attribution kwargs (JSON)", interactive=True, lines=1, ) gr.Markdown(how_it_works) gr.Markdown(how_to_use) gr.Markdown(citation) # Main logic load_model_args = [ model_name_or_path, attribution_method, model_kwargs, tokenizer_kwargs, ] attribute_input_button.click( pecore, inputs=[ input_current_text, input_context_text, output_current_text, output_context_text, model_name_or_path, attribution_method, attributed_fn, context_sensitivity_metric, context_sensitivity_std_threshold, context_sensitivity_topk, attribution_std_threshold, attribution_topk, input_template, contextless_input_current_text, output_template, special_tokens_to_keep, decoder_input_output_separator, model_kwargs, tokenizer_kwargs, generation_kwargs, attribution_kwargs, ], outputs=[ pecore_output_highlights, download_output_file_button, download_output_html_button, ], ) load_model_event = load_model_button.click( preload_model, inputs=load_model_args, outputs=[], ) # Preset params outputs_to_reset = [ model_name_or_path, input_template, contextless_input_current_text, output_template, special_tokens_to_keep, decoder_input_output_separator, model_kwargs, tokenizer_kwargs, generation_kwargs, attribution_kwargs, ] reset_kwargs = { "fn": set_default_preset, "inputs": None, "outputs": outputs_to_reset, } # Presets default_preset.click(**reset_kwargs).success(preload_model, inputs=load_model_args, cancels=load_model_event) cora_preset.click(**reset_kwargs).then( set_cora_preset, outputs=[model_name_or_path, input_template, contextless_input_current_text], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) zephyr_preset.click(**reset_kwargs).then( set_zephyr_preset, outputs=[ model_name_or_path, input_template, contextless_input_current_text, decoder_input_output_separator, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) multilingual_mt_template.click(**reset_kwargs).then( set_mmt_preset, outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) chatml_template.click(**reset_kwargs).then( set_chatml_preset, outputs=[ model_name_or_path, input_template, contextless_input_current_text, decoder_input_output_separator, special_tokens_to_keep, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) towerinstruct_template.click(**reset_kwargs).then( set_towerinstruct_preset, outputs=[ model_name_or_path, input_template, contextless_input_current_text, decoder_input_output_separator, special_tokens_to_keep, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) gemma_template.click(**reset_kwargs).then( set_gemma_preset, outputs=[ model_name_or_path, input_template, contextless_input_current_text, decoder_input_output_separator, special_tokens_to_keep, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) demo.launch(allowed_paths=["outputs/"])