oauth2
This commit is contained in:
121
backend/src/services/invite_link_service.rs
Normal file
121
backend/src/services/invite_link_service.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use sea_orm::*;
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::Rng;
|
||||
use crate::models::invite_link::{self, Entity as InviteLink, Model as InviteLinkModel};
|
||||
use crate::models::{user, User};
|
||||
|
||||
pub struct InviteLinkService;
|
||||
|
||||
impl InviteLinkService {
|
||||
pub fn generate_token() -> String {
|
||||
rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
db: &DatabaseConnection,
|
||||
family_id: i32,
|
||||
created_by: i32,
|
||||
expires_at: Option<chrono::NaiveDateTime>,
|
||||
max_uses: Option<i32>,
|
||||
) -> Result<InviteLinkModel, DbErr> {
|
||||
let token = Self::generate_token();
|
||||
|
||||
let invite = invite_link::ActiveModel {
|
||||
family_id: Set(family_id),
|
||||
token: Set(token),
|
||||
created_by: Set(created_by),
|
||||
expires_at: Set(expires_at),
|
||||
max_uses: Set(max_uses),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
invite.insert(db).await
|
||||
}
|
||||
|
||||
pub async fn find_by_token(
|
||||
db: &DatabaseConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<InviteLinkModel>, DbErr> {
|
||||
InviteLink::find()
|
||||
.filter(invite_link::Column::Token.eq(token))
|
||||
.one(db)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_by_family(
|
||||
db: &DatabaseConnection,
|
||||
family_id: i32,
|
||||
) -> Result<Vec<InviteLinkModel>, DbErr> {
|
||||
InviteLink::find()
|
||||
.filter(invite_link::Column::FamilyId.eq(family_id))
|
||||
.all(db)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn validate_and_use(
|
||||
db: &DatabaseConnection,
|
||||
token: &str,
|
||||
user_id: i32,
|
||||
) -> Result<InviteLinkModel, DbErr> {
|
||||
let invite = InviteLink::find()
|
||||
.filter(invite_link::Column::Token.eq(token))
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("Invite link not found".to_string()))?;
|
||||
|
||||
if let Some(expires_at) = invite.expires_at {
|
||||
let now = chrono::Utc::now().naive_utc();
|
||||
if now > expires_at {
|
||||
return Err(DbErr::Custom("Invite link has expired".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_uses) = invite.max_uses {
|
||||
if invite.uses_count >= max_uses {
|
||||
return Err(DbErr::Custom("Invite link has reached max uses".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let user = User::find_by_id(user_id)
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("User not found".to_string()))?;
|
||||
|
||||
if user.family_id.is_some() {
|
||||
return Err(DbErr::Custom("User already belongs to a family".to_string()));
|
||||
}
|
||||
|
||||
let mut active_user: user::ActiveModel = user.into();
|
||||
active_user.family_id = Set(Some(invite.family_id));
|
||||
active_user.update(db).await?;
|
||||
|
||||
let mut active_invite: invite_link::ActiveModel = invite.clone().into();
|
||||
active_invite.uses_count = Set(invite.uses_count + 1);
|
||||
active_invite.update(db).await
|
||||
}
|
||||
|
||||
pub async fn delete(db: &DatabaseConnection, id: i32) -> Result<DeleteResult, DbErr> {
|
||||
let invite = InviteLink::find_by_id(id)
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("Invite link not found".to_string()))?;
|
||||
|
||||
let invite: invite_link::ActiveModel = invite.into();
|
||||
invite.delete(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_by_token(db: &DatabaseConnection, token: &str) -> Result<DeleteResult, DbErr> {
|
||||
let invite = InviteLink::find()
|
||||
.filter(invite_link::Column::Token.eq(token))
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or(DbErr::RecordNotFound("Invite link not found".to_string()))?;
|
||||
|
||||
let invite: invite_link::ActiveModel = invite.into();
|
||||
invite.delete(db).await
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,12 @@ pub mod family_service;
|
||||
pub mod category_service;
|
||||
pub mod expense_service;
|
||||
pub mod shopping_item_service;
|
||||
pub mod oauth_service;
|
||||
pub mod invite_link_service;
|
||||
|
||||
pub use family_service::FamilyService;
|
||||
pub use category_service::CategoryService;
|
||||
pub use expense_service::ExpenseService;
|
||||
pub use shopping_item_service::ShoppingItemService;
|
||||
pub use oauth_service::OAuthService;
|
||||
pub use invite_link_service::InviteLinkService;
|
||||
|
||||
148
backend/src/services/oauth_service.rs
Normal file
148
backend/src/services/oauth_service.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl,
|
||||
AuthorizationCode, TokenResponse, Scope, CsrfToken,
|
||||
};
|
||||
use reqwest::Client as HttpClient;
|
||||
use sea_orm::{DatabaseConnection, EntityTrait, ColumnTrait, QueryFilter, ActiveModelTrait, Set};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::models::{user, User};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GoogleUserInfo {
|
||||
pub id: String,
|
||||
pub email: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
pub struct OAuthService {
|
||||
http_client: HttpClient,
|
||||
}
|
||||
|
||||
impl OAuthService {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http_client: HttpClient::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_auth_url(&self) -> (String, CsrfToken) {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID")
|
||||
.expect("GOOGLE_CLIENT_ID must be set");
|
||||
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET")
|
||||
.expect("GOOGLE_CLIENT_SECRET must be set");
|
||||
let redirect_url = std::env::var("GOOGLE_REDIRECT_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080/api/auth/google/callback".to_string());
|
||||
|
||||
let client = BasicClient::new(ClientId::new(client_id))
|
||||
.set_client_secret(ClientSecret::new(client_secret))
|
||||
.set_auth_uri(AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string()).unwrap())
|
||||
.set_token_uri(TokenUrl::new("https://oauth2.googleapis.com/token".to_string()).unwrap())
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap());
|
||||
|
||||
let (auth_url, csrf_token) = client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new("openid".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.add_scope(Scope::new("profile".to_string()))
|
||||
.url();
|
||||
|
||||
(auth_url.to_string(), csrf_token)
|
||||
}
|
||||
|
||||
pub async fn exchange_code(&self, code: String) -> Result<String, OAuthError> {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID")
|
||||
.expect("GOOGLE_CLIENT_ID must be set");
|
||||
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET")
|
||||
.expect("GOOGLE_CLIENT_SECRET must be set");
|
||||
let redirect_url = std::env::var("GOOGLE_REDIRECT_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080/api/auth/google/callback".to_string());
|
||||
|
||||
let client = BasicClient::new(ClientId::new(client_id))
|
||||
.set_client_secret(ClientSecret::new(client_secret))
|
||||
.set_auth_uri(AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string()).unwrap())
|
||||
.set_token_uri(TokenUrl::new("https://oauth2.googleapis.com/token".to_string()).unwrap())
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap());
|
||||
|
||||
let http_client = oauth2::reqwest::ClientBuilder::new()
|
||||
.build()
|
||||
.map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
|
||||
|
||||
let token = client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.request_async(&http_client)
|
||||
.await
|
||||
.map_err(|e: oauth2::RequestTokenError<_, _>| OAuthError::TokenExchange(e.to_string()))?;
|
||||
|
||||
Ok(token.access_token().secret().clone())
|
||||
}
|
||||
|
||||
pub async fn get_user_info(&self, access_token: &str) -> Result<GoogleUserInfo, OAuthError> {
|
||||
let response = self.http_client
|
||||
.get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
.bearer_auth(access_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| OAuthError::UserInfo(e.to_string()))?;
|
||||
|
||||
response
|
||||
.json::<GoogleUserInfo>()
|
||||
.await
|
||||
.map_err(|e| OAuthError::UserInfo(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn find_or_create_user(
|
||||
&self,
|
||||
db: &DatabaseConnection,
|
||||
google_user: GoogleUserInfo,
|
||||
) -> Result<user::Model, OAuthError> {
|
||||
let existing = User::find()
|
||||
.filter(user::Column::GoogleId.eq(&google_user.id))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
|
||||
if let Some(user) = existing {
|
||||
return Ok(user);
|
||||
}
|
||||
|
||||
let existing_by_email = User::find()
|
||||
.filter(user::Column::Email.eq(&google_user.email))
|
||||
.one(db)
|
||||
.await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
|
||||
if let Some(user) = existing_by_email {
|
||||
let mut active: user::ActiveModel = user.into();
|
||||
active.google_id = Set(Some(google_user.id));
|
||||
let updated = active.update(db).await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
return Ok(updated);
|
||||
}
|
||||
|
||||
let new_user = user::ActiveModel {
|
||||
email: Set(Some(google_user.email)),
|
||||
google_id: Set(Some(google_user.id)),
|
||||
username: Set(google_user.name),
|
||||
password_hash: Set(None),
|
||||
is_admin: Set(false),
|
||||
family_id: Set(None),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let created = new_user.insert(db).await
|
||||
.map_err(|e| OAuthError::Database(e.to_string()))?;
|
||||
|
||||
Ok(created)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OAuthError {
|
||||
#[error("Token exchange failed: {0}")]
|
||||
TokenExchange(String),
|
||||
#[error("Failed to get user info: {0}")]
|
||||
UserInfo(String),
|
||||
#[error("Database error: {0}")]
|
||||
Database(String),
|
||||
}
|
||||
Reference in New Issue
Block a user