Merge pull request 'mobile update' (#34) from bugfix/iro4ka into master
All checks were successful
Build and Publish Images / build-and-push (push) Successful in 2m8s

Reviewed-on: http://192.168.31.100:3847/Arrelin/family_budget/pulls/34
This commit was merged in pull request #34.
This commit is contained in:
2026-03-10 14:28:58 +03:00
5 changed files with 134 additions and 84 deletions

View File

@@ -26,3 +26,5 @@ oauth2 = { version = "5.0.0", features = ["reqwest"] }
reqwest = { version = "0.13.1", features = ["json"] } reqwest = { version = "0.13.1", features = ["json"] }
rand = "0.9.2" rand = "0.9.2"
uuid = { version = "1", features = ["v4"] } uuid = { version = "1", features = ["v4"] }
sha2 = "0.10"
hex = "0.4"

View File

@@ -1,14 +1,10 @@
use axum::{ use axum::{
routing::{get, post, put, delete}, routing::{get, post, put, delete},
Router, middleware as axum_middleware, Router, middleware as axum_middleware,
Extension,
}; };
use sea_orm::{sqlx, Database, DatabaseConnection, DbErr}; use sea_orm::{sqlx, Database, DatabaseConnection, DbErr};
use sea_orm_migration::prelude::*; use sea_orm_migration::prelude::*;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use tower_sessions::{Expiry, SessionManagerLayer, cookie::SameSite}; use tower_sessions::{Expiry, SessionManagerLayer, cookie::SameSite};
@@ -18,12 +14,6 @@ use time::Duration;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use axum::http::{Method, HeaderValue}; use axum::http::{Method, HeaderValue};
pub enum MobileStoreEntry {
Csrf { created_at: Instant },
AuthToken { user_id: i32, created_at: Instant },
}
pub type MobileTokenStore = Arc<Mutex<HashMap<String, MobileStoreEntry>>>;
pub mod models; pub mod models;
pub mod services; pub mod services;
@@ -179,13 +169,10 @@ pub async fn create_app(db: DatabaseConnection) -> Result<Router, DbErr> {
.layer(auth_layer.clone()) .layer(auth_layer.clone())
.with_state(db.clone()); .with_state(db.clone());
let mobile_token_store: MobileTokenStore = Arc::new(Mutex::new(HashMap::new()));
let oauth_routes = Router::new() let oauth_routes = Router::new()
.route("/auth/google", get(routes::oauth::google_auth)) .route("/auth/google", get(routes::oauth::google_auth))
.route("/auth/google/callback", get(routes::oauth::google_callback)) .route("/auth/google/callback", get(routes::oauth::google_callback))
.route("/auth/mobile-callback", get(routes::oauth::mobile_callback)) .route("/auth/mobile-callback", get(routes::oauth::mobile_callback))
.layer(Extension(mobile_token_store))
.layer(auth_layer.clone()) .layer(auth_layer.clone())
.with_state(db.clone()); .with_state(db.clone());

View File

@@ -2,23 +2,81 @@ use axum::{
extract::{Query, State}, extract::{Query, State},
http::StatusCode, http::StatusCode,
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Redirect, Response},
Extension,
Json, Json,
}; };
use axum_login::AuthSession; use axum_login::AuthSession;
use sea_orm::{DatabaseConnection, EntityTrait}; use sea_orm::{DatabaseConnection, EntityTrait};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tower_sessions::Session; use tower_sessions::Session;
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::auth::AuthBackend; use crate::auth::AuthBackend;
use crate::models::User; use crate::models::User;
use crate::services::OAuthService; use crate::services::OAuthService;
use crate::{MobileStoreEntry, MobileTokenStore};
const CSRF_TOKEN_KEY: &str = "oauth_csrf_token"; const CSRF_TOKEN_KEY: &str = "oauth_csrf_token";
const FRONTEND_URL_KEY: &str = "oauth_frontend_url"; const FRONTEND_URL_KEY: &str = "oauth_frontend_url";
fn mobile_secret() -> String {
std::env::var("MOBILE_SECRET").unwrap_or_else(|_| "family-budget-mobile-secret".to_string())
}
fn sign(data: &str) -> String {
let secret = mobile_secret();
let mut hasher = Sha256::new();
hasher.update(format!("{}:{}", secret, data).as_bytes());
hex::encode(hasher.finalize())
}
fn make_mobile_csrf_state(nonce: &str) -> String {
let sig = sign(&format!("csrf.mobile.{}", nonce));
format!("mobile.{}.{}", nonce, sig)
}
fn verify_mobile_csrf_state(state: &str) -> bool {
let mut parts = state.splitn(3, '.');
match (parts.next(), parts.next(), parts.next()) {
(Some("mobile"), Some(nonce), Some(sig)) => {
sign(&format!("csrf.mobile.{}", nonce)) == sig
}
_ => false,
}
}
fn make_auth_token(user_id: i32) -> String {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let payload = format!("{}.{}", user_id, timestamp);
let sig = sign(&format!("auth.{}", payload));
format!("{}.{}", payload, sig)
}
fn verify_auth_token(token: &str) -> Option<i32> {
let mut parts = token.splitn(3, '.');
let user_id_str = parts.next()?;
let timestamp_str = parts.next()?;
let sig = parts.next()?;
let payload = format!("{}.{}", user_id_str, timestamp_str);
if sign(&format!("auth.{}", payload)) != sig {
return None;
}
let timestamp: u64 = timestamp_str.parse().ok()?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
if now.saturating_sub(timestamp) > 300 {
return None;
}
user_id_str.parse().ok()
}
#[derive(Debug, Deserialize, ToSchema)] #[derive(Debug, Deserialize, ToSchema)]
pub struct GoogleAuthQuery { pub struct GoogleAuthQuery {
pub redirect_url: Option<String>, pub redirect_url: Option<String>,
@@ -41,7 +99,8 @@ pub struct OAuthUrlResponse {
path = "/auth/google", path = "/auth/google",
tag = "auth", tag = "auth",
params( params(
("redirect_url" = Option<String>, Query, description = "Frontend URL to redirect after auth") ("redirect_url" = Option<String>, Query, description = "Frontend URL to redirect after auth"),
("mobile" = Option<bool>, Query, description = "Mobile OAuth flow")
), ),
responses( responses(
(status = 200, description = "Returns Google OAuth URL", body = OAuthUrlResponse) (status = 200, description = "Returns Google OAuth URL", body = OAuthUrlResponse)
@@ -49,30 +108,29 @@ pub struct OAuthUrlResponse {
)] )]
pub async fn google_auth( pub async fn google_auth(
session: Session, session: Session,
Extension(token_store): Extension<MobileTokenStore>,
Query(query): Query<GoogleAuthQuery>, Query(query): Query<GoogleAuthQuery>,
) -> Result<Json<OAuthUrlResponse>, StatusCode> { ) -> Result<Json<OAuthUrlResponse>, StatusCode> {
let oauth_service = OAuthService::new(); let oauth_service = OAuthService::new();
let (auth_url, csrf_token) = oauth_service.get_auth_url();
if query.mobile.unwrap_or(false) { if query.mobile.unwrap_or(false) {
let mut store = token_store.lock().unwrap(); let nonce = uuid::Uuid::new_v4().to_string();
store.insert( let mobile_state = make_mobile_csrf_state(&nonce);
format!("csrf:{}", csrf_token.secret()), let auth_url = oauth_service.get_auth_url_with_state(mobile_state);
MobileStoreEntry::Csrf { created_at: std::time::Instant::now() }, return Ok(Json(OAuthUrlResponse { url: auth_url }));
); }
} else {
let (auth_url, csrf_token) = oauth_service.get_auth_url();
session
.insert(CSRF_TOKEN_KEY, csrf_token.secret().clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(redirect_url) = query.redirect_url {
session session
.insert(CSRF_TOKEN_KEY, csrf_token.secret().clone()) .insert(FRONTEND_URL_KEY, redirect_url)
.await .await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(redirect_url) = query.redirect_url {
session
.insert(FRONTEND_URL_KEY, redirect_url)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
} }
Ok(Json(OAuthUrlResponse { url: auth_url })) Ok(Json(OAuthUrlResponse { url: auth_url }))
@@ -91,34 +149,21 @@ pub async fn google_callback(
mut auth_session: AuthSession<AuthBackend>, mut auth_session: AuthSession<AuthBackend>,
session: Session, session: Session,
State(db): State<DatabaseConnection>, State(db): State<DatabaseConnection>,
Extension(token_store): Extension<MobileTokenStore>,
Query(query): Query<GoogleCallbackQuery>, Query(query): Query<GoogleCallbackQuery>,
) -> Result<Response, StatusCode> { ) -> Result<Response, StatusCode> {
let session_csrf: Option<String> = session let is_mobile = verify_mobile_csrf_state(&query.state);
.get(CSRF_TOKEN_KEY)
.await
.unwrap_or(None);
let is_mobile; if !is_mobile {
let csrf_valid; let session_csrf: Option<String> = session
.get(CSRF_TOKEN_KEY)
if let Some(csrf) = session_csrf { .await
is_mobile = false; .unwrap_or(None);
csrf_valid = csrf == query.state;
session.remove::<String>(CSRF_TOKEN_KEY).await.ok(); session.remove::<String>(CSRF_TOKEN_KEY).await.ok();
} else {
let key = format!("csrf:{}", &query.state);
let mut store = token_store.lock().unwrap();
csrf_valid = matches!(
store.get(&key),
Some(MobileStoreEntry::Csrf { created_at }) if created_at.elapsed().as_secs() < 300
);
store.remove(&key);
is_mobile = csrf_valid;
}
if !csrf_valid { match session_csrf {
return Err(StatusCode::UNAUTHORIZED); Some(csrf) if csrf == query.state => {}
_ => return Err(StatusCode::UNAUTHORIZED),
}
} }
let frontend_url: Option<String> = session let frontend_url: Option<String> = session
@@ -145,14 +190,7 @@ pub async fn google_callback(
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if is_mobile { if is_mobile {
let token = uuid::Uuid::new_v4().to_string(); let token = make_auth_token(user.id);
{
let mut store = token_store.lock().unwrap();
store.insert(
token.clone(),
MobileStoreEntry::AuthToken { user_id: user.id, created_at: std::time::Instant::now() },
);
}
let deep_link = format!("com.arrelin.family-budget-android://auth?token={}", token); let deep_link = format!("com.arrelin.family-budget-android://auth?token={}", token);
let html = format!( let html = format!(
r#"<!DOCTYPE html><html><head><meta http-equiv="refresh" content="0;url={0}"></head><body><script>window.location="{0}"</script></body></html>"#, r#"<!DOCTYPE html><html><head><meta http-equiv="refresh" content="0;url={0}"></head><body><script>window.location="{0}"</script></body></html>"#,
@@ -189,26 +227,11 @@ pub struct MobileCallbackQuery {
pub async fn mobile_callback( pub async fn mobile_callback(
mut auth_session: AuthSession<AuthBackend>, mut auth_session: AuthSession<AuthBackend>,
session: Session,
State(db): State<DatabaseConnection>, State(db): State<DatabaseConnection>,
Extension(token_store): Extension<MobileTokenStore>,
Query(query): Query<MobileCallbackQuery>, Query(query): Query<MobileCallbackQuery>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let user_id = { let user_id = verify_auth_token(&query.token).ok_or(StatusCode::UNAUTHORIZED)?;
let mut store = token_store.lock().unwrap();
match store.get(&query.token) {
Some(MobileStoreEntry::AuthToken { user_id, created_at })
if created_at.elapsed().as_secs() < 300 =>
{
let uid = *user_id;
store.remove(&query.token);
uid
}
_ => {
store.remove(&query.token);
return Err(StatusCode::UNAUTHORIZED);
}
}
};
let user = User::find_by_id(user_id) let user = User::find_by_id(user_id)
.one(&db) .one(&db)
@@ -221,5 +244,17 @@ pub async fn mobile_callback(
.await .await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(family_id) = user.family_id {
let mut authorized_families: Vec<i32> = session
.get("authorized_families")
.await
.unwrap_or(None)
.unwrap_or_default();
if !authorized_families.contains(&family_id) {
authorized_families.push(family_id);
session.insert("authorized_families", authorized_families).await.ok();
}
}
Ok(Json(serde_json::json!({"success": true}))) Ok(Json(serde_json::json!({"success": true})))
} }

View File

@@ -44,6 +44,26 @@ impl OAuthService {
(auth_url.to_string(), csrf_token) (auth_url.to_string(), csrf_token)
} }
pub fn get_auth_url_with_state(&self, state: String) -> String {
let client_id = std::env::var("GOOGLE_CLIENT_ID")
.expect("GOOGLE_CLIENT_ID must be set");
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET")
.expect("GOOGLE_CLIENT_SECRET must be set");
let redirect_url = std::env::var("GOOGLE_REDIRECT_URL")
.unwrap_or_else(|_| "http://localhost:8080/api/auth/google/callback".to_string());
let client = Self::get_client(client_id, client_secret, redirect_url);
let (auth_url, _) = client
.authorize_url(move || CsrfToken::new(state))
.add_scope(Scope::new("openid".to_string()))
.add_scope(Scope::new("email".to_string()))
.add_scope(Scope::new("profile".to_string()))
.url();
auth_url.to_string()
}
pub async fn exchange_code(&self, code: String) -> Result<String, OAuthError> { pub async fn exchange_code(&self, code: String) -> Result<String, OAuthError> {
let client_id = std::env::var("GOOGLE_CLIENT_ID") let client_id = std::env::var("GOOGLE_CLIENT_ID")
.expect("GOOGLE_CLIENT_ID must be set"); .expect("GOOGLE_CLIENT_ID must be set");

View File

@@ -26,8 +26,14 @@ export default function Login() {
const url = Array.isArray(urls) ? urls[0] : urls; const url = Array.isArray(urls) ? urls[0] : urls;
if (!url.startsWith(DEEP_LINK_SCHEME)) return; if (!url.startsWith(DEEP_LINK_SCHEME)) return;
const token = new URL(url).searchParams.get('token'); let token: string | null;
if (!token) return; try {
token = new URL(url).searchParams.get('token');
} catch {
setError(t('login.error'));
return;
}
if (!token) { setError(t('login.error')); return; }
try { try {
setLoading(true); setLoading(true);
@@ -35,7 +41,7 @@ export default function Login() {
const me = await authApi.me(); const me = await authApi.me();
setUser(me.data); setUser(me.data);
} catch { } catch {
setError(t('login.authError')); setError(t('login.error'));
setLoading(false); setLoading(false);
} }
}); });