Make file handling async
This commit is contained in:
parent
57ea49e30e
commit
023d134067
|
@ -1,7 +1,5 @@
|
|||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::fs;
|
||||
use std::io::prelude::*;
|
||||
use std::path::Path;
|
||||
|
||||
use chrono;
|
||||
|
@ -9,6 +7,8 @@ use log::{error, info, warn};
|
|||
use rand::{thread_rng, Rng};
|
||||
use rusqlite::params;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::fs::File;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use warp::{http::StatusCode, reply::Reply, reply::Response, Rejection};
|
||||
|
||||
use super::crypto;
|
||||
|
@ -123,7 +123,7 @@ pub fn get_all_rooms() -> Result<Response, Rejection> {
|
|||
|
||||
// Files
|
||||
|
||||
pub fn store_file(
|
||||
pub async fn store_file(
|
||||
base64_encoded_bytes: &str, auth_token: &str, pool: &storage::DatabaseConnectionPool,
|
||||
) -> Result<Response, Rejection> {
|
||||
// It'd be nice to use the UUID crate for the file ID, but clients want an integer ID
|
||||
|
@ -157,26 +157,22 @@ pub fn store_file(
|
|||
}
|
||||
};
|
||||
// Write to file
|
||||
let mut pos = 0;
|
||||
let raw_path = format!("files/{}", &now);
|
||||
let path = Path::new(&raw_path);
|
||||
let mut buffer = match fs::File::create(path) {
|
||||
Ok(buffer) => buffer,
|
||||
let mut file = match File::create(path).await {
|
||||
Ok(file) => file,
|
||||
Err(e) => {
|
||||
error!("Couldn't store file due to error: {}.", e);
|
||||
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
|
||||
}
|
||||
};
|
||||
match file.write_all(&bytes).await {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
error!("Couldn't store file due to error: {}.", e);
|
||||
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
|
||||
}
|
||||
};
|
||||
while pos < bytes.len() {
|
||||
let count = match buffer.write(&bytes[pos..]) {
|
||||
Ok(count) => count,
|
||||
Err(e) => {
|
||||
error!("Couldn't store file due to error: {}.", e);
|
||||
return Err(warp::reject::custom(Error::DatabaseFailedInternally));
|
||||
}
|
||||
};
|
||||
pos += count;
|
||||
}
|
||||
// Return
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct Response {
|
||||
|
@ -187,7 +183,7 @@ pub fn store_file(
|
|||
return Ok(warp::reply::json(&response).into_response());
|
||||
}
|
||||
|
||||
pub fn get_file(
|
||||
pub async fn get_file(
|
||||
id: i64, auth_token: &str, pool: &storage::DatabaseConnectionPool,
|
||||
) -> Result<GenericStringResponse, Rejection> {
|
||||
// Doesn't return a response directly for testing purposes
|
||||
|
@ -198,10 +194,18 @@ pub fn get_file(
|
|||
return Err(warp::reject::custom(Error::Unauthorized));
|
||||
}
|
||||
// Try to read the file
|
||||
let mut bytes = vec![];
|
||||
let raw_path = format!("files/{}", id);
|
||||
let path = Path::new(&raw_path);
|
||||
let bytes = match fs::read(path) {
|
||||
Ok(bytes) => bytes,
|
||||
let mut file = match File::open(path).await {
|
||||
Ok(file) => file,
|
||||
Err(e) => {
|
||||
error!("Couldn't read file due to error: {}.", e);
|
||||
return Err(warp::reject::custom(Error::ValidationFailed));
|
||||
}
|
||||
};
|
||||
match file.read_to_end(&mut bytes).await {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
error!("Couldn't read file due to error: {}.", e);
|
||||
return Err(warp::reject::custom(Error::ValidationFailed));
|
||||
|
@ -217,12 +221,20 @@ pub fn get_file(
|
|||
return Ok(json);
|
||||
}
|
||||
|
||||
pub fn get_group_image(room_id: &str) -> Result<Response, Rejection> {
|
||||
pub async fn get_group_image(room_id: &str) -> Result<Response, Rejection> {
|
||||
// Try to read the file
|
||||
let mut bytes = vec![];
|
||||
let raw_path = format!("files/{}", room_id);
|
||||
let path = Path::new(&raw_path);
|
||||
let bytes = match fs::read(path) {
|
||||
Ok(bytes) => bytes,
|
||||
let mut file = match File::open(path).await {
|
||||
Ok(file) => file,
|
||||
Err(e) => {
|
||||
error!("Couldn't read file due to error: {}.", e);
|
||||
return Err(warp::reject::custom(Error::ValidationFailed));
|
||||
}
|
||||
};
|
||||
match file.read_to_end(&mut bytes).await {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
error!("Couldn't read file due to error: {}.", e);
|
||||
return Err(warp::reject::custom(Error::ValidationFailed));
|
||||
|
|
|
@ -47,6 +47,7 @@ async fn handle_decrypted_onion_request(
|
|||
};
|
||||
// Perform the RPC call
|
||||
let result = rpc::handle_rpc_call(rpc_call)
|
||||
.await
|
||||
// Turn any error that occurred into an HTTP response
|
||||
// Unwrapping is safe because at this point any error should be caught and turned into an HTTP response (i.e. an OK result)
|
||||
.or_else(super::errors::into_response)?;
|
||||
|
|
15
src/rpc.rs
15
src/rpc.rs
|
@ -25,7 +25,7 @@ pub struct RpcCall {
|
|||
|
||||
const MODE: Mode = Mode::OpenGroupServer;
|
||||
|
||||
pub fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
|
||||
pub async fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
|
||||
// Check that the endpoint is a valid URI and deconstruct it into a path
|
||||
// and query parameters.
|
||||
// Adding "http://placeholder.io" in front of the endpoint is a workaround
|
||||
|
@ -50,10 +50,10 @@ pub fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
|
|||
let auth_token = get_auth_token(&rpc_call);
|
||||
// Switch on the HTTP method
|
||||
match rpc_call.method.as_ref() {
|
||||
"GET" => return handle_get_request(rpc_call, &path, auth_token, query_params),
|
||||
"GET" => return handle_get_request(rpc_call, &path, auth_token, query_params).await,
|
||||
"POST" => {
|
||||
let pool = get_pool_for_room(&rpc_call)?;
|
||||
return handle_post_request(rpc_call, &path, auth_token, &pool);
|
||||
return handle_post_request(rpc_call, &path, auth_token, &pool).await;
|
||||
}
|
||||
"DELETE" => {
|
||||
let pool = get_pool_for_room(&rpc_call)?;
|
||||
|
@ -66,7 +66,7 @@ pub fn handle_rpc_call(rpc_call: RpcCall) -> Result<Response, Rejection> {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_get_request(
|
||||
async fn handle_get_request(
|
||||
rpc_call: RpcCall, path: &str, auth_token: Option<String>,
|
||||
query_params: HashMap<String, String>,
|
||||
) -> Result<Response, Rejection> {
|
||||
|
@ -91,7 +91,7 @@ fn handle_get_request(
|
|||
return handlers::get_room(&room_id);
|
||||
} else if components.len() == 3 && components[2] == "image" {
|
||||
let room_id = components[1];
|
||||
return handlers::get_group_image(&room_id);
|
||||
return handlers::get_group_image(&room_id).await;
|
||||
} else {
|
||||
warn!("Invalid endpoint: {}.", rpc_call.endpoint);
|
||||
return Err(warp::reject::custom(Error::InvalidRpcCall));
|
||||
|
@ -115,6 +115,7 @@ fn handle_get_request(
|
|||
}
|
||||
};
|
||||
return handlers::get_file(file_id, &auth_token, &pool)
|
||||
.await
|
||||
.map(|json| warp::reply::json(&json).into_response());
|
||||
}
|
||||
match path {
|
||||
|
@ -145,7 +146,7 @@ fn handle_get_request(
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_post_request(
|
||||
async fn handle_post_request(
|
||||
rpc_call: RpcCall, path: &str, auth_token: Option<String>,
|
||||
pool: &storage::DatabaseConnectionPool,
|
||||
) -> Result<Response, Rejection> {
|
||||
|
@ -205,7 +206,7 @@ fn handle_post_request(
|
|||
return Err(warp::reject::custom(Error::InvalidRpcCall));
|
||||
}
|
||||
};
|
||||
return handlers::store_file(&json.file, &auth_token, pool);
|
||||
return handlers::store_file(&json.file, &auth_token, pool).await;
|
||||
}
|
||||
_ => {
|
||||
warn!("Ignoring RPC call with invalid or unused endpoint: {}.", path);
|
||||
|
|
|
@ -87,13 +87,13 @@ fn test_file_handling() {
|
|||
// Get an auth token
|
||||
let (auth_token, _) = get_auth_token();
|
||||
// Store the test file
|
||||
handlers::store_file(TEST_FILE, &auth_token, &pool).unwrap();
|
||||
aw!(handlers::store_file(TEST_FILE, &auth_token, &pool)).unwrap();
|
||||
// Check that there's a file record
|
||||
let conn = pool.get().unwrap();
|
||||
let raw_query = format!("SELECT id FROM {}", storage::FILES_TABLE);
|
||||
let id: i64 = conn.query_row(&raw_query, params![], |row| Ok(row.get(0)?)).unwrap();
|
||||
// Retrieve the file and check the content
|
||||
let base64_encoded_file = handlers::get_file(id, &auth_token, &pool).unwrap().result;
|
||||
let base64_encoded_file = aw!(handlers::get_file(id, &auth_token, &pool)).unwrap().result;
|
||||
assert_eq!(base64_encoded_file, TEST_FILE);
|
||||
// Prune the file and check that it's gone
|
||||
aw!(storage::prune_files(-60)); // Will evaluate to now + 60
|
||||
|
|
Loading…
Reference in New Issue