Fix database interaction

This commit is contained in:
Niels Andriesse 2021-03-19 13:26:53 +11:00
parent 07114fad8b
commit 9c1bd0242c
3 changed files with 74 additions and 127 deletions

View File

@ -38,17 +38,15 @@ pub async fn store_file(base64_encoded_bytes: &str, pool: &storage:: DatabaseCon
// We do this * before * storing the actual file, so that in case something goes
// wrong we're not left with files that'll never be pruned.
let now = chrono::Utc::now().timestamp();
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let stmt = format!("INSERT INTO {} (id, timestamp) VALUES (?1, ?2)", storage::FILES_TABLE);
let _ = match tx.execute(&stmt, params![ id, now ]) {
let _ = match conn.execute(&stmt, params![ id, now ]) {
Ok(rows) => rows,
Err(e) => {
println!("Couldn't insert file record due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
tx.commit().map_err(|_| Error::DatabaseFailedInternally)?;
// Write to file
let mut pos = 0;
let mut buffer = match fs::File::create(format!("files/{}", &id)) {
@ -114,19 +112,15 @@ pub async fn get_auth_token_challenge(hex_public_key: &str, pool: &storage::Data
thread_rng().fill(&mut token[..]);
// Store the (pending) token
// Note that a given public key can have multiple pending tokens
{
let now = chrono::Utc::now().timestamp();
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
let stmt = format!("INSERT INTO {} (public_key, timestamp, token) VALUES (?1, ?2, ?3)", storage::PENDING_TOKENS_TABLE);
let _ = match tx.execute(&stmt, params![ hex_public_key, now, token.to_vec() ]) {
Ok(rows) => rows,
Err(e) => {
println!("Couldn't insert pending token due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
tx.commit().map_err(|_| Error::DatabaseFailedInternally)?;
let now = chrono::Utc::now().timestamp();
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let stmt = format!("INSERT INTO {} (public_key, timestamp, token) VALUES (?1, ?2, ?3)", storage::PENDING_TOKENS_TABLE);
let _ = match conn.execute(&stmt, params![ hex_public_key, now, token.to_vec() ]) {
Ok(rows) => rows,
Err(e) => {
println!("Couldn't insert pending token due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
// Encrypt the token with the symmetric key
let ciphertext = crypto::encrypt_aes_gcm(&token, &symmetric_key).await?;
@ -152,47 +146,42 @@ pub async fn claim_auth_token(public_key: &str, token: Option<String>, pool: &st
println!("Ignoring claim token request for invalid token.");
return Err(warp::reject::custom(Error::ValidationFailed));
}
// Get a database connection and open a transaction
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
// Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Get the pending tokens for the given public key
let pending_tokens: Vec<(i64, Vec<u8>)> = {
let raw_query = format!("SELECT timestamp, token FROM {} WHERE public_key = (?1) AND timestamp > (?2)", storage::PENDING_TOKENS_TABLE);
let mut query = tx.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let now = chrono::Utc::now().timestamp();
let expiration = now - storage::PENDING_TOKEN_EXPIRATION;
let rows = match query.query_map(params![ public_key, expiration ], |row| {
Ok((row.get(0)?, row.get(1)?))
}) {
Ok(rows) => rows,
Err(e) => {
println!("Couldn't get pending tokens due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
rows.filter_map(|result| result.ok()).collect()
let raw_query = format!("SELECT timestamp, token FROM {} WHERE public_key = (?1) AND timestamp > (?2)", storage::PENDING_TOKENS_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let now = chrono::Utc::now().timestamp();
let expiration = now - storage::PENDING_TOKEN_EXPIRATION;
let rows = match query.query_map(params![ public_key, expiration ], |row| {
Ok((row.get(0)?, row.get(1)?))
}) {
Ok(rows) => rows,
Err(e) => {
println!("Couldn't get pending tokens due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
let pending_tokens: Vec<(i64, Vec<u8>)> = rows.filter_map(|result| result.ok()).collect();
// Check that the token being claimed is in fact one of the pending tokens
let claim = hex::decode(token).unwrap(); // Safe because we validated it above
let index = pending_tokens.iter().position(|(_, pending_token)| *pending_token == claim).ok_or_else(|| Error::Unauthorized)?;
let token = &pending_tokens[index].1;
// Delete all pending tokens for the given public key
let stmt = format!("DELETE FROM {} WHERE public_key = (?1)", storage::PENDING_TOKENS_TABLE);
match tx.execute(&stmt, params![ public_key ]) {
Ok(_) => (),
Err(e) => println!("Couldn't delete pending tokens due to error: {}.", e) // It's not catastrophic if this fails
};
// Store the claimed token
let stmt = format!("INSERT OR REPLACE INTO {} (public_key, token) VALUES (?1, ?2)", storage::TOKENS_TABLE);
match tx.execute(&stmt, params![ public_key, hex::encode(token) ]) {
match conn.execute(&stmt, params![ public_key, hex::encode(token) ]) {
Ok(_) => (),
Err(e) => {
println!("Couldn't insert token due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
}
// Commit
tx.commit().map_err(|_| Error::DatabaseFailedInternally)?;
// Delete all pending tokens for the given public key
let stmt = format!("DELETE FROM {} WHERE public_key = (?1)", storage::PENDING_TOKENS_TABLE);
match conn.execute(&stmt, params![ public_key ]) {
Ok(_) => (),
Err(e) => println!("Couldn't delete pending tokens due to error: {}.", e) // It's not catastrophic if this fails
};
// Return
return Ok(StatusCode::OK.into_response());
}
@ -201,20 +190,17 @@ pub async fn delete_auth_token(auth_token: Option<String>, pool: &storage::Datab
// Check authorization level
let (has_authorization_level, requesting_public_key) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); }
// Get a connection and open a transaction
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
// Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Delete the token
let stmt = format!("DELETE FROM {} WHERE public_key = (?1)", storage::TOKENS_TABLE);
match tx.execute(&stmt, params![ requesting_public_key ]) {
match conn.execute(&stmt, params![ requesting_public_key ]) {
Ok(_) => (),
Err(e) => {
println!("Couldn't delete token due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
// Commit
tx.commit().map_err(|_| Error::DatabaseFailedInternally)?;
// Return
return Ok(StatusCode::OK.into_response());
}
@ -384,20 +370,17 @@ pub async fn ban(public_key: &str, auth_token: Option<String>, pool: &storage::D
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); }
// Don't double ban public keys
if is_banned(&public_key, pool).await? { return Ok(StatusCode::OK.into_response()); }
// Get a connection and open a transaction
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
// Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Insert the message
let stmt = format!("INSERT INTO {} (public_key) VALUES (?1)", storage::BLOCK_LIST_TABLE);
match tx.execute(&stmt, params![ public_key ]) {
match conn.execute(&stmt, params![ public_key ]) {
Ok(_) => (),
Err(e) => {
println!("Couldn't ban public key due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
// Commit
tx.commit().map_err(|_| Error::DatabaseFailedInternally)?;
// Return
return Ok(StatusCode::OK.into_response());
}
@ -414,20 +397,17 @@ pub async fn unban(public_key: &str, auth_token: Option<String>, pool: &storage:
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); }
// Don't double unban public keys
if !is_banned(&public_key, pool).await? { return Ok(StatusCode::OK.into_response()); }
// Get a connection and open a transaction
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
// Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Insert the message
let stmt = format!("DELETE FROM {} WHERE public_key = (?1)", storage::BLOCK_LIST_TABLE);
match tx.execute(&stmt, params![ public_key ]) {
match conn.execute(&stmt, params![ public_key ]) {
Ok(_) => (),
Err(e) => {
println!("Couldn't unban public key due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
// Commit
tx.commit().map_err(|_| Error::DatabaseFailedInternally)?;
// Return
return Ok(StatusCode::OK.into_response());
}

View File

@ -73,6 +73,7 @@ pub fn pool_by_room_id(room_id: isize) -> Result<DatabaseConnectionPool, Error>
}
};
let names: Vec<String> = rows.filter_map(|result| result.ok()).collect();
// Return
if let Some(name) = names.first() {
return Ok(pool_by_room_name(name));
} else {
@ -189,22 +190,14 @@ async fn prune_tokens() {
for room in rooms {
let pool = pool_by_room_name(&room);
// It's not catastrophic if we fail to prune the database for a given room
let mut conn = match pool.get() {
let conn = match pool.get() {
Ok(conn) => conn,
Err(e) => return println!("Couldn't prune tokens due to error: {}.", e)
};
let tx = match conn.transaction() {
Ok(tx) => tx,
Err(e) => return println!("Couldn't prune tokens due to error: {}.", e)
};
let stmt = format!("DELETE FROM {} WHERE timestamp < (?1)", TOKENS_TABLE);
let now = chrono::Utc::now().timestamp();
let expiration = now - TOKEN_EXPIRATION;
match tx.execute(&stmt, params![ expiration ]) {
Ok(_) => (),
Err(e) => return println!("Couldn't prune tokens due to error: {}.", e)
};
match tx.commit() {
match conn.execute(&stmt, params![ expiration ]) {
Ok(_) => (),
Err(e) => return println!("Couldn't prune tokens due to error: {}.", e)
};
@ -220,22 +213,14 @@ async fn prune_pending_tokens() {
for room in rooms {
let pool = pool_by_room_name(&room);
// It's not catastrophic if we fail to prune the database for a given room
let mut conn = match pool.get() {
let conn = match pool.get() {
Ok(conn) => conn,
Err(e) => return println!("Couldn't prune pending tokens due to error: {}.", e)
};
let tx = match conn.transaction() {
Ok(tx) => tx,
Err(e) => return println!("Couldn't prune pending tokens due to error: {}.", e)
};
let stmt = format!("DELETE FROM {} WHERE timestamp < (?1)", PENDING_TOKENS_TABLE);
let now = chrono::Utc::now().timestamp();
let expiration = now - PENDING_TOKEN_EXPIRATION;
match tx.execute(&stmt, params![ expiration ]) {
Ok(_) => (),
Err(e) => return println!("Couldn't prune pending tokens due to error: {}.", e)
};
match tx.commit() {
match conn.execute(&stmt, params![ expiration ]) {
Ok(_) => (),
Err(e) => return println!("Couldn't prune pending tokens due to error: {}.", e)
};
@ -250,36 +235,29 @@ pub async fn prune_files(file_expiration: i64) { // The expiration setting is pa
};
for room in rooms {
// It's not catastrophic if we fail to prune the database for a given room
println!("room name: {}", room);
let pool = pool_by_room_name(&room);
let now = chrono::Utc::now().timestamp();
let expiration = now - file_expiration;
// Get a database connection and open a transaction
let mut conn = match pool.get() {
let conn = match pool.get() {
Ok(conn) => conn,
Err(e) => return println!("Couldn't prune files due to error: {}.", e)
};
let tx = match conn.transaction() {
Ok(tx) => tx,
// Get the IDs of the files to delete
let raw_query = format!("SELECT id FROM {} WHERE timestamp < (?1)", FILES_TABLE);
let mut query = match conn.prepare(&raw_query) {
Ok(query) => query,
Err(e) => return println!("Couldn't prune files due to error: {}.", e)
};
// Get the IDs of the files to delete
let ids: Vec<String> = {
let raw_query = format!("SELECT id FROM {} WHERE timestamp < (?1)", FILES_TABLE);
let mut query = match tx.prepare(&raw_query) {
Ok(query) => query,
Err(e) => return println!("Couldn't prune files due to error: {}.", e)
};
let rows = match query.query_map(params![ expiration ], |row| {
Ok(row.get(0)?)
}) {
Ok(rows) => rows,
Err(e) => {
return println!("Couldn't prune files due to error: {}.", e);
}
};
rows.filter_map(|result| result.ok()).collect()
let rows = match query.query_map(params![ expiration ], |row| {
Ok(row.get(0)?)
}) {
Ok(rows) => rows,
Err(e) => {
return println!("Couldn't prune files due to error: {}.", e);
}
};
let ids: Vec<String> = rows.filter_map(|result| result.ok()).collect();
// Delete the files
let mut deleted_ids: Vec<String> = vec![];
for id in ids {
@ -288,13 +266,9 @@ pub async fn prune_files(file_expiration: i64) { // The expiration setting is pa
Err(e) => println!("Couldn't delete file due to error: {}.", e)
}
}
// Remove the file records from the database (only for the files that were actually deleted)
// Remove the file records from the database (only for the files that were successfully deleted)
let stmt = format!("DELETE FROM {} WHERE id IN (?1)", FILES_TABLE);
match tx.execute(&stmt, deleted_ids) {
Ok(_) => (),
Err(e) => return println!("Couldn't prune files due to error: {}.", e)
};
match tx.commit() {
match conn.execute(&stmt, deleted_ids) {
Ok(_) => (),
Err(e) => return println!("Couldn't prune files due to error: {}.", e)
};

View File

@ -27,31 +27,24 @@ fn set_up_test_room() {
storage::create_database_if_needed(test_room);
fs::read("rooms/test_room.db").unwrap(); // Fail if this doesn't exist
let pool: &storage::DatabaseConnectionPool = &storage::MAIN_POOL;
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let conn = pool.get().unwrap();
let stmt = format!("REPLACE INTO {} (id, name) VALUES (?1, ?2)", storage::MAIN_TABLE);
tx.execute(&stmt, params![ test_room, "Test Room" ]).unwrap();
tx.commit().unwrap();
conn.execute(&stmt, params![ test_room, "Test Room" ]).unwrap();
}
#[test]
fn test_file_handling() {
// Ensure the test room is set up
set_up_test_room();
// Test file storage
// Store the test file
let pool = storage::pool_by_room_name("test_room");
aw!(handlers::store_file(TEST_FILE, &pool)).unwrap();
// Check that there's a file record
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let conn = pool.get().unwrap();
let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);
let mut query = tx.prepare(&raw_query).unwrap();
let rows = query.query_map(params![], |row| { Ok(row.get(0)?) }).unwrap();
let ids: Vec<String> = rows.filter_map(|result| result.ok()).collect();
assert_eq!(ids.len(), 1);
let id = ids.first().unwrap();
let id: String = conn.query_row(&raw_query, params![], |row| { Ok(row.get(0)?) }).unwrap();
// Retrieve the file and check the content
let base64_encoded_file = aw!(handlers::get_file(id)).unwrap();
let base64_encoded_file = aw!(handlers::get_file(&id)).unwrap();
assert_eq!(base64_encoded_file, TEST_FILE);
// Prune the file and check that it's gone
aw!(storage::prune_files(-60)); // Will evaluate to now + 60
@ -60,13 +53,13 @@ fn test_file_handling() {
Err(_) => ()
}
// Check that the file record is also gone
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let conn = pool.get().unwrap();
let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);
let mut query = tx.prepare(&raw_query).unwrap();
let rows = query.query_map(params![], |row| { Ok(row.get(0)?) }).unwrap();
let ids: Vec<String> = rows.filter_map(|result| result.ok()).collect();
assert_eq!(ids.len(), 0);
let result: Result<String, _> = conn.query_row(&raw_query, params![], |row| { Ok(row.get(0)?) });
match result {
Ok(_) => assert!(false), // It should be gone now
Err(_) => ()
}
}
// Data