use serde; use rocket::{ self, http::Status, serde::Serialize, request::{ self, FromRequest, }, Request, outcome::{ try_outcome, Outcome::{ Success, Failure, Forward, }, }, }; use rocket_db_pools::{ Database, Connection, }; use sqlx::{ pool::PoolConnection, Postgres, }; #[derive(Database)] #[database("notes")] pub struct Notes(sqlx::PgPool); pub type Result> = std::result::Result; #[derive(Serialize)] pub struct Note { pub id: i32, pub content: String, pub list_id: i32, } impl ToString for Note { fn to_string(&self) -> String { format!("Note: {}", self.content) } } #[derive(Serialize)] pub struct User { pub id: i32, pub uuid: String, pub username: String, pub password: String, pub email: String, } impl ToString for User { fn to_string(&self) -> String { format!("User: {}", self.username) } } #[derive(Debug)] pub enum UserError { NoCookie, InvalidCookie, WrappedDBError(Option>), DBError(sqlx::Error), } #[rocket::async_trait] impl<'a> FromRequest<'a> for User { type Error = UserError; async fn from_request(req: &'a Request<'_>) -> request::Outcome { let cookies = req.cookies(); let mut db = match req.guard::>().await { Success(dbv) => dbv, Failure((http_err, err)) => return Failure((http_err, UserError::WrappedDBError(err))), Forward(next) => return Forward(next), }; let user_uuid = match cookies.get_private("user_uuid") { Some(crumb) => crumb, None => return Forward(()), // this will redirect to the login page }; match sqlx::query!("SELECT * FROM users WHERE uuid = $1", user_uuid.value()) .fetch_optional(&mut *db) .await { Ok(Some(row)) => { Success(User { id: row.id, uuid: row.uuid, username: row.username, password: row.password, email: row.email, }) }, Ok(None) => Forward(()), // this will redirect to the login page Err(e) => Failure((Status::InternalServerError, UserError::DBError(e))), } } } #[derive(Serialize)] pub struct List { pub id: i32, pub name: String, pub owner_id: i32, } impl ToString for List { fn to_string(&self) -> String { format!("{}: {} (owned by {})", self.id, self.name, self.owner_id) } } pub async fn get_list_notes(db: &mut PoolConnection, lid: i32) -> Result> { Ok(sqlx::query!(" SELECT id,content,list_id FROM note WHERE list_id = $1", lid) .fetch_all(db) .await? .iter() .map(|r| Note { id: r.id, content: r.content.clone(), list_id: r.list_id }) .collect()) } pub async fn get_user_lists(db: &mut PoolConnection, uid: i32) -> Result> { Ok(sqlx::query!(" SELECT id,name,owner_id FROM list WHERE owner_id = $1", uid) .fetch_all(db) .await? .iter() .map(|r| List { name: r.name.clone(), id: r.id, owner_id: r.owner_id }) .collect()) } pub async fn get_user_lists_from_perms(db: &mut PoolConnection, uid: i32) -> Result> { Ok(sqlx::query!(" SELECT list.id,list.name,list.owner_id FROM list JOIN perms ON list.id=perms.list_id WHERE perms.user_id = $1 AND perms.read = TRUE", uid) .fetch_all(db) .await? .iter() .map(|r| List { name: r.name.clone(), id: r.id, owner_id: r.owner_id }) .collect()) } pub async fn get_user_from_email(db: &mut PoolConnection, email: String) -> Result> { match sqlx::query!( "SELECT * FROM users WHERE email = $1", email ).fetch_optional(&mut *db) .await? { Some(row) => Ok(Some(User { id: row.id, uuid: row.uuid, username: row.username, password: row.password, email: row.email, })), None => Ok(None), } } pub async fn add_permission(db: &mut PoolConnection, user_id: i32, list_id: i32, perm: i32) -> Result<()> { sqlx::query!(" INSERT INTO perms (user_id, list_id, read, write) VALUES ($1, $2, $3, $4)", user_id, list_id, perm >= 1, perm >= 2 ).execute(db) .await?; Ok(()) }