diff --git a/server/src/database/model/user.rs b/server/src/database/model/user.rs index a2cd28b..58f0947 100644 --- a/server/src/database/model/user.rs +++ b/server/src/database/model/user.rs @@ -9,10 +9,17 @@ pub struct User { pub user_id: uuid::Uuid, #[validate(email)] pub email: String, - pub password: String, + pub password: Option, pub admin: bool, } +#[derive(Debug)] +pub struct UpdateUser { + pub email: Option, + pub password: Option, + pub admin: Option, +} + impl User { pub async fn insert( transaction: &mut sqlx::Transaction<'_, Postgres>, @@ -56,6 +63,41 @@ impl User { Ok(user) } + + pub async fn get_password(&self, pool: &PgPool) -> Result { + let password = sqlx::query_scalar!( + " + SELECT password FROM users WHERE user_id = $1 + ", + self.user_id, + ) + .fetch_one(pool) + .await?; + + Ok(password) + } + + pub async fn update( + &self, + transaction: &mut sqlx::Transaction<'_, Postgres>, + update_user: UpdateUser, + ) -> Result<(), sqlx::Error> { + sqlx::query!( + " + UPDATE users + SET email = coalesce($1, email), + password = coalesce($2, password) + WHERE user_id = $3; + ", + update_user.email, + update_user.password, + self.user_id + ) + .execute(&mut **transaction) + .await?; + + Ok(()) + } } #[derive(Debug)] diff --git a/server/src/model/user.rs b/server/src/model/user.rs index f380237..86c407f 100644 --- a/server/src/model/user.rs +++ b/server/src/model/user.rs @@ -23,6 +23,17 @@ impl From for User { } } +impl From for DbUser { + fn from(user: User) -> Self { + Self { + user_id: user.id, + email: user.email, + admin: user.admin, + password: None, + } + } +} + impl User { pub async fn members(&self, pool: &PgPool) -> Result, sqlx::Error> { let related_members = DbUserMember::get_members_from_user(pool, &self.id).await?; diff --git a/server/src/routes/auth.rs b/server/src/routes/auth.rs index a0a4919..9f10abb 100644 --- a/server/src/routes/auth.rs +++ b/server/src/routes/auth.rs @@ -1,6 +1,11 @@ +use axum::debug_handler; +use axum::http::HeaderMap; use axum::{extract::State, routing::post, Json, Router}; +use serde::Deserialize; use crate::auth::verify_password_hash; +use crate::auth::{get_user_from_header, AuthError}; +use crate::database::model::user::UpdateUser; use crate::database::model::Member as DbMember; use crate::database::model::Session as DbSession; use crate::database::model::User as DbUser; @@ -12,9 +17,11 @@ pub fn routes() -> Router { Router::new() .route("/auth/login", post(login)) .route("/auth/register", post(register)) + .route("/auth/change_password", post(change_password)) + .route("/auth/change_email", post(change_email)) } -#[derive(serde::Deserialize)] +#[derive(Deserialize)] pub struct LoginRequest { email: String, password: String, @@ -26,10 +33,14 @@ pub async fn login<'a>( ) -> Result { let db_user = DbUser::get_from_email(&state.pool, login_request.email).await?; - match verify_password_hash(&login_request.password, &db_user.password).await { - Ok(_) => (), - Err(_err) => return Err(crate::Error::Auth(crate::auth::AuthError::InvalidPassword)), - }; + if let Some(pass) = db_user.password { + match verify_password_hash(&login_request.password, &pass).await { + Ok(_) => (), + Err(_err) => return Err(crate::Error::Auth(crate::auth::AuthError::InvalidPassword)), + }; + } else { + return Err(AuthError::Unexpected.into()); + } // Create session let mut transaction = state.pool.begin().await?; @@ -42,14 +53,14 @@ pub async fn login<'a>( Ok(db_session.token) } -#[derive(serde::Deserialize)] +#[derive(Deserialize)] pub struct RegisterRequest { email: String, password: String, registration_tokens: Vec, } -pub async fn register<'a>( +pub async fn register( State(state): State, Json(auth_request): Json, ) -> Result { @@ -83,3 +94,88 @@ pub async fn register<'a>( Ok(db_session.token) } + +#[derive(Debug, Deserialize)] +pub struct ChangePasswordRequest { + pub old_password: String, + pub new_password: String, +} + +pub async fn change_password( + State(state): State, + headers: HeaderMap, + Json(request): Json, +) -> Result<(), crate::Error> { + let (_, user) = get_user_from_header(&state.pool, &headers).await?; + + let password_hash = match generate_password_hash(request.new_password).await { + Ok(hash) => hash, + Err(_err) => return Err(crate::Error::Auth(crate::auth::AuthError::InvalidPassword)), + }; + + let db_user: DbUser = user.into(); + + let old_password_hash = db_user.get_password(&state.pool).await?; + + match verify_password_hash(&request.old_password, &old_password_hash).await { + Ok(_) => (), + Err(_err) => return Err(crate::Error::Auth(crate::auth::AuthError::InvalidPassword)), + }; + + let mut transaction = state.pool.begin().await?; + + db_user + .update( + &mut transaction, + UpdateUser { + email: None, + password: Some(password_hash), + admin: None, + }, + ) + .await?; + + transaction.commit().await?; + + Ok(()) +} + +#[derive(Debug, Deserialize)] +pub struct ChangeEmailRequest { + pub password: String, + pub new_email: String, +} + +pub async fn change_email( + State(state): State, + headers: HeaderMap, + Json(request): Json, +) -> Result<(), crate::Error> { + let (_, user) = get_user_from_header(&state.pool, &headers).await?; + + let db_user: DbUser = user.into(); + + let password_hash = db_user.get_password(&state.pool).await?; + + match verify_password_hash(&request.password, &password_hash).await { + Ok(_) => (), + Err(_err) => return Err(crate::Error::Auth(crate::auth::AuthError::InvalidPassword)), + }; + + let mut transaction = state.pool.begin().await?; + + db_user + .update( + &mut transaction, + UpdateUser { + email: Some(request.new_email), + password: None, + admin: None, + }, + ) + .await?; + + transaction.commit().await?; + + Ok(()) +} diff --git a/server/src/util/error.rs b/server/src/util/error.rs index d164baa..c009799 100644 --- a/server/src/util/error.rs +++ b/server/src/util/error.rs @@ -40,6 +40,12 @@ impl IntoResponse for Error { let (status_code, code) = match self { Self::Sqlx(ref err_kind) => match err_kind { sqlx::Error::RowNotFound => (StatusCode::NOT_FOUND, "DATABASE_ROW_NOT_FOUND"), + sqlx::Error::Database(db_err) => match db_err.kind() { + sqlx::error::ErrorKind::UniqueViolation => { + (StatusCode::INTERNAL_SERVER_ERROR, "DATABASE_DUPLICATE") + } + _ => (StatusCode::INTERNAL_SERVER_ERROR, "DATABASE_ERROR"), + }, _ => (StatusCode::INTERNAL_SERVER_ERROR, "DATABASE_ERROR"), },