diff --git a/Cargo.lock b/Cargo.lock index 0e6cc5a..015dc65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "addr2line" version = "0.14.1" @@ -1764,6 +1766,7 @@ dependencies = [ "r2d2_sqlite", "rand 0.8.3", "rand_core 0.5.1", + "regex", "reqwest", "rusqlite", "rusqlite_migration", diff --git a/Cargo.toml b/Cargo.toml index dc3f3a5..dbe2e1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ log4rs = "1.0" octocrab = "0.9" rand = "0.8" rand_core = "0.5" +regex = "1" reqwest = { version = "0.11", features = ["json"] } rusqlite = { version = "0.24", features = ["bundled"] } rusqlite_migration = "0.4" diff --git a/src/handlers.rs b/src/handlers.rs index d30bddb..6a2d874 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -52,7 +52,9 @@ pub async fn create_room(room: models::Room) -> Result { } } // Set up the database - storage::create_database_if_needed(&room.id); + storage::create_database_if_needed( + &storage::RoomId::new(&room.id).ok_or(Error::ValidationFailed)?, + ); // Return info!("Added room with ID: {}", &room.id); let json = models::StatusCode { status_code: StatusCode::OK.as_u16() }; @@ -768,7 +770,8 @@ pub fn get_deleted_messages( pub async fn add_moderator_public( body: models::ChangeModeratorRequestBody, auth_token: &str, ) -> Result { - let pool = storage::pool_by_room_id(&body.room_id); + let room_id = storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?; + let pool = storage::pool_by_room_id(&room_id); let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Moderator, &pool)?; if !has_authorization_level { @@ -782,7 +785,9 @@ pub async fn add_moderator( body: models::ChangeModeratorRequestBody, ) -> Result { // Get a database connection - let pool = storage::pool_by_room_id(&body.room_id); + let pool = storage::pool_by_room_id( + &storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?, + ); let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; // Insert the moderator let stmt = format!("INSERT INTO {} (public_key) VALUES (?1)", storage::MODERATORS_TABLE); @@ -802,7 +807,9 @@ pub async fn add_moderator( pub async fn delete_moderator_public( body: models::ChangeModeratorRequestBody, auth_token: &str, ) -> Result { - let pool = storage::pool_by_room_id(&body.room_id); + let pool = storage::pool_by_room_id( + &storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?, + ); let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Moderator, &pool)?; if !has_authorization_level { @@ -816,7 +823,9 @@ pub async fn delete_moderator( body: models::ChangeModeratorRequestBody, ) -> Result { // Get a database connection - let pool = storage::pool_by_room_id(&body.room_id); + let pool = storage::pool_by_room_id( + &storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?, + ); let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; // Insert the moderator let stmt = format!("DELETE FROM {} WHERE public_key = (?1)", storage::MODERATORS_TABLE); @@ -1051,7 +1060,9 @@ pub fn compact_poll( } }; // Get the database connection pool - let pool = storage::pool_by_room_id(&room_id); + let pool = storage::pool_by_room_id( + &storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?, + ); // Get the new messages let mut get_messages_query_params: HashMap = HashMap::new(); if let Some(from_message_server_id) = from_message_server_id { @@ -1160,7 +1171,7 @@ pub async fn get_session_version(platform: &str) -> Result { // not publicly exposed. pub async fn get_stats_for_room( - room: String, query_map: HashMap, + room_id: String, query_map: HashMap, ) -> Result { let now = chrono::Utc::now().timestamp(); let window = match query_map.get("window") { @@ -1174,7 +1185,8 @@ pub async fn get_stats_for_room( }; let lowerbound = upperbound - window; - let pool = storage::pool_by_room_id(&room); + let pool = + storage::pool_by_room_id(&storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?); let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let raw_query_users = format!( diff --git a/src/rpc.rs b/src/rpc.rs index 9e36213..196a22b 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -49,13 +49,13 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result { // Get the auth token if possible let auth_token = get_auth_token(&rpc_call); // Get the room ID - let room_id = get_room_id(&rpc_call); + let room_id_str = get_room_id(&rpc_call); // Switch on the HTTP method match rpc_call.method.as_ref() { "GET" => { - return handle_get_request(room_id, rpc_call, &path, auth_token, query_params).await + return handle_get_request(room_id_str, rpc_call, &path, auth_token, query_params).await } - "POST" => return handle_post_request(room_id, rpc_call, &path, auth_token).await, + "POST" => return handle_post_request(room_id_str, rpc_call, &path, auth_token).await, "DELETE" => { let pool = get_pool_for_room(&rpc_call)?; return handle_delete_request(rpc_call, &path, auth_token, &pool).await; @@ -408,14 +408,10 @@ async fn handle_delete_request( // Utilities fn get_pool_for_room(rpc_call: &RpcCall) -> Result { - let room_id = match get_room_id(&rpc_call) { - Some(room_id) => room_id, - None => { - warn!("Missing room ID."); - return Err(warp::reject::custom(Error::InvalidRpcCall)); - } - }; - return Ok(storage::pool_by_room_id(&room_id)); + let room_id = get_room_id(&rpc_call).ok_or(Error::ValidationFailed)?; + return Ok(storage::pool_by_room_id( + &storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?, + )); } fn get_auth_token(rpc_call: &RpcCall) -> Option { diff --git a/src/storage.rs b/src/storage.rs index 45b46b5..31edc39 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,3 +1,4 @@ +use regex::Regex; use std::collections::HashMap; use std::path::Path; use std::sync::Mutex; @@ -12,6 +13,30 @@ use super::errors::Error; pub type DatabaseConnection = r2d2::PooledConnection; pub type DatabaseConnectionPool = r2d2::Pool; +#[derive(PartialEq, Eq, Hash)] +pub struct RoomId { + id: String, +} + +lazy_static::lazy_static! { + // Alphanumeric, Decimals "-" & "_" only and must be between 1 - 64 characters + static ref REGULAR_CHARACTERS_ONLY: Regex = Regex::new(r"^[\w-]{1,64}$").unwrap(); +} + +impl RoomId { + pub fn new(room_id: &str) -> Option { + if REGULAR_CHARACTERS_ONLY.is_match(room_id) { + return Some(RoomId { id: room_id.to_string() }); + } else { + return None; + } + } + + pub fn get_id(&self) -> &str { + &self.id + } +} + // Main pub const MAIN_TABLE: &str = "main"; @@ -63,21 +88,21 @@ lazy_static::lazy_static! { static ref POOLS: Mutex> = Mutex::new(HashMap::new()); } -pub fn pool_by_room_id(room_id: &str) -> DatabaseConnectionPool { +pub fn pool_by_room_id(room_id: &RoomId) -> DatabaseConnectionPool { let mut pools = POOLS.lock().unwrap(); - if let Some(pool) = pools.get(room_id) { + if let Some(pool) = pools.get(room_id.get_id()) { return pool.clone(); } else { - let raw_path = format!("rooms/{}.db", room_id); + let raw_path = format!("rooms/{}.db", room_id.get_id()); let path = Path::new(&raw_path); let db_manager = r2d2_sqlite::SqliteConnectionManager::file(path); let pool = r2d2::Pool::new(db_manager).unwrap(); - pools.insert(room_id.to_string(), pool); - return pools[room_id].clone(); + pools.insert(room_id.get_id().to_string(), pool); + return pools[room_id.get_id()].clone(); } } -pub fn create_database_if_needed(room_id: &str) { +pub fn create_database_if_needed(room_id: &RoomId) { let pool = pool_by_room_id(room_id); let conn = pool.get().unwrap(); create_room_tables_if_needed(&conn); @@ -270,7 +295,9 @@ fn get_expired_file_ids( Ok(rows.filter_map(|result| result.ok()).collect()) } -pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, file_expiration: i64) { +pub async fn prune_files_for_room( + pool: &DatabaseConnectionPool, room: &RoomId, file_expiration: i64, +) { let ids = get_expired_file_ids(&pool, file_expiration); match ids { @@ -278,7 +305,7 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil // Delete the files let futs = ids.iter().map(|id| async move { ( - tokio::fs::remove_file(format!("files/{}_files/{}", room, id)).await, + tokio::fs::remove_file(format!("files/{}_files/{}", room.get_id(), id)).await, id.to_owned(), ) }); @@ -289,7 +316,9 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil if let Err(err) = res { error!( "Couldn't delete file: {} from room: {} due to error: {}.", - id, room, err + id, + room.get_id(), + err ); } } @@ -320,7 +349,7 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil }; } // Log the result - info!("Pruned files for room: {}. Took: {:?}", room, now.elapsed()); + info!("Pruned files for room: {}. Took: {:?}", room.get_id(), now.elapsed()); } Ok(_) => { // empty @@ -375,7 +404,7 @@ pub fn perform_migration() { // Utilities -fn get_all_room_ids() -> Result, Error> { +fn get_all_room_ids() -> Result, Error> { // Get a database connection let conn = MAIN_POOL.get().map_err(|_| Error::DatabaseFailedInternally)?; // Query the database @@ -388,7 +417,11 @@ fn get_all_room_ids() -> Result, Error> { return Err(Error::DatabaseFailedInternally); } }; - let ids: Vec = rows.filter_map(|result| result.ok()).collect(); + let room_ids: Vec<_> = rows + .filter_map(|result: Result| result.ok()) + .map(|opt| RoomId::new(&opt)) + .flatten() + .collect(); // Return - return Ok(ids); + return Ok(room_ids); } diff --git a/src/tests.rs b/src/tests.rs index 09ecd4c..0a3fbad 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -75,12 +75,12 @@ async fn test_file_handling() { // Ensure the test room is set up and get a database connection pool let pool = set_up_test_room().await; - let test_room_id = "test_room"; + let test_room_id = storage::RoomId::new("test_room").unwrap(); // Get an auth token let (auth_token, _) = get_auth_token(&pool); // Store the test file handlers::store_file( - Some(test_room_id.to_string()), + Some(test_room_id.get_id().to_string()), TEST_FILE, Some(auth_token.clone()), &pool, @@ -94,17 +94,21 @@ async fn test_file_handling() { conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap(); let id = id_as_string.parse::().unwrap(); // Retrieve the file and check the content - let base64_encoded_file = - handlers::get_file(Some(test_room_id.to_string()), id, Some(auth_token.clone()), &pool) - .await - .unwrap() - .result; + let base64_encoded_file = handlers::get_file( + Some(test_room_id.get_id().to_string()), + id, + Some(auth_token.clone()), + &pool, + ) + .await + .unwrap() + .result; assert_eq!(base64_encoded_file, TEST_FILE); // Prune the file and check that it's gone // Will evaluate to now + 60 - storage::prune_files_for_room(&pool, test_room_id, -60).await; + storage::prune_files_for_room(&pool, &test_room_id, -60).await; // It should be gone now - fs::read(format!("files/{}_files/{}", test_room_id, id)).unwrap_err(); + fs::read(format!("files/{}_files/{}", test_room_id.get_id(), id)).unwrap_err(); // Check that the file record is also gone let conn = pool.get().unwrap(); let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);