refactor(auth): migrate to httpOnly cookies and update vendor listings

Migrated JWT authentication from localStorage to httpOnly cookies using axum-extra. Refactored vendor listing and edit pages to use the centralized API client. Updated schema and data models to support these changes.
This commit is contained in:
2026-02-09 16:25:38 -05:00
parent feb1a173ec
commit caa318508b
11 changed files with 612 additions and 195 deletions

163
Cargo.lock generated
View File

@@ -40,6 +40,15 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.21"
@@ -67,6 +76,7 @@ version = "0.1.0"
dependencies = [
"argon2",
"axum",
"axum-extra",
"chrono",
"dotenv",
"hyper",
@@ -74,8 +84,11 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
"time",
"tokio",
"tower-http",
"tracing",
"tracing-subscriber",
]
[[package]]
@@ -166,6 +179,28 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum-extra"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a93e433be9382c737320af3924f7d5fc6f89c155cf2bf88949d8f5126fab283f"
dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"http",
"http-body",
"mime",
"pin-project-lite",
"serde",
"tokio",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.75"
@@ -277,6 +312,17 @@ dependencies = [
"windows-link",
]
[[package]]
name = "cookie"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7efb37c3e1ccb1ff97164ad95ac1606e8ccd35b3fa0a7d99a304c7f4a428cc24"
dependencies = [
"percent-encoding",
"time",
"version_check",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@@ -826,6 +872,12 @@ dependencies = [
"simple_asn1",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.172"
@@ -864,6 +916,15 @@ version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]]
name = "matchit"
version = "0.7.3"
@@ -928,6 +989,15 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
@@ -1190,6 +1260,23 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "regex-automata"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c"
[[package]]
name = "ring"
version = "0.16.20"
@@ -1350,6 +1437,15 @@ dependencies = [
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.3.0"
@@ -1605,6 +1701,15 @@ dependencies = [
"syn 2.0.102",
]
[[package]]
name = "thread_local"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
dependencies = [
"cfg-if",
]
[[package]]
name = "time"
version = "0.3.41"
@@ -1760,22 +1865,64 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
version = "0.1.41"
version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-core"
version = "0.1.34"
name = "tracing-attributes"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678"
checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.102",
]
[[package]]
name = "tracing-core"
version = "0.1.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
[[package]]
@@ -1858,6 +2005,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "version_check"
version = "0.9.5"

View File

@@ -15,3 +15,7 @@ dotenv = "0.15"
tower-http = { version = "0.4", features = ["cors"] }
argon2 = { version = "0.5.3", features = ["std"] }
hyper = "0.14"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
axum-extra = { version = "0.7", features = ["cookie"] }
time = "0.3"

View File

