Notes
Notes - notes.io |
SQL Writer & Validator Agent with LangGraph - Proper Tool Architecture
This implements the 4 tools from the workflow diagram as separate nodes:
1. Schema/Glossary Retrieval Tool
2. SQL Generator Tool
3. SQL Validator Tool
4. SQL Executor Tool
Each tool is a node in the graph with proper state management.
"""
from typing import TypedDict, Annotated, List, Dict, Any
from langgraph.graph import StateGraph, END
import operator
from loguru import logger
import json
import sqlparse
from anthropic import Anthropic
from database.db_manager_sqlite import DatabaseManager
from models import SQLQuery, SQLResult
from config_sqlite import settings
# ============================================================================
# STATE DEFINITION
# ============================================================================
class SQLAgentState(TypedDict):
"""
Shared state across all tools in the SQL Agent workflow
This is passed between each tool/node
"""
# Input
query: str # Natural language query
# Schema Retrieval Tool outputs
schema_info: Dict[str, Any] # Database schema
glossary_terms: List[Dict[str, Any]] # Business glossary
# SQL Generator Tool outputs
generated_sql: str # The SQL query
explanation: str # What the SQL does
tables_used: List[str] # Tables involved
# SQL Validator Tool outputs
is_valid: bool # Validation status
validation_errors: List[str] # Any errors found
validation_attempts: int # Number of fix attempts
# SQL Executor Tool outputs
execution_success: bool
result_data: List[Dict[str, Any]]
result_columns: List[str]
row_count: int
execution_time: float
execution_error: str
# Workflow control
current_tool: str # Which tool is executing
iteration: int # For retry logic
max_iterations: int # Max retries
# Messages/logs for tracking
messages: Annotated[List[str], operator.add]
# ============================================================================
# TOOL 1: SCHEMA/GLOSSARY RETRIEVAL
# ============================================================================
class SchemaGlossaryRetrievalTool:
"""
Tool for retrieving database schema and business glossary
From workflow diagram: "Refers schema and column info, key constraints"
"""
def __init__(self, db_manager: DatabaseManager):
self.db_manager = db_manager
self.name = "Schema/Glossary Retrieval"
def execute(self, state: SQLAgentState) -> SQLAgentState:
"""
Retrieve relevant schema and glossary information
"""
logger.info(f"[{self.name}] Starting schema retrieval")
state["current_tool"] = self.name
try:
# Get all table schemas
schema_info = self.db_manager.get_database_schema_info()
state["schema_info"] = schema_info
# TODO: In future, integrate with vector store to find relevant glossary
# For now, just structure the schema info
glossary_terms = []
for table_name, table_data in schema_info.items():
for col in table_data["columns"]:
glossary_terms.append({
"term": f"{table_name}.{col['name']}",
"type": col["type"],
"description": f"Column {col['name']} in table {table_name}"
})
state["glossary_terms"] = glossary_terms
msg = f"✅ Retrieved schema for {len(schema_info)} tables, {len(glossary_terms)} columns"
state["messages"].append(f"[{self.name}] {msg}")
logger.info(msg)
except Exception as e:
msg = f"❌ Schema retrieval failed: {e}"
state["messages"].append(f"[{self.name}] {msg}")
logger.error(msg)
state["execution_error"] = str(e)
return state
# ============================================================================
# TOOL 2: SQL GENERATOR
# ============================================================================
class SQLGeneratorTool:
"""
Tool for generating SQL queries from natural language
From workflow diagram: "Generates SQL from enriched context"
"""
def __init__(self):
self.client = Anthropic(api_key=settings.ANTHROPIC_API_KEY)
self.model = settings.MODEL_NAME
self.name = "SQL Generator"
def execute(self, state: SQLAgentState) -> SQLAgentState:
"""
Generate SQL query using LLM
"""
logger.info(f"[{self.name}] Generating SQL query")
state["current_tool"] = self.name
try:
# Build schema context from state
schema_context = self._build_schema_context(state["schema_info"])
# System prompt
system_prompt = """You are an expert SQL query generator.
Generate accurate SQL queries based on the provided schema.
IMPORTANT: Return ONLY valid JSON in this exact format (no markdown, no code blocks):
{
"query": "SELECT ... FROM ... WHERE ...",
"explanation": "Brief explanation",
"tables_used": ["table1", "table2"]
}"""
user_prompt = f"""Generate a SQL query for this question:
Question: {state['query']}
Available Schema:
{schema_context}
Return the JSON now."""
# Call LLM
response = self.client.messages.create(
model=self.model,
max_tokens=2000,
temperature=0.3,
system=system_prompt,
messages=[{"role": "user", "content": user_prompt}]
)
# Parse response
response_text = response.content[0].text.strip()
# Clean markdown if present
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.startswith("```"):
response_text = response_text[3:]
if response_text.endswith("```"):
response_text = response_text[:-3]
result = json.loads(response_text.strip())
# Format SQL nicely
formatted_sql = sqlparse.format(
result["query"],
reindent=True,
keyword_case='upper'
)
# Update state
state["generated_sql"] = formatted_sql
state["explanation"] = result.get("explanation", "")
state["tables_used"] = result.get("tables_used", [])
msg = f"✅ Generated SQL using {len(state['tables_used'])} tables"
state["messages"].append(f"[{self.name}] {msg}")
logger.info(msg)
except Exception as e:
msg = f"❌ SQL generation failed: {e}"
state["messages"].append(f"[{self.name}] {msg}")
logger.error(msg)
state["generated_sql"] = ""
state["execution_error"] = str(e)
return state
def _build_schema_context(self, schema_info: Dict[str, Any]) -> str:
"""Build schema context string for LLM"""
context_parts = []
for table_name, table_data in schema_info.items():
context_parts.append(f"nTable: {table_name}")
context_parts.append("Columns:")
for col in table_data["columns"]:
nullable = "NULL" if col["nullable"] else "NOT NULL"
pk = " (PRIMARY KEY)" if col["name"] in table_data.get("primary_keys", []) else ""
context_parts.append(f" - {col['name']}: {col['type']} {nullable}{pk}")
if table_data.get("foreign_keys"):
context_parts.append("Foreign Keys:")
for fk in table_data["foreign_keys"]:
context_parts.append(
f" - {', '.join(fk['constrained_columns'])} -> "
f"{fk['referred_table']}.{', '.join(fk['referred_columns'])}"
)
return "n".join(context_parts)
# ============================================================================
# TOOL 3: SQL VALIDATOR
# ============================================================================
class SQLValidatorTool:
"""
Tool for validating SQL queries
From workflow diagram: "Validates SQL and fixes errors if needed"
"""
def __init__(self, db_manager: DatabaseManager, llm_client: Anthropic):
self.db_manager = db_manager
self.client = llm_client
self.model = settings.MODEL_NAME
self.name = "SQL Validator"
def execute(self, state: SQLAgentState) -> SQLAgentState:
"""
Validate SQL query and attempt fixes if needed
"""
logger.info(f"[{self.name}] Validating SQL")
state["current_tool"] = self.name
if not state.get("generated_sql"):
state["is_valid"] = False
state["validation_errors"] = ["No SQL query to validate"]
return state
try:
# Validate using database
is_valid, errors = self.db_manager.validate_sql(state["generated_sql"])
state["is_valid"] = is_valid
state["validation_errors"] = errors
if is_valid:
msg = f"✅ SQL is valid"
state["messages"].append(f"[{self.name}] {msg}")
logger.info(msg)
else:
msg = f"⚠️ SQL validation failed: {errors}"
state["messages"].append(f"[{self.name}] {msg}")
logger.warning(msg)
# Try to fix if we haven't exceeded max attempts
attempts = state.get("validation_attempts", 0)
if attempts < state.get("max_iterations", 3):
state["validation_attempts"] = attempts + 1
state = self._fix_sql(state)
except Exception as e:
msg = f"❌ Validation error: {e}"
state["messages"].append(f"[{self.name}] {msg}")
logger.error(msg)
state["is_valid"] = False
state["validation_errors"] = [str(e)]
return state
def _fix_sql(self, state: SQLAgentState) -> SQLAgentState:
"""
Attempt to fix SQL errors using LLM
"""
logger.info(f"[{self.name}] Attempting to fix SQL (attempt {state['validation_attempts']})")
try:
schema_context = self._build_schema_context(state["schema_info"])
system_prompt = """You are an expert at fixing SQL errors.
Analyze the error and correct the SQL query.
Return ONLY JSON (no markdown):
{
"query": "CORRECTED SQL",
"explanation": "What was fixed"
}"""
user_prompt = f"""Fix this SQL query:
Original Query:
{state['generated_sql']}
Errors:
{chr(10).join(state['validation_errors'])}
Available Schema:
{schema_context}
Return corrected SQL in JSON format."""
response = self.client.messages.create(
model=self.model,
max_tokens=2000,
temperature=0.3,
system=system_prompt,
messages=[{"role": "user", "content": user_prompt}]
)
response_text = response.content[0].text.strip()
# Clean response
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.startswith("```"):
response_text = response_text[3:]
if response_text.endswith("```"):
response_text = response_text[:-3]
result = json.loads(response_text.strip())
# Format and update
state["generated_sql"] = sqlparse.format(
result["query"],
reindent=True,
keyword_case='upper'
)
msg = f"🔧 Fixed SQL: {result.get('explanation', 'Applied corrections')}"
state["messages"].append(f"[{self.name}] {msg}")
logger.info(msg)
except Exception as e:
msg = f"❌ Fix attempt failed: {e}"
state["messages"].append(f"[{self.name}] {msg}")
logger.error(msg)
return state
def _build_schema_context(self, schema_info: Dict[str, Any]) -> str:
"""Build schema context (reuse from generator)"""
context_parts = []
for table_name, table_data in schema_info.items():
context_parts.append(f"nTable: {table_name}")
for col in table_data["columns"]:
context_parts.append(f" - {col['name']}: {col['type']}")
return "n".join(context_parts)
# ============================================================================
# TOOL 4: SQL EXECUTOR
# ============================================================================
class SQLExecutorTool:
"""
Tool for executing SQL queries
From workflow diagram: "Executes validated SQL and returns results"
"""
def __init__(self, db_manager: DatabaseManager):
self.db_manager = db_manager
self.name = "SQL Executor"
def execute(self, state: SQLAgentState) -> SQLAgentState:
"""
Execute the SQL query
"""
logger.info(f"[{self.name}] Executing SQL")
state["current_tool"] = self.name
# Check if SQL is valid before executing
if not state.get("is_valid", False):
msg = "❌ Cannot execute invalid SQL"
state["messages"].append(f"[{self.name}] {msg}")
state["execution_success"] = False
state["execution_error"] = "SQL validation failed"
return state
try:
# Execute query
result = self.db_manager.execute_query(
query=state["generated_sql"],
timeout=settings.SQL_TIMEOUT
)
# Update state with results
state["execution_success"] = result["success"]
state["result_data"] = result.get("data", [])
state["result_columns"] = result.get("columns", [])
state["row_count"] = result.get("row_count", 0)
state["execution_time"] = result.get("execution_time", 0.0)
state["execution_error"] = result.get("error", "")
if result["success"]:
msg = f"✅ Executed successfully: {state['row_count']} rows in {state['execution_time']:.3f}s"
else:
msg = f"❌ Execution failed: {state['execution_error']}"
state["messages"].append(f"[{self.name}] {msg}")
logger.info(msg)
except Exception as e:
msg = f"❌ Execution error: {e}"
state["messages"].append(f"[{self.name}] {msg}")
logger.error(msg)
state["execution_success"] = False
state["execution_error"] = str(e)
return state
# ============================================================================
# LANGGRAPH WORKFLOW
# ============================================================================
def create_sql_agent_workflow(db_manager: DatabaseManager) -> StateGraph:
"""
Create LangGraph workflow with 4 tools as separate nodes
"""
# Initialize tools
schema_tool = SchemaGlossaryRetrievalTool(db_manager)
generator_tool = SQLGeneratorTool()
validator_tool = SQLValidatorTool(db_manager, generator_tool.client)
executor_tool = SQLExecutorTool(db_manager)
# Create graph
workflow = StateGraph(SQLAgentState)
# Add nodes (one for each tool)
workflow.add_node("schema_retrieval", schema_tool.execute)
workflow.add_node("sql_generator", generator_tool.execute)
workflow.add_node("sql_validator", validator_tool.execute)
workflow.add_node("sql_executor", executor_tool.execute)
# Define flow
workflow.set_entry_point("schema_retrieval")
workflow.add_edge("schema_retrieval", "sql_generator")
workflow.add_edge("sql_generator", "sql_validator")
# Conditional edge: Should we retry generation or proceed to execution?
def should_retry_or_execute(state: SQLAgentState) -> str:
"""Decide whether to retry SQL generation or proceed to execution"""
if state["is_valid"]:
return "execute"
# Retry if we haven't exceeded max attempts
if state.get("validation_attempts", 0) < state.get("max_iterations", 3):
return "retry"
# Give up and try to execute anyway (will fail gracefully)
return "execute"
workflow.add_conditional_edges(
"sql_validator",
should_retry_or_execute,
{
"retry": "sql_validator", # Re-validate (after fix attempt)
"execute": "sql_executor"
}
)
workflow.add_edge("sql_executor", END)
return workflow.compile()
# ============================================================================
# MAIN DEMO
# ============================================================================
def main():
"""
Demonstrate the 4-tool LangGraph architecture
"""
print("="*80)
print("SQL AGENT WITH LANGGRAPH - 4 TOOLS AS SEPARATE NODES")
print("="*80)
# Initialize database
print("n📚 Step 1: Initialize Database")
db_manager = DatabaseManager()
if not db_manager.test_connection():
print("❌ Database connection failed")
return
# Create sample data if needed
tables = db_manager.get_all_tables()
if not tables:
print("Creating sample database...")
db_manager.create_sample_database()
tables = db_manager.get_all_tables()
print(f"✅ Connected to database with {len(tables)} tables")
# Create workflow
print("n🔧 Step 2: Create LangGraph Workflow")
workflow = create_sql_agent_workflow(db_manager)
print("✅ Workflow created with 4 tool nodes:")
print(" 1️⃣ Schema/Glossary Retrieval Tool")
print(" 2️⃣ SQL Generator Tool")
print(" 3️⃣ SQL Validator Tool")
print(" 4️⃣ SQL Executor Tool")
print("n Flow: 1 → 2 → 3 → (retry if invalid) → 4")
# Process query
print("n🚀 Step 3: Process Query")
query = "Show all customers from California"
print(f" Query: '{query}'")
# Initialize state
initial_state = {
"query": query,
"schema_info": {},
"glossary_terms": [],
"generated_sql": "",
"explanation": "",
"tables_used": [],
"is_valid": False,
"validation_errors": [],
"validation_attempts": 0,
"execution_success": False,
"result_data": [],
"result_columns": [],
"row_count": 0,
"execution_time": 0.0,
"execution_error": "",
"current_tool": "",
"iteration": 0,
"max_iterations": 3,
"messages": []
}
# Execute workflow
print("n⏳ Executing workflow...")
print("-" * 80)
final_state = workflow.invoke(initial_state)
print("-" * 80)
# Display results
print("n📊 RESULTS")
print("="*80)
print("n📝 Generated SQL:")
print(final_state["generated_sql"])
print(f"nExplanation: {final_state['explanation']}")
print(f"Tables used: {', '.join(final_state['tables_used'])}")
print(f"n✓ Validation: {'✅ Valid' if final_state['is_valid'] else '❌ Invalid'}")
if final_state["validation_errors"]:
print(f" Errors: {final_state['validation_errors']}")
print(f" Validation attempts: {final_state['validation_attempts']}")
print(f"n⚡ Execution: {'✅ Success' if final_state['execution_success'] else '❌ Failed'}")
if final_state["execution_success"]:
print(f" Rows: {final_state['row_count']}")
print(f" Time: {final_state['execution_time']:.3f}s")
if final_state["result_data"]:
print(f"n Sample row:")
for key, value in final_state["result_data"][0].items():
print(f" {key}: {value}")
else:
print(f" Error: {final_state['execution_error']}")
print("n📜 Workflow Trace:")
for i, msg in enumerate(final_state["messages"], 1):
print(f" {i}. {msg}")
print("n" + "="*80)
print("✅ Workflow completed!")
print("="*80)
if __name__ == "__main__":
try:
main()
except ImportError as e:
print(f"❌ Missing dependency: {e}")
print("nInstall LangGraph:")
print(" pip install langgraph langchain-core")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()
![]() |
Notes is a web-based application for online taking notes. You can take your notes and share with others people. If you like taking long notes, notes.io is designed for you. To date, over 8,000,000,000+ notes created and continuing...
With notes.io;
- * You can take a note from anywhere and any device with internet connection.
- * You can share the notes in social platforms (YouTube, Facebook, Twitter, instagram etc.).
- * You can quickly share your contents without website, blog and e-mail.
- * You don't need to create any Account to share a note. As you wish you can use quick, easy and best shortened notes with sms, websites, e-mail, or messaging services (WhatsApp, iMessage, Telegram, Signal).
- * Notes.io has fabulous infrastructure design for a short link and allows you to share the note as an easy and understandable link.
Fast: Notes.io is built for speed and performance. You can take a notes quickly and browse your archive.
Easy: Notes.io doesn’t require installation. Just write and share note!
Short: Notes.io’s url just 8 character. You’ll get shorten link of your note when you want to share. (Ex: notes.io/q )
Free: Notes.io works for 14 years and has been free since the day it was started.
You immediately create your first note and start sharing with the ones you wish. If you want to contact us, you can use the following communication channels;
Email: [email protected]
Twitter: http://twitter.com/notesio
Instagram: http://instagram.com/notes.io
Facebook: http://facebook.com/notesio
Regards;
Notes.io Team
