276 lines
8.5 KiB
Rust
276 lines
8.5 KiB
Rust
use axum::{
|
|
extract::{Query, State},
|
|
http::StatusCode,
|
|
response::{Html, IntoResponse, Redirect, Response},
|
|
Json,
|
|
};
|
|
use axum_login::AuthSession;
|
|
use sea_orm::{DatabaseConnection, EntityTrait};
|
|
use serde::{Deserialize, Serialize};
|
|
use sha2::{Digest, Sha256};
|
|
use tower_sessions::Session;
|
|
use tracing::{info, warn};
|
|
use utoipa::ToSchema;
|
|
|
|
use crate::auth::AuthBackend;
|
|
use crate::models::User;
|
|
use crate::services::OAuthService;
|
|
|
|
const CSRF_TOKEN_KEY: &str = "oauth_csrf_token";
|
|
const FRONTEND_URL_KEY: &str = "oauth_frontend_url";
|
|
|
|
fn mobile_secret() -> String {
|
|
std::env::var("MOBILE_SECRET").unwrap_or_else(|_| "family-budget-mobile-secret".to_string())
|
|
}
|
|
|
|
fn sign(data: &str) -> String {
|
|
let secret = mobile_secret();
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(format!("{}:{}", secret, data).as_bytes());
|
|
hex::encode(hasher.finalize())
|
|
}
|
|
|
|
fn make_mobile_csrf_state(nonce: &str) -> String {
|
|
let sig = sign(&format!("csrf.mobile.{}", nonce));
|
|
format!("mobile.{}.{}", nonce, sig)
|
|
}
|
|
|
|
fn verify_mobile_csrf_state(state: &str) -> bool {
|
|
let mut parts = state.splitn(3, '.');
|
|
match (parts.next(), parts.next(), parts.next()) {
|
|
(Some("mobile"), Some(nonce), Some(sig)) => {
|
|
sign(&format!("csrf.mobile.{}", nonce)) == sig
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
fn make_auth_token(user_id: i32) -> String {
|
|
let timestamp = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs();
|
|
let payload = format!("{}.{}", user_id, timestamp);
|
|
let sig = sign(&format!("auth.{}", payload));
|
|
format!("{}.{}", payload, sig)
|
|
}
|
|
|
|
fn verify_auth_token(token: &str) -> Option<i32> {
|
|
let mut parts = token.splitn(3, '.');
|
|
let user_id_str = parts.next()?;
|
|
let timestamp_str = parts.next()?;
|
|
let sig = parts.next()?;
|
|
|
|
let payload = format!("{}.{}", user_id_str, timestamp_str);
|
|
if sign(&format!("auth.{}", payload)) != sig {
|
|
return None;
|
|
}
|
|
|
|
let timestamp: u64 = timestamp_str.parse().ok()?;
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs();
|
|
if now.saturating_sub(timestamp) > 300 {
|
|
return None;
|
|
}
|
|
|
|
user_id_str.parse().ok()
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, ToSchema)]
|
|
pub struct GoogleAuthQuery {
|
|
pub redirect_url: Option<String>,
|
|
pub mobile: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct GoogleCallbackQuery {
|
|
pub code: String,
|
|
pub state: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, ToSchema)]
|
|
pub struct OAuthUrlResponse {
|
|
pub url: String,
|
|
}
|
|
|
|
#[utoipa::path(
|
|
get,
|
|
path = "/auth/google",
|
|
tag = "auth",
|
|
params(
|
|
("redirect_url" = Option<String>, Query, description = "Frontend URL to redirect after auth"),
|
|
("mobile" = Option<bool>, Query, description = "Mobile OAuth flow")
|
|
),
|
|
responses(
|
|
(status = 200, description = "Returns Google OAuth URL", body = OAuthUrlResponse)
|
|
)
|
|
)]
|
|
pub async fn google_auth(
|
|
session: Session,
|
|
Query(query): Query<GoogleAuthQuery>,
|
|
) -> Result<Json<OAuthUrlResponse>, StatusCode> {
|
|
let oauth_service = OAuthService::new();
|
|
|
|
if query.mobile.unwrap_or(false) {
|
|
let nonce = uuid::Uuid::new_v4().to_string();
|
|
let mobile_state = make_mobile_csrf_state(&nonce);
|
|
let auth_url = oauth_service.get_auth_url_with_state(mobile_state);
|
|
info!("mobile google_auth: generated signed state for nonce={}", nonce);
|
|
return Ok(Json(OAuthUrlResponse { url: auth_url }));
|
|
}
|
|
|
|
let (auth_url, csrf_token) = oauth_service.get_auth_url();
|
|
|
|
session
|
|
.insert(CSRF_TOKEN_KEY, csrf_token.secret().clone())
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
if let Some(redirect_url) = query.redirect_url {
|
|
session
|
|
.insert(FRONTEND_URL_KEY, redirect_url)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
}
|
|
|
|
Ok(Json(OAuthUrlResponse { url: auth_url }))
|
|
}
|
|
|
|
#[utoipa::path(
|
|
get,
|
|
path = "/auth/google/callback",
|
|
tag = "auth",
|
|
responses(
|
|
(status = 302, description = "Redirects to frontend after successful auth"),
|
|
(status = 401, description = "Authentication failed")
|
|
)
|
|
)]
|
|
pub async fn google_callback(
|
|
mut auth_session: AuthSession<AuthBackend>,
|
|
session: Session,
|
|
State(db): State<DatabaseConnection>,
|
|
Query(query): Query<GoogleCallbackQuery>,
|
|
) -> Result<Response, StatusCode> {
|
|
let is_mobile = verify_mobile_csrf_state(&query.state);
|
|
info!("google_callback: state={} is_mobile={}", &query.state[..query.state.len().min(20)], is_mobile);
|
|
|
|
if !is_mobile {
|
|
let session_csrf: Option<String> = session
|
|
.get(CSRF_TOKEN_KEY)
|
|
.await
|
|
.unwrap_or(None);
|
|
session.remove::<String>(CSRF_TOKEN_KEY).await.ok();
|
|
|
|
match session_csrf {
|
|
Some(csrf) if csrf == query.state => {}
|
|
_ => {
|
|
warn!("google_callback: CSRF mismatch, session_csrf={:?}", session_csrf.as_deref().map(|s| &s[..s.len().min(10)]));
|
|
return Err(StatusCode::UNAUTHORIZED);
|
|
}
|
|
}
|
|
}
|
|
|
|
let frontend_url: Option<String> = session
|
|
.get(FRONTEND_URL_KEY)
|
|
.await
|
|
.unwrap_or(None);
|
|
session.remove::<String>(FRONTEND_URL_KEY).await.ok();
|
|
|
|
let oauth_service = OAuthService::new();
|
|
|
|
let access_token = oauth_service
|
|
.exchange_code(query.code)
|
|
.await
|
|
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
|
|
|
let google_user = oauth_service
|
|
.get_user_info(&access_token)
|
|
.await
|
|
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
|
|
|
let user = oauth_service
|
|
.find_or_create_user(&db, google_user)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
if is_mobile {
|
|
let token = make_auth_token(user.id);
|
|
info!("google_callback: mobile auth for user_id={}, token_prefix={}", user.id, &token[..token.len().min(20)]);
|
|
let deep_link = format!("com.arrelin.family-budget-android://auth?token={}", token);
|
|
let html = format!(
|
|
r#"<!DOCTYPE html><html><head><meta http-equiv="refresh" content="0;url={0}"></head><body><script>window.location="{0}"</script></body></html>"#,
|
|
deep_link
|
|
);
|
|
return Ok(Html(html).into_response());
|
|
}
|
|
|
|
auth_session
|
|
.login(&user)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
if let Some(family_id) = user.family_id {
|
|
let mut authorized_families: Vec<i32> = session
|
|
.get("authorized_families")
|
|
.await
|
|
.unwrap_or(None)
|
|
.unwrap_or_default();
|
|
if !authorized_families.contains(&family_id) {
|
|
authorized_families.push(family_id);
|
|
session.insert("authorized_families", authorized_families).await.ok();
|
|
}
|
|
}
|
|
|
|
let redirect_url = frontend_url.unwrap_or_else(|| "http://localhost:3000".to_string());
|
|
Ok(Redirect::temporary(&redirect_url).into_response())
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct MobileCallbackQuery {
|
|
pub token: String,
|
|
}
|
|
|
|
pub async fn mobile_callback(
|
|
mut auth_session: AuthSession<AuthBackend>,
|
|
session: Session,
|
|
State(db): State<DatabaseConnection>,
|
|
Query(query): Query<MobileCallbackQuery>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
info!("mobile_callback: received token_prefix={}", &query.token[..query.token.len().min(20)]);
|
|
let user_id = match verify_auth_token(&query.token) {
|
|
Some(id) => id,
|
|
None => {
|
|
warn!("mobile_callback: token verification failed for token={}", &query.token[..query.token.len().min(40)]);
|
|
return Err(StatusCode::UNAUTHORIZED);
|
|
}
|
|
};
|
|
info!("mobile_callback: token valid for user_id={}", user_id);
|
|
|
|
let user = User::find_by_id(user_id)
|
|
.one(&db)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
|
.ok_or(StatusCode::UNAUTHORIZED)?;
|
|
|
|
auth_session
|
|
.login(&user)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
if let Some(family_id) = user.family_id {
|
|
let mut authorized_families: Vec<i32> = session
|
|
.get("authorized_families")
|
|
.await
|
|
.unwrap_or(None)
|
|
.unwrap_or_default();
|
|
if !authorized_families.contains(&family_id) {
|
|
authorized_families.push(family_id);
|
|
session.insert("authorized_families", authorized_families).await.ok();
|
|
}
|
|
}
|
|
|
|
Ok(Json(serde_json::json!({"success": true})))
|
|
}
|