Make file handling async

This commit is contained in:
Niels Andriesse 2021-04-01 10:32:25 +11:00
parent 57ea49e30e
commit 023d134067
4 changed files with 45 additions and 31 deletions

View File

@ -1,7 +1,5 @@
use std::collections::HashMap;
use std::convert::TryInto;
use std::fs;
use std::io::prelude::*;
use std::path::Path;
use chrono;
@ -9,6 +7,8 @@ use log::{error, info, warn};
use rand::{thread_rng, Rng};
use rusqlite::params;
use serde::{Deserialize, Serialize};
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use warp::{http::StatusCode, reply::Reply, reply::Response, Rejection};
use super::crypto;
@ -123,7 +123,7 @@ pub fn get_all_rooms() -> Result<Response, Rejection> {
// Files
pub fn store_file(
pub async fn store_file(
base64_encoded_bytes: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool,
) -> Result<Response, Rejection> {
// It'd be nice to use the UUID crate for the file ID, but clients want an integer ID
@ -157,26 +157,22 @@ pub fn store_file(
}
};
// Write to file
let mut pos = 0;
let raw_path = format!("files/{}", &now);
let path = Path::new(&raw_path);
let mut buffer = match fs::File::create(path) {
Ok(buffer) => buffer,
let mut file = match File::create(path).await {
Ok(file) => file,
Err(e) => {
error!("Couldn't store file due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
match file.write_all(&bytes).await {
Ok(_) => (),
Err(e) => {
error!("Couldn't store file due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
while pos < bytes.len() {
let count = match buffer.write(&bytes[pos..]) {
Ok(count) => count,
Err(e) => {
error!("Couldn't store file due to error: {}.", e);
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
}
};
pos += count;
}
// Return
#[derive(Debug, Deserialize, Serialize)]
struct Response {
@ -187,7 +183,7 @@ pub fn store_file(
return Ok(warp::reply::json(&response).into_response());
}
pub fn get_file(
pub async fn get_file(
id: i64, auth_token: &str, pool: &storage::DatabaseConnectionPool,
) -> Result<GenericStringResponse, Rejection> {
// Doesn't return a response directly for testing purposes
@ -198,10 +194,18 @@ pub fn get_file(
return Err(warp::reject::custom(Error::Unauthorized));
}
// Try to read the file
let mut bytes = vec![];
let raw_path = format!("files/{}", id);
let path = Path::new(&raw_path);
let bytes = match fs::read(path) {
Ok(bytes) => bytes,
let mut file = match File::open(path).await {
Ok(file) => file,
Err(e) => {
error!("Couldn't read file due to error: {}.", e);
return Err(warp::reject::custom(Error::ValidationFailed));
}
};
match file.read_to_end(&mut bytes).await {
Ok(_) => (),
Err(e) => {
error!("Couldn't read file due to error: {}.", e);
return Err(warp::reject::custom(Error::ValidationFailed));
@ -217,12 +221,20 @@ pub fn get_file(
return Ok(json);
}
pub fn get_group_image(room_id: &str) -> Result<Response, Rejection> {
pub async fn get_group_image(room_id: &str) -> Result<Response, Rejection> {
// Try to read the file
let mut bytes = vec![];
let raw_path = format!("files/{}", room_id);
let path = Path::new(&raw_path);
let bytes = match fs::read(path) {
Ok(bytes) => bytes,
let mut file = match File::open(path).await {
Ok(file) => file,
Err(e) => {
error!("Couldn't read file due to error: {}.", e);
return Err(warp::reject::custom(Error::ValidationFailed));
}
};
match file.read_to_end(&mut bytes).await {
Ok(_) => (),
Err(e) => {
error!("Couldn't read file due to error: {}.", e);
return Err(warp::reject::custom(Error::ValidationFailed));

View File

@ -47,6 +47,7 @@ async fn handle_decrypted_onion_request(
};
// Perform the RPC call
let result = rpc::handle_rpc_call(rpc_call)
.await
// Turn any error that occurred into an HTTP response
// Unwrapping is safe because at this point any error should be caught and turned into an HTTP response (i.e. an OK result)
.or_else(super::errors::into_response)?;

View File

@ -25,7 +25,7 @@ pub struct RpcCall {
const MODE: Mode = Mode::OpenGroupServer;
pub fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
// Check that the endpoint is a valid URI and deconstruct it into a path
// and query parameters.
// Adding "http://placeholder.io" in front of the endpoint is a workaround
@ -50,10 +50,10 @@ pub fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
let auth_token = get_auth_token(&rpc_call);
// Switch on the HTTP method
match rpc_call.method.as_ref() {
"GET" => return handle_get_request(rpc_call, &path, auth_token, query_params),
"GET" => return handle_get_request(rpc_call, &path, auth_token, query_params).await,
"POST" => {
let pool = get_pool_for_room(&rpc_call)?;
return handle_post_request(rpc_call, &path, auth_token, &pool);
return handle_post_request(rpc_call, &path, auth_token, &pool).await;
}
"DELETE" => {
let pool = get_pool_for_room(&rpc_call)?;
@ -66,7 +66,7 @@ pub fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
}
}
fn handle_get_request(
async fn handle_get_request(
rpc_call: RpcCall, path: &str, auth_token: Option<String>,
query_params: HashMap<String, String>,
) -> Result<Response, Rejection> {
@ -91,7 +91,7 @@ fn handle_get_request(
return handlers::get_room(&room_id);
} else if components.len() == 3 && components[2] == "image" {
let room_id = components[1];
return handlers::get_group_image(&room_id);
return handlers::get_group_image(&room_id).await;
} else {
warn!("Invalid endpoint: {}.", rpc_call.endpoint);
return Err(warp::reject::custom(Error::InvalidRpcCall));
@ -115,6 +115,7 @@ fn handle_get_request(
}
};
return handlers::get_file(file_id, &auth_token, &pool)
.await
.map(|json| warp::reply::json(&json).into_response());
}
match path {
@ -145,7 +146,7 @@ fn handle_get_request(
}
}
fn handle_post_request(
async fn handle_post_request(
rpc_call: RpcCall, path: &str, auth_token: Option<String>,
pool: &storage::DatabaseConnectionPool,
) -> Result<Response, Rejection> {
@ -205,7 +206,7 @@ fn handle_post_request(
return Err(warp::reject::custom(Error::InvalidRpcCall));
}
};
return handlers::store_file(&json.file, &auth_token, pool);
return handlers::store_file(&json.file, &auth_token, pool).await;
}
_ => {
warn!("Ignoring RPC call with invalid or unused endpoint: {}.", path);

View File

@ -87,13 +87,13 @@ fn test_file_handling() {
// Get an auth token
let (auth_token, _) = get_auth_token();
// Store the test file
handlers::store_file(TEST_FILE, &auth_token, &pool).unwrap();
aw!(handlers::store_file(TEST_FILE, &auth_token, &pool)).unwrap();
// Check that there's a file record
let conn = pool.get().unwrap();
let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);
let id: i64 = conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap();
// Retrieve the file and check the content
let base64_encoded_file = handlers::get_file(id, &auth_token, &pool).unwrap().result;
let base64_encoded_file = aw!(handlers::get_file(id, &auth_token, &pool)).unwrap().result;
assert_eq!(base64_encoded_file, TEST_FILE);
// Prune the file and check that it's gone
aw!(storage::prune_files(-60)); // Will evaluate to now + 60