@@ -16,7 +16,7 @@ CREATE TABLE service_categories (
total_companies INTEGER DEFAULT 0
);
CREATE TABLE companies (
CREATE TABLE company (
id SERIAL PRIMARY KEY,
active BOOLEAN DEFAULT true,
created DATE NOT NULL DEFAULT CURRENT_DATE,
@@ -31,6 +31,14 @@ CREATE TABLE companies (
user_id INTEGER
);
-- Counties (populated by scripts/add_county_to_db.py)
CREATE TABLE county (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
state VARCHAR(2) NOT NULL,
UNIQUE(name, state)
);
CREATE TABLE listings (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) NOT NULL,

View File

@@ -6,6 +6,7 @@ use axum::{
body::Body,
http::Request as HttpRequest,
};
use axum_extra::extract::cookie::{CookieJar, Cookie, SameSite};
use crate::auth::structs::{AppState, User, RegisterRequest, LoginRequest, Claims};
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
@@ -13,13 +14,16 @@ use argon2::{
};
use jsonwebtoken::{decode, encode, Header, EncodingKey, DecodingKey, Validation};
// Cookie configuration constants
const AUTH_COOKIE_NAME: &str = "auth_token";
// A helper function to convert any error into a 500 Internal Server Error response.
fn internal_error<E>(err: E) -> Response
where
E: std::error::Error,
{
// Log the specific error to the server console for debugging.
eprintln!("Internal server error: {}", err);
tracing::error!("Internal server error: {}", err);
// Return a generic error message to the client.
(
@@ -34,6 +38,7 @@ pub async fn register(
State(state): State<AppState>,
Json(payload): Json<RegisterRequest>,
) -> impl IntoResponse {
tracing::info!(username = %payload.username, "Registration attempt");
// 1. Check if username exists, handling potential database errors
let user_exists = match sqlx::query("SELECT 1 FROM users WHERE username = $1")
.bind(&payload.username)
@@ -45,6 +50,7 @@ pub async fn register(
};
if user_exists {
tracing::warn!(username = %payload.username, "Registration failed: username already exists");
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "Username already exists" })),
@@ -71,20 +77,22 @@ pub async fn register(
.await;
match result {
Ok(user) => (StatusCode::CREATED, Json(user)).into_response(),
Ok(user) => {
tracing::info!(username = %payload.username, user_id = user.id, "User registered successfully");
(StatusCode::CREATED, Json(user)).into_response()
},
Err(e) => internal_error(e),
}
}
// Updated Login endpoint
// Login endpoint - sets JWT as httpOnly cookie
pub async fn login(
State(state): State<AppState>,
jar: CookieJar,
Json(payload): Json<LoginRequest>,
) -> impl IntoResponse {
tracing::info!(username = %payload.username.trim(), "Login attempt");
// 1. Fetch user from the database
let user = match sqlx::query_as::<_, User>("SELECT * FROM users WHERE TRIM(username) = $1")
.bind(&payload.username.trim())
@@ -94,9 +102,13 @@ pub async fn login(
Ok(Some(user)) => user,
Ok(None) => {
// User not found. Use a generic error message to prevent username enumeration attacks.
return (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" }))).into_response();
tracing::warn!(username = %payload.username.trim(), "Login failed: user not found");
return (jar, (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" })))).into_response();
}
Err(e) => {
tracing::error!("Database error during login: {}", e);
return (jar, (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "An internal server error occurred"})))).into_response();
}
Err(e) => return internal_error(e), // Database query failed
};
// --- FIX: Trim whitespace from the password hash string ---
@@ -108,7 +120,8 @@ pub async fn login(
Ok(hash) => hash,
Err(e) => {
// This is a server error because the hash in the DB is malformed.
return internal_error(e);
tracing::error!("Failed to parse password hash: {}", e);
return (jar, (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "An internal server error occurred"})))).into_response();
}
};
@@ -118,7 +131,8 @@ pub async fn login(
.is_err()
{
// Passwords do not match.
return (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" }))).into_response();
tracing::warn!(username = %payload.username.trim(), "Login failed: invalid password");
return (jar, (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" })))).into_response();
}
// 4. Update last_login. If this fails, log it but don't fail the login.
@@ -129,7 +143,7 @@ pub async fn login(
.execute(&*state.db)
.await
{
eprintln!("Failed to update last_login for user {}: {:?}", user.username, e);
tracing::error!("Failed to update last_login for user {}: {:?}", user.username, e);
}
// 5. Generate JWT
@@ -144,52 +158,103 @@ pub async fn login(
&EncodingKey::from_secret(state.jwt_secret.as_bytes()),
) {
Ok(t) => t,
Err(e) => return internal_error(e), // JWT generation failed
Err(e) => {
tracing::error!("Failed to generate JWT: {}", e);
return (jar, (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "An internal server error occurred"})))).into_response();
}
};
(StatusCode::OK, Json(serde_json::json!({ "token": token, "user": user }))).into_response()
// 6. Set JWT as httpOnly cookie
let cookie = Cookie::build(AUTH_COOKIE_NAME, token.clone())
.http_only(true)
.path("/")
.max_age(time::Duration::hours(24))
.same_site(SameSite::Lax)
// Note: In production with HTTPS, also add .secure(true)
.finish();
let jar = jar.add(cookie);
tracing::info!(username = %user.username.trim(), user_id = user.id, "Login successful");
// Return user data (token is now in cookie, but also include for backward compatibility)
(jar, (StatusCode::OK, Json(serde_json::json!({ "token": token, "user": user })))).into_response()
}
// Logout endpoint - clears the auth cookie
pub async fn logout(jar: CookieJar) -> impl IntoResponse {
// Create a cookie with empty value and immediate expiration to clear it
let cookie = Cookie::build(AUTH_COOKIE_NAME, "")
.http_only(true)
.path("/")
.max_age(time::Duration::seconds(0))
.same_site(SameSite::Lax)
.finish();
let jar = jar.remove(Cookie::named(AUTH_COOKIE_NAME)).add(cookie);
tracing::info!("User logged out");
(jar, (StatusCode::OK, Json(serde_json::json!({ "message": "Logged out successfully" })))).into_response()
}
pub async fn auth_middleware(
State(state): State<AppState>,
jar: CookieJar,
mut request: HttpRequest<Body>,
next: Next<Body>,
) -> Result<Response, Response> {
// Manually extract Authorization header
let auth_header = match request.headers().get(axum::http::header::AUTHORIZATION) {
Some(header) => header.to_str().ok(),
None => return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Missing authorization header"}))).into_response()),
};
// Try to get token from cookie first, then fall back to Authorization header
let token = if let Some(cookie) = jar.get(AUTH_COOKIE_NAME) {
tracing::debug!("Auth middleware: using token from cookie");
cookie.value().to_string()
} else {
// Fall back to Authorization header for API client compatibility
let auth_header = match request.headers().get(axum::http::header::AUTHORIZATION) {
Some(header) => header.to_str().ok(),
None => {
tracing::warn!("Auth middleware: no cookie or authorization header");
return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Missing authentication"}))).into_response());
},
};
let token = match auth_header.and_then(|h| h.strip_prefix("Bearer ")) {
Some(t) => t,
None => return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid authorization header"}))).into_response()),
match auth_header.and_then(|h| h.strip_prefix("Bearer ")) {
Some(t) => {
tracing::debug!("Auth middleware: using token from Authorization header");
t.to_string()
},
None => {
tracing::warn!("Auth middleware: invalid authorization header format");
return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid authorization header"}))).into_response());
},
}
};
let claims = match decode::<Claims>(
token,
&token,
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(token_data) => token_data.claims,
Err(_) => return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid token"}))).into_response()),
Err(e) => {
tracing::warn!(error = %e, "Auth middleware: invalid token");
return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid token"}))).into_response());
},
};
// Fetch user from database using username from claims
// Note: Database might pad CHAR fields with spaces, so we trim the username
let trimmed_username = claims.sub.trim();
eprintln!("Looking up user: '{}' (trimmed: '{}')", &claims.sub, trimmed_username);
tracing::debug!("Looking up user: '{}' (trimmed: '{}')", &claims.sub, trimmed_username);
let user = match sqlx::query_as::<_, User>("SELECT * FROM users WHERE TRIM(username) = $1")
.bind(trimmed_username)
.fetch_one(&*state.db)
.await
{
Ok(user) => {
eprintln!("Found user: {}", user.username.trim());
tracing::debug!("Found user: {}", user.username.trim());
user
},
Err(e) => {
eprintln!("Database error finding user '{}' : {:?}", trimmed_username, e);
tracing::error!("Database error finding user '{}' : {:?}", trimmed_username, e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "User not found"}))).into_response());
}
};

