import gradio as gr
import requests
from typing import Dict, Tuple, List
import json
from dataclasses import dataclass
from typing import Optional
@dataclass
class Feature:
feature_id: int
activation: float
token: str
position: int
class FeatureState:
def __init__(self):
self.features_by_token = {}
self.expanded_tokens = set()
self.selected_feature = None
def get_features(text: str) -> Dict:
"""Get neural features from the API using the exact website parameters."""
url = "https://www.neuronpedia.org/api/search-with-topk"
payload = {
"modelId": "gemma-2-2b",
"text": text,
"layer": "20-gemmascope-res-16k"
}
try:
response = requests.post(
url,
headers={"Content-Type": "application/json"},
json=payload
)
response.raise_for_status()
return response.json()
except Exception as e:
return None
def format_feature_list(features: List[Feature], token: str, expanded: bool = False) -> str:
"""Format features as HTML list."""
display_features = features if expanded else features[:3]
features_html = ""
for feature in display_features:
features_html += f"""
Feature {feature.feature_id}
(Activation: {feature.activation:.2f})
"""
if not expanded and len(features) > 3:
remaining = len(features) - 3
features_html += f"""
{remaining} more features available
"""
return features_html
def format_dashboard(feature: Feature) -> str:
"""Format the dashboard HTML for a selected feature."""
if not feature:
return ""
return f"""
Feature {feature.feature_id} Dashboard (Activation: {feature.activation:.2f})
"""
def process_features(data: Dict) -> Dict[str, List[Feature]]:
"""Process API response into features grouped by token."""
features_by_token = {}
for result in data.get('results', []):
if result['token'] == '':
continue
token = result['token']
features = []
for idx, feature in enumerate(result.get('top_features', [])):
features.append(Feature(
feature_id=feature['feature_index'],
activation=feature['activation_value'],
token=token,
position=idx
))
features_by_token[token] = features
return features_by_token
css = """
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
body {
font-family: 'Open Sans', sans-serif !important;
}
.feature-card {
border: 1px solid #e0e5ff;
background-color: #ffffff;
transition: all 0.2s ease;
}
.feature-card:hover {
border-color: #3452db;
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
}
.dashboard-container {
border: 1px solid #e0e5ff;
border-radius: 8px;
background-color: #ffffff;
}
"""
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.Color(
name="blue",
c50="#eef1ff",
c100="#e0e5ff",
c200="#c3cbff",
c300="#a5b2ff",
c400="#8798ff",
c500="#6a7eff",
c600="#3452db",
c700="#2a41af",
c800="#1f3183",
c900="#152156",
c950="#0a102b",
)
)
def analyze_features(text: str, state: Optional[Dict] = None) -> Tuple[str, Dict]:
"""Main analysis function that processes text and returns formatted output."""
if not text:
return "", None
data = get_features(text)
if not data:
return "Error analyzing text", None
# Process features and build state
features_by_token = process_features(data)
# Initialize state if needed
if not state:
state = {
'features_by_token': features_by_token,
'expanded_tokens': set(),
'selected_feature': None
}
# Select first feature as default
first_token = next(iter(features_by_token))
if features_by_token[first_token]:
state['selected_feature'] = features_by_token[first_token][0]
# Build output HTML
output = []
for token, features in features_by_token.items():
expanded = token in state['expanded_tokens']
token_html = f"Token: {token}
"
features_html = format_feature_list(features, token, expanded)
output.append(f"{token_html}{features_html}
")
# Add dashboard if a feature is selected
if state['selected_feature']:
output.append(format_dashboard(state['selected_feature']))
return "\n".join(output), state
def toggle_expansion(token: str, state: Dict) -> Tuple[str, Dict]:
"""Toggle expansion state for a token's features."""
if token in state['expanded_tokens']:
state['expanded_tokens'].remove(token)
else:
state['expanded_tokens'].add(token)
output_html, state = analyze_features(None, state)
return output_html, state
def select_feature(feature_id: int, state: Dict) -> Tuple[str, Dict]:
"""Select a feature and update the dashboard."""
for features in state['features_by_token'].values():
for feature in features:
if feature.feature_id == feature_id:
state['selected_feature'] = feature
break
output_html, state = analyze_features(None, state)
return output_html, state
def create_interface():
state = gr.State({})
with gr.Blocks(theme=theme, css=css) as interface:
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to analyze...",
label="Input Text"
)
analyze_btn = gr.Button("Analyze Features", variant="primary")
gr.Examples(
examples=["WordLift", "Think Different", "Just Do It"],
inputs=input_text
)
with gr.Column(scale=2):
output = gr.HTML()
# Event handlers
analyze_btn.click(
fn=analyze_features,
inputs=[input_text, state],
outputs=[output, state]
)
return interface
if __name__ == "__main__":
create_interface().launch()