From 304df4e1255ef11998f1ad16808e808cc16d1fcc Mon Sep 17 00:00:00 2001 From: Niels Andriesse Date: Thu, 18 Mar 2021 15:53:24 +1100 Subject: [PATCH] Get room from RPC call --- src/rpc.rs | 25 ++++++++++++++++++++----- src/storage.rs | 32 ++++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/src/rpc.rs b/src/rpc.rs index 3cd8bd8..da762a4 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -23,8 +23,11 @@ pub struct QueryOptions { pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result { // Get a connection pool for the given room - let room = "main"; // TODO: Get room from RPC call - let pool = storage::pool(room); + let room_id = match get_room_id(&rpc_call) { + Some(room_id) => room_id, + None => return Err(warp::reject::custom(Error::InvalidRpcCall)) + }; + let pool = storage::pool_by_room_id(room_id)?; // Check that the endpoint is a valid URI let uri = match rpc_call.endpoint.parse::() { Ok(uri) => uri, @@ -164,7 +167,19 @@ fn get_auth_token(rpc_call: &RpcCall) -> Option { Ok(headers) => headers, Err(_) => return None }; - let header = headers.get("Authorization"); - if header == None { return None; } - return header.unwrap().strip_prefix("Bearer").map(|s| s.to_string()).or(None); + let header = headers.get("Authorization")?; + return header.strip_prefix("Bearer").map(|s| s.to_string()).or(None); +} + +fn get_room_id(rpc_call: &RpcCall) -> Option { + if rpc_call.headers.is_empty() { return None; } + let headers: HashMap = match serde_json::from_str(&rpc_call.headers) { + Ok(headers) => headers, + Err(_) => return None + }; + let header = headers.get("Room")?; + match header.parse() { + Ok(room_id) => return Some(room_id), + Err(_) => return None + }; } \ No newline at end of file diff --git a/src/storage.rs b/src/storage.rs index e6d4bc4..09c71b0 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -54,7 +54,30 @@ lazy_static::lazy_static! { static ref POOLS: Mutex> = Mutex::new(HashMap::new()); } -pub fn pool(room: &str) -> DatabaseConnectionPool { +pub fn pool_by_room_id(room_id: usize) -> Result { + // Get a database connection + let conn = MAIN_POOL.get().map_err(|_| Error::DatabaseFailedInternally)?; + // Query the database + let raw_query = format!("SELECT name FROM {} WHERE id = (?1)", MAIN_TABLE); + let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; + let rows = match query.query_map(params![ room_id ], |row| { + Ok(row.get(0)?) + }) { + Ok(rows) => rows, + Err(e) => { + println!("Couldn't query database due to error: {}.", e); + return Err(Error::DatabaseFailedInternally); + } + }; + let names: Vec = rows.filter_map(|result| result.ok()).collect(); + if let Some(name) = names.first() { + return Ok(pool_by_room_name(name)); + } else { + return Err(Error::DatabaseFailedInternally); + } +} + +pub fn pool_by_room_name(room: &str) -> DatabaseConnectionPool { let mut pools = POOLS.lock().unwrap(); if let Some(pool) = pools.get(room) { return pool.clone(); @@ -68,7 +91,7 @@ pub fn pool(room: &str) -> DatabaseConnectionPool { } pub fn create_database_if_needed(room: &str) { - let pool = pool(room); + let pool = pool_by_room_name(room); let conn = pool.get().unwrap(); create_room_tables_if_needed(&conn); } @@ -146,7 +169,7 @@ async fn prune_tokens() { Err(_) => return }; for room in rooms { - let pool = pool(&room); + let pool = pool_by_room_name(&room); // It's not catastrophic if we fail to prune the database for a given room let mut conn = match pool.get() { Ok(conn) => conn, @@ -177,7 +200,7 @@ async fn prune_pending_tokens() { Err(_) => return }; for room in rooms { - let pool = pool(&room); + let pool = pool_by_room_name(&room); // It's not catastrophic if we fail to prune the database for a given room let mut conn = match pool.get() { Ok(conn) => conn, @@ -218,5 +241,6 @@ async fn get_all_rooms() -> Result, Error> { } }; let names: Vec = rows.filter_map(|result| result.ok()).collect(); + // Return return Ok(names); } \ No newline at end of file