View File

@@ -10,7 +10,7 @@ use crate::state::structs::ErrorResponse;
pub async fn get_all_categories(
State(app_state): State<AppState>,
) -> Result<Json<Vec<ServiceCategory>>, (StatusCode, Json<ErrorResponse>)> {
println!("Querying all service categories");
tracing::info!("Querying all service categories");
match sqlx::query_as::<_, ServiceCategory>("SELECT id, name, description, clicks_total, total_companies
FROM service_categories ORDER BY name ASC")
@@ -18,10 +18,11 @@ pub async fn get_all_categories(
.await
{
Ok(categories) => {
tracing::info!(count = categories.len(), "Retrieved service categories");
Ok(Json(categories))
}
Err(e) => {
eprintln!("Database error fetching service categories: {}", e);
tracing::error!(error = %e, "Database error fetching service categories");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {

View File

@@ -48,6 +48,7 @@ pub async fn get_company(
State(state): State<AppState>,
Extension(user): Extension<User>,
) -> Response {
tracing::info!(user_id = user.id, "Fetching company for user");
match sqlx::query_as::<_, Company>(
"SELECT * FROM company WHERE user_id = $1 AND active = true"
)
@@ -55,9 +56,18 @@ pub async fn get_company(
.fetch_optional(&*state.db)
.await
{
Ok(Some(company)) => (StatusCode::OK, Json(company)).into_response(),
Ok(None) => (StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response(),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
Ok(Some(company)) => {
tracing::info!(user_id = user.id, company_id = company.id, "Company found");
(StatusCode::OK, Json(company)).into_response()
},
Ok(None) => {
tracing::warn!(user_id = user.id, "No company found for user");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
},
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error fetching company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
}
}
@@ -66,15 +76,22 @@ pub async fn create_company(
Extension(user): Extension<User>,
Json(payload): Json<CompanyRequest>,
) -> Response {
tracing::info!(user_id = user.id, company_name = %payload.name, "Creating company");
// Check if company already exists
match sqlx::query("SELECT 1 FROM company WHERE user_id = $1 AND active = true")
.bind(user.id)
.fetch_optional(&*state.db)
.await
{
Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response(),
Ok(Some(_)) => {
tracing::warn!(user_id = user.id, "Company already exists for user");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response()
},
Ok(None) => {},
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error checking existing company");
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
}
match sqlx::query_as::<_, Company>(
@@ -102,6 +119,7 @@ pub async fn update_company(
Extension(user): Extension<User>,
Json(payload): Json<CompanyRequest>,
) -> Response {
tracing::info!(user_id = user.id, company_name = %payload.name, "Updating company");
match sqlx::query_as::<_, Company>(
"UPDATE company SET name = $1, address = $2, town = $3, state = $4::text, phone = $5, owner_name = $6, owner_phone_number = $7, email = $8 WHERE user_id = $9 AND active = true RETURNING *"
)
@@ -127,6 +145,7 @@ pub async fn delete_company(
State(state): State<AppState>,
Extension(user): Extension<User>,
) -> Response {
tracing::info!(user_id = user.id, "Deleting company (soft delete)");
match sqlx::query("UPDATE company SET active = false WHERE user_id = $1 AND active = true")
.bind(user.id)
.execute(&*state.db)
@@ -134,61 +153,91 @@ pub async fn delete_company(
{
Ok(result) => {
if result.rows_affected() == 0 {
tracing::warn!(user_id = user.id, "No company found to delete");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
} else {
tracing::info!(user_id = user.id, "Company deleted successfully");
(StatusCode::OK, Json(json!({"success": true, "message": "Company deleted"}))).into_response()
}
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error deleting company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
}
}
pub async fn company_handler(
request: Request<Body>,
) -> impl IntoResponse {
let method = request.method().clone();
tracing::debug!(method = %method, "Company handler invoked");
// Extract user and state from extensions
let user = match request.extensions().get::<User>().cloned() {
Some(user) => user,
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(),
None => {
tracing::warn!("Unauthorized access attempt to company handler");
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response();
},
};
let state = match request.extensions().get::<AppState>().cloned() {
Some(state) => state,
None => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "State not found"}))).into_response(),
None => {
tracing::error!("App state not found in request extensions");
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "State not found"}))).into_response();
},
};
let method = request.method().clone();
tracing::info!(user_id = user.id, method = %method, "Processing company request");
match method {
Method::GET => get_company_logic(&state, &user).await,
Method::POST => {
let body = match hyper::body::to_bytes(request.into_body()).await {
Ok(bytes) => bytes,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response(),
Err(e) => {
tracing::error!(error = %e, "Failed to read request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response();
},
};
let payload: CompanyRequest = match serde_json::from_slice(&body) {
Ok(data) => data,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response(),
Err(e) => {
tracing::warn!(error = %e, "Invalid JSON in request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response();
},
};
create_company_logic(&state, &user, payload).await
}
Method::PUT => {
let body = match hyper::body::to_bytes(request.into_body()).await {
Ok(bytes) => bytes,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response(),
Err(e) => {
tracing::error!(error = %e, "Failed to read request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response();
},
};
let payload: CompanyRequest = match serde_json::from_slice(&body) {
Ok(data) => data,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response(),
Err(e) => {
tracing::warn!(error = %e, "Invalid JSON in request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response();
},
};
update_company_logic(&state, &user, payload).await
}
Method::DELETE => delete_company_logic(&state, &user).await,
_ => (StatusCode::METHOD_NOT_ALLOWED, Json(json!({"error": "Method not allowed"}))).into_response(),
_ => {
tracing::warn!(method = %method, "Method not allowed for company endpoint");
(StatusCode::METHOD_NOT_ALLOWED, Json(json!({"error": "Method not allowed"}))).into_response()
},
}
}
async fn get_company_logic(state: &AppState, user: &User) -> Response {
tracing::debug!(user_id = user.id, "get_company_logic called");
match sqlx::query_as::<_, Company>(
"SELECT * FROM company WHERE user_id = $1 AND active = true"
)
@@ -196,22 +245,38 @@ async fn get_company_logic(state: &AppState, user: &User) -> Response {
.fetch_optional(&*state.db)
.await
{
Ok(Some(company)) => (StatusCode::OK, Json(company)).into_response(),
Ok(None) => (StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response(),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
Ok(Some(company)) => {
tracing::info!(user_id = user.id, company_id = company.id, "Company retrieved");
(StatusCode::OK, Json(company)).into_response()
},
Ok(None) => {
tracing::warn!(user_id = user.id, "Company not found");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
},
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error in get_company_logic");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
}
}
async fn create_company_logic(state: &AppState, user: &User, payload: CompanyRequest) -> Response {
tracing::debug!(user_id = user.id, company_name = %payload.name, "create_company_logic called");
// Check if company already exists
match sqlx::query("SELECT 1 FROM company WHERE user_id = $1 AND active = true")
.bind(user.id)
.fetch_optional(&*state.db)
.await
{
Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response(),
Ok(Some(_)) => {
tracing::warn!(user_id = user.id, "Company already exists");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response()
},
Ok(None) => {},
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error checking existing company");
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
}
match sqlx::query_as::<_, Company>(
@@ -235,7 +300,7 @@ async fn create_company_logic(state: &AppState, user: &User, payload: CompanyReq
}
async fn update_company_logic(state: &AppState, user: &User, payload: CompanyRequest) -> Response {
eprintln!("Updating company for user {}: {:?}", user.id, payload);
tracing::debug!(user_id = user.id, company_name = %payload.name, "update_company_logic called");
match sqlx::query_as::<_, Company>(
"UPDATE company SET name = $1, address = $2, town = $3, state = $4::text, phone = $5, owner_name = $6, owner_phone_number = $7, email = $8 WHERE user_id = $9 AND active = true RETURNING *"
)
@@ -252,21 +317,22 @@ async fn update_company_logic(state: &AppState, user: &User, payload: CompanyReq
.await
{
Ok(Some(company)) => {
eprintln!("Updated company successfully");
tracing::info!(user_id = user.id, company_id = company.id, "Company updated successfully");
(StatusCode::OK, Json(company)).into_response()
},
Ok(None) => {
eprintln!("No company found to update, creating new one");
tracing::info!(user_id = user.id, "No company found to update, creating new one");
create_company_logic(state, user, payload).await
},
Err(e) => {
eprintln!("Database error updating company: {:?}", e);
tracing::error!(user_id = user.id, error = %e, "Database error updating company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
}
}
async fn delete_company_logic(state: &AppState, user: &User) -> Response {
tracing::debug!(user_id = user.id, "delete_company_logic called");
match sqlx::query("UPDATE company SET active = false WHERE user_id = $1 AND active = true")
.bind(user.id)
.execute(&*state.db)
@@ -274,11 +340,16 @@ async fn delete_company_logic(state: &AppState, user: &User) -> Response {
{
Ok(result) => {
if result.rows_affected() == 0 {
tracing::warn!(user_id = user.id, "No company found to delete");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
} else {
tracing::info!(user_id = user.id, "Company soft-deleted successfully");
(StatusCode::OK, Json(json!({"success": true, "message": "Company deleted"}))).into_response()
}
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error deleting company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
}
}

View File

@@ -1,17 +1,24 @@
use axum::{
extract::State,
Extension,
http::StatusCode,
response::IntoResponse,
Json,
};
use crate::auth::structs::AppState;
use crate::auth::structs::User;
// Define the handler for the /user/ endpoint
pub async fn get_user(State(state): State<AppState>) -> impl IntoResponse {
// Placeholder for user data retrieval logic
// In a real application, you would query the database using state.db
let users = sqlx::query("SELECT * FROM users").fetch_all(&*state.db).await;
match users {
Ok(_users) => (StatusCode::OK, "User data retrieved successfully".to_string()).into_response(),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to retrieve users: {}", e)).into_response(),
}
// Returns the authenticated user's information (password excluded)
pub async fn get_user(Extension(user): Extension<User>) -> impl IntoResponse {
tracing::info!(user_id = user.id, username = %user.username.trim(), "User info requested");
// Create a response without the password field for security
let user_response = serde_json::json!({
"id": user.id,
"username": user.username.trim(),
"email": user.email,
"created": user.created,
"last_login": user.last_login,
"owner": user.owner
});
(StatusCode::OK, Json(user_response)).into_response()
}

View File

@@ -14,6 +14,7 @@ pub async fn get_listings(
State(app_state): State<AppState>,
Extension(user): Extension<User>,
) -> Result<Json<Vec<Listing>>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(user_id = user.id, "Fetching listings for user");
match sqlx::query_as::<_, Listing>(
"SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE user_id = $1 ORDER BY id DESC"
)
@@ -21,11 +22,17 @@ pub async fn get_listings(
.fetch_all(&*app_state.db)
.await
{
Ok(listings) => Ok(Json(listings)),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to fetch listings: {}", e)}))
)),
Ok(listings) => {
tracing::info!(user_id = user.id, count = listings.len(), "Listings retrieved");
Ok(Json(listings))
},
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Failed to fetch listings");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to fetch listings: {}", e)}))
))
},
}
}
@@ -34,6 +41,7 @@ pub async fn get_listing_by_id(
Path(listing_id): Path<i32>,
Extension(user): Extension<User>,
) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(user_id = user.id, listing_id = listing_id, "Fetching listing by ID");
match sqlx::query_as::<_, Listing>(
"SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE id = $1 AND user_id = $2"
)
@@ -42,15 +50,24 @@ pub async fn get_listing_by_id(
.fetch_optional(&*app_state.db)
.await
{
Ok(Some(listing)) => Ok(Json(listing)),
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found"}))
)),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to fetch listing: {}", e)}))
)),
Ok(Some(listing)) => {
tracing::info!(user_id = user.id, listing_id = listing_id, "Listing found");
Ok(Json(listing))
},
Ok(None) => {
tracing::warn!(user_id = user.id, listing_id = listing_id, "Listing not found");
Err((
StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found"}))
))
},
Err(e) => {
tracing::error!(user_id = user.id, listing_id = listing_id, error = %e, "Failed to fetch listing");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to fetch listing: {}", e)}))
))
},
}
}
@@ -59,8 +76,15 @@ pub async fn create_listing(
Extension(user): Extension<User>,
Json(payload): Json<CreateListingRequest>,
) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> {
eprintln!("DEBUG: Starting create_listing for user_id: {}", user.id);
eprintln!("DEBUG: Payload: {:?}", payload);
tracing::debug!("Starting create_listing for user_id: {}", user.id);
tracing::debug!("Payload: {:?}", payload);
if let Err(e) = payload.validate() {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({"error": e}))
));
}
// Create the listing directly without company validation
match sqlx::query_as::<_, Listing>(
@@ -83,11 +107,11 @@ pub async fn create_listing(
.await
{
Ok(listing) => {
eprintln!("DEBUG: Successfully created listing: {:?}", listing);
tracing::debug!("Successfully created listing: {:?}", listing);
Ok(Json(listing))
},
Err(e) => {
eprintln!("DEBUG: Error creating listing: {:?}", e);
tracing::error!("Error creating listing: {:?}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create listing: {}", e)}))
@@ -102,97 +126,96 @@ pub async fn update_listing(
Extension(user): Extension<User>,
Json(payload): Json<UpdateListingRequest>,
) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> {
// Build dynamic update query
let mut query = "UPDATE listings SET ".to_string();
let mut params: Vec<String> = Vec::new();
let mut param_count = 1;
if let Some(company_name) = &payload.company_name {
params.push(format!("company_name = ${}", param_count));
param_count += 1;
}
if let Some(is_active) = payload.is_active {
params.push(format!("is_active = ${}", param_count));
param_count += 1;
}
if let Some(price_per_gallon) = payload.price_per_gallon {
params.push(format!("price_per_gallon = ${}", param_count));
param_count += 1;
}
if let Some(price_per_gallon_cash) = payload.price_per_gallon_cash {
params.push(format!("price_per_gallon_cash = ${}", param_count));
param_count += 1;
}
if let Some(note) = &payload.note {
params.push(format!("note = ${}", param_count));
param_count += 1;
}
if let Some(minimum_order) = payload.minimum_order {
params.push(format!("minimum_order = ${}", param_count));
param_count += 1;
}
if let Some(service) = payload.service {
params.push(format!("service = ${}", param_count));
param_count += 1;
}
if let Some(bio_percent) = payload.bio_percent {
params.push(format!("bio_percent = ${}", param_count));
param_count += 1;
}
if let Some(phone) = &payload.phone {
params.push(format!("phone = ${}", param_count));
param_count += 1;
}
if let Some(online_ordering) = &payload.online_ordering {
params.push(format!("online_ordering = ${}", param_count));
param_count += 1;
}
if let Some(county_id) = payload.county_id {
params.push(format!("county_id = ${}", param_count));
param_count += 1;
}
if params.is_empty() {
tracing::info!(user_id = user.id, listing_id = listing_id, "Updating listing");
if let Err(e) = payload.validate() {
tracing::warn!(user_id = user.id, listing_id = listing_id, error = %e, "Validation failed for update");
return Err((
StatusCode::BAD_REQUEST,
Json(json!({"error": "No fields to update"}))
Json(json!({"error": e}))
));
}
query.push_str(&params.join(", "));
query.push_str(&format!(" WHERE id = ${} AND user_id = ${} RETURNING *", param_count, param_count + 1));
let mut query_builder = sqlx::QueryBuilder::new("UPDATE listings SET ");
let mut separated = query_builder.separated(", ");
// This is a simplified version - in production, you'd want to build the query more safely
// For now, let's use a simpler approach
match sqlx::query_as::<_, Listing>(
"UPDATE listings SET company_name = COALESCE($1, company_name), is_active = COALESCE($2, is_active), price_per_gallon = COALESCE($3, price_per_gallon), price_per_gallon_cash = COALESCE($4, price_per_gallon_cash), note = COALESCE($5, note), minimum_order = COALESCE($6, minimum_order), service = COALESCE($7, service), bio_percent = COALESCE($8, bio_percent), phone = COALESCE($9, phone), online_ordering = COALESCE($10, online_ordering), county_id = COALESCE($11, county_id), town = COALESCE($12, town), last_edited = CURRENT_TIMESTAMP WHERE id = $13 AND user_id = $14 RETURNING id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited"
)
.bind(&payload.company_name)
.bind(payload.is_active)
.bind(payload.price_per_gallon)
.bind(payload.price_per_gallon_cash)
.bind(&payload.note)
.bind(payload.minimum_order)
.bind(payload.service)
.bind(payload.bio_percent)
.bind(&payload.phone)
.bind(&payload.online_ordering)
.bind(payload.county_id)
.bind(&payload.town)
.bind(listing_id)
.bind(user.id)
.fetch_optional(&*app_state.db)
.await
{
Ok(Some(listing)) => Ok(Json(listing)),
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found"}))
)),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to update listing: {}", e)}))
)),
if let Some(company_name) = &payload.company_name {
separated.push("company_name = ");
separated.push_bind_unseparated(company_name);
}
if let Some(is_active) = payload.is_active {
separated.push("is_active = ");
separated.push_bind_unseparated(is_active);
}
if let Some(price_per_gallon) = payload.price_per_gallon {
separated.push("price_per_gallon = ");
separated.push_bind_unseparated(price_per_gallon);
}
if let Some(price_per_gallon_cash) = payload.price_per_gallon_cash {
separated.push("price_per_gallon_cash = ");
separated.push_bind_unseparated(price_per_gallon_cash);
}
if let Some(note) = &payload.note {
separated.push("note = ");
separated.push_bind_unseparated(note);
}
if let Some(minimum_order) = payload.minimum_order {
separated.push("minimum_order = ");
separated.push_bind_unseparated(minimum_order);
}
if let Some(service) = payload.service {
separated.push("service = ");
separated.push_bind_unseparated(service);
}
if let Some(bio_percent) = payload.bio_percent {
separated.push("bio_percent = ");
separated.push_bind_unseparated(bio_percent);
}
if let Some(phone) = &payload.phone {
separated.push("phone = ");
separated.push_bind_unseparated(phone);
}
if let Some(online_ordering) = &payload.online_ordering {
separated.push("online_ordering = ");
separated.push_bind_unseparated(online_ordering);
}
if let Some(county_id) = payload.county_id {
separated.push("county_id = ");
separated.push_bind_unseparated(county_id);
}
if let Some(town) = &payload.town {
separated.push("town = ");
separated.push_bind_unseparated(town);
}
separated.push("last_edited = CURRENT_TIMESTAMP");
query_builder.push(" WHERE id = ");
query_builder.push_bind(listing_id);
query_builder.push(" AND user_id = ");
query_builder.push_bind(user.id);
query_builder.push(" RETURNING id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited");
let query = query_builder.build_query_as::<Listing>();
match query.fetch_optional(&*app_state.db).await {
Ok(Some(listing)) => {
tracing::info!(user_id = user.id, listing_id = listing_id, "Listing updated successfully");
Ok(Json(listing))
},
Ok(None) => {
tracing::warn!(user_id = user.id, listing_id = listing_id, "Listing not found or access denied");
Err((
StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found or access denied"}))
))
},
Err(e) => {
tracing::error!(user_id = user.id, listing_id = listing_id, error = %e, "Failed to update listing");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to update listing: {}", e)}))
))
},
}
}
@@ -200,6 +223,7 @@ pub async fn get_listings_by_county(
State(app_state): State<AppState>,
Path(county_id): Path<i32>,
) -> Result<Json<Vec<Listing>>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(county_id = county_id, "Fetching listings by county");
match sqlx::query_as::<_, Listing>(
"SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE county_id = $1 AND is_active = true ORDER BY last_edited DESC"
)
@@ -220,6 +244,7 @@ pub async fn delete_listing(
Path(listing_id): Path<i32>,
Extension(user): Extension<User>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(user_id = user.id, listing_id = listing_id, "Deleting listing");
match sqlx::query("DELETE FROM listings WHERE id = $1 AND user_id = $2")
.bind(listing_id)
.bind(user.id)
@@ -228,17 +253,22 @@ pub async fn delete_listing(
{
Ok(result) => {
if result.rows_affected() == 0 {
tracing::warn!(user_id = user.id, listing_id = listing_id, "Listing not found for deletion");
Err((
StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found"}))
))
} else {
tracing::info!(user_id = user.id, listing_id = listing_id, "Listing deleted successfully");
Ok(Json(json!({"success": true, "message": "Listing deleted"})))
}
}
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete listing: {}", e)}))
)),
Err(e) => {
tracing::error!(user_id = user.id, listing_id = listing_id, error = %e, "Failed to delete listing");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete listing: {}", e)}))
))
},
}
}

