test suite and cargp fmt

This commit is contained in:
Sean Darcy 2021-10-08 16:56:27 +11:00
parent ad4cc06ad8
commit 953b752100
14 changed files with 507 additions and 396 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@
*.pem
/files
/rooms
/uploads

View File

@ -62,7 +62,7 @@ lazy_static::lazy_static! {
/// Takes hex string representation of an ed25519 pubkey, returns the ed25519
/// pubkey, derived x25519 pubkey, and the Session id in hex.
pub fn get_pubkeys(
edpk_hex: &str
edpk_hex: &str,
) -> Result<(ed25519_dalek::PublicKey, x25519_dalek::PublicKey, String), warp::reject::Rejection> {
if edpk_hex.len() != 64 {
return Err(warp::reject::custom(Error::DecryptionFailed));
@ -92,9 +92,8 @@ pub fn get_pubkeys(
pub fn get_x25519_symmetric_key(
public_key: &[u8],
private_key: &x25519_dalek::StaticSecret
) -> Result<Vec<u8>, warp::reject::Rejection>
{
private_key: &x25519_dalek::StaticSecret,
) -> Result<Vec<u8>, warp::reject::Rejection> {
if public_key.len() != 32 {
error!(
"Couldn't create symmetric key using public key of invalid length: {}.",
@ -112,9 +111,8 @@ pub fn get_x25519_symmetric_key(
pub fn encrypt_aes_gcm(
plaintext: &[u8],
symmetric_key: &[u8]
) -> Result<Vec<u8>, warp::reject::Rejection>
{
symmetric_key: &[u8],
) -> Result<Vec<u8>, warp::reject::Rejection> {
let mut iv = [0u8; IV_SIZE];
thread_rng().fill(&mut iv[..]);
let cipher = Aes256Gcm::new(&GenericArray::from_slice(symmetric_key));
@ -133,9 +131,8 @@ pub fn encrypt_aes_gcm(
pub fn decrypt_aes_gcm(
iv_and_ciphertext: &[u8],
symmetric_key: &[u8]
) -> Result<Vec<u8>, warp::reject::Rejection>
{
symmetric_key: &[u8],
) -> Result<Vec<u8>, warp::reject::Rejection> {
if iv_and_ciphertext.len() < IV_SIZE {
warn!("Ignoring ciphertext of invalid size: {}.", iv_and_ciphertext.len());
return Err(warp::reject::custom(Error::DecryptionFailed));
@ -162,9 +159,8 @@ pub fn generate_x25519_key_pair() -> (x25519_dalek::StaticSecret, x25519_dalek::
pub fn verify_signature(
edpk: &ed25519_dalek::PublicKey,
sig: &ed25519_dalek::Signature,
parts: &[&[u8]]
) -> Result<(), Error>
{
parts: &[&[u8]],
) -> Result<(), Error> {
let mut verify_buf: Vec<u8> = Vec::new();
let verify: &[u8];
if parts.len() == 1 {

View File

@ -16,7 +16,7 @@ pub enum Error {
/// The requesting user provided a valid auth token, but they don't have a
/// high enough permission level.
Unauthorized,
ValidationFailed
ValidationFailed,
}
impl Reject for Error {}

View File

@ -8,6 +8,8 @@ use base64;
use ed25519_dalek::Signer;
use log::{debug, error, info, warn};
use parking_lot::RwLock;
use r2d2::PooledConnection;
use r2d2_sqlite::SqliteConnectionManager;
use regex::Regex;
use rusqlite::{params, params_from_iter};
use serde::{Deserialize, Serialize};
@ -29,17 +31,17 @@ use super::storage::{self, db_error};
// for moderators/admins.
#[derive(Default)]
pub struct AuthorizationRequired {
admin: bool, // Required admin permission (server or room)
moderator: bool, // Requires moderator or admin permission (server or room)
read: bool, // Requires read permission
write: bool, // Requires write permission
upload: bool // Requires upload permission
pub admin: bool, // Required admin permission (server or room)
pub moderator: bool, // Requires moderator or admin permission (server or room)
pub read: bool, // Requires read permission
pub write: bool, // Requires write permission
pub upload: bool, // Requires upload permission
}
#[derive(Debug, Serialize)]
pub struct GenericStringResponse {
pub status_code: u16,
pub result: String
pub result: String,
}
// FIXME: this is used to query the github API periodically to find new releases. Ew.
@ -86,14 +88,12 @@ pub const TOKEN_ID_SIZE: usize = 33;
pub const TOKEN_SIG_SIZE: usize = 64;
pub const TOKEN_SIZE: usize = TOKEN_ID_SIZE + TOKEN_SIG_SIZE;
// Rooms
//
#[derive(Deserialize)]
pub struct CreateRoom {
pub token: String,
pub name: String
pub name: String,
}
// Not publicly exposed.
@ -103,10 +103,7 @@ pub async fn create_room(room: CreateRoom) -> Result<Response, Rejection> {
// Get a connection
let conn = storage::get_conn()?;
// Insert the room
let stmt = "INSERT INTO rooms (token, name) VALUES (?, ?) \
ON CONFLICT DO UPDATE SET token = excluded.token, name = excluded.name";
if let Err(e) = conn.execute(&stmt, params![room.token, room.name]) {
if let Err(e) = create_room_with_conn(&conn, &room) {
error!("Couldn't create room: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
@ -117,6 +114,16 @@ pub async fn create_room(room: CreateRoom) -> Result<Response, Rejection> {
Ok(warp::reply::json(&json).into_response())
}
pub fn create_room_with_conn(
conn: &PooledConnection<SqliteConnectionManager>,
room: &CreateRoom,
) -> Result<usize, rusqlite::Error> {
// Insert the room
let stmt = "INSERT INTO rooms (token, name) VALUES (?, ?) \
ON CONFLICT DO UPDATE SET token = excluded.token, name = excluded.name";
return conn.execute(&stmt, params![room.token, room.name]);
}
// Not publicly exposed.
pub async fn delete_room(token: String) -> Result<Response, Rejection> {
// Get a connection
@ -187,10 +194,13 @@ pub fn get_all_rooms_v01x() -> Result<Response, Rejection> {
#[derive(Debug, Serialize)]
struct OldRoom {
id: String,
name: String
name: String,
}
let rooms = get_all_rooms_impl()?.into_iter().map(|r| OldRoom{ id: r.token, name: r.name }).collect::<Vec<OldRoom>>();
let rooms = get_all_rooms_impl()?
.into_iter()
.map(|r| OldRoom { id: r.token, name: r.name })
.collect::<Vec<OldRoom>>();
let response = json!({ "status_code": StatusCode::OK.as_u16(), "rooms": rooms });
Ok(warp::reply::json(&response).into_response())
@ -201,12 +211,12 @@ pub fn get_all_rooms_v01x() -> Result<Response, Rejection> {
/// RAII class holding an in-progress upload transaction and path details. If this is dropped
/// without `commit()` being called we remove the file from disk and abort the transaction
/// inserting the upload into the database.
struct FileUpload<'a> {
pub struct FileUpload<'a> {
pub id: i64, // The value of `id` in the `files` table for this new file
pub room: &'a Room, // The room the file is uploaded to
pub path: String, // The relative path containing the in-progress file upload
tx: Option<rusqlite::Transaction<'a>>,
committed: bool
committed: bool,
}
impl FileUpload<'_> {
pub fn new<'a>(tx: rusqlite::Transaction<'a>, room: &'a Room) -> FileUpload<'a> {
@ -232,22 +242,20 @@ impl Drop for FileUpload<'_> {
}
}
/// Does the actual work involved in storing a file, inserting into the database, etc.
///
/// Returns a FileUpload on success. The caller may optionally use this to perform additional
/// actions, but *must* call `.commit()` on success -- if dropped the FileUpload will clean up the
/// temporary file and drop the transaction inserting the records.
fn store_file_impl<'a>(
pub fn store_file_impl<'a>(
conn: &'a mut storage::DatabaseConnection,
room: &'a Room,
user: &User,
auth: AuthorizationRequired,
data_b64: &str,
filename: Option<&str>,
expires: bool
) -> Result<FileUpload<'a>, Rejection>
{
expires: bool,
) -> Result<FileUpload<'a>, Rejection> {
// Determine the file size from the base64 data without decoding it (we'll do that later
// directly to the destination file).
let mut bytes: usize = data_b64.len() / 4 * 3;
@ -293,7 +301,7 @@ fn store_file_impl<'a>(
(SystemTime::now() + UPLOAD_DEFAULT_EXPIRY)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs_f64()
.as_secs_f64(),
)
} else {
None
@ -305,7 +313,7 @@ fn store_file_impl<'a>(
.prepare_cached(
"INSERT INTO files (room, uploader, size, expiry, filename, path) \
VALUES (?, ?, ?, ?, ?, 'tmp') \
RETURNING id"
RETURNING id",
)
.map_err(db_error)?
.query_row(params![room.id, user.id, bytes, db_filename, expiry], |row| row.get(0))
@ -323,7 +331,7 @@ fn store_file_impl<'a>(
if fs_filename.len() > UPLOAD_FILENAME_MAX {
fs_filename.replace_range(
UPLOAD_FILENAME_KEEP_PREFIX..fs_filename.len() - UPLOAD_FILENAME_KEEP_SUFFIX,
"..."
"...",
);
}
fs_filename = format!("{}/{}_{}", files_dir, upload.id, fs_filename);
@ -373,12 +381,12 @@ fn store_file_impl<'a>(
}
pub fn store_file(
room: Room,
user: User,
room: &Room,
user: &User,
data_b64: &str,
filename: Option<&str>
) -> Result<Response, Rejection>
{
filename: Option<&str>,
) -> Result<Response, Rejection> {
let mut conn = storage::get_conn()?;
if !matches!(rpc::MODE, rpc::Mode::OpenGroupServer) {
panic!("FIXME file mode FIXME FIXME TODO!");
// FIXME TODO
@ -386,18 +394,16 @@ pub fn store_file(
let auth = AuthorizationRequired { upload: true, write: true, ..Default::default() };
let mut conn = storage::get_conn()?;
let mut upload = match store_file_impl(&mut conn, &room, &user, auth, data_b64, filename, true)
{
Ok(id) => id,
Err(e) => return Err(e)
Err(e) => return Err(e),
};
if let Err(e) = upload.commit() {
error!("File upload failed: {}", e);
return Err(Error::DatabaseFailedInternally.into());
}
let response = json!({ "status_code": StatusCode::OK.as_u16(), "result": upload.id });
Ok(warp::reply::json(&response).into_response())
}
@ -437,18 +443,28 @@ fn file_response(path_row: rusqlite::Result<String>) -> Result<Response, Rejecti
// Return
let json = GenericStringResponse {
status_code: StatusCode::OK.as_u16(),
result: base64_encoded_bytes
result: base64_encoded_bytes,
};
Ok(warp::reply::json(&json).into_response())
}
pub fn get_file(room: Room, id: i64, user: User) -> Result<Response, Rejection> {
let conn = storage::get_conn()?;
return get_file_conn(&conn, &room, id, user);
}
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
pub fn get_file_conn(
conn: &PooledConnection<SqliteConnectionManager>,
room: &Room,
id: i64,
user: User,
) -> Result<Response, Rejection> {
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let mut row = conn
.prepare_cached("SELECT path FROM files WHERE room = ? AND id = ?")
@ -474,7 +490,7 @@ pub async fn get_room_image(room: Room) -> Result<Response, Rejection> {
let row = conn
.prepare_cached(
"SELECT path FROM rooms JOIN files ON rooms.image = files.id WHERE rooms.id = ?"
"SELECT path FROM rooms JOIN files ON rooms.image = files.id WHERE rooms.id = ?",
)
.map_err(db_error)?
.query_row(params![room.id], |row| row.get(0));
@ -485,9 +501,8 @@ pub async fn set_room_image(
room: Room,
user: User,
data_b64: &str,
filename: Option<&str>
) -> Result<Response, Rejection>
{
filename: Option<&str>,
) -> Result<Response, Rejection> {
let auth = AuthorizationRequired { moderator: true, ..Default::default() };
let mut conn = storage::get_conn()?;
@ -534,20 +549,18 @@ pub fn insert_or_update_user(conn: &rusqlite::Connection, session_id: &str) -> R
.prepare_cached(
"INSERT INTO users (session_id) VALUES (?) \
ON CONFLICT DO UPDATE SET last_active = ((julianday('now') - 2440587.5)*86400.0) \
RETURNING *"
RETURNING *",
)
.map_err(db_error)?
.query_row(params![&session_id], User::from_row)
.map_err(db_error)?)
}
// Validates a (backwards compat) token string.
pub fn get_user_from_token(
conn: &rusqlite::Connection,
auth_token_str: &str
) -> Result<User, Error>
{
auth_token_str: &str,
) -> Result<User, Error> {
let auth_token =
decode_hex_or_b64(auth_token_str, TOKEN_SIZE).map_err(|_| Error::NoAuthToken)?;
if auth_token[0] != 0x05 {
@ -566,7 +579,6 @@ pub fn get_user_from_token(
insert_or_update_user(conn, &session_id)
}
pub fn get_auth_token_challenge(public_key: &str) -> Result<models::Challenge, Rejection> {
// Doesn't return a response directly for testing purposes
@ -597,7 +609,7 @@ pub fn get_auth_token_challenge(public_key: &str) -> Result<models::Challenge, R
// Return
Ok(models::Challenge {
ciphertext: base64::encode(ciphertext),
ephemeral_public_key: base64::encode(ephemeral_public_key.to_bytes())
ephemeral_public_key: base64::encode(ephemeral_public_key.to_bytes()),
})
}
@ -611,15 +623,16 @@ pub fn insert_message(
room: Room,
user: User,
data: &[u8],
signature: &[u8]
) -> Result<Response, Rejection>
{
signature: &[u8],
) -> Result<Response, Rejection> {
let mut conn = storage::get_conn()?;
let tx = storage::get_transaction(&mut conn)?;
require_authorization(&tx, &user, &room, AuthorizationRequired {
write: true,
..Default::default()
})?;
require_authorization(
&tx,
&user,
&room,
AuthorizationRequired { write: true, ..Default::default() },
)?;
// Check if the requesting user needs to be rate limited
@ -639,7 +652,7 @@ pub fn insert_message(
let size = data.len();
let trimmed = match data.iter().rposition(|&c| c != 0u8) {
Some(last) => &data[0..=last],
None => &data
None => &data,
};
// Insert the message
@ -647,7 +660,7 @@ pub fn insert_message(
.prepare_cached(
"INSERT INTO messages (room, user, data, data_size, signature) \
VALUES (?, ?, ?, ?, ?) \
RETURNING *"
RETURNING *",
)
.map_err(db_error)?
.query_row(params![room.id, user.id, trimmed, size, signature], OldMessage::from_row)
@ -701,15 +714,16 @@ fn get_messages_params(query_params: &HashMap<String, String>) -> (Option<i64>,
pub fn get_messages(
query_params: HashMap<String, String>,
user: User,
room: Room
) -> Result<Vec<OldMessage>, Rejection>
{
room: Room,
) -> Result<Vec<OldMessage>, Rejection> {
let conn = storage::get_conn()?;
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let (from_server_id, limit) = get_messages_params(&query_params);
@ -757,9 +771,8 @@ pub fn delete_message(
conn: &rusqlite::Connection,
id: i64,
user: &User,
room: &Room
) -> Result<Response, Rejection>
{
room: &Room,
) -> Result<Response, Rejection> {
let mut auth_req = AuthorizationRequired { read: true, ..Default::default() };
// Check to see if the message to be deleted is owned by someone else: if it is, we require
@ -778,14 +791,14 @@ pub fn delete_message(
let response = json!({"status_code": StatusCode::NOT_FOUND.as_u16()});
return Ok(warp::reply::json(&response).into_response());
}
Err(_) => return Err(Error::DatabaseFailedInternally.into())
Err(_) => return Err(Error::DatabaseFailedInternally.into()),
};
require_authorization(conn, user, room, auth_req)?;
let mut del_st = conn
.prepare_cached(
"UPDATE messages SET data = NULL, data_size = NULL, signature = NULL WHERE id = ?"
"UPDATE messages SET data = NULL, data_size = NULL, signature = NULL WHERE id = ?",
)
.map_err(db_error)?;
@ -805,17 +818,18 @@ pub fn delete_message(
pub fn get_deleted_messages(
query_params: HashMap<String, String>,
user: User,
room: Room
) -> Result<Vec<models::DeletedMessage>, Rejection>
{
room: Room,
) -> Result<Vec<models::DeletedMessage>, Rejection> {
let conn = storage::get_conn()?;
let (from_server_id, limit) = get_messages_params(&query_params);
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
// Query the database
let mut st = conn.prepare_cached(if from_server_id.is_some() {
@ -843,13 +857,14 @@ pub fn add_moderator_public(
room: Room,
user: User,
session_id: &str,
admin: bool
) -> Result<Response, Rejection>
{
require_authorization(&*storage::get_conn()?, &user, &room, AuthorizationRequired {
admin: true,
..Default::default()
})?;
admin: bool,
) -> Result<Response, Rejection> {
require_authorization(
&*storage::get_conn()?,
&user,
&room,
AuthorizationRequired { admin: true, ..Default::default() },
)?;
add_moderator_impl(session_id, admin, room)
}
@ -858,21 +873,20 @@ pub fn add_moderator_public(
// Not publicly exposed.
pub async fn add_moderator(
body: models::ChangeModeratorRequestBody
body: models::ChangeModeratorRequestBody,
) -> Result<Response, Rejection> {
add_moderator_impl(
&body.session_id,
body.admin.unwrap_or(false),
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?,
)
}
pub fn add_moderator_impl(
session_id: &str,
admin: bool,
room: Room
) -> Result<Response, Rejection>
{
room: Room,
) -> Result<Response, Rejection> {
require_session_id(session_id)?;
let mut conn = storage::get_conn()?;
@ -913,23 +927,24 @@ pub fn add_moderator_impl(
pub fn delete_moderator_public(
session_id: &str,
user: User,
room: Room
) -> Result<Response, Rejection>
{
require_authorization(&*storage::get_conn()?, &user, &room, AuthorizationRequired {
admin: true,
..Default::default()
})?;
room: Room,
) -> Result<Response, Rejection> {
require_authorization(
&*storage::get_conn()?,
&user,
&room,
AuthorizationRequired { admin: true, ..Default::default() },
)?;
delete_moderator_impl(session_id, room)
}
// Not publicly exposed.
pub async fn delete_moderator(
body: models::ChangeModeratorRequestBody
body: models::ChangeModeratorRequestBody,
) -> Result<Response, Rejection> {
delete_moderator_impl(
&body.session_id,
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?
storage::get_room_from_token(&*storage::get_conn()?, &body.room_token)?,
)
}
@ -951,7 +966,7 @@ pub fn delete_moderator_impl(session_id: &str, room: Room) -> Result<Response, R
Ok(count) if count > 0 => {
info!("Removed moderator {} from room {}", session_id, room.token)
}
Ok(_count) => info!("{} is not a moderator of room {}", session_id, room.token)
Ok(_count) => info!("{} is not a moderator of room {}", session_id, room.token),
}
let json = models::StatusCode { status_code: StatusCode::OK.as_u16() };
@ -962,17 +977,18 @@ pub fn delete_moderator_impl(session_id: &str, room: Room) -> Result<Response, R
pub fn get_moderators(
conn: &rusqlite::Connection,
user: &User,
room: &Room
) -> Result<Vec<String>, Rejection>
{
require_authorization(conn, user, room, AuthorizationRequired {
read: true,
..Default::default()
})?;
room: &Room,
) -> Result<Vec<String>, Rejection> {
require_authorization(
conn,
user,
room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let mut st = conn
.prepare_cached(
"SELECT session_id FROM user_permissions WHERE room = ? AND moderator AND visible_mod"
"SELECT session_id FROM user_permissions WHERE room = ? AND moderator AND visible_mod",
)
.map_err(db_error)?;
@ -999,9 +1015,8 @@ pub async fn ban(
session_id: &str,
delete_all: bool,
user: &User,
room: &Room
) -> Result<Response, Rejection>
{
room: &Room,
) -> Result<Response, Rejection> {
if !is_session_id(&session_id) {
warn!("Ignoring ban request: invalid session_id.");
return Err(Error::ValidationFailed.into());
@ -1022,7 +1037,7 @@ pub async fn ban(
match tx
.prepare_cached(
"SELECT user, moderator, global_moderator FROM user_permissions \
WHERE room = ? AND session_id = ?"
WHERE room = ? AND session_id = ?",
)
.map_err(db_error)?
.query_row(params![room.id, session_id], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))
@ -1042,7 +1057,7 @@ pub async fn ban(
let response = json!({"status_code": StatusCode::NOT_FOUND.as_u16()});
return Ok(warp::reply::json(&response).into_response());
}
Err(_) => return Err(Error::DatabaseFailedInternally.into())
Err(_) => return Err(Error::DatabaseFailedInternally.into()),
};
require_authorization(&tx, user, room, auth)?;
@ -1053,7 +1068,7 @@ pub async fn ban(
INSERT INTO user_permission_overrides (room, user, banned, moderator, admin) \
VALUES (?, ?, TRUE, FALSE, FALSE) \
ON CONFLICT DO UPDATE SET banned = TRUE, moderator = FALSE, admin = FALSE
"
",
)
.map_err(db_error)?
.execute(params![room.id, userid])
@ -1068,7 +1083,7 @@ pub async fn ban(
posts_removed += match tx
.prepare_cached(
"UPDATE messages SET data = NULL, data_size = NULL, signature = NULL \
WHERE room = ? AND user = ?"
WHERE room = ? AND user = ?",
)
.map_err(db_error)?
.execute(params![room.id, userid])
@ -1085,7 +1100,7 @@ pub async fn ban(
// from disk).
files_removed = tx
.prepare_cached(
"UPDATE files SET room = NULL, expiry = ? WHERE room = ? AND uploader = ?"
"UPDATE files SET room = NULL, expiry = ? WHERE room = ? AND uploader = ?",
)
.map_err(db_error)?
.execute(params![unixtime_f64(), room.id, userid])
@ -1114,15 +1129,17 @@ pub fn unban(session_id: &str, user: &User, room: &Room) -> Result<Response, Rej
}
let conn = storage::get_conn()?;
require_authorization(&conn, user, room, AuthorizationRequired {
moderator: true,
..Default::default()
})?;
require_authorization(
&conn,
user,
room,
AuthorizationRequired { moderator: true, ..Default::default() },
)?;
let count = match conn
.prepare_cached(
"UPDATE user_permission_overrides SET banned = FALSE \
WHERE room = ? AND user IN (SELECT id FROM users WHERE session_id = ?)"
WHERE room = ? AND user IN (SELECT id FROM users WHERE session_id = ?)",
)
.map_err(db_error)?
.execute(params![room.id, session_id])
@ -1148,10 +1165,12 @@ pub fn unban(session_id: &str, user: &User, room: &Room) -> Result<Response, Rej
/// Returns the full list of banned public keys.
pub fn get_banned_public_keys(user: &User, room: &Room) -> Result<Response, Rejection> {
let conn = storage::get_conn()?;
require_authorization(&conn, user, room, AuthorizationRequired {
moderator: true,
..Default::default()
})?;
require_authorization(
&conn,
user,
room,
AuthorizationRequired { moderator: true, ..Default::default() },
)?;
let banned_members: Result<Vec<String>, _> = match conn
.prepare_cached("SELECT session_id FROM user_permissions WHERE room = ? AND banned")
@ -1184,14 +1203,15 @@ pub fn get_member_count(user: User, room: Room) -> Result<Response, Rejection> {
pub fn get_member_count_since(
user: User,
room: Room,
ago: Duration
) -> Result<Response, Rejection>
{
ago: Duration,
) -> Result<Response, Rejection> {
let conn = storage::get_conn()?;
require_authorization(&conn, &user, &room, AuthorizationRequired {
read: true,
..Default::default()
})?;
require_authorization(
&conn,
&user,
&room,
AuthorizationRequired { read: true, ..Default::default() },
)?;
let mut st = conn
.prepare_cached("SELECT COUNT(*) FROM room_users WHERE room = ? AND last_active >= ?")
@ -1236,15 +1256,13 @@ pub fn get_room_updates(user: User, room: Room, since_update: i64) {
*/
}
/// Deprecated room polling; unlike the above, this does not handle metadata (except for
/// moderators, which are *always* included even though they rarely change), does not support
/// message edits, and has non-obvious alternate modes of operation.
pub fn compact_poll(
user: Option<User>,
request_bodies: Vec<models::CompactPollRequestBody>
) -> Result<Response, Rejection>
{
request_bodies: Vec<models::CompactPollRequestBody>,
) -> Result<Response, Rejection> {
let mut response_bodies = Vec::<models::CompactPollResponseBody>::new();
let mut conn = storage::get_conn()?;
@ -1275,7 +1293,7 @@ pub fn compact_poll(
.prepare_cached(
"SELECT * FROM message_details \
WHERE room = ? AND data IS NOT NULL \
ORDER BY id DESC LIMIT 256"
ORDER BY id DESC LIMIT 256",
)
.map_err(db_error)?;
@ -1283,7 +1301,7 @@ pub fn compact_poll(
.prepare_cached(
"SELECT id, updated FROM messages \
WHERE room = ? AND data IS NULL \
ORDER BY updated DESC LIMIT 256"
ORDER BY updated DESC LIMIT 256",
)
.map_err(db_error)?;
@ -1292,7 +1310,7 @@ pub fn compact_poll(
.prepare_cached(
"SELECT id, updated FROM messages \
WHERE room = ? AND updated > ? AND data IS NULL \
ORDER BY updated LIMIT 256"
ORDER BY updated LIMIT 256",
)
.map_err(db_error)?;
@ -1300,18 +1318,17 @@ pub fn compact_poll(
.prepare_cached(
"SELECT * FROM message_details \
WHERE room = ? AND id > ? AND data IS NOT NULL \
ORDER BY id LIMIT 256"
ORDER BY id LIMIT 256",
)
.map_err(db_error)?;
for request in request_bodies {
let mut response = models::CompactPollResponseBody {
room_token: request.room_token.clone(),
status_code: StatusCode::OK.as_u16(),
messages: vec![],
deletions: vec![],
moderators: vec![]
moderators: vec![],
};
let room: &Room = match rooms.get(&request.room_token) {
@ -1449,13 +1466,12 @@ pub async fn get_session_version(platform: &str) -> Result<String, Rejection> {
// not publicly exposed.
pub async fn get_stats_for_room(
room_token: String,
query_map: HashMap<String, i64>
) -> Result<Response, Rejection>
{
query_map: HashMap<String, i64>,
) -> Result<Response, Rejection> {
let window = *query_map.get("window").unwrap_or(&3600) as f64;
let upperbound = match query_map.get("start") {
Some(ts) => *ts as f64,
None => unixtime_f64()
None => unixtime_f64(),
};
let lowerbound = upperbound - window;
@ -1466,7 +1482,7 @@ pub async fn get_stats_for_room(
let active = tx
.prepare_cached(
"SELECT COUNT(*) FROM room_users WHERE room = ? AND last_active BETWEEN ? AND ?"
"SELECT COUNT(*) FROM room_users WHERE room = ? AND last_active BETWEEN ? AND ?",
)
.map_err(db_error)?
.query_row(params![room.id, lowerbound, upperbound], |row| Ok(row.get::<_, i64>(0)?))
@ -1511,9 +1527,8 @@ fn require_authorization(
conn: &rusqlite::Connection,
user: &User,
room: &Room,
req: AuthorizationRequired
) -> Result<(), Error>
{
req: AuthorizationRequired,
) -> Result<(), Error> {
require_authorization_impl(conn, &user, &room, req, true)
}
/// Same as above, but does not update the room/user last activity timestamp.
@ -1522,9 +1537,8 @@ fn require_authorization_no_activity(
conn: &rusqlite::Connection,
user: &User,
room: &Room,
req: AuthorizationRequired
) -> Result<(), Error>
{
req: AuthorizationRequired,
) -> Result<(), Error> {
return require_authorization_impl(conn, &user, &room, req, false);
}
@ -1533,9 +1547,8 @@ fn require_authorization_impl(
user: &User,
room: &Room,
need: AuthorizationRequired,
log_active: bool
) -> Result<(), Error>
{
log_active: bool,
) -> Result<(), Error> {
let mut st = conn
.prepare_cached(
"SELECT banned, read, write, upload, moderator, admin FROM user_permissions WHERE room = ? AND user = ?"

View File

@ -1,9 +1,13 @@
use log::LevelFilter;
use log4rs::{append::{console::ConsoleAppender,
rolling_file::{policy::compound, RollingFileAppender}},
config::{Appender, Logger, Root},
encode::pattern::PatternEncoder,
filter::threshold::ThresholdFilter};
use log4rs::{
append::{
console::ConsoleAppender,
rolling_file::{policy::compound, RollingFileAppender},
},
config::{Appender, Logger, Root},
encode::pattern::PatternEncoder,
filter::threshold::ThresholdFilter,
};
use std::str::FromStr;
pub fn init(log_file: Option<String>, log_level: Option<String>) {

View File

@ -3,8 +3,10 @@
use parking_lot::RwLock;
use std::fs;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::{collections::HashMap,
sync::atomic::{AtomicBool, AtomicU16, Ordering}};
use std::{
collections::HashMap,
sync::atomic::{AtomicBool, AtomicU16, Ordering},
};
use futures::join;
use log::info;
@ -15,13 +17,13 @@ mod crypto;
mod errors;
mod handlers;
mod logging;
mod migration;
mod models;
mod onion_requests;
mod options;
mod routes;
mod rpc;
mod storage;
mod migration;
#[cfg(test)]
mod tests;

View File

@ -1,16 +1,15 @@
use std::fs;
use std::path::Path;
use std::os::unix::fs::MetadataExt;
use std::path::Path;
use std::time::SystemTime;
use log::{info, warn};
use rusqlite::{Connection, OpenFlags, params, types::Null};
use super::handlers;
use super::storage;
use log::{info, warn};
use rusqlite::{params, types::Null, Connection, OpenFlags};
// Performs database migration from v0.1.8 to v0.2.0
pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// Old database database.db is a single table database containing just the list of rooms:
/*
CREATE TABLE IF NOT EXISTS main (
@ -24,7 +23,10 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// that starting again will try to import again).
let tx = conn.transaction()?;
struct Rm { token: String, name: Option<String> }
struct Rm {
token: String,
name: Option<String>,
}
let rooms = Connection::open_with_flags("database.db", OpenFlags::SQLITE_OPEN_READ_ONLY)?
.prepare("SELECT id, name FROM main")?
@ -34,34 +36,44 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
warn!("{} rooms to import", rooms.len());
{
tx.execute("\
tx.execute(
"\
CREATE TABLE room_import_hacks (
room INTEGER PRIMARY KEY NOT NULL REFERENCES rooms(id),
old_message_id_max INTEGER NOT NULL,
message_id_offset INTEGER NOT NULL
)", [])?;
)",
[],
)?;
let mut used_room_hacks: bool = false;
let mut ins_room_hack = tx.prepare(
"INSERT INTO room_import_hacks (room, old_message_id_max, message_id_offset) VALUES (?, ?, ?)")?;
tx.execute("\
tx.execute(
"\
CREATE TABLE file_id_hacks (
room INTEGER NOT NULL REFERENCES rooms(id),
old_file_id INTEGER NOT NULL,
file INTEGER NOT NULL REFERENCES files(id) ON DELETE CASCADE,
PRIMARY KEY(room, old_file_id)
)", [])?;
)",
[],
)?;
let mut used_file_hacks: bool = false;
let mut ins_file_hack = tx.prepare(
"INSERT INTO file_id_hacks (room, old_file_id, file) VALUES (?, ?, ?)")?;
let mut ins_file_hack =
tx.prepare("INSERT INTO file_id_hacks (room, old_file_id, file) VALUES (?, ?, ?)")?;
let mut ins_room = tx.prepare("INSERT INTO rooms (token, name) VALUES (?, ?) RETURNING id")?;
let mut ins_room =
tx.prepare("INSERT INTO rooms (token, name) VALUES (?, ?) RETURNING id")?;
let mut ins_user = tx.prepare("INSERT INTO users (session_id, last_active) VALUES (?, 0.0) ON CONFLICT DO NOTHING")?;
let mut ins_user = tx.prepare(
"INSERT INTO users (session_id, last_active) VALUES (?, 0.0) ON CONFLICT DO NOTHING",
)?;
let mut ins_msg = tx.prepare(
"INSERT INTO messages (id, room, user, posted, data, data_size, signature) \
VALUES (?, ?, (SELECT id FROM users WHERE session_id = ?), ?, ?, ?, ?)")?;
VALUES (?, ?, (SELECT id FROM users WHERE session_id = ?), ?, ?, ?, ?)",
)?;
let mut upd_msg_updated = tx.prepare("UPDATE messages SET updated = ? WHERE id = ?")?;
let mut upd_room_updates = tx.prepare("UPDATE rooms SET updates = ? WHERE id = ?")?;
@ -83,7 +95,8 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
"INSERT INTO room_users (room, user, last_active) VALUES (?, (SELECT id FROM users WHERE session_id = ?), ?) \
ON CONFLICT DO UPDATE SET last_active = excluded.last_active WHERE excluded.last_active > last_active")?;
let mut upd_user_activity = tx.prepare(
"UPDATE users SET last_active = ?1 WHERE session_id = ?2 AND last_active < ?1")?;
"UPDATE users SET last_active = ?1 WHERE session_id = ?2 AND last_active < ?1",
)?;
for room in rooms {
let room_db_filename = format!("rooms/{}.db", room.token);
@ -95,7 +108,8 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
info!("Importing room {}...", room.token);
let room_id = ins_room.query_row(params![room.token, room.name], |row| row.get::<_, i64>(0))?;
let room_id =
ins_room.query_row(params![room.token, room.name], |row| row.get::<_, i64>(0))?;
let rconn = Connection::open_with_flags(room_db, OpenFlags::SQLITE_OPEN_READ_ONLY)?;
@ -115,7 +129,7 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
data and signature are in base64 (wtf), data is typically padded from the client (i.e. to
the next multiple, with lots of 0s on the end). If the message was deleted then it remains
here but `is_deleted` is set to 1 (data are signature should be NULL as well, but older
versions apparently didn't do that), plus we have a row in here:
versions apparently didn't do that), plus we have a row in here:
CREATE TABLE IF NOT EXISTS deleted_messages (
id INTEGER PRIMARY KEY,
@ -138,12 +152,23 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
deletion ids.
*/
let mut id_offset: i64 = tx.query_row("SELECT COALESCE(MAX(id), 0) + 1 FROM messages", [], |row| row.get(0))?;
let mut id_offset: i64 =
tx.query_row("SELECT COALESCE(MAX(id), 0) + 1 FROM messages", [], |row| {
row.get(0)
})?;
let mut top_old_id: i64 = -1;
let mut updated: i64 = 0;
let mut imported_msgs: i64 = 0;
struct Msg { id: i64, session_id: String, ts_ms: i64, data: Option<String>, signature: Option<String>, deleted: Option<i64> }
let n_msgs: i64 = rconn.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
struct Msg {
id: i64,
session_id: String,
ts_ms: i64,
data: Option<String>,
signature: Option<String>,
deleted: Option<i64>,
}
let n_msgs: i64 =
rconn.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
let mut msg_st = rconn.prepare("\
SELECT messages.id, public_key, timestamp, data, signature, is_deleted, deleted_messages.id \
FROM messages LEFT JOIN deleted_messages ON messages.id = deleted_messages.deleted_message_id
@ -153,8 +178,16 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
let mut dupe_dels: i64 = 0;
while let Some(row) = msg_rows.next()? {
let msg = Msg {
id: row.get(0)?, session_id: row.get(1)?, ts_ms: row.get(2)?, data: row.get(3)?, signature: row.get(4)?,
deleted: if row.get::<_, Option<bool>>(5)?.unwrap_or(false) { Some(row.get(6)?) } else { None }
id: row.get(0)?,
session_id: row.get(1)?,
ts_ms: row.get(2)?,
data: row.get(3)?,
signature: row.get(4)?,
deleted: if row.get::<_, Option<bool>>(5)?.unwrap_or(false) {
Some(row.get(6)?)
} else {
None
},
};
if top_old_id == -1 {
id_offset -= msg.id;
@ -177,21 +210,42 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// Data was pointlessly store padding, so unpad it:
let padded_data = match base64::decode(msg.data.unwrap()) {
Ok(d) => d,
Err(e) => panic!("Unexpected data: {} message id={} has non-base64 data ({})", room_db.display(), msg.id, e)
Err(e) => panic!(
"Unexpected data: {} message id={} has non-base64 data ({})",
room_db.display(),
msg.id,
e
),
};
let data_size = padded_data.len();
let data = match padded_data.iter().rposition(|&c| c != 0u8) {
Some(last) => &padded_data[0..=last],
None => &padded_data
None => &padded_data,
};
let sig = match base64::decode(msg.signature.unwrap()) {
Ok(d) if d.len() == 64 => d,
Ok(_) => panic!("Unexpected data: {} message id={} has invalid signature", room_db.display(), msg.id),
Err(e) => panic!("Unexpected data: {} message id={} has non-base64 signature ({})", room_db.display(), msg.id, e)
Ok(_) => panic!(
"Unexpected data: {} message id={} has invalid signature",
room_db.display(),
msg.id
),
Err(e) => panic!(
"Unexpected data: {} message id={} has non-base64 signature ({})",
room_db.display(),
msg.id,
e
),
};
ins_msg.execute(params![msg.id + id_offset, room_id, msg.session_id, (msg.ts_ms as f64) / 1000., data, data_size, sig])?;
ins_msg.execute(params![
msg.id + id_offset,
room_id,
msg.session_id,
(msg.ts_ms as f64) / 1000.,
data,
data_size,
sig
])?;
} else if msg.deleted.is_some() &&
// Deleted messages are usually set to the fixed string "deleted" (why not
// NULL?) for data and signature, so accept either null or that string if the
@ -205,7 +259,15 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
// deletion id as the "updated" field. (We do this with a second query because the
// first query is going to trigger an automatic update of the field).
ins_msg.execute(params![msg.id + id_offset, room_id, msg.session_id, (msg.ts_ms as f64) / 1000., Null, Null, Null])?;
ins_msg.execute(params![
msg.id + id_offset,
room_id,
msg.session_id,
(msg.ts_ms as f64) / 1000.,
Null,
Null,
Null
])?;
} else {
panic!("Inconsistent message in {} database: message id={} has inconsistent deletion state (data: {}, signature: {}, del row: {})",
room_db.display(), msg.id, msg.data.is_some(), msg.signature.is_some(), msg.deleted.is_some());
@ -217,7 +279,10 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
info!("- ... imported {}/{} messages", imported_msgs, n_msgs);
}
}
info!("- migrated {} messages, {} duplicate deletions ignored", imported_msgs, dupe_dels);
info!(
"- migrated {} messages, {} duplicate deletions ignored",
imported_msgs, dupe_dels
);
upd_room_updates.execute(params![updated, room_id])?;
@ -229,36 +294,55 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
let mut imported_files: i64 = 0;
let n_files: i64 = rconn.query_row("SELECT COUNT(*) FROM files", [], |row| row.get(0))?;
let n_files: i64 =
rconn.query_row("SELECT COUNT(*) FROM files", [], |row| row.get(0))?;
// WTF is this id stored as a TEXT?
struct File { id: String, ts: i64 }
struct File {
id: String,
ts: i64,
}
let mut rows_st = rconn.prepare("SELECT id, timestamp FROM files")?;
let mut file_rows = rows_st.query([])?;
while let Some(row) = file_rows.next()? {
let file = File { id: row.get(0)?, ts: row.get(1)? };
let old_id = match file.id.parse::<i64>() {
Ok(id) => id,
Err(e) => panic!("Invalid fileid '{}' found in {}: {}", file.id, room_db.display(), e)
Err(e) => {
panic!("Invalid fileid '{}' found in {}: {}", file.id, room_db.display(), e)
}
};
let old_path = format!("files/{}_files/{}", room.token, old_id);
let size = match fs::metadata(&old_path) {
Ok(md) => md.len(),
Err(e) => {
warn!("Error accessing file {} ({}); skipping import of this upload", old_path, e);
warn!(
"Error accessing file {} ({}); skipping import of this upload",
old_path, e
);
continue;
}
};
let ts = if file.ts > 10000000000 {
warn!("- file {} has nonsensical timestamp {}; importing it with current time", old_path, file.ts);
warn!(
"- file {} has nonsensical timestamp {}; importing it with current time",
old_path, file.ts
);
SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs_f64()
} else {
file.ts as f64
};
let new_id = ins_file.query_row(
params![room_id, size, ts, ts + handlers::UPLOAD_DEFAULT_EXPIRY.as_secs_f64(), old_path],
|row| row.get::<_, i64>(0))?;
params![
room_id,
size,
ts,
ts + handlers::UPLOAD_DEFAULT_EXPIRY.as_secs_f64(),
old_path
],
|row| row.get::<_, i64>(0),
)?;
ins_file_hack.execute(params![room_id, old_id, new_id])?;
imported_files += 1;
@ -266,7 +350,9 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
info!("- ... imported {}/{} files", imported_files, n_files);
}
}
if imported_files > 0 { used_file_hacks = true; }
if imported_files > 0 {
used_file_hacks = true;
}
info!("- migrated {} files", imported_files);
// There's also a potential room image, which is just stored on disk and not referenced in
@ -282,11 +368,21 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
panic!("Unable to mkdir {} for room file storage: {}", files_dir, e);
}
let file_id = ins_file.query_row(
params![room_id, md.len(), md.mtime() as f64 + md.mtime_nsec() as f64 * 1e-9, Null, "tmp"],
|row| row.get::<_, i64>(0))?;
params![
room_id,
md.len(),
md.mtime() as f64 + md.mtime_nsec() as f64 * 1e-9,
Null,
"tmp"
],
|row| row.get::<_, i64>(0),
)?;
let new_image_path = format!("uploads/{}/{}_(unnamed)", room.token, file_id);
if let Err(e) = fs::hard_link(&room_image_path, &new_image_path) {
panic!("Unable to hard link room image file {} => {}: {}", room_image_path, new_image_path, e);
panic!(
"Unable to hard link room image file {} => {}: {}",
room_image_path, new_image_path, e
);
}
upd_file_path.execute(params![new_image_path, file_id])?;
upd_room_image.execute(params![file_id, room_id])?;
@ -325,22 +421,36 @@ pub fn migrate_0_2_0(conn: &mut Connection) -> Result<(), rusqlite::Error> {
let mut imported_activity: i64 = 0;
let mut imported_active: i64 = 0;
// Don't import rows we're going to immediately prune:
let import_cutoff = (SystemTime::now() - storage::ROOM_ACTIVE_PRUNE_THRESHOLD).duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs_f64();
let n_activity: i64 = rconn.query_row("SELECT COUNT(*) FROM user_activity WHERE last_active > ?",
params![import_cutoff], |row| row.get(0))?;
let import_cutoff = (SystemTime::now() - storage::ROOM_ACTIVE_PRUNE_THRESHOLD)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs_f64();
let n_activity: i64 = rconn.query_row(
"SELECT COUNT(*) FROM user_activity WHERE last_active > ?",
params![import_cutoff],
|row| row.get(0),
)?;
let mut activity_st = rconn.prepare("SELECT public_key, last_active FROM user_activity WHERE last_active > ? AND public_key IS NOT NULL")?;
let mut act_rows = activity_st.query(params![import_cutoff])?;
let cutoff = (SystemTime::now() - handlers::ROOM_DEFAULT_ACTIVE_THRESHOLD).duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs_f64();
let cutoff = (SystemTime::now() - handlers::ROOM_DEFAULT_ACTIVE_THRESHOLD)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs_f64();
while let Some(row) = act_rows.next()? {
let session_id: String = row.get(0)?;
let ts: f64 = row.get::<_, i64>(1)? as f64;
ins_user.execute(params![session_id])?;
ins_room_activity.execute(params![room_id, session_id, ts])?;
upd_user_activity.execute(params![ts, session_id])?;
if ts >= cutoff { imported_active += 1; }
if ts >= cutoff {
imported_active += 1;
}
imported_activity += 1;
if imported_activity % 1000 == 0 {
info!("- ... imported {}/{} user activity records ({} active)", imported_activity, n_activity, imported_active);
info!(
"- ... imported {}/{} user activity records ({} active)",
imported_activity, n_activity, imported_active
);
}
}
warn!("Imported room {}: {} messages, {} files, {} moderators, {} bans, {} users ({} active)",

View File

@ -9,7 +9,7 @@ pub struct User {
pub last_active: f64,
pub banned: bool,
pub moderator: bool,
pub admin: bool
pub admin: bool,
}
impl User {
@ -21,13 +21,15 @@ impl User {
last_active: row.get(row.column_index("last_active")?)?,
banned: row.get(row.column_index("banned")?)?,
moderator: row.get(row.column_index("moderator")?)?,
admin: row.get(row.column_index("admin")?)?
admin: row.get(row.column_index("admin")?)?,
});
}
}
fn as_opt_base64<S>(val: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error>
where S: Serializer {
where
S: Serializer,
{
s.serialize_str(&base64::encode(val.as_ref().unwrap()))
}
@ -50,7 +52,7 @@ pub struct OldMessage {
/// XEd25519 message signature of the `data` bytes (not the base64 representation), encoded in
/// base64
#[serde(serialize_with = "as_opt_base64")]
pub signature: Option<Vec<u8>>
pub signature: Option<Vec<u8>>,
}
impl OldMessage {
@ -59,19 +61,18 @@ impl OldMessage {
repad(&mut data, row.get::<_, Option<usize>>(row.column_index("data_size")?)?);
let session_id = match row.column_index("session_id") {
Ok(index) => Some(row.get(index)?),
Err(_) => None
Err(_) => None,
};
return Ok(OldMessage {
server_id: row.get(row.column_index("id")?)?,
public_key: session_id,
timestamp: (row.get::<_, f64>(row.column_index("posted")?)? * 1000.0) as i64,
data,
signature: row.get(row.column_index("signature")?)?
signature: row.get(row.column_index("signature")?)?,
});
}
}
#[derive(Debug, Serialize)]
pub struct Message {
/// The message id.
@ -95,7 +96,7 @@ pub struct Message {
pub signature: Option<Vec<u8>>,
/// Flag set to true if the message is deleted, and omitted otherwise.
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted: Option<bool>
pub deleted: Option<bool>,
}
fn repad(data: &mut Option<Vec<u8>>, size: Option<usize>) {
@ -113,7 +114,7 @@ impl Message {
let deleted = if data.is_none() { Some(true) } else { None };
let session_id = match row.column_index("session_id") {
Ok(index) => Some(row.get(index)?),
Err(_) => None
Err(_) => None,
};
return Ok(Message {
id: row.get(row.column_index("id")?)?,
@ -123,13 +124,15 @@ impl Message {
updated: row.get(row.column_index("updated")?)?,
data,
signature: row.get(row.column_index("signature")?)?,
deleted
deleted,
});
}
}
fn bytes_from_base64<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where D: Deserializer<'de> {
where
D: Deserializer<'de>,
{
use serde::de::Error;
String::deserialize(deserializer)
.and_then(|str| base64::decode(&str).map_err(|err| Error::custom(err.to_string())))
@ -140,14 +143,14 @@ pub struct PostMessage {
#[serde(deserialize_with = "bytes_from_base64")]
pub data: Vec<u8>,
#[serde(deserialize_with = "bytes_from_base64")]
pub signature: Vec<u8>
pub signature: Vec<u8>,
}
#[derive(Debug, Serialize)]
pub struct DeletedMessage {
#[serde(rename = "id")]
pub updated: i64,
pub deleted_message_id: i64
pub deleted_message_id: i64,
}
#[derive(Debug, Serialize)]
@ -165,7 +168,7 @@ pub struct Room {
pub updates: i64,
pub default_read: bool,
pub default_write: bool,
pub default_upload: bool
pub default_upload: bool,
}
impl Room {
@ -180,13 +183,11 @@ impl Room {
updates: row.get(row.column_index("updates")?)?,
default_read: row.get(row.column_index("read")?)?,
default_write: row.get(row.column_index("write")?)?,
default_upload: row.get(row.column_index("upload")?)?
default_upload: row.get(row.column_index("upload")?)?,
});
}
}
// FIXME: this appears to be used for both add/remove. But what if we want to promote to admin, or
// demote to moderator?
#[derive(Debug, Deserialize)]
@ -195,7 +196,7 @@ pub struct ChangeModeratorRequestBody {
pub room_token: String,
#[serde(rename = "public_key")]
pub session_id: String,
pub admin: Option<bool>
pub admin: Option<bool>,
}
#[derive(Debug, Deserialize)]
@ -204,7 +205,7 @@ pub struct PollRoomMetadata {
pub room: String,
/// The last `info_update` value the client has; results are only returned if the room has been
/// modified since the value provided by the client.
pub since_update: i64
pub since_update: i64,
}
#[derive(Debug, Serialize)]
@ -216,7 +217,7 @@ pub struct RoomDetails {
/// Metadata of the room; this omitted from the response when polling if the room metadata
/// (other than active user count) has not changed since the request update counter.
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<RoomMetadata>
pub details: Option<RoomMetadata>,
}
#[derive(Debug, Serialize)]
@ -255,7 +256,7 @@ pub struct RoomMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub moderator: Option<bool>,
/// Will be present and true if the requesting user has admin powers, omitted otherwise.
pub admin: Option<bool>
pub admin: Option<bool>,
}
#[derive(Debug, Deserialize)]
@ -264,7 +265,7 @@ pub struct PollRoomMessages {
pub room: String,
/// Return new messages, edit, and deletions posted since this `updates` value. Clients should
/// poll with the most recent updates value they have received.
pub since_update: i64
pub since_update: i64,
}
#[derive(Debug, Serialize)]
@ -272,7 +273,7 @@ pub struct RoomMessages {
/// The token of this room
pub room: String,
/// Vector of new/edited/deleted message posted to the room since the requested update.
pub messages: Vec<Message>
pub messages: Vec<Message>,
}
#[derive(Debug, Deserialize)]
@ -288,7 +289,7 @@ pub struct CompactPollRequestBody {
// messages/deletions, in reverse order from what you get with regular polling. New clients
// should update to the new polling endpoints ASAP.
pub from_message_server_id: Option<i64>,
pub from_deletion_server_id: Option<i64>
pub from_deletion_server_id: Option<i64>,
}
#[derive(Debug, Serialize)]
@ -298,16 +299,16 @@ pub struct CompactPollResponseBody {
pub status_code: u16,
pub deletions: Vec<DeletedMessage>,
pub messages: Vec<OldMessage>,
pub moderators: Vec<String>
pub moderators: Vec<String>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Challenge {
pub ciphertext: String,
pub ephemeral_public_key: String
pub ephemeral_public_key: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct StatusCode {
pub status_code: u16
pub status_code: u16,
}

View File

@ -12,12 +12,12 @@ use super::rpc;
#[derive(Deserialize, Serialize, Debug)]
struct OnionRequestPayload {
pub ciphertext: Vec<u8>,
pub metadata: OnionRequestPayloadMetadata
pub metadata: OnionRequestPayloadMetadata,
}
#[derive(Deserialize, Serialize, Debug)]
struct OnionRequestPayloadMetadata {
pub ephemeral_key: String
pub ephemeral_key: String,
}
pub async fn handle_onion_request(blob: warp::hyper::body::Bytes) -> Result<Response, Rejection> {
@ -36,9 +36,8 @@ pub async fn handle_onion_request(blob: warp::hyper::body::Bytes) -> Result<Resp
async fn handle_decrypted_onion_request(
plaintext: &[u8],
symmetric_key: &[u8]
) -> Result<Response, Rejection>
{
symmetric_key: &[u8],
) -> Result<Response, Rejection> {
let rpc_call = match serde_json::from_slice(plaintext) {
Ok(rpc_call) => rpc_call,
Err(e) => {
@ -59,7 +58,7 @@ async fn handle_decrypted_onion_request(
}
fn parse_onion_request_payload(
blob: warp::hyper::body::Bytes
blob: warp::hyper::body::Bytes,
) -> Result<OnionRequestPayload, Rejection> {
// The encoding of an onion request looks like:
//
@ -107,7 +106,7 @@ fn parse_onion_request_payload(
/// Returns the decrypted `payload.ciphertext` plus the `symmetric_key` that was used for
/// decryption if successful.
fn decrypt_onion_request_payload(
payload: OnionRequestPayload
payload: OnionRequestPayload,
) -> Result<(Vec<u8>, Vec<u8>), Rejection> {
let ephemeral_key = hex::decode(payload.metadata.ephemeral_key).unwrap(); // Safe because it was validated in the parsing step
let symmetric_key = crypto::get_x25519_symmetric_key(&ephemeral_key, &crypto::PRIVATE_KEY)?;

View File

@ -64,5 +64,5 @@ pub struct Opt {
/// Prints the URL format users can use to join rooms on this open group server.
#[structopt(long = "print-url")]
pub print_url: bool
pub print_url: bool,
}

View File

@ -105,9 +105,8 @@ pub async fn root_html() -> Result<Response, Rejection> {
pub async fn fallback_html(
room: String,
query_map: HashMap<String, String>
) -> Result<Response, Rejection>
{
query_map: HashMap<String, String>,
) -> Result<Response, Rejection> {
if !query_map.contains_key("public_key") || room == "" {
return fallback_nopubkey_html().await;
}

View File

@ -15,7 +15,7 @@ use super::storage;
#[allow(dead_code)]
pub enum Mode {
FileServer,
OpenGroupServer
OpenGroupServer,
}
#[derive(Deserialize, Debug)]
@ -34,7 +34,7 @@ pub struct RpcCall {
/// Arbitrary string; must be different on each request
pub nonce: Option<String>,
/// Ed25519 signature (in base64 or hex) of (method || endpoint || body || nonce)
pub signature: Option<String>
pub signature: Option<String>,
}
pub const MODE: Mode = Mode::OpenGroupServer;
@ -43,9 +43,8 @@ pub const MODE: Mode = Mode::OpenGroupServer;
// is a parseable auth token, and an error for anything else.
fn get_user_from_auth_header(
conn: &rusqlite::Connection,
rpc: &RpcCall
) -> Result<Option<User>, Error>
{
rpc: &RpcCall,
) -> Result<Option<User>, Error> {
if let Some(auth_token_str) = rpc.headers.get("Authorization") {
return Ok(Some(handlers::get_user_from_token(conn, auth_token_str)?));
}
@ -77,16 +76,20 @@ pub async fn handle_rpc_call(mut rpc_call: RpcCall) -> Result<Response, Rejectio
let mut sig_bytes: [u8; 64] = [0; 64];
sig_bytes.copy_from_slice(
&handlers::decode_hex_or_b64(rpc_call.signature.as_ref().unwrap(), 64)?[0..64]
&handlers::decode_hex_or_b64(rpc_call.signature.as_ref().unwrap(), 64)?[0..64],
);
let sig = ed25519_dalek::Signature::new(sig_bytes);
if let Err(sigerr) = crypto::verify_signature(&edpk, &sig, &vec![
rpc_call.endpoint.as_bytes(),
rpc_call.method.as_bytes(),
rpc_call.body.as_bytes(),
nonce.as_bytes(),
]) {
if let Err(sigerr) = crypto::verify_signature(
&edpk,
&sig,
&vec![
rpc_call.endpoint.as_bytes(),
rpc_call.method.as_bytes(),
rpc_call.body.as_bytes(),
nonce.as_bytes(),
],
) {
warn!("Signature verification failed for request from {}", sessionid);
return Err(sigerr.into());
}
@ -111,7 +114,7 @@ pub async fn handle_rpc_call(mut rpc_call: RpcCall) -> Result<Response, Rejectio
path = uri.path().trim_start_matches('/').to_string();
query_params = match uri.query() {
Some(qs) => form_urlencoded::parse(qs.as_bytes()).into_owned().collect(),
None => HashMap::new()
None => HashMap::new(),
};
}
Err(e) => {
@ -153,9 +156,8 @@ async fn handle_get_request(
rpc_call: RpcCall,
path: &str,
user: Option<User>,
query_params: HashMap<String, String>
) -> Result<Response, Rejection>
{
query_params: HashMap<String, String>,
) -> Result<Response, Rejection> {
let mut components: Vec<&str> = path.split('/').collect();
if components.len() == 0 {
components.push("");
@ -165,7 +167,7 @@ async fn handle_get_request(
if components[0] == "auth_token_challenge" && components.len() == 1 {
reject_if_file_server_mode(path)?;
let challenge = handlers::get_auth_token_challenge(
query_params.get("public_key").ok_or(Error::InvalidRpcCall)?
query_params.get("public_key").ok_or(Error::InvalidRpcCall)?,
)?;
let response = json!({ "status_code": StatusCode::OK.as_u16(), "challenge": challenge });
return Ok(warp::reply::json(&response).into_response());
@ -208,7 +210,7 @@ async fn handle_get_request(
warn!("Ignoring RPC call with invalid or unused endpoint: {}.", path);
return Err(Error::InvalidRpcCall.into());
}
Mode::FileServer => ()
Mode::FileServer => (),
}
let platform = query_params
.get("platform")
@ -216,7 +218,7 @@ async fn handle_get_request(
let version = handlers::get_session_version(platform).await?;
let response = handlers::GenericStringResponse {
status_code: StatusCode::OK.as_u16(),
result: version
result: version,
};
return Ok(warp::reply::json(&response).into_response());
}
@ -299,9 +301,8 @@ async fn handle_post_request(
room: Option<Room>,
rpc_call: RpcCall,
path: &str,
user: Option<User>
) -> Result<Response, Rejection>
{
user: Option<User>,
) -> Result<Response, Rejection> {
// Handle routes that don't require authorization first
// The compact poll endpoint expects the auth token to be in the request body; not in the
@ -314,7 +315,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct CompactPollRequestBodyWrapper {
requests: Vec<models::CompactPollRequestBody>
requests: Vec<models::CompactPollRequestBody>,
}
let wrapper: CompactPollRequestBodyWrapper = match serde_json::from_str(&rpc_call.body) {
Ok(bodies) => bodies,
@ -344,7 +345,7 @@ async fn handle_post_request(
if components.len() == 3 && components[2] == "image" {
#[derive(Debug, Deserialize)]
struct JSON {
file: String
file: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -405,7 +406,7 @@ async fn handle_post_request(
// filename
#[derive(Debug, Deserialize)]
struct JSON {
file: String
file: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -418,7 +419,7 @@ async fn handle_post_request(
// FIXME TODO: add an input field so that the uploader can pass the filename
let filename: Option<&str> = None;
return handlers::store_file(room, user, &json.file, filename);
return handlers::store_file(&room, &user, &json.file, filename);
}
// FIXME: deprecate these next two separate endpoints and replace with a single
@ -428,7 +429,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct JSON {
public_key: String
public_key: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -443,7 +444,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct JSON {
public_key: String
public_key: String,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -478,7 +479,7 @@ async fn handle_post_request(
room,
user,
&body.session_id,
body.admin.unwrap_or(false)
body.admin.unwrap_or(false),
);
}
"delete_messages" => {
@ -488,7 +489,7 @@ async fn handle_post_request(
reject_if_file_server_mode(path)?;
#[derive(Debug, Deserialize)]
struct JSON {
ids: Vec<i64>
ids: Vec<i64>,
}
let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(json) => json,
@ -510,9 +511,8 @@ async fn handle_delete_request(
room: Room,
rpc_call: RpcCall,
path: &str,
user: Option<User>
) -> Result<Response, Rejection>
{
user: Option<User>,
) -> Result<Response, Rejection> {
// Check that the auth token is present
let user = user.ok_or(Error::NoAuthToken)?;
// DELETE /messages/:server_id
@ -594,7 +594,7 @@ fn get_room(rpc_call: &RpcCall) -> Result<Option<Room>, Error> {
let room_token = match rpc_call.headers.get("Room") {
Some(s) => s,
None => return Ok(None)
None => return Ok(None),
};
return Ok(Some(storage::get_room_from_token(&*storage::get_conn()?, room_token)?));
}
@ -605,6 +605,6 @@ fn reject_if_file_server_mode(path: &str) -> Result<(), Rejection> {
warn!("Ignoring RPC call with invalid or unused endpoint: {}.", path);
return Err(Error::InvalidRpcCall.into());
}
Mode::OpenGroupServer => return Ok(())
Mode::OpenGroupServer => return Ok(()),
}
}

View File

@ -1,16 +1,17 @@
use std::fs;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::time::{Duration, SystemTime};
use log::{error, info, warn};
use r2d2::PooledConnection;
use r2d2_sqlite::SqliteConnectionManager;
use regex::Regex;
use rusqlite::{config::DbConfig, params};
use super::errors::Error;
use super::models::Room;
use super::migration;
use super::models::Room;
pub type DatabaseConnection = r2d2::PooledConnection<SqliteConnectionManager>;
pub type DatabaseConnectionPool = r2d2::Pool<SqliteConnectionManager>;
@ -32,7 +33,7 @@ impl RoomId {
Ok(())
} else {
Err(Error::ValidationFailed)
}
};
}
pub fn new(room_id: &str) -> Result<RoomId, Error> {
@ -51,12 +52,11 @@ pub const ROOM_ACTIVE_PRUNE_THRESHOLD: Duration = Duration::from_secs(60 * 86400
// How long we keep message edit/deletion history.
pub const MESSAGE_HISTORY_PRUNE_THRESHOLD: Duration = Duration::from_secs(30 * 86400);
// Migration support: when migrating to 0.2.x old room ids cannot be preserved, so we map the old
// id range [1, max] to the new range [offset+1, offset+max].
pub struct RoomMigrationMap {
pub max: i64,
pub offset: i64
pub offset: i64,
}
lazy_static::lazy_static! {
@ -127,9 +127,12 @@ pub fn get_conn() -> Result<DatabaseConnection, Error> {
}
}
fn get_room_hacks(conn: &rusqlite::Connection) -> Result<HashMap<i64, RoomMigrationMap>, rusqlite::Error> {
fn get_room_hacks(
conn: &rusqlite::Connection,
) -> Result<HashMap<i64, RoomMigrationMap>, rusqlite::Error> {
let mut hacks = HashMap::new();
let mut st = conn.prepare("SELECT room, old_message_id_max, message_id_offset FROM room_import_hacks")?;
let mut st =
conn.prepare("SELECT room, old_message_id_max, message_id_offset FROM room_import_hacks")?;
let mut query = st.query([])?;
while let Some(row) = query.next()? {
hacks.insert(row.get(0)?, RoomMigrationMap { max: row.get(1)?, offset: row.get(2)? });
@ -138,7 +141,7 @@ fn get_room_hacks(conn: &rusqlite::Connection) -> Result<HashMap<i64, RoomMigrat
}
pub fn get_transaction<'a>(
conn: &'a mut DatabaseConnection
conn: &'a mut DatabaseConnection,
) -> Result<DatabaseTransaction<'a>, Error> {
conn.transaction().map_err(db_error)
}
@ -150,16 +153,19 @@ pub fn db_error(e: rusqlite::Error) -> Error {
/// Initialize the database, creating and migrating its structure if necessary.
pub fn setup_database() {
let mut conn = get_conn().unwrap();
setup_database_with_conn(&mut conn);
}
pub fn setup_database_with_conn(conn: &mut PooledConnection<SqliteConnectionManager>) {
if rusqlite::version_number() < 3035000 {
panic!("SQLite 3.35.0+ is required!");
}
let mut conn = get_conn().unwrap();
let have_messages = match conn.query_row(
"SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'messages')",
params![],
|row| row.get::<_, bool>(0)
|row| row.get::<_, bool>(0),
) {
Ok(exists) => exists,
Err(e) => {
@ -172,18 +178,19 @@ pub fn setup_database() {
conn.execute_batch(include_str!("schema.sql")).expect("Couldn't create database schema.");
}
let n_rooms = match conn.query_row("SELECT COUNT(*) FROM rooms", params![], |row| row.get::<_, i64>(0)) {
Ok(r) => r,
Err(e) => {
panic!("Error querying database for # of rooms: {}", e);
}
};
let n_rooms =
match conn.query_row("SELECT COUNT(*) FROM rooms", params![], |row| row.get::<_, i64>(0)) {
Ok(r) => r,
Err(e) => {
panic!("Error querying database for # of rooms: {}", e);
}
};
if n_rooms == 0 && Path::new("database.db").exists() {
// If we have no rooms then check to see if there is an old (pre-v0.2) set of databases to
// import from.
warn!("No rooms found, but database.db exists; attempting migration");
if let Err(e) = migration::migrate_0_2_0(&mut conn) {
if let Err(e) = migration::migrate_0_2_0(conn) {
panic!("\n\ndatabase.db exists but migration failed:\n\n {}.\n\n\
Please report this bug!\n\n\
If no migration from 0.1.x is needed then rename or delete database.db to start up with a fresh (new) database.\n\n", e);
@ -193,7 +200,6 @@ pub fn setup_database() {
// Future migrations here
}
// Performs periodic DB maintenance: file pruning, delayed permission applying,
// etc.
pub async fn db_maintenance_job() {
@ -216,7 +222,7 @@ pub async fn db_maintenance_job() {
/// Removes all files with expiries <= the given time (which should generally by
/// `SystemTime::now()`, except in the test suite).
fn prune_files(conn: &mut DatabaseConnection, now: &SystemTime) {
pub fn prune_files(conn: &mut DatabaseConnection, now: &SystemTime) {
let mut st = match conn.prepare_cached("DELETE FROM files WHERE expiry <= ? RETURNING path") {
Ok(st) => st,
Err(e) => {
@ -315,7 +321,7 @@ fn apply_permission_updates(conn: &mut DatabaseConnection, now: &SystemTime) {
ON CONFLICT DO UPDATE SET
read = COALESCE(excluded.read, read),
write = COALESCE(excluded.write, write),
upload = COALESCE(excluded.upload, upload)"
upload = COALESCE(excluded.upload, upload)",
) {
Ok(st) => st,
Err(e) => {
@ -355,11 +361,7 @@ fn apply_permission_updates(conn: &mut DatabaseConnection, now: &SystemTime) {
// Utilities
pub fn get_room_from_token(
conn: &rusqlite::Connection,
token: &str
) -> Result<Room, Error>
{
pub fn get_room_from_token(conn: &rusqlite::Connection, token: &str) -> Result<Room, Error> {
match conn
.prepare_cached("SELECT * FROM rooms WHERE token = ?")
.map_err(db_error)?
@ -367,6 +369,6 @@ pub fn get_room_from_token(
{
Ok(room) => return Ok(room),
Err(rusqlite::Error::QueryReturnedNoRows) => return Err(Error::NoSuchRoom.into()),
Err(_) => return Err(Error::DatabaseFailedInternally.into())
Err(_) => return Err(Error::DatabaseFailedInternally.into()),
}
}

View File

@ -1,18 +1,25 @@
use std::collections::HashMap;
//use std::collections::HashMap;
use r2d2::PooledConnection;
use r2d2_sqlite::SqliteConnectionManager;
use std::fs;
use std::time::{Duration, SystemTime};
use rand::{thread_rng, Rng};
//use rand::{thread_rng, Rng};
use rusqlite::params;
use rusqlite::OpenFlags;
use warp::http::StatusCode;
use warp::{hyper, Reply};
use crate::storage::DatabaseConnectionPool;
use super::crypto;
use super::handlers;
use super::handlers::CreateRoom;
use super::models::User;
use super::storage;
use crate::handlers::GenericStringResponse;
async fn set_up_test_room() -> DatabaseConnectionPool {
async fn set_up_test_room() -> (PooledConnection<SqliteConnectionManager>, DatabaseConnectionPool) {
let manager = r2d2_sqlite::SqliteConnectionManager::file("file::memory:?cache=shared");
let mut flags = OpenFlags::default();
flags.set(OpenFlags::SQLITE_OPEN_URI, true);
@ -21,94 +28,71 @@ async fn set_up_test_room() -> DatabaseConnectionPool {
let pool = r2d2::Pool::<r2d2_sqlite::SqliteConnectionManager>::new(manager).unwrap();
let conn = pool.get().unwrap();
let mut conn = pool.get().unwrap();
storage::create_room_tables_if_needed(&conn);
storage::setup_database_with_conn(&mut conn);
let success = handlers::create_room_with_conn(
&conn,
&CreateRoom { token: "test_room".to_string(), name: "Test".to_string() },
);
assert!(success.is_ok());
pool
return (conn, pool);
}
fn get_auth_token(pool: &DatabaseConnectionPool) -> (String, String) {
fn get_user(conn: &rusqlite::Connection) -> User {
// Generate a fake user key pair
let (user_private_key, user_public_key) = crypto::generate_x25519_key_pair();
let (_, user_public_key) = crypto::generate_x25519_key_pair();
let hex_user_public_key = format!("05{}", hex::encode(user_public_key.to_bytes()));
// Get a challenge
let mut query_params: HashMap<String, String> = HashMap::new();
query_params.insert("public_key".to_string(), hex_user_public_key.clone());
let challenge = handlers::get_auth_token_challenge(query_params, &pool).unwrap();
// Generate a symmetric key
let ephemeral_public_key = base64::decode(challenge.ephemeral_public_key).unwrap();
let symmetric_key =
crypto::get_x25519_symmetric_key(&ephemeral_public_key, &user_private_key).unwrap();
// Decrypt the challenge
let ciphertext = base64::decode(challenge.ciphertext).unwrap();
let plaintext = crypto::decrypt_aes_gcm(&ciphertext, &symmetric_key).unwrap();
let auth_token = hex::encode(plaintext);
// Try to claim the token
let response = handlers::claim_auth_token(&hex_user_public_key, &auth_token, &pool).unwrap();
assert_eq!(response.status(), StatusCode::OK);
// return
return (auth_token, hex_user_public_key);
}
#[tokio::test]
async fn test_authorization() {
// Ensure the test room is set up and get a database connection pool
let pool = set_up_test_room().await;
// Get an auth token
// This tests claiming a token internally
let (_, hex_user_public_key) = get_auth_token(&pool);
// Try to claim an incorrect token
let mut incorrect_token = [0u8; 48];
thread_rng().fill(&mut incorrect_token[..]);
let hex_incorrect_token = hex::encode(incorrect_token);
match handlers::claim_auth_token(&hex_user_public_key, &hex_incorrect_token, &pool) {
Ok(_) => assert!(false),
Err(_) => ()
}
let result = handlers::insert_or_update_user(conn, &hex_user_public_key);
assert!(result.is_ok());
return result.unwrap();
}
#[tokio::test]
async fn test_file_handling() {
// Ensure the test room is set up and get a database connection pool
let pool = set_up_test_room().await;
let (mut conn, pool) = set_up_test_room().await;
let test_room_id = storage::RoomId::new("test_room").unwrap();
let room = storage::get_room_from_token(&conn, "test_room").unwrap();
// Get an auth token
let (auth_token, _) = get_auth_token(&pool);
let user = get_user(&conn);
// Store the test file
handlers::store_file(
Some(test_room_id.get_id().to_string()),
TEST_FILE,
Some(auth_token.clone()),
&pool
)
.await
.unwrap();
// Check that there's a file record
let conn = pool.get().unwrap();
let raw_query = "SELECT id FROM files";
let id_as_string: String =
conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap();
let id = id_as_string.parse::<u64>().unwrap();
let filename: Option<&str> = None;
let auth = handlers::AuthorizationRequired { upload: true, write: true, ..Default::default() };
let id =
match handlers::store_file_impl(&mut conn, &room, &user, auth, TEST_FILE, filename, true)
.ok()
{
Some(mut upload) => {
let result = upload.commit();
assert!(result.is_ok());
Some(upload.id)
}
_ => None,
}
.unwrap();
// Retrieve the file and check the content
let base64_encoded_file = handlers::get_file(
Some(test_room_id.get_id().to_string()),
id,
Some(auth_token.clone()),
&pool,
)
.await
.unwrap()
.result;
assert_eq!(base64_encoded_file, TEST_FILE);
let room_id = room.id;
let response = handlers::get_file_conn(&mut conn, &room, id, user).unwrap();
let response_bytes = hyper::body::to_bytes(response).await.ok();
// The expected json response
let json = GenericStringResponse {
status_code: StatusCode::OK.as_u16(),
result: TEST_FILE.to_string(),
};
let expected_result = warp::reply::json(&json).into_response();
let expected_result_bytes = hyper::body::to_bytes(expected_result).await.ok();
assert_eq!(response_bytes, expected_result_bytes);
// Prune the file and check that it's gone
// Will evaluate to now + 60
storage::prune_files_for_room(&pool, &test_room_id, -60).await;
let sixty_seconds_ago = SystemTime::now() - Duration::new(60, 0);
storage::prune_files(&mut conn, &sixty_seconds_ago);
// It should be gone now
fs::read(format!("files/{}_files/{}", test_room_id.get_id(), id)).unwrap_err();
fs::read(format!("files/{}_files/{}", room_id, id)).unwrap_err();
// Check that the file record is also gone
let conn = pool.get().unwrap();
let raw_query = "SELECT id FROM files";