refactor(auth): migrate to httpOnly cookies and update vendor listings

Migrated JWT authentication from localStorage to httpOnly cookies using axum-extra. Refactored vendor listing and edit pages to use the centralized API client. Updated schema and data models to support these changes.
This commit is contained in:
2026-02-09 16:25:38 -05:00
parent feb1a173ec
commit caa318508b
11 changed files with 612 additions and 195 deletions

163
Cargo.lock generated
View File

@@ -40,6 +40,15 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "aho-corasick"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "allocator-api2" name = "allocator-api2"
version = "0.2.21" version = "0.2.21"
@@ -67,6 +76,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"argon2", "argon2",
"axum", "axum",
"axum-extra",
"chrono", "chrono",
"dotenv", "dotenv",
"hyper", "hyper",
@@ -74,8 +84,11 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",
"time",
"tokio", "tokio",
"tower-http", "tower-http",
"tracing",
"tracing-subscriber",
] ]
[[package]] [[package]]
@@ -166,6 +179,28 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-extra"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a93e433be9382c737320af3924f7d5fc6f89c155cf2bf88949d8f5126fab283f"
dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"http",
"http-body",
"mime",
"pin-project-lite",
"serde",
"tokio",
"tower",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.75" version = "0.3.75"
@@ -277,6 +312,17 @@ dependencies = [
"windows-link", "windows-link",
] ]
[[package]]
name = "cookie"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7efb37c3e1ccb1ff97164ad95ac1606e8ccd35b3fa0a7d99a304c7f4a428cc24"
dependencies = [
"percent-encoding",
"time",
"version_check",
]
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.7" version = "0.8.7"
@@ -826,6 +872,12 @@ dependencies = [
"simple_asn1", "simple_asn1",
] ]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.172" version = "0.2.172"
@@ -864,6 +916,15 @@ version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.7.3" version = "0.7.3"
@@ -928,6 +989,15 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "num-bigint" name = "num-bigint"
version = "0.4.6" version = "0.4.6"
@@ -1190,6 +1260,23 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "regex-automata"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c"
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.16.20" version = "0.16.20"
@@ -1350,6 +1437,15 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.3.0" version = "1.3.0"
@@ -1605,6 +1701,15 @@ dependencies = [
"syn 2.0.102", "syn 2.0.102",
] ]
[[package]]
name = "thread_local"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "time" name = "time"
version = "0.3.41" version = "0.3.41"
@@ -1760,22 +1865,64 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.41" version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [ dependencies = [
"log", "log",
"pin-project-lite", "pin-project-lite",
"tracing-attributes",
"tracing-core", "tracing-core",
] ]
[[package]] [[package]]
name = "tracing-core" name = "tracing-attributes"
version = "0.1.34" version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.102",
]
[[package]]
name = "tracing-core"
version = "0.1.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
] ]
[[package]] [[package]]
@@ -1858,6 +2005,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.5" version = "0.9.5"

View File

@@ -15,3 +15,7 @@ dotenv = "0.15"
tower-http = { version = "0.4", features = ["cors"] } tower-http = { version = "0.4", features = ["cors"] }
argon2 = { version = "0.5.3", features = ["std"] } argon2 = { version = "0.5.3", features = ["std"] }
hyper = "0.14" hyper = "0.14"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
axum-extra = { version = "0.7", features = ["cookie"] }
time = "0.3"

View File

