use axum::{ extract::{Query, State}, http::StatusCode, response::{Html, IntoResponse, Redirect, Response}, Json, }; use axum_login::AuthSession; use sea_orm::{DatabaseConnection, EntityTrait}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tower_sessions::Session; use tracing::{info, warn}; use utoipa::ToSchema; use crate::auth::AuthBackend; use crate::models::User; use crate::services::OAuthService; const CSRF_TOKEN_KEY: &str = "oauth_csrf_token"; const FRONTEND_URL_KEY: &str = "oauth_frontend_url"; fn mobile_secret() -> String { std::env::var("MOBILE_SECRET").unwrap_or_else(|_| "family-budget-mobile-secret".to_string()) } fn sign(data: &str) -> String { let secret = mobile_secret(); let mut hasher = Sha256::new(); hasher.update(format!("{}:{}", secret, data).as_bytes()); hex::encode(hasher.finalize()) } fn make_mobile_csrf_state(nonce: &str) -> String { let sig = sign(&format!("csrf.mobile.{}", nonce)); format!("mobile.{}.{}", nonce, sig) } fn verify_mobile_csrf_state(state: &str) -> bool { let mut parts = state.splitn(3, '.'); match (parts.next(), parts.next(), parts.next()) { (Some("mobile"), Some(nonce), Some(sig)) => { sign(&format!("csrf.mobile.{}", nonce)) == sig } _ => false, } } fn make_auth_token(user_id: i32) -> String { let timestamp = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs(); let payload = format!("{}.{}", user_id, timestamp); let sig = sign(&format!("auth.{}", payload)); format!("{}.{}", payload, sig) } fn verify_auth_token(token: &str) -> Option { let mut parts = token.splitn(3, '.'); let user_id_str = parts.next()?; let timestamp_str = parts.next()?; let sig = parts.next()?; let payload = format!("{}.{}", user_id_str, timestamp_str); if sign(&format!("auth.{}", payload)) != sig { return None; } let timestamp: u64 = timestamp_str.parse().ok()?; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs(); if now.saturating_sub(timestamp) > 300 { return None; } user_id_str.parse().ok() } #[derive(Debug, Deserialize, ToSchema)] pub struct GoogleAuthQuery { pub redirect_url: Option, pub mobile: Option, } #[derive(Debug, Deserialize)] pub struct GoogleCallbackQuery { pub code: String, pub state: String, } #[derive(Debug, Serialize, ToSchema)] pub struct OAuthUrlResponse { pub url: String, } #[utoipa::path( get, path = "/auth/google", tag = "auth", params( ("redirect_url" = Option, Query, description = "Frontend URL to redirect after auth"), ("mobile" = Option, Query, description = "Mobile OAuth flow") ), responses( (status = 200, description = "Returns Google OAuth URL", body = OAuthUrlResponse) ) )] pub async fn google_auth( session: Session, Query(query): Query, ) -> Result, StatusCode> { let oauth_service = OAuthService::new(); if query.mobile.unwrap_or(false) { let nonce = uuid::Uuid::new_v4().to_string(); let mobile_state = make_mobile_csrf_state(&nonce); let auth_url = oauth_service.get_auth_url_with_state(mobile_state); info!("mobile google_auth: generated signed state for nonce={}", nonce); return Ok(Json(OAuthUrlResponse { url: auth_url })); } let (auth_url, csrf_token) = oauth_service.get_auth_url(); session .insert(CSRF_TOKEN_KEY, csrf_token.secret().clone()) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if let Some(redirect_url) = query.redirect_url { session .insert(FRONTEND_URL_KEY, redirect_url) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; } Ok(Json(OAuthUrlResponse { url: auth_url })) } #[utoipa::path( get, path = "/auth/google/callback", tag = "auth", responses( (status = 302, description = "Redirects to frontend after successful auth"), (status = 401, description = "Authentication failed") ) )] pub async fn google_callback( mut auth_session: AuthSession, session: Session, State(db): State, Query(query): Query, ) -> Result { let is_mobile = verify_mobile_csrf_state(&query.state); info!("google_callback: state={} is_mobile={}", &query.state[..query.state.len().min(20)], is_mobile); if !is_mobile { let session_csrf: Option = session .get(CSRF_TOKEN_KEY) .await .unwrap_or(None); session.remove::(CSRF_TOKEN_KEY).await.ok(); match session_csrf { Some(csrf) if csrf == query.state => {} _ => { warn!("google_callback: CSRF mismatch, session_csrf={:?}", session_csrf.as_deref().map(|s| &s[..s.len().min(10)])); return Err(StatusCode::UNAUTHORIZED); } } } let frontend_url: Option = session .get(FRONTEND_URL_KEY) .await .unwrap_or(None); session.remove::(FRONTEND_URL_KEY).await.ok(); let oauth_service = OAuthService::new(); let access_token = oauth_service .exchange_code(query.code) .await .map_err(|_| StatusCode::UNAUTHORIZED)?; let google_user = oauth_service .get_user_info(&access_token) .await .map_err(|_| StatusCode::UNAUTHORIZED)?; let user = oauth_service .find_or_create_user(&db, google_user) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if is_mobile { let token = make_auth_token(user.id); info!("google_callback: mobile auth for user_id={}, token_prefix={}", user.id, &token[..token.len().min(20)]); let deep_link = format!("com.arrelin.family-budget-android://auth?token={}", token); let html = format!( r#""#, deep_link ); return Ok(Html(html).into_response()); } auth_session .login(&user) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if let Some(family_id) = user.family_id { let mut authorized_families: Vec = session .get("authorized_families") .await .unwrap_or(None) .unwrap_or_default(); if !authorized_families.contains(&family_id) { authorized_families.push(family_id); session.insert("authorized_families", authorized_families).await.ok(); } } let redirect_url = frontend_url.unwrap_or_else(|| "http://localhost:3000".to_string()); Ok(Redirect::temporary(&redirect_url).into_response()) } #[derive(Debug, Deserialize)] pub struct MobileCallbackQuery { pub token: String, } pub async fn mobile_callback( mut auth_session: AuthSession, session: Session, State(db): State, Query(query): Query, ) -> Result, StatusCode> { info!("mobile_callback: received token_prefix={}", &query.token[..query.token.len().min(20)]); let user_id = match verify_auth_token(&query.token) { Some(id) => id, None => { warn!("mobile_callback: token verification failed for token={}", &query.token[..query.token.len().min(40)]); return Err(StatusCode::UNAUTHORIZED); } }; info!("mobile_callback: token valid for user_id={}", user_id); let user = User::find_by_id(user_id) .one(&db) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::UNAUTHORIZED)?; auth_session .login(&user) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if let Some(family_id) = user.family_id { let mut authorized_families: Vec = session .get("authorized_families") .await .unwrap_or(None) .unwrap_or_default(); if !authorized_families.contains(&family_id) { authorized_families.push(family_id); session.insert("authorized_families", authorized_families).await.ok(); } } Ok(Json(serde_json::json!({"success": true}))) }