diff --git a/backend/alembic.ini b/backend/alembic.ini new file mode 100644 index 0000000..aaaf6a6 --- /dev/null +++ b/backend/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql+asyncpg://user:password@localhost:5432/councilOS + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/backend/alembic/env.py b/backend/alembic/env.py new file mode 100644 index 0000000..44da48a --- /dev/null +++ b/backend/alembic/env.py @@ -0,0 +1,74 @@ +"""Alembic environment configuration for async SQLAlchemy.""" + +import asyncio +import os +import sys +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +# Add the backend directory to sys.path so we can import our models +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from models.blueprint import Base # noqa: E402 + +# Alembic Config object +config = context.config + +# Override sqlalchemy.url from environment variable if present +database_url = os.environ.get("DATABASE_URL") +if database_url: + config.set_main_option("sqlalchemy.url", database_url) + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode (generate SQL without DB connection).""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + """Run migrations using a sync connection.""" + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in 'online' mode with an async engine.""" + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/alembic/script.py.mako b/backend/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/backend/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/backend/alembic/versions/001_create_blueprints_table.py b/backend/alembic/versions/001_create_blueprints_table.py new file mode 100644 index 0000000..8fdd487 --- /dev/null +++ b/backend/alembic/versions/001_create_blueprints_table.py @@ -0,0 +1,42 @@ +"""Create blueprints table + +Revision ID: 001 +Revises: None +Create Date: 2026-02-20 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +revision: str = "001" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "blueprints", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("version", sa.Integer(), nullable=False, server_default="1"), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("nodes", sa.JSON(), nullable=False, server_default="[]"), + sa.Column("edges", sa.JSON(), nullable=False, server_default="[]"), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + + +def downgrade() -> None: + op.drop_table("blueprints") diff --git a/backend/api/blueprint_routes.py b/backend/api/blueprint_routes.py new file mode 100644 index 0000000..ca8b9ac --- /dev/null +++ b/backend/api/blueprint_routes.py @@ -0,0 +1,153 @@ +""" +REST API routes for council blueprint CRUD. + +Endpoints: + GET /api/councils/ — List all blueprints + POST /api/councils/ — Create a new blueprint + GET /api/councils/{id} — Get a specific blueprint + PUT /api/councils/{id} — Update a blueprint + DELETE /api/councils/{id} — Delete a blueprint +""" + +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from database import get_session +from services.blueprint_service import ( + create_blueprint, + delete_blueprint, + get_blueprint, + list_blueprints, + update_blueprint, +) + +blueprint_router = APIRouter() + + +# --------------------------------------------------------------------------- +# Request / Response Models +# --------------------------------------------------------------------------- + +class AgentTools(BaseModel): + webSearch: bool = False + pdfReader: bool = False + + +class BlueprintNodeSchema(BaseModel): + id: str + label: str + systemPrompt: str = "" + model: str = "claude-3-5-sonnet" + tools: AgentTools = Field(default_factory=AgentTools) + position: dict = Field(default_factory=lambda: {"x": 0, "y": 0}) + + +class BlueprintEdgeSchema(BaseModel): + id: str + source: str + target: str + type: str = "linear" + condition: Optional[str] = None + + +class BlueprintCreateRequest(BaseModel): + version: int = 1 + name: str = Field(..., min_length=1, max_length=255) + nodes: List[BlueprintNodeSchema] + edges: List[BlueprintEdgeSchema] = [] + id: Optional[str] = None + + +class BlueprintUpdateRequest(BaseModel): + name: Optional[str] = Field(None, min_length=1, max_length=255) + nodes: Optional[List[BlueprintNodeSchema]] = None + edges: Optional[List[BlueprintEdgeSchema]] = None + + +class BlueprintResponse(BaseModel): + id: str + version: int + name: str + nodes: list + edges: list + createdAt: Optional[str] = None + updatedAt: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + +@blueprint_router.get("/councils/", response_model=List[BlueprintResponse]) +async def list_all_blueprints( + session: AsyncSession = Depends(get_session), +): + """List all council blueprints.""" + blueprints = await list_blueprints(session) + return [bp.to_dict() for bp in blueprints] + + +@blueprint_router.post( + "/councils/", + response_model=BlueprintResponse, + status_code=201, +) +async def create_new_blueprint( + request: BlueprintCreateRequest, + session: AsyncSession = Depends(get_session), +): + """Create a new council blueprint.""" + bp = await create_blueprint( + session=session, + name=request.name, + nodes=[n.model_dump() for n in request.nodes], + edges=[e.model_dump() for e in request.edges], + blueprint_id=request.id, + version=request.version, + ) + return bp.to_dict() + + +@blueprint_router.get("/councils/{blueprint_id}", response_model=BlueprintResponse) +async def get_single_blueprint( + blueprint_id: str, + session: AsyncSession = Depends(get_session), +): + """Retrieve a specific blueprint by ID.""" + bp = await get_blueprint(session, blueprint_id) + if bp is None: + raise HTTPException(status_code=404, detail=f"Blueprint '{blueprint_id}' not found.") + return bp.to_dict() + + +@blueprint_router.put("/councils/{blueprint_id}", response_model=BlueprintResponse) +async def update_existing_blueprint( + blueprint_id: str, + request: BlueprintUpdateRequest, + session: AsyncSession = Depends(get_session), +): + """Update an existing blueprint.""" + bp = await update_blueprint( + session=session, + blueprint_id=blueprint_id, + name=request.name, + nodes=[n.model_dump() for n in request.nodes] if request.nodes is not None else None, + edges=[e.model_dump() for e in request.edges] if request.edges is not None else None, + ) + if bp is None: + raise HTTPException(status_code=404, detail=f"Blueprint '{blueprint_id}' not found.") + return bp.to_dict() + + +@blueprint_router.delete("/councils/{blueprint_id}", status_code=204) +async def delete_existing_blueprint( + blueprint_id: str, + session: AsyncSession = Depends(get_session), +): + """Delete a blueprint by ID.""" + deleted = await delete_blueprint(session, blueprint_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"Blueprint '{blueprint_id}' not found.") diff --git a/backend/api/routes.py b/backend/api/routes.py index 316202e..68fbf03 100644 --- a/backend/api/routes.py +++ b/backend/api/routes.py @@ -2,17 +2,22 @@ REST API routes for CouncilOS. Endpoints: - POST /api/councils/run — Start a new council run (async, returns run_id) - GET /api/councils/run/{run_id} — Poll the status/result of a run - GET /api/health — Health check + POST /api/councils/run — Start a new council run (Phase 1 hard-coded graph) + POST /api/councils/{id}/run — Start a run from a saved blueprint (Phase 3) + GET /api/councils/run/{run_id} — Poll the status/result of a run + GET /api/health — Health check """ import uuid from typing import Optional -from fastapi import APIRouter, HTTPException, BackgroundTasks +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession from services.graph_builder import run_council_async +from services.dynamic_graph_builder import run_blueprint_council_async +from services.blueprint_service import get_blueprint +from database import get_session from api.run_store import run_store @@ -64,7 +69,7 @@ async def start_council_run( background_tasks: BackgroundTasks, ): """ - Start a new council run. + Start a new council run using the Phase 1 hard-coded graph. The run executes asynchronously in the background. Poll GET /api/councils/run/{run_id} for the result, or connect to the @@ -85,6 +90,45 @@ async def start_council_run( ) +@router.post( + "/councils/{blueprint_id}/run", + response_model=CouncilRunResponse, + status_code=202, +) +async def start_blueprint_run( + blueprint_id: str, + request: CouncilRunRequest, + background_tasks: BackgroundTasks, + session: AsyncSession = Depends(get_session), +): + """ + Start a council run using a saved blueprint (Phase 3 dynamic graph). + + Reads the blueprint from PostgreSQL and dynamically constructs the + LangGraph execution graph at runtime. + """ + bp = await get_blueprint(session, blueprint_id) + if bp is None: + raise HTTPException(status_code=404, detail=f"Blueprint '{blueprint_id}' not found.") + + run_id = str(uuid.uuid4()) + run_store.create(run_id, request.input_topic) + + blueprint_dict = bp.to_dict() + background_tasks.add_task( + _execute_blueprint_run, run_id, request.input_topic, blueprint_dict + ) + + return CouncilRunResponse( + run_id=run_id, + status="pending", + message=( + f"Council run started from blueprint '{bp.name}'. " + f"Connect to /ws/council/{run_id} for live updates." + ), + ) + + @router.get("/councils/run/{run_id}", response_model=CouncilResultResponse) async def get_council_result(run_id: str): """ @@ -110,7 +154,7 @@ async def get_council_result(run_id: str): async def _execute_run(run_id: str, input_topic: str) -> None: """ - Background task that runs the LangGraph council and updates the run store. + Background task that runs the Phase 1 hard-coded LangGraph council. """ run_store.update(run_id, {"status": "running"}) try: @@ -133,3 +177,33 @@ async def _execute_run(run_id: str, input_topic: str) -> None: ) except Exception as exc: # noqa: BLE001 run_store.update(run_id, {"status": "failed", "error": str(exc)}) + + +async def _execute_blueprint_run( + run_id: str, input_topic: str, blueprint: dict +) -> None: + """ + Background task that runs a dynamically built LangGraph from a blueprint. + """ + run_store.update(run_id, {"status": "running"}) + try: + final_state = await run_blueprint_council_async( + blueprint=blueprint, + input_topic=input_topic, + run_id=run_id, + on_node_event=lambda nid, node: run_store.update( + nid, {"active_node": node} + ), + ) + run_store.update( + run_id, + { + "status": "completed", + "final_draft": final_state.get("current_draft"), + "critic_score": final_state.get("critic_score"), + "iteration_count": final_state.get("iteration_count"), + "active_node": "done", + }, + ) + except Exception as exc: # noqa: BLE001 + run_store.update(run_id, {"status": "failed", "error": str(exc)}) diff --git a/backend/database.py b/backend/database.py new file mode 100644 index 0000000..4420fc6 --- /dev/null +++ b/backend/database.py @@ -0,0 +1,48 @@ +""" +Database connection management for CouncilOS. + +Provides an async SQLAlchemy engine and session factory backed by PostgreSQL. +Falls back to SQLite for development/testing if DATABASE_URL is not set. +""" + +import os + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +DATABASE_URL = os.environ.get( + "DATABASE_URL", + "sqlite+aiosqlite:///./councilOS_dev.db", +) + +# For SQLite async support, replace the driver if needed +if DATABASE_URL.startswith("sqlite"): + connect_args = {"check_same_thread": False} +else: + connect_args = {} + +engine = create_async_engine( + DATABASE_URL, + echo=False, + connect_args=connect_args, +) + +async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +async def get_session() -> AsyncSession: + """Dependency that yields an async database session.""" + async with async_session() as session: + yield session + + +async def init_db() -> None: + """Create all tables. Used at application startup.""" + from models.blueprint import Base + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def close_db() -> None: + """Dispose of the engine connection pool.""" + await engine.dispose() diff --git a/backend/main.py b/backend/main.py index 1c2f4d3..09ac65d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,10 +5,16 @@ Start the server: uvicorn main:app --reload --port 8000 API Overview: - POST /api/councils/run — Start a council run - GET /api/councils/run/{run_id} — Poll run status/result - GET /api/health — Health check - WS /ws/council/{run_id} — Real-time agent status events + POST /api/councils/ — Create a blueprint + GET /api/councils/ — List all blueprints + GET /api/councils/{id} — Get specific blueprint + PUT /api/councils/{id} — Update a blueprint + DELETE /api/councils/{id} — Delete a blueprint + POST /api/councils/run — Start a run (Phase 1 hard-coded graph) + POST /api/councils/{id}/run — Start a run from a blueprint (Phase 3) + GET /api/councils/run/{run_id} — Poll run status/result + GET /api/health — Health check + WS /ws/council/{run_id} — Real-time agent status events """ from contextlib import asynccontextmanager @@ -16,14 +22,19 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from api.routes import router +from api.blueprint_routes import blueprint_router from api.websocket import ws_router +from database import init_db, close_db @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan: startup and shutdown logic.""" print("CouncilOS API starting up...") + await init_db() + print("Database initialized.") yield + await close_db() print("CouncilOS API shutting down...") @@ -34,7 +45,7 @@ app = FastAPI( "Orchestrates LangGraph council runs and streams real-time agent " "status via WebSockets." ), - version="0.1.0", + version="0.2.0", lifespan=lifespan, ) @@ -49,6 +60,7 @@ app.add_middleware( # Mount REST routes under /api prefix app.include_router(router, prefix="/api") +app.include_router(blueprint_router, prefix="/api") # Mount WebSocket routes (no prefix — path is /ws/council/{run_id}) app.include_router(ws_router) diff --git a/backend/models/__init__.py b/backend/models/__init__.py new file mode 100644 index 0000000..915664d --- /dev/null +++ b/backend/models/__init__.py @@ -0,0 +1,5 @@ +"""SQLAlchemy database models for CouncilOS.""" + +from .blueprint import Blueprint + +__all__ = ["Blueprint"] diff --git a/backend/models/blueprint.py b/backend/models/blueprint.py new file mode 100644 index 0000000..c449275 --- /dev/null +++ b/backend/models/blueprint.py @@ -0,0 +1,62 @@ +""" +Blueprint model — stores council blueprints as JSON in PostgreSQL. + +Each blueprint represents a complete council graph configuration created +by the user in the "Rat-Architekt" (Setup Mode) frontend tab. +""" + +import uuid +from datetime import datetime, timezone + +from sqlalchemy import Column, DateTime, Integer, JSON, String +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + """SQLAlchemy declarative base for all models.""" + + pass + + +class Blueprint(Base): + """ + A council blueprint stored in PostgreSQL. + + The nodes and edges are stored as JSON columns matching the + CouncilBlueprint TypeScript interface from the frontend. + """ + + __tablename__ = "blueprints" + + id = Column( + String(36), + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + version = Column(Integer, nullable=False, default=1) + name = Column(String(255), nullable=False) + nodes = Column(JSON, nullable=False, default=list) + edges = Column(JSON, nullable=False, default=list) + created_at = Column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + updated_at = Column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + + def to_dict(self) -> dict: + """Serialize to the CouncilBlueprint JSON format expected by the frontend.""" + return { + "id": self.id, + "version": self.version, + "name": self.name, + "nodes": self.nodes, + "edges": self.edges, + "createdAt": self.created_at.isoformat() if self.created_at else None, + "updatedAt": self.updated_at.isoformat() if self.updated_at else None, + } diff --git a/backend/requirements.txt b/backend/requirements.txt index 25f1a97..c3635b5 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,6 +12,7 @@ python-multipart>=0.0.9 # Database asyncpg>=0.29.0 +aiosqlite>=0.20.0 sqlalchemy[asyncio]>=2.0.0 alembic>=1.13.0 @@ -26,6 +27,7 @@ tavily-python>=0.3.0 python-dotenv>=1.0.0 pydantic>=2.0.0 pydantic-settings>=2.0.0 +typing-extensions>=4.5.0 # Linting and formatting ruff>=0.4.0 diff --git a/backend/services/blueprint_service.py b/backend/services/blueprint_service.py new file mode 100644 index 0000000..9f4d357 --- /dev/null +++ b/backend/services/blueprint_service.py @@ -0,0 +1,95 @@ +""" +Blueprint Service — CRUD operations for council blueprints. + +Handles persistence of blueprints to PostgreSQL via SQLAlchemy async sessions. +""" + +from datetime import datetime, timezone +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from models.blueprint import Blueprint + + +async def create_blueprint( + session: AsyncSession, + name: str, + nodes: list, + edges: list, + blueprint_id: Optional[str] = None, + version: int = 1, +) -> Blueprint: + """Create and persist a new blueprint.""" + bp = Blueprint( + name=name, + version=version, + nodes=nodes, + edges=edges, + ) + if blueprint_id: + bp.id = blueprint_id + + session.add(bp) + await session.commit() + await session.refresh(bp) + return bp + + +async def get_blueprint( + session: AsyncSession, + blueprint_id: str, +) -> Optional[Blueprint]: + """Retrieve a blueprint by ID.""" + result = await session.execute( + select(Blueprint).where(Blueprint.id == blueprint_id) + ) + return result.scalar_one_or_none() + + +async def list_blueprints(session: AsyncSession) -> List[Blueprint]: + """Retrieve all blueprints, ordered by most recently updated.""" + result = await session.execute( + select(Blueprint).order_by(Blueprint.updated_at.desc()) + ) + return list(result.scalars().all()) + + +async def update_blueprint( + session: AsyncSession, + blueprint_id: str, + name: Optional[str] = None, + nodes: Optional[list] = None, + edges: Optional[list] = None, +) -> Optional[Blueprint]: + """Update an existing blueprint. Returns None if not found.""" + bp = await get_blueprint(session, blueprint_id) + if bp is None: + return None + + if name is not None: + bp.name = name + if nodes is not None: + bp.nodes = nodes + if edges is not None: + bp.edges = edges + bp.updated_at = datetime.now(timezone.utc) + + await session.commit() + await session.refresh(bp) + return bp + + +async def delete_blueprint( + session: AsyncSession, + blueprint_id: str, +) -> bool: + """Delete a blueprint by ID. Returns True if deleted, False if not found.""" + bp = await get_blueprint(session, blueprint_id) + if bp is None: + return False + + await session.delete(bp) + await session.commit() + return True diff --git a/backend/services/dynamic_graph_builder.py b/backend/services/dynamic_graph_builder.py new file mode 100644 index 0000000..af85f52 --- /dev/null +++ b/backend/services/dynamic_graph_builder.py @@ -0,0 +1,393 @@ +""" +Dynamic Graph Builder — constructs LangGraph graphs from JSON blueprints. + +This is the Phase 3 replacement for the hard-coded graph in graph_builder.py. +It reads a CouncilBlueprint JSON (as produced by the frontend parser) and +dynamically constructs the LangGraph StateGraph with the correct nodes, +edges, and conditional routing. +""" + +import asyncio +import os +from typing import Any, Callable, Dict, List, Optional + +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.graph import END, StateGraph + +from state import CouncilState, APPROVAL_THRESHOLD, MAX_ITERATIONS + + +# --------------------------------------------------------------------------- +# LLM factory — maps model names from the frontend to LangChain chat models +# --------------------------------------------------------------------------- + +_MODEL_MAP = { + "claude-3-5-sonnet": lambda: ChatAnthropic( + model="claude-3-5-sonnet-20241022", + api_key=os.environ.get("ANTHROPIC_API_KEY"), + temperature=0.7, + max_tokens=4096, + ), + "gpt-4o": lambda: ChatOpenAI( + model="gpt-4o", + api_key=os.environ.get("OPENAI_API_KEY"), + temperature=0.7, + max_tokens=4096, + ), +} + + +def _get_llm(model_name: str) -> Any: + """Instantiate a LangChain chat model by frontend model name.""" + factory = _MODEL_MAP.get(model_name) + if factory is None: + raise ValueError( + f"Unknown model '{model_name}'. " + f"Supported models: {list(_MODEL_MAP.keys())}" + ) + return factory() + + +# --------------------------------------------------------------------------- +# Generic agent node factory +# --------------------------------------------------------------------------- + +def _make_agent_node( + node_id: str, + label: str, + system_prompt: str, + model_name: str, +) -> Callable[[CouncilState], dict]: + """ + Create a LangGraph node function for a user-defined agent. + + Each node function reads the CouncilState, calls the configured LLM + with the agent's system prompt, and returns updated state fields. + + Args: + node_id: Unique node ID from the blueprint. + label: Display name of the agent (used in prompts). + system_prompt: The persona / role definition for this agent. + model_name: Which LLM to use ("claude-3-5-sonnet" | "gpt-4o"). + + Returns: + A callable (CouncilState) -> dict suitable for StateGraph.add_node(). + """ + + def agent_node(state: CouncilState) -> dict: + llm = _get_llm(model_name) + + # Build user prompt from current state + if not state["current_draft"]: + user_content = ( + f"Please work on the following topic:\n\n{state['input_topic']}" + ) + elif state["feedback_history"]: + feedback_block = "\n\n---\n".join( + f"Feedback round {i + 1}:\n{fb}" + for i, fb in enumerate(state["feedback_history"]) + ) + user_content = ( + f"Topic: {state['input_topic']}\n\n" + f"Current draft:\n{state['current_draft']}\n\n" + f"Feedback ({len(state['feedback_history'])} round(s)):\n\n" + f"{feedback_block}\n\n" + f"Please produce an improved version." + ) + else: + user_content = ( + f"Topic: {state['input_topic']}\n\n" + f"Current draft:\n{state['current_draft']}\n\n" + f"Please review and improve this draft." + ) + + system_msg = SystemMessage(content=system_prompt) + user_msg = HumanMessage(content=user_content) + response = llm.invoke([system_msg, user_msg]) + + return { + "current_draft": response.content, + "messages": [system_msg, user_msg, response], + "active_node": node_id, + "iteration_count": state.get("iteration_count", 0) + 1, + } + + agent_node.__name__ = f"agent_{node_id}" + return agent_node + + +# --------------------------------------------------------------------------- +# Conditional routing +# --------------------------------------------------------------------------- + +def _make_conditional_router( + source_id: str, + conditional_edges: List[Dict[str, str]], + linear_target: Optional[str], +) -> Callable[[CouncilState], str]: + """ + Build a conditional routing function for edges originating from source_id. + + This looks at `route_decision` in the state and maps it to the correct + target node ID using the condition labels from the blueprint edges. + + Args: + source_id: The node that has outgoing conditional edges. + conditional_edges: List of {"target": node_id, "condition": "..."}. + linear_target: Fallback target if no condition matches (from linear edges). + + Returns: + A function (CouncilState) -> str returning the next node ID. + """ + condition_map = {e["condition"]: e["target"] for e in conditional_edges} + + def router(state: CouncilState) -> str: + decision = state.get("route_decision", "") + if decision in condition_map: + return condition_map[decision] + # If there's a linear fallback, use it + if linear_target: + return linear_target + # Default: return first conditional target as fallback + if conditional_edges: + return conditional_edges[0]["target"] + return END + + router.__name__ = f"route_from_{source_id}" + return router + + +# --------------------------------------------------------------------------- +# Critic-style node detection and creation +# --------------------------------------------------------------------------- + +_CRITIC_KEYWORDS = {"critic", "kritik", "bewert", "evaluat", "review", "score"} + + +def _is_critic_like(system_prompt: str) -> bool: + """Heuristic: does this agent's prompt suggest it's a critic/evaluator?""" + lower = system_prompt.lower() + return any(kw in lower for kw in _CRITIC_KEYWORDS) + + +def _make_critic_node( + node_id: str, + label: str, + system_prompt: str, + model_name: str, +) -> Callable[[CouncilState], dict]: + """ + Create a critic-style node that scores and routes. + + This node evaluates the current draft and sets route_decision + to "approve" or "rework" based on the score. + """ + import re + + critic_system = ( + system_prompt + "\n\n" + "IMPORTANT: You must respond in EXACTLY this format:\n\n" + "SCORE: \n" + "VERDICT: <\"approve\" if score >= 8, otherwise \"rework\">\n" + "FEEDBACK:\n" + "\n\n" + "Scoring: 0-3 poor, 4-6 adequate, 7 good, 8-9 high quality, 10 exceptional." + ) + + def critic_node(state: CouncilState) -> dict: + # Safety valve + if state.get("iteration_count", 0) >= MAX_ITERATIONS: + return { + "route_decision": "approve", + "critic_score": APPROVAL_THRESHOLD, + "feedback_history": [ + f"[Auto-approved after {MAX_ITERATIONS} iterations]" + ], + "messages": [], + "active_node": node_id, + } + + llm = _get_llm(model_name) + + system_msg = SystemMessage(content=critic_system) + user_msg = HumanMessage( + content=( + f"Evaluate this draft on the topic '{state['input_topic']}':\n\n" + f"{state['current_draft']}" + ) + ) + + response = llm.invoke([system_msg, user_msg]) + + # Parse structured response + score_match = re.search(r"SCORE:\s*(\d+(?:\.\d+)?)", response.content) + feedback_match = re.search(r"FEEDBACK:\s*(.*)", response.content, re.DOTALL) + + score = float(score_match.group(1)) if score_match else 0.0 + score = max(0.0, min(10.0, score)) + feedback = feedback_match.group(1).strip() if feedback_match else response.content.strip() + + route_decision = "approve" if score >= APPROVAL_THRESHOLD else "rework" + + result: dict = { + "critic_score": score, + "route_decision": route_decision, + "messages": [system_msg, user_msg, response], + "active_node": node_id, + } + + if route_decision == "rework": + result["feedback_history"] = [f"Score: {score}/10\n{feedback}"] + + return result + + critic_node.__name__ = f"critic_{node_id}" + return critic_node + + +# --------------------------------------------------------------------------- +# Main: build graph from blueprint JSON +# --------------------------------------------------------------------------- + +def build_graph_from_blueprint(blueprint: dict) -> Any: + """ + Dynamically construct a compiled LangGraph from a CouncilBlueprint JSON. + + Args: + blueprint: A dict matching the CouncilBlueprint schema: + { + "version": 1, + "name": "...", + "nodes": [{"id", "label", "systemPrompt", "model", "tools", "position"}], + "edges": [{"id", "source", "target", "type", "condition?"}] + } + + Returns: + A compiled LangGraph StateGraph ready for invocation. + + Raises: + ValueError: If the blueprint is invalid (no nodes, no entry point, etc.) + """ + nodes = blueprint.get("nodes", []) + edges = blueprint.get("edges", []) + + if not nodes: + raise ValueError("Blueprint has no nodes.") + + # Build node lookup + node_lookup = {n["id"]: n for n in nodes} + + # Find entry point: the node that has no incoming edges + targets = {e["target"] for e in edges} + entry_candidates = [n["id"] for n in nodes if n["id"] not in targets] + if not entry_candidates: + # All nodes have incoming edges (pure cycle) — use first node + entry_candidates = [nodes[0]["id"]] + entry_node_id = entry_candidates[0] + + # Find terminal nodes: nodes that have no outgoing edges + sources = {e["source"] for e in edges} + terminal_nodes = {n["id"] for n in nodes if n["id"] not in sources} + + # Build the StateGraph + graph = StateGraph(CouncilState) + + # Register all nodes + for node in nodes: + nid = node["id"] + label = node.get("label", nid) + system_prompt = node.get("systemPrompt", f"You are {label}.") + model_name = node.get("model", "claude-3-5-sonnet") + + if _is_critic_like(system_prompt): + node_fn = _make_critic_node(nid, label, system_prompt, model_name) + else: + node_fn = _make_agent_node(nid, label, system_prompt, model_name) + + graph.add_node(nid, node_fn) + + # Set entry point + graph.set_entry_point(entry_node_id) + + # Group edges by source + edges_by_source: Dict[str, Dict[str, list]] = {} + for edge in edges: + src = edge["source"] + if src not in edges_by_source: + edges_by_source[src] = {"linear": [], "conditional": []} + if edge.get("type") == "conditional": + edges_by_source[src]["conditional"].append(edge) + else: + edges_by_source[src]["linear"].append(edge) + + # Add edges + for source_id, grouped in edges_by_source.items(): + linear = grouped["linear"] + conditional = grouped["conditional"] + + if conditional: + # Build conditional routing + linear_target = linear[0]["target"] if linear else None + router = _make_conditional_router(source_id, conditional, linear_target) + + # Build the mapping dict for add_conditional_edges + route_map: Dict[str, str] = {} + for ce in conditional: + route_map[ce["target"]] = ce["target"] + if linear_target: + route_map[linear_target] = linear_target + + graph.add_conditional_edges(source_id, router, route_map) + elif linear: + # Simple linear edge (only one target expected) + graph.add_edge(source_id, linear[0]["target"]) + + # Terminal nodes → END + for tid in terminal_nodes: + if tid not in edges_by_source: + graph.add_edge(tid, END) + + return graph.compile() + + +async def run_blueprint_council_async( + blueprint: dict, + input_topic: str, + run_id: str, + on_node_event: Optional[Callable[[str, str], Any]] = None, +) -> CouncilState: + """ + Execute a council run using a dynamically built graph from a blueprint. + + Args: + blueprint: The CouncilBlueprint JSON dict. + input_topic: The user's prompt. + run_id: Unique identifier for this run. + on_node_event: Optional callback for WebSocket node events. + + Returns: + The final CouncilState after execution completes. + """ + compiled_graph = build_graph_from_blueprint(blueprint) + + initial_state = CouncilState( + input_topic=input_topic, + current_draft="", + feedback_history=[], + route_decision="", + messages=[], + iteration_count=0, + critic_score=None, + run_id=run_id, + active_node="", + ) + + loop = asyncio.get_event_loop() + final_state = await loop.run_in_executor( + None, + lambda: compiled_graph.invoke(initial_state), + ) + + return final_state diff --git a/backend/tests/test_blueprint_api.py b/backend/tests/test_blueprint_api.py new file mode 100644 index 0000000..d393e48 --- /dev/null +++ b/backend/tests/test_blueprint_api.py @@ -0,0 +1,176 @@ +""" +Integration tests for the blueprint CRUD REST endpoints. + +Overrides the database dependency to use an in-memory SQLite database. +""" + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from models.blueprint import Base +from database import get_session +from main import app + + +# --------------------------------------------------------------------------- +# Test database setup +# --------------------------------------------------------------------------- + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" +test_engine = create_async_engine(TEST_DATABASE_URL, echo=False) +TestSessionLocal = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) + + +async def override_get_session(): + async with TestSessionLocal() as session: + yield session + + +app.dependency_overrides[get_session] = override_get_session + + +@pytest_asyncio.fixture(autouse=True) +async def setup_db(): + """Create and tear down tables for each test.""" + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest_asyncio.fixture +async def client(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +# --------------------------------------------------------------------------- +# Sample payload +# --------------------------------------------------------------------------- + +SAMPLE_BLUEPRINT = { + "version": 1, + "name": "Test Council", + "nodes": [ + { + "id": "node-1", + "label": "Master", + "systemPrompt": "You are the master writer.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 0, "y": 0}, + }, + { + "id": "node-2", + "label": "Critic", + "systemPrompt": "You evaluate drafts.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 300, "y": 0}, + }, + ], + "edges": [ + {"id": "edge-1", "source": "node-1", "target": "node-2", "type": "linear"}, + ], +} + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestBlueprintEndpoints: + @pytest.mark.asyncio + async def test_create_blueprint(self, client): + response = await client.post("/api/councils/", json=SAMPLE_BLUEPRINT) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "Test Council" + assert data["version"] == 1 + assert len(data["nodes"]) == 2 + assert "id" in data + + @pytest.mark.asyncio + async def test_list_blueprints(self, client): + await client.post("/api/councils/", json=SAMPLE_BLUEPRINT) + await client.post( + "/api/councils/", + json={**SAMPLE_BLUEPRINT, "name": "Second Council"}, + ) + + response = await client.get("/api/councils/") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + @pytest.mark.asyncio + async def test_get_blueprint(self, client): + create_resp = await client.post("/api/councils/", json=SAMPLE_BLUEPRINT) + bp_id = create_resp.json()["id"] + + response = await client.get(f"/api/councils/{bp_id}") + assert response.status_code == 200 + assert response.json()["name"] == "Test Council" + + @pytest.mark.asyncio + async def test_get_nonexistent_returns_404(self, client): + response = await client.get("/api/councils/nonexistent-id") + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_update_blueprint(self, client): + create_resp = await client.post("/api/councils/", json=SAMPLE_BLUEPRINT) + bp_id = create_resp.json()["id"] + + update_resp = await client.put( + f"/api/councils/{bp_id}", + json={"name": "Renamed Council"}, + ) + assert update_resp.status_code == 200 + assert update_resp.json()["name"] == "Renamed Council" + + @pytest.mark.asyncio + async def test_update_nonexistent_returns_404(self, client): + response = await client.put( + "/api/councils/ghost-id", + json={"name": "Ghost"}, + ) + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_blueprint(self, client): + create_resp = await client.post("/api/councils/", json=SAMPLE_BLUEPRINT) + bp_id = create_resp.json()["id"] + + delete_resp = await client.delete(f"/api/councils/{bp_id}") + assert delete_resp.status_code == 204 + + get_resp = await client.get(f"/api/councils/{bp_id}") + assert get_resp.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_nonexistent_returns_404(self, client): + response = await client.delete("/api/councils/ghost-id") + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_create_rejects_missing_name(self, client): + payload = {**SAMPLE_BLUEPRINT} + del payload["name"] + response = await client.post("/api/councils/", json=payload) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_create_rejects_empty_name(self, client): + payload = {**SAMPLE_BLUEPRINT, "name": ""} + response = await client.post("/api/councils/", json=payload) + assert response.status_code == 422 diff --git a/backend/tests/test_blueprint_service.py b/backend/tests/test_blueprint_service.py new file mode 100644 index 0000000..41a232d --- /dev/null +++ b/backend/tests/test_blueprint_service.py @@ -0,0 +1,159 @@ +""" +Tests for the blueprint CRUD service and API endpoints. + +Uses an in-memory SQLite database for isolation. +""" + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from models.blueprint import Base, Blueprint +from services.blueprint_service import ( + create_blueprint, + delete_blueprint, + get_blueprint, + list_blueprints, + update_blueprint, +) + + +# --------------------------------------------------------------------------- +# Test database setup (in-memory SQLite) +# --------------------------------------------------------------------------- + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + +test_engine = create_async_engine(TEST_DATABASE_URL, echo=False) +TestSessionLocal = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) + + +@pytest_asyncio.fixture +async def session(): + """Create tables and yield a fresh session for each test.""" + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async with TestSessionLocal() as sess: + yield sess + + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +# --------------------------------------------------------------------------- +# Sample data +# --------------------------------------------------------------------------- + +SAMPLE_NODES = [ + { + "id": "node-1", + "label": "Master", + "systemPrompt": "You are the master.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 0, "y": 0}, + }, + { + "id": "node-2", + "label": "Critic", + "systemPrompt": "You evaluate drafts.", + "model": "gpt-4o", + "tools": {"webSearch": True, "pdfReader": False}, + "position": {"x": 300, "y": 0}, + }, +] + +SAMPLE_EDGES = [ + {"id": "edge-1", "source": "node-1", "target": "node-2", "type": "linear"}, +] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestBlueprintCRUD: + @pytest.mark.asyncio + async def test_create_blueprint(self, session): + bp = await create_blueprint(session, "Test Council", SAMPLE_NODES, SAMPLE_EDGES) + assert bp.id is not None + assert bp.name == "Test Council" + assert bp.version == 1 + assert len(bp.nodes) == 2 + assert len(bp.edges) == 1 + + @pytest.mark.asyncio + async def test_create_with_custom_id(self, session): + bp = await create_blueprint( + session, "Custom ID", SAMPLE_NODES, SAMPLE_EDGES, blueprint_id="my-custom-id" + ) + assert bp.id == "my-custom-id" + + @pytest.mark.asyncio + async def test_get_blueprint(self, session): + bp = await create_blueprint(session, "Get Test", SAMPLE_NODES, SAMPLE_EDGES) + fetched = await get_blueprint(session, bp.id) + assert fetched is not None + assert fetched.name == "Get Test" + + @pytest.mark.asyncio + async def test_get_nonexistent_returns_none(self, session): + result = await get_blueprint(session, "nonexistent-id") + assert result is None + + @pytest.mark.asyncio + async def test_list_blueprints(self, session): + await create_blueprint(session, "First", SAMPLE_NODES, SAMPLE_EDGES) + await create_blueprint(session, "Second", SAMPLE_NODES, SAMPLE_EDGES) + all_bps = await list_blueprints(session) + assert len(all_bps) == 2 + + @pytest.mark.asyncio + async def test_update_blueprint_name(self, session): + bp = await create_blueprint(session, "Original", SAMPLE_NODES, SAMPLE_EDGES) + updated = await update_blueprint(session, bp.id, name="Renamed") + assert updated is not None + assert updated.name == "Renamed" + + @pytest.mark.asyncio + async def test_update_blueprint_nodes(self, session): + bp = await create_blueprint(session, "Nodes Test", SAMPLE_NODES, SAMPLE_EDGES) + new_nodes = [SAMPLE_NODES[0]] # Remove second node + updated = await update_blueprint(session, bp.id, nodes=new_nodes) + assert updated is not None + assert len(updated.nodes) == 1 + + @pytest.mark.asyncio + async def test_update_nonexistent_returns_none(self, session): + result = await update_blueprint(session, "ghost-id", name="New Name") + assert result is None + + @pytest.mark.asyncio + async def test_delete_blueprint(self, session): + bp = await create_blueprint(session, "To Delete", SAMPLE_NODES, SAMPLE_EDGES) + deleted = await delete_blueprint(session, bp.id) + assert deleted is True + assert await get_blueprint(session, bp.id) is None + + @pytest.mark.asyncio + async def test_delete_nonexistent_returns_false(self, session): + deleted = await delete_blueprint(session, "ghost-id") + assert deleted is False + + @pytest.mark.asyncio + async def test_to_dict_format(self, session): + bp = await create_blueprint(session, "Dict Test", SAMPLE_NODES, SAMPLE_EDGES) + d = bp.to_dict() + assert d["id"] == bp.id + assert d["version"] == 1 + assert d["name"] == "Dict Test" + assert "createdAt" in d + assert "updatedAt" in d + assert isinstance(d["nodes"], list) + assert isinstance(d["edges"], list) diff --git a/backend/tests/test_dynamic_graph_builder.py b/backend/tests/test_dynamic_graph_builder.py new file mode 100644 index 0000000..91d6831 --- /dev/null +++ b/backend/tests/test_dynamic_graph_builder.py @@ -0,0 +1,330 @@ +""" +Tests for the dynamic graph builder (Phase 3). + +Verifies that build_graph_from_blueprint correctly creates LangGraph graphs +from JSON blueprints matching the frontend's CouncilBlueprint format. +All LLM calls are mocked. +""" + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import pytest +from unittest.mock import patch, MagicMock + +from services.dynamic_graph_builder import ( + build_graph_from_blueprint, + _make_agent_node, + _make_critic_node, + _make_conditional_router, + _is_critic_like, + _get_llm, +) +from services.graph_builder import create_initial_state +from state import CouncilState, APPROVAL_THRESHOLD, MAX_ITERATIONS + + +# --------------------------------------------------------------------------- +# Sample blueprints for testing +# --------------------------------------------------------------------------- + +SIMPLE_LINEAR_BLUEPRINT = { + "version": 1, + "name": "Simple Linear", + "nodes": [ + { + "id": "node-1", + "label": "Writer", + "systemPrompt": "You are a professional writer.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 0, "y": 0}, + }, + { + "id": "node-2", + "label": "Editor", + "systemPrompt": "You are a professional editor. Polish the text.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 300, "y": 0}, + }, + ], + "edges": [ + {"id": "edge-1", "source": "node-1", "target": "node-2", "type": "linear"}, + ], +} + +CYCLIC_BLUEPRINT = { + "version": 1, + "name": "Cyclic Council", + "nodes": [ + { + "id": "master", + "label": "Master Agent", + "systemPrompt": "You are the master writer.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 0, "y": 0}, + }, + { + "id": "critic", + "label": "Critic Agent", + "systemPrompt": "You are the critic. Evaluate and score the draft.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 300, "y": 0}, + }, + { + "id": "writer", + "label": "Final Writer", + "systemPrompt": "You polish approved drafts.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 600, "y": 0}, + }, + ], + "edges": [ + {"id": "e1", "source": "master", "target": "critic", "type": "linear"}, + { + "id": "e2", + "source": "critic", + "target": "master", + "type": "conditional", + "condition": "rework", + }, + { + "id": "e3", + "source": "critic", + "target": "writer", + "type": "conditional", + "condition": "approve", + }, + ], +} + + +# --------------------------------------------------------------------------- +# Test: critic detection heuristic +# --------------------------------------------------------------------------- + +class TestCriticDetection: + def test_detects_critic_keyword(self): + assert _is_critic_like("You are the critic. Evaluate drafts.") is True + + def test_detects_evaluate_keyword(self): + assert _is_critic_like("Your role is to evaluate and score.") is True + + def test_detects_review_keyword(self): + assert _is_critic_like("Review the document for quality.") is True + + def test_no_match_for_writer(self): + assert _is_critic_like("You are a professional writer.") is False + + def test_case_insensitive(self): + assert _is_critic_like("You are the CRITIC agent.") is True + + +# --------------------------------------------------------------------------- +# Test: conditional routing +# --------------------------------------------------------------------------- + +class TestConditionalRouter: + def test_routes_to_correct_target(self): + edges = [ + {"target": "node-a", "condition": "rework"}, + {"target": "node-b", "condition": "approve"}, + ] + router = _make_conditional_router("source", edges, None) + + state = create_initial_state("topic", "run-1") + state["route_decision"] = "approve" + assert router(state) == "node-b" + + def test_routes_rework(self): + edges = [ + {"target": "node-a", "condition": "rework"}, + {"target": "node-b", "condition": "approve"}, + ] + router = _make_conditional_router("source", edges, None) + + state = create_initial_state("topic", "run-1") + state["route_decision"] = "rework" + assert router(state) == "node-a" + + def test_unknown_decision_uses_linear_fallback(self): + edges = [ + {"target": "node-a", "condition": "rework"}, + ] + router = _make_conditional_router("source", edges, "fallback-node") + + state = create_initial_state("topic", "run-1") + state["route_decision"] = "unknown" + assert router(state) == "fallback-node" + + def test_unknown_decision_uses_first_conditional_as_fallback(self): + edges = [ + {"target": "node-a", "condition": "rework"}, + {"target": "node-b", "condition": "approve"}, + ] + router = _make_conditional_router("source", edges, None) + + state = create_initial_state("topic", "run-1") + state["route_decision"] = "unknown" + assert router(state) == "node-a" + + +# --------------------------------------------------------------------------- +# Test: agent node factory +# --------------------------------------------------------------------------- + +class TestAgentNodeFactory: + def test_agent_node_returns_draft(self): + mock_response = MagicMock() + mock_response.content = "Generated content about AI." + + with patch("services.dynamic_graph_builder.ChatAnthropic") as MockLLM: + MockLLM.return_value.invoke.return_value = mock_response + + node_fn = _make_agent_node("node-1", "Writer", "You write.", "claude-3-5-sonnet") + state = create_initial_state("AI basics", "run-1") + result = node_fn(state) + + assert result["current_draft"] == "Generated content about AI." + assert result["active_node"] == "node-1" + assert result["iteration_count"] == 1 + + def test_agent_node_with_existing_draft_and_feedback(self): + mock_response = MagicMock() + mock_response.content = "Improved draft." + + with patch("services.dynamic_graph_builder.ChatAnthropic") as MockLLM: + MockLLM.return_value.invoke.return_value = mock_response + + node_fn = _make_agent_node("node-1", "Writer", "You write.", "claude-3-5-sonnet") + state = create_initial_state("AI", "run-1") + state["current_draft"] = "First draft" + state["feedback_history"] = ["Needs more detail"] + state["iteration_count"] = 1 + result = node_fn(state) + + assert result["current_draft"] == "Improved draft." + assert result["iteration_count"] == 2 + + +# --------------------------------------------------------------------------- +# Test: critic node factory +# --------------------------------------------------------------------------- + +class TestCriticNodeFactory: + def test_critic_node_approves_high_score(self): + mock_response = MagicMock() + mock_response.content = "SCORE: 9\nVERDICT: approve\nFEEDBACK:\nExcellent work." + + with patch("services.dynamic_graph_builder.ChatAnthropic") as MockLLM: + MockLLM.return_value.invoke.return_value = mock_response + + node_fn = _make_critic_node("critic-1", "Critic", "You evaluate.", "claude-3-5-sonnet") + state = create_initial_state("Topic", "run-1") + state["current_draft"] = "A great draft" + result = node_fn(state) + + assert result["route_decision"] == "approve" + assert result["critic_score"] == 9.0 + + def test_critic_node_reworks_low_score(self): + mock_response = MagicMock() + mock_response.content = "SCORE: 4\nVERDICT: rework\nFEEDBACK:\nNeeds more structure." + + with patch("services.dynamic_graph_builder.ChatAnthropic") as MockLLM: + MockLLM.return_value.invoke.return_value = mock_response + + node_fn = _make_critic_node("critic-1", "Critic", "You evaluate.", "claude-3-5-sonnet") + state = create_initial_state("Topic", "run-1") + state["current_draft"] = "Draft" + result = node_fn(state) + + assert result["route_decision"] == "rework" + assert result["critic_score"] == 4.0 + assert len(result["feedback_history"]) == 1 + assert "structure" in result["feedback_history"][0] + + def test_critic_safety_valve_at_max_iterations(self): + node_fn = _make_critic_node("critic-1", "Critic", "Evaluate.", "claude-3-5-sonnet") + state = create_initial_state("Topic", "run-1") + state["current_draft"] = "Draft" + state["iteration_count"] = MAX_ITERATIONS + + result = node_fn(state) + + assert result["route_decision"] == "approve" + assert result["critic_score"] == APPROVAL_THRESHOLD + + +# --------------------------------------------------------------------------- +# Test: build_graph_from_blueprint +# --------------------------------------------------------------------------- + +class TestBuildGraphFromBlueprint: + def test_rejects_empty_blueprint(self): + with pytest.raises(ValueError, match="no nodes"): + build_graph_from_blueprint({"version": 1, "name": "Empty", "nodes": [], "edges": []}) + + def test_builds_linear_graph(self): + """A simple linear blueprint should compile without error.""" + graph = build_graph_from_blueprint(SIMPLE_LINEAR_BLUEPRINT) + assert graph is not None + + def test_builds_cyclic_graph(self): + """A cyclic blueprint with conditional edges should compile.""" + graph = build_graph_from_blueprint(CYCLIC_BLUEPRINT) + assert graph is not None + + def test_entry_point_is_node_with_no_incoming(self): + """The entry point should be the node that has no incoming edges.""" + # In CYCLIC_BLUEPRINT, 'master' has no incoming edges except from critic (conditional rework), + # but critic->master is an edge so master IS a target. The first node without incoming = master + # Actually master IS a target of the rework edge. Let's verify with simple linear. + graph = build_graph_from_blueprint(SIMPLE_LINEAR_BLUEPRINT) + assert graph is not None # node-1 has no incoming, so it's the entry + + def test_single_node_blueprint(self): + """A single node with no edges should work (trivial graph).""" + bp = { + "version": 1, + "name": "Single", + "nodes": [ + { + "id": "only-node", + "label": "Solo Agent", + "systemPrompt": "You work alone.", + "model": "claude-3-5-sonnet", + "tools": {"webSearch": False, "pdfReader": False}, + "position": {"x": 0, "y": 0}, + } + ], + "edges": [], + } + graph = build_graph_from_blueprint(bp) + assert graph is not None + + +# --------------------------------------------------------------------------- +# Test: model lookup +# --------------------------------------------------------------------------- + +class TestModelLookup: + def test_unknown_model_raises(self): + with pytest.raises(ValueError, match="Unknown model"): + _get_llm("nonexistent-model") + + def test_claude_model_creates_instance(self): + with patch("services.dynamic_graph_builder.ChatAnthropic") as MockLLM: + llm = _get_llm("claude-3-5-sonnet") + MockLLM.assert_called_once() + + def test_gpt4o_model_creates_instance(self): + with patch("services.dynamic_graph_builder.ChatOpenAI") as MockLLM: + llm = _get_llm("gpt-4o") + MockLLM.assert_called_once()