@@ -16,7 +16,7 @@ CREATE TABLE service_categories (
total_companies INTEGER DEFAULT 0 total_companies INTEGER DEFAULT 0
); );
CREATE TABLE companies ( CREATE TABLE company (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
active BOOLEAN DEFAULT true, active BOOLEAN DEFAULT true,
created DATE NOT NULL DEFAULT CURRENT_DATE, created DATE NOT NULL DEFAULT CURRENT_DATE,
@@ -31,6 +31,14 @@ CREATE TABLE companies (
user_id INTEGER user_id INTEGER
); );
-- Counties (populated by scripts/add_county_to_db.py)
CREATE TABLE county (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
state VARCHAR(2) NOT NULL,
UNIQUE(name, state)
);
CREATE TABLE listings ( CREATE TABLE listings (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
company_name VARCHAR(255) NOT NULL, company_name VARCHAR(255) NOT NULL,

View File

@@ -6,6 +6,7 @@ use axum::{
body::Body, body::Body,
http::Request as HttpRequest, http::Request as HttpRequest,
}; };
use axum_extra::extract::cookie::{CookieJar, Cookie, SameSite};
use crate::auth::structs::{AppState, User, RegisterRequest, LoginRequest, Claims}; use crate::auth::structs::{AppState, User, RegisterRequest, LoginRequest, Claims};
use argon2::{ use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
@@ -13,13 +14,16 @@ use argon2::{
}; };
use jsonwebtoken::{decode, encode, Header, EncodingKey, DecodingKey, Validation}; use jsonwebtoken::{decode, encode, Header, EncodingKey, DecodingKey, Validation};
// Cookie configuration constants
const AUTH_COOKIE_NAME: &str = "auth_token";
// A helper function to convert any error into a 500 Internal Server Error response. // A helper function to convert any error into a 500 Internal Server Error response.
fn internal_error<E>(err: E) -> Response fn internal_error<E>(err: E) -> Response
where where
E: std::error::Error, E: std::error::Error,
{ {
// Log the specific error to the server console for debugging. // Log the specific error to the server console for debugging.
eprintln!("Internal server error: {}", err); tracing::error!("Internal server error: {}", err);
// Return a generic error message to the client. // Return a generic error message to the client.
( (
@@ -34,6 +38,7 @@ pub async fn register(
State(state): State<AppState>, State(state): State<AppState>,
Json(payload): Json<RegisterRequest>, Json(payload): Json<RegisterRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
tracing::info!(username = %payload.username, "Registration attempt");
// 1. Check if username exists, handling potential database errors // 1. Check if username exists, handling potential database errors
let user_exists = match sqlx::query("SELECT 1 FROM users WHERE username = $1") let user_exists = match sqlx::query("SELECT 1 FROM users WHERE username = $1")
.bind(&payload.username) .bind(&payload.username)
@@ -45,6 +50,7 @@ pub async fn register(
}; };
if user_exists { if user_exists {
tracing::warn!(username = %payload.username, "Registration failed: username already exists");
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "Username already exists" })), Json(serde_json::json!({ "error": "Username already exists" })),
@@ -71,20 +77,22 @@ pub async fn register(
.await; .await;
match result { match result {
Ok(user) => (StatusCode::CREATED, Json(user)).into_response(), Ok(user) => {
tracing::info!(username = %payload.username, user_id = user.id, "User registered successfully");
(StatusCode::CREATED, Json(user)).into_response()
},
Err(e) => internal_error(e), Err(e) => internal_error(e),
} }
} }
// Login endpoint - sets JWT as httpOnly cookie
// Updated Login endpoint
pub async fn login( pub async fn login(
State(state): State<AppState>, State(state): State<AppState>,
jar: CookieJar,
Json(payload): Json<LoginRequest>, Json(payload): Json<LoginRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
tracing::info!(username = %payload.username.trim(), "Login attempt");
// 1. Fetch user from the database // 1. Fetch user from the database
let user = match sqlx::query_as::<_, User>("SELECT * FROM users WHERE TRIM(username) = $1") let user = match sqlx::query_as::<_, User>("SELECT * FROM users WHERE TRIM(username) = $1")
.bind(&payload.username.trim()) .bind(&payload.username.trim())
@@ -94,9 +102,13 @@ pub async fn login(
Ok(Some(user)) => user, Ok(Some(user)) => user,
Ok(None) => { Ok(None) => {
// User not found. Use a generic error message to prevent username enumeration attacks. // User not found. Use a generic error message to prevent username enumeration attacks.
return (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" }))).into_response(); tracing::warn!(username = %payload.username.trim(), "Login failed: user not found");
return (jar, (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" })))).into_response();
}
Err(e) => {
tracing::error!("Database error during login: {}", e);
return (jar, (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "An internal server error occurred"})))).into_response();
} }
Err(e) => return internal_error(e), // Database query failed
}; };
// --- FIX: Trim whitespace from the password hash string --- // --- FIX: Trim whitespace from the password hash string ---
@@ -108,7 +120,8 @@ pub async fn login(
Ok(hash) => hash, Ok(hash) => hash,
Err(e) => { Err(e) => {
// This is a server error because the hash in the DB is malformed. // This is a server error because the hash in the DB is malformed.
return internal_error(e); tracing::error!("Failed to parse password hash: {}", e);
return (jar, (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "An internal server error occurred"})))).into_response();
} }
}; };
@@ -118,7 +131,8 @@ pub async fn login(
.is_err() .is_err()
{ {
// Passwords do not match. // Passwords do not match.
return (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" }))).into_response(); tracing::warn!(username = %payload.username.trim(), "Login failed: invalid password");
return (jar, (StatusCode::UNAUTHORIZED, Json(serde_json::json!({ "error": "Invalid credentials" })))).into_response();
} }
// 4. Update last_login. If this fails, log it but don't fail the login. // 4. Update last_login. If this fails, log it but don't fail the login.
@@ -129,7 +143,7 @@ pub async fn login(
.execute(&*state.db) .execute(&*state.db)
.await .await
{ {
eprintln!("Failed to update last_login for user {}: {:?}", user.username, e); tracing::error!("Failed to update last_login for user {}: {:?}", user.username, e);
} }
// 5. Generate JWT // 5. Generate JWT
@@ -144,52 +158,103 @@ pub async fn login(
&EncodingKey::from_secret(state.jwt_secret.as_bytes()), &EncodingKey::from_secret(state.jwt_secret.as_bytes()),
) { ) {
Ok(t) => t, Ok(t) => t,
Err(e) => return internal_error(e), // JWT generation failed Err(e) => {
tracing::error!("Failed to generate JWT: {}", e);
return (jar, (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "An internal server error occurred"})))).into_response();
}
}; };
(StatusCode::OK, Json(serde_json::json!({ "token": token, "user": user }))).into_response() // 6. Set JWT as httpOnly cookie
let cookie = Cookie::build(AUTH_COOKIE_NAME, token.clone())
.http_only(true)
.path("/")
.max_age(time::Duration::hours(24))
.same_site(SameSite::Lax)
// Note: In production with HTTPS, also add .secure(true)
.finish();
let jar = jar.add(cookie);
tracing::info!(username = %user.username.trim(), user_id = user.id, "Login successful");
// Return user data (token is now in cookie, but also include for backward compatibility)
(jar, (StatusCode::OK, Json(serde_json::json!({ "token": token, "user": user })))).into_response()
}
// Logout endpoint - clears the auth cookie
pub async fn logout(jar: CookieJar) -> impl IntoResponse {
// Create a cookie with empty value and immediate expiration to clear it
let cookie = Cookie::build(AUTH_COOKIE_NAME, "")
.http_only(true)
.path("/")
.max_age(time::Duration::seconds(0))
.same_site(SameSite::Lax)
.finish();
let jar = jar.remove(Cookie::named(AUTH_COOKIE_NAME)).add(cookie);
tracing::info!("User logged out");
(jar, (StatusCode::OK, Json(serde_json::json!({ "message": "Logged out successfully" })))).into_response()
} }
pub async fn auth_middleware( pub async fn auth_middleware(
State(state): State<AppState>, State(state): State<AppState>,
jar: CookieJar,
mut request: HttpRequest<Body>, mut request: HttpRequest<Body>,
next: Next<Body>, next: Next<Body>,
) -> Result<Response, Response> { ) -> Result<Response, Response> {
// Manually extract Authorization header // Try to get token from cookie first, then fall back to Authorization header
let token = if let Some(cookie) = jar.get(AUTH_COOKIE_NAME) {
tracing::debug!("Auth middleware: using token from cookie");
cookie.value().to_string()
} else {
// Fall back to Authorization header for API client compatibility
let auth_header = match request.headers().get(axum::http::header::AUTHORIZATION) { let auth_header = match request.headers().get(axum::http::header::AUTHORIZATION) {
Some(header) => header.to_str().ok(), Some(header) => header.to_str().ok(),
None => return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Missing authorization header"}))).into_response()), None => {
tracing::warn!("Auth middleware: no cookie or authorization header");
return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Missing authentication"}))).into_response());
},
}; };
let token = match auth_header.and_then(|h| h.strip_prefix("Bearer ")) { match auth_header.and_then(|h| h.strip_prefix("Bearer ")) {
Some(t) => t, Some(t) => {
None => return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid authorization header"}))).into_response()), tracing::debug!("Auth middleware: using token from Authorization header");
t.to_string()
},
None => {
tracing::warn!("Auth middleware: invalid authorization header format");
return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid authorization header"}))).into_response());
},
}
}; };
let claims = match decode::<Claims>( let claims = match decode::<Claims>(
token, &token,
&DecodingKey::from_secret(state.jwt_secret.as_bytes()), &DecodingKey::from_secret(state.jwt_secret.as_bytes()),
&Validation::default(), &Validation::default(),
) { ) {
Ok(token_data) => token_data.claims, Ok(token_data) => token_data.claims,
Err(_) => return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid token"}))).into_response()), Err(e) => {
tracing::warn!(error = %e, "Auth middleware: invalid token");
return Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid token"}))).into_response());
},
}; };
// Fetch user from database using username from claims // Fetch user from database using username from claims
// Note: Database might pad CHAR fields with spaces, so we trim the username // Note: Database might pad CHAR fields with spaces, so we trim the username
let trimmed_username = claims.sub.trim(); let trimmed_username = claims.sub.trim();
eprintln!("Looking up user: '{}' (trimmed: '{}')", &claims.sub, trimmed_username); tracing::debug!("Looking up user: '{}' (trimmed: '{}')", &claims.sub, trimmed_username);
let user = match sqlx::query_as::<_, User>("SELECT * FROM users WHERE TRIM(username) = $1") let user = match sqlx::query_as::<_, User>("SELECT * FROM users WHERE TRIM(username) = $1")
.bind(trimmed_username) .bind(trimmed_username)
.fetch_one(&*state.db) .fetch_one(&*state.db)
.await .await
{ {
Ok(user) => { Ok(user) => {
eprintln!("Found user: {}", user.username.trim()); tracing::debug!("Found user: {}", user.username.trim());
user user
}, },
Err(e) => { Err(e) => {
eprintln!("Database error finding user '{}' : {:?}", trimmed_username, e); tracing::error!("Database error finding user '{}' : {:?}", trimmed_username, e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "User not found"}))).into_response()); return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "User not found"}))).into_response());
} }
}; };