View File

@@ -38,6 +38,28 @@ pub struct CreateListingRequest {
pub town: Option<String>,
}
impl CreateListingRequest {
pub fn validate(&self) -> Result<(), String> {
if self.price_per_gallon <= 0.0 {
return Err("Price per gallon must be greater than 0".to_string());
}
if let Some(cash_price) = self.price_per_gallon_cash {
if cash_price < 0.0 {
return Err("Cash price must be non-negative".to_string());
}
}
if self.bio_percent < 0 || self.bio_percent > 100 {
return Err("Bio percent must be between 0 and 100".to_string());
}
if let Some(min_order) = self.minimum_order {
if min_order < 0 {
return Err("Minimum order must be non-negative".to_string());
}
}
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UpdateListingRequest {
pub company_name: Option<String>,
@@ -53,3 +75,29 @@ pub struct UpdateListingRequest {
pub county_id: Option<i32>,
pub town: Option<String>,
}
impl UpdateListingRequest {
pub fn validate(&self) -> Result<(), String> {
if let Some(price) = self.price_per_gallon {
if price <= 0.0 {
return Err("Price per gallon must be greater than 0".to_string());
}
}
if let Some(cash_price) = self.price_per_gallon_cash {
if cash_price < 0.0 {
return Err("Cash price must be non-negative".to_string());
}
}
if let Some(bio) = self.bio_percent {
if bio < 0 || bio > 100 {
return Err("Bio percent must be between 0 and 100".to_string());
}
}
if let Some(min_order) = self.minimum_order {
if min_order < 0 {
return Err("Minimum order must be non-negative".to_string());
}
}
Ok(())
}
}

View File

@@ -1,12 +1,11 @@
use axum::{
http::{header, Method},
Router,
};
use std::env;
use tower_http::cors::{CorsLayer, Any};
use tower_http::cors::CorsLayer;
use crate::auth::structs::AppState;
use crate::auth::auth::{auth_middleware, login, register};
use crate::auth::auth::{auth_middleware, login, register, logout};
use crate::data::data::get_user;
use crate::state::data::{get_counties_by_state, get_county_by_id};
use crate::listing::data::{get_listings, get_listing_by_id, get_listings_by_county, create_listing, update_listing, delete_listing};
@@ -22,17 +21,32 @@ mod listing;
#[tokio::main]
async fn main() {
// Initialize tracing first (before any logging)
// RUST_LOG env var controls log level, e.g. RUST_LOG=debug or RUST_LOG=api_rust=debug
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("api_rust=info".parse().unwrap())
)
.init();
tracing::info!("Starting NewEnglandBio API server...");
// Load environment variables
dotenv::dotenv().ok();
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let _jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
let frontend_origin = env::var("FRONTEND_ORIGIN").unwrap_or_else(|_| "http://localhost:9551".to_string());
tracing::info!(frontend_origin = %frontend_origin, "Configuration loaded");
// Connect to PostgreSQL
tracing::info!("Connecting to PostgreSQL database...");
let db_pool = PgPool::connect(&database_url)
.await
.expect("Failed to connect to database");
tracing::info!("Database connection established");
let db = Arc::new(db_pool);
// Create app state
@@ -41,11 +55,14 @@ async fn main() {
jwt_secret: env::var("JWT_SECRET").expect("JWT_SECRET must be set"),
};
// Configure CORS
// Configure CORS with credentials support for cookie auth
let cors = CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::exact(frontend_origin.parse::<header::HeaderValue>().unwrap()))
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_headers(Any);
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
.allow_credentials(true);
tracing::debug!("CORS configured for origin: {} with credentials", frontend_origin);
// Build router with separated public and protected routes
let protected_routes = Router::new()
@@ -61,15 +78,20 @@ async fn main() {
let public_routes = Router::new()
.route("/auth/register", axum::routing::post(register))
.route("/auth/login", axum::routing::post(login))
.route("/auth/logout", axum::routing::post(logout))
.route("/categories", axum::routing::get(crate::company::category::get_all_categories))
.route("/state/:state_abbr", axum::routing::get(get_counties_by_state))
.route("/state/:state_abbr/:county_id", axum::routing::get(get_county_by_id))
.route("/listings/county/:county_id", axum::routing::get(get_listings_by_county));
let app = public_routes.merge(protected_routes).with_state(state).layer(cors);
let app = public_routes
.merge(protected_routes)
.with_state(state)
.layer(cors);
// Print server status
println!("Server is running on http://0.0.0.0:9552");
tracing::info!("Routes configured");
tracing::info!("Server is running on http://0.0.0.0:9552");
tracing::info!("Press Ctrl+C to stop");
// Run server
axum::Server::bind(&"0.0.0.0:9552".parse().unwrap())

View File

@@ -14,7 +14,7 @@ pub async fn get_counties_by_state(
Path(state_abbr): Path<String>,
) -> Result<Json<Vec<County>>, (StatusCode, Json<ErrorResponse>)> {
let state_abbr_upper = state_abbr.to_uppercase();
println!("Querying counties for state: {}", state_abbr_upper);
tracing::info!(state = %state_abbr_upper, "Querying counties for state");
match sqlx::query_as::<_, County>("SELECT id, name, state FROM county WHERE UPPER(state) = $1 ORDER BY name ASC")
.bind(&state_abbr_upper)
@@ -23,6 +23,7 @@ pub async fn get_counties_by_state(
{
Ok(counties) => {
if counties.is_empty() {
tracing::warn!(state = %state_abbr_upper, "No counties found for state");
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
@@ -30,11 +31,12 @@ pub async fn get_counties_by_state(
}),
))
} else {
tracing::info!(state = %state_abbr_upper, count = counties.len(), "Counties retrieved");
Ok(Json(counties))
}
}
Err(e) => {
eprintln!("Database error fetching counties for state {}: {}", state_abbr_upper, e);
tracing::error!(state = %state_abbr_upper, error = %e, "Database error fetching counties");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
@@ -50,7 +52,7 @@ pub async fn get_county_by_id(
Path((state_abbr, county_id)): Path<(String, i32)>,
) -> Result<Json<County>, (StatusCode, Json<ErrorResponse>)> {
let state_abbr_upper = state_abbr.to_uppercase();
println!("Querying county with ID: {} for state: {}", county_id, state_abbr_upper);
tracing::info!(state = %state_abbr_upper, county_id = county_id, "Querying county by ID");
match sqlx::query_as::<_, County>("SELECT id, name, state FROM county WHERE UPPER(state) = $1 AND id = $2")
.bind(&state_abbr_upper)
@@ -58,15 +60,21 @@ pub async fn get_county_by_id(
.fetch_one(&*app_state.db)
.await
{
Ok(county) => Ok(Json(county)),
Err(sqlx::Error::RowNotFound) => Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("County with ID {} not found in state {}", county_id, state_abbr_upper),
}),
)),
Ok(county) => {
tracing::info!(state = %state_abbr_upper, county_id = county_id, "County retrieved");
Ok(Json(county))
},
Err(sqlx::Error::RowNotFound) => {
tracing::warn!(state = %state_abbr_upper, county_id = county_id, "County not found");
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!("County with ID {} not found in state {}", county_id, state_abbr_upper),
}),
))
},
Err(e) => {
eprintln!("Database error fetching county with ID {} for state {}: {}", county_id, state_abbr_upper, e);
tracing::error!(state = %state_abbr_upper, county_id = county_id, error = %e, "Database error fetching county");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {