Add .rustfmt.toml

This commit is contained in:
Niels Andriesse 2021-03-25 10:56:16 +11:00
parent fd0025da94
commit edca4b9e95
11 changed files with 457 additions and 251 deletions

15
.rustfmt.toml Normal file
View File

@ -0,0 +1,15 @@
edition = "2018"
unstable_features = true
blank_lines_upper_bound = 3
brace_style = "PreferSameLine"
combine_control_expr = true
fn_args_layout = "Compressed"
fn_single_line = true
imports_indent = "Visual"
overflow_delimited_expr = true
group_imports = "StdExternalCrate"
trailing_comma = "Never"
use_field_init_shorthand = true
use_small_heuristics = "Max"
where_single_line = true

View File

@ -1,7 +1,7 @@
use std::convert::TryInto; use std::convert::TryInto;
use aes_gcm::aead::{generic_array::GenericArray, Aead, NewAead};
use aes_gcm::Aes256Gcm; use aes_gcm::Aes256Gcm;
use aes_gcm::aead::{Aead, NewAead, generic_array::GenericArray};
use hmac::{Hmac, Mac, NewMac}; use hmac::{Hmac, Mac, NewMac};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use rand_core::OsRng; use rand_core::OsRng;
@ -32,9 +32,14 @@ lazy_static::lazy_static! {
}; };
} }
pub async fn get_x25519_symmetric_key(public_key: &[u8], private_key: &x25519_dalek::StaticSecret) -> Result<Vec<u8>, warp::reject::Rejection> { pub async fn get_x25519_symmetric_key(
public_key: &[u8], private_key: &x25519_dalek::StaticSecret
) -> Result<Vec<u8>, warp::reject::Rejection> {
if public_key.len() != 32 { if public_key.len() != 32 {
println!("Couldn't create symmetric key using public key of invalid length: {}.", hex::encode(public_key)); println!(
"Couldn't create symmetric key using public key of invalid length: {}.",
hex::encode(public_key)
);
return Err(warp::reject::custom(Error::DecryptionFailed)); return Err(warp::reject::custom(Error::DecryptionFailed));
} }
let public_key: [u8; 32] = public_key.try_into().unwrap(); // Safe because we know it has a length of 32 at this point let public_key: [u8; 32] = public_key.try_into().unwrap(); // Safe because we know it has a length of 32 at this point
@ -45,7 +50,9 @@ pub async fn get_x25519_symmetric_key(public_key: &[u8], private_key: &x25519_da
return Ok(mac.finalize().into_bytes().to_vec()); return Ok(mac.finalize().into_bytes().to_vec());
} }
pub async fn encrypt_aes_gcm(plaintext: &[u8], symmetric_key: &[u8]) -> Result<Vec<u8>, warp::reject::Rejection> { pub async fn encrypt_aes_gcm(
plaintext: &[u8], symmetric_key: &[u8]
) -> Result<Vec<u8>, warp::reject::Rejection> {
let mut iv = [0u8; IV_SIZE]; let mut iv = [0u8; IV_SIZE];
thread_rng().fill(&mut iv[..]); thread_rng().fill(&mut iv[..]);
let cipher = Aes256Gcm::new(&GenericArray::from_slice(symmetric_key)); let cipher = Aes256Gcm::new(&GenericArray::from_slice(symmetric_key));
@ -54,7 +61,7 @@ pub async fn encrypt_aes_gcm(plaintext: &[u8], symmetric_key: &[u8]) -> Result<V
let mut iv_and_ciphertext = iv.to_vec(); let mut iv_and_ciphertext = iv.to_vec();
iv_and_ciphertext.append(&mut ciphertext); iv_and_ciphertext.append(&mut ciphertext);
return Ok(iv_and_ciphertext); return Ok(iv_and_ciphertext);
}, }
Err(e) => { Err(e) => {
println!("Couldn't encrypt ciphertext due to error: {}.", e); println!("Couldn't encrypt ciphertext due to error: {}.", e);
return Err(warp::reject::custom(Error::DecryptionFailed)); return Err(warp::reject::custom(Error::DecryptionFailed));
@ -62,7 +69,9 @@ pub async fn encrypt_aes_gcm(plaintext: &[u8], symmetric_key: &[u8]) -> Result<V
}; };
} }
pub async fn decrypt_aes_gcm(iv_and_ciphertext: &[u8], symmetric_key: &[u8]) -> Result<Vec<u8>, warp::reject::Rejection> { pub async fn decrypt_aes_gcm(
iv_and_ciphertext: &[u8], symmetric_key: &[u8]
) -> Result<Vec<u8>, warp::reject::Rejection> {
if iv_and_ciphertext.len() < IV_SIZE { if iv_and_ciphertext.len() < IV_SIZE {
println!("Ignoring ciphertext of invalid size: {}.", iv_and_ciphertext.len()); println!("Ignoring ciphertext of invalid size: {}.", iv_and_ciphertext.len());
return Err(warp::reject::custom(Error::DecryptionFailed)); return Err(warp::reject::custom(Error::DecryptionFailed));

View File

@ -1,4 +1,4 @@
use warp::{http::StatusCode, Rejection, reply::Reply, reply::Response}; use warp::{http::StatusCode, reply::Reply, reply::Response, Rejection};
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
@ -15,10 +15,14 @@ impl warp::reject::Reject for Error { }
pub fn into_response(e: Rejection) -> Result<Response, Rejection> { pub fn into_response(e: Rejection) -> Result<Response, Rejection> {
if let Some(error) = e.find::<Error>() { if let Some(error) = e.find::<Error>() {
match error { match error {
Error::DecryptionFailed | Error::InvalidOnionRequest | Error::InvalidRpcCall Error::DecryptionFailed
| Error::InvalidOnionRequest
| Error::InvalidRpcCall
| Error::ValidationFailed => return Ok(StatusCode::BAD_REQUEST.into_response()), | Error::ValidationFailed => return Ok(StatusCode::BAD_REQUEST.into_response()),
Error::Unauthorized => return Ok(StatusCode::FORBIDDEN.into_response()), Error::Unauthorized => return Ok(StatusCode::FORBIDDEN.into_response()),
Error::DatabaseFailedInternally => return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()) Error::DatabaseFailedInternally => {
return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response())
}
}; };
} else { } else {
return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()); return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());

View File

@ -1,15 +1,15 @@
use std::collections::HashMap;
use std::convert::TryInto; use std::convert::TryInto;
use std::fs; use std::fs;
use std::collections::HashMap;
use std::io::prelude::*; use std::io::prelude::*;
use std::path::Path; use std::path::Path;
use chrono; use chrono;
use serde::{Deserialize, Serialize};
use rusqlite::params;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use rusqlite::params;
use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use warp::{Rejection, http::StatusCode, reply::Reply, reply::Response}; use warp::{http::StatusCode, reply::Reply, reply::Response, Rejection};
use super::crypto; use super::crypto;
use super::errors::Error; use super::errors::Error;
@ -52,10 +52,15 @@ pub async fn create_room(id: &str, name: &str) -> Result<Response, Rejection> {
// Files // Files
pub async fn store_file(base64_encoded_bytes: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn store_file(
base64_encoded_bytes: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Parse bytes // Parse bytes
let bytes = match base64::decode(base64_encoded_bytes) { let bytes = match base64::decode(base64_encoded_bytes) {
Ok(bytes) => bytes, Ok(bytes) => bytes,
@ -107,10 +112,16 @@ pub async fn store_file(base64_encoded_bytes: &str, auth_token: &str, pool: &sto
return Ok(warp::reply::json(&json).into_response()); return Ok(warp::reply::json(&json).into_response());
} }
pub async fn get_file(id: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<GenericStringResponse, Rejection> { // Doesn't return a response directly for testing purposes pub async fn get_file(
id: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<GenericStringResponse, Rejection> {
// Doesn't return a response directly for testing purposes
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Check that the ID is a valid UUID // Check that the ID is a valid UUID
match Uuid::parse_str(id) { match Uuid::parse_str(id) {
Ok(_) => (), Ok(_) => (),
@ -132,15 +143,22 @@ pub async fn get_file(id: &str, auth_token: &str, pool: &storage::DatabaseConnec
// Base64 encode the result // Base64 encode the result
let base64_encoded_bytes = base64::encode(bytes); let base64_encoded_bytes = base64::encode(bytes);
// Return // Return
let json = GenericStringResponse { status_code : StatusCode::OK.as_u16(), result : base64_encoded_bytes }; let json = GenericStringResponse {
status_code: StatusCode::OK.as_u16(),
result: base64_encoded_bytes
};
return Ok(json); return Ok(json);
} }
// Authentication // Authentication
pub async fn get_auth_token_challenge(query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool) -> Result<models::Challenge, Rejection> { // Doesn't return a response directly for testing purposes pub async fn get_auth_token_challenge(
query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool
) -> Result<models::Challenge, Rejection> {
// Doesn't return a response directly for testing purposes
// Get the public key // Get the public key
let hex_public_key = query_params.get("public_key").ok_or(warp::reject::custom(Error::InvalidRpcCall))?; let hex_public_key =
query_params.get("public_key").ok_or(warp::reject::custom(Error::InvalidRpcCall))?;
// Validate the public key // Validate the public key
if !is_valid_public_key(hex_public_key) { if !is_valid_public_key(hex_public_key) {
println!("Ignoring challenge request for invalid public key: {}.", hex_public_key); println!("Ignoring challenge request for invalid public key: {}.", hex_public_key);
@ -151,7 +169,8 @@ pub async fn get_auth_token_challenge(query_params: HashMap<String, String>, poo
// Generate an ephemeral key pair // Generate an ephemeral key pair
let (ephemeral_private_key, ephemeral_public_key) = crypto::generate_x25519_key_pair().await; let (ephemeral_private_key, ephemeral_public_key) = crypto::generate_x25519_key_pair().await;
// Generate a symmetric key from the requesting user's public key and the ephemeral private key // Generate a symmetric key from the requesting user's public key and the ephemeral private key
let symmetric_key = crypto::get_x25519_symmetric_key(&public_key, &ephemeral_private_key).await?; let symmetric_key =
crypto::get_x25519_symmetric_key(&public_key, &ephemeral_private_key).await?;
// Generate a random token // Generate a random token
let mut token = [0u8; 48]; let mut token = [0u8; 48];
thread_rng().fill(&mut token[..]); thread_rng().fill(&mut token[..]);
@ -159,7 +178,10 @@ pub async fn get_auth_token_challenge(query_params: HashMap<String, String>, poo
// Note that a given public key can have multiple pending tokens // Note that a given public key can have multiple pending tokens
let now = chrono::Utc::now().timestamp(); let now = chrono::Utc::now().timestamp();
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let stmt = format!("INSERT INTO {} (public_key, timestamp, token) VALUES (?1, ?2, ?3)", storage::PENDING_TOKENS_TABLE); let stmt = format!(
"INSERT INTO {} (public_key, timestamp, token) VALUES (?1, ?2, ?3)",
storage::PENDING_TOKENS_TABLE
);
let _ = match conn.execute(&stmt, params![hex_public_key, now, token.to_vec()]) { let _ = match conn.execute(&stmt, params![hex_public_key, now, token.to_vec()]) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
@ -170,10 +192,15 @@ pub async fn get_auth_token_challenge(query_params: HashMap<String, String>, poo
// Encrypt the token with the symmetric key // Encrypt the token with the symmetric key
let ciphertext = crypto::encrypt_aes_gcm(&token, &symmetric_key).await?; let ciphertext = crypto::encrypt_aes_gcm(&token, &symmetric_key).await?;
// Return // Return
return Ok(models::Challenge { ciphertext : base64::encode(ciphertext), ephemeral_public_key : base64::encode(ephemeral_public_key.to_bytes()) }); return Ok(models::Challenge {
ciphertext: base64::encode(ciphertext),
ephemeral_public_key: base64::encode(ephemeral_public_key.to_bytes())
});
} }
pub async fn claim_auth_token(public_key: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn claim_auth_token(
public_key: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Validate the public key // Validate the public key
if !is_valid_public_key(&public_key) { if !is_valid_public_key(&public_key) {
println!("Ignoring claim token request for invalid public key."); println!("Ignoring claim token request for invalid public key.");
@ -187,13 +214,16 @@ pub async fn claim_auth_token(public_key: &str, auth_token: &str, pool: &storage
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Get the pending tokens for the given public key // Get the pending tokens for the given public key
let raw_query = format!("SELECT timestamp, token FROM {} WHERE public_key = (?1) AND timestamp > (?2)", storage::PENDING_TOKENS_TABLE); let raw_query = format!(
"SELECT timestamp, token FROM {} WHERE public_key = (?1) AND timestamp > (?2)",
storage::PENDING_TOKENS_TABLE
);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let now = chrono::Utc::now().timestamp(); let now = chrono::Utc::now().timestamp();
let expiration = now - storage::PENDING_TOKEN_EXPIRATION; let expiration = now - storage::PENDING_TOKEN_EXPIRATION;
let rows = match query.query_map(params![ public_key, expiration ], |row| { let rows = match query
Ok((row.get(0)?, row.get(1)?)) .query_map(params![public_key, expiration], |row| Ok((row.get(0)?, row.get(1)?)))
}) { {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't get pending tokens due to error: {}.", e); println!("Couldn't get pending tokens due to error: {}.", e);
@ -203,10 +233,16 @@ pub async fn claim_auth_token(public_key: &str, auth_token: &str, pool: &storage
let pending_tokens: Vec<(i64, Vec<u8>)> = rows.filter_map(|result| result.ok()).collect(); let pending_tokens: Vec<(i64, Vec<u8>)> = rows.filter_map(|result| result.ok()).collect();
// Check that the token being claimed is in fact one of the pending tokens // Check that the token being claimed is in fact one of the pending tokens
let claim = hex::decode(auth_token).unwrap(); // Safe because we validated it above let claim = hex::decode(auth_token).unwrap(); // Safe because we validated it above
let index = pending_tokens.iter().position(|(_, pending_token)| *pending_token == claim).ok_or_else(|| Error::Unauthorized)?; let index = pending_tokens
.iter()
.position(|(_, pending_token)| *pending_token == claim)
.ok_or_else(|| Error::Unauthorized)?;
let token = &pending_tokens[index].1; let token = &pending_tokens[index].1;
// Store the claimed token // Store the claimed token
let stmt = format!("INSERT OR REPLACE INTO {} (public_key, token) VALUES (?1, ?2)", storage::TOKENS_TABLE); let stmt = format!(
"INSERT OR REPLACE INTO {} (public_key, token) VALUES (?1, ?2)",
storage::TOKENS_TABLE
);
match conn.execute(&stmt, params![public_key, hex::encode(token)]) { match conn.execute(&stmt, params![public_key, hex::encode(token)]) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
@ -225,10 +261,15 @@ pub async fn claim_auth_token(public_key: &str, auth_token: &str, pool: &storage
return Ok(warp::reply::json(&json).into_response()); return Ok(warp::reply::json(&json).into_response());
} }
pub async fn delete_auth_token(auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn delete_auth_token(
auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, requesting_public_key) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, requesting_public_key) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Delete the token // Delete the token
@ -248,20 +289,28 @@ pub async fn delete_auth_token(auth_token: &str, pool: &storage::DatabaseConnect
// Message sending & receiving // Message sending & receiving
/// Inserts the given `message` into the database if it's valid. /// Inserts the given `message` into the database if it's valid.
pub async fn insert_message(mut message: models::Message, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn insert_message(
mut message: models::Message, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Validate the message // Validate the message
if !message.is_valid() { if !message.is_valid() {
println!("Ignoring invalid message."); println!("Ignoring invalid message.");
return Err(warp::reject::custom(Error::ValidationFailed)); return Err(warp::reject::custom(Error::ValidationFailed));
} }
// Check authorization level // Check authorization level
let (has_authorization_level, requesting_public_key) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, requesting_public_key) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Get a connection and open a transaction // Get a connection and open a transaction
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?; let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
// Insert the message // Insert the message
let stmt = format!("INSERT INTO {} (public_key, data, signature) VALUES (?1, ?2, ?3)", storage::MESSAGES_TABLE); let stmt = format!(
"INSERT INTO {} (public_key, data, signature) VALUES (?1, ?2, ?3)",
storage::MESSAGES_TABLE
);
match tx.execute(&stmt, params![&requesting_public_key, message.data, message.signature]) { match tx.execute(&stmt, params![&requesting_public_key, message.data, message.signature]) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
@ -280,15 +329,20 @@ pub async fn insert_message(mut message: models::Message, auth_token: &str, pool
status_code: u16, status_code: u16,
message: models::Message message: models::Message
} }
let response = Response { status_code : StatusCode::OK.as_u16(), message : message }; let response = Response { status_code: StatusCode::OK.as_u16(), message };
return Ok(warp::reply::json(&response).into_response()); return Ok(warp::reply::json(&response).into_response());
} }
/// Returns either the last `limit` messages or all messages since `from_server_id, limited to `limit`. /// Returns either the last `limit` messages or all messages since `from_server_id, limited to `limit`.
pub async fn get_messages(query_params: HashMap<String, String>, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn get_messages(
query_params: HashMap<String, String>, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Unwrap query parameters // Unwrap query parameters
@ -309,11 +363,19 @@ pub async fn get_messages(query_params: HashMap<String, String>, auth_token: &st
if query_params.get("from_server_id").is_some() { if query_params.get("from_server_id").is_some() {
raw_query = format!("SELECT id, public_key, data, signature FROM {} WHERE rowid > (?1) ORDER BY rowid ASC LIMIT (?2)", storage::MESSAGES_TABLE); raw_query = format!("SELECT id, public_key, data, signature FROM {} WHERE rowid > (?1) ORDER BY rowid ASC LIMIT (?2)", storage::MESSAGES_TABLE);
} else { } else {
raw_query = format!("SELECT id, public_key, data, signature FROM {} ORDER BY rowid DESC LIMIT (?2)", storage::MESSAGES_TABLE); raw_query = format!(
"SELECT id, public_key, data, signature FROM {} ORDER BY rowid DESC LIMIT (?2)",
storage::MESSAGES_TABLE
);
} }
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![from_server_id, limit], |row| { let rows = match query.query_map(params![from_server_id, limit], |row| {
Ok(models::Message { server_id : row.get(0)?, public_key : row.get(1)?, data : row.get(2)?, signature : row.get(3)? }) Ok(models::Message {
server_id: row.get(0)?,
public_key: row.get(1)?,
data: row.get(2)?,
signature: row.get(3)?
})
}) { }) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
@ -328,25 +390,29 @@ pub async fn get_messages(query_params: HashMap<String, String>, auth_token: &st
status_code: u16, status_code: u16,
messages: Vec<models::Message> messages: Vec<models::Message>
} }
let response = Response { status_code : StatusCode::OK.as_u16(), messages : messages }; let response = Response { status_code: StatusCode::OK.as_u16(), messages };
return Ok(warp::reply::json(&response).into_response()); return Ok(warp::reply::json(&response).into_response());
} }
// Message deletion // Message deletion
/// Deletes the message with the given `row_id` from the database, if it's present. /// Deletes the message with the given `row_id` from the database, if it's present.
pub async fn delete_message(row_id: i64, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn delete_message(
row_id: i64, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, requesting_public_key) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, requesting_public_key) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Check that the requesting user is either the sender of the message or a moderator // Check that the requesting user is either the sender of the message or a moderator
let sender_option: Option<String> = { let sender_option: Option<String> = {
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let raw_query = format!("SELECT public_key FROM {} WHERE rowid = (?1)", storage::MESSAGES_TABLE); let raw_query =
format!("SELECT public_key FROM {} WHERE rowid = (?1)", storage::MESSAGES_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![ row_id ], |row| { let rows = match query.query_map(params![row_id], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't delete message due to error: {}.", e); println!("Couldn't delete message due to error: {}.", e);
@ -357,7 +423,9 @@ pub async fn delete_message(row_id: i64, auth_token: &str, pool: &storage::Datab
public_keys.get(0).map(|s| s.to_string()) public_keys.get(0).map(|s| s.to_string())
}; };
let sender = sender_option.ok_or(warp::reject::custom(Error::DatabaseFailedInternally))?; let sender = sender_option.ok_or(warp::reject::custom(Error::DatabaseFailedInternally))?;
if !is_moderator(&requesting_public_key, pool).await? && requesting_public_key != sender { return Err(warp::reject::custom(Error::Unauthorized)); } if !is_moderator(&requesting_public_key, pool).await? && requesting_public_key != sender {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Get a connection and open a transaction // Get a connection and open a transaction
let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let mut conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?; let tx = conn.transaction().map_err(|_| Error::DatabaseFailedInternally)?;
@ -389,10 +457,15 @@ pub async fn delete_message(row_id: i64, auth_token: &str, pool: &storage::Datab
} }
/// Returns either the last `limit` deleted messages or all deleted messages since `from_server_id, limited to `limit`. /// Returns either the last `limit` deleted messages or all deleted messages since `from_server_id, limited to `limit`.
pub async fn get_deleted_messages(query_params: HashMap<String, String>, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn get_deleted_messages(
query_params: HashMap<String, String>, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Unwrap query parameters // Unwrap query parameters
@ -411,14 +484,18 @@ pub async fn get_deleted_messages(query_params: HashMap<String, String>, auth_to
// Query the database // Query the database
let raw_query: String; let raw_query: String;
if query_params.get("from_server_id").is_some() { if query_params.get("from_server_id").is_some() {
raw_query = format!("SELECT id FROM {} WHERE rowid > (?1) ORDER BY rowid ASC LIMIT (?2)", storage::DELETED_MESSAGES_TABLE); raw_query = format!(
"SELECT id FROM {} WHERE rowid > (?1) ORDER BY rowid ASC LIMIT (?2)",
storage::DELETED_MESSAGES_TABLE
);
} else { } else {
raw_query = format!("SELECT id FROM {} ORDER BY rowid DESC LIMIT (?2)", storage::DELETED_MESSAGES_TABLE); raw_query = format!(
"SELECT id FROM {} ORDER BY rowid DESC LIMIT (?2)",
storage::DELETED_MESSAGES_TABLE
);
} }
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![ from_server_id, limit ], |row| { let rows = match query.query_map(params![from_server_id, limit], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't query database due to error: {}.", e); println!("Couldn't query database due to error: {}.", e);
@ -432,17 +509,22 @@ pub async fn get_deleted_messages(query_params: HashMap<String, String>, auth_to
status_code: u16, status_code: u16,
ids: Vec<i64> ids: Vec<i64>
} }
let response = Response { status_code : StatusCode::OK.as_u16(), ids : ids }; let response = Response { status_code: StatusCode::OK.as_u16(), ids };
return Ok(warp::reply::json(&response).into_response()); return Ok(warp::reply::json(&response).into_response());
} }
// Moderation // Moderation
/// Returns the full list of moderators. /// Returns the full list of moderators.
pub async fn get_moderators(auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn get_moderators(
auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Return // Return
let public_keys = get_moderators_vector(pool).await?; let public_keys = get_moderators_vector(pool).await?;
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -455,17 +537,24 @@ pub async fn get_moderators(auth_token: &str, pool: &storage::DatabaseConnection
} }
/// Bans the given `public_key` if the requesting user is a moderator. /// Bans the given `public_key` if the requesting user is a moderator.
pub async fn ban(public_key: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn ban(
public_key: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Validate the public key // Validate the public key
if !is_valid_public_key(&public_key) { if !is_valid_public_key(&public_key) {
println!("Ignoring ban request for invalid public key."); println!("Ignoring ban request for invalid public key.");
return Err(warp::reject::custom(Error::ValidationFailed)); return Err(warp::reject::custom(Error::ValidationFailed));
} }
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Moderator, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Moderator, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Don't double ban public keys // Don't double ban public keys
if is_banned(&public_key, pool).await? { return Ok(StatusCode::OK.into_response()); } if is_banned(&public_key, pool).await? {
return Ok(StatusCode::OK.into_response());
}
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Insert the message // Insert the message
@ -483,17 +572,24 @@ pub async fn ban(public_key: &str, auth_token: &str, pool: &storage::DatabaseCon
} }
/// Unbans the given `public_key` if the requesting user is a moderator. /// Unbans the given `public_key` if the requesting user is a moderator.
pub async fn unban(public_key: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn unban(
public_key: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Validate the public key // Validate the public key
if !is_valid_public_key(&public_key) { if !is_valid_public_key(&public_key) {
println!("Ignoring unban request for invalid public key."); println!("Ignoring unban request for invalid public key.");
return Err(warp::reject::custom(Error::ValidationFailed)); return Err(warp::reject::custom(Error::ValidationFailed));
} }
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Moderator, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Moderator, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Don't double unban public keys // Don't double unban public keys
if !is_banned(&public_key, pool).await? { return Ok(StatusCode::OK.into_response()); } if !is_banned(&public_key, pool).await? {
return Ok(StatusCode::OK.into_response());
}
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Insert the message // Insert the message
@ -511,10 +607,15 @@ pub async fn unban(public_key: &str, auth_token: &str, pool: &storage::DatabaseC
} }
/// Returns the full list of banned public keys. /// Returns the full list of banned public keys.
pub async fn get_banned_public_keys(auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn get_banned_public_keys(
auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Return // Return
let public_keys = get_banned_public_keys_vector(pool).await?; let public_keys = get_banned_public_keys_vector(pool).await?;
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -528,18 +629,21 @@ pub async fn get_banned_public_keys(auth_token: &str, pool: &storage::DatabaseCo
// General // General
pub async fn get_member_count(auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { pub async fn get_member_count(
auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check authorization level // Check authorization level
let (has_authorization_level, _) = has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?; let (has_authorization_level, _) =
if !has_authorization_level { return Err(warp::reject::custom(Error::Unauthorized)); } has_authorization_level(auth_token, AuthorizationLevel::Basic, pool).await?;
if !has_authorization_level {
return Err(warp::reject::custom(Error::Unauthorized));
}
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Query the database // Query the database
let raw_query = format!("SELECT public_key FROM {}", storage::TOKENS_TABLE); let raw_query = format!("SELECT public_key FROM {}", storage::TOKENS_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![], |row| { let rows = match query.query_map(params![], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't query database due to error: {}.", e); println!("Couldn't query database due to error: {}.", e);
@ -554,21 +658,22 @@ pub async fn get_member_count(auth_token: &str, pool: &storage::DatabaseConnecti
status_code: u16, status_code: u16,
member_count: usize member_count: usize
} }
let response = Response { status_code : StatusCode::OK.as_u16(), member_count : public_key_count }; let response =
Response { status_code: StatusCode::OK.as_u16(), member_count: public_key_count };
return Ok(warp::reply::json(&response).into_response()); return Ok(warp::reply::json(&response).into_response());
} }
// Utilities // Utilities
async fn get_moderators_vector(pool: &storage::DatabaseConnectionPool) -> Result<Vec<String>, Rejection> { async fn get_moderators_vector(
pool: &storage::DatabaseConnectionPool
) -> Result<Vec<String>, Rejection> {
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Query the database // Query the database
let raw_query = format!("SELECT public_key FROM {}", storage::MODERATORS_TABLE); let raw_query = format!("SELECT public_key FROM {}", storage::MODERATORS_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![], |row| { let rows = match query.query_map(params![], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't query database due to error: {}.", e); println!("Couldn't query database due to error: {}.", e);
@ -579,20 +684,22 @@ async fn get_moderators_vector(pool: &storage::DatabaseConnectionPool) -> Result
return Ok(rows.filter_map(|result| result.ok()).collect()); return Ok(rows.filter_map(|result| result.ok()).collect());
} }
async fn is_moderator(public_key: &str, pool: &storage::DatabaseConnectionPool) -> Result<bool, Rejection> { async fn is_moderator(
public_key: &str, pool: &storage::DatabaseConnectionPool
) -> Result<bool, Rejection> {
let public_keys = get_moderators_vector(&pool).await?; let public_keys = get_moderators_vector(&pool).await?;
return Ok(public_keys.contains(&public_key.to_owned())); return Ok(public_keys.contains(&public_key.to_owned()));
} }
async fn get_banned_public_keys_vector(pool: &storage::DatabaseConnectionPool) -> Result<Vec<String>, Rejection> { async fn get_banned_public_keys_vector(
pool: &storage::DatabaseConnectionPool
) -> Result<Vec<String>, Rejection> {
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Query the database // Query the database
let raw_query = format!("SELECT public_key FROM {}", storage::BLOCK_LIST_TABLE); let raw_query = format!("SELECT public_key FROM {}", storage::BLOCK_LIST_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![], |row| { let rows = match query.query_map(params![], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't query database due to error: {}.", e); println!("Couldn't query database due to error: {}.", e);
@ -603,29 +710,35 @@ async fn get_banned_public_keys_vector(pool: &storage::DatabaseConnectionPool) -
return Ok(rows.filter_map(|result| result.ok()).collect()); return Ok(rows.filter_map(|result| result.ok()).collect());
} }
async fn is_banned(public_key: &str, pool: &storage::DatabaseConnectionPool) -> Result<bool, Rejection> { async fn is_banned(
public_key: &str, pool: &storage::DatabaseConnectionPool
) -> Result<bool, Rejection> {
let public_keys = get_banned_public_keys_vector(&pool).await?; let public_keys = get_banned_public_keys_vector(&pool).await?;
return Ok(public_keys.contains(&public_key.to_owned())); return Ok(public_keys.contains(&public_key.to_owned()));
} }
fn is_valid_public_key(public_key: &str) -> bool { fn is_valid_public_key(public_key: &str) -> bool {
// Check that it's a valid hex encoding // Check that it's a valid hex encoding
if hex::decode(public_key).is_err() { return false; } if hex::decode(public_key).is_err() {
return false;
}
// Check that it's the right length // Check that it's the right length
if public_key.len() != 66 { return false } // The version byte + 32 bytes of random data if public_key.len() != 66 {
return false;
} // The version byte + 32 bytes of random data
// It appears to be a valid public key // It appears to be a valid public key
return true return true;
} }
async fn get_public_key_for_auth_token(auth_token: &str, pool: &storage::DatabaseConnectionPool) -> Result<Option<String>, Rejection> { async fn get_public_key_for_auth_token(
auth_token: &str, pool: &storage::DatabaseConnectionPool
) -> Result<Option<String>, Rejection> {
// Get a database connection // Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?; let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Query the database // Query the database
let raw_query = format!("SELECT public_key FROM {} WHERE token = (?1)", storage::TOKENS_TABLE); let raw_query = format!("SELECT public_key FROM {} WHERE token = (?1)", storage::TOKENS_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![ auth_token ], |row| { let rows = match query.query_map(params![auth_token], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't query database due to error: {}.", e); println!("Couldn't query database due to error: {}.", e);
@ -637,17 +750,23 @@ async fn get_public_key_for_auth_token(auth_token: &str, pool: &storage::Databas
return Ok(public_keys.get(0).map(|s| s.to_string())); return Ok(public_keys.get(0).map(|s| s.to_string()));
} }
async fn has_authorization_level(auth_token: &str, level: AuthorizationLevel, pool: &storage::DatabaseConnectionPool) -> Result<(bool, String), Rejection> { async fn has_authorization_level(
auth_token: &str, level: AuthorizationLevel, pool: &storage::DatabaseConnectionPool
) -> Result<(bool, String), Rejection> {
// Check that we have a public key associated with the given auth token // Check that we have a public key associated with the given auth token
let public_key_option = get_public_key_for_auth_token(auth_token, pool).await?; let public_key_option = get_public_key_for_auth_token(auth_token, pool).await?;
let public_key = public_key_option.ok_or(warp::reject::custom(Error::Unauthorized))?; let public_key = public_key_option.ok_or(warp::reject::custom(Error::Unauthorized))?;
// Check that the given public key isn't banned // Check that the given public key isn't banned
if is_banned(&public_key, pool).await? { return Err(warp::reject::custom(Error::Unauthorized)); } if is_banned(&public_key, pool).await? {
return Err(warp::reject::custom(Error::Unauthorized));
}
// If needed, check that the given public key is a moderator // If needed, check that the given public key is a moderator
match level { match level {
AuthorizationLevel::Basic => return Ok((true, public_key)), AuthorizationLevel::Basic => return Ok((true, public_key)),
AuthorizationLevel::Moderator => { AuthorizationLevel::Moderator => {
if !is_moderator(&public_key, pool).await? { return Err(warp::reject::custom(Error::Unauthorized)); } if !is_moderator(&public_key, pool).await? {
return Err(warp::reject::custom(Error::Unauthorized));
}
return Ok((true, public_key)); return Ok((true, public_key));
} }
}; };

View File

@ -1,8 +1,8 @@
use std::fs; use std::fs;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use futures::join; use futures::join;
use structopt::StructOpt; use structopt::StructOpt;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tokio; use tokio;
use warp::Filter; use warp::Filter;
@ -39,7 +39,7 @@ struct Opt {
/// Set IP to bind to. /// Set IP to bind to.
#[structopt(short = "H", long = "host", default_value = "0.0.0.0")] #[structopt(short = "H", long = "host", default_value = "0.0.0.0")]
host: Ipv4Addr, host: Ipv4Addr
} }
#[tokio::main] #[tokio::main]
@ -69,7 +69,12 @@ async fn main() {
println!("Running in plaintext mode on {}.", addr); println!("Running in plaintext mode on {}.", addr);
let serve_routes_future = warp::serve(routes).run(addr); let serve_routes_future = warp::serve(routes).run(addr);
// Keep futures alive // Keep futures alive
join!(prune_pending_tokens_future, prune_tokens_future, prune_files_future, serve_routes_future); join!(
prune_pending_tokens_future,
prune_tokens_future,
prune_files_future,
serve_routes_future
);
} else { } else {
println!("Running on {} with TLS.", addr); println!("Running on {} with TLS.", addr);
let serve_routes_future = warp::serve(routes) let serve_routes_future = warp::serve(routes)
@ -78,6 +83,11 @@ async fn main() {
.key_path(opt.tls_priv_key_file) .key_path(opt.tls_priv_key_file)
.run(addr); .run(addr);
// Keep futures alive // Keep futures alive
join!(prune_pending_tokens_future, prune_tokens_future, prune_files_future, serve_routes_future); join!(
prune_pending_tokens_future,
prune_tokens_future,
prune_files_future,
serve_routes_future
);
} }
} }

View File

@ -9,10 +9,7 @@ pub struct Message {
} }
impl Message { impl Message {
pub fn is_valid(&self) -> bool { return !self.data.is_empty() && !self.signature.is_empty(); }
pub fn is_valid(&self) -> bool {
return !self.data.is_empty() && !self.signature.is_empty();
}
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]

View File

@ -1,7 +1,7 @@
use std::convert::TryInto; use std::convert::TryInto;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use warp::{Rejection, reply::Reply, reply::Response, http::StatusCode}; use warp::{http::StatusCode, reply::Reply, reply::Response, Rejection};
use super::crypto; use super::crypto;
use super::errors::Error; use super::errors::Error;
@ -34,7 +34,9 @@ pub async fn handle_onion_request(blob: warp::hyper::body::Bytes) -> Result<Resp
return handle_decrypted_onion_request(&plaintext, &symmetric_key).await; return handle_decrypted_onion_request(&plaintext, &symmetric_key).await;
} }
async fn handle_decrypted_onion_request(plaintext: &[u8], symmetric_key: &[u8]) -> Result<Response, Rejection> { async fn handle_decrypted_onion_request(
plaintext: &[u8], symmetric_key: &[u8]
) -> Result<Response, Rejection> {
let rpc_call = match serde_json::from_slice(plaintext) { let rpc_call = match serde_json::from_slice(plaintext) {
Ok(rpc_call) => rpc_call, Ok(rpc_call) => rpc_call,
Err(e) => { Err(e) => {
@ -43,7 +45,8 @@ async fn handle_decrypted_onion_request(plaintext: &[u8], symmetric_key: &[u8])
} }
}; };
// Perform the RPC call // Perform the RPC call
let result = rpc::handle_rpc_call(rpc_call).await let result = rpc::handle_rpc_call(rpc_call)
.await
// Turn any error that occurred into an HTTP response // Turn any error that occurred into an HTTP response
.or_else(super::errors::into_response)?; // Safe because at this point any error should be caught and turned into an HTTP response (i.e. an OK result) .or_else(super::errors::into_response)?; // Safe because at this point any error should be caught and turned into an HTTP response (i.e. an OK result)
// Encrypt the HTTP response so that it's propagated back to the client that made // Encrypt the HTTP response so that it's propagated back to the client that made
@ -51,7 +54,9 @@ async fn handle_decrypted_onion_request(plaintext: &[u8], symmetric_key: &[u8])
return encrypt_response(result, symmetric_key).await; return encrypt_response(result, symmetric_key).await;
} }
async fn parse_onion_request_payload(blob: warp::hyper::body::Bytes) -> Result<OnionRequestPayload, Rejection> { async fn parse_onion_request_payload(
blob: warp::hyper::body::Bytes
) -> Result<OnionRequestPayload, Rejection> {
// The encoding of an onion request looks like: | 4 bytes: size N of ciphertext | N bytes: ciphertext | json as utf8 | // The encoding of an onion request looks like: | 4 bytes: size N of ciphertext | N bytes: ciphertext | json as utf8 |
if blob.len() < 4 { if blob.len() < 4 {
println!("Ignoring blob of invalid size."); println!("Ignoring blob of invalid size.");
@ -88,9 +93,12 @@ async fn parse_onion_request_payload(blob: warp::hyper::body::Bytes) -> Result<O
} }
/// Returns the decrypted `payload.ciphertext` plus the `symmetric_key` that was used for decryption if successful. /// Returns the decrypted `payload.ciphertext` plus the `symmetric_key` that was used for decryption if successful.
async fn decrypt_onion_request_payload(payload: OnionRequestPayload) -> Result<(Vec<u8>, Vec<u8>), Rejection> { async fn decrypt_onion_request_payload(
payload: OnionRequestPayload
) -> Result<(Vec<u8>, Vec<u8>), Rejection> {
let ephemeral_key = hex::decode(payload.metadata.ephemeral_key).unwrap(); // Safe because it was validated in the parsing step let ephemeral_key = hex::decode(payload.metadata.ephemeral_key).unwrap(); // Safe because it was validated in the parsing step
let symmetric_key = crypto::get_x25519_symmetric_key(&ephemeral_key, &crypto::PRIVATE_KEY).await?; let symmetric_key =
crypto::get_x25519_symmetric_key(&ephemeral_key, &crypto::PRIVATE_KEY).await?;
let plaintext = crypto::decrypt_aes_gcm(&payload.ciphertext, &symmetric_key).await?; let plaintext = crypto::decrypt_aes_gcm(&payload.ciphertext, &symmetric_key).await?;
return Ok((plaintext, symmetric_key)); return Ok((plaintext, symmetric_key));
} }
@ -106,18 +114,16 @@ async fn encrypt_response(response: Response, symmetric_key: &[u8]) -> Result<Re
} }
let ciphertext = crypto::encrypt_aes_gcm(&bytes, symmetric_key).await.unwrap(); let ciphertext = crypto::encrypt_aes_gcm(&bytes, symmetric_key).await.unwrap();
let json = base64::encode(&ciphertext); let json = base64::encode(&ciphertext);
let response = warp::http::Response::builder() let response =
.status(StatusCode::OK.as_u16()) warp::http::Response::builder().status(StatusCode::OK.as_u16()).body(json).into_response();
.body(json)
.into_response();
return Ok(response); return Ok(response);
} }
// Utilities // Utilities
fn as_le_u32(array: &[u8; 4]) -> u32 { fn as_le_u32(array: &[u8; 4]) -> u32 {
((array[0] as u32) << 00) + ((array[0] as u32) << 00)
((array[1] as u32) << 08) + + ((array[1] as u32) << 08)
((array[2] as u32) << 16) + + ((array[2] as u32) << 16)
((array[3] as u32) << 24) + ((array[3] as u32) << 24)
} }

View File

@ -1,19 +1,19 @@
use warp::{Filter, Rejection, reply::Reply, reply::Response}; use warp::{reply::Reply, reply::Response, Filter, Rejection};
use super::errors; use super::errors;
use super::onion_requests; use super::onion_requests;
/// GET / /// GET /
pub fn root() -> impl Filter<Extract = impl warp::Reply, Error = Rejection> + Clone { pub fn root() -> impl Filter<Extract = impl warp::Reply, Error = Rejection> + Clone {
return warp::get() return warp::get().and(warp::path::end()).and_then(root_html);
.and(warp::path::end())
.and_then(root_html);
} }
/// POST /loki/v3/lsrpc /// POST /loki/v3/lsrpc
pub fn lsrpc() -> impl Filter<Extract = impl warp::Reply, Error = Rejection> + Clone { pub fn lsrpc() -> impl Filter<Extract = impl warp::Reply, Error = Rejection> + Clone {
return warp::post() return warp::post()
.and(warp::path("loki")).and(warp::path("v3")).and(warp::path("lsrpc")) .and(warp::path("loki"))
.and(warp::path("v3"))
.and(warp::path("lsrpc"))
.and(warp::body::content_length_limit(10 * 1024 * 1024)) // Match storage server .and(warp::body::content_length_limit(10 * 1024 * 1024)) // Match storage server
.and(warp::body::bytes()) // Expect bytes .and(warp::body::bytes()) // Expect bytes
.and_then(onion_requests::handle_onion_request) .and_then(onion_requests::handle_onion_request)

View File

@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use warp::{http::StatusCode, Rejection, reply::Reply, reply::Response}; use warp::{http::StatusCode, reply::Reply, reply::Response, Rejection};
use super::errors::Error; use super::errors::Error;
use super::handlers; use super::handlers;
@ -22,7 +22,7 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
Some(room_id) => room_id, Some(room_id) => room_id,
None => { None => {
println!("Missing room ID."); println!("Missing room ID.");
return Err(warp::reject::custom(Error::InvalidRpcCall)) return Err(warp::reject::custom(Error::InvalidRpcCall));
} }
}; };
let pool = storage::pool_by_room_id(&room_id); let pool = storage::pool_by_room_id(&room_id);
@ -60,7 +60,10 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
} }
} }
async fn handle_get_request(rpc_call: RpcCall, path: &str, auth_token: Option<String>, query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { async fn handle_get_request(
rpc_call: RpcCall, path: &str, auth_token: Option<String>,
query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Getting an auth token challenge doesn't require authorization, so we // Getting an auth token challenge doesn't require authorization, so we
// handle it first // handle it first
if path == "auth_token_challenge" { if path == "auth_token_challenge" {
@ -70,7 +73,7 @@ async fn handle_get_request(rpc_call: RpcCall, path: &str, auth_token: Option<St
status_code: u16, status_code: u16,
challenge: models::Challenge challenge: models::Challenge
} }
let response = Response { status_code : StatusCode::OK.as_u16(), challenge : challenge }; let response = Response { status_code: StatusCode::OK.as_u16(), challenge };
return Ok(warp::reply::json(&response).into_response()); return Ok(warp::reply::json(&response).into_response());
} }
// Check that the auth token is present // Check that the auth token is present
@ -83,11 +86,15 @@ async fn handle_get_request(rpc_call: RpcCall, path: &str, auth_token: Option<St
return Err(warp::reject::custom(Error::InvalidRpcCall)); return Err(warp::reject::custom(Error::InvalidRpcCall));
} }
let file_id = components[1]; let file_id = components[1];
return handlers::get_file(file_id, &auth_token, &pool).await.map(|json| warp::reply::json(&json).into_response()); return handlers::get_file(file_id, &auth_token, &pool)
.await
.map(|json| warp::reply::json(&json).into_response());
} }
match path { match path {
"messages" => return handlers::get_messages(query_params, &auth_token, pool).await, "messages" => return handlers::get_messages(query_params, &auth_token, pool).await,
"deleted_messages" => return handlers::get_deleted_messages(query_params, &auth_token, pool).await, "deleted_messages" => {
return handlers::get_deleted_messages(query_params, &auth_token, pool).await
}
"moderators" => return handlers::get_moderators(&auth_token, pool).await, "moderators" => return handlers::get_moderators(&auth_token, pool).await,
"block_list" => return handlers::get_banned_public_keys(&auth_token, pool).await, "block_list" => return handlers::get_banned_public_keys(&auth_token, pool).await,
"member_count" => return handlers::get_member_count(&auth_token, pool).await, "member_count" => return handlers::get_member_count(&auth_token, pool).await,
@ -98,9 +105,9 @@ async fn handle_get_request(rpc_call: RpcCall, path: &str, auth_token: Option<St
status_code: u16, status_code: u16,
challenge: models::Challenge challenge: models::Challenge
} }
let response = Response { status_code : StatusCode::OK.as_u16(), challenge : challenge }; let response = Response { status_code: StatusCode::OK.as_u16(), challenge };
return Ok(warp::reply::json(&response).into_response()); return Ok(warp::reply::json(&response).into_response());
}, }
_ => { _ => {
println!("Ignoring RPC call with invalid or unused endpoint: {}.", rpc_call.endpoint); println!("Ignoring RPC call with invalid or unused endpoint: {}.", rpc_call.endpoint);
return Err(warp::reject::custom(Error::InvalidRpcCall)); return Err(warp::reject::custom(Error::InvalidRpcCall));
@ -108,7 +115,10 @@ async fn handle_get_request(rpc_call: RpcCall, path: &str, auth_token: Option<St
} }
} }
async fn handle_post_request(rpc_call: RpcCall, path: &str, auth_token: Option<String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { async fn handle_post_request(
rpc_call: RpcCall, path: &str, auth_token: Option<String>,
pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check that the auth token is present // Check that the auth token is present
let auth_token = auth_token.ok_or(warp::reject::custom(Error::Unauthorized))?; let auth_token = auth_token.ok_or(warp::reject::custom(Error::Unauthorized))?;
// Switch on the path // Switch on the path
@ -122,10 +132,12 @@ async fn handle_post_request(rpc_call: RpcCall, path: &str, auth_token: Option<S
} }
}; };
return handlers::insert_message(message, &auth_token, pool).await; return handlers::insert_message(message, &auth_token, pool).await;
}, }
"block_list" => { "block_list" => {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct JSON { public_key: String } struct JSON {
public_key: String
}
let json: JSON = match serde_json::from_str(&rpc_call.body) { let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(message) => message, Ok(message) => message,
Err(e) => { Err(e) => {
@ -134,10 +146,12 @@ async fn handle_post_request(rpc_call: RpcCall, path: &str, auth_token: Option<S
} }
}; };
return handlers::ban(&json.public_key, &auth_token, pool).await; return handlers::ban(&json.public_key, &auth_token, pool).await;
}, }
"claim_auth_token" => { "claim_auth_token" => {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct JSON { public_key: String } struct JSON {
public_key: String
}
let json: JSON = match serde_json::from_str(&rpc_call.body) { let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(message) => message, Ok(message) => message,
Err(e) => { Err(e) => {
@ -146,10 +160,12 @@ async fn handle_post_request(rpc_call: RpcCall, path: &str, auth_token: Option<S
} }
}; };
return handlers::claim_auth_token(&json.public_key, &auth_token, pool).await; return handlers::claim_auth_token(&json.public_key, &auth_token, pool).await;
}, }
"files" => { "files" => {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct JSON { file: String } struct JSON {
file: String
}
let json: JSON = match serde_json::from_str(&rpc_call.body) { let json: JSON = match serde_json::from_str(&rpc_call.body) {
Ok(message) => message, Ok(message) => message,
Err(e) => { Err(e) => {
@ -158,7 +174,7 @@ async fn handle_post_request(rpc_call: RpcCall, path: &str, auth_token: Option<S
} }
}; };
return handlers::store_file(&json.file, &auth_token, pool).await; return handlers::store_file(&json.file, &auth_token, pool).await;
}, }
_ => { _ => {
println!("Ignoring RPC call with invalid or unused endpoint: {}.", rpc_call.endpoint); println!("Ignoring RPC call with invalid or unused endpoint: {}.", rpc_call.endpoint);
return Err(warp::reject::custom(Error::InvalidRpcCall)); return Err(warp::reject::custom(Error::InvalidRpcCall));
@ -166,7 +182,10 @@ async fn handle_post_request(rpc_call: RpcCall, path: &str, auth_token: Option<S
} }
} }
async fn handle_delete_request(rpc_call: RpcCall, path: &str, auth_token: Option<String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> { async fn handle_delete_request(
rpc_call: RpcCall, path: &str, auth_token: Option<String>,
pool: &storage::DatabaseConnectionPool
) -> Result<Response, Rejection> {
// Check that the auth token is present // Check that the auth token is present
let auth_token = auth_token.ok_or(warp::reject::custom(Error::Unauthorized))?; let auth_token = auth_token.ok_or(warp::reject::custom(Error::Unauthorized))?;
// DELETE /messages/:server_id // DELETE /messages/:server_id
@ -207,11 +226,15 @@ async fn handle_delete_request(rpc_call: RpcCall, path: &str, auth_token: Option
// Utilities // Utilities
fn get_auth_token(rpc_call: &RpcCall) -> Option<String> { fn get_auth_token(rpc_call: &RpcCall) -> Option<String> {
if rpc_call.headers.is_empty() { return None; } if rpc_call.headers.is_empty() {
return None;
}
return rpc_call.headers.get("Authorization").map(|s| s.to_string()); return rpc_call.headers.get("Authorization").map(|s| s.to_string());
} }
fn get_room_id(rpc_call: &RpcCall) -> Option<String> { fn get_room_id(rpc_call: &RpcCall) -> Option<String> {
if rpc_call.headers.is_empty() { return None; } if rpc_call.headers.is_empty() {
return None;
}
return rpc_call.headers.get("Room").map(|s| s.to_string()); return rpc_call.headers.get("Room").map(|s| s.to_string());
} }

View File

@ -3,8 +3,8 @@ use std::fs;
use std::path::Path; use std::path::Path;
use std::sync::Mutex; use std::sync::Mutex;
use rusqlite::params;
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::params;
use super::errors::Error; use super::errors::Error;
@ -35,7 +35,9 @@ fn create_main_tables_if_needed(conn: &DatabaseConnection) {
"CREATE TABLE IF NOT EXISTS {} ( "CREATE TABLE IF NOT EXISTS {} (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
name TEXT name TEXT
)", MAIN_TABLE); )",
MAIN_TABLE
);
conn.execute(&main_table_cmd, params![]).expect("Couldn't create main table."); conn.execute(&main_table_cmd, params![]).expect("Couldn't create main table.");
} }
@ -88,25 +90,34 @@ fn create_room_tables_if_needed(conn: &DatabaseConnection) {
public_key TEXT, public_key TEXT,
data TEXT, data TEXT,
signature TEXT signature TEXT
)", MESSAGES_TABLE); )",
MESSAGES_TABLE
);
conn.execute(&messages_table_cmd, params![]).expect("Couldn't create messages table."); conn.execute(&messages_table_cmd, params![]).expect("Couldn't create messages table.");
// Deleted messages // Deleted messages
let deleted_messages_table_cmd = format!( let deleted_messages_table_cmd = format!(
"CREATE TABLE IF NOT EXISTS {} ( "CREATE TABLE IF NOT EXISTS {} (
id INTEGER PRIMARY KEY id INTEGER PRIMARY KEY
)", DELETED_MESSAGES_TABLE); )",
conn.execute(&deleted_messages_table_cmd, params![]).expect("Couldn't create deleted messages table."); DELETED_MESSAGES_TABLE
);
conn.execute(&deleted_messages_table_cmd, params![])
.expect("Couldn't create deleted messages table.");
// Moderators // Moderators
let moderators_table_cmd = format!( let moderators_table_cmd = format!(
"CREATE TABLE IF NOT EXISTS {} ( "CREATE TABLE IF NOT EXISTS {} (
public_key TEXT public_key TEXT
)", MODERATORS_TABLE); )",
MODERATORS_TABLE
);
conn.execute(&moderators_table_cmd, params![]).expect("Couldn't create moderators table."); conn.execute(&moderators_table_cmd, params![]).expect("Couldn't create moderators table.");
// Block list // Block list
let block_list_table_cmd = format!( let block_list_table_cmd = format!(
"CREATE TABLE IF NOT EXISTS {} ( "CREATE TABLE IF NOT EXISTS {} (
public_key TEXT public_key TEXT
)", BLOCK_LIST_TABLE); )",
BLOCK_LIST_TABLE
);
conn.execute(&block_list_table_cmd, params![]).expect("Couldn't create block list table."); conn.execute(&block_list_table_cmd, params![]).expect("Couldn't create block list table.");
// Pending tokens // Pending tokens
// Note that a given public key can have multiple pending tokens // Note that a given public key can have multiple pending tokens
@ -115,8 +126,11 @@ fn create_room_tables_if_needed(conn: &DatabaseConnection) {
public_key STRING, public_key STRING,
timestamp INTEGER, timestamp INTEGER,
token BLOB token BLOB
)", PENDING_TOKENS_TABLE); )",
conn.execute(&pending_tokens_table_cmd, params![]).expect("Couldn't create pending tokens table."); PENDING_TOKENS_TABLE
);
conn.execute(&pending_tokens_table_cmd, params![])
.expect("Couldn't create pending tokens table.");
// Tokens // Tokens
// The token is stored as hex here (rather than as bytes) because it's more convenient for lookup // The token is stored as hex here (rather than as bytes) because it's more convenient for lookup
let tokens_table_cmd = format!( let tokens_table_cmd = format!(
@ -124,14 +138,18 @@ fn create_room_tables_if_needed(conn: &DatabaseConnection) {
public_key STRING PRIMARY KEY, public_key STRING PRIMARY KEY,
timestamp INTEGER, timestamp INTEGER,
token TEXT token TEXT
)", TOKENS_TABLE); )",
TOKENS_TABLE
);
conn.execute(&tokens_table_cmd, params![]).expect("Couldn't create tokens table."); conn.execute(&tokens_table_cmd, params![]).expect("Couldn't create tokens table.");
// Files // Files
let files_table_cmd = format!( let files_table_cmd = format!(
"CREATE TABLE IF NOT EXISTS {} ( "CREATE TABLE IF NOT EXISTS {} (
id STRING PRIMARY KEY, id STRING PRIMARY KEY,
timestamp INTEGER timestamp INTEGER
)", FILES_TABLE); )",
FILES_TABLE
);
conn.execute(&files_table_cmd, params![]).expect("Couldn't create files table."); conn.execute(&files_table_cmd, params![]).expect("Couldn't create files table.");
} }
@ -141,7 +159,9 @@ pub async fn prune_tokens_periodically() {
let mut timer = tokio::time::interval(chrono::Duration::minutes(10).to_std().unwrap()); let mut timer = tokio::time::interval(chrono::Duration::minutes(10).to_std().unwrap());
loop { loop {
timer.tick().await; timer.tick().await;
tokio::spawn(async { prune_tokens().await; }); tokio::spawn(async {
prune_tokens().await;
});
} }
} }
@ -149,7 +169,9 @@ pub async fn prune_pending_tokens_periodically() {
let mut timer = tokio::time::interval(chrono::Duration::minutes(10).to_std().unwrap()); let mut timer = tokio::time::interval(chrono::Duration::minutes(10).to_std().unwrap());
loop { loop {
timer.tick().await; timer.tick().await;
tokio::spawn(async { prune_pending_tokens().await; }); tokio::spawn(async {
prune_pending_tokens().await;
});
} }
} }
@ -157,7 +179,9 @@ pub async fn prune_files_periodically() {
let mut timer = tokio::time::interval(chrono::Duration::days(1).to_std().unwrap()); let mut timer = tokio::time::interval(chrono::Duration::days(1).to_std().unwrap());
loop { loop {
timer.tick().await; timer.tick().await;
tokio::spawn(async { prune_files(FILE_EXPIRATION).await; }); tokio::spawn(async {
prune_files(FILE_EXPIRATION).await;
});
} }
} }
@ -207,7 +231,8 @@ async fn prune_pending_tokens() {
println!("Pruned pending tokens."); println!("Pruned pending tokens.");
} }
pub async fn prune_files(file_expiration: i64) { // The expiration setting is passed in for testing purposes pub async fn prune_files(file_expiration: i64) {
// The expiration setting is passed in for testing purposes
let rooms = match get_all_room_ids().await { let rooms = match get_all_room_ids().await {
Ok(rooms) => rooms, Ok(rooms) => rooms,
Err(_) => return Err(_) => return
@ -228,9 +253,7 @@ pub async fn prune_files(file_expiration: i64) { // The expiration setting is pa
Ok(query) => query, Ok(query) => query,
Err(e) => return println!("Couldn't prune files due to error: {}.", e) Err(e) => return println!("Couldn't prune files due to error: {}.", e)
}; };
let rows = match query.query_map(params![ expiration ], |row| { let rows = match query.query_map(params![expiration], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
return println!("Couldn't prune files due to error: {}.", e); return println!("Couldn't prune files due to error: {}.", e);
@ -263,9 +286,7 @@ async fn get_all_room_ids() -> Result<Vec<String>, Error> {
// Query the database // Query the database
let raw_query = format!("SELECT id FROM {}", MAIN_TABLE); let raw_query = format!("SELECT id FROM {}", MAIN_TABLE);
let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?; let mut query = conn.prepare(&raw_query).map_err(|_| Error::DatabaseFailedInternally)?;
let rows = match query.query_map(params![], |row| { let rows = match query.query_map(params![], |row| Ok(row.get(0)?)) {
Ok(row.get(0)?)
}) {
Ok(rows) => rows, Ok(rows) => rows,
Err(e) => { Err(e) => {
println!("Couldn't query database due to error: {}.", e); println!("Couldn't query database due to error: {}.", e);

View File

@ -45,7 +45,8 @@ fn get_auth_token() -> (String, String) {
let challenge = aw!(handlers::get_auth_token_challenge(query_params, &pool)).unwrap(); let challenge = aw!(handlers::get_auth_token_challenge(query_params, &pool)).unwrap();
// Generate a symmetric key // Generate a symmetric key
let ephemeral_public_key = base64::decode(challenge.ephemeral_public_key).unwrap(); let ephemeral_public_key = base64::decode(challenge.ephemeral_public_key).unwrap();
let symmetric_key = aw!(crypto::get_x25519_symmetric_key(&ephemeral_public_key, &user_private_key)).unwrap(); let symmetric_key =
aw!(crypto::get_x25519_symmetric_key(&ephemeral_public_key, &user_private_key)).unwrap();
// Decrypt the challenge // Decrypt the challenge
let ciphertext = base64::decode(challenge.ciphertext).unwrap(); let ciphertext = base64::decode(challenge.ciphertext).unwrap();
let plaintext = aw!(crypto::decrypt_aes_gcm(&ciphertext, &symmetric_key)).unwrap(); let plaintext = aw!(crypto::decrypt_aes_gcm(&ciphertext, &symmetric_key)).unwrap();
@ -70,7 +71,8 @@ fn test_authorization() {
Err(_) => () Err(_) => ()
} }
// Try to claim the correct token // Try to claim the correct token
let response = aw!(handlers::claim_auth_token(&hex_user_public_key, &auth_token, &pool)).unwrap(); let response =
aw!(handlers::claim_auth_token(&hex_user_public_key, &auth_token, &pool)).unwrap();
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }
@ -87,7 +89,7 @@ fn test_file_handling() {
// Check that there's a file record // Check that there's a file record
let conn = pool.get().unwrap(); let conn = pool.get().unwrap();
let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE); let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);
let id: String = conn.query_row(&raw_query, params![], |row| { Ok(row.get(0)?) }).unwrap(); let id: String = conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap();
// Retrieve the file and check the content // Retrieve the file and check the content
let base64_encoded_file = aw!(handlers::get_file(&id, &auth_token, &pool)).unwrap().result; let base64_encoded_file = aw!(handlers::get_file(&id, &auth_token, &pool)).unwrap().result;
assert_eq!(base64_encoded_file, TEST_FILE); assert_eq!(base64_encoded_file, TEST_FILE);
@ -100,7 +102,7 @@ fn test_file_handling() {
// Check that the file record is also gone // Check that the file record is also gone
let conn = pool.get().unwrap(); let conn = pool.get().unwrap();
let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE); let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);
let result: Result<String, _> = conn.query_row(&raw_query, params![], |row| { Ok(row.get(0)?) }); let result: Result<String, _> = conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?));
match result { match result {
Ok(_) => assert!(false), // It should be gone now Ok(_) => assert!(false), // It should be gone now
Err(_) => () Err(_) => ()