test suite and cargp fmt

This commit is contained in:
Sean Darcy 2021-10-08 16:56:27 +11:00
parent ad4cc06ad8
commit 953b752100
14 changed files with 507 additions and 396 deletions

1
.gitignore vendored
View file

@ -6,3 +6,4 @@
*.pem
/files
/rooms
/uploads

View file

@ -62,7 +62,7 @@ lazy_static::lazy_static! {
/// Takes hex string representation of an ed25519 pubkey, returns the ed25519
/// pubkey, derived x25519 pubkey, and the Session id in hex.
pub fn get_pubkeys(
edpk_hex: &str
edpk_hex: &str,
) -> Result<(ed25519_dalek::PublicKey, x25519_dalek::PublicKey, String), warp::reject::Rejection> {
if edpk_hex.len() != 64 {
return Err(warp::reject::custom(Error::DecryptionFailed));
@ -92,9 +92,8 @@ pub fn get_pubkeys(
pub fn get_x25519_symmetric_key(
public_key: &[u8],
private_key: &x25519_dalek::StaticSecret
) -> Result<Vec<u8>, warp::reject::Rejection>
{
private_key: &x25519_dalek::StaticSecret,
) -> Result<Vec<u8>, warp::reject::Rejection> {
if public_key.len() != 32 {
error!(
"Couldn't create symmetric key using public key of invalid length: {}.",
@ -112,9 +111,8 @@ pub fn get_x25519_symmetric_key(
pub fn encrypt_aes_gcm(
plaintext: &[u8],
symmetric_key: &[u8]
) -> Result<Vec<u8>, warp::reject::Rejection>
{
symmetric_key: &[u8],
) -> Result<Vec<u8>, warp::reject::Rejection> {
let mut iv = [0u8; IV_SIZE];
thread_rng().fill(&mut iv[..]);
let cipher = Aes256Gcm::new(&GenericArray::from_slice(symmetric_key));
@ -133,9 +131,8 @@ pub fn encrypt_aes_gcm(
pub fn decrypt_aes_gcm(
iv_and_ciphertext: &[u8],
symmetric_key: &[u8]
) -> Result<Vec<u8>, warp::reject::Rejection>
{
symmetric_key: &[u8],
) -> Result<Vec<u8>, warp::reject::Rejection> {
if iv_and_ciphertext.len() < IV_SIZE {
warn!("Ignoring ciphertext of invalid size: {}.", iv_and_ciphertext.len());
return Err(warp::reject::custom(Error::DecryptionFailed));
@ -162,9 +159,8 @@ pub fn generate_x25519_key_pair() -> (x25519_dalek::StaticSecret, x25519_dalek::
pub fn verify_signature(
edpk: &ed25519_dalek::PublicKey,
sig: &ed25519_dalek::Signature,
parts: &[&[u8]]
) -> Result<(), Error>
{
parts: &[&[u8]],
) -> Result<(), Error> {
let mut verify_buf: Vec<u8> = Vec::new();
let verify: &[u8];
if parts.len() == 1 {

View file

@ -16,7 +16,7 @@ pub enum Error {
/// The requesting user provided a valid auth token, but they don't have a
/// high enough permission level.
Unauthorized,
ValidationFailed
ValidationFailed,
}
impl Reject for Error {}

View file

@ -8,6 +8,8 @@ use base64;
use ed25519_dalek::Signer;
use log::{debug, error, info, warn};
use parking_lot::RwLock;
use r2d2::PooledConnection;
use r2d2_sqlite::SqliteConnectionManager;
use regex::Regex;
use rusqlite::{params, params_from_iter};
use serde::{Deserialize, Serialize};
@ -29,17 +31,17 @@ use super::storage::{self, db_error};
// for moderators/admins.
#[derive(Default)]
pub struct AuthorizationRequired {
admin: bool, // Required admin permission (server or room)
moderator: bool, // Requires moderator or admin permission (server or room)
read: bool, // Requires read permission
write: bool, // Requires write permission
upload: bool // Requires upload permission
pub admin: bool, // Required admin permission (server or room)
pub moderator: bool, // Requires moderator or admin permission (server or room)
pub read: bool, // Requires read permission
pub write: bool, // Requires write permission
pub upload: bool, // Requires upload permission
}
#[derive(Debug, Serialize)]
pub struct GenericStringResponse {
pub status_code: u16,
pub result: String
pub result: String,
}
// FIXME: this is used to query the github API periodically to find new releases. Ew.
@ -86,14 +88,12 @@ pub const TOKEN_ID_SIZE: usize = 33;
pub const TOKEN_SIG_SIZE: usize = 64;
pub const TOKEN_SIZE: usize = TOKEN_ID_SIZE + TOKEN_SIG_SIZE;
// Rooms
//
#[derive(Deserialize)]
pub struct CreateRoom {
pub token: String,
pub name: String
pub name: String,
}
// Not publicly exposed.
@ -103,10 +103,7 @@ pub async fn create_room(room: CreateRoom) -> Result<Response, Rejection> {
// Get a connection
let conn = storage::get_conn()?;
// Insert the room
let stmt = "INSERT INTO rooms (token, name) VALUES (?, ?) \
ON CONFLICT DO UPDATE SET token = excluded.token, name = excluded.name";
if let Err(e) = conn.execute(&stmt, params![room.token, room.name]) {
if let Err(e) = create_room_with_conn(&conn, &room) {
error!("Couldn't create room: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
@ -117,6 +114,16 @@ pub async fn create_room(room: CreateRoom) -> Result<Response, Rejection> {
Ok(warp::reply::json(&json).into_response())
}
pub fn create_room_with_conn(
conn: &PooledConnection<SqliteConnectionManager>,
room: &CreateRoom,
) -> Result<usize, rusqlite::Error> {
// Insert the room
let stmt = "INSERT INTO rooms (token, name) VALUES (?, ?) \
ON CONFLICT DO UPDATE SET token = excluded.token, name = excluded.name";
return conn.execute(&stmt, params![room.token, room.name]);
}
// Not publicly exposed.
pub async fn delete_room(token: String) -> Result<Response, Rejection> {
// Get a connection
@ -187,10 +194,13 @@ pub fn get_all_rooms_v01x() -> Result<Response, Rejection> {
#[derive(Debug, Serialize)]
struct OldRoom {
id: String,
name: String
name: String,
}
let rooms = get_all_rooms_impl()?.into_iter().map(|r| OldRoom{ id: r.token, name: r.name }).collect::<Vec<OldRoom>>();
let rooms = get_all_rooms_impl()?
.into_iter()
.map(|r| OldRoom { id: r.token, name: r.name })
.collect::<Vec<OldRoom>>();
let response = json!({ "status_code": StatusCode::OK.as_u16(), "rooms": rooms });
Ok(warp::reply::json(&response).into_response())
@ -201,12 +211,12 @@ pub fn get_all_rooms_v01x() -> Result<Response, Rejection> {
/// RAII class holding an in-progress upload transaction and path details. If this is dropped
/// without `commit()` being called we remove the file from disk and abort the transaction
/// inserting the upload into the database.
struct FileUpload<'a> {
pub struct FileUpload<'a> {
pub id: i64, // The value of `id` in the `files` table for this new file
pub room: &'a Room, // The room the file is uploaded to
pub path: String, // The relative path containing the in-progress file upload
tx: Option<rusqlite::Transaction<'a>>,
committed: bool
committed: bool,
}
impl FileUpload<'_> {
pub fn new<'a>(tx: rusqlite::Transaction<'a>, room: &'a Room) -> FileUpload<'a> {
@ -232,22 +242,20 @@ impl Drop for FileUpload<'_> {
}
}
/// Does the actual work involved in storing a file, inserting into the database, etc.
///
/// Returns a FileUpload on success. The caller may optionally use this to perform additional
/// actions, but *must* call `.commit()` on success -- if dropped the FileUpload will clean up the
/// temporary file and drop the transaction inserting the records.
fn store_file_impl<'a>(
pub fn store_file_impl<'a>(
conn: &'a mut storage::DatabaseConnection,
room: &'a Room,
user: &User,
auth: AuthorizationRequired,
data_b64: &str,
filename: Option<&str>,
expires: bool
) -> Result<FileUpload<'a>, Rejection>
{
expires: bool,
) -> Result<FileUpload<'a>, Rejection> {
// Determine the file size from the base64 data without decoding it (we'll do that later
// directly to the destination file).
let mut bytes: usize = data_b64.len() / 4 * 3;
@ -293,7 +301,7 @@ fn store_file_impl<'a>(
(SystemTime::now() + UPLOAD_DEFAULT_EXPIRY)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs_f64()
.as_secs_f64(),
)
} else {
None
@ -305,7 +313,7 @@ fn store_file_impl<'a>(
.prepare_cached(
"INSERT INTO files (room, uploader, size, expiry, filename, path) \
VALUES (?, ?, ?, ?, ?, 'tmp') \
RETURNING id"
RETURNING id",
)
.map_err(db_error)?
.query_row(params![room.id, user.id, bytes, db_filename, expiry], |row| row.get(0))
@ -323,7 +331,7 @@ fn store_file_impl<'a>(
if fs_filename.len() > UPLOAD_FILENAME_MAX {
fs_filename.replace_range(
UPLOAD_FILENAME_KEEP_PREFIX..fs_filename.len() - UPLOAD_FILENAME_KEEP_SUFFIX,
"..."
"...",
);
}
fs_filename = format!("{}/{}_{}", files_dir, upload.id, fs_filename);
@ -373,12 +381,12 @@ fn store_file_impl<'a>(
}
pub fn store_file(
room: Room,
user: User,
room: &Room,
user: &User,
data_b64: &str,
filename: Option<&str>
) -> Result<Response, Rejection>
{
filename: Option<&str>,
) -> Result<Response, Rejection> {
let mut conn = storage::get_conn()?;
if !matches!(rpc::MODE, rpc::Mode::OpenGroupServer) {
panic!("FIXME file mode FIXME FIXME TODO!");
// FIXME TODO
@ -386,18 +394,16 @@ pub fn store_file(
let auth = AuthorizationRequired { upload: true, write: true, ..Default::default() };
let mut conn = storage::get_conn()?;
let mut upload = match store_file_impl(&mut conn, &room, &user, auth, data_b64, filename, true)
{
Ok(id) => id,
Err(e) => return Err(e)
Err(e) => return Err(e),
};
if let Err(e) = upload.commit() {
error!("File upload failed: {}", e);
return Err(Error::DatabaseFailedInternally.into());
}
let response = json!({ "status_code": StatusCode::OK.as_u16(), "result": upload.id });
Ok(warp::reply::json(&response).into_response())
}
@ -437,18 +443,28 @@ fn file_response(path_row: rusqlite::Result<String>) -> Result<Response, Rejecti
// Return
let json = GenericStringResponse {
status_code: StatusCode::OK.as_u16(),
result: base64_encoded_bytes
result: base64_encoded_bytes,
};
Ok(warp::reply::json(&json).into_response())
}
pub fn get_file(room: Room, id: i64, user: User) -> Result<Response, Rejection> {
let conn = storage::get_conn()?;
return get_file_conn(&conn, &room, id, user);
}
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
pub fn get_file_conn(
conn: &PooledConnection<SqliteConnectionManager>,
room: &Room,
id: i64,
user: User,
) -> Result<Response, Rejection> {
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let mut row = conn
.prepare_cached("SELECT path FROM files WHERE room = ? AND id = ?")
@ -474,7 +490,7 @@ pub async fn get_room_image(room: Room) -> Result<Response, Rejection> {
let row = conn
.prepare_cached(
"SELECT path FROM rooms JOIN files ON rooms.image = files.id WHERE rooms.id = ?"
"SELECT path FROM rooms JOIN files ON rooms.image = files.id WHERE rooms.id = ?",
)
.map_err(db_error)?
.query_row(params![room.id], |row| row.get(0));
@ -485,9 +501,8 @@ pub async fn set_room_image(
room: Room,
user: User,
data_b64: &str,
filename: Option<&str>
) -> Result<Response, Rejection>
{
filename: Option<&str>,
) -> Result<Response, Rejection> {
let auth = AuthorizationRequired { moderator: true, ..Default::default() };
let mut conn = storage::get_conn()?;
@ -534,20 +549,18 @@ pub fn insert_or_update_user(conn: &rusqlite::Connection, session_id: &str) -> R
.prepare_cached(
"INSERT INTO users (session_id) VALUES (?) \
ON CONFLICT DO UPDATE SET last_active = ((julianday('now') - 2440587.5)*86400.0) \
RETURNING *"
RETURNING *",
)
.map_err(db_error)?
.query_row(params![&session_id], User::from_row)
.map_err(db_error)?)
}
// Validates a (backwards compat) token string.
pub fn get_user_from_token(
conn: &rusqlite::Connection,
auth_token_str: &str
) -> Result<User, Error>
{
auth_token_str: &str,
) -> Result<User, Error> {
let auth_token =
decode_hex_or_b64(auth_token_str, TOKEN_SIZE).map_err(|_| Error::NoAuthToken)?;
if auth_token[0] != 0x05 {
@ -566,7 +579,6 @@ pub fn get_user_from_token(
insert_or_update_user(conn, &session_id)
}
pub fn get_auth_token_challenge(public_key: &str) -> Result<models::Challenge, Rejection> {
// Doesn't return a response directly for testing purposes
@ -597,7 +609,7 @@ pub fn get_auth_token_challenge(public_key: &str) -> Result<models::Challenge, R
// Return
Ok(models::Challenge {
ciphertext: base64::encode(ciphertext),
ephemeral_public_key: base64::encode(ephemeral_public_key.to_bytes())
ephemeral_public_key: base64::encode(ephemeral_public_key.to_bytes()),
})
}
@ -611,15 +623,16 @@ pub fn insert_message(
room: Room,
user: User,
data: &[u8],
signature: &[u8]
) -> Result<Response, Rejection>
{
signature: &[u8],
) -> Result<Response, Rejection> {
let mut conn = storage::get_conn()?;
let tx = storage::get_transaction(&mut conn)?;
require_authorization(&tx, &user, &room, AuthorizationRequired {
write: true,
..Default::default()
})?;
require_authorization(
&tx,
&user,
&room,
AuthorizationRequired { write: true, ..Default::default() },
)?;
// Check if the requesting user needs to be rate limited
@ -639,7 +652,7 @@ pub fn insert_message(
let size = data.len();
let trimmed = match data.iter().rposition(|&c| c != 0u8) {
Some(last) => &data[0..=last],
None => &data
None => &data,
};
// Insert the message
@ -647,7 +660,7 @@ pub fn insert_message(
.prepare_cached(
"INSERT INTO messages (room, user, data, data_size, signature) \
VALUES (?, ?, ?, ?, ?) \
RETURNING *"
RETURNING *",
)
.map_err(db_error)?
.query_row(params![room.id, user.id, trimmed, size, signature], OldMessage::from_row)
@ -701,15 +714,16 @@ fn get_messages_params(query_params: &HashMap<String, String>) -> (Option<i64>,
pub fn get_messages(
query_params: HashMap<String, String>,
user: User,
room: Room
) -> Result<Vec<OldMessage>, Rejection>
{
room: Room,
) -> Result<Vec<OldMessage>, Rejection> {
let conn = storage::get_conn()?;
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let (from_server_id, limit) = get_messages_params(&query_params);
@ -757,9 +771,8 @@ pub fn delete_message(
conn: &rusqlite::Connection,
id: i64,
user: &User,
room: &Room
) -> Result<Response, Rejection>
{
room: &Room,
) -> Result<Response, Rejection> {
let mut auth_req = AuthorizationRequired { read: true, ..Default::default() };
// Check to see if the message to be deleted is owned by someone else: if it is, we require
@ -778,14 +791,14 @@ pub fn delete_message(
let response = json!({"status_code": StatusCode::NOT_FOUND.as_u16()});
return Ok(warp::reply::json(&response).into_response());
}
Err(_) => return Err(Error::DatabaseFailedInternally.into())
Err(_) => return Err(Error::DatabaseFailedInternally.into()),
};
require_authorization(conn, user, room, auth_req)?;
let mut del_st = conn
.prepare_cached(
"UPDATE messages SET data = NULL, data_size = NULL, signature = NULL WHERE id = ?"
"UPDATE messages SET data = NULL, data_size = NULL, signature = NULL WHERE id = ?",
)
.map_err(db_error)?;
@ -805,17 +818,18 @@ pub fn delete_message(
pub fn get_deleted_messages(
query_params: HashMap<String, String>,
user: User,
room: Room
) -> Result<Vec<models::DeletedMessage>, Rejection>
{
room: Room,
) -> Result<Vec<models::DeletedMessage>, Rejection> {
let conn = storage::get_conn()?;
let (from_server_id, limit) = get_messages_params(&query_params);
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
// Query the database
let mut st = conn.prepare_cached(if from_server_id.is_some() {
@ -843,13 +857,14 @@ pub fn add_moderator_public(
room: Room,
user: User,
session_id: &str,
admin: bool
) -> Result<Response, Rejection>
{
require_authorization(&*storage::get_conn()?, &user, &room, AuthorizationRequired {
admin: true,
..Default::default()
})?;
admin: bool,
) -> Result<Response, Rejection> {
require_authorization(
&*storage::get_conn()?,
&user,
&room,
AuthorizationRequired { admin: true, ..Default::default() },
)?;
add_moderator_impl(session_id, admin, room)
}
@ -858,21 +873,20 @@ pub fn add_moderator_public(
// Not publicly exposed.
pub async fn add_moderator(
body: models::ChangeModeratorRequestBody
body: models::ChangeModeratorRequestBody,
) -> Result<Response, Rejection> {
add_moderator_impl(
&body.session_id,
body.admin.unwrap_or(false),
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?,
)
}
pub fn add_moderator_impl(
session_id: &str,
admin: bool,
room: Room
) -> Result<Response, Rejection>
{
room: Room,
) -> Result<Response, Rejection> {
require_session_id(session_id)?;
let mut conn = storage::get_conn()?;
@ -913,23 +927,24 @@ pub fn add_moderator_impl(
pub fn delete_moderator_public(
session_id: &str,
user: User,
room: Room
) -> Result<Response, Rejection>
{
require_authorization(&*storage::get_conn()?, &user, &room, AuthorizationRequired {
admin: true,
..Default::default()
})?;
room: Room,
) -> Result<Response, Rejection> {
require_authorization(
&*storage::get_conn()?,
&user,
&room,
AuthorizationRequired { admin: true, ..Default::default() },
)?;
delete_moderator_impl(session_id, room)
}
// Not publicly exposed.
pub async fn delete_moderator(
body: models::ChangeModeratorRequestBody
body: models::ChangeModeratorRequestBody,
) -> Result<Response, Rejection> {
delete_moderator_impl(
&body.session_id,
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?,
)
}
@ -951,7 +966,7 @@ pub fn delete_moderator_impl(session_id: &str, room: Room) -> Result<Response, R
Ok(count) if count > 0 => {
info!("Removed moderator {} from room {}", session_id, room.token)
}
Ok(_count) => info!("{} is not a moderator of room {}", session_id, room.token)
Ok(_count) => info!("{} is not a moderator of room {}", session_id, room.token),
}
let json = models::StatusCode { status_code: StatusCode::OK.as_u16() };
@ -962,17 +977,18 @@ pub fn delete_moderator_impl(session_id: &str, room: Room) -> Result<Response, R
pub fn get_moderators(
conn: &rusqlite::Connection,
user: &User,
room: &Room
) -> Result<Vec<String>, Rejection>
{
require_authorization(conn, user, room, AuthorizationRequired {
read: true,
..Default::default()
})?;
room: &Room,
) -> Result<Vec<String>, Rejection> {
require_authorization(
conn,
user,
room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let mut st = conn
.prepare_cached(
"SELECT session_id FROM user_permissions WHERE room = ? AND moderator AND visible_mod"
"SELECT session_id FROM user_permissions WHERE room = ? AND moderator AND visible_mod",
)
.map_err(db_error)?;
@ -999,9 +1015,8 @@ pub async fn ban(
session_id: &str,
delete_all: bool,
user: &User,
room: &Room
) -> Result<Response, Rejection>
{
room: &Room,
) -> Result<Response, Rejection> {
if !is_session_id(&session_id) {
warn!("Ignoring ban request: invalid session_id.");
return Err(Error::ValidationFailed.into());
@ -1022,7 +1037,7 @@ pub async fn ban(
match tx
.prepare_cached(
"SELECT user, moderator, global_moderator FROM user_permissions \
WHERE room = ? AND session_id = ?"
WHERE room = ? AND session_id = ?",
)
.map_err(db_error)?
.query_row(params![room.id, session_id], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))
@ -1042,7 +1057,7 @@ pub async fn ban(
let response = json!({"status_code": StatusCode::NOT_FOUND.as_u16()});
return Ok(warp::reply::json(&response).into_response());
}
Err(_) => return Err(Error::DatabaseFailedInternally.into())
Err(_) => return Err(Error::DatabaseFailedInternally.into()),
};
require_authorization(&tx, user, room, auth)?;
@ -1053,7 +1068,7 @@ pub async fn ban(
INSERT INTO user_permission_overrides (room, user, banned, moderator, admin) \
VALUES (?, ?, TRUE, FALSE, FALSE) \
ON CONFLICT DO UPDATE SET banned = TRUE, moderator = FALSE, admin = FALSE
"
",
)
.map_err(db_error)?
.execute(params![room.id, userid])
@ -1068,7 +1083,7 @@ pub async fn ban(
posts_removed += match tx
.prepare_cached(
"UPDATE messages SET data = NULL, data_size = NULL, signature = NULL \
WHERE room = ? AND user = ?"
WHERE room = ? AND user = ?",
)
.map_err(db_error)?
.execute(params![room.id, userid])
@ -1085,7 +1100,7 @@ pub async fn ban(
// from disk).
files_removed = tx
.prepare_cached(
"UPDATE files SET room = NULL, expiry = ? WHERE room = ? AND uploader = ?"
"UPDATE files SET room = NULL, expiry = ? WHERE room = ? AND uploader = ?",
)
.map_err(db_error)?
.execute(params![unixtime_f64(), room.id, userid])
@ -1114,15 +1129,17 @@ pub fn unban(session_id: &str, user: &User, room: &Room) -> Result<Response, Rej
}
let conn = storage::get_conn()?;
require_authorization(&conn, user, room, AuthorizationRequired {
moderator: true,
..Default::default()
})?;
require_authorization(
&conn,
user,
room,
AuthorizationRequired { moderator: true, ..Default::default() },
)?;
let count = match conn
.prepare_cached(
"UPDATE user_permission_overrides SET banned = FALSE \
WHERE room = ? AND user IN (SELECT id FROM users WHERE session_id = ?)"
WHERE room = ? AND user IN (SELECT id FROM users WHERE session_id = ?)",
)
.map_err(db_error)?
.execute(params![room.id, session_id])
@ -1148,10 +1165,12 @@ pub fn unban(session_id: &str, user: &User, room: &Room) -> Result<Response, Rej
/// Returns the full list of banned public keys.
pub fn get_banned_public_keys(user: &User, room: &Room) -> Result<Response, Rejection> {
let conn = storage::get_conn()?;
require_authorization(&conn, user, room, AuthorizationRequired {
moderator: true,
..Default::default()
})?;
require_authorization(
&conn,
user,
room,
AuthorizationRequired { moderator: true, ..Default::default() },
)?;
let banned_members: Result<Vec<String>, _> = match conn
.prepare_cached("SELECT session_id FROM user_permissions WHERE room = ? AND banned")
@ -1184,14 +1203,15 @@ pub fn get_member_count(user: User, room: Room) -> Result<Response, Rejection> {
pub fn get_member_count_since(
user: User,
room: Room,
ago: Duration
) -> Result<Response, Rejection>
{
ago: Duration,
) -> Result<Response, Rejection> {
let conn = storage::get_conn()?;
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let mut st = conn
.prepare_cached("SELECT COUNT(*) FROM room_users WHERE room = ? AND last_active >= ?")
@ -1236,15 +1256,13 @@ pub fn get_room_updates(user: User, room: Room, since_update: i64) {
*/
}
/// Deprecated room polling; unlike the above, this does not handle metadata (except for
/// moderators, which are *always* included even though they rarely change), does not support
/// message edits, and has non-obvious alternate modes of operation.
pub fn compact_poll(
user: Option<User>,
request_bodies: Vec<models::CompactPollRequestBody>
) -> Result<Response, Rejection>
{
request_bodies: Vec<models::CompactPollRequestBody>,
) -> Result<Response, Rejection> {
let mut response_bodies = Vec::<models::CompactPollResponseBody>::new();
let mut conn = storage::get_conn()?;
@ -1275,7 +1293,7 @@ pub fn compact_poll(
.prepare_cached(
"SELECT * FROM message_details \
WHERE room = ? AND data IS NOT NULL \
ORDER BY id DESC LIMIT 256"
ORDER BY id DESC LIMIT 256",
)
.map_err(db_error)?;
@ -1283,7 +1301,7 @@ pub fn compact_poll(
.prepare_cached(
"SELECT id, updated FROM messages \
WHERE room = ? AND data IS NULL \
ORDER BY updated DESC LIMIT 256"
ORDER BY updated DESC LIMIT 256",
)
.map_err(db_error)?;
@ -1292,7 +1310,7 @@ pub fn compact_poll(
.prepare_cached(
"SELECT id, updated FROM messages \
WHERE room = ? AND updated > ? AND data IS NULL \
ORDER BY updated LIMIT 256"
ORDER BY updated LIMIT 256",
)
.map_err(db_error)?;
@ -1300,18 +1318,17 @@ pub fn compact_poll(
.prepare_cached(
"SELECT * FROM message_details \
WHERE room = ? AND id > ? AND data IS NOT NULL \
ORDER BY id LIMIT 256"
ORDER BY id LIMIT 256",
)
.map_err(db_error)?;
for request in request_bodies {
let mut response = models::CompactPollResponseBody {
room_token: request.room_token.clone(),
status_code: StatusCode::OK.as_u16(),
messages: vec![],
deletions: vec![],
moderators: vec![]
moderators: vec![],
};
let room: &Room = match rooms.get(&request.room_token) {
@ -1449,13 +1466,12 @@ pub async fn get_session_version(platform: &str) -> Result<String, Rejection> {
// not publicly exposed.
pub async fn get_stats_for_room(
room_token: String,
query_map: HashMap<String, i64>
) -> Result<Response, Rejection>
{
query_map: HashMap<String, i64>,
) -> Result<Response, Rejection> {
let window = *query_map.get("window").unwrap_or(&3600) as f64;
let upperbound = match query_map.get("start") {
Some(ts) => *ts as f64,
None => unixtime_f64()
None => unixtime_f64(),
};
let lowerbound = upperbound - window;
@ -1466,7 +1482,7 @@ pub async fn get_stats_for_room(
let active = tx
.prepare_cached(
"SELECT COUNT(*) FROM room_users WHERE room = ? AND last_active BETWEEN ? AND ?"
"SELECT COUNT(*) FROM room_users WHERE room = ? AND last_active BETWEEN ? AND ?",
)
.map_err(db_error)?
.query_row(params![room.id, lowerbound, upperbound], |row| Ok(row.get::<_, i64>(0)?))
@ -1511,9 +1527,8 @@ fn require_authorization(
conn: &rusqlite::Connection,
user: &User,
room: &Room,
req: AuthorizationRequired
) -> Result<(), Error>
{
req: AuthorizationRequired,
) -> Result<(), Error> {
require_authorization_impl(conn, &user, &room, req, true)
}
/// Same as above, but does not update the room/user last activity timestamp.
@ -1522,9 +1537,8 @@ fn require_authorization_no_activity(
conn: &rusqlite::Connection,
user: &User,
room: &Room,
req: AuthorizationRequired
) -> Result<(), Error>
{
req: AuthorizationRequired,
) -> Result<(), Error> {
return require_authorization_impl(conn, &user, &room, req, false);
}
@ -1533,9 +1547,8 @@ fn require_authorization_impl(
user: &User,
room: &Room,
need: AuthorizationRequired,
log_active: bool
) -> Result<(), Error>
{
log_active: bool,
) -> Result<(), Error> {
let mut st = conn
.prepare_cached(
"SELECT banned, read, write, upload, moderator, admin FROM user_permissions WHERE room = ? AND user = ?"

View file

@ -1,9 +1,13 @@
use log::LevelFilter;
use log4rs::{append::{console::ConsoleAppender,
rolling_file::{policy::compound, RollingFileAppender}},
use log4rs::{
append::{
console::ConsoleAppender,
rolling_file::{policy::compound, RollingFileAppender},
},
config::{Appender, Logger, Root},
encode::pattern::PatternEncoder,
filter::threshold::ThresholdFilter};
filter::threshold::ThresholdFilter,
};
use std::str::FromStr;
pub fn init(log_file: Option<String>, log_level: Option<String>) {

View file

@ -3,8 +3,10 @@
use parking_lot::RwLock;
use std::fs;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::{collections::HashMap,
sync::atomic::{AtomicBool, AtomicU16, Ordering}};
use std::{
collections::HashMap,
sync::atomic::{AtomicBool, AtomicU16, Ordering},
};
use futures::join;
use log::info;
@ -15,13 +17,13 @@ mod crypto;
mod errors;
mod handlers;
mod logging;
mod migration;
mod models;
mod onion_requests;
mod options;
mod routes;
mod rpc;
mod storage;
mod migration;
#[cfg(test)]
mod tests;

View file

@ -1,16 +1,15 @@
use std::fs;
use std::path::Path;
use std::os::unix::fs::MetadataExt;
use std::path::Path;
use std::time::SystemTime;
use log::{info, warn};
use rusqlite::{Connection, OpenFlags, params, types::Null};
use super::handlers;
use super::storage;
use log::{info, warn};
use rusqlite::{params, types::Null, Connection, OpenFlags};
// Performs database migration from v0.1.8 to v0.2.0
pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// Old database database.db is a single table database containing just the list of rooms:
/*
CREATE TABLE IF NOT EXISTS main (
@ -24,7 +23,10 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// that starting again will try to import again).
let tx = conn.transaction()?;
struct Rm { token: String, name: Option<String> }
struct Rm {
token: String,
name: Option<String>,
}
let rooms = Connection::open_with_flags("database.db", OpenFlags::SQLITE_OPEN_READ_ONLY)?
.prepare("SELECT id, name FROM main")?
@ -34,34 +36,44 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
warn!("{} rooms to import", rooms.len());
{
tx.execute("\
tx.execute(
"\
CREATE TABLE room_import_hacks (
room INTEGER PRIMARY KEY NOT NULL REFERENCES rooms(id),
old_message_id_max INTEGER NOT NULL,
message_id_offset INTEGER NOT NULL
)", [])?;
)",
[],
)?;
let mut used_room_hacks: bool = false;
let mut ins_room_hack = tx.prepare(
"INSERT INTO room_import_hacks (room, old_message_id_max, message_id_offset) VALUES (?, ?, ?)")?;
tx.execute("\
tx.execute(
"\
CREATE TABLE file_id_hacks (
room INTEGER NOT NULL REFERENCES rooms(id),
old_file_id INTEGER NOT NULL,
file INTEGER NOT NULL REFERENCES files(id) ON DELETE CASCADE,
PRIMARY KEY(room, old_file_id)
)", [])?;
)",
[],
)?;
let mut used_file_hacks: bool = false;
let mut ins_file_hack = tx.prepare(
"INSERT INTO file_id_hacks (room, old_file_id, file) VALUES (?, ?, ?)")?;
let mut ins_file_hack =
tx.prepare("INSERT INTO file_id_hacks (room, old_file_id, file) VALUES (?, ?, ?)")?;
let mut ins_room = tx.prepare("INSERT INTO rooms (token, name) VALUES (?, ?) RETURNING id")?;
let mut ins_room =
tx.prepare("INSERT INTO rooms (token, name) VALUES (?, ?) RETURNING id")?;
let mut ins_user = tx.prepare("INSERT INTO users (session_id, last_active) VALUES (?, 0.0) ON CONFLICT DO NOTHING")?;
let mut ins_user = tx.prepare(
"INSERT INTO users (session_id, last_active) VALUES (?, 0.0) ON CONFLICT DO NOTHING",
)?;
let mut ins_msg = tx.prepare(
"INSERT INTO messages (id, room, user, posted, data, data_size, signature) \
VALUES (?, ?, (SELECT id FROM users WHERE session_id = ?), ?, ?, ?, ?)")?;
VALUES (?, ?, (SELECT id FROM users WHERE session_id = ?), ?, ?, ?, ?)",
)?;
let mut upd_msg_updated = tx.prepare("UPDATE messages SET updated = ? WHERE id = ?")?;
let mut upd_room_updates = tx.prepare("UPDATE rooms SET updates = ? WHERE id = ?")?;
@ -83,7 +95,8 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
"INSERT INTO room_users (room, user, last_active) VALUES (?, (SELECT id FROM users WHERE session_id = ?), ?) \
ON CONFLICT DO UPDATE SET last_active = excluded.last_active WHERE excluded.last_active > last_active")?;
let mut upd_user_activity = tx.prepare(
"UPDATE users SET last_active = ?1 WHERE session_id = ?2 AND last_active < ?1")?;
"UPDATE users SET last_active = ?1 WHERE session_id = ?2 AND last_active < ?1",
)?;
for room in rooms {
let room_db_filename = format!("rooms/{}.db", room.token);
@ -95,7 +108,8 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
info!("Importing room {}...", room.token);
let room_id = ins_room.query_row(params![room.token, room.name], |row| row.get::<_, i64>(0))?;
let room_id =
ins_room.query_row(params![room.token, room.name], |row| row.get::<_, i64>(0))?;
let rconn = Connection::open_with_flags(room_db, OpenFlags::SQLITE_OPEN_READ_ONLY)?;
@ -138,12 +152,23 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
deletion ids.
*/
let mut id_offset: i64 = tx.query_row("SELECT COALESCE(MAX(id), 0) + 1 FROM messages", [], |row| row.get(0))?;
let mut id_offset: i64 =
tx.query_row("SELECT COALESCE(MAX(id), 0) + 1 FROM messages", [], |row| {
row.get(0)
})?;
let mut top_old_id: i64 = -1;
let mut updated: i64 = 0;
let mut imported_msgs: i64 = 0;
struct Msg { id: i64, session_id: String, ts_ms: i64, data: Option<String>, signature: Option<String>, deleted: Option<i64> }
let n_msgs: i64 = rconn.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
struct Msg {
id: i64,
session_id: String,
ts_ms: i64,
data: Option<String>,
signature: Option<String>,
deleted: Option<i64>,
}
let n_msgs: i64 =
rconn.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
let mut msg_st = rconn.prepare("\
SELECT messages.id, public_key, timestamp, data, signature, is_deleted, deleted_messages.id \
FROM messages LEFT JOIN deleted_messages ON messages.id = deleted_messages.deleted_message_id
@ -153,8 +178,16 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
let mut dupe_dels: i64 = 0;
while let Some(row) = msg_rows.next()? {
let msg = Msg {
id: row.get(0)?, session_id: row.get(1)?, ts_ms: row.get(2)?, data: row.get(3)?, signature: row.get(4)?,
deleted: if row.get::<_, Option<bool>>(5)?.unwrap_or(false) { Some(row.get(6)?) } else { None }
id: row.get(0)?,
session_id: row.get(1)?,
ts_ms: row.get(2)?,
data: row.get(3)?,
signature: row.get(4)?,
deleted: if row.get::<_, Option<bool>>(5)?.unwrap_or(false) {
Some(row.get(6)?)
} else {
None
},
};
if top_old_id == -1 {
id_offset -= msg.id;
@ -177,21 +210,42 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// Data was pointlessly store padding, so unpad it:
let padded_data = match base64::decode(msg.data.unwrap()) {
Ok(d) => d,
Err(e) => panic!("Unexpected data: {} message id={} has non-base64 data ({})", room_db.display(), msg.id, e)
Err(e) => panic!(
"Unexpected data: {} message id={} has non-base64 data ({})",
room_db.display(),
msg.id,
e
),
};
let data_size = padded_data.len();
let data = match padded_data.iter().rposition(|&c| c != 0u8) {
Some(last) => &padded_data[0..=last],
None => &padded_data
None => &padded_data,
};
let sig = match base64::decode(msg.signature.unwrap()) {
Ok(d) if d.len() == 64 => d,
Ok(_) => panic!("Unexpected data: {} message id={} has invalid signature", room_db.display(), msg.id),
Err(e) => panic!("Unexpected data: {} message id={} has non-base64 signature ({})", room_db.display(), msg.id, e)
Ok(_) => panic!(
"Unexpected data: {} message id={} has invalid signature",
room_db.display(),
msg.id
),
Err(e) => panic!(
"Unexpected data: {} message id={} has non-base64 signature ({})",
room_db.display(),
msg.id,
e
),
};
ins_msg.execute(params![msg.id + id_offset, room_id, msg.session_id, (msg.ts_ms as f64) / 1000., data, data_size, sig])?;
ins_msg.execute(params![
msg.id + id_offset,
room_id,
msg.session_id,
(msg.ts_ms as f64) / 1000.,
data,
data_size,
sig
])?;
} else if msg.deleted.is_some() &&
// Deleted messages are usually set to the fixed string "deleted" (why not
// NULL?) for data and signature, so accept either null or that string if the
@ -205,7 +259,15 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// deletion id as the "updated" field. (We do this with a second query because the
// first query is going to trigger an automatic update of the field).
ins_msg.execute(params![msg.id + id_offset, room_id, msg.session_id, (msg.ts_ms as f64) / 1000., Null, Null, Null])?;
ins_msg.execute(params![
msg.id + id_offset,
room_id,
msg.session_id,
(msg.ts_ms as f64) / 1000.,
Null,
Null,
Null
])?;
} else {
panic!("Inconsistent message in {} database: message id={} has inconsistent deletion state (data: {}, signature: {}, del row: {})",
room_db.display(), msg.id, msg.data.is_some(), msg.signature.is_some(), msg.deleted.is_some());
@ -217,7 +279,10 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
info!("- ... imported {}/{} messages", imported_msgs, n_msgs);
}
}
info!("- migrated {} messages, {} duplicate deletions ignored", imported_msgs, dupe_dels);
info!(
"- migrated {} messages, {} duplicate deletions ignored",
imported_msgs, dupe_dels
);
upd_room_updates.execute(params![updated, room_id])?;
@ -229,36 +294,55 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
let mut imported_files: i64 = 0;
let n_files: i64 = rconn.query_row("SELECT COUNT(*) FROM files", [], |row| row.get(0))?;
let n_files: i64 =
rconn.query_row("SELECT COUNT(*) FROM files", [], |row| row.get(0))?;
// WTF is this id stored as a TEXT?
struct File { id: String, ts: i64 }
struct File {
id: String,
ts: i64,
}
let mut rows_st = rconn.prepare("SELECT id, timestamp FROM files")?;
let mut file_rows = rows_st.query([])?;
while let Some(row) = file_rows.next()? {
let file = File { id: row.get(0)?, ts: row.get(1)? };
let old_id = match file.id.parse::<i64>() {
Ok(id) => id,
Err(e) => panic!("Invalid fileid '{}' found in {}: {}", file.id, room_db.display(), e)
Err(e) => {
panic!("Invalid fileid '{}' found in {}: {}", file.id, room_db.display(), e)
}
};
let old_path = format!("files/{}_files/{}", room.token, old_id);
let size = match fs::metadata(&old_path) {
Ok(md) => md.len(),
Err(e) => {
warn!("Error accessing file {} ({}); skipping import of this upload", old_path, e);
warn!(
"Error accessing file {} ({}); skipping import of this upload",
old_path, e
);
continue;
}
};
let ts = if file.ts > 10000000000 {
warn!("- file {} has nonsensical timestamp {}; importing it with current time", old_path, file.ts);
warn!(
"- file {} has nonsensical timestamp {}; importing it with current time",
old_path, file.ts
);
SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs_f64()
} else {
file.ts as f64
};
let new_id = ins_file.query_row(
params![room_id, size, ts, ts + handlers::UPLOAD_DEFAULT_EXPIRY.as_secs_f64(), old_path],
|row| row.get::<_, i64>(0))?;
params![
room_id,
size,
ts,
ts + handlers::UPLOAD_DEFAULT_EXPIRY.as_secs_f64(),
old_path
],
|row| row.get::<_, i64>(0),
)?;
ins_file_hack.execute(params![room_id, old_id, new_id])?;
imported_files += 1;
@ -266,7 +350,9 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
info!("- ... imported {}/{} files", imported_files, n_files);
}
}
if imported_files > 0 { used_file_hacks = true; }
if imported_files > 0 {
used_file_hacks = true;
}
info!("- migrated {} files", imported_files);
// There's also a potential room image, which is just stored on disk and not referenced in
@ -282,11 +368,21 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
panic!("Unable to mkdir {} for room file storage: {}", files_dir, e);
}
let file_id = ins_file.query_row(
params![room_id, md.len(), md.mtime() as f64 + md.mtime_nsec() as f64 * 1e-9, Null, "tmp"],
|row| row.get::<_, i64>(0))?;
params![
room_id,
md.len(),
md.mtime() as f64 + md.mtime_nsec() as f64 * 1e-9,
Null,
"tmp"
],
|row| row.get::<_, i64>(0),
)?;
let new_image_path = format!("uploads/{}/{}_(unnamed)", room.token, file_id);
if let Err(e) = fs::hard_link(&room_image_path, &new_image_path) {
panic!("Unable to hard link room image file {} => {}: {}", room_image_path, new_image_path, e);
panic!(
"Unable to hard link room image file {} => {}: {}",
room_image_path, new_image_path, e
);
}
upd_file_path.execute(params![new_image_path, file_id])?;
upd_room_image.execute(params![file_id, room_id])?;
@ -325,22 +421,36 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
let mut imported_activity: i64 = 0;
let mut imported_active: i64 = 0;
// Don't import rows we're going to immediately prune:
let import_cutoff = (SystemTime::now() - storage::ROOM_ACTIVE_PRUNE_THRESHOLD).duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs_f64();
let n_activity: i64 = rconn.query_row("SELECT COUNT(*) FROM user_activity WHERE last_active > ?",
params![import_cutoff], |row| row.get(0))?;
let import_cutoff = (SystemTime::now() - storage::ROOM_ACTIVE_PRUNE_THRESHOLD)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs_f64();
let n_activity: i64 = rconn.query_row(
"SELECT COUNT(*) FROM user_activity WHERE last_active > ?",
params![import_cutoff],
|row| row.get(0),
)?;
let mut activity_st = rconn.prepare("SELECT public_key, last_active FROM user_activity WHERE last_active > ? AND public_key IS NOT NULL")?;
let mut act_rows = activity_st.query(params![import_cutoff])?;
let cutoff = (SystemTime::now() - handlers::ROOM_DEFAULT_ACTIVE_THRESHOLD).duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs_f64();
let cutoff = (SystemTime::now() - handlers::ROOM_DEFAULT_ACTIVE_THRESHOLD)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs_f64();
while let Some(row) = act_rows.next()? {
let session_id: String = row.get(0)?;
let ts: f64 = row.get::<_, i64>(1)? as f64;
ins_user.execute(params![session_id])?;
ins_room_activity.execute(params![room_id, session_id, ts])?;
upd_user_activity.execute(params![ts, session_id])?;
if ts >= cutoff { imported_active += 1; }
if ts >= cutoff {
imported_active += 1;
}
imported_activity += 1;
if imported_activity % 1000 == 0 {
info!("- ... imported {}/{} user activity records ({} active)", imported_activity, n_activity, imported_active);
info!(
"- ... imported {}/{} user activity records ({} active)",
imported_activity, n_activity, imported_active
);
}
}
warn!("Imported room {}: {} messages, {} files, {} moderators, {} bans, {} users ({} active)",

View file

@ -9,7 +9,7 @@ pub struct User {
pub last_active: f64,
pub banned: bool,
pub moderator: bool,
pub admin: bool
pub admin: bool,
}
impl User {
@ -21,13 +21,15 @@ impl User {
last_active: row.get(row.column_index("last_active")?)?,
banned: row.get(row.column_index("banned")?)?,
moderator: row.get(row.column_index("moderator")?)?,
admin: row.get(row.column_index("admin")?)?
admin: row.get(row.column_index("admin")?)?,
});
}
}
fn as_opt_base64<S>(val: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error>
where S: Serializer {
where
S: Serializer,
{
s.serialize_str(&base64::encode(val.as_ref().unwrap()))
}
@ -50,7 +52,7 @@ pub struct OldMessage {
/// XEd25519 message signature of the `data` bytes (not the base64 representation), encoded in
/// base64
#[serde(serialize_with = "as_opt_base64")]
pub signature: Option<Vec<u8>>
pub signature: Option<Vec<u8>>,
}
impl OldMessage {
@ -59,19 +61,18 @@ impl OldMessage {
repad(&mut data, row.get::<_, Option<usize>>(row.column_index("data_size")?)?);
let session_id = match row.column_index("session_id") {
Ok(index) => Some(row.get(index)?),
Err(_) => None
Err(_) => None,
};
return Ok(OldMessage {
server_id: row.get(row.column_index("id")?)?,
public_key: session_id,
timestamp: (row.get::<_, f64>(row.column_index("posted")?)? * 1000.0) as i64,
data,
signature: row.get(row.column_index("signature")?)?
signature: row.get(row.column_index("signature")?)?,
});
}
}
#[derive(Debug, Serialize)]
pub struct Message {
/// The message id.
@ -95,7 +96,7 @@ pub struct Message {
pub signature: Option<Vec<u8>>,
/// Flag set to true if the message is deleted, and omitted otherwise.
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted: Option<bool>
pub deleted: Option<bool>,
}
fn repad(data: &mut Option<Vec<u8>>, size: Option<usize>) {
@ -113,7 +114,7 @@ impl Message {
let deleted = if data.is_none() { Some(true) } else { None };
let session_id = match row.column_index("session_id") {
Ok(index) => Some(row.get(index)?),
Err(_) => None
Err(_) => None,
};
return Ok(Message {
id: row.get(row.column_index("id")?)?,
@ -123,13 +124,15 @@ impl Message {
updated: row.get(row.column_index("updated")?)?,
data,
signature: row.get(row.column_index("signature")?)?,
deleted
deleted,
});
}
}
fn bytes_from_base64<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where D: Deserializer<'de> {
where
D: Deserializer<'de>,
{
use serde::de::Error;
String::deserialize(deserializer)
.and_then(|str| base64::decode(&str).map_err(|err| Error::custom(err.to_string())))
@ -140,14 +143,14 @@ pub struct PostMessage {
#[serde(deserialize_with = "bytes_from_base64")]
pub data: Vec<u8>,
#[serde(deserialize_with = "bytes_from_base64")]
pub signature: Vec<u8>
pub signature: Vec<u8>,
}
#[derive(Debug, Serialize)]
pub struct DeletedMessage {
#[serde(rename = "id")]
pub updated: i64,
pub deleted_message_id: i64
pub deleted_message_id: i64,
}
#[derive(Debug, Serialize)]
@ -165,7 +168,7 @@ pub struct Room {
pub updates: i64,
pub default_read: bool,
pub default_write: bool,
pub default_upload: bool
pub default_upload: bool,
}
impl Room {
@ -180,13 +183,11 @@ impl Room {
updates: row.get(row.column_index("updates")?)?,
default_read: row.get(row.column_index("read")?)?,
default_write: row.get(row.column_index("write")?)?,
default_upload: row.get(row.column_index("upload")?)?
default_upload: row.get(row.column_index("upload")?)?,
});
}
}
// FIXME: this appears to be used for both add/remove. But what if we want to promote to admin, or
// demote to moderator?
#[derive(Debug, Deserialize)]
@ -195,7 +196,7 @@ pub struct ChangeModeratorRequestBody {
pub room_token: String,
#[serde(rename = "public_key")]
pub session_id: String,
pub admin: Option<bool>
pub admin: Option<bool>,
}
#[derive(Debug, Deserialize)]
@ -204,7 +205,7 @@ pub struct PollRoomMetadata {
pub room: String,
/// The last `info_update` value the client has; results are only returned if the room has been
/// modified since the value provided by the client.
pub since_update: i64
pub since_update: i64,
}
#[derive(Debug, Serialize)]
@ -216,7 +217,7 @@ pub struct RoomDetails {
/// Metadata of the room; this omitted from the response when polling if the room metadata
/// (other than active user count) has not changed since the request update counter.
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<RoomMetadata>
pub details: Option<RoomMetadata>,
}
#[derive(Debug, Serialize)]
@ -255,7 +256,7 @@ pub struct RoomMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub moderator: Option<bool>,
/// Will be present and true if the requesting user has admin powers, omitted otherwise.
pub admin: Option<bool>
pub admin: Option<bool>,
}
#[derive(Debug, Deserialize)]
@ -264,7 +265,7 @@ pub struct PollRoomMessages {
pub room: String,
/// Return new messages, edit, and deletions posted since this `updates` value. Clients should
/// poll with the most recent updates value they have received.
pub since_update: i64
pub since_update: i64,
}
#[derive(Debug, Serialize)]
@ -272,7 +273,7 @@ pub struct RoomMessages {
/// The token of this room
pub room: String,
/// Vector of new/edited/deleted message posted to the room since the requested update.
pub messages: Vec<Message>
pub messages: Vec<Message>,
}
#[derive(Debug, Deserialize)]
@ -288,7 +289,7 @@ pub struct CompactPollRequestBody {
// messages/deletions, in reverse order from what you get with regular polling. New clients
// should update to the new polling endpoints ASAP.
pub from_message_server_id: Option<i64>,
pub from_deletion_server_id: Option<i64>
pub from_deletion_server_id: Option<i64>,
}
#[derive(Debug, Serialize)]
@ -298,16 +299,16 @@ pub struct CompactPollResponseBody {
pub status_code: u16,
pub deletions: Vec<DeletedMessage>,
pub messages: Vec<OldMessage>,
pub moderators: Vec<String>
pub moderators: Vec<String>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Challenge {
pub ciphertext: String,
pub ephemeral_public_key: String
pub ephemeral_public_key: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct StatusCode {
pub status_code: u16
pub status_code: u16,
}

View file

@ -12,12 +12,12 @@ use super::rpc;
#[derive(Deserialize, Serialize, Debug)]
struct OnionRequestPayload {
pub ciphertext: Vec<u8>,
pub metadata: OnionRequestPayloadMetadata
pub metadata: OnionRequestPayloadMetadata,
}
#[derive(Deserialize, Serialize, Debug)]
struct OnionRequestPayloadMetadata {
pub ephemeral_key: String
pub ephemeral_key: String,
}
pub async fn handle_onion_request(blob: warp::hyper::body::Bytes) -> Result<Response, Rejection> {
@ -36,9 +36,8 @@ pub async fn handle_onion_request(blob: warp::hyper::body::Bytes) -> Result<Resp
async fn handle_decrypted_onion_request(
plaintext: &[u8],
symmetric_key: &[u8]
) -> Result<Response, Rejection>
{
symmetric_key: &[u8],
) -> Result<Response, Rejection> {
let rpc_call = match serde_json::from_slice(plaintext) {
Ok(rpc_call) => rpc_call,
Err(e) => {
@ -59,7 +58,7 @@ async fn handle_decrypted_onion_request(
}
fn parse_onion_request_payload(
blob: warp::hyper::body::Bytes
blob: warp::hyper::body::Bytes,
) -> Result<OnionRequestPayload, Rejection> {
// The encoding of an onion request looks like:
//
@ -107,7 +106,7 @@ fn parse_onion_request_payload(
/// Returns the decrypted `payload.ciphertext` plus the `symmetric_key` that was used for
/// decryption if successful.
fn decrypt_onion_request_payload(
payload: OnionRequestPayload
payload: OnionRequestPayload,
) -> Result<(Vec<u8>, Vec<u8>), Rejection> {
let ephemeral_key = hex::decode(payload.metadata.ephemeral_key).unwrap(); // Safe because it was validated in the parsing step
let symmetric_key = crypto::get_x25519_symmetric_key(&ephemeral_key, &crypto::PRIVATE_KEY)?;

View file

@ -64,5 +64,5 @@ pub struct Opt {
/// Prints the URL format users can use to join rooms on this open group server.
#[structopt(long = "print-url")]
pub print_url: bool
pub print_url: bool,
}

View file

@ -105,9 +105,8 @@ pub async fn root_html() -> Result<Response, Rejection> {
pub async fn fallback_html(
room: String,
query_map: HashMap<String, String>
) -> Result<Response, Rejection>
{
query_map: HashMap<String, String>,
) -> Result<Response, Rejection> {
if !query_map.contains_key("public_key") || room == "" {
return fallback_nopubkey_html().await;
}

View file

@ -15,7 +15,7 @@ use super::storage;
#[allow(dead_code)]
pub enum Mode {
FileServer,
OpenGroupServer
OpenGroupServer,
}
#[derive(Deserialize, Debug)]
@ -34,7 +34,7 @@ pub struct RpcCall {
/// Arbitrary string; must be different on each request
pub nonce: Option<String>,
/// Ed25519 signature (in base64 or hex) of (method || endpoint || body || nonce)
pub signature: Option<String>
pub signature: Option<String>,
}
pub const MODE: Mode = Mode::OpenGroupServer;
@ -43,9 +43,8 @@ pub const MODE: Mode = Mode::OpenGroupServer;
// is a parseable auth token, and an error for anything else.
fn get_user_from_auth_header(
conn: &rusqlite::Connection,
rpc: &RpcCall
) -> Result<Option<User>, Error>
{
rpc: &RpcCall,
) -> Result<Option<User>, Error> {
if let Some(auth_token_str) = rpc.headers.get("Authorization") {
return Ok(Some(handlers::get_user_from_token(conn, auth_token_str)?));
}
@ -77,16 +76,20 @@ pub async fn handle_rpc_call(mut rpc_call: RpcCall) -> Result<Response, Rejectio
let mut sig_bytes: [u8; 64] = [0; 64];
sig_bytes.copy_from_slice(
&handlers::decode_hex_or_b64(rpc_call.signature.as_ref().unwrap(), 64)?[0..64]
&handlers::decode_hex_or_b64(rpc_call.signature.as_ref().unwrap(), 64)?[0..64],
);
let sig = ed25519_dalek::Signature::new(sig_bytes);
if let Err(sigerr) = crypto::verify_signature(&edpk, &sig, &vec![
if let Err(sigerr) = crypto::verify_signature(
&edpk,
&sig,
&vec![
rpc_call.endpoint.as_bytes(),
rpc_call.method.as_bytes(),
rpc_call.body.as_bytes(),
nonce.as_bytes(),
]) {
],
) {
warn!("Signature verification failed for request from {}", sessionid);
return Err(sigerr.into());
}
@ -111,7 +114,7 @@ pub async fn handle_rpc_call(mut rpc_call: RpcCall) -> Result<Response, Rejectio
path = uri.path().trim_start_matches('/').to_string();
query_params = match uri.query() {
Some(qs) => form_urlencoded::parse(qs.as_bytes()).into_owned().collect(),
None => HashMap::new()
None => HashMap::new(),
};
}
Err(e) => {
@ -153,9 +156,8 @@ async fn handle_get_request(
rpc_call: RpcCall,
path: &str,
user: Option<User>,
query_params: HashMap<String, String>
) -> Result<Response, Rejection>
{
query_params: HashMap<String, String>,
) -> Result<Response, Rejection> {
let mut components: Vec<&str> = path.split('/').collect();
if components.len() == 0 {
components.push("");
@ -165,7 +167,7 @@ async fn handle_get_request(
if components[0] == "auth_token_challenge" && components.len() == 1 {
reject_if_file_server_mode(path)?;
let challenge = handlers::get_auth_token_challenge(
query_params.get("public_key").ok_or(Error::InvalidRpcCall)?
query_params.get("public_key").ok_or(Error::InvalidRpcCall)?,
)?;
let response = json!({ "status_code": StatusCode::OK.as_u16(), "challenge": challenge });
return Ok(warp::reply::json(&response).into_response());
@ -208,7 +210,7 @@ async fn handle_get_request(
warn!("Ignoring RPC call with invalid or unused endpoint: {}.", path);
return Err(Error::InvalidRpcCall.into());
}
Mode::FileServer => ()
Mode::FileServer => (),
}
let platform = query_params
.get("platform")
@ -216,7 +218,7 @@ async fn handle_get_request(
let version = handlers::get_session_version(platform).await?;
let response = handlers::GenericStringResponse {
status_code: StatusCode::OK.as_u16(),
result: version
result: version,
};
return Ok(warp::reply::json(&response).into_response());
}
@ -299,9 +301,8 @@ async fn handle_post_request(
room: Option<Room>,
rpc_call: RpcCall,
path: &str,
user: Option<User>
) -> Result<Response, Rejection>
{
user: Option<User>,
) -> Result<Response, Rejection> {
// Handle routes that don't require authorization first
// The compact poll endpoint expects the auth token to be in the request body; not in the
@ -314,7 +315,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct CompactPollRequestBodyWrapper {
requests: Vec<models::CompactPollRequestBody>
requests: Vec<models::CompactPollRequestBody>,
}
let wrapper: CompactPollRequestBodyWrapper = match serde_json::from_str(&rpc_call.body) {
Ok(bodies) => bodies,
@ -344,7 +345,7 @@ async fn handle_post_request(
if components.len() == 3 && components[2] == "image" {
#[derive(Debug, Deserialize)]
struct JSON {
file: String
file: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -405,7 +406,7 @@ async fn handle_post_request(
// filename
#[derive(Debug, Deserialize)]
struct JSON {
file: String
file: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -418,7 +419,7 @@ async fn handle_post_request(
// FIXME TODO: add an input field so that the uploader can pass the filename
let filename: Option<&str> = None;
return handlers::store_file(room, user, &json.file, filename);
return handlers::store_file(&room, &user, &json.file, filename);
}
// FIXME: deprecate these next two separate endpoints and replace with a single
@ -428,7 +429,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct JSON {
public_key: String
public_key: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -443,7 +444,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct JSON {
public_key: String
public_key: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -478,7 +479,7 @@ async fn handle_post_request(
room,
user,
&body.session_id,
body.admin.unwrap_or(false)
body.admin.unwrap_or(false),
);
}
"delete_messages" => {
@ -488,7 +489,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct JSON {
ids: Vec<i64>
ids: Vec<i64>,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -510,9 +511,8 @@ async fn handle_delete_request(
room: Room,
rpc_call: RpcCall,
path: &str,
user: Option<User>
) -> Result<Response, Rejection>
{
user: Option<User>,
) -> Result<Response, Rejection> {
// Check that the auth token is present
let user = user.ok_or(Error::NoAuthToken)?;
// DELETE /messages/:server_id
@ -594,7 +594,7 @@ fn get_room(rpc_call: &RpcCall) -> Result<Option<Room>, Error> {
let room_token = match rpc_call.headers.get("Room") {
Some(s) => s,
None => return Ok(None)
None => return Ok(None),
};
return Ok(Some(storage::get_room_from_token(&*storage::get_conn()?, room_token)?));
}
@ -605,6 +605,6 @@ fn reject_if_file_server_mode(path: &str) -> Result<(), Rejection> {
warn!("Ignoring RPC call with invalid or unused endpoint: {}.", path);
return Err(Error::InvalidRpcCall.into());
}
Mode::OpenGroupServer => return Ok(())
Mode::OpenGroupServer => return Ok(()),
}
}

View file

@ -1,16 +1,17 @@
use std::fs;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::time::{Duration, SystemTime};
use log::{error, info, warn};
use r2d2::PooledConnection;
use r2d2_sqlite::SqliteConnectionManager;
use regex::Regex;
use rusqlite::{config::DbConfig, params};
use super::errors::Error;
use super::models::Room;
use super::migration;
use super::models::Room;
pub type DatabaseConnection = r2d2::PooledConnection<SqliteConnectionManager>;
pub type DatabaseConnectionPool = r2d2::Pool<SqliteConnectionManager>;
@ -32,7 +33,7 @@ impl RoomId {
Ok(())
} else {
Err(Error::ValidationFailed)
}
};
}
pub fn new(room_id: &str) -> Result<RoomId, Error> {
@ -51,12 +52,11 @@ pub const ROOM_ACTIVE_PRUNE_THRESHOLD: Duration = Duration::from_secs(60 * 86400
// How long we keep message edit/deletion history.
pub const MESSAGE_HISTORY_PRUNE_THRESHOLD: Duration = Duration::from_secs(30 * 86400);
// Migration support: when migrating to 0.2.x old room ids cannot be preserved, so we map the old
// id range [1, max] to the new range [offset+1, offset+max].
pub struct RoomMigrationMap {
pub max: i64,
pub offset: i64
pub offset: i64,
}
lazy_static::lazy_static! {
@ -127,9 +127,12 @@ pub fn get_conn() -> Result<DatabaseConnection, Error> {
}
}
fn get_room_hacks(conn: &rusqlite::Connection) -> Result<HashMap<i64, RoomMigrationMap>, rusqlite::Error> {
fn get_room_hacks(
conn: &rusqlite::Connection,
) -> Result<HashMap<i64, RoomMigrationMap>, rusqlite::Error> {
let mut hacks = HashMap::new();
let mut st = conn.prepare("SELECT room, old_message_id_max, message_id_offset FROM room_import_hacks")?;
let mut st =
conn.prepare("SELECT room, old_message_id_max, message_id_offset FROM room_import_hacks")?;
let mut query = st.query([])?;
while let Some(row) = query.next()? {
hacks.insert(row.get(0)?, RoomMigrationMap { max: row.get(1)?, offset: row.get(2)? });
@ -138,7 +141,7 @@ fn get_room_hacks(conn: &rusqlite::Connection) -> Result<HashMap<i64, RoomMigrat
}
pub fn get_transaction<'a>(
conn: &'a mut DatabaseConnection
conn: &'a mut DatabaseConnection,
) -> Result<DatabaseTransaction<'a>, Error> {
conn.transaction().map_err(db_error)
}
@ -150,16 +153,19 @@ pub fn db_error(e: rusqlite::Error) -> Error {
/// Initialize the database, creating and migrating its structure if necessary.
pub fn setup_database() {
let mut conn = get_conn().unwrap();
setup_database_with_conn(&mut conn);
}
pub fn setup_database_with_conn(conn: &mut PooledConnection<SqliteConnectionManager>) {
if rusqlite::version_number() < 3035000 {
panic!("SQLite 3.35.0+ is required!");
}
let mut conn = get_conn().unwrap();
let have_messages = match conn.query_row(
"SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'messages')",
params![],
|row| row.get::<_, bool>(0)
|row| row.get::<_, bool>(0),
) {
Ok(exists) => exists,
Err(e) => {
@ -172,7 +178,8 @@ pub fn setup_database() {
conn.execute_batch(include_str!("schema.sql")).expect("Couldn't create database schema.");
}
let n_rooms = match conn.query_row("SELECT COUNT(*) FROM rooms", params![], |row| row.get::<_, i64>(0)) {
let n_rooms =
match conn.query_row("SELECT COUNT(*) FROM rooms", params![], |row| row.get::<_, i64>(0)) {
Ok(r) => r,
Err(e) => {
panic!("Error querying database for # of rooms: {}", e);
@ -183,7 +190,7 @@ pub fn setup_database() {
// If we have no rooms then check to see if there is an old (pre-v0.2) set of databases to
// import from.
warn!("No rooms found, but database.db exists; attempting migration");
if let Err(e) = migration::migrate_0_2_0(&mut conn) {
if let Err(e) = migration::migrate_0_2_0(conn) {
panic!("\n\ndatabase.db exists but migration failed:\n\n {}.\n\n\
Please report this bug!\n\n\
If no migration from 0.1.x is needed then rename or delete database.db to start up with a fresh (new) database.\n\n", e);
@ -193,7 +200,6 @@ pub fn setup_database() {
// Future migrations here
}
// Performs periodic DB maintenance: file pruning, delayed permission applying,
// etc.
pub async fn db_maintenance_job() {
@ -216,7 +222,7 @@ pub async fn db_maintenance_job() {
/// Removes all files with expiries <= the given time (which should generally by
/// `SystemTime::now()`, except in the test suite).
fn prune_files(conn: &mut DatabaseConnection, now: &SystemTime) {
pub fn prune_files(conn: &mut DatabaseConnection, now: &SystemTime) {
let mut st = match conn.prepare_cached("DELETE FROM files WHERE expiry <= ? RETURNING path") {
Ok(st) => st,
Err(e) => {
@ -315,7 +321,7 @@ fn apply_permission_updates(conn: &mut DatabaseConnection, now: &SystemTime) {
ON CONFLICT DO UPDATE SET
read = COALESCE(excluded.read, read),
write = COALESCE(excluded.write, write),
upload = COALESCE(excluded.upload, upload)"
upload = COALESCE(excluded.upload, upload)",
) {
Ok(st) => st,
Err(e) => {
@ -355,11 +361,7 @@ fn apply_permission_updates(conn: &mut DatabaseConnection, now: &SystemTime) {
// Utilities
pub fn get_room_from_token(
conn: &rusqlite::Connection,
token: &str
) -> Result<Room, Error>
{
pub fn get_room_from_token(conn: &rusqlite::Connection, token: &str) -> Result<Room, Error> {
match conn
.prepare_cached("SELECT * FROM rooms WHERE token = ?")
.map_err(db_error)?
@ -367,6 +369,6 @@ pub fn get_room_from_token(
{
Ok(room) => return Ok(room),
Err(rusqlite::Error::QueryReturnedNoRows) => return Err(Error::NoSuchRoom.into()),
Err(_) => return Err(Error::DatabaseFailedInternally.into())
Err(_) => return Err(Error::DatabaseFailedInternally.into()),
}
}

View file

@ -1,18 +1,25 @@
use std::collections::HashMap;
//use std::collections::HashMap;
use r2d2::PooledConnection;
use r2d2_sqlite::SqliteConnectionManager;
use std::fs;
use std::time::{Duration, SystemTime};
use rand::{thread_rng, Rng};
//use rand::{thread_rng, Rng};
use rusqlite::params;
use rusqlite::OpenFlags;
use warp::http::StatusCode;
use warp::{hyper, Reply};
use crate::storage::DatabaseConnectionPool;
use super::crypto;
use super::handlers;
use super::handlers::CreateRoom;
use super::models::User;
use super::storage;
use crate::handlers::GenericStringResponse;
async fn set_up_test_room() -> DatabaseConnectionPool {
async fn set_up_test_room() -> (PooledConnection<SqliteConnectionManager>, DatabaseConnectionPool) {
let manager = r2d2_sqlite::SqliteConnectionManager::file("file::memory:?cache=shared");
let mut flags = OpenFlags::default();
flags.set(OpenFlags::SQLITE_OPEN_URI, true);
@ -21,94 +28,71 @@ async fn set_up_test_room() -> DatabaseConnectionPool {
let pool = r2d2::Pool::<r2d2_sqlite::SqliteConnectionManager>::new(manager).unwrap();
let conn = pool.get().unwrap();
let mut conn = pool.get().unwrap();
storage::create_room_tables_if_needed(&conn);
storage::setup_database_with_conn(&mut conn);
let success = handlers::create_room_with_conn(
&conn,
&CreateRoom { token: "test_room".to_string(), name: "Test".to_string() },
);
assert!(success.is_ok());
pool
return (conn, pool);
}
fn get_auth_token(pool: &DatabaseConnectionPool) -> (String, String) {
fn get_user(conn: &rusqlite::Connection) -> User {
// Generate a fake user key pair
let (user_private_key, user_public_key) = crypto::generate_x25519_key_pair();
let (_, user_public_key) = crypto::generate_x25519_key_pair();
let hex_user_public_key = format!("05{}", hex::encode(user_public_key.to_bytes()));
// Get a challenge
let mut query_params: HashMap<String, String> = HashMap::new();
query_params.insert("public_key".to_string(), hex_user_public_key.clone());
let challenge = handlers::get_auth_token_challenge(query_params, &pool).unwrap();
// Generate a symmetric key
let ephemeral_public_key = base64::decode(challenge.ephemeral_public_key).unwrap();
let symmetric_key =
crypto::get_x25519_symmetric_key(&ephemeral_public_key, &user_private_key).unwrap();
// Decrypt the challenge
let ciphertext = base64::decode(challenge.ciphertext).unwrap();
let plaintext = crypto::decrypt_aes_gcm(&ciphertext, &symmetric_key).unwrap();
let auth_token = hex::encode(plaintext);
// Try to claim the token
let response = handlers::claim_auth_token(&hex_user_public_key, &auth_token, &pool).unwrap();
assert_eq!(response.status(), StatusCode::OK);
// return
return (auth_token, hex_user_public_key);
}
#[tokio::test]
async fn test_authorization() {
// Ensure the test room is set up and get a database connection pool
let pool = set_up_test_room().await;
// Get an auth token
// This tests claiming a token internally
let (_, hex_user_public_key) = get_auth_token(&pool);
// Try to claim an incorrect token
let mut incorrect_token = [0u8; 48];
thread_rng().fill(&mut incorrect_token[..]);
let hex_incorrect_token = hex::encode(incorrect_token);
match handlers::claim_auth_token(&hex_user_public_key, &hex_incorrect_token, &pool) {
Ok(_) => assert!(false),
Err(_) => ()
}
let result = handlers::insert_or_update_user(conn, &hex_user_public_key);
assert!(result.is_ok());
return result.unwrap();
}
#[tokio::test]
async fn test_file_handling() {
// Ensure the test room is set up and get a database connection pool
let pool = set_up_test_room().await;
let (mut conn, pool) = set_up_test_room().await;
let test_room_id = storage::RoomId::new("test_room").unwrap();
let room = storage::get_room_from_token(&conn, "test_room").unwrap();
// Get an auth token
let (auth_token, _) = get_auth_token(&pool);
let user = get_user(&conn);
// Store the test file
handlers::store_file(
Some(test_room_id.get_id().to_string()),
TEST_FILE,
Some(auth_token.clone()),
&pool
)
.await
let filename: Option<&str> = None;
let auth = handlers::AuthorizationRequired { upload: true, write: true, ..Default::default() };
let id =
match handlers::store_file_impl(&mut conn, &room, &user, auth, TEST_FILE, filename, true)
.ok()
{
Some(mut upload) => {
let result = upload.commit();
assert!(result.is_ok());
Some(upload.id)
}
_ => None,
}
.unwrap();
// Check that there's a file record
let conn = pool.get().unwrap();
let raw_query = "SELECT id FROM files";
let id_as_string: String =
conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap();
let id = id_as_string.parse::<u64>().unwrap();
// Retrieve the file and check the content
let base64_encoded_file = handlers::get_file(
Some(test_room_id.get_id().to_string()),
id,
Some(auth_token.clone()),
&pool,
)
.await
.unwrap()
.result;
assert_eq!(base64_encoded_file, TEST_FILE);
let room_id = room.id;
let response = handlers::get_file_conn(&mut conn, &room, id, user).unwrap();
let response_bytes = hyper::body::to_bytes(response).await.ok();
// The expected json response
let json = GenericStringResponse {
status_code: StatusCode::OK.as_u16(),
result: TEST_FILE.to_string(),
};
let expected_result = warp::reply::json(&json).into_response();
let expected_result_bytes = hyper::body::to_bytes(expected_result).await.ok();
assert_eq!(response_bytes, expected_result_bytes);
// Prune the file and check that it's gone
// Will evaluate to now + 60
storage::prune_files_for_room(&pool, &test_room_id, -60).await;
let sixty_seconds_ago = SystemTime::now() - Duration::new(60, 0);
storage::prune_files(&mut conn, &sixty_seconds_ago);
// It should be gone now
fs::read(format!("files/{}_files/{}", test_room_id.get_id(), id)).unwrap_err();
fs::read(format!("files/{}_files/{}", room_id, id)).unwrap_err();
// Check that the file record is also gone
let conn = pool.get().unwrap();
let raw_query = "SELECT id FROM files";