View File

@@ -10,7 +10,7 @@ use crate::state::structs::ErrorResponse;
pub async fn get_all_categories( pub async fn get_all_categories(
State(app_state): State<AppState>, State(app_state): State<AppState>,
) -> Result<Json<Vec<ServiceCategory>>, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Json<Vec<ServiceCategory>>, (StatusCode, Json<ErrorResponse>)> {
println!("Querying all service categories"); tracing::info!("Querying all service categories");
match sqlx::query_as::<_, ServiceCategory>("SELECT id, name, description, clicks_total, total_companies match sqlx::query_as::<_, ServiceCategory>("SELECT id, name, description, clicks_total, total_companies
FROM service_categories ORDER BY name ASC") FROM service_categories ORDER BY name ASC")
@@ -18,10 +18,11 @@ pub async fn get_all_categories(
.await .await
{ {
Ok(categories) => { Ok(categories) => {
tracing::info!(count = categories.len(), "Retrieved service categories");
Ok(Json(categories)) Ok(Json(categories))
} }
Err(e) => { Err(e) => {
eprintln!("Database error fetching service categories: {}", e); tracing::error!(error = %e, "Database error fetching service categories");
Err(( Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse { Json(ErrorResponse {

View File

@@ -48,6 +48,7 @@ pub async fn get_company(
State(state): State<AppState>, State(state): State<AppState>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
) -> Response { ) -> Response {
tracing::info!(user_id = user.id, "Fetching company for user");
match sqlx::query_as::<_, Company>( match sqlx::query_as::<_, Company>(
"SELECT * FROM company WHERE user_id = $1 AND active = true" "SELECT * FROM company WHERE user_id = $1 AND active = true"
) )
@@ -55,9 +56,18 @@ pub async fn get_company(
.fetch_optional(&*state.db) .fetch_optional(&*state.db)
.await .await
{ {
Ok(Some(company)) => (StatusCode::OK, Json(company)).into_response(), Ok(Some(company)) => {
Ok(None) => (StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response(), tracing::info!(user_id = user.id, company_id = company.id, "Company found");
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(), (StatusCode::OK, Json(company)).into_response()
},
Ok(None) => {
tracing::warn!(user_id = user.id, "No company found for user");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
},
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error fetching company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
} }
} }
@@ -66,15 +76,22 @@ pub async fn create_company(
Extension(user): Extension<User>, Extension(user): Extension<User>,
Json(payload): Json<CompanyRequest>, Json(payload): Json<CompanyRequest>,
) -> Response { ) -> Response {
tracing::info!(user_id = user.id, company_name = %payload.name, "Creating company");
// Check if company already exists // Check if company already exists
match sqlx::query("SELECT 1 FROM company WHERE user_id = $1 AND active = true") match sqlx::query("SELECT 1 FROM company WHERE user_id = $1 AND active = true")
.bind(user.id) .bind(user.id)
.fetch_optional(&*state.db) .fetch_optional(&*state.db)
.await .await
{ {
Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response(), Ok(Some(_)) => {
tracing::warn!(user_id = user.id, "Company already exists for user");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response()
},
Ok(None) => {}, Ok(None) => {},
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(), Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error checking existing company");
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
} }
match sqlx::query_as::<_, Company>( match sqlx::query_as::<_, Company>(
@@ -102,6 +119,7 @@ pub async fn update_company(
Extension(user): Extension<User>, Extension(user): Extension<User>,
Json(payload): Json<CompanyRequest>, Json(payload): Json<CompanyRequest>,
) -> Response { ) -> Response {
tracing::info!(user_id = user.id, company_name = %payload.name, "Updating company");
match sqlx::query_as::<_, Company>( match sqlx::query_as::<_, Company>(
"UPDATE company SET name = $1, address = $2, town = $3, state = $4::text, phone = $5, owner_name = $6, owner_phone_number = $7, email = $8 WHERE user_id = $9 AND active = true RETURNING *" "UPDATE company SET name = $1, address = $2, town = $3, state = $4::text, phone = $5, owner_name = $6, owner_phone_number = $7, email = $8 WHERE user_id = $9 AND active = true RETURNING *"
) )
@@ -127,6 +145,7 @@ pub async fn delete_company(
State(state): State<AppState>, State(state): State<AppState>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
) -> Response { ) -> Response {
tracing::info!(user_id = user.id, "Deleting company (soft delete)");
match sqlx::query("UPDATE company SET active = false WHERE user_id = $1 AND active = true") match sqlx::query("UPDATE company SET active = false WHERE user_id = $1 AND active = true")
.bind(user.id) .bind(user.id)
.execute(&*state.db) .execute(&*state.db)
@@ -134,61 +153,91 @@ pub async fn delete_company(
{ {
Ok(result) => { Ok(result) => {
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
tracing::warn!(user_id = user.id, "No company found to delete");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response() (StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
} else { } else {
tracing::info!(user_id = user.id, "Company deleted successfully");
(StatusCode::OK, Json(json!({"success": true, "message": "Company deleted"}))).into_response() (StatusCode::OK, Json(json!({"success": true, "message": "Company deleted"}))).into_response()
} }
} }
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(), Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error deleting company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
} }
} }
pub async fn company_handler( pub async fn company_handler(
request: Request<Body>, request: Request<Body>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let method = request.method().clone();
tracing::debug!(method = %method, "Company handler invoked");
// Extract user and state from extensions // Extract user and state from extensions
let user = match request.extensions().get::<User>().cloned() { let user = match request.extensions().get::<User>().cloned() {
Some(user) => user, Some(user) => user,
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(), None => {
tracing::warn!("Unauthorized access attempt to company handler");
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response();
},
}; };
let state = match request.extensions().get::<AppState>().cloned() { let state = match request.extensions().get::<AppState>().cloned() {
Some(state) => state, Some(state) => state,
None => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "State not found"}))).into_response(), None => {
tracing::error!("App state not found in request extensions");
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "State not found"}))).into_response();
},
}; };
let method = request.method().clone(); tracing::info!(user_id = user.id, method = %method, "Processing company request");
match method { match method {
Method::GET => get_company_logic(&state, &user).await, Method::GET => get_company_logic(&state, &user).await,
Method::POST => { Method::POST => {
let body = match hyper::body::to_bytes(request.into_body()).await { let body = match hyper::body::to_bytes(request.into_body()).await {
Ok(bytes) => bytes, Ok(bytes) => bytes,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response(), Err(e) => {
tracing::error!(error = %e, "Failed to read request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response();
},
}; };
let payload: CompanyRequest = match serde_json::from_slice(&body) { let payload: CompanyRequest = match serde_json::from_slice(&body) {
Ok(data) => data, Ok(data) => data,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response(), Err(e) => {
tracing::warn!(error = %e, "Invalid JSON in request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response();
},
}; };
create_company_logic(&state, &user, payload).await create_company_logic(&state, &user, payload).await
} }
Method::PUT => { Method::PUT => {
let body = match hyper::body::to_bytes(request.into_body()).await { let body = match hyper::body::to_bytes(request.into_body()).await {
Ok(bytes) => bytes, Ok(bytes) => bytes,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response(), Err(e) => {
tracing::error!(error = %e, "Failed to read request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid body"}))).into_response();
},
}; };
let payload: CompanyRequest = match serde_json::from_slice(&body) { let payload: CompanyRequest = match serde_json::from_slice(&body) {
Ok(data) => data, Ok(data) => data,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response(), Err(e) => {
tracing::warn!(error = %e, "Invalid JSON in request body");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Invalid JSON"}))).into_response();
},
}; };
update_company_logic(&state, &user, payload).await update_company_logic(&state, &user, payload).await
} }
Method::DELETE => delete_company_logic(&state, &user).await, Method::DELETE => delete_company_logic(&state, &user).await,
_ => (StatusCode::METHOD_NOT_ALLOWED, Json(json!({"error": "Method not allowed"}))).into_response(), _ => {
tracing::warn!(method = %method, "Method not allowed for company endpoint");
(StatusCode::METHOD_NOT_ALLOWED, Json(json!({"error": "Method not allowed"}))).into_response()
},
} }
} }
async fn get_company_logic(state: &AppState, user: &User) -> Response { async fn get_company_logic(state: &AppState, user: &User) -> Response {
tracing::debug!(user_id = user.id, "get_company_logic called");
match sqlx::query_as::<_, Company>( match sqlx::query_as::<_, Company>(
"SELECT * FROM company WHERE user_id = $1 AND active = true" "SELECT * FROM company WHERE user_id = $1 AND active = true"
) )
@@ -196,22 +245,38 @@ async fn get_company_logic(state: &AppState, user: &User) -> Response {
.fetch_optional(&*state.db) .fetch_optional(&*state.db)
.await .await
{ {
Ok(Some(company)) => (StatusCode::OK, Json(company)).into_response(), Ok(Some(company)) => {
Ok(None) => (StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response(), tracing::info!(user_id = user.id, company_id = company.id, "Company retrieved");
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(), (StatusCode::OK, Json(company)).into_response()
},
Ok(None) => {
tracing::warn!(user_id = user.id, "Company not found");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
},
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error in get_company_logic");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
} }
} }
async fn create_company_logic(state: &AppState, user: &User, payload: CompanyRequest) -> Response { async fn create_company_logic(state: &AppState, user: &User, payload: CompanyRequest) -> Response {
tracing::debug!(user_id = user.id, company_name = %payload.name, "create_company_logic called");
// Check if company already exists // Check if company already exists
match sqlx::query("SELECT 1 FROM company WHERE user_id = $1 AND active = true") match sqlx::query("SELECT 1 FROM company WHERE user_id = $1 AND active = true")
.bind(user.id) .bind(user.id)
.fetch_optional(&*state.db) .fetch_optional(&*state.db)
.await .await
{ {
Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response(), Ok(Some(_)) => {
tracing::warn!(user_id = user.id, "Company already exists");
return (StatusCode::BAD_REQUEST, Json(json!({"error": "Company already exists"}))).into_response()
},
Ok(None) => {}, Ok(None) => {},
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(), Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error checking existing company");
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
} }
match sqlx::query_as::<_, Company>( match sqlx::query_as::<_, Company>(
@@ -235,7 +300,7 @@ async fn create_company_logic(state: &AppState, user: &User, payload: CompanyReq
} }
async fn update_company_logic(state: &AppState, user: &User, payload: CompanyRequest) -> Response { async fn update_company_logic(state: &AppState, user: &User, payload: CompanyRequest) -> Response {
eprintln!("Updating company for user {}: {:?}", user.id, payload); tracing::debug!(user_id = user.id, company_name = %payload.name, "update_company_logic called");
match sqlx::query_as::<_, Company>( match sqlx::query_as::<_, Company>(
"UPDATE company SET name = $1, address = $2, town = $3, state = $4::text, phone = $5, owner_name = $6, owner_phone_number = $7, email = $8 WHERE user_id = $9 AND active = true RETURNING *" "UPDATE company SET name = $1, address = $2, town = $3, state = $4::text, phone = $5, owner_name = $6, owner_phone_number = $7, email = $8 WHERE user_id = $9 AND active = true RETURNING *"
) )
@@ -252,21 +317,22 @@ async fn update_company_logic(state: &AppState, user: &User, payload: CompanyReq
.await .await
{ {
Ok(Some(company)) => { Ok(Some(company)) => {
eprintln!("Updated company successfully"); tracing::info!(user_id = user.id, company_id = company.id, "Company updated successfully");
(StatusCode::OK, Json(company)).into_response() (StatusCode::OK, Json(company)).into_response()
}, },
Ok(None) => { Ok(None) => {
eprintln!("No company found to update, creating new one"); tracing::info!(user_id = user.id, "No company found to update, creating new one");
create_company_logic(state, user, payload).await create_company_logic(state, user, payload).await
}, },
Err(e) => { Err(e) => {
eprintln!("Database error updating company: {:?}", e); tracing::error!(user_id = user.id, error = %e, "Database error updating company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response() (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
}, },
} }
} }
async fn delete_company_logic(state: &AppState, user: &User) -> Response { async fn delete_company_logic(state: &AppState, user: &User) -> Response {
tracing::debug!(user_id = user.id, "delete_company_logic called");
match sqlx::query("UPDATE company SET active = false WHERE user_id = $1 AND active = true") match sqlx::query("UPDATE company SET active = false WHERE user_id = $1 AND active = true")
.bind(user.id) .bind(user.id)
.execute(&*state.db) .execute(&*state.db)
@@ -274,11 +340,16 @@ async fn delete_company_logic(state: &AppState, user: &User) -> Response {
{ {
Ok(result) => { Ok(result) => {
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
tracing::warn!(user_id = user.id, "No company found to delete");
(StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response() (StatusCode::NOT_FOUND, Json(json!({"error": "Company not found"}))).into_response()
} else { } else {
tracing::info!(user_id = user.id, "Company soft-deleted successfully");
(StatusCode::OK, Json(json!({"success": true, "message": "Company deleted"}))).into_response() (StatusCode::OK, Json(json!({"success": true, "message": "Company deleted"}))).into_response()
} }
} }
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(), Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Database error deleting company");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response()
},
} }
} }

View File

@@ -1,17 +1,24 @@
use axum::{ use axum::{
extract::State, Extension,
http::StatusCode, http::StatusCode,
response::IntoResponse, response::IntoResponse,
Json,
}; };
use crate::auth::structs::AppState; use crate::auth::structs::User;
// Define the handler for the /user/ endpoint // Define the handler for the /user/ endpoint
pub async fn get_user(State(state): State<AppState>) -> impl IntoResponse { // Returns the authenticated user's information (password excluded)
// Placeholder for user data retrieval logic pub async fn get_user(Extension(user): Extension<User>) -> impl IntoResponse {
// In a real application, you would query the database using state.db tracing::info!(user_id = user.id, username = %user.username.trim(), "User info requested");
let users = sqlx::query("SELECT * FROM users").fetch_all(&*state.db).await; // Create a response without the password field for security
match users { let user_response = serde_json::json!({
Ok(_users) => (StatusCode::OK, "User data retrieved successfully".to_string()).into_response(), "id": user.id,
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to retrieve users: {}", e)).into_response(), "username": user.username.trim(),
} "email": user.email,
"created": user.created,
"last_login": user.last_login,
"owner": user.owner
});
(StatusCode::OK, Json(user_response)).into_response()
} }

View File

@@ -14,6 +14,7 @@ pub async fn get_listings(
State(app_state): State<AppState>, State(app_state): State<AppState>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
) -> Result<Json<Vec<Listing>>, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Json<Vec<Listing>>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(user_id = user.id, "Fetching listings for user");
match sqlx::query_as::<_, Listing>( match sqlx::query_as::<_, Listing>(
"SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE user_id = $1 ORDER BY id DESC" "SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE user_id = $1 ORDER BY id DESC"
) )
@@ -21,11 +22,17 @@ pub async fn get_listings(
.fetch_all(&*app_state.db) .fetch_all(&*app_state.db)
.await .await
{ {
Ok(listings) => Ok(Json(listings)), Ok(listings) => {
Err(e) => Err(( tracing::info!(user_id = user.id, count = listings.len(), "Listings retrieved");
Ok(Json(listings))
},
Err(e) => {
tracing::error!(user_id = user.id, error = %e, "Failed to fetch listings");
Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to fetch listings: {}", e)})) Json(json!({"error": format!("Failed to fetch listings: {}", e)}))
)), ))
},
} }
} }
@@ -34,6 +41,7 @@ pub async fn get_listing_by_id(
Path(listing_id): Path<i32>, Path(listing_id): Path<i32>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(user_id = user.id, listing_id = listing_id, "Fetching listing by ID");
match sqlx::query_as::<_, Listing>( match sqlx::query_as::<_, Listing>(
"SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE id = $1 AND user_id = $2" "SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE id = $1 AND user_id = $2"
) )
@@ -42,15 +50,24 @@ pub async fn get_listing_by_id(
.fetch_optional(&*app_state.db) .fetch_optional(&*app_state.db)
.await .await
{ {
Ok(Some(listing)) => Ok(Json(listing)), Ok(Some(listing)) => {
Ok(None) => Err(( tracing::info!(user_id = user.id, listing_id = listing_id, "Listing found");
Ok(Json(listing))
},
Ok(None) => {
tracing::warn!(user_id = user.id, listing_id = listing_id, "Listing not found");
Err((
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found"})) Json(json!({"error": "Listing not found"}))
)), ))
Err(e) => Err(( },
Err(e) => {
tracing::error!(user_id = user.id, listing_id = listing_id, error = %e, "Failed to fetch listing");
Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to fetch listing: {}", e)})) Json(json!({"error": format!("Failed to fetch listing: {}", e)}))
)), ))
},
} }
} }
@@ -59,8 +76,15 @@ pub async fn create_listing(
Extension(user): Extension<User>, Extension(user): Extension<User>,
Json(payload): Json<CreateListingRequest>, Json(payload): Json<CreateListingRequest>,
) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> {
eprintln!("DEBUG: Starting create_listing for user_id: {}", user.id); tracing::debug!("Starting create_listing for user_id: {}", user.id);
eprintln!("DEBUG: Payload: {:?}", payload); tracing::debug!("Payload: {:?}", payload);
if let Err(e) = payload.validate() {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({"error": e}))
));
}
// Create the listing directly without company validation // Create the listing directly without company validation
match sqlx::query_as::<_, Listing>( match sqlx::query_as::<_, Listing>(
@@ -83,11 +107,11 @@ pub async fn create_listing(
.await .await
{ {
Ok(listing) => { Ok(listing) => {
eprintln!("DEBUG: Successfully created listing: {:?}", listing); tracing::debug!("Successfully created listing: {:?}", listing);
Ok(Json(listing)) Ok(Json(listing))
}, },
Err(e) => { Err(e) => {
eprintln!("DEBUG: Error creating listing: {:?}", e); tracing::error!("Error creating listing: {:?}", e);
Err(( Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create listing: {}", e)})) Json(json!({"error": format!("Failed to create listing: {}", e)}))
@@ -102,97 +126,96 @@ pub async fn update_listing(
Extension(user): Extension<User>, Extension(user): Extension<User>,
Json(payload): Json<UpdateListingRequest>, Json(payload): Json<UpdateListingRequest>,
) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Json<Listing>, (StatusCode, Json<serde_json::Value>)> {
// Build dynamic update query tracing::info!(user_id = user.id, listing_id = listing_id, "Updating listing");
let mut query = "UPDATE listings SET ".to_string(); if let Err(e) = payload.validate() {
let mut params: Vec<String> = Vec::new(); tracing::warn!(user_id = user.id, listing_id = listing_id, error = %e, "Validation failed for update");
let mut param_count = 1;
if let Some(company_name) = &payload.company_name {
params.push(format!("company_name = ${}", param_count));
param_count += 1;
}
if let Some(is_active) = payload.is_active {
params.push(format!("is_active = ${}", param_count));
param_count += 1;
}
if let Some(price_per_gallon) = payload.price_per_gallon {
params.push(format!("price_per_gallon = ${}", param_count));
param_count += 1;
}
if let Some(price_per_gallon_cash) = payload.price_per_gallon_cash {
params.push(format!("price_per_gallon_cash = ${}", param_count));
param_count += 1;
}
if let Some(note) = &payload.note {
params.push(format!("note = ${}", param_count));
param_count += 1;
}
if let Some(minimum_order) = payload.minimum_order {
params.push(format!("minimum_order = ${}", param_count));
param_count += 1;
}
if let Some(service) = payload.service {
params.push(format!("service = ${}", param_count));
param_count += 1;
}
if let Some(bio_percent) = payload.bio_percent {
params.push(format!("bio_percent = ${}", param_count));
param_count += 1;
}
if let Some(phone) = &payload.phone {
params.push(format!("phone = ${}", param_count));
param_count += 1;
}
if let Some(online_ordering) = &payload.online_ordering {
params.push(format!("online_ordering = ${}", param_count));
param_count += 1;
}
if let Some(county_id) = payload.county_id {
params.push(format!("county_id = ${}", param_count));
param_count += 1;
}
if params.is_empty() {
return Err(( return Err((
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(json!({"error": "No fields to update"})) Json(json!({"error": e}))
)); ));
} }
query.push_str(&params.join(", ")); let mut query_builder = sqlx::QueryBuilder::new("UPDATE listings SET ");
query.push_str(&format!(" WHERE id = ${} AND user_id = ${} RETURNING *", param_count, param_count + 1)); let mut separated = query_builder.separated(", ");
// This is a simplified version - in production, you'd want to build the query more safely if let Some(company_name) = &payload.company_name {
// For now, let's use a simpler approach separated.push("company_name = ");
match sqlx::query_as::<_, Listing>( separated.push_bind_unseparated(company_name);
"UPDATE listings SET company_name = COALESCE($1, company_name), is_active = COALESCE($2, is_active), price_per_gallon = COALESCE($3, price_per_gallon), price_per_gallon_cash = COALESCE($4, price_per_gallon_cash), note = COALESCE($5, note), minimum_order = COALESCE($6, minimum_order), service = COALESCE($7, service), bio_percent = COALESCE($8, bio_percent), phone = COALESCE($9, phone), online_ordering = COALESCE($10, online_ordering), county_id = COALESCE($11, county_id), town = COALESCE($12, town), last_edited = CURRENT_TIMESTAMP WHERE id = $13 AND user_id = $14 RETURNING id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited" }
) if let Some(is_active) = payload.is_active {
.bind(&payload.company_name) separated.push("is_active = ");
.bind(payload.is_active) separated.push_bind_unseparated(is_active);
.bind(payload.price_per_gallon) }
.bind(payload.price_per_gallon_cash) if let Some(price_per_gallon) = payload.price_per_gallon {
.bind(&payload.note) separated.push("price_per_gallon = ");
.bind(payload.minimum_order) separated.push_bind_unseparated(price_per_gallon);
.bind(payload.service) }
.bind(payload.bio_percent) if let Some(price_per_gallon_cash) = payload.price_per_gallon_cash {
.bind(&payload.phone) separated.push("price_per_gallon_cash = ");
.bind(&payload.online_ordering) separated.push_bind_unseparated(price_per_gallon_cash);
.bind(payload.county_id) }
.bind(&payload.town) if let Some(note) = &payload.note {
.bind(listing_id) separated.push("note = ");
.bind(user.id) separated.push_bind_unseparated(note);
.fetch_optional(&*app_state.db) }
.await if let Some(minimum_order) = payload.minimum_order {
{ separated.push("minimum_order = ");
Ok(Some(listing)) => Ok(Json(listing)), separated.push_bind_unseparated(minimum_order);
Ok(None) => Err(( }
if let Some(service) = payload.service {
separated.push("service = ");
separated.push_bind_unseparated(service);
}
if let Some(bio_percent) = payload.bio_percent {
separated.push("bio_percent = ");
separated.push_bind_unseparated(bio_percent);
}
if let Some(phone) = &payload.phone {
separated.push("phone = ");
separated.push_bind_unseparated(phone);
}
if let Some(online_ordering) = &payload.online_ordering {
separated.push("online_ordering = ");
separated.push_bind_unseparated(online_ordering);
}
if let Some(county_id) = payload.county_id {
separated.push("county_id = ");
separated.push_bind_unseparated(county_id);
}
if let Some(town) = &payload.town {
separated.push("town = ");
separated.push_bind_unseparated(town);
}
separated.push("last_edited = CURRENT_TIMESTAMP");
query_builder.push(" WHERE id = ");
query_builder.push_bind(listing_id);
query_builder.push(" AND user_id = ");
query_builder.push_bind(user.id);
query_builder.push(" RETURNING id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited");
let query = query_builder.build_query_as::<Listing>();
match query.fetch_optional(&*app_state.db).await {
Ok(Some(listing)) => {
tracing::info!(user_id = user.id, listing_id = listing_id, "Listing updated successfully");
Ok(Json(listing))
},
Ok(None) => {
tracing::warn!(user_id = user.id, listing_id = listing_id, "Listing not found or access denied");
Err((
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found"})) Json(json!({"error": "Listing not found or access denied"}))
)), ))
Err(e) => Err(( },
Err(e) => {
tracing::error!(user_id = user.id, listing_id = listing_id, error = %e, "Failed to update listing");
Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to update listing: {}", e)})) Json(json!({"error": format!("Failed to update listing: {}", e)}))
)), ))
},
} }
} }
@@ -200,6 +223,7 @@ pub async fn get_listings_by_county(
State(app_state): State<AppState>, State(app_state): State<AppState>,
Path(county_id): Path<i32>, Path(county_id): Path<i32>,
) -> Result<Json<Vec<Listing>>, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Json<Vec<Listing>>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(county_id = county_id, "Fetching listings by county");
match sqlx::query_as::<_, Listing>( match sqlx::query_as::<_, Listing>(
"SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE county_id = $1 AND is_active = true ORDER BY last_edited DESC" "SELECT id, company_name, is_active, price_per_gallon, price_per_gallon_cash, note, minimum_order, service, bio_percent, phone, online_ordering, county_id, town, user_id, last_edited FROM listings WHERE county_id = $1 AND is_active = true ORDER BY last_edited DESC"
) )
@@ -220,6 +244,7 @@ pub async fn delete_listing(
Path(listing_id): Path<i32>, Path(listing_id): Path<i32>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
tracing::info!(user_id = user.id, listing_id = listing_id, "Deleting listing");
match sqlx::query("DELETE FROM listings WHERE id = $1 AND user_id = $2") match sqlx::query("DELETE FROM listings WHERE id = $1 AND user_id = $2")
.bind(listing_id) .bind(listing_id)
.bind(user.id) .bind(user.id)
@@ -228,17 +253,22 @@ pub async fn delete_listing(
{ {
Ok(result) => { Ok(result) => {
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
tracing::warn!(user_id = user.id, listing_id = listing_id, "Listing not found for deletion");
Err(( Err((
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(json!({"error": "Listing not found"})) Json(json!({"error": "Listing not found"}))
)) ))
} else { } else {
tracing::info!(user_id = user.id, listing_id = listing_id, "Listing deleted successfully");
Ok(Json(json!({"success": true, "message": "Listing deleted"}))) Ok(Json(json!({"success": true, "message": "Listing deleted"})))
} }
} }
Err(e) => Err(( Err(e) => {
tracing::error!(user_id = user.id, listing_id = listing_id, error = %e, "Failed to delete listing");
Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete listing: {}", e)})) Json(json!({"error": format!("Failed to delete listing: {}", e)}))
)), ))
},
} }
} }

View File

@@ -38,6 +38,28 @@ pub struct CreateListingRequest {
pub town: Option<String>, pub town: Option<String>,
} }
impl CreateListingRequest {
pub fn validate(&self) -> Result<(), String> {
if self.price_per_gallon <= 0.0 {
return Err("Price per gallon must be greater than 0".to_string());
}
if let Some(cash_price) = self.price_per_gallon_cash {
if cash_price < 0.0 {
return Err("Cash price must be non-negative".to_string());
}
}
if self.bio_percent < 0 || self.bio_percent > 100 {
return Err("Bio percent must be between 0 and 100".to_string());
}
if let Some(min_order) = self.minimum_order {
if min_order < 0 {
return Err("Minimum order must be non-negative".to_string());
}
}
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct UpdateListingRequest { pub struct UpdateListingRequest {
pub company_name: Option<String>, pub company_name: Option<String>,
@@ -53,3 +75,29 @@ pub struct UpdateListingRequest {
pub county_id: Option<i32>, pub county_id: Option<i32>,
pub town: Option<String>, pub town: Option<String>,
} }
impl UpdateListingRequest {
pub fn validate(&self) -> Result<(), String> {
if let Some(price) = self.price_per_gallon {
if price <= 0.0 {
return Err("Price per gallon must be greater than 0".to_string());
}
}
if let Some(cash_price) = self.price_per_gallon_cash {
if cash_price < 0.0 {
return Err("Cash price must be non-negative".to_string());
}
}
if let Some(bio) = self.bio_percent {
if bio < 0 || bio > 100 {
return Err("Bio percent must be between 0 and 100".to_string());
}
}
if let Some(min_order) = self.minimum_order {
if min_order < 0 {
return Err("Minimum order must be non-negative".to_string());
}
}
Ok(())
}
}

View File

@@ -1,12 +1,11 @@
use axum::{ use axum::{
http::{header, Method}, http::{header, Method},
Router, Router,
}; };
use std::env; use std::env;
use tower_http::cors::{CorsLayer, Any}; use tower_http::cors::CorsLayer;
use crate::auth::structs::AppState; use crate::auth::structs::AppState;
use crate::auth::auth::{auth_middleware, login, register}; use crate::auth::auth::{auth_middleware, login, register, logout};
use crate::data::data::get_user; use crate::data::data::get_user;
use crate::state::data::{get_counties_by_state, get_county_by_id}; use crate::state::data::{get_counties_by_state, get_county_by_id};
use crate::listing::data::{get_listings, get_listing_by_id, get_listings_by_county, create_listing, update_listing, delete_listing}; use crate::listing::data::{get_listings, get_listing_by_id, get_listings_by_county, create_listing, update_listing, delete_listing};
@@ -22,17 +21,32 @@ mod listing;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// Initialize tracing first (before any logging)
// RUST_LOG env var controls log level, e.g. RUST_LOG=debug or RUST_LOG=api_rust=debug
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("api_rust=info".parse().unwrap())
)
.init();
tracing::info!("Starting NewEnglandBio API server...");
// Load environment variables // Load environment variables
dotenv::dotenv().ok(); dotenv::dotenv().ok();
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let _jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
let frontend_origin = env::var("FRONTEND_ORIGIN").unwrap_or_else(|_| "http://localhost:9551".to_string()); let frontend_origin = env::var("FRONTEND_ORIGIN").unwrap_or_else(|_| "http://localhost:9551".to_string());
tracing::info!(frontend_origin = %frontend_origin, "Configuration loaded");
// Connect to PostgreSQL // Connect to PostgreSQL
tracing::info!("Connecting to PostgreSQL database...");
let db_pool = PgPool::connect(&database_url) let db_pool = PgPool::connect(&database_url)
.await .await
.expect("Failed to connect to database"); .expect("Failed to connect to database");
tracing::info!("Database connection established");
let db = Arc::new(db_pool); let db = Arc::new(db_pool);
// Create app state // Create app state
@@ -41,11 +55,14 @@ async fn main() {
jwt_secret: env::var("JWT_SECRET").expect("JWT_SECRET must be set"), jwt_secret: env::var("JWT_SECRET").expect("JWT_SECRET must be set"),
}; };
// Configure CORS // Configure CORS with credentials support for cookie auth
let cors = CorsLayer::new() let cors = CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::exact(frontend_origin.parse::<header::HeaderValue>().unwrap())) .allow_origin(tower_http::cors::AllowOrigin::exact(frontend_origin.parse::<header::HeaderValue>().unwrap()))
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_headers(Any); .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
.allow_credentials(true);
tracing::debug!("CORS configured for origin: {} with credentials", frontend_origin);
// Build router with separated public and protected routes // Build router with separated public and protected routes
let protected_routes = Router::new() let protected_routes = Router::new()
@@ -61,15 +78,20 @@ async fn main() {
let public_routes = Router::new() let public_routes = Router::new()
.route("/auth/register", axum::routing::post(register)) .route("/auth/register", axum::routing::post(register))
.route("/auth/login", axum::routing::post(login)) .route("/auth/login", axum::routing::post(login))
.route("/auth/logout", axum::routing::post(logout))
.route("/categories", axum::routing::get(crate::company::category::get_all_categories)) .route("/categories", axum::routing::get(crate::company::category::get_all_categories))
.route("/state/:state_abbr", axum::routing::get(get_counties_by_state)) .route("/state/:state_abbr", axum::routing::get(get_counties_by_state))
.route("/state/:state_abbr/:county_id", axum::routing::get(get_county_by_id)) .route("/state/:state_abbr/:county_id", axum::routing::get(get_county_by_id))
.route("/listings/county/:county_id", axum::routing::get(get_listings_by_county)); .route("/listings/county/:county_id", axum::routing::get(get_listings_by_county));
let app = public_routes.merge(protected_routes).with_state(state).layer(cors); let app = public_routes
.merge(protected_routes)
.with_state(state)
.layer(cors);
// Print server status tracing::info!("Routes configured");
println!("Server is running on http://0.0.0.0:9552"); tracing::info!("Server is running on http://0.0.0.0:9552");
tracing::info!("Press Ctrl+C to stop");
// Run server // Run server
axum::Server::bind(&"0.0.0.0:9552".parse().unwrap()) axum::Server::bind(&"0.0.0.0:9552".parse().unwrap())

View File

@@ -14,7 +14,7 @@ pub async fn get_counties_by_state(
Path(state_abbr): Path<String>, Path(state_abbr): Path<String>,
) -> Result<Json<Vec<County>>, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Json<Vec<County>>, (StatusCode, Json<ErrorResponse>)> {
let state_abbr_upper = state_abbr.to_uppercase(); let state_abbr_upper = state_abbr.to_uppercase();
println!("Querying counties for state: {}", state_abbr_upper); tracing::info!(state = %state_abbr_upper, "Querying counties for state");
match sqlx::query_as::<_, County>("SELECT id, name, state FROM county WHERE UPPER(state) = $1 ORDER BY name ASC") match sqlx::query_as::<_, County>("SELECT id, name, state FROM county WHERE UPPER(state) = $1 ORDER BY name ASC")
.bind(&state_abbr_upper) .bind(&state_abbr_upper)
@@ -23,6 +23,7 @@ pub async fn get_counties_by_state(
{ {
Ok(counties) => { Ok(counties) => {
if counties.is_empty() { if counties.is_empty() {
tracing::warn!(state = %state_abbr_upper, "No counties found for state");
Err(( Err((
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(ErrorResponse { Json(ErrorResponse {
@@ -30,11 +31,12 @@ pub async fn get_counties_by_state(
}), }),
)) ))
} else { } else {
tracing::info!(state = %state_abbr_upper, count = counties.len(), "Counties retrieved");
Ok(Json(counties)) Ok(Json(counties))
} }
} }
Err(e) => { Err(e) => {
eprintln!("Database error fetching counties for state {}: {}", state_abbr_upper, e); tracing::error!(state = %state_abbr_upper, error = %e, "Database error fetching counties");
Err(( Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse { Json(ErrorResponse {
@@ -50,7 +52,7 @@ pub async fn get_county_by_id(
Path((state_abbr, county_id)): Path<(String, i32)>, Path((state_abbr, county_id)): Path<(String, i32)>,
) -> Result<Json<County>, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Json<County>, (StatusCode, Json<ErrorResponse>)> {
let state_abbr_upper = state_abbr.to_uppercase(); let state_abbr_upper = state_abbr.to_uppercase();
println!("Querying county with ID: {} for state: {}", county_id, state_abbr_upper); tracing::info!(state = %state_abbr_upper, county_id = county_id, "Querying county by ID");
match sqlx::query_as::<_, County>("SELECT id, name, state FROM county WHERE UPPER(state) = $1 AND id = $2") match sqlx::query_as::<_, County>("SELECT id, name, state FROM county WHERE UPPER(state) = $1 AND id = $2")
.bind(&state_abbr_upper) .bind(&state_abbr_upper)
@@ -58,15 +60,21 @@ pub async fn get_county_by_id(
.fetch_one(&*app_state.db) .fetch_one(&*app_state.db)
.await .await
{ {
Ok(county) => Ok(Json(county)), Ok(county) => {
Err(sqlx::Error::RowNotFound) => Err(( tracing::info!(state = %state_abbr_upper, county_id = county_id, "County retrieved");
Ok(Json(county))
},
Err(sqlx::Error::RowNotFound) => {
tracing::warn!(state = %state_abbr_upper, county_id = county_id, "County not found");
Err((
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(ErrorResponse { Json(ErrorResponse {
error: format!("County with ID {} not found in state {}", county_id, state_abbr_upper), error: format!("County with ID {} not found in state {}", county_id, state_abbr_upper),
}), }),
)), ))
},
Err(e) => { Err(e) => {
eprintln!("Database error fetching county with ID {} for state {}: {}", county_id, state_abbr_upper, e); tracing::error!(state = %state_abbr_upper, county_id = county_id, error = %e, "Database error fetching county");
Err(( Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse { Json(ErrorResponse {