Skip to content

Commit

Permalink
rsc: Track blobs in the database (#1492)
Browse files Browse the repository at this point in the history
* Track blobs in the database

* add todo
  • Loading branch information
V-FEXrt authored Dec 15, 2023
1 parent 8501a51 commit 68e7071
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 17 deletions.
62 changes: 51 additions & 11 deletions rust/rsc/src/rsc/blob.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::types::GetUploadUrlResponse;
use crate::types::{GetUploadUrlResponse, PostBlobResponse, PostBlobResponsePart};
use async_trait::async_trait;
use axum::{extract::Multipart, http::StatusCode, Json};
use data_encoding::BASE64URL;
use entity::blob;
use futures::stream::BoxStream;
use futures::TryStreamExt;
use rand_core::{OsRng, RngCore};
use sea_orm::DatabaseConnection;
use sea_orm::{ActiveModelTrait, ActiveValue::*, DatabaseConnection};
use std::sync::Arc;
use tokio::fs::File;
use tokio::io::BufWriter;
Expand Down Expand Up @@ -66,21 +67,60 @@ pub async fn get_upload_url(server_addr: String) -> Json<GetUploadUrlResponse> {
#[tracing::instrument]
pub async fn create_blob(
mut multipart: Multipart,
_conn: Arc<DatabaseConnection>,
db: Arc<DatabaseConnection>,
store: Arc<dyn DebugBlobStore + Send + Sync>,
) -> (StatusCode, String) {
store_id: i32,
) -> (StatusCode, Json<PostBlobResponse>) {
let mut parts: Vec<PostBlobResponsePart> = Vec::new();

while let Ok(Some(field)) = multipart.next_field().await {
match store
let name = match field.name() {
Some(x) => x.to_string(),
None => {
return (
StatusCode::BAD_REQUEST,
Json(PostBlobResponse::Error {
message: "Multipart field must be named".into(),
}),
)
}
};

let result = store
.stream(Box::pin(field.map_err(|err| {
std::io::Error::new(std::io::ErrorKind::Other, err)
})))
.await
{
// TODO: The blob should be inserted into the db instead of just printing the key
Ok(key) => println!("{:?}", key),
Err(msg) => return (StatusCode::INTERNAL_SERVER_ERROR, msg.to_string()),
.await;

if let Err(msg) = result {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(PostBlobResponse::Error {
message: msg.to_string(),
}),
);
}

let active_blob = blob::ActiveModel {
// TODO: these ids should be migrated to UUIDs
id: NotSet,
created_at: NotSet,
key: Set(result.unwrap()),
store_id: Set(store_id),
};

match active_blob.insert(db.as_ref()).await {
Err(msg) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(PostBlobResponse::Error {
message: msg.to_string(),
}),
)
}
Ok(blob) => parts.push(PostBlobResponsePart { id: blob.id, name }),
}
}

(StatusCode::OK, "ok".into())
(StatusCode::OK, Json(PostBlobResponse::Ok { blobs: parts }))
}
7 changes: 4 additions & 3 deletions rust/rsc/src/rsc/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ fn create_router(conn: Arc<DatabaseConnection>, config: Arc<config::RSCConfig>)
post({
let conn = conn.clone();
let store = store.clone();
move |multipart: Multipart| blob::create_blob(multipart, conn, store)
// TODO: Don't hardcode store type here
move |multipart: Multipart| blob::create_blob(multipart, conn, store, 1)
})
.layer(DefaultBodyLimit::disable()),
)
Expand Down Expand Up @@ -339,7 +340,7 @@ mod tests {

assert_eq!(res.status(), StatusCode::OK);

// Non-matching job should 200 with expected body
// Non-matching job should 404 with expected body
let res = router
.call(
Request::builder()
Expand All @@ -363,7 +364,7 @@ mod tests {
.await
.unwrap();

assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.status(), StatusCode::NOT_FOUND);

let body = hyper::body::to_bytes(res).await.unwrap();
let body: Value = serde_json::from_slice(&body).unwrap();
Expand Down
9 changes: 6 additions & 3 deletions rust/rsc/src/rsc/read_job.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::types::{Dir, File, ReadJobPayload, ReadJobResponse, Symlink};
use axum::Json;
use entity::{job, job_use, output_dir, output_file, output_symlink};
use hyper::StatusCode;
use sea_orm::{
ActiveModelTrait, ActiveValue::*, ColumnTrait, DatabaseConnection, DbErr, EntityTrait,
ModelTrait, QueryFilter, TransactionTrait,
Expand All @@ -22,7 +23,7 @@ async fn record_use(job_id: i32, conn: Arc<DatabaseConnection>) {
pub async fn read_job(
Json(payload): Json<ReadJobPayload>,
conn: Arc<DatabaseConnection>,
) -> Json<ReadJobResponse> {
) -> (StatusCode, Json<ReadJobResponse>) {
// First find the hash so we can look up the exact job
let hash: Vec<u8> = payload.hash().into();

Expand Down Expand Up @@ -97,20 +98,22 @@ pub async fn read_job(
// If we get a match we want to record the use but we don't
// want to block sending the response on it so we spawn a task
// to go do that.
let mut status = StatusCode::NOT_FOUND;
if let ReadJobResponse::Match { .. } = response {
status = StatusCode::OK;
let shared_conn = conn.clone();
tokio::spawn(async move {
record_use(job_id, shared_conn).await;
});
}
Json(response)
(status, Json(response))
}
Err(cause) => {
tracing::error! {
%cause,
"failed to add job"
};
Json(ReadJobResponse::NoMatch)
(StatusCode::NOT_FOUND, Json(ReadJobResponse::NoMatch))
}
}
}
13 changes: 13 additions & 0 deletions rust/rsc/src/rsc/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ impl ReadJobPayload {
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct PostBlobResponsePart {
pub id: i32,
pub name: String,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum PostBlobResponse {
Error { message: String },
Ok { blobs: Vec<PostBlobResponsePart> },
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ReadJobResponse {
Expand Down

0 comments on commit 68e7071

Please sign in to comment.