diff --git a/Cargo.lock b/Cargo.lock index e3a996d..0abd391 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "achubb_website" @@ -10,8 +10,12 @@ dependencies = [ "axum", "bb8", "clap", + "cookie", "futures-util", + "pbkdf2", "rand", + "rand_chacha", + "rand_core", "serde", "sqlx", "time", @@ -382,6 +386,16 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "time", + "version_check", +] + [[package]] name = "cpufeatures" version = "0.2.12" @@ -1102,12 +1116,35 @@ dependencies = [ "windows-targets 0.52.5", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" diff --git a/Cargo.toml b/Cargo.toml index 6bddf3c..a4c4e2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,20 @@ askama = "0.12.1" axum = "0.6" bb8 = "0.8.3" clap = { version = "4.5.13", features = ["derive"] } +cookie = "0.18.1" futures-util = "0.3.30" +pbkdf2 = { version = "0.12.2", features = ["simple"] } rand = "0.8.5" +rand_chacha = "0.3.1" +rand_core = "0.6.4" serde = { version = "1.0.197", features = ["derive"] } -sqlx = {version = "0.7.4", features = ["runtime-tokio-rustls", "postgres", "time", "macros"]} -time = {version = "0.3.36", features = ["macros", "serde"]} +sqlx = { version = "0.7.4", features = [ + "runtime-tokio-rustls", + "postgres", + "time", + "macros", +] } +time = { version = "0.3.36", features = ["macros", "serde"] } tokio = { version = "1.35.0", features = ["full"] } tower = "0.4.13" tower-http = { version = "0.4.4", features = ["fs"] } diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..f3e1bf3 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,158 @@ +use axum::{ + body::Empty, + http::{Request, Response, StatusCode}, + middleware, + response::IntoResponse, + Extension, Form, +}; +use pbkdf2::{ + password_hash::{PasswordHash, PasswordVerifier}, + Pbkdf2, +}; +use sqlx::PgPool; + +use crate::{ + database::{ + session::{new_session, Random}, + user::create_user, + }, + errors::{LoginError, SignupError}, + html::{ + root::{error_page, get_login}, + Login, Signup, + }, +}; + +#[derive(Clone)] +pub struct UserInfo { + pub user_id: i32, + pub admin: bool, +} + +#[derive(Clone)] +pub struct AuthState(Option<(u128, Option, PgPool)>); + +pub async fn auth( + mut req: Request, + next: middleware::Next, + pool: PgPool, +) -> axum::response::Response { + let session_token = req + .headers() + .get_all("Cookie") + .iter() + .filter_map(|cookie| { + cookie + .to_str() + .ok() + .and_then(|cookie| cookie.parse::().ok()) + }) + .find_map(|cookie| { + (cookie.name() == "session_token").then(move || cookie.value().to_owned()) + }) + .and_then(|cookie_value| cookie_value.parse::().ok()); + + if session_token.is_none() + && req.uri().to_string().contains("/admin") + && !req.uri().to_string().contains("/login") + { + return get_login().await.into_response(); + } + + req.extensions_mut() + .insert(AuthState(session_token.map(|v| (v, None, pool)))); + + next.run(req).await +} + +impl AuthState { + pub async fn get_user(&mut self) -> Option<&UserInfo> { + let (session_token, store, pool) = self.0.as_mut()?; + + if store.is_none() { + const QUERY: &str = + "SELECT id, admin FROM users JOIN sessions ON user_id = id WHERE session_token = $1;"; + let user: Option<(i32, bool)> = sqlx::query_as(QUERY) + .bind(&session_token.to_le_bytes().to_vec()) + .fetch_optional(&*pool) + .await + .unwrap(); + + if let Some((id, admin)) = user { + *store = Some(UserInfo { + user_id: id, + admin + }); + } + } + store.as_ref() + } +} + +pub fn set_cookie(session_token: &str) -> impl IntoResponse { + Response::builder() + .status(StatusCode::SEE_OTHER) + .header("Location", "/") + .header( + "Set-Cookie", + format!("session_token={}; Max-Age=999999", session_token), + ) + .body(Empty::new()) + .unwrap() +} + +pub async fn post_login( + Extension(pool): Extension, + Extension(random): Extension, + Form(login): Form, +) -> impl IntoResponse { + const LOGIN_QUERY: &str = "SELECT id, password FROM users WHERE users.username = $1;"; + + let row: Option<(i32, String)> = sqlx::query_as(LOGIN_QUERY) + .bind(login.username) + .fetch_optional(&pool) + .await + .unwrap(); + + let (user_id, hashed_password) = if let Some(row) = row { + row + } else { + return Err(error_page(&LoginError::UserDoesNotExist)); + }; + + // Verify password against PHC string + let parsed_hash = PasswordHash::new(&hashed_password).unwrap(); + + if let Err(_err) = Pbkdf2.verify_password(login.password.as_bytes(), &parsed_hash) { + return Err(error_page(&LoginError::WrongPassword)); + } + + let session_token = new_session(&pool, random, user_id).await; + + Ok(set_cookie(&session_token)) +} + +pub async fn post_signup( + Extension(pool): Extension, + Extension(random): Extension, + Form(signup): Form, +) -> impl IntoResponse { + if signup.password != signup.confirm_password { + return Err(error_page(&SignupError::PasswordsDoNotMatch)); + } + + let user_id = create_user(&signup.username, &signup.password, &pool).await.unwrap(); + + let session_token = new_session(&pool, random, user_id).await; + + Ok(set_cookie(&session_token)) +} + +pub async fn logout_response() -> impl IntoResponse { + Response::builder() + .status(StatusCode::SEE_OTHER) + .header("Location", "/") + .header("Set-Cookie", "session_token=_; Max-Age=0") + .body(Empty::new()) + .unwrap() +} diff --git a/src/database/mod.rs b/src/database/mod.rs index c48b1a1..93ea341 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -9,6 +9,8 @@ use std::{ pub mod article; pub mod link; +pub mod session; +pub mod user; pub async fn establish_connection() -> Result> { let db_url = match env::var("ACHUBB_DATABASE_URL") { diff --git a/src/database/session.rs b/src/database/session.rs new file mode 100644 index 0000000..1473803 --- /dev/null +++ b/src/database/session.rs @@ -0,0 +1,108 @@ +use crate::{database::PsqlData, errors::DatabaseError}; +use futures_util::TryStreamExt; +use rand::RngCore; +use rand_chacha::ChaCha8Rng; +use serde::{Deserialize, Serialize}; +use sqlx::postgres::PgPool; +use std::{ + error::Error, + sync::{Arc, Mutex}, +}; + +pub type Random = Arc>; + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct Session { + pub session_token: Vec, + pub user_id: i32, +} + +impl Session { + pub async fn read_by_token( + pool: &PgPool, + session_token: &String, + ) -> Result, Box> { + let token: Vec = session_token.parse::()?.to_le_bytes().to_vec(); + let result = sqlx::query_as!( + Self, + "SELECT * FROM sessions WHERE session_token = $1", + token, + ) + .fetch_one(pool) + .await?; + + Ok(Box::new(result)) + } +} + +impl PsqlData for Session { + async fn read_all(pool: &PgPool) -> Result>, Box> { + crate::psql_read_all!(Self, pool, "sessions") + } + + async fn read(pool: &PgPool, id: i32) -> Result, Box> { + let result = sqlx::query_as!(Self, "SELECT * FROM sessions WHERE user_id = $1", id) + .fetch_one(pool) + .await?; + + Ok(Box::new(result)) + } + + async fn insert(&self, pool: &PgPool) -> Result<(), Box> { + sqlx::query!( + "INSERT INTO sessions (session_token, user_id) VALUES ($1, $2)", + self.session_token, + self.user_id, + ) + .execute(pool) + .await?; + + Ok(()) + } + + async fn update(&self, pool: &PgPool) -> Result<(), Box> { + sqlx::query!( + "UPDATE sessions SET session_token=$1 WHERE user_id=$2", + self.session_token, + self.user_id, + ) + .execute(pool) + .await?; + + Ok(()) + } + + async fn delete(&self, pool: &PgPool) -> Result<(), Box> { + let _result = sqlx::query!("DELETE FROM sessions WHERE user_id = $1", self.user_id) + .execute(pool) + .await?; + + Ok(()) + } +} + +pub async fn new_session(pool: &PgPool, random: Random, user_id: i32) -> String { + let mut u128_pool = [0u8; 16]; + random.lock().unwrap().fill_bytes(&mut u128_pool); + + let session_token = u128::from_le_bytes(u128_pool); + + let session = Session { + user_id, + session_token: session_token.to_le_bytes().to_vec(), + }; + + session.insert(pool).await.unwrap(); + + session_token.to_string() +} + +pub async fn clear_sessions_for_user(pool: &PgPool, user_id: i32) -> Result<(), DatabaseError> { + const QUERY: &str = "DELETE FROM sessions WHERE user_id=$1;"; + let _result = sqlx::query(QUERY) + .bind(user_id) + .execute(pool) + .await + .unwrap(); + Ok(()) +} diff --git a/src/database/user.rs b/src/database/user.rs new file mode 100644 index 0000000..46e980f --- /dev/null +++ b/src/database/user.rs @@ -0,0 +1,107 @@ +use crate::{database::PsqlData, errors::SignupError}; +use futures_util::TryStreamExt; +use pbkdf2::{ + password_hash::{rand_core::OsRng, PasswordHasher, SaltString}, + Pbkdf2, +}; +use serde::{Deserialize, Serialize}; +use sqlx::postgres::PgPool; +use std::error::Error; + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct User { + pub id: i32, + pub username: String, + pub password: String, + pub admin: bool, +} + +impl User { + pub async fn read_by_name( + pool: &PgPool, + username: &String, + ) -> Result, Box> { + let result = sqlx::query_as!(Self, "SELECT * FROM users WHERE username = $1", username,) + .fetch_one(pool) + .await?; + + Ok(Box::new(result)) + } +} + +impl PsqlData for User { + async fn read_all(pool: &PgPool) -> Result>, Box> { + crate::psql_read_all!(Self, pool, "users") + } + + async fn read(pool: &PgPool, id: i32) -> Result, Box> { + crate::psql_read!(Self, pool, id, "users") + } + + async fn insert(&self, pool: &PgPool) -> Result<(), Box> { + sqlx::query!( + "INSERT INTO users (username, password, admin) VALUES ($1, $2, $3)", + self.username, + self.password, + self.admin, + ) + .execute(pool) + .await?; + + Ok(()) + } + + async fn update(&self, pool: &PgPool) -> Result<(), Box> { + sqlx::query!( + "UPDATE users SET username=$1, password=$2 WHERE id=$3", + self.username, + self.password, + self.id, + ) + .execute(pool) + .await?; + + Ok(()) + } + + async fn delete(&self, pool: &PgPool) -> Result<(), Box> { + let id = &self.id; + crate::psql_delete!(id, pool, "users") + } +} + +pub async fn create_user( + username: &str, + password: &str, + pool: &PgPool, +) -> Result { + let salt = SaltString::generate(&mut OsRng); + + // Hash password to PHC string ($pbkdf2-sha256$...) + let hashed_password = Pbkdf2 + .hash_password(password.as_bytes(), &salt) + .unwrap() + .to_string(); + + const INSERT_QUERY: &str = + "INSERT INTO users (username, password, admin) VALUES ($1, $2, $3) RETURNING id;"; + + let fetch_one = sqlx::query_as(INSERT_QUERY) + .bind(username) + .bind(hashed_password) + .bind(false) + .fetch_one(pool) + .await; + + match fetch_one { + Ok((user_id,)) => Ok(user_id), + Err(sqlx::Error::Database(database)) + if database.constraint() == Some("users_username_key") => + { + return Err(SignupError::UsernameExists); + } + Err(_) => { + return Err(SignupError::InternalError); + } + } +} diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..354636d --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,70 @@ +use std::{error::Error, fmt::Display}; + +#[derive(Debug)] +pub struct NotLoggedIn; + +impl Display for NotLoggedIn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Not logged in") + } +} + +impl Error for NotLoggedIn {} + +#[derive(Debug)] +pub enum SignupError { + UsernameExists, + InvalidUsername, + PasswordsDoNotMatch, + MissingDetails, + InvalidPassword, + InternalError, +} + +impl Display for SignupError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SignupError::InvalidUsername => f.write_str("Invalid username"), + SignupError::UsernameExists => f.write_str("Username already exists"), + SignupError::PasswordsDoNotMatch => f.write_str("Passwords do not match"), + SignupError::MissingDetails => f.write_str("Missing Details"), + SignupError::InvalidPassword => f.write_str("Invalid Password"), + SignupError::InternalError => f.write_str("Internal Error"), + } + } +} + +impl Error for SignupError {} + +#[derive(Debug)] +pub enum LoginError { + UserDoesNotExist, + WrongPassword, +} + +impl Display for LoginError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LoginError::UserDoesNotExist => f.write_str("User does not exist"), + LoginError::WrongPassword => f.write_str("Wrong password"), + } + } +} + +impl Error for LoginError {} + +#[derive(Debug)] +pub struct NoUser(pub String); + +impl Display for NoUser { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("could not find user '{}'", self.0)) + } +} + +impl Error for NoUser {} + +#[derive(Debug)] +pub enum DatabaseError { + NoEntries, +} diff --git a/src/html/admin.rs b/src/html/admin.rs new file mode 100644 index 0000000..0925320 --- /dev/null +++ b/src/html/admin.rs @@ -0,0 +1,31 @@ +use axum::{ + response::{IntoResponse, Redirect, Response}, + routing::{get, Router}, + Extension, +}; + +use crate::auth::{AuthState, UserInfo}; + +use super::templates::{AdminTemplate, HtmlTemplate}; + +pub fn get_router() -> Router { + Router::new().route("/", get(get_admin)) +} + +pub async fn get_admin( + Extension(mut current_user): Extension, +) -> Response { + let user_info: UserInfo = current_user.get_user().await.unwrap().clone(); + if !user_info.admin { + return Redirect::to("/").into_response(); + } + HtmlTemplate(AdminTemplate {}).into_response() +} + +fn get_trimmed_string(input: String) -> Option { + if input.trim().is_empty() { + None + } else { + Some(input.trim().to_string()) + } +} diff --git a/src/html/mod.rs b/src/html/mod.rs index 8a1d45c..c8b2e78 100644 --- a/src/html/mod.rs +++ b/src/html/mod.rs @@ -1,5 +1,20 @@ +use serde::Deserialize; + +pub mod admin; pub mod blog; pub mod garden; pub mod root; pub mod templates; +#[derive(Deserialize)] +pub struct Login { + pub username: String, + pub password: String, +} + +#[derive(Deserialize)] +pub struct Signup { + pub username: String, + pub password: String, + pub confirm_password: String, +} diff --git a/src/html/root.rs b/src/html/root.rs index 660acae..2495bf9 100644 --- a/src/html/root.rs +++ b/src/html/root.rs @@ -1,33 +1,42 @@ +use std:: sync::{Arc, Mutex}; + use axum::{ + http::{Response, StatusCode}, response::{IntoResponse, Redirect}, - routing::{get, Router}, + routing::{get, post, Router}, Extension, }; -use rand::{seq::SliceRandom, thread_rng}; +use rand::{seq::SliceRandom, thread_rng, RngCore, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rand_core::OsRng; use sqlx::PgPool; use std::error::Error; use tower_http::services::ServeDir; -use crate::database::{ - link::{Link, LinkType}, - PsqlData, +use crate::{ + auth::{auth, logout_response, post_login, post_signup}, + database::{ + link::{Link, LinkType}, + PsqlData, + }, }; use super::{ - blog::{self, get_articles_date_sorted}, - garden, - templates::{ - AboutTemplate, AiTemplate, BlogrollTemplate, ContactTemplate, GiftsTemplate, HomeTemplate, - HtmlTemplate, InterestsTemplate, LinksPageTemplate, NowTemplate, ResumeTemplate, - WorkTemplate, UsesTemplate, - }, + admin, blog::{self, get_articles_date_sorted}, garden, templates::{ + AboutTemplate, AiTemplate, BlogrollTemplate, ContactTemplate, GiftsTemplate, HomeTemplate, HtmlTemplate, InterestsTemplate, LinksPageTemplate, LoginTemplate, NowTemplate, ResumeTemplate, SignupTemplate, UsesTemplate, WorkTemplate + } }; pub fn get_router(pool: PgPool) -> Router { let assets_path = std::env::current_dir().unwrap(); + + let random = ChaCha8Rng::seed_from_u64(OsRng.next_u64()); + let middleware_database = pool.clone(); + Router::new() .nest("/blog", blog::get_router()) .nest("/garden", garden::get_router()) + .nest("/admin", admin::get_router()) .nest_service( "/assets", ServeDir::new(format!("{}/assets", assets_path.to_str().unwrap())), @@ -44,11 +53,18 @@ pub fn get_router(pool: PgPool) -> Router { .route("/resume", get(resume)) .route("/gifts", get(gifts)) .route("/hire", get(work)) + .route("/login", get(get_login).post(post_login)) + .route("/signup", get(get_signup).post(post_signup)) + .route("/logout", post(logout_response)) .route( "/robots.txt", get(|| async { Redirect::permanent("/assets/robots.txt") }), ) + .layer(axum::middleware::from_fn(move |req, next| { + auth(req, next, middleware_database.clone()) + })) .layer(Extension(pool)) + .layer(Extension(Arc::new(Mutex::new(random)))) } async fn home(Extension(pool): Extension) -> impl IntoResponse { @@ -133,3 +149,18 @@ pub async fn get_links_as_list( .collect(); Ok(list) } + +pub fn error_page(err: &dyn std::error::Error) -> impl IntoResponse { + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(format!("Err: {}", err)) + .unwrap() +} + +pub async fn get_login() -> impl IntoResponse { + HtmlTemplate(LoginTemplate { username: None }) +} + +pub async fn get_signup() -> impl IntoResponse { + HtmlTemplate(SignupTemplate {}) +} diff --git a/src/html/templates.rs b/src/html/templates.rs index 5417e94..a9b2393 100644 --- a/src/html/templates.rs +++ b/src/html/templates.rs @@ -144,3 +144,17 @@ pub struct WorkTemplate {} #[derive(Template)] #[template(path = "technology.html")] pub struct TechnologyTemplate {} + +#[derive(Template)] +#[template(path = "login.html")] +pub struct LoginTemplate { + pub username: Option, +} + +#[derive(Template)] +#[template(path = "signup.html")] +pub struct SignupTemplate {} + +#[derive(Template)] +#[template(path = "admin.html")] +pub struct AdminTemplate {} diff --git a/src/lib.rs b/src/lib.rs index 57ff0b6..bba3b33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,9 @@ use std::error::Error; use tracing::info; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +mod auth; pub mod database; +mod errors; mod html; mod macros; diff --git a/templates/admin.html b/templates/admin.html new file mode 100644 index 0000000..e5b804f --- /dev/null +++ b/templates/admin.html @@ -0,0 +1,12 @@ + +{% extends "base.html" %} + +{% block content %} +

Admin

+
+ +
+

+ My admin page +

+{% endblock %} diff --git a/templates/login.html b/templates/login.html new file mode 100644 index 0000000..53fd1f0 --- /dev/null +++ b/templates/login.html @@ -0,0 +1,32 @@ + + + + + + + Awstin + {% block head %}{% endblock %} + + + +
+
+

+ {% if let Some(username) = username %} + + + {% else %} + + + {% endif %} +

+

+ + +

+ +
+
+ + + diff --git a/templates/signup.html b/templates/signup.html new file mode 100644 index 0000000..136ef2d --- /dev/null +++ b/templates/signup.html @@ -0,0 +1,31 @@ + + + + + + + Awstin + {% block head %}{% endblock %} + + + +
+
+

+ + +

+

+ + +

+

+ + +

+ +
+
+ + +