first commit
This commit is contained in:
516
app/agent.py
Normal file
516
app/agent.py
Normal file
@@ -0,0 +1,516 @@
|
||||
"""
|
||||
Agentic Address Verification Orchestrator.
|
||||
|
||||
This module implements a lightweight ReAct-inspired autonomous agent for batch
|
||||
address verification. The agent follows a structured workflow:
|
||||
|
||||
1. PLANNING PHASE: Query records needing verification
|
||||
2. EXECUTION PHASE: For each record, follow think-act-observe-reflect cycle
|
||||
- If geocoding fails, attempt fuzzy matching to correct misspellings
|
||||
- Retry geocoding with corrected address
|
||||
3. REFLECTION PHASE: Summarize batch results and statistics
|
||||
|
||||
The agent is designed for resilience - individual record failures don't stop
|
||||
the batch, and progress is committed incrementally.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, date
|
||||
from typing import List, Optional
|
||||
|
||||
from geopy.geocoders import Nominatim
|
||||
from sqlalchemy import or_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import (
|
||||
BATCH_SIZE,
|
||||
COMMIT_BATCH_SIZE,
|
||||
NOMINATIM_USER_AGENT,
|
||||
)
|
||||
from app.models import CustomerCustomer
|
||||
from app.tools import (
|
||||
build_address,
|
||||
validate_address_components,
|
||||
format_address_string,
|
||||
geocode_address,
|
||||
validate_geocode_result,
|
||||
update_record,
|
||||
rate_limit_sleep,
|
||||
GeocodeResult,
|
||||
get_state_abbreviation,
|
||||
)
|
||||
from app.streets import correct_address, get_town_street_count
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStats:
|
||||
"""Statistics for a batch verification run."""
|
||||
total_queried: int = 0
|
||||
processed: int = 0
|
||||
updated: int = 0
|
||||
corrected: int = 0 # Addresses fixed via fuzzy matching
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
rate_limited: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
corrections: List[str] = field(default_factory=list) # Log of corrections made
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float:
|
||||
"""Calculate batch duration in seconds."""
|
||||
if self.start_time and self.end_time:
|
||||
return (self.end_time - self.start_time).total_seconds()
|
||||
return 0.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert stats to dictionary for JSON response."""
|
||||
return {
|
||||
"total_queried": self.total_queried,
|
||||
"processed": self.processed,
|
||||
"updated": self.updated,
|
||||
"corrected": self.corrected,
|
||||
"failed": self.failed,
|
||||
"skipped": self.skipped,
|
||||
"rate_limited": self.rate_limited,
|
||||
"duration_seconds": round(self.duration_seconds, 2),
|
||||
"errors_count": len(self.errors),
|
||||
"sample_errors": self.errors[:5] if self.errors else [],
|
||||
"sample_corrections": self.corrections[:5] if self.corrections else [],
|
||||
}
|
||||
|
||||
|
||||
class AddressVerificationAgent:
|
||||
"""
|
||||
Lightweight autonomous agent for address verification.
|
||||
|
||||
Implements a ReAct-inspired workflow where each record goes through:
|
||||
- OBSERVE: Examine the address data
|
||||
- THINK: Decide if geocoding should be attempted
|
||||
- ACT: Call geocoding API
|
||||
- OBSERVE: Examine the result
|
||||
- REFLECT: Log decision and update database
|
||||
|
||||
Attributes:
|
||||
session: SQLAlchemy database session
|
||||
batch_size: Maximum records per batch
|
||||
commit_size: Records between commits
|
||||
stats: Running statistics for the batch
|
||||
geocoder: Nominatim geocoder instance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
batch_size: int = BATCH_SIZE,
|
||||
commit_size: int = COMMIT_BATCH_SIZE,
|
||||
):
|
||||
"""
|
||||
Initialize the address verification agent.
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy session for database operations
|
||||
batch_size: Max records to process (default from config)
|
||||
commit_size: Records before intermediate commit
|
||||
"""
|
||||
self.session = session
|
||||
self.batch_size = batch_size
|
||||
self.commit_size = commit_size
|
||||
self.stats = BatchStats()
|
||||
self.geocoder = Nominatim(user_agent=NOMINATIM_USER_AGENT)
|
||||
|
||||
logger.info(
|
||||
f"Agent initialized: batch_size={batch_size}, commit_size={commit_size}"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# PHASE 1: PLANNING
|
||||
# =========================================================================
|
||||
|
||||
def plan_batch(self) -> List[CustomerCustomer]:
|
||||
"""
|
||||
PLANNING PHASE: Query records that need address verification.
|
||||
|
||||
Criteria for selection:
|
||||
- correct_address = FALSE, OR
|
||||
- verified_at IS NULL, OR
|
||||
- verified_at < today (not verified today)
|
||||
|
||||
Returns:
|
||||
List of CustomerCustomer records to process
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("PLANNING PHASE: Querying records needing verification")
|
||||
logger.info("=" * 60)
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Build query for records needing verification
|
||||
query = self.session.query(CustomerCustomer).filter(
|
||||
or_(
|
||||
CustomerCustomer.correct_address == False, # noqa: E712
|
||||
CustomerCustomer.verified_at.is_(None),
|
||||
func.date(CustomerCustomer.verified_at) < today,
|
||||
)
|
||||
).limit(self.batch_size)
|
||||
|
||||
records = query.all()
|
||||
self.stats.total_queried = len(records)
|
||||
|
||||
logger.info(
|
||||
f"PLAN RESULT: Found {len(records)} records needing verification",
|
||||
extra={"record_count": len(records), "batch_limit": self.batch_size}
|
||||
)
|
||||
|
||||
# Log sample of record IDs for debugging
|
||||
if records:
|
||||
sample_ids = [r.id for r in records[:10]]
|
||||
logger.debug(f"Sample record IDs: {sample_ids}")
|
||||
|
||||
return records
|
||||
|
||||
# =========================================================================
|
||||
# PHASE 2: EXECUTION (ReAct-style per record)
|
||||
# =========================================================================
|
||||
|
||||
def process_record(self, customer: CustomerCustomer) -> bool:
|
||||
"""
|
||||
EXECUTION PHASE: Process a single record with ReAct-style workflow.
|
||||
|
||||
Steps:
|
||||
1. OBSERVE: Build address from record components
|
||||
2. THINK: Validate address - skip if obviously invalid
|
||||
3. ACT: Call Nominatim geocoder
|
||||
4. OBSERVE: Examine geocoding result
|
||||
5. REFLECT: Log decision and update database
|
||||
|
||||
Args:
|
||||
customer: CustomerCustomer record to process
|
||||
|
||||
Returns:
|
||||
True if record was successfully updated, False otherwise
|
||||
"""
|
||||
logger.info("-" * 40)
|
||||
logger.info(f"Processing record ID: {customer.id}")
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# STEP 1: OBSERVE - Build address from components
|
||||
# -----------------------------------------------------------------
|
||||
logger.debug(f"[OBSERVE] Building address for customer {customer.id}")
|
||||
address_components = build_address(customer)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# STEP 2: THINK - Validate address components
|
||||
# -----------------------------------------------------------------
|
||||
logger.debug(f"[THINK] Validating address components")
|
||||
address_components = validate_address_components(address_components)
|
||||
|
||||
if not address_components.is_valid:
|
||||
# REFLECT: Skip invalid addresses
|
||||
logger.info(
|
||||
f"[REFLECT] Skipping record {customer.id}: "
|
||||
f"{address_components.validation_error}"
|
||||
)
|
||||
self.stats.skipped += 1
|
||||
|
||||
# Still update the record to mark it as processed
|
||||
geocode_result = GeocodeResult(
|
||||
success=False,
|
||||
skipped=True,
|
||||
skip_reason=address_components.validation_error,
|
||||
error_message=address_components.validation_error,
|
||||
)
|
||||
update_record(self.session, customer, geocode_result, is_valid=False)
|
||||
return False
|
||||
|
||||
# Format address for geocoding
|
||||
address_string = format_address_string(address_components)
|
||||
logger.debug(f"[THINK] Formatted address: {address_string}")
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# STEP 3: ACT - Call geocoding API
|
||||
# -----------------------------------------------------------------
|
||||
logger.debug(f"[ACT] Calling Nominatim geocoder")
|
||||
geocode_result = geocode_address(address_string, self.geocoder)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# STEP 4: OBSERVE - Examine geocoding result
|
||||
# -----------------------------------------------------------------
|
||||
logger.debug(f"[OBSERVE] Geocoding result: success={geocode_result.success}")
|
||||
|
||||
if not geocode_result.success:
|
||||
# -----------------------------------------------------------------
|
||||
# STEP 4a: THINK - Try fuzzy matching to correct address
|
||||
# -----------------------------------------------------------------
|
||||
logger.info(
|
||||
f"[THINK] Geocoding failed, attempting fuzzy street matching..."
|
||||
)
|
||||
|
||||
# Get state abbreviation for fuzzy matching
|
||||
state_abbr = get_state_abbreviation(customer.customer_state)
|
||||
town = address_components.city
|
||||
|
||||
if state_abbr and town:
|
||||
# Check if we have street data for this town
|
||||
street_count = get_town_street_count(self.session, town, state_abbr)
|
||||
|
||||
if street_count > 0:
|
||||
# Try to correct the address
|
||||
match = correct_address(
|
||||
session=self.session,
|
||||
full_address=address_components.street or "",
|
||||
town=town,
|
||||
state=state_abbr,
|
||||
min_confidence=75.0,
|
||||
)
|
||||
|
||||
if match and match.corrected_address:
|
||||
logger.info(
|
||||
f"[ACT] Found correction: '{address_components.street}' "
|
||||
f"-> '{match.corrected_address}' "
|
||||
f"(confidence: {match.confidence_score:.1f}%)"
|
||||
)
|
||||
|
||||
# Build corrected address string
|
||||
corrected_components = address_components
|
||||
corrected_components.street = match.corrected_address
|
||||
corrected_address_string = format_address_string(corrected_components)
|
||||
|
||||
logger.info(f"[ACT] Retrying with corrected address: {corrected_address_string}")
|
||||
|
||||
# Rate limit before retry
|
||||
rate_limit_sleep()
|
||||
|
||||
# Retry geocoding with corrected address
|
||||
geocode_result = geocode_address(corrected_address_string, self.geocoder)
|
||||
|
||||
if geocode_result.success:
|
||||
logger.info(
|
||||
f"[OBSERVE] Corrected address geocoded successfully!"
|
||||
)
|
||||
self.stats.corrected += 1
|
||||
self.stats.corrections.append(
|
||||
f"ID {customer.id}: '{address_components.street}' "
|
||||
f"-> '{match.corrected_address}'"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[OBSERVE] Corrected address still failed to geocode"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[THINK] No confident fuzzy match found"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[THINK] No street reference data for {town}, {state_abbr}. "
|
||||
f"Use POST /streets/{town}/{state_abbr} to populate."
|
||||
)
|
||||
|
||||
# If still failed after correction attempt
|
||||
if not geocode_result.success:
|
||||
# -----------------------------------------------------------------
|
||||
# STEP 5a: REFLECT - Handle failed geocoding
|
||||
# -----------------------------------------------------------------
|
||||
logger.info(
|
||||
f"[REFLECT] Geocoding failed for record {customer.id}: "
|
||||
f"{geocode_result.error_message}"
|
||||
)
|
||||
self.stats.failed += 1
|
||||
self.stats.errors.append(
|
||||
f"ID {customer.id}: {geocode_result.error_message}"
|
||||
)
|
||||
|
||||
update_record(self.session, customer, geocode_result, is_valid=False)
|
||||
return False
|
||||
|
||||
# Validate geocode result quality
|
||||
is_valid, validation_reason = validate_geocode_result(geocode_result)
|
||||
logger.debug(f"[OBSERVE] Validation: valid={is_valid}, reason={validation_reason}")
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# STEP 5b: REFLECT - Update database with result
|
||||
# -----------------------------------------------------------------
|
||||
if is_valid:
|
||||
logger.info(
|
||||
f"[REFLECT] Success for record {customer.id}: "
|
||||
f"lat={geocode_result.latitude}, lon={geocode_result.longitude}"
|
||||
)
|
||||
self.stats.updated += 1
|
||||
else:
|
||||
logger.info(
|
||||
f"[REFLECT] Invalid result for record {customer.id}: {validation_reason}"
|
||||
)
|
||||
self.stats.failed += 1
|
||||
self.stats.errors.append(f"ID {customer.id}: {validation_reason}")
|
||||
|
||||
update_record(self.session, customer, geocode_result, is_valid=is_valid)
|
||||
return is_valid
|
||||
|
||||
def execute_batch(self, records: List[CustomerCustomer]) -> None:
|
||||
"""
|
||||
Execute the batch processing loop with rate limiting.
|
||||
|
||||
Processes records sequentially with proper rate limiting between
|
||||
geocoding calls. Commits to database periodically.
|
||||
|
||||
Args:
|
||||
records: List of CustomerCustomer records to process
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("EXECUTION PHASE: Processing records")
|
||||
logger.info("=" * 60)
|
||||
|
||||
uncommitted_count = 0
|
||||
|
||||
for i, customer in enumerate(records):
|
||||
try:
|
||||
# Process the record
|
||||
self.process_record(customer)
|
||||
self.stats.processed += 1
|
||||
uncommitted_count += 1
|
||||
|
||||
# Commit in batches
|
||||
if uncommitted_count >= self.commit_size:
|
||||
logger.info(f"Committing batch of {uncommitted_count} records")
|
||||
self.session.commit()
|
||||
uncommitted_count = 0
|
||||
|
||||
# Rate limiting (skip on last record)
|
||||
if i < len(records) - 1:
|
||||
rate_limit_sleep()
|
||||
|
||||
except Exception as e:
|
||||
# Handle unexpected errors - continue processing
|
||||
logger.error(
|
||||
f"Unexpected error processing record {customer.id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
self.stats.failed += 1
|
||||
self.stats.errors.append(f"ID {customer.id}: Unexpected error: {str(e)}")
|
||||
self.stats.processed += 1
|
||||
|
||||
# Rollback the current transaction and continue
|
||||
self.session.rollback()
|
||||
uncommitted_count = 0
|
||||
|
||||
# Final commit for any remaining records
|
||||
if uncommitted_count > 0:
|
||||
logger.info(f"Final commit of {uncommitted_count} records")
|
||||
self.session.commit()
|
||||
|
||||
# =========================================================================
|
||||
# PHASE 3: REFLECTION
|
||||
# =========================================================================
|
||||
|
||||
def reflect(self) -> dict:
|
||||
"""
|
||||
REFLECTION PHASE: Summarize batch results and statistics.
|
||||
|
||||
Logs comprehensive statistics about the batch run and returns
|
||||
a summary dictionary suitable for API response.
|
||||
|
||||
Returns:
|
||||
Dictionary with batch statistics
|
||||
"""
|
||||
self.stats.end_time = datetime.utcnow()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("REFLECTION PHASE: Batch Summary")
|
||||
logger.info("=" * 60)
|
||||
|
||||
stats_dict = self.stats.to_dict()
|
||||
|
||||
logger.info(f"Total queried: {stats_dict['total_queried']}")
|
||||
logger.info(f"Processed: {stats_dict['processed']}")
|
||||
logger.info(f"Updated (valid): {stats_dict['updated']}")
|
||||
logger.info(f"Corrected: {stats_dict['corrected']}")
|
||||
logger.info(f"Failed: {stats_dict['failed']}")
|
||||
logger.info(f"Skipped: {stats_dict['skipped']}")
|
||||
logger.info(f"Duration: {stats_dict['duration_seconds']}s")
|
||||
|
||||
if stats_dict['errors_count'] > 0:
|
||||
logger.warning(f"Errors encountered: {stats_dict['errors_count']}")
|
||||
for error in stats_dict['sample_errors']:
|
||||
logger.warning(f" - {error}")
|
||||
|
||||
if stats_dict['corrected'] > 0:
|
||||
logger.info(f"Addresses corrected via fuzzy matching: {stats_dict['corrected']}")
|
||||
for correction in stats_dict['sample_corrections']:
|
||||
logger.info(f" - {correction}")
|
||||
|
||||
# Calculate success rate
|
||||
if stats_dict['processed'] > 0:
|
||||
success_rate = (stats_dict['updated'] / stats_dict['processed']) * 100
|
||||
logger.info(f"Success rate: {success_rate:.1f}%")
|
||||
stats_dict['success_rate'] = round(success_rate, 1)
|
||||
else:
|
||||
stats_dict['success_rate'] = 0.0
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
return stats_dict
|
||||
|
||||
# =========================================================================
|
||||
# MAIN ENTRY POINT
|
||||
# =========================================================================
|
||||
|
||||
def run(self) -> dict:
|
||||
"""
|
||||
Execute the full agent workflow.
|
||||
|
||||
Runs through all three phases:
|
||||
1. Planning - Query records
|
||||
2. Execution - Process each record
|
||||
3. Reflection - Summarize results
|
||||
|
||||
Returns:
|
||||
Dictionary with batch statistics and message
|
||||
"""
|
||||
logger.info("*" * 60)
|
||||
logger.info("ADDRESS VERIFICATION AGENT STARTING")
|
||||
logger.info("*" * 60)
|
||||
|
||||
self.stats.start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Phase 1: Planning
|
||||
records = self.plan_batch()
|
||||
|
||||
if not records:
|
||||
logger.info("No records to process - batch complete")
|
||||
self.stats.end_time = datetime.utcnow()
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "No records needed verification",
|
||||
**self.stats.to_dict(),
|
||||
}
|
||||
|
||||
# Phase 2: Execution
|
||||
self.execute_batch(records)
|
||||
|
||||
# Phase 3: Reflection
|
||||
stats = self.reflect()
|
||||
|
||||
logger.info("*" * 60)
|
||||
logger.info("ADDRESS VERIFICATION AGENT COMPLETE")
|
||||
logger.info("*" * 60)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Batch complete: {stats['updated']} addresses updated",
|
||||
**stats,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent failed with error: {e}", exc_info=True)
|
||||
self.stats.end_time = datetime.utcnow()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Agent failed: {str(e)}",
|
||||
**self.stats.to_dict(),
|
||||
}
|
||||
Reference in New Issue
Block a user