Upload folder using huggingface_hub
Browse files- Dockerfile +80 -0
- README.md +96 -5
- TODO_movement_first_rubric.md +101 -0
- __init__.py +10 -0
- client.py +59 -0
- models.py +47 -0
- openenv.yaml +7 -0
- openenv_adhd_env.egg-info/PKG-INFO +9 -0
- openenv_adhd_env.egg-info/SOURCES.txt +15 -0
- openenv_adhd_env.egg-info/dependency_links.txt +1 -0
- openenv_adhd_env.egg-info/entry_points.txt +2 -0
- openenv_adhd_env.egg-info/requires.txt +5 -0
- openenv_adhd_env.egg-info/top_level.txt +1 -0
- pyproject.toml +46 -0
- reward.py +164 -0
- server/__init__.py +5 -0
- server/adhd_env_environment.py +135 -0
- server/app.py +33 -0
- server/requirements.txt +6 -0
- test_environment.py +262 -0
- test_with_model.py +285 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=adhd_env
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check (use Python - curl may not be in base image)
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,101 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: ADHD Task Initiation Coaching Environment
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- reinforcement-learning
|
| 12 |
+
- adhd
|
| 13 |
+
- executive-function
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# ADHD Task Initiation Coaching Environment
|
| 17 |
+
|
| 18 |
+
An OpenEnv environment that evaluates ADHD coaching response quality. It scores AI coaching responses for task initiation paralysis based on tool calling and response quality.
|
| 19 |
+
|
| 20 |
+
**Innovation**: State tracking ("knobs") + tool calling evaluation - not just text scoring.
|
| 21 |
+
|
| 22 |
+
## Quick Start
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
from adhd_env import ADHDAction, ADHDEnv
|
| 26 |
+
|
| 27 |
+
# Connect to deployed environment
|
| 28 |
+
with ADHDEnv(base_url="https://YOUR-SPACE.hf.space") as env:
|
| 29 |
+
# Get an ADHD scenario
|
| 30 |
+
result = env.reset()
|
| 31 |
+
print(f"Scenario: {result.observation.scenario}")
|
| 32 |
+
|
| 33 |
+
# Submit a coaching response for scoring
|
| 34 |
+
result = env.step(ADHDAction(
|
| 35 |
+
tool_calls=["adhd_task_initiation_coach"],
|
| 36 |
+
message="Open email and type just the recipient name. Stop there."
|
| 37 |
+
))
|
| 38 |
+
print(f"Reward: {result.reward}") # 1.0
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## How Scoring Works
|
| 42 |
+
|
| 43 |
+
The environment evaluates coaching responses on tool calling (V1):
|
| 44 |
+
|
| 45 |
+
| Action | Reward | Why |
|
| 46 |
+
|--------|--------|-----|
|
| 47 |
+
| Called `adhd_task_initiation_coach` | **1.0** | Used the primary coaching tool |
|
| 48 |
+
| Called `set_timer` or `break_down_task` | **0.5** | Valid tool, but not the primary one |
|
| 49 |
+
| No tools called | **0.0** | No tool engagement |
|
| 50 |
+
|
| 51 |
+
### Available Tools
|
| 52 |
+
- `adhd_task_initiation_coach` - Primary coaching tool for task initiation
|
| 53 |
+
- `set_timer` - Focus timers for task boxing
|
| 54 |
+
- `break_down_task` - Decompose large tasks into micro-steps
|
| 55 |
+
|
| 56 |
+
## API
|
| 57 |
+
|
| 58 |
+
### POST /reset
|
| 59 |
+
Returns a new ADHD scenario with user state.
|
| 60 |
+
|
| 61 |
+
### POST /step
|
| 62 |
+
Scores a coaching response. Body: `{"action": {"tool_calls": [...], "message": "..."}}`
|
| 63 |
+
|
| 64 |
+
### GET /health
|
| 65 |
+
Health check endpoint.
|
| 66 |
+
|
| 67 |
+
### GET /schema
|
| 68 |
+
JSON schemas for action and observation models.
|
| 69 |
+
|
| 70 |
+
## Environment Details
|
| 71 |
+
|
| 72 |
+
### ADHDAction
|
| 73 |
+
- `tool_calls` (list[str]) - Tools the model would call
|
| 74 |
+
- `message` (str) - The coaching response text
|
| 75 |
+
|
| 76 |
+
### ADHDObservation
|
| 77 |
+
- `scenario` (str) - The ADHD task initiation scenario
|
| 78 |
+
- `state` (dict) - User state tracking (sitting time, energy, etc.)
|
| 79 |
+
- `scoring` (dict) - Detailed scoring breakdown with explanations
|
| 80 |
+
- `reward` (float) - Score 0.0-1.0
|
| 81 |
+
- `done` (bool) - Episode complete flag
|
| 82 |
+
|
| 83 |
+
## Development
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
# Install dependencies
|
| 87 |
+
cd adhd_env && uv sync
|
| 88 |
+
|
| 89 |
+
# Run locally
|
| 90 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 91 |
+
|
| 92 |
+
# Test
|
| 93 |
+
python test_environment.py # Direct test
|
| 94 |
+
python test_environment.py --http # HTTP test (server must be running)
|
| 95 |
+
|
| 96 |
+
# Validate structure
|
| 97 |
+
openenv validate --verbose
|
| 98 |
+
|
| 99 |
+
# Deploy to HF Spaces
|
| 100 |
+
openenv push --repo-id USERNAME/adhd-env
|
| 101 |
+
```
|
TODO_movement_first_rubric.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: "Movement First" Rubric Criterion
|
| 2 |
+
|
| 3 |
+
## The Ideal Response Pattern
|
| 4 |
+
|
| 5 |
+
When user state indicates physical distress (slouching, long sitting, late evening),
|
| 6 |
+
the OPTIMAL coaching response prioritizes body movement BEFORE task work:
|
| 7 |
+
|
| 8 |
+
> "Before you do anything else, get up and move your body. Maybe get a drink of water,
|
| 9 |
+
> go outside and touch grass, go split some wood. When you come back I will have some
|
| 10 |
+
> questions for you."
|
| 11 |
+
|
| 12 |
+
Then it calls `adhd_coach_tool` to set up the follow-up interaction for when they return.
|
| 13 |
+
|
| 14 |
+
This should be rewarded very heavily — it's the best possible ADHD coaching response
|
| 15 |
+
for these states.
|
| 16 |
+
|
| 17 |
+
## State Triggers (any of these)
|
| 18 |
+
|
| 19 |
+
- `position_in_chair == "slouching"`
|
| 20 |
+
- `minutes_since_last_stood >= 60`
|
| 21 |
+
- Late evening: hour >= 20
|
| 22 |
+
|
| 23 |
+
## What Makes This Response Pattern Unique
|
| 24 |
+
|
| 25 |
+
It has THREE parts, all present together:
|
| 26 |
+
1. **Movement-first priority** — "before anything else", "first", "before you start"
|
| 27 |
+
2. **Physical activity suggestions** — water, outside, walk, fresh air, move your body, stretch
|
| 28 |
+
3. **Promise of return/continuation** — "when you come back", "after that we'll", "then I'll help"
|
| 29 |
+
|
| 30 |
+
Plus: calls `adhd_coach_tool` (to prepare the follow-up)
|
| 31 |
+
|
| 32 |
+
## Brainstorm: First-Cut Keyword Approach
|
| 33 |
+
|
| 34 |
+
Could score this without an LLM judge by checking for all 3 categories:
|
| 35 |
+
|
| 36 |
+
```python
|
| 37 |
+
def score_movement_first(action, user_state) -> float:
|
| 38 |
+
"""Heavy bonus when response prioritizes movement before task work."""
|
| 39 |
+
# Only triggers when state warrants it
|
| 40 |
+
needs_movement = (
|
| 41 |
+
user_state.get("position_in_chair") == "slouching"
|
| 42 |
+
or user_state.get("minutes_since_last_stood", 0) >= 60
|
| 43 |
+
or int(user_state.get("time_of_day", "12:00").split(":")[0]) >= 20
|
| 44 |
+
)
|
| 45 |
+
if not needs_movement:
|
| 46 |
+
return 0.0 # not applicable, no bonus
|
| 47 |
+
|
| 48 |
+
msg = action.message.lower()
|
| 49 |
+
|
| 50 |
+
# Category 1: Prioritizes movement BEFORE task
|
| 51 |
+
priority_words = ["before", "first", "before anything", "step away", "stop"]
|
| 52 |
+
|
| 53 |
+
# Category 2: Physical activity
|
| 54 |
+
activity_words = ["water", "outside", "walk", "fresh air", "move", "body",
|
| 55 |
+
"stretch", "drink", "grass", "sunshine", "exercise"]
|
| 56 |
+
|
| 57 |
+
# Category 3: Promise to continue after
|
| 58 |
+
return_words = ["come back", "when you return", "after that", "then we",
|
| 59 |
+
"then i", "ready", "waiting", "here for you"]
|
| 60 |
+
|
| 61 |
+
has_priority = any(w in msg for w in priority_words)
|
| 62 |
+
has_activity = any(w in msg for w in activity_words)
|
| 63 |
+
has_return = any(w in msg for w in return_words)
|
| 64 |
+
|
| 65 |
+
if has_priority and has_activity and has_return:
|
| 66 |
+
return 1.0 # full bonus — all 3 parts present
|
| 67 |
+
elif has_priority and has_activity:
|
| 68 |
+
return 0.6 # good — movement first but no return promise
|
| 69 |
+
elif has_activity:
|
| 70 |
+
return 0.3 # mentions movement but doesn't prioritize it
|
| 71 |
+
return 0.0
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### How to integrate into rubric weights
|
| 75 |
+
|
| 76 |
+
Option A: Add as 4th criterion, rebalance weights:
|
| 77 |
+
- tool_calling: 30%, state_awareness: 20%, adhd_relevance: 20%, movement_first: 30%
|
| 78 |
+
|
| 79 |
+
Option B: Make it a multiplier/bonus on top of existing score:
|
| 80 |
+
- If movement_first triggers fully, multiply final score by 1.5 (before clamp)
|
| 81 |
+
|
| 82 |
+
Option C: Replace state_awareness with this (it's a superset):
|
| 83 |
+
- This IS state awareness, just the most important kind
|
| 84 |
+
|
| 85 |
+
## Limitations of Keyword Approach
|
| 86 |
+
|
| 87 |
+
- Can't tell if the response ACTUALLY prioritizes movement vs just mentioning it
|
| 88 |
+
- "Don't walk away from your task" would false-positive on "walk" and "away"
|
| 89 |
+
- Can't evaluate the QUALITY or tone of the suggestion
|
| 90 |
+
- Can't tell if the return promise is genuine coaching setup vs throwaway
|
| 91 |
+
|
| 92 |
+
## Future: LLM-as-Judge
|
| 93 |
+
|
| 94 |
+
For a proper implementation, we'd want an LLM judge that evaluates:
|
| 95 |
+
1. Does the response prioritize physical wellbeing over task completion?
|
| 96 |
+
2. Does it give concrete physical activity suggestions (not just "take a break")?
|
| 97 |
+
3. Does it promise meaningful follow-up (not just dismissing the user)?
|
| 98 |
+
4. Is the tone encouraging rather than prescriptive?
|
| 99 |
+
|
| 100 |
+
Could use a small model (SmolLM3-3B) as judge with a rubric prompt.
|
| 101 |
+
Trade-off: slower scoring, but much more accurate for this nuanced criterion.
|
__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ADHD Task Initiation Coaching Evaluation Environment."""
|
| 2 |
+
|
| 3 |
+
from .client import ADHDEnv
|
| 4 |
+
from .models import ADHDAction, ADHDObservation
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ADHDAction",
|
| 8 |
+
"ADHDObservation",
|
| 9 |
+
"ADHDEnv",
|
| 10 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ADHD Environment Client.
|
| 2 |
+
|
| 3 |
+
Connects to an ADHD coaching evaluation environment server via WebSocket.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Dict
|
| 7 |
+
|
| 8 |
+
from openenv.core.client_types import StepResult
|
| 9 |
+
from openenv.core.env_server.types import State
|
| 10 |
+
from openenv.core import EnvClient
|
| 11 |
+
|
| 12 |
+
from .models import ADHDAction, ADHDObservation
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ADHDEnv(EnvClient[ADHDAction, ADHDObservation]):
|
| 16 |
+
"""Client for the ADHD Task Initiation Coaching Environment.
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
>>> with ADHDEnv(base_url="http://localhost:8000") as client:
|
| 20 |
+
... result = client.reset()
|
| 21 |
+
... print(result.observation.scenario)
|
| 22 |
+
...
|
| 23 |
+
... result = client.step(ADHDAction(
|
| 24 |
+
... tool_calls=["adhd_task_initiation_coach"],
|
| 25 |
+
... message="Open email and type just the recipient name."
|
| 26 |
+
... ))
|
| 27 |
+
... print(f"Reward: {result.reward}")
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def _step_payload(self, action: ADHDAction) -> Dict:
|
| 31 |
+
"""Convert ADHDAction to JSON payload."""
|
| 32 |
+
return {
|
| 33 |
+
"tool_calls": action.tool_calls,
|
| 34 |
+
"message": action.message,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def _parse_result(self, payload: Dict) -> StepResult[ADHDObservation]:
|
| 38 |
+
"""Parse server response into StepResult."""
|
| 39 |
+
obs_data = payload.get("observation", {})
|
| 40 |
+
observation = ADHDObservation(
|
| 41 |
+
scenario=obs_data.get("scenario", ""),
|
| 42 |
+
state=obs_data.get("state", {}),
|
| 43 |
+
scoring=obs_data.get("scoring", {}),
|
| 44 |
+
done=payload.get("done", False),
|
| 45 |
+
reward=payload.get("reward", 0.0),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return StepResult(
|
| 49 |
+
observation=observation,
|
| 50 |
+
reward=payload.get("reward", 0.0),
|
| 51 |
+
done=payload.get("done", False),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 55 |
+
"""Parse server response into State object."""
|
| 56 |
+
return State(
|
| 57 |
+
episode_id=payload.get("episode_id"),
|
| 58 |
+
step_count=payload.get("step_count", 0),
|
| 59 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models for the ADHD Task Initiation Coaching Environment."""
|
| 2 |
+
|
| 3 |
+
from pydantic import Field
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
from openenv.core.env_server.types import Action, Observation
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ADHDAction(Action):
|
| 10 |
+
"""Action: Tool calls + coaching response to evaluate.
|
| 11 |
+
|
| 12 |
+
Models submit tool_calls (which tools they'd invoke) and a message
|
| 13 |
+
(the coaching response text) for scoring.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
tool_calls: List[str] = Field(
|
| 17 |
+
default_factory=list,
|
| 18 |
+
description="Tools called by the model (e.g., ['adhd_task_initiation_coach'])",
|
| 19 |
+
)
|
| 20 |
+
message: str = Field(
|
| 21 |
+
default="",
|
| 22 |
+
description="The coaching response text",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ADHDObservation(Observation):
|
| 27 |
+
"""Observation: ADHD scenario + user state.
|
| 28 |
+
|
| 29 |
+
Returned from reset() with the scenario and state.
|
| 30 |
+
Returned from step() with the scored reward and scoring details.
|
| 31 |
+
Note: done, reward, metadata are inherited from Observation base class.
|
| 32 |
+
Note: OpenEnv's serialize_observation excludes 'metadata' from HTTP responses,
|
| 33 |
+
so we use a custom 'scoring' field for scoring details.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
scenario: str = Field(
|
| 37 |
+
default="",
|
| 38 |
+
description="The task initiation scenario (user utterance)",
|
| 39 |
+
)
|
| 40 |
+
state: Dict[str, Any] = Field(
|
| 41 |
+
default_factory=dict,
|
| 42 |
+
description="User state tracking (sitting time, energy, etc.)",
|
| 43 |
+
)
|
| 44 |
+
scoring: Dict[str, Any] = Field(
|
| 45 |
+
default_factory=dict,
|
| 46 |
+
description="Scoring breakdown and explanation (visible in HTTP responses)",
|
| 47 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: adhd_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
openenv_adhd_env.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-adhd_env
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Adhd Env environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.0
|
| 7 |
+
Provides-Extra: dev
|
| 8 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 9 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_adhd_env.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
./__init__.py
|
| 4 |
+
./client.py
|
| 5 |
+
./models.py
|
| 6 |
+
./reward.py
|
| 7 |
+
openenv_adhd_env.egg-info/PKG-INFO
|
| 8 |
+
openenv_adhd_env.egg-info/SOURCES.txt
|
| 9 |
+
openenv_adhd_env.egg-info/dependency_links.txt
|
| 10 |
+
openenv_adhd_env.egg-info/entry_points.txt
|
| 11 |
+
openenv_adhd_env.egg-info/requires.txt
|
| 12 |
+
openenv_adhd_env.egg-info/top_level.txt
|
| 13 |
+
server/__init__.py
|
| 14 |
+
server/adhd_env_environment.py
|
| 15 |
+
server/app.py
|
openenv_adhd_env.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_adhd_env.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = adhd_env.server.app:main
|
openenv_adhd_env.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.0
|
| 2 |
+
|
| 3 |
+
[dev]
|
| 4 |
+
pytest>=8.0.0
|
| 5 |
+
pytest-cov>=4.0.0
|
openenv_adhd_env.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
adhd_env
|
pyproject.toml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-adhd_env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Adhd Env environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.0",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
"openai>=1.0.0",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
[project.scripts]
|
| 39 |
+
# Server entry point - enables running via: uv run --project . server
|
| 40 |
+
# or: python -m adhd_env.server.app
|
| 41 |
+
server = "adhd_env.server.app:main"
|
| 42 |
+
|
| 43 |
+
[tool.setuptools]
|
| 44 |
+
include-package-data = true
|
| 45 |
+
packages = ["adhd_env", "adhd_env.server"]
|
| 46 |
+
package-dir = { "adhd_env" = ".", "adhd_env.server" = "server" }
|
reward.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward scoring for ADHD coaching environment.
|
| 2 |
+
|
| 3 |
+
V2: Rubric-based scoring with tool calling + state awareness.
|
| 4 |
+
- Tool calling: 40% weight - penalizes wrong-domain tools
|
| 5 |
+
- State awareness: 30% weight - rewards state-responsive coaching
|
| 6 |
+
- ADHD relevance: 30% weight - rewards directive, low-cognitive-load responses
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
from models import ADHDAction
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ADHD-domain tools
|
| 14 |
+
ADHD_TOOLS = {"adhd_coach_tool"}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def score_tool_calling(
|
| 18 |
+
action: ADHDAction,
|
| 19 |
+
is_adhd_scenario: bool,
|
| 20 |
+
expected_tool: Optional[str] = None,
|
| 21 |
+
) -> float:
|
| 22 |
+
"""Score tool selection based on scenario type.
|
| 23 |
+
|
| 24 |
+
ADHD scenario:
|
| 25 |
+
1.0 - called adhd_coach_tool
|
| 26 |
+
0.0 - no tools called
|
| 27 |
+
-0.5 - called a non-ADHD tool (wrong domain)
|
| 28 |
+
|
| 29 |
+
Non-ADHD scenario:
|
| 30 |
+
-0.5 - called adhd_coach_tool (wrong domain)
|
| 31 |
+
0.7 - called the expected non-ADHD tool
|
| 32 |
+
0.5 - no tools called (neutral)
|
| 33 |
+
0.5 - called some other non-ADHD tool (neutral)
|
| 34 |
+
"""
|
| 35 |
+
called = set(action.tool_calls)
|
| 36 |
+
|
| 37 |
+
if is_adhd_scenario:
|
| 38 |
+
if "adhd_coach_tool" in called:
|
| 39 |
+
return 1.0
|
| 40 |
+
if not called:
|
| 41 |
+
return 0.0
|
| 42 |
+
# Called non-ADHD tool on ADHD scenario
|
| 43 |
+
return -0.5
|
| 44 |
+
else:
|
| 45 |
+
# Non-ADHD scenario
|
| 46 |
+
if "adhd_coach_tool" in called:
|
| 47 |
+
return -0.5
|
| 48 |
+
if expected_tool and expected_tool in called:
|
| 49 |
+
return 0.7
|
| 50 |
+
# No tool or some other non-ADHD tool - neutral
|
| 51 |
+
return 0.5
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def score_state_awareness(action: ADHDAction, user_state: dict) -> float:
|
| 55 |
+
"""Score whether response accounts for user state.
|
| 56 |
+
|
| 57 |
+
1.0 - mentions movement/stretching when sitting 60+ min or slouching
|
| 58 |
+
1.0 - suggests simpler tasks when evening (hour >= 20)
|
| 59 |
+
0.5 - generic response (default, neutral)
|
| 60 |
+
"""
|
| 61 |
+
msg = action.message.lower()
|
| 62 |
+
score = 0.5 # default neutral
|
| 63 |
+
|
| 64 |
+
minutes_sitting = user_state.get("minutes_since_last_stood", 0)
|
| 65 |
+
position = user_state.get("position_in_chair", "normal")
|
| 66 |
+
time_str = user_state.get("time_of_day", "12:00")
|
| 67 |
+
hour = int(time_str.split(":")[0])
|
| 68 |
+
|
| 69 |
+
movement_keywords = [
|
| 70 |
+
"stand", "stretch", "walk", "move", "get up", "posture",
|
| 71 |
+
"take a break", "step away", "physical",
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# Reward movement suggestions when sitting too long or slouching
|
| 75 |
+
if minutes_sitting >= 60 or position == "slouching":
|
| 76 |
+
if any(kw in msg for kw in movement_keywords):
|
| 77 |
+
score = 1.0
|
| 78 |
+
|
| 79 |
+
# Reward simpler task suggestions in the evening
|
| 80 |
+
evening_keywords = [
|
| 81 |
+
"simple", "small", "easy", "quick", "short", "wind down",
|
| 82 |
+
"rest", "tomorrow", "lighter",
|
| 83 |
+
]
|
| 84 |
+
if hour >= 20:
|
| 85 |
+
if any(kw in msg for kw in evening_keywords):
|
| 86 |
+
score = 1.0
|
| 87 |
+
|
| 88 |
+
return score
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def score_adhd_relevance(action: ADHDAction, is_adhd_scenario: bool) -> float:
|
| 92 |
+
"""Score ADHD-specific response quality.
|
| 93 |
+
|
| 94 |
+
For ADHD scenarios: rewards concise responses and reflective questions.
|
| 95 |
+
For non-ADHD: returns neutral 0.5.
|
| 96 |
+
"""
|
| 97 |
+
if not is_adhd_scenario:
|
| 98 |
+
return 0.5
|
| 99 |
+
|
| 100 |
+
msg = action.message.strip()
|
| 101 |
+
if not msg:
|
| 102 |
+
return 0.0
|
| 103 |
+
|
| 104 |
+
score = 0.5 # baseline
|
| 105 |
+
msg_lower = msg.lower()
|
| 106 |
+
|
| 107 |
+
# Reward reflective/clarifying questions that prompt self-reflection
|
| 108 |
+
if "?" in msg:
|
| 109 |
+
question_words = ("what", "how")
|
| 110 |
+
reflective_words = ("specific", "detail", "details", "feeling", "think", "reflect", "explain")
|
| 111 |
+
if any(qw in msg_lower for qw in question_words) and any(rw in msg_lower for rw in reflective_words):
|
| 112 |
+
score += 0.15
|
| 113 |
+
|
| 114 |
+
# Reward concise responses (under 100 words = lower cognitive load)
|
| 115 |
+
word_count = len(msg.split())
|
| 116 |
+
if 5 <= word_count <= 50:
|
| 117 |
+
score += 0.25
|
| 118 |
+
elif word_count > 100:
|
| 119 |
+
score -= 0.25 # too long, high cognitive load
|
| 120 |
+
|
| 121 |
+
return max(0.0, min(1.0, score))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def score_rubric(
|
| 125 |
+
action: ADHDAction,
|
| 126 |
+
scenario: str,
|
| 127 |
+
user_state: dict,
|
| 128 |
+
is_adhd_scenario: bool,
|
| 129 |
+
expected_tool: Optional[str] = None,
|
| 130 |
+
) -> Dict[str, Any]:
|
| 131 |
+
"""Combined rubric score with per-criterion breakdown.
|
| 132 |
+
|
| 133 |
+
Weights: tool_calling 40% + state_awareness 30% + adhd_relevance 30%
|
| 134 |
+
Total clamped to 0.0-1.0.
|
| 135 |
+
"""
|
| 136 |
+
tool_score = score_tool_calling(action, is_adhd_scenario, expected_tool)
|
| 137 |
+
state_score = score_state_awareness(action, user_state)
|
| 138 |
+
relevance_score = score_adhd_relevance(action, is_adhd_scenario)
|
| 139 |
+
|
| 140 |
+
raw_total = (tool_score * 0.4) + (state_score * 0.3) + (relevance_score * 0.3)
|
| 141 |
+
total = max(0.0, min(1.0, raw_total))
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"version": "v2.1",
|
| 145 |
+
"total_score": round(total, 3),
|
| 146 |
+
"criteria": {
|
| 147 |
+
"tool_calling": {
|
| 148 |
+
"score": tool_score,
|
| 149 |
+
"weight": 0.4,
|
| 150 |
+
"is_adhd_scenario": is_adhd_scenario,
|
| 151 |
+
"expected_tool": expected_tool,
|
| 152 |
+
"tools_called": action.tool_calls,
|
| 153 |
+
},
|
| 154 |
+
"state_awareness": {
|
| 155 |
+
"score": state_score,
|
| 156 |
+
"weight": 0.3,
|
| 157 |
+
"user_state": user_state,
|
| 158 |
+
},
|
| 159 |
+
"adhd_relevance": {
|
| 160 |
+
"score": relevance_score,
|
| 161 |
+
"weight": 0.3,
|
| 162 |
+
},
|
| 163 |
+
},
|
| 164 |
+
}
|
server/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ADHD environment server components."""
|
| 2 |
+
|
| 3 |
+
from .adhd_env_environment import ADHDEnvironment
|
| 4 |
+
|
| 5 |
+
__all__ = ["ADHDEnvironment"]
|
server/adhd_env_environment.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ADHD Task Initiation Coaching Evaluation Environment.
|
| 2 |
+
|
| 3 |
+
Evaluates ADHD coaching responses by scoring tool calling and response quality.
|
| 4 |
+
V2: Multiple scenarios, state tracking, rubric-based scoring.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import random
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
|
| 11 |
+
from openenv.core.env_server.interfaces import Environment
|
| 12 |
+
from openenv.core.env_server.types import State
|
| 13 |
+
|
| 14 |
+
from models import ADHDAction, ADHDObservation
|
| 15 |
+
from reward import score_rubric
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ADHD task initiation scenarios
|
| 19 |
+
ADHD_SCENARIOS = [
|
| 20 |
+
"I can't start writing the email to my manager",
|
| 21 |
+
"I've been staring at this blank document for 30 minutes",
|
| 22 |
+
"I need to make a phone call but I keep putting it off",
|
| 23 |
+
"I'm stuck on starting this presentation",
|
| 24 |
+
"I've been avoiding this report all day",
|
| 25 |
+
"I don't know how to begin this project proposal",
|
| 26 |
+
"I keep switching tabs instead of starting my work",
|
| 27 |
+
"I'm overwhelmed by this task list",
|
| 28 |
+
"I can't focus on writing this code review",
|
| 29 |
+
"I've been procrastinating on this assignment for hours",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
# Non-ADHD scenarios: (prompt, expected_tool or None)
|
| 33 |
+
NON_ADHD_SCENARIOS = [
|
| 34 |
+
("What's the weather like today?", "web_search_tool"),
|
| 35 |
+
("What is the latest revenue for IBM?", "web_search_tool"),
|
| 36 |
+
("What is the capital of France?", "web_search_tool"),
|
| 37 |
+
("Write me a poem about cats", None),
|
| 38 |
+
("Translate this sentence to Spanish", None),
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def generate_user_state() -> dict:
|
| 43 |
+
"""Generate randomized user state (the 'knobs')."""
|
| 44 |
+
hour = random.randint(6, 22)
|
| 45 |
+
minute = random.randint(0, 59)
|
| 46 |
+
return {
|
| 47 |
+
"time_of_day": f"{hour:02d}:{minute:02d}",
|
| 48 |
+
"position_in_chair": random.choice(["normal", "slouching", "standing"]),
|
| 49 |
+
"minutes_since_last_stood": random.randint(0, 240),
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ADHDEnvironment(Environment):
|
| 54 |
+
"""ADHD Task Initiation Coaching Evaluation Environment.
|
| 55 |
+
|
| 56 |
+
Evaluates coaching responses for ADHD task initiation paralysis.
|
| 57 |
+
Innovation: state tracking + tool calling evaluation.
|
| 58 |
+
|
| 59 |
+
V2: Multiple scenarios, state tracking, rubric-based scoring.
|
| 60 |
+
- 10 ADHD scenarios + 5 non-ADHD scenarios
|
| 61 |
+
- 3 state variables (time_of_day, position_in_chair, minutes_since_last_stood)
|
| 62 |
+
- Rubric with tool calling + state awareness scoring
|
| 63 |
+
|
| 64 |
+
Single-turn: reset() -> step() -> done=True
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 71 |
+
self.current_scenario: str = ""
|
| 72 |
+
self.current_user_state: dict = {}
|
| 73 |
+
self.is_adhd_scenario: bool = True
|
| 74 |
+
self.expected_tool: Optional[str] = None
|
| 75 |
+
|
| 76 |
+
def reset(self) -> ADHDObservation:
|
| 77 |
+
"""Generate new episode with randomized scenario and user state."""
|
| 78 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 79 |
+
self.current_user_state = generate_user_state()
|
| 80 |
+
|
| 81 |
+
# Pick ADHD 80% / non-ADHD 20%
|
| 82 |
+
if random.random() < 0.7:
|
| 83 |
+
self.current_scenario = random.choice(ADHD_SCENARIOS)
|
| 84 |
+
self.is_adhd_scenario = True
|
| 85 |
+
self.expected_tool = "adhd_coach_tool"
|
| 86 |
+
else:
|
| 87 |
+
scenario_tuple = random.choice(NON_ADHD_SCENARIOS)
|
| 88 |
+
self.current_scenario = scenario_tuple[0]
|
| 89 |
+
self.is_adhd_scenario = False
|
| 90 |
+
self.expected_tool = scenario_tuple[1]
|
| 91 |
+
|
| 92 |
+
return ADHDObservation(
|
| 93 |
+
scenario=self.current_scenario,
|
| 94 |
+
state=self.current_user_state,
|
| 95 |
+
done=False,
|
| 96 |
+
reward=0.0,
|
| 97 |
+
scoring={
|
| 98 |
+
"version": "v2.1",
|
| 99 |
+
"available_tools": [
|
| 100 |
+
"adhd_coach_tool",
|
| 101 |
+
"web_search_tool",
|
| 102 |
+
],
|
| 103 |
+
},
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def step(self, action: ADHDAction) -> ADHDObservation: # type: ignore[override]
|
| 107 |
+
"""Score a coaching response.
|
| 108 |
+
|
| 109 |
+
Single-turn: returns done=True after scoring.
|
| 110 |
+
"""
|
| 111 |
+
self._state.step_count += 1
|
| 112 |
+
|
| 113 |
+
scoring = score_rubric(
|
| 114 |
+
action,
|
| 115 |
+
self.current_scenario,
|
| 116 |
+
self.current_user_state,
|
| 117 |
+
self.is_adhd_scenario,
|
| 118 |
+
self.expected_tool,
|
| 119 |
+
)
|
| 120 |
+
scoring["action"] = {
|
| 121 |
+
"tool_calls": action.tool_calls,
|
| 122 |
+
"message": action.message,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
return ADHDObservation(
|
| 126 |
+
scenario=self.current_scenario,
|
| 127 |
+
state=self.current_user_state,
|
| 128 |
+
done=True,
|
| 129 |
+
reward=scoring["total_score"],
|
| 130 |
+
scoring=scoring,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def state(self) -> State:
|
| 135 |
+
return self._state
|
server/app.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for the ADHD Coaching Environment.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
# Development (with auto-reload):
|
| 5 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 6 |
+
|
| 7 |
+
# Production:
|
| 8 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from openenv.core.env_server.http_server import create_app
|
| 12 |
+
|
| 13 |
+
from models import ADHDAction, ADHDObservation
|
| 14 |
+
from .adhd_env_environment import ADHDEnvironment
|
| 15 |
+
|
| 16 |
+
app = create_app(
|
| 17 |
+
ADHDEnvironment,
|
| 18 |
+
ADHDAction,
|
| 19 |
+
ADHDObservation,
|
| 20 |
+
env_name="adhd_env",
|
| 21 |
+
max_concurrent_envs=1,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 26 |
+
"""Entry point for: uv run --project . server"""
|
| 27 |
+
import uvicorn
|
| 28 |
+
|
| 29 |
+
uvicorn.run(app, host=host, port=port)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
test_environment.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Test script for the ADHD coaching environment.
|
| 3 |
+
|
| 4 |
+
Tests the environment directly (no server needed) and via HTTP if a server is running.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# Direct test (no server):
|
| 8 |
+
cd adhd_env && .venv/bin/python test_environment.py
|
| 9 |
+
|
| 10 |
+
# With server running:
|
| 11 |
+
cd adhd_env && .venv/bin/uvicorn server.app:app --host 0.0.0.0 --port 8000 &
|
| 12 |
+
cd adhd_env && .venv/bin/python test_environment.py --http
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_direct():
|
| 19 |
+
"""Test environment directly without HTTP server."""
|
| 20 |
+
from server.adhd_env_environment import ADHDEnvironment
|
| 21 |
+
from models import ADHDAction
|
| 22 |
+
|
| 23 |
+
env = ADHDEnvironment()
|
| 24 |
+
print("=" * 60)
|
| 25 |
+
print("DIRECT ENVIRONMENT TEST")
|
| 26 |
+
print("=" * 60)
|
| 27 |
+
|
| 28 |
+
# Test reset returns valid state
|
| 29 |
+
obs = env.reset()
|
| 30 |
+
print(f"\n--- Reset ---")
|
| 31 |
+
print(f"Scenario: {obs.scenario}")
|
| 32 |
+
print(f"State: {obs.state}")
|
| 33 |
+
print(f"Done: {obs.done}")
|
| 34 |
+
print(f"Reward: {obs.reward}")
|
| 35 |
+
|
| 36 |
+
assert obs.scenario, "Scenario should not be empty"
|
| 37 |
+
assert obs.done is False
|
| 38 |
+
assert obs.reward == 0.0
|
| 39 |
+
|
| 40 |
+
# Validate state has all 3 keys
|
| 41 |
+
assert "time_of_day" in obs.state, "Missing time_of_day"
|
| 42 |
+
assert "position_in_chair" in obs.state, "Missing position_in_chair"
|
| 43 |
+
assert "minutes_since_last_stood" in obs.state, "Missing minutes_since_last_stood"
|
| 44 |
+
assert obs.state["position_in_chair"] in ("normal", "slouching", "standing")
|
| 45 |
+
assert 0 <= obs.state["minutes_since_last_stood"] <= 240
|
| 46 |
+
print("State validation: PASS")
|
| 47 |
+
|
| 48 |
+
# Variety check: reset 10x and verify we get at least 2 distinct states
|
| 49 |
+
states = []
|
| 50 |
+
for _ in range(10):
|
| 51 |
+
o = env.reset()
|
| 52 |
+
states.append(
|
| 53 |
+
(o.state["time_of_day"], o.state["position_in_chair"], o.state["minutes_since_last_stood"])
|
| 54 |
+
)
|
| 55 |
+
unique_states = len(set(states))
|
| 56 |
+
assert unique_states >= 2, f"Expected at least 2 distinct states, got {unique_states}"
|
| 57 |
+
print(f"State variety check ({unique_states} unique in 10 resets): PASS")
|
| 58 |
+
|
| 59 |
+
print(f"\n{'=' * 60}")
|
| 60 |
+
print("ALL DIRECT TESTS PASSED")
|
| 61 |
+
print(f"{'=' * 60}")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_rubric():
|
| 65 |
+
"""Test rubric scoring with positive and negative cases."""
|
| 66 |
+
from server.adhd_env_environment import ADHDEnvironment
|
| 67 |
+
from models import ADHDAction
|
| 68 |
+
from reward import score_rubric
|
| 69 |
+
|
| 70 |
+
print(f"\n{'=' * 60}")
|
| 71 |
+
print("RUBRIC TEST")
|
| 72 |
+
print(f"{'=' * 60}")
|
| 73 |
+
|
| 74 |
+
# State where user has been sitting a long time and is slouching
|
| 75 |
+
tired_state = {
|
| 76 |
+
"time_of_day": "14:00",
|
| 77 |
+
"position_in_chair": "slouching",
|
| 78 |
+
"minutes_since_last_stood": 90,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
evening_state = {
|
| 82 |
+
"time_of_day": "21:00",
|
| 83 |
+
"position_in_chair": "normal",
|
| 84 |
+
"minutes_since_last_stood": 30,
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# POSITIVE: ADHD scenario + primary tool + state-aware message
|
| 88 |
+
action_good = ADHDAction(
|
| 89 |
+
tool_calls=["adhd_coach_tool"],
|
| 90 |
+
message="Stand up and stretch for 30 seconds, then type just the recipient name.",
|
| 91 |
+
)
|
| 92 |
+
result = score_rubric(action_good, "I can't start the email", tired_state, True, None)
|
| 93 |
+
print(f"\nPOSITIVE (ADHD + primary tool + state-aware): {result['total_score']}")
|
| 94 |
+
assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}"
|
| 95 |
+
print("PASS")
|
| 96 |
+
|
| 97 |
+
# NEGATIVE: ADHD scenario + wrong-domain tool
|
| 98 |
+
action_wrong_tool = ADHDAction(
|
| 99 |
+
tool_calls=["web_search_tool"],
|
| 100 |
+
message="Let me search for tips on email writing.",
|
| 101 |
+
)
|
| 102 |
+
result = score_rubric(action_wrong_tool, "I can't start the email", tired_state, True, None)
|
| 103 |
+
print(f"\nNEGATIVE (ADHD + web_search_tool): {result['total_score']}")
|
| 104 |
+
assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}"
|
| 105 |
+
print("PASS")
|
| 106 |
+
|
| 107 |
+
# NEGATIVE: Non-ADHD scenario + ADHD tool
|
| 108 |
+
action_adhd_on_non = ADHDAction(
|
| 109 |
+
tool_calls=["adhd_coach_tool"],
|
| 110 |
+
message="Let me help you initiate that task.",
|
| 111 |
+
)
|
| 112 |
+
result = score_rubric(action_adhd_on_non, "What's the weather?", tired_state, False, "web_search_tool")
|
| 113 |
+
print(f"\nNEGATIVE (non-ADHD + ADHD tool): {result['total_score']}")
|
| 114 |
+
assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}"
|
| 115 |
+
print("PASS")
|
| 116 |
+
|
| 117 |
+
# SLIGHTLY POSITIVE: Non-ADHD factual + correct tool
|
| 118 |
+
action_correct_non_adhd = ADHDAction(
|
| 119 |
+
tool_calls=["web_search_tool"],
|
| 120 |
+
message="Let me look that up for you.",
|
| 121 |
+
)
|
| 122 |
+
result = score_rubric(action_correct_non_adhd, "What is the capital of France?", tired_state, False, "web_search_tool")
|
| 123 |
+
print(f"\nSLIGHTLY POSITIVE (non-ADHD + correct tool): {result['total_score']}")
|
| 124 |
+
assert result["total_score"] >= 0.5, f"Expected >= 0.5, got {result['total_score']}"
|
| 125 |
+
print("PASS")
|
| 126 |
+
|
| 127 |
+
# NEUTRAL: Non-ADHD creative + no tool
|
| 128 |
+
action_no_tool_creative = ADHDAction(
|
| 129 |
+
tool_calls=[],
|
| 130 |
+
message="Here is a poem about cats.",
|
| 131 |
+
)
|
| 132 |
+
result = score_rubric(action_no_tool_creative, "Write me a poem about cats", tired_state, False, None)
|
| 133 |
+
print(f"\nNEUTRAL (non-ADHD creative + no tool): {result['total_score']}")
|
| 134 |
+
assert 0.3 <= result["total_score"] <= 0.7, f"Expected 0.3-0.7, got {result['total_score']}"
|
| 135 |
+
print("PASS")
|
| 136 |
+
|
| 137 |
+
# MEDIUM: ADHD + primary tool + generic message (no state awareness)
|
| 138 |
+
action_generic = ADHDAction(
|
| 139 |
+
tool_calls=["adhd_coach_tool"],
|
| 140 |
+
message="Try breaking this task into smaller pieces.",
|
| 141 |
+
)
|
| 142 |
+
result = score_rubric(action_generic, "I'm stuck on this report", tired_state, True, None)
|
| 143 |
+
print(f"\nMEDIUM (ADHD + primary tool + generic): {result['total_score']}")
|
| 144 |
+
assert 0.4 <= result["total_score"] <= 0.85, f"Expected 0.4-0.85, got {result['total_score']}"
|
| 145 |
+
print("PASS")
|
| 146 |
+
|
| 147 |
+
# EVENING: ADHD + primary tool + evening-aware message
|
| 148 |
+
action_evening = ADHDAction(
|
| 149 |
+
tool_calls=["adhd_coach_tool"],
|
| 150 |
+
message="It's late. Pick a small easy task to finish tonight, save the rest for tomorrow.",
|
| 151 |
+
)
|
| 152 |
+
result = score_rubric(action_evening, "I can't focus on this", evening_state, True, None)
|
| 153 |
+
print(f"\nEVENING AWARE (ADHD + primary tool + evening tips): {result['total_score']}")
|
| 154 |
+
assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}"
|
| 155 |
+
print("PASS")
|
| 156 |
+
|
| 157 |
+
# REFLECTIVE QUESTION: ADHD + primary tool + clarifying question
|
| 158 |
+
action_reflective = ADHDAction(
|
| 159 |
+
tool_calls=["adhd_coach_tool"],
|
| 160 |
+
message="What are you specifically stuck on? Explain the first step you think you need to take.",
|
| 161 |
+
)
|
| 162 |
+
result_reflective = score_rubric(action_reflective, "I've been stuck for 30 minutes", tired_state, True, None)
|
| 163 |
+
# Compare against same scenario with generic non-reflective message
|
| 164 |
+
action_plain = ADHDAction(
|
| 165 |
+
tool_calls=["adhd_coach_tool"],
|
| 166 |
+
message="Just try to get started on it.",
|
| 167 |
+
)
|
| 168 |
+
result_plain = score_rubric(action_plain, "I've been stuck for 30 minutes", tired_state, True, None)
|
| 169 |
+
print(f"\nREFLECTIVE Q (ADHD + primary tool + clarifying question): {result_reflective['total_score']}")
|
| 170 |
+
print(f" vs PLAIN (ADHD + primary tool + generic): {result_plain['total_score']}")
|
| 171 |
+
assert result_reflective["total_score"] > result_plain["total_score"], \
|
| 172 |
+
f"Reflective question should score higher than plain: {result_reflective['total_score']} vs {result_plain['total_score']}"
|
| 173 |
+
print("PASS")
|
| 174 |
+
|
| 175 |
+
print(f"\n{'=' * 60}")
|
| 176 |
+
print("ALL RUBRIC TESTS PASSED")
|
| 177 |
+
print(f"{'=' * 60}")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def test_http(base_url="http://localhost:8000"):
|
| 181 |
+
"""Test environment via HTTP endpoints."""
|
| 182 |
+
import requests
|
| 183 |
+
|
| 184 |
+
print(f"\n{'=' * 60}")
|
| 185 |
+
print(f"HTTP TEST ({base_url})")
|
| 186 |
+
print(f"{'=' * 60}")
|
| 187 |
+
|
| 188 |
+
# Health check
|
| 189 |
+
r = requests.get(f"{base_url}/health")
|
| 190 |
+
assert r.status_code == 200
|
| 191 |
+
print(f"\nHealth: {r.json()}")
|
| 192 |
+
|
| 193 |
+
# Schema
|
| 194 |
+
r = requests.get(f"{base_url}/schema")
|
| 195 |
+
assert r.status_code == 200
|
| 196 |
+
schema = r.json()
|
| 197 |
+
assert "action" in schema
|
| 198 |
+
assert "observation" in schema
|
| 199 |
+
print(f"Schema: action has {list(schema['action']['properties'].keys())}")
|
| 200 |
+
print(f"Schema: observation has {list(schema['observation']['properties'].keys())}")
|
| 201 |
+
|
| 202 |
+
# Reset
|
| 203 |
+
r = requests.post(f"{base_url}/reset")
|
| 204 |
+
assert r.status_code == 200
|
| 205 |
+
data = r.json()
|
| 206 |
+
assert data["done"] is False
|
| 207 |
+
assert data["reward"] == 0.0
|
| 208 |
+
assert "scenario" in data["observation"]
|
| 209 |
+
obs = data["observation"]
|
| 210 |
+
assert "state" in obs
|
| 211 |
+
assert "time_of_day" in obs["state"]
|
| 212 |
+
assert "position_in_chair" in obs["state"]
|
| 213 |
+
assert "minutes_since_last_stood" in obs["state"]
|
| 214 |
+
print(f"\nReset: scenario='{obs['scenario']}'")
|
| 215 |
+
print(f" state={obs['state']}")
|
| 216 |
+
print(f" State keys present: PASS")
|
| 217 |
+
|
| 218 |
+
# Good action (ADHD scenario + primary tool)
|
| 219 |
+
r = requests.post(f"{base_url}/step", json={
|
| 220 |
+
"action": {
|
| 221 |
+
"tool_calls": ["adhd_coach_tool"],
|
| 222 |
+
"message": "Stand up and stretch, then type just the recipient name.",
|
| 223 |
+
}
|
| 224 |
+
})
|
| 225 |
+
assert r.status_code == 200
|
| 226 |
+
data = r.json()
|
| 227 |
+
assert data["done"] is True
|
| 228 |
+
assert data["reward"] > 0
|
| 229 |
+
print(f"Good action: reward={data['reward']} PASS")
|
| 230 |
+
|
| 231 |
+
# Bad action (no tools on presumed ADHD scenario)
|
| 232 |
+
r = requests.post(f"{base_url}/step", json={
|
| 233 |
+
"action": {
|
| 234 |
+
"tool_calls": [],
|
| 235 |
+
"message": "What do you want to work on?",
|
| 236 |
+
}
|
| 237 |
+
})
|
| 238 |
+
assert r.status_code == 200
|
| 239 |
+
data = r.json()
|
| 240 |
+
print(f"No-tool action: reward={data['reward']}")
|
| 241 |
+
|
| 242 |
+
# Verify scoring details in response
|
| 243 |
+
assert "scoring" in data["observation"]
|
| 244 |
+
assert "total_score" in data["observation"]["scoring"]
|
| 245 |
+
assert "criteria" in data["observation"]["scoring"]
|
| 246 |
+
print(f"Scoring details present: PASS")
|
| 247 |
+
|
| 248 |
+
print(f"\n{'=' * 60}")
|
| 249 |
+
print("ALL HTTP TESTS PASSED")
|
| 250 |
+
print(f"{'=' * 60}")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
test_direct()
|
| 255 |
+
test_rubric()
|
| 256 |
+
|
| 257 |
+
if "--http" in sys.argv:
|
| 258 |
+
url = "http://localhost:8000"
|
| 259 |
+
for arg in sys.argv:
|
| 260 |
+
if arg.startswith("http"):
|
| 261 |
+
url = arg
|
| 262 |
+
test_http(url)
|
test_with_model.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""End-to-end test: LLM with tool calling -> ADHD environment scoring.
|
| 3 |
+
|
| 4 |
+
Tests whether LLMs pick the correct tools for ADHD vs non-ADHD scenarios,
|
| 5 |
+
and scores their responses using the environment's rubric.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
cd adhd_env && .venv/bin/python test_with_model.py
|
| 9 |
+
cd adhd_env && .venv/bin/python test_with_model.py --model Qwen/Qwen3.5-9B
|
| 10 |
+
cd adhd_env && .venv/bin/python test_with_model.py --all
|
| 11 |
+
|
| 12 |
+
Requires HF_TOKEN environment variable.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
from openai import OpenAI
|
| 20 |
+
|
| 21 |
+
from models import ADHDAction
|
| 22 |
+
from reward import score_rubric
|
| 23 |
+
|
| 24 |
+
MODELS = [
|
| 25 |
+
"HuggingFaceTB/SmolLM3-3B",
|
| 26 |
+
"Qwen/Qwen3.5-9B",
|
| 27 |
+
"allenai/OLMo-3-7B-Instruct",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
# Tool definitions the LLM sees
|
| 31 |
+
TOOLS = [
|
| 32 |
+
{
|
| 33 |
+
"type": "function",
|
| 34 |
+
"function": {
|
| 35 |
+
"name": "adhd_assist_tool",
|
| 36 |
+
"description": (
|
| 37 |
+
"Help a user with ADHD task initiation paralysis. "
|
| 38 |
+
"Use when someone is stuck starting a task, procrastinating, "
|
| 39 |
+
"or overwhelmed by executive function challenges."
|
| 40 |
+
),
|
| 41 |
+
"parameters": {
|
| 42 |
+
"type": "object",
|
| 43 |
+
"properties": {
|
| 44 |
+
"coaching_message": {
|
| 45 |
+
"type": "string",
|
| 46 |
+
"description": "The coaching response to help the user start their task.",
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
"required": ["coaching_message"],
|
| 50 |
+
},
|
| 51 |
+
},
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"type": "function",
|
| 55 |
+
"function": {
|
| 56 |
+
"name": "web_search_tool",
|
| 57 |
+
"description": (
|
| 58 |
+
"Search the web for information. Use for general knowledge questions, "
|
| 59 |
+
"weather, facts, latest news, etc."
|
| 60 |
+
),
|
| 61 |
+
"parameters": {
|
| 62 |
+
"type": "object",
|
| 63 |
+
"properties": {
|
| 64 |
+
"query": {
|
| 65 |
+
"type": "string",
|
| 66 |
+
"description": "The search query.",
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"required": ["query"],
|
| 70 |
+
},
|
| 71 |
+
},
|
| 72 |
+
},
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
# LLM tool name -> environment tool name
|
| 76 |
+
TOOL_NAME_MAP = {
|
| 77 |
+
"adhd_assist_tool": "adhd_coach_tool",
|
| 78 |
+
"web_search_tool": "web_search_tool",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Test cases: (scenario, user_state, is_adhd, expected_tool, expected_llm_tool, description)
|
| 82 |
+
TEST_CASES = [
|
| 83 |
+
{
|
| 84 |
+
"scenario": "I can't start writing the email to my manager",
|
| 85 |
+
"user_state": {"time_of_day": "10:00", "position_in_chair": "normal", "minutes_since_last_stood": 30},
|
| 86 |
+
"is_adhd": True,
|
| 87 |
+
"expected_tool": None,
|
| 88 |
+
"expected_llm_tool": "adhd_assist_tool",
|
| 89 |
+
"description": "ADHD task initiation - should use adhd_assist_tool",
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"scenario": "What's the weather like today?",
|
| 93 |
+
"user_state": {"time_of_day": "12:00", "position_in_chair": "normal", "minutes_since_last_stood": 15},
|
| 94 |
+
"is_adhd": False,
|
| 95 |
+
"expected_tool": "web_search_tool",
|
| 96 |
+
"expected_llm_tool": "web_search_tool",
|
| 97 |
+
"description": "Weather question - should use web_search_tool",
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"scenario": "I've been procrastinating on this assignment for hours and I'm exhausted",
|
| 101 |
+
"user_state": {"time_of_day": "21:30", "position_in_chair": "slouching", "minutes_since_last_stood": 120},
|
| 102 |
+
"is_adhd": True,
|
| 103 |
+
"expected_tool": None,
|
| 104 |
+
"expected_llm_tool": "adhd_assist_tool",
|
| 105 |
+
"description": "Evening ADHD with fatigue - should use adhd_assist_tool",
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"scenario": "Write me a poem about cats",
|
| 109 |
+
"user_state": {"time_of_day": "14:00", "position_in_chair": "normal", "minutes_since_last_stood": 20},
|
| 110 |
+
"is_adhd": False,
|
| 111 |
+
"expected_tool": None,
|
| 112 |
+
"expected_llm_tool": None,
|
| 113 |
+
"description": "Creative request - should NOT use adhd_assist_tool",
|
| 114 |
+
},
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def call_model(client: OpenAI, model: str, scenario: str, user_state: dict) -> dict:
|
| 119 |
+
"""Send scenario to LLM and parse tool call response."""
|
| 120 |
+
system_prompt = (
|
| 121 |
+
"You are a helpful assistant. You have access to tools. "
|
| 122 |
+
"Use the appropriate tool when the user's request matches a tool's purpose. "
|
| 123 |
+
"If no tool is appropriate, respond directly without calling any tool.\n\n"
|
| 124 |
+
f"User context: time={user_state['time_of_day']}, "
|
| 125 |
+
f"position={user_state['position_in_chair']}, "
|
| 126 |
+
f"minutes since last stood={user_state['minutes_since_last_stood']}"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
response = client.chat.completions.create(
|
| 131 |
+
model=model,
|
| 132 |
+
messages=[
|
| 133 |
+
{"role": "system", "content": system_prompt},
|
| 134 |
+
{"role": "user", "content": scenario},
|
| 135 |
+
],
|
| 136 |
+
tools=TOOLS,
|
| 137 |
+
tool_choice="auto",
|
| 138 |
+
max_tokens=256,
|
| 139 |
+
)
|
| 140 |
+
except Exception as e:
|
| 141 |
+
return {"error": str(e), "tool_calls": [], "message": ""}
|
| 142 |
+
|
| 143 |
+
msg = response.choices[0].message
|
| 144 |
+
tool_calls_raw = msg.tool_calls or []
|
| 145 |
+
|
| 146 |
+
# Map LLM tool names to environment tool names
|
| 147 |
+
env_tool_calls = []
|
| 148 |
+
llm_tool_names = []
|
| 149 |
+
for tc in tool_calls_raw:
|
| 150 |
+
llm_tool_names.append(tc.function.name)
|
| 151 |
+
env_name = TOOL_NAME_MAP.get(tc.function.name, tc.function.name)
|
| 152 |
+
env_tool_calls.append(env_name)
|
| 153 |
+
|
| 154 |
+
# Extract message from tool args or content
|
| 155 |
+
message = msg.content or ""
|
| 156 |
+
if not message and tool_calls_raw:
|
| 157 |
+
import json
|
| 158 |
+
try:
|
| 159 |
+
args = json.loads(tool_calls_raw[0].function.arguments)
|
| 160 |
+
message = args.get("coaching_message", args.get("query", ""))
|
| 161 |
+
except (json.JSONDecodeError, IndexError):
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
"tool_calls": env_tool_calls,
|
| 166 |
+
"llm_tool_names": llm_tool_names,
|
| 167 |
+
"message": message,
|
| 168 |
+
"error": None,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def run_model_tests(client: OpenAI, model: str) -> dict:
|
| 173 |
+
"""Run all test cases against a model and return results."""
|
| 174 |
+
print(f"\n{'=' * 60}")
|
| 175 |
+
print(f"MODEL: {model}")
|
| 176 |
+
print(f"{'=' * 60}")
|
| 177 |
+
|
| 178 |
+
correct = 0
|
| 179 |
+
total = len(TEST_CASES)
|
| 180 |
+
total_reward = 0.0
|
| 181 |
+
results = []
|
| 182 |
+
|
| 183 |
+
for i, tc in enumerate(TEST_CASES):
|
| 184 |
+
print(f"\n--- Test {i+1}: {tc['description']} ---")
|
| 185 |
+
print(f" Scenario: {tc['scenario']}")
|
| 186 |
+
|
| 187 |
+
resp = call_model(client, model, tc["scenario"], tc["user_state"])
|
| 188 |
+
|
| 189 |
+
if resp.get("error"):
|
| 190 |
+
print(f" ERROR: {resp['error']}")
|
| 191 |
+
results.append({"test": i+1, "error": resp["error"]})
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
print(f" LLM tools: {resp['llm_tool_names']}")
|
| 195 |
+
print(f" Message: {resp['message'][:80]}...")
|
| 196 |
+
|
| 197 |
+
# Score with environment rubric
|
| 198 |
+
action = ADHDAction(tool_calls=resp["tool_calls"], message=resp["message"])
|
| 199 |
+
scoring = score_rubric(
|
| 200 |
+
action, tc["scenario"], tc["user_state"],
|
| 201 |
+
tc["is_adhd"], tc["expected_tool"],
|
| 202 |
+
)
|
| 203 |
+
reward = scoring["total_score"]
|
| 204 |
+
total_reward += reward
|
| 205 |
+
|
| 206 |
+
# Check if LLM picked the right tool
|
| 207 |
+
llm_picked = resp["llm_tool_names"][0] if resp["llm_tool_names"] else None
|
| 208 |
+
expected = tc["expected_llm_tool"]
|
| 209 |
+
|
| 210 |
+
if expected is None:
|
| 211 |
+
# For "no tool expected", correct if didn't pick adhd_assist_tool
|
| 212 |
+
tool_correct = llm_picked != "adhd_assist_tool"
|
| 213 |
+
else:
|
| 214 |
+
tool_correct = llm_picked == expected
|
| 215 |
+
|
| 216 |
+
if tool_correct:
|
| 217 |
+
correct += 1
|
| 218 |
+
|
| 219 |
+
status = "CORRECT" if tool_correct else "WRONG"
|
| 220 |
+
print(f" Tool choice: {status} (picked={llm_picked}, expected={expected})")
|
| 221 |
+
print(f" Reward: {reward}")
|
| 222 |
+
results.append({
|
| 223 |
+
"test": i+1,
|
| 224 |
+
"tool_correct": tool_correct,
|
| 225 |
+
"reward": reward,
|
| 226 |
+
"picked": llm_picked,
|
| 227 |
+
"expected": expected,
|
| 228 |
+
})
|
| 229 |
+
|
| 230 |
+
avg_reward = total_reward / total if total > 0 else 0
|
| 231 |
+
print(f"\n--- Summary for {model} ---")
|
| 232 |
+
print(f" Tool accuracy: {correct}/{total}")
|
| 233 |
+
print(f" Avg reward: {avg_reward:.3f}")
|
| 234 |
+
|
| 235 |
+
return {
|
| 236 |
+
"model": model,
|
| 237 |
+
"correct": correct,
|
| 238 |
+
"total": total,
|
| 239 |
+
"avg_reward": avg_reward,
|
| 240 |
+
"results": results,
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def main():
|
| 245 |
+
parser = argparse.ArgumentParser(description="Test LLM tool calling with ADHD environment")
|
| 246 |
+
parser.add_argument("--model", type=str, help="Model to test (default: first in list)")
|
| 247 |
+
parser.add_argument("--all", action="store_true", help="Test all models and show leaderboard")
|
| 248 |
+
args = parser.parse_args()
|
| 249 |
+
|
| 250 |
+
token = os.environ.get("HF_TOKEN")
|
| 251 |
+
if not token:
|
| 252 |
+
print("ERROR: HF_TOKEN environment variable not set.")
|
| 253 |
+
print("Run: export HF_TOKEN=hf_...")
|
| 254 |
+
sys.exit(1)
|
| 255 |
+
|
| 256 |
+
client = OpenAI(
|
| 257 |
+
base_url="https://router.huggingface.co/v1",
|
| 258 |
+
api_key=token,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if args.all:
|
| 262 |
+
models = MODELS
|
| 263 |
+
elif args.model:
|
| 264 |
+
models = [args.model]
|
| 265 |
+
else:
|
| 266 |
+
models = [MODELS[0]]
|
| 267 |
+
|
| 268 |
+
all_results = []
|
| 269 |
+
for model in models:
|
| 270 |
+
result = run_model_tests(client, model)
|
| 271 |
+
all_results.append(result)
|
| 272 |
+
|
| 273 |
+
if len(all_results) > 1:
|
| 274 |
+
print(f"\n{'=' * 60}")
|
| 275 |
+
print("MODEL LEADERBOARD")
|
| 276 |
+
print(f"{'=' * 60}")
|
| 277 |
+
print(f"{'Model':<40} {'Accuracy':>10} {'Avg Reward':>12}")
|
| 278 |
+
print("-" * 62)
|
| 279 |
+
|
| 280 |
+
for r in sorted(all_results, key=lambda x: x["avg_reward"], reverse=True):
|
| 281 |
+
print(f"{r['model']:<40} {r['correct']}/{r['total']:>8} {r['avg_reward']:>11.3f}")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
if __name__ == "__main__":
|
| 285 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|