oauth2
This commit is contained in:
@@ -21,4 +21,7 @@ tower-sessions-sqlx-store = { version = "0.12", features = ["postgres"] }
|
||||
argon2 = "0.5"
|
||||
async-trait = "0.1"
|
||||
thiserror = "2.0"
|
||||
time = "0.3"
|
||||
time = "0.3"
|
||||
oauth2 = { version = "5.0.0", features = ["reqwest"] }
|
||||
reqwest = { version = "0.12.28", features = ["json"] }
|
||||
rand = "0.8"
|
||||
@@ -20,7 +20,7 @@ impl AuthUser for user::Model {
|
||||
}
|
||||
|
||||
fn session_auth_hash(&self) -> &[u8] {
|
||||
self.password_hash.as_bytes()
|
||||
self.password_hash.as_deref().unwrap_or("oauth").as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,8 @@ impl AuthnBackend for AuthBackend {
|
||||
.await?;
|
||||
|
||||
if let Some(user) = user {
|
||||
let parsed_hash = PasswordHash::new(&user.password_hash)
|
||||
let password_hash = user.password_hash.as_ref().ok_or(Error::InvalidCredentials)?;
|
||||
let parsed_hash = PasswordHash::new(password_hash)
|
||||
.map_err(|_| Error::PasswordHash)?;
|
||||
|
||||
let is_valid = Argon2::default()
|
||||
|
||||
@@ -29,7 +29,12 @@ pub use middleware::{require_admin, require_family_access};
|
||||
paths(
|
||||
routes::auth::login,
|
||||
routes::auth::logout,
|
||||
routes::auth::me,
|
||||
routes::auth::family_login,
|
||||
routes::oauth::google_auth,
|
||||
routes::oauth::google_callback,
|
||||
routes::family::create_family,
|
||||
routes::family::create_my_family,
|
||||
routes::family::get_family,
|
||||
routes::family::get_all_families,
|
||||
routes::family::update_family,
|
||||
@@ -53,6 +58,11 @@ pub use middleware::{require_admin, require_family_access};
|
||||
routes::shopping_item::mark_as_purchased,
|
||||
routes::shopping_item::mark_all_as_purchased,
|
||||
routes::shopping_item::clear_all,
|
||||
routes::invite_link::create_invite_link,
|
||||
routes::invite_link::get_my_invite_links,
|
||||
routes::invite_link::delete_invite_link,
|
||||
routes::invite_link::validate_invite_link,
|
||||
routes::invite_link::join_family_via_invite,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
@@ -62,7 +72,13 @@ pub use middleware::{require_admin, require_family_access};
|
||||
models::shopping_item::Model,
|
||||
routes::auth::LoginRequest,
|
||||
routes::auth::LoginResponse,
|
||||
routes::auth::MeResponse,
|
||||
routes::auth::FamilyLoginRequest,
|
||||
routes::auth::FamilyLoginResponse,
|
||||
routes::oauth::OAuthUrlResponse,
|
||||
routes::family::CreateFamilyRequest,
|
||||
routes::family::CreateMyFamilyRequest,
|
||||
routes::family::CreateMyFamilyResponse,
|
||||
routes::family::UpdateFamilyRequest,
|
||||
routes::category::CreateCategoryRequest,
|
||||
routes::category::UpdateCategoryRequest,
|
||||
@@ -73,6 +89,11 @@ pub use middleware::{require_admin, require_family_access};
|
||||
routes::shopping_item::UpdateShoppingItemRequest,
|
||||
routes::shopping_item::MarkAsPurchasedRequest,
|
||||
routes::shopping_item::BulkOperationResponse,
|
||||
models::invite_link::Model,
|
||||
routes::invite_link::CreateInviteLinkRequest,
|
||||
routes::invite_link::InviteLinkResponse,
|
||||
routes::invite_link::ValidateInviteResponse,
|
||||
routes::invite_link::JoinFamilyResponse,
|
||||
)
|
||||
),
|
||||
tags(
|
||||
@@ -80,7 +101,8 @@ pub use middleware::{require_admin, require_family_access};
|
||||
(name = "families", description = "Family management endpoints"),
|
||||
(name = "categories", description = "Category management endpoints"),
|
||||
(name = "expenses", description = "Expense management endpoints"),
|
||||
(name = "shopping-items", description = "Shopping list management endpoints")
|
||||
(name = "shopping-items", description = "Shopping list management endpoints"),
|
||||
(name = "invite-links", description = "Family invite link management endpoints")
|
||||
),
|
||||
info(
|
||||
title = "Family Budget API",
|
||||
@@ -130,6 +152,23 @@ pub async fn create_app(db: DatabaseConnection) -> Result<Router, DbErr> {
|
||||
let auth_routes = Router::new()
|
||||
.route("/login", post(routes::auth::login))
|
||||
.route("/logout", post(routes::auth::logout))
|
||||
.route("/me", get(routes::auth::me))
|
||||
.route("/my-family", post(routes::family::create_my_family))
|
||||
.route("/auth/family-login", post(routes::auth::family_login))
|
||||
.layer(auth_layer.clone())
|
||||
.with_state(db.clone());
|
||||
|
||||
let oauth_routes = Router::new()
|
||||
.route("/auth/google", get(routes::oauth::google_auth))
|
||||
.route("/auth/google/callback", get(routes::oauth::google_callback))
|
||||
.layer(auth_layer.clone())
|
||||
.with_state(db.clone());
|
||||
|
||||
let invite_link_routes = Router::new()
|
||||
.route("/my-family/invite-links", post(routes::invite_link::create_invite_link))
|
||||
.route("/my-family/invite-links", get(routes::invite_link::get_my_invite_links))
|
||||
.route("/my-family/invite-links/:token", delete(routes::invite_link::delete_invite_link))
|
||||
.route("/invite/:token/join", post(routes::invite_link::join_family_via_invite))
|
||||
.layer(auth_layer)
|
||||
.with_state(db.clone());
|
||||
|
||||
@@ -162,12 +201,15 @@ pub async fn create_app(db: DatabaseConnection) -> Result<Router, DbErr> {
|
||||
.route("/families/:id", get(routes::family::get_family))
|
||||
.route("/families/:id", put(routes::family::update_family))
|
||||
.route("/families/:id/verify", post(routes::family::verify_family_password))
|
||||
.route("/invite/:token", get(routes::invite_link::validate_invite_link))
|
||||
.layer(session_layer)
|
||||
.with_state(db);
|
||||
|
||||
let api_routes = Router::new()
|
||||
.merge(admin_family_routes)
|
||||
.merge(auth_routes)
|
||||
.merge(oauth_routes)
|
||||
.merge(invite_link_routes)
|
||||
.merge(family_protected_routes)
|
||||
.merge(public_routes);
|
||||
|
||||
|
||||
77
backend/src/migration/m20250116_000001_add_oauth_fields.rs
Normal file
77
backend/src/migration/m20250116_000001_add_oauth_fields.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.alter_table(
|
||||
Table::alter()
|
||||
.table(User::Table)
|
||||
.add_column(ColumnDef::new(User::Email).string().unique_key())
|
||||
.add_column(ColumnDef::new(User::GoogleId).string().unique_key())
|
||||
.add_column(ColumnDef::new(User::FamilyId).integer())
|
||||
.modify_column(ColumnDef::new(User::PasswordHash).string().null())
|
||||
.modify_column(ColumnDef::new(User::Username).string().null())
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
manager
|
||||
.create_foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("fk_user_family")
|
||||
.from(User::Table, User::FamilyId)
|
||||
.to(Family::Table, Family::Id)
|
||||
.on_delete(ForeignKeyAction::SetNull)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.drop_foreign_key(
|
||||
ForeignKey::drop()
|
||||
.name("fk_user_family")
|
||||
.table(User::Table)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
manager
|
||||
.alter_table(
|
||||
Table::alter()
|
||||
.table(User::Table)
|
||||
.drop_column(User::Email)
|
||||
.drop_column(User::GoogleId)
|
||||
.drop_column(User::FamilyId)
|
||||
.modify_column(ColumnDef::new(User::PasswordHash).string().not_null())
|
||||
.modify_column(ColumnDef::new(User::Username).string().not_null())
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum User {
|
||||
Table,
|
||||
Email,
|
||||
GoogleId,
|
||||
FamilyId,
|
||||
PasswordHash,
|
||||
Username,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum Family {
|
||||
Table,
|
||||
Id,
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.create_table(
|
||||
Table::create()
|
||||
.table(InviteLink::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(InviteLink::Id)
|
||||
.integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key(),
|
||||
)
|
||||
.col(ColumnDef::new(InviteLink::FamilyId).integer().not_null())
|
||||
.col(
|
||||
ColumnDef::new(InviteLink::Token)
|
||||
.string()
|
||||
.not_null()
|
||||
.unique_key(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(InviteLink::CreatedAt)
|
||||
.timestamp()
|
||||
.not_null()
|
||||
.default(Expr::current_timestamp()),
|
||||
)
|
||||
.col(ColumnDef::new(InviteLink::ExpiresAt).timestamp())
|
||||
.col(ColumnDef::new(InviteLink::MaxUses).integer())
|
||||
.col(
|
||||
ColumnDef::new(InviteLink::UsesCount)
|
||||
.integer()
|
||||
.not_null()
|
||||
.default(0),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(InviteLink::CreatedBy)
|
||||
.integer()
|
||||
.not_null(),
|
||||
)
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("fk_invite_link_family")
|
||||
.from(InviteLink::Table, InviteLink::FamilyId)
|
||||
.to(Family::Table, Family::Id)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("fk_invite_link_created_by")
|
||||
.from(InviteLink::Table, InviteLink::CreatedBy)
|
||||
.to(User::Table, User::Id)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.to_owned(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.drop_table(Table::drop().table(InviteLink::Table).to_owned())
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum InviteLink {
|
||||
Table,
|
||||
Id,
|
||||
FamilyId,
|
||||
Token,
|
||||
CreatedAt,
|
||||
ExpiresAt,
|
||||
MaxUses,
|
||||
UsesCount,
|
||||
CreatedBy,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum Family {
|
||||
Table,
|
||||
Id,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum User {
|
||||
Table,
|
||||
Id,
|
||||
}
|
||||
@@ -5,6 +5,8 @@ mod m20241209_000002_create_users;
|
||||
mod m20241209_000003_seed_admin;
|
||||
mod m20241215_000001_add_family_password;
|
||||
mod m20241224_000001_create_shopping_items;
|
||||
mod m20250116_000001_add_oauth_fields;
|
||||
mod m20250117_000001_create_invite_links;
|
||||
|
||||
pub struct Migrator;
|
||||
|
||||
@@ -17,6 +19,8 @@ impl MigratorTrait for Migrator {
|
||||
Box::new(m20241209_000003_seed_admin::Migration),
|
||||
Box::new(m20241215_000001_add_family_password::Migration),
|
||||
Box::new(m20241224_000001_create_shopping_items::Migration),
|
||||
Box::new(m20250116_000001_add_oauth_fields::Migration),
|
||||
Box::new(m20250117_000001_create_invite_links::Migration),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
48
backend/src/models/invite_link.rs
Normal file
48
backend/src/models/invite_link.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize, ToSchema)]
|
||||
#[sea_orm(table_name = "invite_link")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
pub family_id: i32,
|
||||
#[sea_orm(unique)]
|
||||
pub token: String,
|
||||
pub created_at: DateTime,
|
||||
pub expires_at: Option<DateTime>,
|
||||
pub max_uses: Option<i32>,
|
||||
pub uses_count: i32,
|
||||
pub created_by: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::family::Entity",
|
||||
from = "Column::FamilyId",
|
||||
to = "super::family::Column::Id"
|
||||
)]
|
||||
Family,
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::CreatedBy",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::family::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Family.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -3,9 +3,11 @@ pub mod category;
|
||||
pub mod expense;
|
||||
pub mod user;
|
||||
pub mod shopping_item;
|
||||
pub mod invite_link;
|
||||
|
||||
pub use family::Entity as Family;
|
||||
pub use category::Entity as Category;
|
||||
pub use expense::Entity as Expense;
|
||||
pub use user::Entity as User;
|
||||
pub use shopping_item::Entity as ShoppingItem;
|
||||
pub use invite_link::Entity as InviteLink;
|
||||
|
||||
@@ -7,12 +7,32 @@ pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
#[sea_orm(unique)]
|
||||
pub username: String,
|
||||
pub password_hash: String,
|
||||
pub username: Option<String>,
|
||||
#[serde(skip_serializing)]
|
||||
pub password_hash: Option<String>,
|
||||
pub is_admin: bool,
|
||||
#[sea_orm(unique)]
|
||||
pub email: Option<String>,
|
||||
#[sea_orm(unique)]
|
||||
#[serde(skip_serializing)]
|
||||
pub google_id: Option<String>,
|
||||
pub family_id: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::family::Entity",
|
||||
from = "Column::FamilyId",
|
||||
to = "super::family::Column::Id"
|
||||
)]
|
||||
Family,
|
||||
}
|
||||
|
||||
impl Related<super::family::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Family.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use axum_login::AuthSession;
|
||||
use sea_orm::{DatabaseConnection, EntityTrait, ColumnTrait, QueryFilter, ActiveModelTrait, Set};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::auth::{AuthBackend, Credentials};
|
||||
use crate::models::{user, User, family, Family};
|
||||
use crate::services::FamilyService;
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
pub struct LoginRequest {
|
||||
@@ -20,6 +24,15 @@ pub struct LoginResponse {
|
||||
pub is_admin: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct MeResponse {
|
||||
pub id: i32,
|
||||
pub username: Option<String>,
|
||||
pub email: Option<String>,
|
||||
pub is_admin: bool,
|
||||
pub family_id: Option<i32>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/login",
|
||||
@@ -72,3 +85,105 @@ pub async fn logout(
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/me",
|
||||
tag = "auth",
|
||||
responses(
|
||||
(status = 200, description = "Current user info", body = MeResponse),
|
||||
(status = 401, description = "Not authenticated")
|
||||
)
|
||||
)]
|
||||
pub async fn me(
|
||||
auth_session: AuthSession<AuthBackend>,
|
||||
) -> Result<Json<MeResponse>, StatusCode> {
|
||||
let user = auth_session.user.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
Ok(Json(MeResponse {
|
||||
id: user.id,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
is_admin: user.is_admin,
|
||||
family_id: user.family_id,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
pub struct FamilyLoginRequest {
|
||||
pub family_name: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct FamilyLoginResponse {
|
||||
pub success: bool,
|
||||
pub family_id: i32,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/family-login",
|
||||
tag = "auth",
|
||||
request_body = FamilyLoginRequest,
|
||||
responses(
|
||||
(status = 200, description = "Login successful", body = FamilyLoginResponse),
|
||||
(status = 401, description = "Invalid credentials"),
|
||||
(status = 404, description = "Family not found")
|
||||
)
|
||||
)]
|
||||
pub async fn family_login(
|
||||
mut auth_session: AuthSession<AuthBackend>,
|
||||
State(db): State<DatabaseConnection>,
|
||||
Json(payload): Json<FamilyLoginRequest>,
|
||||
) -> Result<Json<FamilyLoginResponse>, StatusCode> {
|
||||
let family = Family::find()
|
||||
.filter(family::Column::Name.eq(&payload.family_name))
|
||||
.one(&db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
let valid = FamilyService::verify_password(&db, family.id, payload.password.clone())
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !valid {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let existing_member = User::find()
|
||||
.filter(user::Column::FamilyId.eq(family.id))
|
||||
.filter(user::Column::GoogleId.is_null())
|
||||
.filter(user::Column::Username.eq(&payload.family_name))
|
||||
.one(&db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let member_user = if let Some(user) = existing_member {
|
||||
user
|
||||
} else {
|
||||
let new_member = user::ActiveModel {
|
||||
username: Set(Some(payload.family_name)),
|
||||
email: Set(None),
|
||||
google_id: Set(None),
|
||||
password_hash: Set(None),
|
||||
is_admin: Set(false),
|
||||
family_id: Set(Some(family.id)),
|
||||
..Default::default()
|
||||
};
|
||||
new_member.insert(&db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
};
|
||||
|
||||
auth_session
|
||||
.login(&member_user)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(FamilyLoginResponse {
|
||||
success: true,
|
||||
family_id: family.id,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -3,12 +3,15 @@ use axum::{
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use sea_orm::DatabaseConnection;
|
||||
use axum_login::AuthSession;
|
||||
use sea_orm::{DatabaseConnection, EntityTrait, ActiveModelTrait, Set};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
use tower_sessions::Session;
|
||||
|
||||
use crate::auth::AuthBackend;
|
||||
use crate::models::family::Model as FamilyModel;
|
||||
use crate::models::{user, User};
|
||||
use crate::services::FamilyService;
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
@@ -18,6 +21,14 @@ pub struct CreateFamilyRequest {
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
#[schema(example = json!({"name": "Smith Family", "password": "secret123"}))]
|
||||
pub struct CreateMyFamilyRequest {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
#[schema(example = json!({"password": "secret123"}))]
|
||||
pub struct VerifyFamilyPasswordRequest {
|
||||
@@ -188,3 +199,57 @@ pub async fn verify_family_password(
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct CreateMyFamilyResponse {
|
||||
pub family: FamilyModel,
|
||||
pub user_id: i32,
|
||||
pub family_id: i32,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/my-family",
|
||||
tag = "families",
|
||||
request_body = CreateMyFamilyRequest,
|
||||
responses(
|
||||
(status = 200, description = "Family created and linked to user", body = CreateMyFamilyResponse),
|
||||
(status = 401, description = "Not authenticated"),
|
||||
(status = 409, description = "User already has a family"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_my_family(
|
||||
auth_session: AuthSession<AuthBackend>,
|
||||
State(db): State<DatabaseConnection>,
|
||||
Json(payload): Json<CreateMyFamilyRequest>,
|
||||
) -> Result<Json<CreateMyFamilyResponse>, StatusCode> {
|
||||
let current_user = auth_session.user.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if current_user.family_id.is_some() {
|
||||
return Err(StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
let password = payload.password.unwrap_or_default();
|
||||
let family = FamilyService::create(&db, payload.name, password)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let mut active_user: user::ActiveModel = User::find_by_id(current_user.id)
|
||||
.one(&db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?
|
||||
.into();
|
||||
|
||||
active_user.family_id = Set(Some(family.id));
|
||||
active_user.update(&db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(CreateMyFamilyResponse {
|
||||
family_id: family.id,
|
||||
user_id: current_user.id,
|
||||
family,
|
||||
}))
|
||||
}
|
||||
|
||||
264
backend/src/routes/invite_link.rs
Normal file
264
backend/src/routes/invite_link.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use axum_login::AuthSession;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::auth::AuthBackend;
|
||||
use crate::models::invite_link::Model as InviteLinkModel;
|
||||
use crate::services::InviteLinkService;
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
#[schema(example = json!({"expires_in_hours": 24, "max_uses": 5}))]
|
||||
pub struct CreateInviteLinkRequest {
|
||||
pub expires_in_hours: Option<i64>,
|
||||
pub max_uses: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct InviteLinkResponse {
|
||||
pub id: i32,
|
||||
pub family_id: i32,
|
||||
pub token: String,
|
||||
pub invite_url: String,
|
||||
pub expires_at: Option<String>,
|
||||
pub max_uses: Option<i32>,
|
||||
pub uses_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ValidateInviteResponse {
|
||||
pub valid: bool,
|
||||
pub family_id: Option<i32>,
|
||||
pub family_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct JoinFamilyResponse {
|
||||
pub success: bool,
|
||||
pub family_id: i32,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
fn model_to_response(model: InviteLinkModel, base_url: &str) -> InviteLinkResponse {
|
||||
InviteLinkResponse {
|
||||
id: model.id,
|
||||
family_id: model.family_id,
|
||||
token: model.token.clone(),
|
||||
invite_url: format!("{}/invite/{}", base_url, model.token),
|
||||
expires_at: model.expires_at.map(|dt| dt.to_string()),
|
||||
max_uses: model.max_uses,
|
||||
uses_count: model.uses_count,
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/my-family/invite-links",
|
||||
tag = "invite-links",
|
||||
request_body = CreateInviteLinkRequest,
|
||||
responses(
|
||||
(status = 200, description = "Invite link created", body = InviteLinkResponse),
|
||||
(status = 401, description = "Not authenticated"),
|
||||
(status = 403, description = "User has no family"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn create_invite_link(
|
||||
auth_session: AuthSession<AuthBackend>,
|
||||
State(db): State<DatabaseConnection>,
|
||||
Json(payload): Json<CreateInviteLinkRequest>,
|
||||
) -> Result<Json<InviteLinkResponse>, StatusCode> {
|
||||
let user = auth_session.user.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
let family_id = user.family_id.ok_or(StatusCode::FORBIDDEN)?;
|
||||
|
||||
let expires_at = payload.expires_in_hours.map(|hours| {
|
||||
chrono::Utc::now().naive_utc() + chrono::Duration::hours(hours)
|
||||
});
|
||||
|
||||
let invite = InviteLinkService::create(&db, family_id, user.id, expires_at, payload.max_uses)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let base_url = std::env::var("FRONTEND_URL").unwrap_or_else(|_| "http://localhost:5173".to_string());
|
||||
Ok(Json(model_to_response(invite, &base_url)))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/my-family/invite-links",
|
||||
tag = "invite-links",
|
||||
responses(
|
||||
(status = 200, description = "List of invite links", body = Vec<InviteLinkResponse>),
|
||||
(status = 401, description = "Not authenticated"),
|
||||
(status = 403, description = "User has no family"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn get_my_invite_links(
|
||||
auth_session: AuthSession<AuthBackend>,
|
||||
State(db): State<DatabaseConnection>,
|
||||
) -> Result<Json<Vec<InviteLinkResponse>>, StatusCode> {
|
||||
let user = auth_session.user.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
let family_id = user.family_id.ok_or(StatusCode::FORBIDDEN)?;
|
||||
|
||||
let invites = InviteLinkService::find_by_family(&db, family_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let base_url = std::env::var("FRONTEND_URL").unwrap_or_else(|_| "http://localhost:5173".to_string());
|
||||
let responses: Vec<InviteLinkResponse> = invites
|
||||
.into_iter()
|
||||
.map(|i| model_to_response(i, &base_url))
|
||||
.collect();
|
||||
|
||||
Ok(Json(responses))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/my-family/invite-links/{token}",
|
||||
tag = "invite-links",
|
||||
params(
|
||||
("token" = String, Path, description = "Invite token")
|
||||
),
|
||||
responses(
|
||||
(status = 204, description = "Invite link deleted"),
|
||||
(status = 401, description = "Not authenticated"),
|
||||
(status = 403, description = "User has no family or not authorized"),
|
||||
(status = 404, description = "Invite link not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn delete_invite_link(
|
||||
auth_session: AuthSession<AuthBackend>,
|
||||
State(db): State<DatabaseConnection>,
|
||||
Path(token): Path<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
let user = auth_session.user.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
let family_id = user.family_id.ok_or(StatusCode::FORBIDDEN)?;
|
||||
|
||||
let invite = InviteLinkService::find_by_token(&db, &token)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
if invite.family_id != family_id {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
InviteLinkService::delete_by_token(&db, &token)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/invite/{token}",
|
||||
tag = "invite-links",
|
||||
params(
|
||||
("token" = String, Path, description = "Invite token")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Invite link is valid", body = ValidateInviteResponse),
|
||||
(status = 404, description = "Invite link not found or invalid"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn validate_invite_link(
|
||||
State(db): State<DatabaseConnection>,
|
||||
Path(token): Path<String>,
|
||||
) -> Result<Json<ValidateInviteResponse>, StatusCode> {
|
||||
use crate::models::Family;
|
||||
use sea_orm::EntityTrait;
|
||||
|
||||
let invite = InviteLinkService::find_by_token(&db, &token)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
if let Some(expires_at) = invite.expires_at {
|
||||
let now = chrono::Utc::now().naive_utc();
|
||||
if now > expires_at {
|
||||
return Ok(Json(ValidateInviteResponse {
|
||||
valid: false,
|
||||
family_id: None,
|
||||
family_name: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_uses) = invite.max_uses {
|
||||
if invite.uses_count >= max_uses {
|
||||
return Ok(Json(ValidateInviteResponse {
|
||||
valid: false,
|
||||
family_id: None,
|
||||
family_name: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
let family = Family::find_by_id(invite.family_id)
|
||||
.one(&db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(ValidateInviteResponse {
|
||||
valid: true,
|
||||
family_id: Some(invite.family_id),
|
||||
family_name: family.map(|f| f.name),
|
||||
}))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/invite/{token}/join",
|
||||
tag = "invite-links",
|
||||
params(
|
||||
("token" = String, Path, description = "Invite token")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Successfully joined family", body = JoinFamilyResponse),
|
||||
(status = 401, description = "Not authenticated"),
|
||||
(status = 400, description = "User already in a family or invite invalid"),
|
||||
(status = 404, description = "Invite link not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
pub async fn join_family_via_invite(
|
||||
auth_session: AuthSession<AuthBackend>,
|
||||
State(db): State<DatabaseConnection>,
|
||||
Path(token): Path<String>,
|
||||
) -> Result<Json<JoinFamilyResponse>, StatusCode> {
|
||||
let user = auth_session.user.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if user.family_id.is_some() {
|
||||
return Ok(Json(JoinFamilyResponse {
|
||||
success: false,
|
||||
family_id: 0,
|
||||
message: "You already belong to a family".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
let invite = InviteLinkService::validate_and_use(&db, &token, user.id)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
sea_orm::DbErr::RecordNotFound(_) => StatusCode::NOT_FOUND,
|
||||
sea_orm::DbErr::Custom(msg) if msg.contains("expired") => StatusCode::BAD_REQUEST,
|
||||
sea_orm::DbErr::Custom(msg) if msg.contains("max uses") => StatusCode::BAD_REQUEST,
|
||||
sea_orm::DbErr::Custom(msg) if msg.contains("already belongs") => StatusCode::BAD_REQUEST,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
})?;
|
||||
|
||||
Ok(Json(JoinFamilyResponse {
|
||||
success: true,
|
||||
family_id: invite.family_id,
|
||||
message: "Successfully joined family".to_string(),
|
||||
}))
|
||||
}
|
||||
@@ -3,3 +3,5 @@ pub mod category;
|
||||
pub mod expense;
|
||||
pub mod auth;
|
||||
pub mod shopping_item;
|
||||
pub mod oauth;
|
||||
pub mod invite_link;
|
||||
|
||||
125
backend/src/routes/oauth.rs
Normal file
125
backend/src/routes/oauth.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::Redirect,
|
||||
Json,
|
||||
};
|
||||
use axum_login::AuthSession;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_sessions::Session;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::auth::AuthBackend;
|
||||
use crate::services::OAuthService;
|
||||
|
||||
const CSRF_TOKEN_KEY: &str = "oauth_csrf_token";
|
||||
const FRONTEND_URL_KEY: &str = "oauth_frontend_url";
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
pub struct GoogleAuthQuery {
|
||||
pub redirect_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GoogleCallbackQuery {
|
||||
pub code: String,
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct OAuthUrlResponse {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/auth/google",
|
||||
tag = "auth",
|
||||
params(
|
||||
("redirect_url" = Option<String>, Query, description = "Frontend URL to redirect after auth")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Returns Google OAuth URL", body = OAuthUrlResponse)
|
||||
)
|
||||
)]
|
||||
pub async fn google_auth(
|
||||
session: Session,
|
||||
Query(query): Query<GoogleAuthQuery>,
|
||||
) -> Result<Json<OAuthUrlResponse>, StatusCode> {
|
||||
let oauth_service = OAuthService::new();
|
||||
let (auth_url, csrf_token) = oauth_service.get_auth_url();
|
||||
|
||||
session
|
||||
.insert(CSRF_TOKEN_KEY, csrf_token.secret().clone())
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if let Some(redirect_url) = query.redirect_url {
|
||||
session
|
||||
.insert(FRONTEND_URL_KEY, redirect_url)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
Ok(Json(OAuthUrlResponse { url: auth_url }))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/auth/google/callback",
|
||||
tag = "auth",
|
||||
responses(
|
||||
(status = 302, description = "Redirects to frontend after successful auth"),
|
||||
(status = 401, description = "Authentication failed")
|
||||
)
|
||||
)]
|
||||
pub async fn google_callback(
|
||||
mut auth_session: AuthSession<AuthBackend>,
|
||||
session: Session,
|
||||
State(db): State<DatabaseConnection>,
|
||||
Query(query): Query<GoogleCallbackQuery>,
|
||||
) -> Result<Redirect, StatusCode> {
|
||||
let stored_csrf: Option<String> = session
|
||||
.get(CSRF_TOKEN_KEY)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let frontend_url: Option<String> = session
|
||||
.get(FRONTEND_URL_KEY)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
session.remove::<String>(CSRF_TOKEN_KEY).await.ok();
|
||||
session.remove::<String>(FRONTEND_URL_KEY).await.ok();
|
||||
|
||||
if stored_csrf.as_deref() != Some(&query.state) {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let oauth_service = OAuthService::new();
|
||||
|
||||
let access_token = oauth_service
|
||||
.exchange_code(query.code)
|
||||
.await
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let google_user = oauth_service
|
||||
.get_user_info(&access_token)
|
||||
.await
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user = oauth_service
|
||||
.find_or_create_user(&db, google_user)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
auth_session
|
||||
.login(&user)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let redirect_url = frontend_url.unwrap_or_else(|| "http://localhost:3000".to_string());
|
||||
|
||||
Ok(Redirect::temporary(&redirect_url))
|
||||
}
|
||||
121
backend/src/services/invite_link_service.rs
Normal file
121
backend/src/services/invite_link_service.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use sea_orm::*;
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::Rng;
|
||||
use crate::models::invite_link::{self, Entity as InviteLink, Model as InviteLinkModel};
|
||||
use crate::models::{user, User};
|
||||
|
||||
pub struct InviteLinkService;
|
||||
|
||||
impl InviteLinkService {
|
||||
pub fn generate_token() -> String {
|
||||
rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
db: &DatabaseConnection,
|
||||
family_id: i32,
|
||||
created_by: i32,
|
||||
expires_at: Option<chrono::NaiveDateTime>,
|
||||
max_uses: Option<i32>,
|
||||
) -> Result<InviteLinkModel, DbErr> {
|
||||
let token = Self::generate_token();
|
||||
|
||||
let invite = invite_link::ActiveModel {
|
||||
family_id: Set(family_id),
|
||||
token: Set(token),
|
||||
created_by: Set(created_by),
|
||||
expires_at: Set(expires_at),
|
||||
max_uses: Set(max_uses),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
invite.insert(db).await
|
||||
}
|
||||
|
||||
pub async fn find_by_token(
|
||||
db: &DatabaseConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<InviteLinkModel>, DbErr> {
|
||||
InviteLink::find()
|
||||
.filter(invite_link::Column::Token.eq(token))
|
||||
.one(db)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_by_family(
|
||||
db: &DatabaseConnection,
|
||||
family_id: i32,
|
||||
) -> Result<Vec<InviteLinkModel>, DbErr> {
|
||||
InviteLink::find()
|
||||
.filter(invite_link::Column::FamilyId.eq(family_id))
|
||||
.all(db)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn validate_and_use(
|
||||
db: &DatabaseConnection,
|
||||
token: &str,
|
||||
user_id: i32,
|
||||
) -> Result<InviteLinkModel, DbErr> {
|
||||
let invite = InviteLink::find()
|
||||
.filter(invite_link::Column::Token.eq(token))
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("Invite link not found".to_string()))?;
|
||||
|
||||
if let Some(expires_at) = invite.expires_at {
|
||||
let now = chrono::Utc::now().naive_utc();
|
||||
if now > expires_at {
|
||||
return Err(DbErr::Custom("Invite link has expired".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_uses) = invite.max_uses {
|
||||
if invite.uses_count >= max_uses {
|
||||
return Err(DbErr::Custom("Invite link has reached max uses".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let user = User::find_by_id(user_id)
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("User not found".to_string()))?;
|
||||
|
||||
if user.family_id.is_some() {
|
||||
return Err(DbErr::Custom("User already belongs to a family".to_string()));
|
||||
}
|
||||
|
||||
let mut active_user: user::ActiveModel = user.into();
|
||||
active_user.family_id = Set(Some(invite.family_id));
|
||||
active_user.update(db).await?;
|
||||
|
||||
let mut active_invite: invite_link::ActiveModel = invite.clone().into();
|
||||
active_invite.uses_count = Set(invite.uses_count + 1);
|
||||
active_invite.update(db).await
|
||||
}
|
||||
|
||||
pub async fn delete(db: &DatabaseConnection, id: i32) -> Result<DeleteResult, DbErr> {
|
||||
let invite = InviteLink::find_by_id(id)
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("Invite link not found".to_string()))?;
|
||||
|
||||
let invite: invite_link::ActiveModel = invite.into();
|
||||
invite.delete(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_by_token(db: &DatabaseConnection, token: &str) -> Result<DeleteResult, DbErr> {
|
||||
let invite = InviteLink::find()
|
||||
.filter(invite_link::Column::Token.eq(token))
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("Invite link not found".to_string()))?;
|
||||
|
||||
let invite: invite_link::ActiveModel = invite.into();
|
||||
invite.delete(db).await
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,12 @@ pub mod family_service;
|
||||
pub mod category_service;
|
||||
pub mod expense_service;
|
||||
pub mod shopping_item_service;
|
||||
pub mod oauth_service;
|
||||
pub mod invite_link_service;
|
||||
|
||||
pub use family_service::FamilyService;
|
||||
pub use category_service::CategoryService;
|
||||
pub use expense_service::ExpenseService;
|
||||
pub use shopping_item_service::ShoppingItemService;
|
||||
pub use oauth_service::OAuthService;
|
||||
pub use invite_link_service::InviteLinkService;
|
||||
|
||||
148
backend/src/services/oauth_service.rs
Normal file
148
backend/src/services/oauth_service.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl,
|
||||
AuthorizationCode, TokenResponse, Scope, CsrfToken,
|
||||
};
|
||||
use reqwest::Client as HttpClient;
|
||||
use sea_orm::{DatabaseConnection, EntityTrait, ColumnTrait, QueryFilter, ActiveModelTrait, Set};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::models::{user, User};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GoogleUserInfo {
|
||||
pub id: String,
|
||||
pub email: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
pub struct OAuthService {
|
||||
http_client: HttpClient,
|
||||
}
|
||||
|
||||
impl OAuthService {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http_client: HttpClient::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_auth_url(&self) -> (String, CsrfToken) {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID")
|
||||
.expect("GOOGLE_CLIENT_ID must be set");
|
||||
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET")
|
||||
.expect("GOOGLE_CLIENT_SECRET must be set");
|
||||
let redirect_url = std::env::var("GOOGLE_REDIRECT_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080/api/auth/google/callback".to_string());
|
||||
|
||||
let client = BasicClient::new(ClientId::new(client_id))
|
||||
.set_client_secret(ClientSecret::new(client_secret))
|
||||
.set_auth_uri(AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string()).unwrap())
|
||||
.set_token_uri(TokenUrl::new("https://oauth2.googleapis.com/token".to_string()).unwrap())
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap());
|
||||
|
||||
let (auth_url, csrf_token) = client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new("openid".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.add_scope(Scope::new("profile".to_string()))
|
||||
.url();
|
||||
|
||||
(auth_url.to_string(), csrf_token)
|
||||
}
|
||||
|
||||
pub async fn exchange_code(&self, code: String) -> Result<String, OAuthError> {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID")
|
||||
.expect("GOOGLE_CLIENT_ID must be set");
|
||||
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET")
|
||||
.expect("GOOGLE_CLIENT_SECRET must be set");
|
||||
let redirect_url = std::env::var("GOOGLE_REDIRECT_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080/api/auth/google/callback".to_string());
|
||||
|
||||
let client = BasicClient::new(ClientId::new(client_id))
|
||||
.set_client_secret(ClientSecret::new(client_secret))
|
||||
.set_auth_uri(AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string()).unwrap())
|
||||
.set_token_uri(TokenUrl::new("https://oauth2.googleapis.com/token".to_string()).unwrap())
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap());
|
||||
|
||||
let http_client = oauth2::reqwest::ClientBuilder::new()
|
||||
.build()
|
||||
.map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
|
||||
|
||||
let token = client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.request_async(&http_client)
|
||||
.await
|
||||
.map_err(|e: oauth2::RequestTokenError<_, _>| OAuthError::TokenExchange(e.to_string()))?;
|
||||
|
||||
Ok(token.access_token().secret().clone())
|
||||
}
|
||||
|
||||
pub async fn get_user_info(&self, access_token: &str) -> Result<GoogleUserInfo, OAuthError> {
|
||||
let response = self.http_client
|
||||
.get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
.bearer_auth(access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| OAuthError::UserInfo(e.to_string()))?;
|
||||
|
||||
response
|
||||
.json::<GoogleUserInfo>()
|
||||
.await
|
||||
.map_err(|e| OAuthError::UserInfo(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn find_or_create_user(
|
||||
&self,
|
||||
db: &DatabaseConnection,
|
||||
google_user: GoogleUserInfo,
|
||||
) -> Result<user::Model, OAuthError> {
|
||||
let existing = User::find()
|
||||
.filter(user::Column::GoogleId.eq(&google_user.id))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
|
||||
if let Some(user) = existing {
|
||||
return Ok(user);
|
||||
}
|
||||
|
||||
let existing_by_email = User::find()
|
||||
.filter(user::Column::Email.eq(&google_user.email))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
|
||||
if let Some(user) = existing_by_email {
|
||||
let mut active: user::ActiveModel = user.into();
|
||||
active.google_id = Set(Some(google_user.id));
|
||||
let updated = active.update(db).await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
return Ok(updated);
|
||||
}
|
||||
|
||||
let new_user = user::ActiveModel {
|
||||
email: Set(Some(google_user.email)),
|
||||
google_id: Set(Some(google_user.id)),
|
||||
username: Set(google_user.name),
|
||||
password_hash: Set(None),
|
||||
is_admin: Set(false),
|
||||
family_id: Set(None),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let created = new_user.insert(db).await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
|
||||
Ok(created)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OAuthError {
|
||||
#[error("Token exchange failed: {0}")]
|
||||
TokenExchange(String),
|
||||
#[error("Failed to get user info: {0}")]
|
||||
UserInfo(String),
|
||||
#[error("Database error: {0}")]
|
||||
Database(String),
|
||||
}
|
||||
Reference in New Issue
Block a user