517 lines
19 KiB
Python
517 lines
19 KiB
Python
"""
|
|
Agentic Address Verification Orchestrator.
|
|
|
|
This module implements a lightweight ReAct-inspired autonomous agent for batch
|
|
address verification. The agent follows a structured workflow:
|
|
|
|
1. PLANNING PHASE: Query records needing verification
|
|
2. EXECUTION PHASE: For each record, follow think-act-observe-reflect cycle
|
|
- If geocoding fails, attempt fuzzy matching to correct misspellings
|
|
- Retry geocoding with corrected address
|
|
3. REFLECTION PHASE: Summarize batch results and statistics
|
|
|
|
The agent is designed for resilience - individual record failures don't stop
|
|
the batch, and progress is committed incrementally.
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, date
|
|
from typing import List, Optional
|
|
|
|
from geopy.geocoders import Nominatim
|
|
from sqlalchemy import or_, func
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import (
|
|
BATCH_SIZE,
|
|
COMMIT_BATCH_SIZE,
|
|
NOMINATIM_USER_AGENT,
|
|
)
|
|
from app.models import CustomerCustomer
|
|
from app.tools import (
|
|
build_address,
|
|
validate_address_components,
|
|
format_address_string,
|
|
geocode_address,
|
|
validate_geocode_result,
|
|
update_record,
|
|
rate_limit_sleep,
|
|
GeocodeResult,
|
|
get_state_abbreviation,
|
|
)
|
|
from app.streets import correct_address, get_town_street_count
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BatchStats:
|
|
"""Statistics for a batch verification run."""
|
|
total_queried: int = 0
|
|
processed: int = 0
|
|
updated: int = 0
|
|
corrected: int = 0 # Addresses fixed via fuzzy matching
|
|
failed: int = 0
|
|
skipped: int = 0
|
|
rate_limited: int = 0
|
|
errors: List[str] = field(default_factory=list)
|
|
corrections: List[str] = field(default_factory=list) # Log of corrections made
|
|
start_time: Optional[datetime] = None
|
|
end_time: Optional[datetime] = None
|
|
|
|
@property
|
|
def duration_seconds(self) -> float:
|
|
"""Calculate batch duration in seconds."""
|
|
if self.start_time and self.end_time:
|
|
return (self.end_time - self.start_time).total_seconds()
|
|
return 0.0
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert stats to dictionary for JSON response."""
|
|
return {
|
|
"total_queried": self.total_queried,
|
|
"processed": self.processed,
|
|
"updated": self.updated,
|
|
"corrected": self.corrected,
|
|
"failed": self.failed,
|
|
"skipped": self.skipped,
|
|
"rate_limited": self.rate_limited,
|
|
"duration_seconds": round(self.duration_seconds, 2),
|
|
"errors_count": len(self.errors),
|
|
"sample_errors": self.errors[:5] if self.errors else [],
|
|
"sample_corrections": self.corrections[:5] if self.corrections else [],
|
|
}
|
|
|
|
|
|
class AddressVerificationAgent:
|
|
"""
|
|
Lightweight autonomous agent for address verification.
|
|
|
|
Implements a ReAct-inspired workflow where each record goes through:
|
|
- OBSERVE: Examine the address data
|
|
- THINK: Decide if geocoding should be attempted
|
|
- ACT: Call geocoding API
|
|
- OBSERVE: Examine the result
|
|
- REFLECT: Log decision and update database
|
|
|
|
Attributes:
|
|
session: SQLAlchemy database session
|
|
batch_size: Maximum records per batch
|
|
commit_size: Records between commits
|
|
stats: Running statistics for the batch
|
|
geocoder: Nominatim geocoder instance
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
session: Session,
|
|
batch_size: int = BATCH_SIZE,
|
|
commit_size: int = COMMIT_BATCH_SIZE,
|
|
):
|
|
"""
|
|
Initialize the address verification agent.
|
|
|
|
Args:
|
|
session: SQLAlchemy session for database operations
|
|
batch_size: Max records to process (default from config)
|
|
commit_size: Records before intermediate commit
|
|
"""
|
|
self.session = session
|
|
self.batch_size = batch_size
|
|
self.commit_size = commit_size
|
|
self.stats = BatchStats()
|
|
self.geocoder = Nominatim(user_agent=NOMINATIM_USER_AGENT)
|
|
|
|
logger.info(
|
|
f"Agent initialized: batch_size={batch_size}, commit_size={commit_size}"
|
|
)
|
|
|
|
# =========================================================================
|
|
# PHASE 1: PLANNING
|
|
# =========================================================================
|
|
|
|
def plan_batch(self) -> List[CustomerCustomer]:
|
|
"""
|
|
PLANNING PHASE: Query records that need address verification.
|
|
|
|
Criteria for selection:
|
|
- correct_address = FALSE, OR
|
|
- verified_at IS NULL, OR
|
|
- verified_at < today (not verified today)
|
|
|
|
Returns:
|
|
List of CustomerCustomer records to process
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("PLANNING PHASE: Querying records needing verification")
|
|
logger.info("=" * 60)
|
|
|
|
today = date.today()
|
|
|
|
# Build query for records needing verification
|
|
query = self.session.query(CustomerCustomer).filter(
|
|
or_(
|
|
CustomerCustomer.correct_address == False, # noqa: E712
|
|
CustomerCustomer.verified_at.is_(None),
|
|
func.date(CustomerCustomer.verified_at) < today,
|
|
)
|
|
).limit(self.batch_size)
|
|
|
|
records = query.all()
|
|
self.stats.total_queried = len(records)
|
|
|
|
logger.info(
|
|
f"PLAN RESULT: Found {len(records)} records needing verification",
|
|
extra={"record_count": len(records), "batch_limit": self.batch_size}
|
|
)
|
|
|
|
# Log sample of record IDs for debugging
|
|
if records:
|
|
sample_ids = [r.id for r in records[:10]]
|
|
logger.debug(f"Sample record IDs: {sample_ids}")
|
|
|
|
return records
|
|
|
|
# =========================================================================
|
|
# PHASE 2: EXECUTION (ReAct-style per record)
|
|
# =========================================================================
|
|
|
|
def process_record(self, customer: CustomerCustomer) -> bool:
|
|
"""
|
|
EXECUTION PHASE: Process a single record with ReAct-style workflow.
|
|
|
|
Steps:
|
|
1. OBSERVE: Build address from record components
|
|
2. THINK: Validate address - skip if obviously invalid
|
|
3. ACT: Call Nominatim geocoder
|
|
4. OBSERVE: Examine geocoding result
|
|
5. REFLECT: Log decision and update database
|
|
|
|
Args:
|
|
customer: CustomerCustomer record to process
|
|
|
|
Returns:
|
|
True if record was successfully updated, False otherwise
|
|
"""
|
|
logger.info("-" * 40)
|
|
logger.info(f"Processing record ID: {customer.id}")
|
|
|
|
# -----------------------------------------------------------------
|
|
# STEP 1: OBSERVE - Build address from components
|
|
# -----------------------------------------------------------------
|
|
logger.debug(f"[OBSERVE] Building address for customer {customer.id}")
|
|
address_components = build_address(customer)
|
|
|
|
# -----------------------------------------------------------------
|
|
# STEP 2: THINK - Validate address components
|
|
# -----------------------------------------------------------------
|
|
logger.debug(f"[THINK] Validating address components")
|
|
address_components = validate_address_components(address_components)
|
|
|
|
if not address_components.is_valid:
|
|
# REFLECT: Skip invalid addresses
|
|
logger.info(
|
|
f"[REFLECT] Skipping record {customer.id}: "
|
|
f"{address_components.validation_error}"
|
|
)
|
|
self.stats.skipped += 1
|
|
|
|
# Still update the record to mark it as processed
|
|
geocode_result = GeocodeResult(
|
|
success=False,
|
|
skipped=True,
|
|
skip_reason=address_components.validation_error,
|
|
error_message=address_components.validation_error,
|
|
)
|
|
update_record(self.session, customer, geocode_result, is_valid=False)
|
|
return False
|
|
|
|
# Format address for geocoding
|
|
address_string = format_address_string(address_components)
|
|
logger.debug(f"[THINK] Formatted address: {address_string}")
|
|
|
|
# -----------------------------------------------------------------
|
|
# STEP 3: ACT - Call geocoding API
|
|
# -----------------------------------------------------------------
|
|
logger.debug(f"[ACT] Calling Nominatim geocoder")
|
|
geocode_result = geocode_address(address_string, self.geocoder)
|
|
|
|
# -----------------------------------------------------------------
|
|
# STEP 4: OBSERVE - Examine geocoding result
|
|
# -----------------------------------------------------------------
|
|
logger.debug(f"[OBSERVE] Geocoding result: success={geocode_result.success}")
|
|
|
|
if not geocode_result.success:
|
|
# -----------------------------------------------------------------
|
|
# STEP 4a: THINK - Try fuzzy matching to correct address
|
|
# -----------------------------------------------------------------
|
|
logger.info(
|
|
f"[THINK] Geocoding failed, attempting fuzzy street matching..."
|
|
)
|
|
|
|
# Get state abbreviation for fuzzy matching
|
|
state_abbr = get_state_abbreviation(customer.customer_state)
|
|
town = address_components.city
|
|
|
|
if state_abbr and town:
|
|
# Check if we have street data for this town
|
|
street_count = get_town_street_count(self.session, town, state_abbr)
|
|
|
|
if street_count > 0:
|
|
# Try to correct the address
|
|
match = correct_address(
|
|
session=self.session,
|
|
full_address=address_components.street or "",
|
|
town=town,
|
|
state=state_abbr,
|
|
min_confidence=75.0,
|
|
)
|
|
|
|
if match and match.corrected_address:
|
|
logger.info(
|
|
f"[ACT] Found correction: '{address_components.street}' "
|
|
f"-> '{match.corrected_address}' "
|
|
f"(confidence: {match.confidence_score:.1f}%)"
|
|
)
|
|
|
|
# Build corrected address string
|
|
corrected_components = address_components
|
|
corrected_components.street = match.corrected_address
|
|
corrected_address_string = format_address_string(corrected_components)
|
|
|
|
logger.info(f"[ACT] Retrying with corrected address: {corrected_address_string}")
|
|
|
|
# Rate limit before retry
|
|
rate_limit_sleep()
|
|
|
|
# Retry geocoding with corrected address
|
|
geocode_result = geocode_address(corrected_address_string, self.geocoder)
|
|
|
|
if geocode_result.success:
|
|
logger.info(
|
|
f"[OBSERVE] Corrected address geocoded successfully!"
|
|
)
|
|
self.stats.corrected += 1
|
|
self.stats.corrections.append(
|
|
f"ID {customer.id}: '{address_components.street}' "
|
|
f"-> '{match.corrected_address}'"
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"[OBSERVE] Corrected address still failed to geocode"
|
|
)
|
|
else:
|
|
logger.debug(
|
|
f"[THINK] No confident fuzzy match found"
|
|
)
|
|
else:
|
|
logger.debug(
|
|
f"[THINK] No street reference data for {town}, {state_abbr}. "
|
|
f"Use POST /streets/{town}/{state_abbr} to populate."
|
|
)
|
|
|
|
# If still failed after correction attempt
|
|
if not geocode_result.success:
|
|
# -----------------------------------------------------------------
|
|
# STEP 5a: REFLECT - Handle failed geocoding
|
|
# -----------------------------------------------------------------
|
|
logger.info(
|
|
f"[REFLECT] Geocoding failed for record {customer.id}: "
|
|
f"{geocode_result.error_message}"
|
|
)
|
|
self.stats.failed += 1
|
|
self.stats.errors.append(
|
|
f"ID {customer.id}: {geocode_result.error_message}"
|
|
)
|
|
|
|
update_record(self.session, customer, geocode_result, is_valid=False)
|
|
return False
|
|
|
|
# Validate geocode result quality
|
|
is_valid, validation_reason = validate_geocode_result(geocode_result)
|
|
logger.debug(f"[OBSERVE] Validation: valid={is_valid}, reason={validation_reason}")
|
|
|
|
# -----------------------------------------------------------------
|
|
# STEP 5b: REFLECT - Update database with result
|
|
# -----------------------------------------------------------------
|
|
if is_valid:
|
|
logger.info(
|
|
f"[REFLECT] Success for record {customer.id}: "
|
|
f"lat={geocode_result.latitude}, lon={geocode_result.longitude}"
|
|
)
|
|
self.stats.updated += 1
|
|
else:
|
|
logger.info(
|
|
f"[REFLECT] Invalid result for record {customer.id}: {validation_reason}"
|
|
)
|
|
self.stats.failed += 1
|
|
self.stats.errors.append(f"ID {customer.id}: {validation_reason}")
|
|
|
|
update_record(self.session, customer, geocode_result, is_valid=is_valid)
|
|
return is_valid
|
|
|
|
def execute_batch(self, records: List[CustomerCustomer]) -> None:
|
|
"""
|
|
Execute the batch processing loop with rate limiting.
|
|
|
|
Processes records sequentially with proper rate limiting between
|
|
geocoding calls. Commits to database periodically.
|
|
|
|
Args:
|
|
records: List of CustomerCustomer records to process
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("EXECUTION PHASE: Processing records")
|
|
logger.info("=" * 60)
|
|
|
|
uncommitted_count = 0
|
|
|
|
for i, customer in enumerate(records):
|
|
try:
|
|
# Process the record
|
|
self.process_record(customer)
|
|
self.stats.processed += 1
|
|
uncommitted_count += 1
|
|
|
|
# Commit in batches
|
|
if uncommitted_count >= self.commit_size:
|
|
logger.info(f"Committing batch of {uncommitted_count} records")
|
|
self.session.commit()
|
|
uncommitted_count = 0
|
|
|
|
# Rate limiting (skip on last record)
|
|
if i < len(records) - 1:
|
|
rate_limit_sleep()
|
|
|
|
except Exception as e:
|
|
# Handle unexpected errors - continue processing
|
|
logger.error(
|
|
f"Unexpected error processing record {customer.id}: {e}",
|
|
exc_info=True
|
|
)
|
|
self.stats.failed += 1
|
|
self.stats.errors.append(f"ID {customer.id}: Unexpected error: {str(e)}")
|
|
self.stats.processed += 1
|
|
|
|
# Rollback the current transaction and continue
|
|
self.session.rollback()
|
|
uncommitted_count = 0
|
|
|
|
# Final commit for any remaining records
|
|
if uncommitted_count > 0:
|
|
logger.info(f"Final commit of {uncommitted_count} records")
|
|
self.session.commit()
|
|
|
|
# =========================================================================
|
|
# PHASE 3: REFLECTION
|
|
# =========================================================================
|
|
|
|
def reflect(self) -> dict:
|
|
"""
|
|
REFLECTION PHASE: Summarize batch results and statistics.
|
|
|
|
Logs comprehensive statistics about the batch run and returns
|
|
a summary dictionary suitable for API response.
|
|
|
|
Returns:
|
|
Dictionary with batch statistics
|
|
"""
|
|
self.stats.end_time = datetime.utcnow()
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("REFLECTION PHASE: Batch Summary")
|
|
logger.info("=" * 60)
|
|
|
|
stats_dict = self.stats.to_dict()
|
|
|
|
logger.info(f"Total queried: {stats_dict['total_queried']}")
|
|
logger.info(f"Processed: {stats_dict['processed']}")
|
|
logger.info(f"Updated (valid): {stats_dict['updated']}")
|
|
logger.info(f"Corrected: {stats_dict['corrected']}")
|
|
logger.info(f"Failed: {stats_dict['failed']}")
|
|
logger.info(f"Skipped: {stats_dict['skipped']}")
|
|
logger.info(f"Duration: {stats_dict['duration_seconds']}s")
|
|
|
|
if stats_dict['errors_count'] > 0:
|
|
logger.warning(f"Errors encountered: {stats_dict['errors_count']}")
|
|
for error in stats_dict['sample_errors']:
|
|
logger.warning(f" - {error}")
|
|
|
|
if stats_dict['corrected'] > 0:
|
|
logger.info(f"Addresses corrected via fuzzy matching: {stats_dict['corrected']}")
|
|
for correction in stats_dict['sample_corrections']:
|
|
logger.info(f" - {correction}")
|
|
|
|
# Calculate success rate
|
|
if stats_dict['processed'] > 0:
|
|
success_rate = (stats_dict['updated'] / stats_dict['processed']) * 100
|
|
logger.info(f"Success rate: {success_rate:.1f}%")
|
|
stats_dict['success_rate'] = round(success_rate, 1)
|
|
else:
|
|
stats_dict['success_rate'] = 0.0
|
|
|
|
logger.info("=" * 60)
|
|
|
|
return stats_dict
|
|
|
|
# =========================================================================
|
|
# MAIN ENTRY POINT
|
|
# =========================================================================
|
|
|
|
def run(self) -> dict:
|
|
"""
|
|
Execute the full agent workflow.
|
|
|
|
Runs through all three phases:
|
|
1. Planning - Query records
|
|
2. Execution - Process each record
|
|
3. Reflection - Summarize results
|
|
|
|
Returns:
|
|
Dictionary with batch statistics and message
|
|
"""
|
|
logger.info("*" * 60)
|
|
logger.info("ADDRESS VERIFICATION AGENT STARTING")
|
|
logger.info("*" * 60)
|
|
|
|
self.stats.start_time = datetime.utcnow()
|
|
|
|
try:
|
|
# Phase 1: Planning
|
|
records = self.plan_batch()
|
|
|
|
if not records:
|
|
logger.info("No records to process - batch complete")
|
|
self.stats.end_time = datetime.utcnow()
|
|
return {
|
|
"status": "success",
|
|
"message": "No records needed verification",
|
|
**self.stats.to_dict(),
|
|
}
|
|
|
|
# Phase 2: Execution
|
|
self.execute_batch(records)
|
|
|
|
# Phase 3: Reflection
|
|
stats = self.reflect()
|
|
|
|
logger.info("*" * 60)
|
|
logger.info("ADDRESS VERIFICATION AGENT COMPLETE")
|
|
logger.info("*" * 60)
|
|
|
|
return {
|
|
"status": "success",
|
|
"message": f"Batch complete: {stats['updated']} addresses updated",
|
|
**stats,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Agent failed with error: {e}", exc_info=True)
|
|
self.stats.end_time = datetime.utcnow()
|
|
return {
|
|
"status": "error",
|
|
"message": f"Agent failed: {str(e)}",
|
|
**self.stats.to_dict(),
|
|
}
|