Fix query parameter handling

This commit is contained in:
nielsandriesse 2021-03-24 09:12:54 +11:00
parent 224de9e696
commit be11886515
5 changed files with 87 additions and 81 deletions

1
Cargo.lock generated
View File

@ -1322,6 +1322,7 @@ dependencies = [
"structopt",
"tokio",
"tokio-test",
"url",
"uuid",
"warp",
"x25519-dalek",

View File

@ -6,7 +6,6 @@ edition = "2018"
[dependencies]
aes-gcm = "0.8"
structopt = "0.3"
base64 = "0.13"
chrono = "0.4"
curve25519-parser = "0.2"
@ -23,7 +22,9 @@ r2d2 = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sha2 = "0.9"
structopt = "0.3"
tokio = { version = "1.3", features = ["full"] }
url = "2.2.1"
uuid = { version = "0.8", features = ["v4"] }
warp = { version = "0.3", features = ["tls"] }
x25519-dalek = "1.1"

View File

@ -1,5 +1,6 @@
use std::convert::TryInto;
use std::fs;
use std::collections::HashMap;
use std::io::prelude::*;
use std::path::Path;
@ -13,7 +14,6 @@ use warp::{Rejection, http::StatusCode, reply::Reply, reply::Response};
use super::crypto;
use super::errors::Error;
use super::models;
use super::rpc;
use super::storage;
enum AuthorizationLevel {
@ -131,10 +131,12 @@ pub async fn get_file(id: &str) -> Result<GenericStringResponse, Rejection> { //
// Authentication
pub async fn get_auth_token_challenge(hex_public_key: &str, pool: &storage::DatabaseConnectionPool) -> Result<models::Challenge, Rejection> { // Doesn't return a response directly for testing purposes
pub async fn get_auth_token_challenge(query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool) -> Result<models::Challenge, Rejection> { // Doesn't return a response directly for testing purposes
// Get the public key
let hex_public_key = query_params.get("public_key").ok_or(warp::reject::custom(Error::InvalidRpcCall))?;
// Validate the public key
if !is_valid_public_key(hex_public_key) {
println!("Ignoring challenge request for invalid public key.");
println!("Ignoring challenge request for invalid public key: {}.", hex_public_key);
return Err(warp::reject::custom(Error::ValidationFailed));
}
// Convert the public key to bytes and cut off the version byte
@ -270,15 +272,37 @@ pub async fn insert_message(mut message: models::Message, auth_token: Option<Str
}
/// Returns either the last `limit` messages or all messages since `from_server_id, limited to `limit`.
pub async fn get_messages(options: rpc::QueryOptions, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
pub async fn get_messages(query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
// Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Unwrap parameters
let from_server_id = options.from_server_id.unwrap_or(0);
let limit = options.limit.unwrap_or(256); // Never return more than 256 messages at once
// Unwrap query parameters
let from_server_id: i64;
if let Some(str) = query_params.get("from_server_id") {
from_server_id = match str.parse() {
Ok(from_server_id) => from_server_id,
Err(_) => {
println!("Couldn't parse query parameter from: {}.", str);
return Err(warp::reject::custom(Error::ValidationFailed));
}
}
} else {
from_server_id = 0;
}
let limit: u16;
if let Some(str) = query_params.get("limit") {
limit = match str.parse() {
Ok(limit) => limit,
Err(_) => {
println!("Couldn't parse query parameter from: {}.", str);
return Err(warp::reject::custom(Error::ValidationFailed));
}
}
} else {
limit = 256; // Never return more than 256 messages at once
}
// Query the database
let raw_query: String;
if options.from_server_id.is_some() {
if query_params.get("from_server_id").is_some() {
raw_query = format!("SELECT id, data, signature FROM {} WHERE rowid > (?1) LIMIT (?2)", storage::MESSAGES_TABLE);
} else {
raw_query = format!("SELECT id, data, signature FROM {} ORDER BY rowid DESC LIMIT (?2)", storage::MESSAGES_TABLE);
@ -360,15 +384,37 @@ pub async fn delete_message(row_id: i64, auth_token: Option<String>, pool: &stor
}
/// Returns either the last `limit` deleted messages or all deleted messages since `from_server_id, limited to `limit`.
pub async fn get_deleted_messages(options: rpc::QueryOptions, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
pub async fn get_deleted_messages(query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
// Get a database connection
let conn = pool.get().map_err(|_| Error::DatabaseFailedInternally)?;
// Unwrap parameters
let from_server_id = options.from_server_id.unwrap_or(0);
let limit = options.limit.unwrap_or(256); // Never return more than 256 deleted messages at once
// Unwrap query parameters
let from_server_id: i64;
if let Some(str) = query_params.get("from_server_id") {
from_server_id = match str.parse() {
Ok(from_server_id) => from_server_id,
Err(_) => {
println!("Couldn't parse query parameter from: {}.", str);
return Err(warp::reject::custom(Error::ValidationFailed));
}
}
} else {
from_server_id = 0;
}
let limit: u16;
if let Some(str) = query_params.get("limit") {
limit = match str.parse() {
Ok(limit) => limit,
Err(_) => {
println!("Couldn't parse query parameter from: {}.", str);
return Err(warp::reject::custom(Error::ValidationFailed));
}
}
} else {
limit = 256; // Never return more than 256 messages at once
}
// Query the database
let raw_query: String;
if options.from_server_id.is_some() {
if query_params.get("from_server_id").is_some() {
raw_query = format!("SELECT id FROM {} WHERE rowid > (?1) LIMIT (?2)", storage::DELETED_MESSAGES_TABLE);
} else {
raw_query = format!("SELECT id FROM {} ORDER BY rowid DESC LIMIT (?2)", storage::DELETED_MESSAGES_TABLE);

View File

@ -15,12 +15,6 @@ pub struct RpcCall {
pub headers: HashMap<String, String>
}
#[derive(Debug, Deserialize)]
pub struct QueryOptions {
pub limit: Option<u16>,
pub from_server_id: Option<i64>
}
pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
// Get a connection pool for the given room
let room_id = match get_room_id(&rpc_call) {
@ -32,11 +26,18 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
};
let pool = storage::pool_by_room_id(&room_id);
// Check that the endpoint is a valid URI
let raw_uri = format!("/{}", rpc_call.endpoint.trim_start_matches("/"));
let uri = match raw_uri.parse::<http::Uri>() {
Ok(uri) => uri,
let raw_uri = format!("http://placeholder.io/{}", rpc_call.endpoint.trim_start_matches("/"));
let path: String = match raw_uri.parse::<http::Uri>() {
Ok(uri) => uri.path().trim_start_matches("/").to_string(),
Err(e) => {
println!("Couldn't parse URI from: {} due to error: {}.", rpc_call.endpoint, e);
println!("Couldn't parse URI from: {} due to error: {}.", &raw_uri, e);
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
};
let query_params: HashMap<String, String> = match url::Url::parse(&raw_uri) {
Ok(url) => url.query_pairs().into_owned().collect(),
Err(e) => {
println!("Couldn't parse URL from: {} due to error: {}.", &raw_uri, e);
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
};
@ -44,9 +45,9 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
let auth_token = get_auth_token(&rpc_call);
// Switch on the HTTP method
match rpc_call.method.as_ref() {
"GET" => return handle_get_request(rpc_call, uri, &pool).await,
"POST" => return handle_post_request(rpc_call, uri, auth_token, &pool).await,
"DELETE" => return handle_delete_request(rpc_call, uri, auth_token, &pool).await,
"GET" => return handle_get_request(rpc_call, &path, query_params, &pool).await,
"POST" => return handle_post_request(rpc_call, &path, auth_token, &pool).await,
"DELETE" => return handle_delete_request(rpc_call, &path, auth_token, &pool).await,
_ => {
println!("Ignoring RPC call with invalid or unused HTTP method: {}.", rpc_call.method);
return Err(warp::reject::custom(Error::InvalidRpcCall));
@ -54,9 +55,8 @@ pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
}
}
async fn handle_get_request(rpc_call: RpcCall, uri: http::Uri, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
// Switch on the path
let path = uri.path().trim_start_matches("/");
async fn handle_get_request(rpc_call: RpcCall, path: &str, query_params: HashMap<String, String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
// Switch on the path
if path.starts_with("files") {
let components: Vec<&str> = path.split("/").collect(); // Split on subsequent slashes
if components.len() != 2 {
@ -67,56 +67,13 @@ async fn handle_get_request(rpc_call: RpcCall, uri: http::Uri, pool: &storage::D
return handlers::get_file(file_id).await.map(|json| warp::reply::json(&json).into_response());
}
match path {
"messages" => {
let query_options: QueryOptions;
if let Some(query) = uri.query() {
query_options = match serde_json::from_str(&query) {
Ok(query_options) => query_options,
Err(e) => {
println!("Couldn't parse query options from: {} due to error: {}.", query, e);
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
};
} else {
query_options = QueryOptions { limit : None, from_server_id : None };
}
return handlers::get_messages(query_options, pool).await;
},
"deleted_messages" => {
let query_options: QueryOptions;
if let Some(query) = uri.query() {
query_options = match serde_json::from_str(&query) {
Ok(query_options) => query_options,
Err(e) => {
println!("Couldn't parse query options from: {} due to error: {}.", query, e);
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
};
} else {
query_options = QueryOptions { limit : None, from_server_id : None };
}
return handlers::get_deleted_messages(query_options, pool).await
},
"messages" => return handlers::get_messages(query_params, pool).await,
"deleted_messages" => return handlers::get_deleted_messages(query_params, pool).await,
"moderators" => return handlers::get_moderators(pool).await,
"block_list" => return handlers::get_banned_public_keys(pool).await,
"member_count" => return handlers::get_member_count(pool).await,
"auth_token_challenge" => {
#[derive(Debug, Deserialize)]
struct QueryOptions { public_key: String }
let query_options: QueryOptions;
if let Some(query) = uri.query() {
query_options = match serde_json::from_str(&query) {
Ok(query_options) => query_options,
Err(e) => {
println!("Couldn't parse query options from: {} due to error: {}.", query, e);
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
};
} else {
println!("Missing query options.");
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
return handlers::get_auth_token_challenge(&query_options.public_key, pool).await.map(|json| warp::reply::json(&json).into_response());
return handlers::get_auth_token_challenge(query_params, pool).await.map(|json| warp::reply::json(&json).into_response());
},
_ => {
println!("Ignoring RPC call with invalid or unused endpoint: {}.", rpc_call.endpoint);
@ -125,8 +82,7 @@ async fn handle_get_request(rpc_call: RpcCall, uri: http::Uri, pool: &storage::D
}
}
async fn handle_post_request(rpc_call: RpcCall, uri: http::Uri, auth_token: Option<String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
let path = uri.path().trim_start_matches("/");
async fn handle_post_request(rpc_call: RpcCall, path: &str, auth_token: Option<String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
match path {
"messages" => {
let message = match serde_json::from_str(&rpc_call.body) {
@ -181,8 +137,7 @@ async fn handle_post_request(rpc_call: RpcCall, uri: http::Uri, auth_token: Opti
}
}
async fn handle_delete_request(rpc_call: RpcCall, uri: http::Uri, auth_token: Option<String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
let path = uri.path().trim_start_matches("/");
async fn handle_delete_request(rpc_call: RpcCall, path: &str, auth_token: Option<String>, pool: &storage::DatabaseConnectionPool) -> Result<Response, Rejection> {
// DELETE /messages/:server_id
if path.starts_with("messages") {
let components: Vec<&str> = path.split("/").collect(); // Split on subsequent slashes

View File

@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::fs;
use std::path::Path;
@ -41,7 +42,9 @@ fn test_authorization() {
let (user_private_key, user_public_key) = aw!(crypto::generate_x25519_key_pair());
let hex_user_public_key = format!("05{}", hex::encode(user_public_key.to_bytes()));
// Get a challenge
let challenge = aw!(handlers::get_auth_token_challenge(&hex_user_public_key, &pool)).unwrap();
let mut query_params: HashMap<String, String> = HashMap::new();
query_params.insert("public_key".to_string(), hex_user_public_key.clone());
let challenge = aw!(handlers::get_auth_token_challenge(query_params, &pool)).unwrap();
// Generate a symmetric key
let ephemeral_public_key = base64::decode(challenge.ephemeral_public_key).unwrap();
let symmetric_key = aw!(crypto::get_x25519_symmetric_key(&ephemeral_public_key, &user_private_key)).unwrap();