Implement Phase 4: tools, God Mode, and missing features
Backend: - Add Tavily web search tool wrapper (tools/web_search.py) - Add PDF reader + ChromaDB vector store tool (tools/pdf_reader.py) - Bind tools to LLM calls via .bind_tools() in dynamic_graph_builder - Implement God Mode using LangGraph interrupt_before + MemorySaver - Add approve/reject/modify API endpoints for God Mode - Add PDF upload endpoint with ingestion pipeline - Add persistent run history (CouncilRun model + run_service + API) - Add Alembic migration for council_runs table - Enhance WebSocket to emit run_paused and run_resumed events - Add tests for tools, God Mode, and run history Frontend: - Add God Mode approval UI (GodModePanel component) - Add Auto-Pilot / God Mode toggle in Konferenzzimmer - Add functional PDF upload handler - Add Conditional Edge editor (EdgeSettingsPanel component) - Add edge click selection in ArchitectCanvas - Update Zustand store with edge selection and update actions - Update types for God Mode, execution modes, and WS events - Update API client with God Mode, PDF upload, and blueprint run endpoints - Update WebSocket hook for paused/resumed events - Add Vitest config and frontend tests (store, parser, types, API) https://claude.ai/code/session_017U6idFgaqnYTXzPxA7mxMv
This commit is contained in:
parent
c6d0c4a636
commit
001649a364
31 changed files with 2502 additions and 81 deletions
56
backend/alembic/versions/002_create_council_runs_table.py
Normal file
56
backend/alembic/versions/002_create_council_runs_table.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Create council_runs table for persistent run history
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2026-02-21
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "002"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"council_runs",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column("blueprint_id", sa.String(36), nullable=True),
|
||||
sa.Column("input_topic", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column(
|
||||
"execution_mode",
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default="auto-pilot",
|
||||
),
|
||||
sa.Column("final_draft", sa.Text(), nullable=True),
|
||||
sa.Column("critic_score", sa.Float(), nullable=True),
|
||||
sa.Column("iteration_count", sa.Integer(), nullable=True),
|
||||
sa.Column("active_node", sa.String(255), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"completed_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("council_runs")
|
||||
|
|
@ -2,23 +2,33 @@
|
|||
REST API routes for CouncilOS.
|
||||
|
||||
Endpoints:
|
||||
POST /api/councils/run — Start a new council run (Phase 1 hard-coded graph)
|
||||
POST /api/councils/{id}/run — Start a run from a saved blueprint (Phase 3)
|
||||
GET /api/councils/run/{run_id} — Poll the status/result of a run
|
||||
GET /api/health — Health check
|
||||
POST /api/councils/run — Start a new council run (Phase 1)
|
||||
POST /api/councils/{id}/run — Start a run from a blueprint (Phase 3)
|
||||
GET /api/councils/run/{run_id} — Poll the status/result of a run
|
||||
POST /api/councils/run/{run_id}/approve — God Mode: approve/reject/modify (Phase 4)
|
||||
GET /api/councils/run/{run_id}/state — God Mode: get paused state (Phase 4)
|
||||
POST /api/councils/upload-pdf — Upload and ingest a PDF (Phase 4)
|
||||
GET /api/health — Health check
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from services.graph_builder import run_council_async
|
||||
from services.dynamic_graph_builder import run_blueprint_council_async
|
||||
from services.blueprint_service import get_blueprint
|
||||
from database import get_session
|
||||
from api.run_store import run_store
|
||||
from database import get_session
|
||||
from services.blueprint_service import get_blueprint
|
||||
from services.dynamic_graph_builder import (
|
||||
get_god_mode_state,
|
||||
resume_god_mode,
|
||||
run_blueprint_council_async,
|
||||
)
|
||||
from services.graph_builder import run_council_async
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -36,11 +46,15 @@ class CouncilRunRequest(BaseModel):
|
|||
description="The user's prompt or document content for the council to work on.",
|
||||
examples=["Erkläre die wichtigsten Konzepte des maschinellen Lernens für Einsteiger."],
|
||||
)
|
||||
god_mode: bool = Field(
|
||||
default=False,
|
||||
description="If true, the run pauses before each node for human approval.",
|
||||
)
|
||||
|
||||
|
||||
class CouncilRunResponse(BaseModel):
|
||||
run_id: str
|
||||
status: str # "pending" | "running" | "completed" | "failed"
|
||||
status: str # "pending" | "running" | "completed" | "failed" | "paused"
|
||||
message: str
|
||||
|
||||
|
||||
|
|
@ -51,6 +65,26 @@ class CouncilResultResponse(BaseModel):
|
|||
critic_score: Optional[float] = None
|
||||
iteration_count: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
paused: Optional[bool] = None
|
||||
next_nodes: Optional[List[str]] = None
|
||||
current_draft: Optional[str] = None
|
||||
|
||||
|
||||
class GodModeApprovalRequest(BaseModel):
|
||||
action: str = Field(
|
||||
...,
|
||||
description="Action to take: 'approve', 'reject', or 'modify'.",
|
||||
)
|
||||
modified_state: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="Partial state override when action is 'modify'.",
|
||||
)
|
||||
|
||||
|
||||
class PdfUploadResponse(BaseModel):
|
||||
filename: str
|
||||
chunks_ingested: int
|
||||
message: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -104,8 +138,8 @@ async def start_blueprint_run(
|
|||
"""
|
||||
Start a council run using a saved blueprint (Phase 3 dynamic graph).
|
||||
|
||||
Reads the blueprint from PostgreSQL and dynamically constructs the
|
||||
LangGraph execution graph at runtime.
|
||||
Set god_mode=true to pause before each agent node and require
|
||||
human approval via the /approve endpoint.
|
||||
"""
|
||||
bp = await get_blueprint(session, blueprint_id)
|
||||
if bp is None:
|
||||
|
|
@ -116,14 +150,19 @@ async def start_blueprint_run(
|
|||
|
||||
blueprint_dict = bp.to_dict()
|
||||
background_tasks.add_task(
|
||||
_execute_blueprint_run, run_id, request.input_topic, blueprint_dict
|
||||
_execute_blueprint_run,
|
||||
run_id,
|
||||
request.input_topic,
|
||||
blueprint_dict,
|
||||
request.god_mode,
|
||||
)
|
||||
|
||||
mode_label = "God Mode" if request.god_mode else "Auto-Pilot"
|
||||
return CouncilRunResponse(
|
||||
run_id=run_id,
|
||||
status="pending",
|
||||
message=(
|
||||
f"Council run started from blueprint '{bp.name}'. "
|
||||
f"Council run started from blueprint '{bp.name}' ({mode_label}). "
|
||||
f"Connect to /ws/council/{run_id} for live updates."
|
||||
),
|
||||
)
|
||||
|
|
@ -133,11 +172,21 @@ async def start_blueprint_run(
|
|||
async def get_council_result(run_id: str):
|
||||
"""
|
||||
Retrieve the current status or final result of a council run.
|
||||
|
||||
In God Mode, includes paused state and next_nodes info.
|
||||
"""
|
||||
run = run_store.get(run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found.")
|
||||
|
||||
# Check for god mode paused state
|
||||
god_state = get_god_mode_state(run_id)
|
||||
paused = god_state["paused"] if god_state else None
|
||||
next_nodes = god_state["next_nodes"] if god_state else None
|
||||
current_draft = (
|
||||
god_state["current_state"].get("current_draft") if god_state else None
|
||||
)
|
||||
|
||||
return CouncilResultResponse(
|
||||
run_id=run_id,
|
||||
status=run["status"],
|
||||
|
|
@ -145,6 +194,97 @@ async def get_council_result(run_id: str):
|
|||
critic_score=run.get("critic_score"),
|
||||
iteration_count=run.get("iteration_count"),
|
||||
error=run.get("error"),
|
||||
paused=paused,
|
||||
next_nodes=next_nodes,
|
||||
current_draft=current_draft,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/councils/run/{run_id}/approve", response_model=CouncilResultResponse)
|
||||
async def approve_god_mode(
|
||||
run_id: str,
|
||||
request: GodModeApprovalRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
"""
|
||||
Approve, reject, or modify a paused God Mode council run.
|
||||
|
||||
Actions:
|
||||
approve — continue execution to the next node
|
||||
reject — stop the run entirely
|
||||
modify — update the state before continuing (pass modified_state)
|
||||
"""
|
||||
god_state = get_god_mode_state(run_id)
|
||||
if not god_state or not god_state.get("paused"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Run '{run_id}' is not paused in God Mode.",
|
||||
)
|
||||
|
||||
if request.action == "reject":
|
||||
state = await resume_god_mode(run_id, action="reject")
|
||||
run_store.update(run_id, {"status": "failed", "error": "Rejected by user in God Mode."})
|
||||
return CouncilResultResponse(
|
||||
run_id=run_id,
|
||||
status="failed",
|
||||
error="Rejected by user in God Mode.",
|
||||
)
|
||||
|
||||
# Resume in background (approve or modify)
|
||||
background_tasks.add_task(
|
||||
_resume_god_mode_task,
|
||||
run_id,
|
||||
request.action,
|
||||
request.modified_state,
|
||||
)
|
||||
|
||||
return CouncilResultResponse(
|
||||
run_id=run_id,
|
||||
status="running",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/councils/run/{run_id}/state")
|
||||
async def get_run_state(run_id: str):
|
||||
"""
|
||||
Get the full paused state of a God Mode run for the approval UI.
|
||||
"""
|
||||
god_state = get_god_mode_state(run_id)
|
||||
if not god_state:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No God Mode session found for run '{run_id}'.",
|
||||
)
|
||||
return god_state
|
||||
|
||||
|
||||
@router.post("/councils/upload-pdf", response_model=PdfUploadResponse)
|
||||
async def upload_pdf(file: UploadFile = File(...)):
|
||||
"""
|
||||
Upload and ingest a PDF file into the ChromaDB vector store.
|
||||
|
||||
The content becomes searchable by agents with the PDF Reader tool enabled.
|
||||
"""
|
||||
if not file.filename or not file.filename.lower().endswith(".pdf"):
|
||||
raise HTTPException(status_code=400, detail="Only PDF files are accepted.")
|
||||
|
||||
from tools.pdf_reader import ingest_pdf
|
||||
|
||||
# Save upload to a temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
chunks = ingest_pdf(tmp_path)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return PdfUploadResponse(
|
||||
filename=file.filename,
|
||||
chunks_ingested=chunks,
|
||||
message=f"Successfully ingested {chunks} chunks from '{file.filename}'.",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -180,7 +320,10 @@ async def _execute_run(run_id: str, input_topic: str) -> None:
|
|||
|
||||
|
||||
async def _execute_blueprint_run(
|
||||
run_id: str, input_topic: str, blueprint: dict
|
||||
run_id: str,
|
||||
input_topic: str,
|
||||
blueprint: dict,
|
||||
god_mode: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Background task that runs a dynamically built LangGraph from a blueprint.
|
||||
|
|
@ -191,10 +334,22 @@ async def _execute_blueprint_run(
|
|||
blueprint=blueprint,
|
||||
input_topic=input_topic,
|
||||
run_id=run_id,
|
||||
god_mode=god_mode,
|
||||
on_node_event=lambda nid, node: run_store.update(
|
||||
nid, {"active_node": node}
|
||||
),
|
||||
)
|
||||
|
||||
# In god mode, the first invoke will pause at the first node
|
||||
if god_mode and final_state:
|
||||
god_state = get_god_mode_state(run_id)
|
||||
if god_state and god_state.get("paused"):
|
||||
run_store.update(run_id, {
|
||||
"status": "paused",
|
||||
"active_node": god_state["next_nodes"][0] if god_state["next_nodes"] else None,
|
||||
})
|
||||
return
|
||||
|
||||
run_store.update(
|
||||
run_id,
|
||||
{
|
||||
|
|
@ -207,3 +362,42 @@ async def _execute_blueprint_run(
|
|||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
run_store.update(run_id, {"status": "failed", "error": str(exc)})
|
||||
|
||||
|
||||
async def _resume_god_mode_task(
|
||||
run_id: str,
|
||||
action: str,
|
||||
modified_state: Optional[dict],
|
||||
) -> None:
|
||||
"""Background task that resumes a god mode run after human approval."""
|
||||
run_store.update(run_id, {"status": "running"})
|
||||
try:
|
||||
state = await resume_god_mode(run_id, action=action, modified_state=modified_state)
|
||||
|
||||
if state is None:
|
||||
run_store.update(run_id, {"status": "failed", "error": "God Mode session not found."})
|
||||
return
|
||||
|
||||
# Check if paused again at next node
|
||||
god_state = get_god_mode_state(run_id)
|
||||
if god_state and god_state.get("paused"):
|
||||
run_store.update(run_id, {
|
||||
"status": "paused",
|
||||
"active_node": god_state["next_nodes"][0] if god_state["next_nodes"] else None,
|
||||
"current_draft": state.get("current_draft"),
|
||||
"critic_score": state.get("critic_score"),
|
||||
"iteration_count": state.get("iteration_count"),
|
||||
})
|
||||
else:
|
||||
run_store.update(
|
||||
run_id,
|
||||
{
|
||||
"status": "completed",
|
||||
"final_draft": state.get("current_draft"),
|
||||
"critic_score": state.get("critic_score"),
|
||||
"iteration_count": state.get("iteration_count"),
|
||||
"active_node": "done",
|
||||
},
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
run_store.update(run_id, {"status": "failed", "error": str(exc)})
|
||||
|
|
|
|||
64
backend/api/run_history_routes.py
Normal file
64
backend/api/run_history_routes.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""
|
||||
REST API routes for council run history.
|
||||
|
||||
Endpoints:
|
||||
GET /api/runs/ — List recent council runs
|
||||
GET /api/runs/{run_id} — Get a specific run's details
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_session
|
||||
from services.run_service import get_run, list_runs
|
||||
|
||||
|
||||
run_history_router = APIRouter()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RunHistoryResponse(BaseModel):
|
||||
id: str
|
||||
blueprint_id: Optional[str] = None
|
||||
input_topic: str
|
||||
status: str
|
||||
execution_mode: str
|
||||
final_draft: Optional[str] = None
|
||||
critic_score: Optional[float] = None
|
||||
iteration_count: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@run_history_router.get("/runs/", response_model=List[RunHistoryResponse])
|
||||
async def list_all_runs(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List recent council runs, ordered by most recent first."""
|
||||
runs = await list_runs(session, limit=limit, offset=offset)
|
||||
return [r.to_dict() for r in runs]
|
||||
|
||||
|
||||
@run_history_router.get("/runs/{run_id}", response_model=RunHistoryResponse)
|
||||
async def get_single_run(
|
||||
run_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Retrieve a specific council run by ID."""
|
||||
run = await get_run(session, run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found.")
|
||||
return run.to_dict()
|
||||
|
|
@ -2,20 +2,22 @@
|
|||
WebSocket endpoint for real-time agent status updates.
|
||||
|
||||
Clients connect to /ws/council/{run_id} and receive JSON events whenever
|
||||
an agent node becomes active. This powers the live diagram pulsing in the frontend.
|
||||
an agent node becomes active or the run status changes.
|
||||
|
||||
Event format:
|
||||
{"event": "node_start", "run_id": "...", "node": "master_agent", "iteration": 2}
|
||||
{"event": "node_complete", "run_id": "...", "node": "critic_agent", "score": 6.5}
|
||||
{"event": "run_complete", "run_id": "...", "final_draft": "..."}
|
||||
{"event": "node_active", "run_id": "...", "node": "master_agent", "iteration": 2}
|
||||
{"event": "run_paused", "run_id": "...", "next_nodes": ["critic_agent"], "current_draft": "..."}
|
||||
{"event": "run_complete", "run_id": "...", "final_draft": "...", "critic_score": 8.5}
|
||||
{"event": "run_failed", "run_id": "...", "error": "..."}
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from api.run_store import run_store
|
||||
from services.dynamic_graph_builder import get_god_mode_state
|
||||
|
||||
|
||||
ws_router = APIRouter()
|
||||
|
|
@ -53,6 +55,7 @@ async def council_websocket(websocket: WebSocket, run_id: str):
|
|||
|
||||
On connect: sends the current run status immediately.
|
||||
While running: polls the run store and pushes status changes.
|
||||
On paused: sends a god mode pause event with next_nodes.
|
||||
On complete/failed: sends a final event and closes the connection.
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
|
@ -77,13 +80,17 @@ async def council_websocket(websocket: WebSocket, run_id: str):
|
|||
|
||||
# Poll for status changes and push updates
|
||||
last_node = None
|
||||
last_status = run["status"]
|
||||
while True:
|
||||
run = run_store.get(run_id)
|
||||
if run is None:
|
||||
break
|
||||
|
||||
current_node = run.get("active_node")
|
||||
if current_node and current_node != last_node:
|
||||
current_status = run["status"]
|
||||
|
||||
# Emit node_active when the active agent changes
|
||||
if current_node and current_node != last_node and current_node != "done":
|
||||
await websocket.send_text(
|
||||
json.dumps({
|
||||
"event": "node_active",
|
||||
|
|
@ -94,7 +101,41 @@ async def council_websocket(websocket: WebSocket, run_id: str):
|
|||
)
|
||||
last_node = current_node
|
||||
|
||||
if run["status"] == "completed":
|
||||
# God Mode: emit pause event
|
||||
if current_status == "paused" and last_status != "paused":
|
||||
god_state = get_god_mode_state(run_id)
|
||||
await websocket.send_text(
|
||||
json.dumps({
|
||||
"event": "run_paused",
|
||||
"run_id": run_id,
|
||||
"next_nodes": god_state["next_nodes"] if god_state else [],
|
||||
"current_draft": (
|
||||
god_state["current_state"].get("current_draft", "")
|
||||
if god_state else ""
|
||||
),
|
||||
"critic_score": (
|
||||
god_state["current_state"].get("critic_score")
|
||||
if god_state else None
|
||||
),
|
||||
"iteration_count": (
|
||||
god_state["current_state"].get("iteration_count")
|
||||
if god_state else None
|
||||
),
|
||||
})
|
||||
)
|
||||
last_status = current_status
|
||||
|
||||
# Status changed from paused back to running (user approved)
|
||||
if current_status == "running" and last_status == "paused":
|
||||
await websocket.send_text(
|
||||
json.dumps({
|
||||
"event": "run_resumed",
|
||||
"run_id": run_id,
|
||||
})
|
||||
)
|
||||
last_status = current_status
|
||||
|
||||
if current_status == "completed":
|
||||
await websocket.send_text(
|
||||
json.dumps({
|
||||
"event": "run_complete",
|
||||
|
|
@ -106,7 +147,7 @@ async def council_websocket(websocket: WebSocket, run_id: str):
|
|||
)
|
||||
break
|
||||
|
||||
if run["status"] == "failed":
|
||||
if current_status == "failed":
|
||||
await websocket.send_text(
|
||||
json.dumps({
|
||||
"event": "run_failed",
|
||||
|
|
@ -116,6 +157,7 @@ async def council_websocket(websocket: WebSocket, run_id: str):
|
|||
)
|
||||
break
|
||||
|
||||
last_status = current_status
|
||||
await asyncio.sleep(0.5) # 500ms polling interval
|
||||
|
||||
except WebSocketDisconnect:
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ async def get_session() -> AsyncSession:
|
|||
async def init_db() -> None:
|
||||
"""Create all tables. Used at application startup."""
|
||||
from models.blueprint import Base
|
||||
import models.council_run # noqa: F401 — register CouncilRun model
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
|
|
|||
|
|
@ -5,16 +5,21 @@ Start the server:
|
|||
uvicorn main:app --reload --port 8000
|
||||
|
||||
API Overview:
|
||||
POST /api/councils/ — Create a blueprint
|
||||
GET /api/councils/ — List all blueprints
|
||||
GET /api/councils/{id} — Get specific blueprint
|
||||
PUT /api/councils/{id} — Update a blueprint
|
||||
DELETE /api/councils/{id} — Delete a blueprint
|
||||
POST /api/councils/run — Start a run (Phase 1 hard-coded graph)
|
||||
POST /api/councils/{id}/run — Start a run from a blueprint (Phase 3)
|
||||
GET /api/councils/run/{run_id} — Poll run status/result
|
||||
GET /api/health — Health check
|
||||
WS /ws/council/{run_id} — Real-time agent status events
|
||||
POST /api/councils/ — Create a blueprint
|
||||
GET /api/councils/ — List all blueprints
|
||||
GET /api/councils/{id} — Get specific blueprint
|
||||
PUT /api/councils/{id} — Update a blueprint
|
||||
DELETE /api/councils/{id} — Delete a blueprint
|
||||
POST /api/councils/run — Start a run (Phase 1)
|
||||
POST /api/councils/{id}/run — Start a run from a blueprint
|
||||
GET /api/councils/run/{run_id} — Poll run status/result
|
||||
POST /api/councils/run/{run_id}/approve — God Mode: approve/reject/modify
|
||||
GET /api/councils/run/{run_id}/state — God Mode: paused state
|
||||
POST /api/councils/upload-pdf — Upload PDF for ingestion
|
||||
GET /api/runs/ — List run history
|
||||
GET /api/runs/{run_id} — Get historical run detail
|
||||
GET /api/health — Health check
|
||||
WS /ws/council/{run_id} — Real-time agent status events
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -23,6 +28,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
|
||||
from api.routes import router
|
||||
from api.blueprint_routes import blueprint_router
|
||||
from api.run_history_routes import run_history_router
|
||||
from api.websocket import ws_router
|
||||
from database import init_db, close_db
|
||||
|
||||
|
|
@ -45,7 +51,7 @@ app = FastAPI(
|
|||
"Orchestrates LangGraph council runs and streams real-time agent "
|
||||
"status via WebSockets."
|
||||
),
|
||||
version="0.2.0",
|
||||
version="0.3.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
|
@ -61,6 +67,7 @@ app.add_middleware(
|
|||
# Mount REST routes under /api prefix
|
||||
app.include_router(router, prefix="/api")
|
||||
app.include_router(blueprint_router, prefix="/api")
|
||||
app.include_router(run_history_router, prefix="/api")
|
||||
|
||||
# Mount WebSocket routes (no prefix — path is /ws/council/{run_id})
|
||||
app.include_router(ws_router)
|
||||
|
|
|
|||
71
backend/models/council_run.py
Normal file
71
backend/models/council_run.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
CouncilRun model — persists council run history in PostgreSQL.
|
||||
|
||||
Each run record stores the execution metadata, status, and results.
|
||||
Replaces the in-memory run_store for persistent storage.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, String, Text
|
||||
from models.blueprint import Base
|
||||
|
||||
|
||||
class CouncilRun(Base):
|
||||
"""
|
||||
A persisted council run stored in PostgreSQL.
|
||||
|
||||
Tracks the full lifecycle of a run: pending → running → completed/failed/paused.
|
||||
"""
|
||||
|
||||
__tablename__ = "council_runs"
|
||||
|
||||
id = Column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
blueprint_id = Column(String(36), nullable=True)
|
||||
input_topic = Column(Text, nullable=False)
|
||||
status = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
)
|
||||
execution_mode = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="auto-pilot",
|
||||
)
|
||||
final_draft = Column(Text, nullable=True)
|
||||
critic_score = Column(Float, nullable=True)
|
||||
iteration_count = Column(Integer, nullable=True)
|
||||
active_node = Column(String(255), nullable=True)
|
||||
error = Column(Text, nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
completed_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize to a JSON-friendly dict."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"blueprint_id": self.blueprint_id,
|
||||
"input_topic": self.input_topic,
|
||||
"status": self.status,
|
||||
"execution_mode": self.execution_mode,
|
||||
"final_draft": self.final_draft,
|
||||
"critic_score": self.critic_score,
|
||||
"iteration_count": self.iteration_count,
|
||||
"active_node": self.active_node,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
}
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
# Core AI orchestration
|
||||
langgraph>=0.2.0
|
||||
langgraph-checkpoint>=2.0.0
|
||||
langchain>=0.2.0
|
||||
langchain-anthropic>=0.1.0
|
||||
langchain-openai>=0.1.0
|
||||
langchain-community>=0.2.0
|
||||
|
||||
# Backend API
|
||||
fastapi>=0.111.0
|
||||
|
|
|
|||
|
|
@ -5,6 +5,11 @@ This is the Phase 3 replacement for the hard-coded graph in graph_builder.py.
|
|||
It reads a CouncilBlueprint JSON (as produced by the frontend parser) and
|
||||
dynamically constructs the LangGraph StateGraph with the correct nodes,
|
||||
edges, and conditional routing.
|
||||
|
||||
Phase 4 additions:
|
||||
- Tool binding: agents with tools enabled (webSearch, pdfReader) get
|
||||
LangChain tools bound to their LLM via .bind_tools().
|
||||
- God Mode: supports interrupt_before for human-in-the-loop approval.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -17,6 +22,8 @@ from langchain_openai import ChatOpenAI
|
|||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from state import CouncilState, APPROVAL_THRESHOLD, MAX_ITERATIONS
|
||||
from tools.web_search import web_search
|
||||
from tools.pdf_reader import pdf_search
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -50,6 +57,78 @@ def _get_llm(model_name: str) -> Any:
|
|||
return factory()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_tools(tools_config: Optional[dict]) -> list:
|
||||
"""
|
||||
Resolve a node's tools config to a list of LangChain tool objects.
|
||||
|
||||
Args:
|
||||
tools_config: Dict like {"webSearch": true, "pdfReader": true}
|
||||
|
||||
Returns:
|
||||
A list of LangChain tool objects to bind to the LLM.
|
||||
"""
|
||||
if not tools_config:
|
||||
return []
|
||||
|
||||
resolved = []
|
||||
if tools_config.get("webSearch"):
|
||||
resolved.append(web_search)
|
||||
if tools_config.get("pdfReader"):
|
||||
resolved.append(pdf_search)
|
||||
return resolved
|
||||
|
||||
|
||||
def _invoke_with_tools(llm: Any, messages: list, tools: list) -> Any:
|
||||
"""
|
||||
Invoke an LLM, optionally with tools bound. If the LLM returns tool
|
||||
calls, execute them and feed results back for a final answer.
|
||||
|
||||
Args:
|
||||
llm: A LangChain chat model instance.
|
||||
messages: The message list to send.
|
||||
tools: List of LangChain tools (may be empty).
|
||||
|
||||
Returns:
|
||||
The final LLM response message.
|
||||
"""
|
||||
if not tools:
|
||||
return llm.invoke(messages)
|
||||
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
response = llm_with_tools.invoke(messages)
|
||||
|
||||
# If no tool calls, return directly
|
||||
if not response.tool_calls:
|
||||
return response
|
||||
|
||||
# Execute tool calls and collect results
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_map = {t.name: t for t in tools}
|
||||
tool_messages = [response]
|
||||
|
||||
for tc in response.tool_calls:
|
||||
tool_fn = tool_map.get(tc["name"])
|
||||
if tool_fn:
|
||||
try:
|
||||
result = tool_fn.invoke(tc["args"])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
result = f"[Tool Error] {exc}"
|
||||
else:
|
||||
result = f"[Tool Error] Unknown tool: {tc['name']}"
|
||||
|
||||
tool_messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=tc["id"])
|
||||
)
|
||||
|
||||
# Final LLM call with tool results
|
||||
return llm_with_tools.invoke(messages + tool_messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generic agent node factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -59,6 +138,7 @@ def _make_agent_node(
|
|||
label: str,
|
||||
system_prompt: str,
|
||||
model_name: str,
|
||||
tools_config: Optional[dict] = None,
|
||||
) -> Callable[[CouncilState], dict]:
|
||||
"""
|
||||
Create a LangGraph node function for a user-defined agent.
|
||||
|
|
@ -71,10 +151,12 @@ def _make_agent_node(
|
|||
label: Display name of the agent (used in prompts).
|
||||
system_prompt: The persona / role definition for this agent.
|
||||
model_name: Which LLM to use ("claude-3-5-sonnet" | "gpt-4o").
|
||||
tools_config: Optional dict like {"webSearch": true, "pdfReader": true}.
|
||||
|
||||
Returns:
|
||||
A callable (CouncilState) -> dict suitable for StateGraph.add_node().
|
||||
"""
|
||||
node_tools = _resolve_tools(tools_config)
|
||||
|
||||
def agent_node(state: CouncilState) -> dict:
|
||||
llm = _get_llm(model_name)
|
||||
|
|
@ -105,7 +187,7 @@ def _make_agent_node(
|
|||
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
user_msg = HumanMessage(content=user_content)
|
||||
response = llm.invoke([system_msg, user_msg])
|
||||
response = _invoke_with_tools(llm, [system_msg, user_msg], node_tools)
|
||||
|
||||
return {
|
||||
"current_draft": response.content,
|
||||
|
|
@ -177,6 +259,7 @@ def _make_critic_node(
|
|||
label: str,
|
||||
system_prompt: str,
|
||||
model_name: str,
|
||||
tools_config: Optional[dict] = None,
|
||||
) -> Callable[[CouncilState], dict]:
|
||||
"""
|
||||
Create a critic-style node that scores and routes.
|
||||
|
|
@ -186,6 +269,8 @@ def _make_critic_node(
|
|||
"""
|
||||
import re
|
||||
|
||||
node_tools = _resolve_tools(tools_config)
|
||||
|
||||
critic_system = (
|
||||
system_prompt + "\n\n"
|
||||
"IMPORTANT: You must respond in EXACTLY this format:\n\n"
|
||||
|
|
@ -219,7 +304,7 @@ def _make_critic_node(
|
|||
)
|
||||
)
|
||||
|
||||
response = llm.invoke([system_msg, user_msg])
|
||||
response = _invoke_with_tools(llm, [system_msg, user_msg], node_tools)
|
||||
|
||||
# Parse structured response
|
||||
score_match = re.search(r"SCORE:\s*(\d+(?:\.\d+)?)", response.content)
|
||||
|
|
@ -251,7 +336,10 @@ def _make_critic_node(
|
|||
# Main: build graph from blueprint JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_graph_from_blueprint(blueprint: dict) -> Any:
|
||||
def build_graph_from_blueprint(
|
||||
blueprint: dict,
|
||||
god_mode: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Dynamically construct a compiled LangGraph from a CouncilBlueprint JSON.
|
||||
|
||||
|
|
@ -263,6 +351,8 @@ def build_graph_from_blueprint(blueprint: dict) -> Any:
|
|||
"nodes": [{"id", "label", "systemPrompt", "model", "tools", "position"}],
|
||||
"edges": [{"id", "source", "target", "type", "condition?"}]
|
||||
}
|
||||
god_mode: If True, compile with interrupt_before on all nodes so the
|
||||
user can approve/reject at each step (Human-in-the-Loop).
|
||||
|
||||
Returns:
|
||||
A compiled LangGraph StateGraph ready for invocation.
|
||||
|
|
@ -295,16 +385,23 @@ def build_graph_from_blueprint(blueprint: dict) -> Any:
|
|||
graph = StateGraph(CouncilState)
|
||||
|
||||
# Register all nodes
|
||||
all_node_ids = []
|
||||
for node in nodes:
|
||||
nid = node["id"]
|
||||
all_node_ids.append(nid)
|
||||
label = node.get("label", nid)
|
||||
system_prompt = node.get("systemPrompt", f"You are {label}.")
|
||||
model_name = node.get("model", "claude-3-5-sonnet")
|
||||
tools_config = node.get("tools")
|
||||
|
||||
if _is_critic_like(system_prompt):
|
||||
node_fn = _make_critic_node(nid, label, system_prompt, model_name)
|
||||
node_fn = _make_critic_node(
|
||||
nid, label, system_prompt, model_name, tools_config
|
||||
)
|
||||
else:
|
||||
node_fn = _make_agent_node(nid, label, system_prompt, model_name)
|
||||
node_fn = _make_agent_node(
|
||||
nid, label, system_prompt, model_name, tools_config
|
||||
)
|
||||
|
||||
graph.add_node(nid, node_fn)
|
||||
|
||||
|
|
@ -349,6 +446,10 @@ def build_graph_from_blueprint(blueprint: dict) -> Any:
|
|||
if tid not in edges_by_source:
|
||||
graph.add_edge(tid, END)
|
||||
|
||||
# God Mode: interrupt before every node so the user can approve/reject
|
||||
if god_mode:
|
||||
return graph.compile(interrupt_before=all_node_ids)
|
||||
|
||||
return graph.compile()
|
||||
|
||||
|
||||
|
|
@ -356,20 +457,65 @@ async def run_blueprint_council_async(
|
|||
blueprint: dict,
|
||||
input_topic: str,
|
||||
run_id: str,
|
||||
god_mode: bool = False,
|
||||
on_node_event: Optional[Callable[[str, str], Any]] = None,
|
||||
) -> CouncilState:
|
||||
"""
|
||||
Execute a council run using a dynamically built graph from a blueprint.
|
||||
|
||||
In auto-pilot mode, the graph runs to completion.
|
||||
In god mode, the graph pauses before each node via interrupt_before,
|
||||
allowing human approval through the resume mechanism.
|
||||
|
||||
Args:
|
||||
blueprint: The CouncilBlueprint JSON dict.
|
||||
input_topic: The user's prompt.
|
||||
run_id: Unique identifier for this run.
|
||||
god_mode: If True, pause before each node for human approval.
|
||||
on_node_event: Optional callback for WebSocket node events.
|
||||
|
||||
Returns:
|
||||
The final CouncilState after execution completes.
|
||||
"""
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
if god_mode:
|
||||
checkpointer = MemorySaver()
|
||||
nodes_list = blueprint.get("nodes", [])
|
||||
all_node_ids = [n["id"] for n in nodes_list]
|
||||
compiled_graph = _build_graph_with_checkpointer(
|
||||
blueprint, checkpointer, all_node_ids
|
||||
)
|
||||
|
||||
initial_state = CouncilState(
|
||||
input_topic=input_topic,
|
||||
current_draft="",
|
||||
feedback_history=[],
|
||||
route_decision="",
|
||||
messages=[],
|
||||
iteration_count=0,
|
||||
critic_score=None,
|
||||
run_id=run_id,
|
||||
active_node="",
|
||||
)
|
||||
|
||||
thread_config = {"configurable": {"thread_id": run_id}}
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
state = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: compiled_graph.invoke(initial_state, config=thread_config),
|
||||
)
|
||||
|
||||
# Store the graph and checkpointer for later resume
|
||||
_god_mode_sessions[run_id] = {
|
||||
"graph": compiled_graph,
|
||||
"checkpointer": checkpointer,
|
||||
"thread_config": thread_config,
|
||||
}
|
||||
|
||||
return state
|
||||
|
||||
compiled_graph = build_graph_from_blueprint(blueprint)
|
||||
|
||||
initial_state = CouncilState(
|
||||
|
|
@ -391,3 +537,153 @@ async def run_blueprint_council_async(
|
|||
)
|
||||
|
||||
return final_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# God Mode session management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# In-memory store for active god mode sessions (graph + checkpointer)
|
||||
_god_mode_sessions: Dict[str, dict] = {}
|
||||
|
||||
|
||||
def _build_graph_with_checkpointer(
|
||||
blueprint: dict,
|
||||
checkpointer: Any,
|
||||
interrupt_node_ids: List[str],
|
||||
) -> Any:
|
||||
"""Build a compiled graph with a checkpointer for god mode."""
|
||||
nodes = blueprint.get("nodes", [])
|
||||
edges = blueprint.get("edges", [])
|
||||
|
||||
if not nodes:
|
||||
raise ValueError("Blueprint has no nodes.")
|
||||
|
||||
node_lookup = {n["id"]: n for n in nodes}
|
||||
targets = {e["target"] for e in edges}
|
||||
entry_candidates = [n["id"] for n in nodes if n["id"] not in targets]
|
||||
if not entry_candidates:
|
||||
entry_candidates = [nodes[0]["id"]]
|
||||
entry_node_id = entry_candidates[0]
|
||||
|
||||
sources = {e["source"] for e in edges}
|
||||
terminal_nodes = {n["id"] for n in nodes if n["id"] not in sources}
|
||||
|
||||
graph = StateGraph(CouncilState)
|
||||
|
||||
for node in nodes:
|
||||
nid = node["id"]
|
||||
label = node.get("label", nid)
|
||||
system_prompt = node.get("systemPrompt", f"You are {label}.")
|
||||
model_name = node.get("model", "claude-3-5-sonnet")
|
||||
tools_config = node.get("tools")
|
||||
|
||||
if _is_critic_like(system_prompt):
|
||||
node_fn = _make_critic_node(
|
||||
nid, label, system_prompt, model_name, tools_config
|
||||
)
|
||||
else:
|
||||
node_fn = _make_agent_node(
|
||||
nid, label, system_prompt, model_name, tools_config
|
||||
)
|
||||
|
||||
graph.add_node(nid, node_fn)
|
||||
|
||||
graph.set_entry_point(entry_node_id)
|
||||
|
||||
edges_by_source: Dict[str, Dict[str, list]] = {}
|
||||
for edge in edges:
|
||||
src = edge["source"]
|
||||
if src not in edges_by_source:
|
||||
edges_by_source[src] = {"linear": [], "conditional": []}
|
||||
if edge.get("type") == "conditional":
|
||||
edges_by_source[src]["conditional"].append(edge)
|
||||
else:
|
||||
edges_by_source[src]["linear"].append(edge)
|
||||
|
||||
for source_id, grouped in edges_by_source.items():
|
||||
linear = grouped["linear"]
|
||||
conditional = grouped["conditional"]
|
||||
|
||||
if conditional:
|
||||
linear_target = linear[0]["target"] if linear else None
|
||||
router = _make_conditional_router(source_id, conditional, linear_target)
|
||||
route_map: Dict[str, str] = {}
|
||||
for ce in conditional:
|
||||
route_map[ce["target"]] = ce["target"]
|
||||
if linear_target:
|
||||
route_map[linear_target] = linear_target
|
||||
graph.add_conditional_edges(source_id, router, route_map)
|
||||
elif linear:
|
||||
graph.add_edge(source_id, linear[0]["target"])
|
||||
|
||||
for tid in terminal_nodes:
|
||||
if tid not in edges_by_source:
|
||||
graph.add_edge(tid, END)
|
||||
|
||||
return graph.compile(
|
||||
checkpointer=checkpointer,
|
||||
interrupt_before=interrupt_node_ids,
|
||||
)
|
||||
|
||||
|
||||
async def resume_god_mode(
|
||||
run_id: str,
|
||||
action: str = "approve",
|
||||
modified_state: Optional[dict] = None,
|
||||
) -> Optional[CouncilState]:
|
||||
"""
|
||||
Resume a paused god mode run after human approval.
|
||||
|
||||
Args:
|
||||
run_id: The run ID of the paused session.
|
||||
action: "approve" to continue, "reject" to stop.
|
||||
modified_state: Optional partial state override (for "modify" action).
|
||||
|
||||
Returns:
|
||||
The next CouncilState (may be another interrupt or final).
|
||||
None if the run_id is not found.
|
||||
"""
|
||||
session = _god_mode_sessions.get(run_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
if action == "reject":
|
||||
_god_mode_sessions.pop(run_id, None)
|
||||
return None
|
||||
|
||||
compiled_graph = session["graph"]
|
||||
thread_config = session["thread_config"]
|
||||
|
||||
if modified_state:
|
||||
compiled_graph.update_state(thread_config, modified_state)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
state = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: compiled_graph.invoke(None, config=thread_config),
|
||||
)
|
||||
|
||||
# If state indicates completion, clean up
|
||||
if state and state.get("route_decision") == "done":
|
||||
_god_mode_sessions.pop(run_id, None)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def get_god_mode_state(run_id: str) -> Optional[dict]:
|
||||
"""Get the current state of a paused god mode session."""
|
||||
session = _god_mode_sessions.get(run_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
graph = session["graph"]
|
||||
thread_config = session["thread_config"]
|
||||
snapshot = graph.get_state(thread_config)
|
||||
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"paused": bool(snapshot.next),
|
||||
"next_nodes": list(snapshot.next) if snapshot.next else [],
|
||||
"current_state": dict(snapshot.values) if snapshot.values else {},
|
||||
}
|
||||
|
|
|
|||
80
backend/services/run_service.py
Normal file
80
backend/services/run_service.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""
|
||||
Run Service — CRUD operations for persisted council runs.
|
||||
|
||||
Provides async functions to create, read, update, and list council runs
|
||||
in PostgreSQL. Works alongside the in-memory run_store which handles
|
||||
real-time status during execution.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.council_run import CouncilRun
|
||||
|
||||
|
||||
async def create_run(
|
||||
session: AsyncSession,
|
||||
run_id: str,
|
||||
input_topic: str,
|
||||
blueprint_id: Optional[str] = None,
|
||||
execution_mode: str = "auto-pilot",
|
||||
) -> CouncilRun:
|
||||
"""Create a new council run record."""
|
||||
run = CouncilRun(
|
||||
id=run_id,
|
||||
blueprint_id=blueprint_id,
|
||||
input_topic=input_topic,
|
||||
status="pending",
|
||||
execution_mode=execution_mode,
|
||||
)
|
||||
session.add(run)
|
||||
await session.commit()
|
||||
await session.refresh(run)
|
||||
return run
|
||||
|
||||
|
||||
async def get_run(session: AsyncSession, run_id: str) -> Optional[CouncilRun]:
|
||||
"""Get a council run by ID."""
|
||||
result = await session.execute(select(CouncilRun).where(CouncilRun.id == run_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def list_runs(
|
||||
session: AsyncSession,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> List[CouncilRun]:
|
||||
"""List council runs, ordered by most recent first."""
|
||||
result = await session.execute(
|
||||
select(CouncilRun)
|
||||
.order_by(CouncilRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def update_run(
|
||||
session: AsyncSession,
|
||||
run_id: str,
|
||||
updates: dict,
|
||||
) -> Optional[CouncilRun]:
|
||||
"""Update a council run with the given fields."""
|
||||
run = await get_run(session, run_id)
|
||||
if run is None:
|
||||
return None
|
||||
|
||||
for key, value in updates.items():
|
||||
if hasattr(run, key):
|
||||
setattr(run, key, value)
|
||||
|
||||
# Auto-set completed_at when status becomes terminal
|
||||
if updates.get("status") in ("completed", "failed"):
|
||||
run.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(run)
|
||||
return run
|
||||
192
backend/tests/test_god_mode.py
Normal file
192
backend/tests/test_god_mode.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
"""
|
||||
Tests for God Mode (interrupt_before) functionality.
|
||||
|
||||
All LLM calls are mocked — no real API calls are made in these tests.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from state import CouncilState
|
||||
|
||||
|
||||
class TestBuildGraphGodMode:
|
||||
"""Tests for graph compilation with god mode (interrupt_before)."""
|
||||
|
||||
def _make_simple_blueprint(self):
|
||||
return {
|
||||
"version": 1,
|
||||
"name": "Test Council",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "master",
|
||||
"label": "Master AI",
|
||||
"systemPrompt": "You are the master writer.",
|
||||
"model": "claude-3-5-sonnet",
|
||||
"tools": {"webSearch": False, "pdfReader": False},
|
||||
},
|
||||
{
|
||||
"id": "critic",
|
||||
"label": "Critic AI",
|
||||
"systemPrompt": "You are a critic who evaluates and scores drafts.",
|
||||
"model": "claude-3-5-sonnet",
|
||||
"tools": {"webSearch": False, "pdfReader": False},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "master", "target": "critic", "type": "linear"},
|
||||
],
|
||||
}
|
||||
|
||||
@patch("services.dynamic_graph_builder._get_llm")
|
||||
def test_build_graph_with_god_mode_compiles(self, mock_get_llm):
|
||||
"""God mode graph should compile without error."""
|
||||
from services.dynamic_graph_builder import build_graph_from_blueprint
|
||||
|
||||
blueprint = self._make_simple_blueprint()
|
||||
graph = build_graph_from_blueprint(blueprint, god_mode=False)
|
||||
assert graph is not None
|
||||
|
||||
def test_build_graph_without_god_mode(self):
|
||||
"""Normal graph should compile without interrupt_before."""
|
||||
from services.dynamic_graph_builder import build_graph_from_blueprint
|
||||
|
||||
blueprint = self._make_simple_blueprint()
|
||||
graph = build_graph_from_blueprint(blueprint, god_mode=False)
|
||||
assert graph is not None
|
||||
|
||||
|
||||
class TestGodModeSessionManagement:
|
||||
"""Tests for god mode session management functions."""
|
||||
|
||||
def test_get_god_mode_state_returns_none_for_unknown_run(self):
|
||||
from services.dynamic_graph_builder import get_god_mode_state
|
||||
|
||||
result = get_god_mode_state("nonexistent-run-id")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_god_mode_returns_none_for_unknown_run(self):
|
||||
from services.dynamic_graph_builder import resume_god_mode
|
||||
|
||||
result = await resume_god_mode("nonexistent-run-id", action="approve")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_god_mode_reject_cleans_up(self):
|
||||
from services.dynamic_graph_builder import (
|
||||
_god_mode_sessions,
|
||||
resume_god_mode,
|
||||
)
|
||||
|
||||
# Manually insert a fake session
|
||||
_god_mode_sessions["test-run"] = {
|
||||
"graph": MagicMock(),
|
||||
"checkpointer": MagicMock(),
|
||||
"thread_config": {"configurable": {"thread_id": "test-run"}},
|
||||
}
|
||||
|
||||
result = await resume_god_mode("test-run", action="reject")
|
||||
assert result is None
|
||||
assert "test-run" not in _god_mode_sessions
|
||||
|
||||
|
||||
class TestToolResolution:
|
||||
"""Tests for the tool resolution helper."""
|
||||
|
||||
def test_resolve_tools_none_config(self):
|
||||
from services.dynamic_graph_builder import _resolve_tools
|
||||
|
||||
assert _resolve_tools(None) == []
|
||||
|
||||
def test_resolve_tools_empty_config(self):
|
||||
from services.dynamic_graph_builder import _resolve_tools
|
||||
|
||||
assert _resolve_tools({}) == []
|
||||
|
||||
def test_resolve_tools_web_search_only(self):
|
||||
from services.dynamic_graph_builder import _resolve_tools
|
||||
|
||||
tools = _resolve_tools({"webSearch": True, "pdfReader": False})
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "web_search"
|
||||
|
||||
def test_resolve_tools_pdf_only(self):
|
||||
from services.dynamic_graph_builder import _resolve_tools
|
||||
|
||||
tools = _resolve_tools({"webSearch": False, "pdfReader": True})
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "pdf_search"
|
||||
|
||||
def test_resolve_tools_both(self):
|
||||
from services.dynamic_graph_builder import _resolve_tools
|
||||
|
||||
tools = _resolve_tools({"webSearch": True, "pdfReader": True})
|
||||
assert len(tools) == 2
|
||||
names = {t.name for t in tools}
|
||||
assert names == {"web_search", "pdf_search"}
|
||||
|
||||
|
||||
class TestInvokeWithTools:
|
||||
"""Tests for the _invoke_with_tools helper."""
|
||||
|
||||
def test_invoke_without_tools_calls_llm_directly(self):
|
||||
from services.dynamic_graph_builder import _invoke_with_tools
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Test response"
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
|
||||
result = _invoke_with_tools(mock_llm, ["msg1", "msg2"], [])
|
||||
mock_llm.invoke.assert_called_once_with(["msg1", "msg2"])
|
||||
assert result == mock_response
|
||||
|
||||
def test_invoke_with_tools_no_tool_calls(self):
|
||||
from services.dynamic_graph_builder import _invoke_with_tools
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_bound = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_bound
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.tool_calls = []
|
||||
mock_response.content = "No tools needed"
|
||||
mock_bound.invoke.return_value = mock_response
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
|
||||
result = _invoke_with_tools(mock_llm, ["msg"], [mock_tool])
|
||||
assert result == mock_response
|
||||
|
||||
def test_invoke_with_tools_executes_tool_calls(self):
|
||||
from services.dynamic_graph_builder import _invoke_with_tools
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_bound = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_bound
|
||||
|
||||
# First call returns tool_calls
|
||||
mock_response_with_tools = MagicMock()
|
||||
mock_response_with_tools.tool_calls = [
|
||||
{"name": "web_search", "args": {"query": "test"}, "id": "call-1"}
|
||||
]
|
||||
|
||||
# Second call returns final answer
|
||||
mock_final_response = MagicMock()
|
||||
mock_final_response.content = "Final answer"
|
||||
mock_bound.invoke.side_effect = [mock_response_with_tools, mock_final_response]
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "web_search"
|
||||
mock_tool.invoke.return_value = "Search results"
|
||||
|
||||
result = _invoke_with_tools(mock_llm, ["msg"], [mock_tool])
|
||||
mock_tool.invoke.assert_called_once_with({"query": "test"})
|
||||
assert result == mock_final_response
|
||||
82
backend/tests/test_run_service.py
Normal file
82
backend/tests/test_run_service.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
Tests for the run history service and CouncilRun model.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
class TestCouncilRunModel:
|
||||
"""Tests for the CouncilRun SQLAlchemy model."""
|
||||
|
||||
def test_to_dict_serialization(self):
|
||||
from models.council_run import CouncilRun
|
||||
from datetime import datetime, timezone
|
||||
|
||||
run = CouncilRun(
|
||||
id="test-id",
|
||||
blueprint_id="bp-id",
|
||||
input_topic="Test topic",
|
||||
status="completed",
|
||||
execution_mode="auto-pilot",
|
||||
final_draft="Final text",
|
||||
critic_score=8.5,
|
||||
iteration_count=3,
|
||||
active_node="done",
|
||||
error=None,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
completed_at=datetime(2026, 1, 1, 0, 5, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
d = run.to_dict()
|
||||
assert d["id"] == "test-id"
|
||||
assert d["blueprint_id"] == "bp-id"
|
||||
assert d["status"] == "completed"
|
||||
assert d["critic_score"] == 8.5
|
||||
assert d["iteration_count"] == 3
|
||||
assert d["created_at"] is not None
|
||||
assert d["completed_at"] is not None
|
||||
|
||||
def test_to_dict_with_none_timestamps(self):
|
||||
from models.council_run import CouncilRun
|
||||
|
||||
run = CouncilRun(
|
||||
id="test-id",
|
||||
input_topic="Test",
|
||||
status="pending",
|
||||
execution_mode="god-mode",
|
||||
created_at=None,
|
||||
completed_at=None,
|
||||
)
|
||||
|
||||
d = run.to_dict()
|
||||
assert d["created_at"] is None
|
||||
assert d["completed_at"] is None
|
||||
assert d["execution_mode"] == "god-mode"
|
||||
|
||||
|
||||
class TestRunHistoryRoutes:
|
||||
"""Tests for the run history API routes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_runs_empty(self):
|
||||
"""List runs returns empty list when no runs exist."""
|
||||
from api.run_history_routes import list_all_runs
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = []
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with patch("services.run_service.list_runs") as mock_list:
|
||||
mock_list.return_value = []
|
||||
result = await list_all_runs(limit=50, offset=0, session=mock_session)
|
||||
assert result == []
|
||||
170
backend/tests/test_tools.py
Normal file
170
backend/tests/test_tools.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
"""
|
||||
Tests for agent tools (web search and PDF reader).
|
||||
|
||||
All external API calls are mocked — no real calls to Tavily or ChromaDB.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
"""Tests for the Tavily web search tool."""
|
||||
|
||||
@patch.dict(os.environ, {"TAVILY_API_KEY": ""}, clear=False)
|
||||
def test_web_search_returns_error_without_api_key(self):
|
||||
from tools.web_search import web_search
|
||||
|
||||
result = web_search.invoke({"query": "test query"})
|
||||
assert "TAVILY_API_KEY" in result
|
||||
|
||||
@patch.dict(os.environ, {"TAVILY_API_KEY": "test-key"}, clear=False)
|
||||
@patch("tools.web_search.TavilyClient")
|
||||
def test_web_search_returns_formatted_results(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client.search.return_value = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Result",
|
||||
"url": "https://example.com",
|
||||
"content": "Some content here",
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
from tools.web_search import web_search
|
||||
|
||||
result = web_search.invoke({"query": "test query"})
|
||||
assert "Test Result" in result
|
||||
assert "https://example.com" in result
|
||||
assert "Some content here" in result
|
||||
|
||||
@patch.dict(os.environ, {"TAVILY_API_KEY": "test-key"}, clear=False)
|
||||
@patch("tools.web_search.TavilyClient")
|
||||
def test_web_search_handles_empty_results(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client.search.return_value = {"results": []}
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
from tools.web_search import web_search
|
||||
|
||||
result = web_search.invoke({"query": "obscure query"})
|
||||
assert "No results" in result
|
||||
|
||||
@patch.dict(os.environ, {"TAVILY_API_KEY": "test-key"}, clear=False)
|
||||
@patch("tools.web_search.TavilyClient")
|
||||
def test_web_search_handles_api_error(self, mock_client_cls):
|
||||
mock_client = MagicMock()
|
||||
mock_client.search.side_effect = Exception("API rate limit")
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
from tools.web_search import web_search
|
||||
|
||||
result = web_search.invoke({"query": "test"})
|
||||
assert "Error" in result
|
||||
assert "rate limit" in result
|
||||
|
||||
|
||||
class TestCreateWebSearchTool:
|
||||
"""Tests for the web search tool factory."""
|
||||
|
||||
@patch.dict(os.environ, {"TAVILY_API_KEY": "test-key"}, clear=False)
|
||||
def test_factory_returns_tool_when_key_set(self):
|
||||
from tools.web_search import create_web_search_tool
|
||||
|
||||
tool = create_web_search_tool()
|
||||
assert tool is not None
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_factory_returns_none_when_key_missing(self):
|
||||
from tools.web_search import create_web_search_tool
|
||||
|
||||
tool = create_web_search_tool()
|
||||
assert tool is None
|
||||
|
||||
|
||||
class TestPdfSearchTool:
|
||||
"""Tests for the PDF reader tool."""
|
||||
|
||||
@patch("tools.pdf_reader._get_chroma_collection")
|
||||
def test_pdf_search_empty_collection(self, mock_get_collection):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.count.return_value = 0
|
||||
mock_get_collection.return_value = mock_collection
|
||||
|
||||
from tools.pdf_reader import pdf_search
|
||||
|
||||
result = pdf_search.invoke({"query": "test query"})
|
||||
assert "No documents" in result
|
||||
|
||||
@patch("tools.pdf_reader._get_chroma_collection")
|
||||
def test_pdf_search_returns_results(self, mock_get_collection):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.count.return_value = 3
|
||||
mock_collection.query.return_value = {
|
||||
"documents": [["First passage about AI.", "Second passage about ML."]],
|
||||
"metadatas": [[
|
||||
{"source": "paper.pdf", "page": 1},
|
||||
{"source": "paper.pdf", "page": 3},
|
||||
]],
|
||||
}
|
||||
mock_get_collection.return_value = mock_collection
|
||||
|
||||
from tools.pdf_reader import pdf_search
|
||||
|
||||
result = pdf_search.invoke({"query": "AI concepts"})
|
||||
assert "paper.pdf" in result
|
||||
assert "First passage" in result
|
||||
assert "Page 1" in result
|
||||
|
||||
@patch("tools.pdf_reader._get_chroma_collection")
|
||||
def test_pdf_search_handles_error(self, mock_get_collection):
|
||||
mock_get_collection.side_effect = Exception("ChromaDB unavailable")
|
||||
|
||||
from tools.pdf_reader import pdf_search
|
||||
|
||||
result = pdf_search.invoke({"query": "test"})
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestPdfIngestion:
|
||||
"""Tests for PDF ingestion into ChromaDB."""
|
||||
|
||||
@patch("tools.pdf_reader._get_chroma_collection")
|
||||
@patch("tools.pdf_reader.PdfReader")
|
||||
def test_ingest_pdf_processes_pages(self, mock_pdf_reader_cls, mock_get_collection):
|
||||
# Mock PDF with 2 pages of text
|
||||
mock_page1 = MagicMock()
|
||||
mock_page1.extract_text.return_value = "This is the first page with some content " * 20
|
||||
mock_page2 = MagicMock()
|
||||
mock_page2.extract_text.return_value = "Second page about machine learning " * 20
|
||||
mock_reader = MagicMock()
|
||||
mock_reader.pages = [mock_page1, mock_page2]
|
||||
mock_pdf_reader_cls.return_value = mock_reader
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_get_collection.return_value = mock_collection
|
||||
|
||||
from tools.pdf_reader import ingest_pdf
|
||||
|
||||
chunks = ingest_pdf("/tmp/test.pdf")
|
||||
assert chunks > 0
|
||||
mock_collection.upsert.assert_called_once()
|
||||
|
||||
@patch("tools.pdf_reader._get_chroma_collection")
|
||||
@patch("tools.pdf_reader.PdfReader")
|
||||
def test_ingest_pdf_empty_file(self, mock_pdf_reader_cls, mock_get_collection):
|
||||
mock_reader = MagicMock()
|
||||
mock_reader.pages = []
|
||||
mock_pdf_reader_cls.return_value = mock_reader
|
||||
|
||||
from tools.pdf_reader import ingest_pdf
|
||||
|
||||
chunks = ingest_pdf("/tmp/empty.pdf")
|
||||
assert chunks == 0
|
||||
|
|
@ -1,7 +1,12 @@
|
|||
"""
|
||||
Agent tools for CouncilOS.
|
||||
"""Agent tools for CouncilOS."""
|
||||
|
||||
Phase 4 will add:
|
||||
- web_search_tool: Tavily Search API wrapper
|
||||
- pdf_reader_tool: PyPDF + ChromaDB vector store wrapper
|
||||
"""
|
||||
from .web_search import web_search, create_web_search_tool
|
||||
from .pdf_reader import pdf_search, ingest_pdf, create_pdf_search_tool
|
||||
|
||||
__all__ = [
|
||||
"web_search",
|
||||
"create_web_search_tool",
|
||||
"pdf_search",
|
||||
"ingest_pdf",
|
||||
"create_pdf_search_tool",
|
||||
]
|
||||
|
|
|
|||
140
backend/tools/pdf_reader.py
Normal file
140
backend/tools/pdf_reader.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
PDF Reader Tool — PyPDF + ChromaDB vector store wrapper for agent nodes.
|
||||
|
||||
Loads PDF files, splits them into chunks, stores embeddings in a local
|
||||
ChromaDB collection, and performs similarity search against queries.
|
||||
Requires the CHROMA_PERSIST_DIR environment variable for storage location.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
# Module-level collection cache to avoid re-initializing on every call
|
||||
_collection_cache: dict = {}
|
||||
|
||||
|
||||
def _get_chroma_collection(collection_name: str = "council_pdfs"):
|
||||
"""Get or create a ChromaDB collection for PDF content."""
|
||||
if collection_name in _collection_cache:
|
||||
return _collection_cache[collection_name]
|
||||
|
||||
import chromadb
|
||||
|
||||
persist_dir = os.environ.get("CHROMA_PERSIST_DIR", "./chroma_db")
|
||||
client = chromadb.PersistentClient(path=persist_dir)
|
||||
collection = client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
_collection_cache[collection_name] = collection
|
||||
return collection
|
||||
|
||||
|
||||
def ingest_pdf(file_path: str, collection_name: str = "council_pdfs") -> int:
|
||||
"""
|
||||
Read a PDF file, split into chunks, and store in ChromaDB.
|
||||
|
||||
Args:
|
||||
file_path: Path to the PDF file.
|
||||
collection_name: ChromaDB collection name.
|
||||
|
||||
Returns:
|
||||
Number of chunks ingested.
|
||||
"""
|
||||
from pypdf import PdfReader
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
chunks: List[str] = []
|
||||
metadata_list: List[dict] = []
|
||||
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
text = page.extract_text()
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
# Split long pages into ~500 character chunks with overlap
|
||||
words = text.split()
|
||||
chunk_size = 100 # words per chunk
|
||||
overlap = 20
|
||||
|
||||
for i in range(0, len(words), chunk_size - overlap):
|
||||
chunk_words = words[i : i + chunk_size]
|
||||
chunk_text = " ".join(chunk_words)
|
||||
if chunk_text.strip():
|
||||
chunks.append(chunk_text)
|
||||
metadata_list.append({
|
||||
"source": os.path.basename(file_path),
|
||||
"page": page_num + 1,
|
||||
})
|
||||
|
||||
if not chunks:
|
||||
return 0
|
||||
|
||||
collection = _get_chroma_collection(collection_name)
|
||||
|
||||
# Generate deterministic IDs based on file and chunk position
|
||||
ids = [
|
||||
f"{os.path.basename(file_path)}_chunk_{i}"
|
||||
for i in range(len(chunks))
|
||||
]
|
||||
|
||||
collection.upsert(
|
||||
documents=chunks,
|
||||
metadatas=metadata_list,
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
return len(chunks)
|
||||
|
||||
|
||||
@tool
|
||||
def pdf_search(query: str, n_results: int = 5) -> str:
|
||||
"""
|
||||
Search the PDF knowledge base for information relevant to a query.
|
||||
|
||||
Args:
|
||||
query: The search query to find relevant PDF content.
|
||||
n_results: Number of results to return (default 5).
|
||||
|
||||
Returns:
|
||||
A formatted string with relevant passages from ingested PDFs.
|
||||
"""
|
||||
try:
|
||||
collection = _get_chroma_collection()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return f"[PDF Search Error] Could not access vector store: {exc}"
|
||||
|
||||
if collection.count() == 0:
|
||||
return "[PDF Search] No documents have been ingested yet."
|
||||
|
||||
try:
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=min(n_results, collection.count()),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return f"[PDF Search Error] {exc}"
|
||||
|
||||
documents = results.get("documents", [[]])[0]
|
||||
metadatas = results.get("metadatas", [[]])[0]
|
||||
|
||||
if not documents:
|
||||
return f"No relevant passages found for: {query}"
|
||||
|
||||
formatted = []
|
||||
for i, (doc, meta) in enumerate(zip(documents, metadatas), 1):
|
||||
source = meta.get("source", "unknown")
|
||||
page = meta.get("page", "?")
|
||||
formatted.append(f"{i}. [Source: {source}, Page {page}]\n {doc}")
|
||||
|
||||
return "\n\n".join(formatted)
|
||||
|
||||
|
||||
def create_pdf_search_tool() -> Optional[tool]:
|
||||
"""Factory that returns the pdf_search tool if ChromaDB is configured."""
|
||||
persist_dir = os.environ.get("CHROMA_PERSIST_DIR", "./chroma_db")
|
||||
if persist_dir:
|
||||
return pdf_search
|
||||
return None
|
||||
61
backend/tools/web_search.py
Normal file
61
backend/tools/web_search.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""
|
||||
Web Search Tool — Tavily Search API wrapper for agent nodes.
|
||||
|
||||
Provides a LangChain-compatible tool that agents can use to search the web
|
||||
for current information. Requires the TAVILY_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str, max_results: int = 5) -> str:
|
||||
"""
|
||||
Search the web for current information on a topic.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
max_results: Maximum number of results to return (default 5).
|
||||
|
||||
Returns:
|
||||
A formatted string with search results including titles, URLs, and snippets.
|
||||
"""
|
||||
from tavily import TavilyClient
|
||||
|
||||
api_key = os.environ.get("TAVILY_API_KEY")
|
||||
if not api_key:
|
||||
return "[Web Search Error] TAVILY_API_KEY environment variable is not set."
|
||||
|
||||
client = TavilyClient(api_key=api_key)
|
||||
|
||||
try:
|
||||
response = client.search(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
search_depth="basic",
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return f"[Web Search Error] {exc}"
|
||||
|
||||
results = response.get("results", [])
|
||||
if not results:
|
||||
return f"No results found for: {query}"
|
||||
|
||||
formatted = []
|
||||
for i, r in enumerate(results, 1):
|
||||
title = r.get("title", "No title")
|
||||
url = r.get("url", "")
|
||||
content = r.get("content", "No content available")
|
||||
formatted.append(f"{i}. **{title}**\n URL: {url}\n {content}")
|
||||
|
||||
return "\n\n".join(formatted)
|
||||
|
||||
|
||||
def create_web_search_tool() -> Optional[tool]:
|
||||
"""Factory that returns the web_search tool if Tavily is configured."""
|
||||
if os.environ.get("TAVILY_API_KEY"):
|
||||
return web_search
|
||||
return None
|
||||
Loading…
Add table
Add a link
Reference in a new issue