Fix room ID vs room name usage

This commit is contained in:
Niels Andriesse 2021-03-23 15:39:42 +11:00
parent 0a54da8079
commit 786f7293d5
4 changed files with 26 additions and 51 deletions

View File

@ -43,7 +43,7 @@ pub async fn create_room(id: &str, name: &str) -> Result<Response, Rejection> {
}
}
// Set up the database
storage::create_database_if_needed(name);
storage::create_database_if_needed(id);
// Return
let json = models::StatusCode { status_code : StatusCode::OK.as_u16() };
return Ok(warp::reply::json(&json).into_response());

View File

@ -30,7 +30,7 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
return Err(warp::reject::custom(Error::InvalidRpcCall))
}
};
let pool = storage::pool_by_room_id(&room_id)?;
let pool = storage::pool_by_room_id(&room_id);
// Check that the endpoint is a valid URI
let uri = match rpc_call.endpoint.parse::<http::Uri>() {
Ok(uri) => uri,

View File

@ -58,46 +58,22 @@ lazy_static::lazy_static! {
static ref POOLS: Mutex<HashMap<String, DatabaseConnectionPool>> = Mutex::new(HashMap::new());
}
pub fn pool_by_room_id(room_id: &str) -> Result<DatabaseConnectionPool, Error> {
// Get a database connection
let conn = MAIN_POOL.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Query the database
let raw_query = format!("SELECT name FROM {} WHERE id = (?1)", MAIN_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![ room_id ], |row| {
Ok(row.get(0)?)
}) {
Ok(rows) => rows,
Err(e) => {
println!("Couldn't query database due to error: {}.", e);
return Err(Error::DatabaseFailedInternally);
}
};
let names: Vec<String> = rows.filter_map(|result| result.ok()).collect();
// Return
if let Some(name) = names.first() {
return Ok(pool_by_room_name(name));
} else {
return Err(Error::DatabaseFailedInternally);
}
}
pub fn pool_by_room_name(room: &str) -> DatabaseConnectionPool {
pub fn pool_by_room_id(room_id: &str) -> DatabaseConnectionPool {
let mut pools = POOLS.lock().unwrap();
if let Some(pool) = pools.get(room) {
if let Some(pool) = pools.get(room_id) {
return pool.clone();
} else {
let raw_path = format!("rooms/{}.db", room);
let raw_path = format!("rooms/{}.db", room_id);
let path = Path::new(&raw_path);
let db_manager = r2d2_sqlite::SqliteConnectionManager::file(path);
let pool = r2d2::Pool::new(db_manager).unwrap();
pools.insert(room.to_string(), pool);
return pools[room].clone();
pools.insert(room_id.to_string(), pool);
return pools[room_id].clone();
}
}
pub fn create_database_if_needed(room: &str) {
let pool = pool_by_room_name(room);
pub fn create_database_if_needed(room_id: &str) {
let pool = pool_by_room_id(room_id);
let conn = pool.get().unwrap();
create_room_tables_if_needed(&conn);
}
@ -186,12 +162,12 @@ pub async fn prune_files_periodically() {
}
async fn prune_tokens() {
let rooms = match get_all_rooms().await {
let rooms = match get_all_room_ids().await {
Ok(rooms) => rooms,
Err(_) => return
};
for room in rooms {
let pool = pool_by_room_name(&room);
let pool = pool_by_room_id(&room);
// It's not catastrophic if we fail to prune the database for a given room
let conn = match pool.get() {
Ok(conn) => conn,
@ -209,12 +185,12 @@ async fn prune_tokens() {
}
async fn prune_pending_tokens() {
let rooms = match get_all_rooms().await {
let rooms = match get_all_room_ids().await {
Ok(rooms) => rooms,
Err(_) => return
};
for room in rooms {
let pool = pool_by_room_name(&room);
let pool = pool_by_room_id(&room);
// It's not catastrophic if we fail to prune the database for a given room
let conn = match pool.get() {
Ok(conn) => conn,
@ -232,13 +208,13 @@ async fn prune_pending_tokens() {
}
pub async fn prune_files(file_expiration: i64) { // The expiration setting is passed in for testing purposes
let rooms = match get_all_rooms().await {
let rooms = match get_all_room_ids().await {
Ok(rooms) => rooms,
Err(_) => return
};
for room in rooms {
// It's not catastrophic if we fail to prune the database for a given room
let pool = pool_by_room_name(&room);
let pool = pool_by_room_id(&room);
let now = chrono::Utc::now().timestamp();
let expiration = now - file_expiration;
// Get a database connection and open a transaction
@ -279,7 +255,7 @@ pub async fn prune_files(file_expiration: i64) { // The expiration setting is pa
println!("Pruned files.");
}
async fn get_all_rooms() -> Result<Vec<String>, Error> {
async fn get_all_room_ids() -> Result<Vec<String>, Error> {
// Get a database connection
let conn = MAIN_POOL.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Query the database
@ -294,7 +270,7 @@ async fn get_all_rooms() -> Result<Vec<String>, Error> {
return Err(Error::DatabaseFailedInternally);
}
};
let names: Vec<String> = rows.filter_map(|result| result.ok()).collect();
let ids: Vec<String> = rows.filter_map(|result| result.ok()).collect();
// Return
return Ok(names);
return Ok(ids);
}

View File

@ -23,22 +23,20 @@ fn perform_main_setup() {
fn set_up_test_room() {
perform_main_setup();
let test_room = "test_room";
storage::create_database_if_needed(test_room);
let raw_path = format!("rooms/{}.db", test_room);
let test_room_id = "test_room";
let test_room_name = "Test Room";
aw!(handlers::create_room(&test_room_id, &test_room_name)).unwrap();
let raw_path = format!("rooms/{}.db", test_room_id);
let path = Path::new(&raw_path);
fs::read(path).unwrap(); // Fail if this doesn't exist
let pool: &storage::DatabaseConnectionPool = &storage::MAIN_POOL;
let conn = pool.get().unwrap();
let stmt = format!("REPLACE INTO {} (id, name) VALUES (?1, ?2)", storage::MAIN_TABLE);
conn.execute(&stmt, params![ test_room, "Test Room" ]).unwrap();
}
#[test]
fn test_authorization() {
// Ensure the test room is set up and get a database connection pool
set_up_test_room();
let pool = storage::pool_by_room_name("test_room");
let test_room_id = "test_room";
let pool = storage::pool_by_room_id(&test_room_id);
// Generate a fake user key pair
let (user_private_key, user_public_key) = aw!(crypto::generate_x25519_key_pair());
let hex_user_public_key = format!("05{}", hex::encode(user_public_key.to_bytes()));
@ -68,7 +66,8 @@ fn test_authorization() {
fn test_file_handling() {
// Ensure the test room is set up and get a database connection pool
set_up_test_room();
let pool = storage::pool_by_room_name("test_room");
let test_room_id = "test_room";
let pool = storage::pool_by_room_id(&test_room_id);
// Store the test file
aw!(handlers::store_file(TEST_FILE, &pool)).unwrap();
// Check that there's a file record