import gradio as gr import requests from typing import Dict, Tuple, List import json from dataclasses import dataclass from typing import Optional @dataclass class Feature: feature_id: int activation: float token: str position: int class FeatureState: def __init__(self): self.features_by_token = {} self.expanded_tokens = set() self.selected_feature = None def get_features(text: str) -> Dict: """Get neural features from the API using the exact website parameters.""" url = "https://www.neuronpedia.org/api/search-with-topk" payload = { "modelId": "gemma-2-2b", "text": text, "layer": "20-gemmascope-res-16k" } try: response = requests.post( url, headers={"Content-Type": "application/json"}, json=payload ) response.raise_for_status() return response.json() except Exception as e: return None def format_feature_list(features: List[Feature], token: str, expanded: bool = False) -> str: """Format features as HTML list.""" display_features = features if expanded else features[:3] features_html = "" for feature in display_features: features_html += f"""
Feature {feature.feature_id} (Activation: {feature.activation:.2f})
""" if not expanded and len(features) > 3: remaining = len(features) - 3 features_html += f"""
{remaining} more features available
""" return features_html def format_dashboard(feature: Feature) -> str: """Format the dashboard HTML for a selected feature.""" if not feature: return "" return f"""

Feature {feature.feature_id} Dashboard (Activation: {feature.activation:.2f})

""" def process_features(data: Dict) -> Dict[str, List[Feature]]: """Process API response into features grouped by token.""" features_by_token = {} for result in data.get('results', []): if result['token'] == '': continue token = result['token'] features = [] for idx, feature in enumerate(result.get('top_features', [])): features.append(Feature( feature_id=feature['feature_index'], activation=feature['activation_value'], token=token, position=idx )) features_by_token[token] = features return features_by_token css = """ @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap'); body { font-family: 'Open Sans', sans-serif !important; } .feature-card { border: 1px solid #e0e5ff; background-color: #ffffff; transition: all 0.2s ease; } .feature-card:hover { border-color: #3452db; box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1); } .dashboard-container { border: 1px solid #e0e5ff; border-radius: 8px; background-color: #ffffff; } """ theme = gr.themes.Soft( primary_hue=gr.themes.colors.Color( name="blue", c50="#eef1ff", c100="#e0e5ff", c200="#c3cbff", c300="#a5b2ff", c400="#8798ff", c500="#6a7eff", c600="#3452db", c700="#2a41af", c800="#1f3183", c900="#152156", c950="#0a102b", ) ) def analyze_features(text: str, state: Optional[Dict] = None) -> Tuple[str, Dict]: """Main analysis function that processes text and returns formatted output.""" if not text: return "", None data = get_features(text) if not data: return "Error analyzing text", None # Process features and build state features_by_token = process_features(data) # Initialize state if needed if not state: state = { 'features_by_token': features_by_token, 'expanded_tokens': set(), 'selected_feature': None } # Select first feature as default first_token = next(iter(features_by_token)) if features_by_token[first_token]: state['selected_feature'] = features_by_token[first_token][0] # Build output HTML output = [] for token, features in features_by_token.items(): expanded = token in state['expanded_tokens'] token_html = f"

Token: {token}

" features_html = format_feature_list(features, token, expanded) output.append(f"
{token_html}{features_html}
") # Add dashboard if a feature is selected if state['selected_feature']: output.append(format_dashboard(state['selected_feature'])) return "\n".join(output), state def toggle_expansion(token: str, state: Dict) -> Tuple[str, Dict]: """Toggle expansion state for a token's features.""" if token in state['expanded_tokens']: state['expanded_tokens'].remove(token) else: state['expanded_tokens'].add(token) output_html, state = analyze_features(None, state) return output_html, state def select_feature(feature_id: int, state: Dict) -> Tuple[str, Dict]: """Select a feature and update the dashboard.""" for features in state['features_by_token'].values(): for feature in features: if feature.feature_id == feature_id: state['selected_feature'] = feature break output_html, state = analyze_features(None, state) return output_html, state def create_interface(): state = gr.State({}) with gr.Blocks(theme=theme, css=css) as interface: gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2") gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6") with gr.Row(): with gr.Column(scale=1): input_text = gr.Textbox( lines=5, placeholder="Enter text to analyze...", label="Input Text" ) analyze_btn = gr.Button("Analyze Features", variant="primary") gr.Examples( examples=["WordLift", "Think Different", "Just Do It"], inputs=input_text ) with gr.Column(scale=2): output = gr.HTML() # Event handlers analyze_btn.click( fn=analyze_features, inputs=[input_text, state], outputs=[output, state] ) return interface if __name__ == "__main__": create_interface().launch()