Merge pull request #20 from darcys22/room-id-type

Replaces room_id variables with typed versions
This commit is contained in:
Sean 2021-09-14 12:07:14 +10:00 committed by GitHub
commit 0b897052ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 41 deletions

3
Cargo.lock generated
View File

@ -1,5 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "addr2line"
version = "0.14.1"
@ -1764,6 +1766,7 @@ dependencies = [
"r2d2_sqlite",
"rand 0.8.3",
"rand_core 0.5.1",
"regex",
"reqwest",
"rusqlite",
"rusqlite_migration",

View File

@ -27,6 +27,7 @@ log4rs = "1.0"
octocrab = "0.9"
rand = "0.8"
rand_core = "0.5"
regex = "1"
reqwest = { version = "0.11", features = ["json"] }
rusqlite = { version = "0.24", features = ["bundled"] }
rusqlite_migration = "0.4"

View File

@ -52,7 +52,9 @@ pub async fn create_room(room: models::Room) -> Result<Response, Rejection> {
}
}
// Set up the database
storage::create_database_if_needed(&room.id);
storage::create_database_if_needed(
&storage::RoomId::new(&room.id).ok_or(Error::ValidationFailed)?,
);
// Return
info!("Added room with ID: {}", &room.id);
let json = models::StatusCode { status_code: StatusCode::OK.as_u16() };
@ -743,7 +745,8 @@ pub fn get_deleted_messages(
pub async fn add_moderator_public(
body: models::ChangeModeratorRequestBody, auth_token: &str,
) -> Result<Response, Rejection> {
let pool = storage::pool_by_room_id(&body.room_id);
let room_id = storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?;
let pool = storage::pool_by_room_id(&room_id);
let (has_authorization_level, _) =
has_authorization_level(auth_token, AuthorizationLevel::Moderator, &pool)?;
if !has_authorization_level {
@ -757,7 +760,9 @@ pub async fn add_moderator(
body: models::ChangeModeratorRequestBody,
) -> Result<Response, Rejection> {
// Get a database connection
let pool = storage::pool_by_room_id(&body.room_id);
let pool = storage::pool_by_room_id(
&storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?,
);
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Insert the moderator
let stmt = "INSERT INTO moderators (public_key) VALUES (?1)";
@ -777,7 +782,9 @@ pub async fn add_moderator(
pub async fn delete_moderator_public(
body: models::ChangeModeratorRequestBody, auth_token: &str,
) -> Result<Response, Rejection> {
let pool = storage::pool_by_room_id(&body.room_id);
let pool = storage::pool_by_room_id(
&storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?,
);
let (has_authorization_level, _) =
has_authorization_level(auth_token, AuthorizationLevel::Moderator, &pool)?;
if !has_authorization_level {
@ -791,7 +798,9 @@ pub async fn delete_moderator(
body: models::ChangeModeratorRequestBody,
) -> Result<Response, Rejection> {
// Get a database connection
let pool = storage::pool_by_room_id(&body.room_id);
let pool = storage::pool_by_room_id(
&storage::RoomId::new(&body.room_id).ok_or(Error::ValidationFailed)?,
);
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Insert the moderator
let stmt = "DELETE FROM moderators WHERE public_key = (?1)";
@ -1023,7 +1032,9 @@ pub fn compact_poll(
}
};
// Get the database connection pool
let pool = storage::pool_by_room_id(&room_id);
let pool = storage::pool_by_room_id(
&storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?,
);
// Get the new messages
let mut get_messages_query_params: HashMap<String, String> = HashMap::new();
if let Some(from_message_server_id) = from_message_server_id {
@ -1132,7 +1143,7 @@ pub async fn get_session_version(platform: &str) -> Result<String, Rejection> {
// not publicly exposed.
pub async fn get_stats_for_room(
room: String, query_map: HashMap<String, i64>,
room_id: String, query_map: HashMap<String, i64>,
) -> Result<Response, Rejection> {
let now = chrono::Utc::now().timestamp();
let window = match query_map.get("window") {
@ -1146,7 +1157,8 @@ pub async fn get_stats_for_room(
};
let lowerbound = upperbound - window;
let pool = storage::pool_by_room_id(&room);
let pool =
storage::pool_by_room_id(&storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?);
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let raw_query_users =

View File

@ -49,13 +49,13 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
// Get the auth token if possible
let auth_token = get_auth_token(&rpc_call);
// Get the room ID
let room_id = get_room_id(&rpc_call);
let room_id_str = get_room_id(&rpc_call);
// Switch on the HTTP method
match rpc_call.method.as_ref() {
"GET" => {
return handle_get_request(room_id, rpc_call, &path, auth_token, query_params).await
return handle_get_request(room_id_str, rpc_call, &path, auth_token, query_params).await
}
"POST" => return handle_post_request(room_id, rpc_call, &path, auth_token).await,
"POST" => return handle_post_request(room_id_str, rpc_call, &path, auth_token).await,
"DELETE" => {
let pool = get_pool_for_room(&rpc_call)?;
return handle_delete_request(rpc_call, &path, auth_token, &pool).await;
@ -408,14 +408,10 @@ async fn handle_delete_request(
// Utilities
fn get_pool_for_room(rpc_call: &RpcCall) -> Result<storage::DatabaseConnectionPool, Rejection> {
let room_id = match get_room_id(&rpc_call) {
Some(room_id) => room_id,
None => {
warn!("Missing room ID.");
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
};
return Ok(storage::pool_by_room_id(&room_id));
let room_id = get_room_id(&rpc_call).ok_or(Error::ValidationFailed)?;
return Ok(storage::pool_by_room_id(
&storage::RoomId::new(&room_id).ok_or(Error::ValidationFailed)?,
));
}
fn get_auth_token(rpc_call: &RpcCall) -> Option<String> {

View File

@ -1,3 +1,4 @@
use regex::Regex;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Mutex;
@ -12,6 +13,30 @@ use super::errors::Error;
pub type DatabaseConnection = r2d2::PooledConnection<SqliteConnectionManager>;
pub type DatabaseConnectionPool = r2d2::Pool<SqliteConnectionManager>;
#[derive(PartialEq, Eq, Hash)]
pub struct RoomId {
id: String,
}
lazy_static::lazy_static! {
// Alphanumeric, Decimals "-" & "_" only and must be between 1 - 64 characters
static ref REGULAR_CHARACTERS_ONLY: Regex = Regex::new(r"^[\w-]{1,64}$").unwrap();
}
impl RoomId {
pub fn new(room_id: &str) -> Option<RoomId> {
if REGULAR_CHARACTERS_ONLY.is_match(room_id) {
return Some(RoomId { id: room_id.to_string() });
} else {
return None;
}
}
pub fn get_id(&self) -> &str {
&self.id
}
}
// Main
lazy_static::lazy_static! {
@ -49,21 +74,21 @@ lazy_static::lazy_static! {
static ref POOLS: Mutex<HashMap<String, DatabaseConnectionPool>> = Mutex::new(HashMap::new());
}
pub fn pool_by_room_id(room_id: &str) -> DatabaseConnectionPool {
pub fn pool_by_room_id(room_id: &RoomId) -> DatabaseConnectionPool {
let mut pools = POOLS.lock().unwrap();
if let Some(pool) = pools.get(room_id) {
if let Some(pool) = pools.get(room_id.get_id()) {
return pool.clone();
} else {
let raw_path = format!("rooms/{}.db", room_id);
let raw_path = format!("rooms/{}.db", room_id.get_id());
let path = Path::new(&raw_path);
let db_manager = r2d2_sqlite::SqliteConnectionManager::file(path);
let pool = r2d2::Pool::new(db_manager).unwrap();
pools.insert(room_id.to_string(), pool);
return pools[room_id].clone();
pools.insert(room_id.get_id().to_string(), pool);
return pools[room_id.get_id()].clone();
}
}
pub fn create_database_if_needed(room_id: &str) {
pub fn create_database_if_needed(room_id: &RoomId) {
let pool = pool_by_room_id(room_id);
let conn = pool.get().unwrap();
create_room_tables_if_needed(&conn);
@ -232,7 +257,9 @@ fn get_expired_file_ids(
Ok(rows.filter_map(|result| result.ok()).collect())
}
pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, file_expiration: i64) {
pub async fn prune_files_for_room(
pool: &DatabaseConnectionPool, room: &RoomId, file_expiration: i64,
) {
let ids = get_expired_file_ids(&pool, file_expiration);
match ids {
@ -240,7 +267,7 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil
// Delete the files
let futs = ids.iter().map(|id| async move {
(
tokio::fs::remove_file(format!("files/{}_files/{}", room, id)).await,
tokio::fs::remove_file(format!("files/{}_files/{}", room.get_id(), id)).await,
id.to_owned(),
)
});
@ -251,7 +278,9 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil
if let Err(err) = res {
error!(
"Couldn't delete file: {} from room: {} due to error: {}.",
id, room, err
id,
room.get_id(),
err
);
}
}
@ -282,7 +311,7 @@ pub async fn prune_files_for_room(pool: &DatabaseConnectionPool, room: &str, fil
};
}
// Log the result
info!("Pruned files for room: {}. Took: {:?}", room, now.elapsed());
info!("Pruned files for room: {}. Took: {:?}", room.get_id(), now.elapsed());
}
Ok(_) => {
// empty
@ -334,7 +363,7 @@ pub fn perform_migration() {
// Utilities
fn get_all_room_ids() -> Result<Vec<String>, Error> {
fn get_all_room_ids() -> Result<Vec<RoomId>, Error> {
// Get a database connection
let conn = MAIN_POOL.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Query the database
@ -347,7 +376,11 @@ fn get_all_room_ids() -> Result<Vec<String>, Error> {
return Err(Error::DatabaseFailedInternally);
}
};
let ids: Vec<String> = rows.filter_map(|result| result.ok()).collect();
let room_ids: Vec<_> = rows
.filter_map(|result: Result<String, _>| result.ok())
.map(|opt| RoomId::new(&opt))
.flatten()
.collect();
// Return
return Ok(ids);
return Ok(room_ids);
}

View File

@ -75,12 +75,12 @@ async fn test_file_handling() {
// Ensure the test room is set up and get a database connection pool
let pool = set_up_test_room().await;
let test_room_id = "test_room";
let test_room_id = storage::RoomId::new("test_room").unwrap();
// Get an auth token
let (auth_token, _) = get_auth_token(&pool);
// Store the test file
handlers::store_file(
Some(test_room_id.to_string()),
Some(test_room_id.get_id().to_string()),
TEST_FILE,
Some(auth_token.clone()),
&pool,
@ -94,17 +94,21 @@ async fn test_file_handling() {
conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap();
let id = id_as_string.parse::<u64>().unwrap();
// Retrieve the file and check the content
let base64_encoded_file =
handlers::get_file(Some(test_room_id.to_string()), id, Some(auth_token.clone()), &pool)
.await
.unwrap()
.result;
let base64_encoded_file = handlers::get_file(
Some(test_room_id.get_id().to_string()),
id,
Some(auth_token.clone()),
&pool,
)
.await
.unwrap()
.result;
assert_eq!(base64_encoded_file, TEST_FILE);
// Prune the file and check that it's gone
// Will evaluate to now + 60
storage::prune_files_for_room(&pool, test_room_id, -60).await;
storage::prune_files_for_room(&pool, &test_room_id, -60).await;
// It should be gone now
fs::read(format!("files/{}_files/{}", test_room_id, id)).unwrap_err();
fs::read(format!("files/{}_files/{}", test_room_id.get_id(), id)).unwrap_err();
// Check that the file record is also gone
let conn = pool.get().unwrap();
let raw_query = "SELECT id FROM files";