2025-02-07 15:59:50 +01:00

158 lines
4.5 KiB
Rust

use std::collections::HashSet;
use argon2::{
password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
Argon2, PasswordHash, PasswordVerifier,
};
use axum::{
extract::FromRequestParts,
http::{header, request::Parts, HeaderMap, StatusCode},
RequestPartsExt,
};
use axum_extra::{
extract::cookie::{Cookie, CookieJar},
headers::{authorization::Bearer, Authorization},
typed_header::TypedHeaderRejectionReason,
TypedHeader,
};
use bearer::verify_bearer;
use chrono::Utc;
pub use error::AuthError;
use rand::distr::Alphanumeric;
use rand::prelude::*;
use rand_chacha::ChaCha20Rng;
use sqlx::PgPool;
use tokio::task;
use crate::{database::model::Session, model::User};
mod bearer;
mod error;
mod scopes;
#[derive(Debug)]
pub struct Permissions<'a>(pub HashSet<&'a str>);
// Middleware for getting permissions
impl<S> FromRequestParts<S> for Permissions<'_>
where
S: Send + Sync,
{
type Rejection = crate::Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// First check if the request has a beaerer token to authenticate
match parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
Ok(bearer) => {
verify_bearer(bearer.token().to_string()).map_err(|_| AuthError::InvalidToken)?;
let permissions = Permissions {
0: HashSet::from(["root"]),
};
return Ok(permissions);
}
Err(err) => match err.reason() {
TypedHeaderRejectionReason::Missing => (),
TypedHeaderRejectionReason::Error(_err) => {
return Err(AuthError::InvalidToken.into())
}
_ => return Err(AuthError::Unexpected.into()),
},
};
match parts.extract::<CookieJar>().await {
Ok(jar) => {
if let Some(session_token) = jar.get("session_token") {
// TODO: Implement function to retrieve user permissions
tracing::info!("{session_token:?}")
}
}
Err(_) => (),
}
Err(AuthError::Unauthorized.into())
}
}
pub async fn get_user_from_header(pool: &PgPool, headers: &HeaderMap) -> Result<User, AuthError> {
let bearer_value = headers.get(header::AUTHORIZATION);
let bearer_value = bearer_value
.ok_or_else(|| AuthError::InvalidToken)?
.to_str()
.map_err(|_| AuthError::InvalidToken)?;
let token = get_token_from_bearer(bearer_value)?;
let potential_user = match token.split_once("_") {
Some(("ses", _)) => {
let session = match Session::from_token(&pool, &token).await {
Ok(s) => s,
Err(_) => return Err(AuthError::InvalidToken),
};
if session.expires_at < Utc::now() {
return Err(AuthError::InvalidToken);
}
let db_user = match crate::database::model::User::get(&pool, session.user_id).await {
Ok(u) => u,
Err(_) => return Err(AuthError::InvalidToken),
};
db_user.into()
}
_ => return Err(AuthError::InvalidToken),
};
Ok(potential_user)
}
pub fn get_token_from_bearer(bearer: &str) -> Result<String, AuthError> {
match bearer.strip_prefix("Bearer ") {
Some(token) => Ok(token.to_string()),
None => return Err(AuthError::InvalidToken),
}
}
pub async fn generate_password_hash(
password: String,
) -> Result<String, argon2::password_hash::Error> {
let password_hash: Result<String, argon2::password_hash::Error> =
task::spawn_blocking(move || {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)?
.to_string();
Ok(password_hash)
})
.await
.unwrap();
Ok(password_hash?)
}
pub async fn verify_password_hash(
password: &str,
hash: &str,
) -> Result<(), argon2::password_hash::Error> {
let parsed_hash = PasswordHash::new(hash)?;
Argon2::default().verify_password(password.as_bytes(), &parsed_hash)?;
Ok(())
}
pub fn generate_session_token() -> String {
let session = ChaCha20Rng::from_os_rng()
.sample_iter(&Alphanumeric)
.take(60)
.map(char::from)
.collect::<String>();
session
}