Replaces room_id variables with typed versions
Allows for better protection of the input strings and fits into type checking.
This commit is contained in:
parent
5718c48524
commit
1523699dd1
|
@ -1,5 +1,7 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.14.1"
|
||||
|
@ -1764,6 +1766,7 @@ dependencies = [
|
|||
"r2d2_sqlite",
|
||||
"rand 0.8.3",
|
||||
"rand_core 0.5.1",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"rusqlite",
|
||||
"rusqlite_migration",
|
||||
|
|
|
@ -27,6 +27,7 @@ log4rs = "1.0"
|
|||
octocrab = "0.9"
|
||||
rand = "0.8"
|
||||
rand_core = "0.5"
|
||||
regex = "1"
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
rusqlite = { version = "0.24", features = ["bundled"] }
|
||||
rusqlite_migration = "0.4"
|
||||
|
|
|
@ -52,7 +52,9 @@ pub async fn create_room(room: models::Room) -> Result<Response, Rejection> {
|
|||
}
|
||||
}
|
||||
// Set up the database
|
||||
storage::create_database_if_needed(&room.id);
|
||||
storage::create_database_if_needed(
|
||||
&storage::RoomId::new(&room.id).ok_or(Error::ValidationFailed)?,
|
||||
);
|
||||
// Return
|
||||
info!("Added room with ID: {}", &room.id);
|
||||
let json = models::StatusCode { status_code: StatusCode::OK.as_u16() };
|
||||
|
@ -768,7 +770,8 @@ pub fn get_deleted_messages(
|
|||
pub async fn add_moderator_public(
|
||||
body: models::ChangeModeratorRequestBody, auth_token: &str,
|
||||
) -> Result<Response, Rejection> {
|
||||
let pool = storage::pool_by_room_id(&body.room_id);
|
||||
let room_id = storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?;
|
||||
let pool = storage::pool_by_room_id(&room_id);
|
||||
let (has_authorization_level, _) =
|
||||
has_authorization_level(auth_token, AuthorizationLevel::Moderator, &pool)?;
|
||||
if !has_authorization_level {
|
||||
|
@ -782,7 +785,9 @@ pub async fn add_moderator(
|
|||
body: models::ChangeModeratorRequestBody,
|
||||
) -> Result<Response, Rejection> {
|
||||
// Get a database connection
|
||||
let pool = storage::pool_by_room_id(&body.room_id);
|
||||
let pool = storage::pool_by_room_id(
|
||||
&storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?,
|
||||
);
|
||||
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
|
||||
// Insert the moderator
|
||||
let stmt = format!("INSERT INTO {} (public_key) VALUES (?1)", storage::MODERATORS_TABLE);
|
||||
|
@ -802,7 +807,9 @@ pub async fn add_moderator(
|
|||
pub async fn delete_moderator_public(
|
||||
body: models::ChangeModeratorRequestBody, auth_token: &str,
|
||||
) -> Result<Response, Rejection> {
|
||||
let pool = storage::pool_by_room_id(&body.room_id);
|
||||
let pool = storage::pool_by_room_id(
|
||||
&storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?,
|
||||
);
|
||||
let (has_authorization_level, _) =
|
||||
has_authorization_level(auth_token, AuthorizationLevel::Moderator, &pool)?;
|
||||
if !has_authorization_level {
|
||||
|
@ -816,7 +823,9 @@ pub async fn delete_moderator(
|
|||
body: models::ChangeModeratorRequestBody,
|
||||
) -> Result<Response, Rejection> {
|
||||
// Get a database connection
|
||||
let pool = storage::pool_by_room_id(&body.room_id);
|
||||
let pool = storage::pool_by_room_id(
|
||||
&storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?,
|
||||
);
|
||||
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
|
||||
// Insert the moderator
|
||||
let stmt = format!("DELETE FROM {} WHERE public_key = (?1)", storage::MODERATORS_TABLE);
|
||||
|
@ -1051,7 +1060,9 @@ pub fn compact_poll(
|
|||
}
|
||||
};
|
||||
// Get the database connection pool
|
||||
let pool = storage::pool_by_room_id(&room_id);
|
||||
let pool = storage::pool_by_room_id(
|
||||
&storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?,
|
||||
);
|
||||
// Get the new messages
|
||||
let mut get_messages_query_params: HashMap<String, String> = HashMap::new();
|
||||
if let Some(from_message_server_id) = from_message_server_id {
|
||||
|
@ -1160,7 +1171,7 @@ pub async fn get_session_version(platform: &str) -> Result<String, Rejection> {
|
|||
|
||||
// not publicly exposed.
|
||||
pub async fn get_stats_for_room(
|
||||
room: String, query_map: HashMap<String, i64>,
|
||||
room_id: String, query_map: HashMap<String, i64>,
|
||||
) -> Result<Response, Rejection> {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let window = match query_map.get("window") {
|
||||
|
@ -1174,7 +1185,8 @@ pub async fn get_stats_for_room(
|
|||
};
|
||||
|
||||
let lowerbound = upperbound - window;
|
||||
let pool = storage::pool_by_room_id(&room);
|
||||
let pool =
|
||||
storage::pool_by_room_id(&storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?);
|
||||
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
|
||||
|
||||
let raw_query_users = format!(
|
||||
|
|
18
src/rpc.rs
18
src/rpc.rs
|
@ -49,13 +49,13 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
|
|||
// Get the auth token if possible
|
||||
let auth_token = get_auth_token(&rpc_call);
|
||||
// Get the room ID
|
||||
let room_id = get_room_id(&rpc_call);
|
||||
let room_id_str = get_room_id(&rpc_call);
|
||||
// Switch on the HTTP method
|
||||
match rpc_call.method.as_ref() {
|
||||
"GET" => {
|
||||
return handle_get_request(room_id, rpc_call, &path, auth_token, query_params).await
|
||||
return handle_get_request(room_id_str, rpc_call, &path, auth_token, query_params).await
|
||||
}
|
||||
"POST" => return handle_post_request(room_id, rpc_call, &path, auth_token).await,
|
||||
"POST" => return handle_post_request(room_id_str, rpc_call, &path, auth_token).await,
|
||||
"DELETE" => {
|
||||
let pool = get_pool_for_room(&rpc_call)?;
|
||||
return handle_delete_request(rpc_call, &path, auth_token, &pool).await;
|
||||
|
@ -408,14 +408,10 @@ async fn handle_delete_request(
|
|||
// Utilities
|
||||
|
||||
fn get_pool_for_room(rpc_call: &RpcCall) -> Result<storage::DatabaseConnectionPool, Rejection> {
|
||||
let room_id = match get_room_id(&rpc_call) {
|
||||
Some(room_id) => room_id,
|
||||
None => {
|
||||
warn!("Missing room ID.");
|
||||
return Err(warp::reject::custom(Error::InvalidRpcCall));
|
||||
}
|
||||
};
|
||||
return Ok(storage::pool_by_room_id(&room_id));
|
||||
let room_id = get_room_id(&rpc_call).ok_or(Error::ValidationFailed)?;
|
||||
return Ok(storage::pool_by_room_id(
|
||||
&storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?,
|
||||
));
|
||||
}
|
||||
|
||||
fn get_auth_token(rpc_call: &RpcCall) -> Option<String> {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
|
@ -12,6 +13,30 @@ use super::errors::Error;
|
|||
pub type DatabaseConnection = r2d2::PooledConnection<SqliteConnectionManager>;
|
||||
pub type DatabaseConnectionPool = r2d2::Pool<SqliteConnectionManager>;
|
||||
|
||||
#[derive(PartialEq, Eq, Hash)]
|
||||
pub struct RoomId {
|
||||
id: String,
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
// Alphanumeric, Decimals "-" & "_" only and must be between 1 - 64 characters
|
||||
static ref REGULAR_CHARACTERS_ONLY: Regex = Regex::new(r"^[\w-]{1,64}$").unwrap();
|
||||
}
|
||||
|
||||
impl RoomId {
|
||||
pub fn new(room_id: &str) -> Option<RoomId> {
|
||||
if REGULAR_CHARACTERS_ONLY.is_match(room_id) {
|
||||
return Some(RoomId { id: room_id.to_string() });
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
|
||||
// Main
|
||||
|
||||
pub const MAIN_TABLE: &str = "main";
|
||||
|
@ -63,21 +88,21 @@ lazy_static::lazy_static! {
|
|||
static ref POOLS: Mutex<HashMap<String, DatabaseConnectionPool>> = Mutex::new(HashMap::new());
|
||||
}
|
||||
|
||||
pub fn pool_by_room_id(room_id: &str) -> DatabaseConnectionPool {
|
||||
pub fn pool_by_room_id(room_id: &RoomId) -> DatabaseConnectionPool {
|
||||
let mut pools = POOLS.lock().unwrap();
|
||||
if let Some(pool) = pools.get(room_id) {
|
||||
if let Some(pool) = pools.get(room_id.get_id()) {
|
||||
return pool.clone();
|
||||
} else {
|
||||
let raw_path = format!("rooms/{}.db", room_id);
|
||||
let raw_path = format!("rooms/{}.db", room_id.get_id());
|
||||
let path = Path::new(&raw_path);
|
||||
let db_manager = r2d2_sqlite::SqliteConnectionManager::file(path);
|
||||
let pool = r2d2::Pool::new(db_manager).unwrap();
|
||||
pools.insert(room_id.to_string(), pool);
|
||||
return pools[room_id].clone();
|
||||
pools.insert(room_id.get_id().to_string(), pool);
|
||||
return pools[room_id.get_id()].clone();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_database_if_needed(room_id: &str) {
|
||||
pub fn create_database_if_needed(room_id: &RoomId) {
|
||||
let pool = pool_by_room_id(room_id);
|
||||
let conn = pool.get().unwrap();
|
||||
create_room_tables_if_needed(&conn);
|
||||
|
@ -270,7 +295,9 @@ fn get_expired_file_ids(
|
|||
Ok(rows.filter_map(|result| result.ok()).collect())
|
||||
}
|
||||
|
||||
pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, file_expiration: i64) {
|
||||
pub async fn prune_files_for_room(
|
||||
pool: &DatabaseConnectionPool, room: &RoomId, file_expiration: i64,
|
||||
) {
|
||||
let ids = get_expired_file_ids(&pool, file_expiration);
|
||||
|
||||
match ids {
|
||||
|
@ -278,7 +305,7 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil
|
|||
// Delete the files
|
||||
let futs = ids.iter().map(|id| async move {
|
||||
(
|
||||
tokio::fs::remove_file(format!("files/{}_files/{}", room, id)).await,
|
||||
tokio::fs::remove_file(format!("files/{}_files/{}", room.get_id(), id)).await,
|
||||
id.to_owned(),
|
||||
)
|
||||
});
|
||||
|
@ -289,7 +316,9 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil
|
|||
if let Err(err) = res {
|
||||
error!(
|
||||
"Couldn't delete file: {} from room: {} due to error: {}.",
|
||||
id, room, err
|
||||
id,
|
||||
room.get_id(),
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -320,7 +349,7 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil
|
|||
};
|
||||
}
|
||||
// Log the result
|
||||
info!("Pruned files for room: {}. Took: {:?}", room, now.elapsed());
|
||||
info!("Pruned files for room: {}. Took: {:?}", room.get_id(), now.elapsed());
|
||||
}
|
||||
Ok(_) => {
|
||||
// empty
|
||||
|
@ -375,7 +404,7 @@ pub fn perform_migration() {
|
|||
|
||||
// Utilities
|
||||
|
||||
fn get_all_room_ids() -> Result<Vec<String>, Error> {
|
||||
fn get_all_room_ids() -> Result<Vec<RoomId>, Error> {
|
||||
// Get a database connection
|
||||
let conn = MAIN_POOL.get().map_err(|_| Error::DatabaseFailedInternally)?;
|
||||
// Query the database
|
||||
|
@ -388,7 +417,11 @@ fn get_all_room_ids() -> Result<Vec<String>, Error> {
|
|||
return Err(Error::DatabaseFailedInternally);
|
||||
}
|
||||
};
|
||||
let ids: Vec<String> = rows.filter_map(|result| result.ok()).collect();
|
||||
let room_ids: Vec<_> = rows
|
||||
.filter_map(|result: Result<String, _>| result.ok())
|
||||
.map(|opt| RoomId::new(&opt))
|
||||
.flatten()
|
||||
.collect();
|
||||
// Return
|
||||
return Ok(ids);
|
||||
return Ok(room_ids);
|
||||
}
|
||||
|
|
22
src/tests.rs
22
src/tests.rs
|
@ -75,12 +75,12 @@ async fn test_file_handling() {
|
|||
// Ensure the test room is set up and get a database connection pool
|
||||
let pool = set_up_test_room().await;
|
||||
|
||||
let test_room_id = "test_room";
|
||||
let test_room_id = storage::RoomId::new("test_room").unwrap();
|
||||
// Get an auth token
|
||||
let (auth_token, _) = get_auth_token(&pool);
|
||||
// Store the test file
|
||||
handlers::store_file(
|
||||
Some(test_room_id.to_string()),
|
||||
Some(test_room_id.get_id().to_string()),
|
||||
TEST_FILE,
|
||||
Some(auth_token.clone()),
|
||||
&pool,
|
||||
|
@ -94,17 +94,21 @@ async fn test_file_handling() {
|
|||
conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap();
|
||||
let id = id_as_string.parse::<u64>().unwrap();
|
||||
// Retrieve the file and check the content
|
||||
let base64_encoded_file =
|
||||
handlers::get_file(Some(test_room_id.to_string()), id, Some(auth_token.clone()), &pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.result;
|
||||
let base64_encoded_file = handlers::get_file(
|
||||
Some(test_room_id.get_id().to_string()),
|
||||
id,
|
||||
Some(auth_token.clone()),
|
||||
&pool,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.result;
|
||||
assert_eq!(base64_encoded_file, TEST_FILE);
|
||||
// Prune the file and check that it's gone
|
||||
// Will evaluate to now + 60
|
||||
storage::prune_files_for_room(&pool, test_room_id, -60).await;
|
||||
storage::prune_files_for_room(&pool, &test_room_id, -60).await;
|
||||
// It should be gone now
|
||||
fs::read(format!("files/{}_files/{}", test_room_id, id)).unwrap_err();
|
||||
fs::read(format!("files/{}_files/{}", test_room_id.get_id(), id)).unwrap_err();
|
||||
// Check that the file record is also gone
|
||||
let conn = pool.get().unwrap();
|
||||
let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);
|
||||
|
|
Loading…
Reference in New Issue