Skip to content

Commit 4116e65

Browse files
committed
feat: aws s3 as a file storage option
Signed-off-by: Richard Zak <[email protected]>
1 parent e5b3308 commit 4116e65

File tree

10 files changed

+587
-43
lines changed

10 files changed

+587
-43
lines changed

Cargo.lock

Lines changed: 438 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ aes-gcm = { version = "0.10.3", default-features = false }
102102
anyhow = { version = "1.0", default-features = false }
103103
app-memory-usage-fetcher = { version = "0.2.1", default-features = false }
104104
argon2 = { version = "0.5.3", default-features = false }
105+
aws-config = { version = "1.1.7", default-features = false }
106+
aws-sdk-s3 = { version = "1.76.0", default-features = false }
107+
aws-smithy-async = { version = "1.2.4", default-features = false }
105108
axum = { version = "0.8.1", default-features = false }
106109
axum-server = { version = "0.7.2", default-features = false }
107110
base64 = { version = "0.22.1", default-features = false }

crates/server/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ aes-gcm = { workspace = true, features = ["aes", "alloc", "getrandom", "std"] }
2727
anyhow = { workspace = true, features = ["std"] }
2828
app-memory-usage-fetcher = { workspace = true }
2929
argon2 = { workspace = true, features = ["alloc", "password-hash", "std"] }
30+
aws-config = { workspace = true, features = ["behavior-version-latest"] }
31+
aws-sdk-s3 = { workspace = true }
32+
aws-smithy-async = { workspace = true, features = ["rt-tokio"] }
3033
axum = { workspace = true, features = ["http1", "http2", "json", "macros", "tokio"] }
3134
axum-server = { workspace = true, features = ["tls-rustls", "rustls"] }
3235
base64 = { workspace = true, features = ["alloc", "std"] }

crates/server/src/db/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ mod tests {
749749
}),
750750
cert: None,
751751
key: None,
752+
aws: None,
752753
};
753754

754755
let vt: VtUpdater = state.try_into().expect("failed to create VtUpdater");
@@ -810,6 +811,7 @@ mod tests {
810811
}),
811812
cert: None,
812813
key: None,
814+
aws: None,
813815
};
814816

815817
let sqlite_second = Sqlite::new(DB_FILE)

crates/server/src/http/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ mod tests {
435435
vt_client: None,
436436
cert: None,
437437
key: None,
438+
aws: None,
438439
};
439440

440441
state
@@ -495,6 +496,7 @@ mod tests {
495496
vt_client: None,
496497
cert: Some("../../testdata/server_ca_cert.pem".into()),
497498
key: Some("../../testdata/server_key.pem".into()),
499+
aws: None,
498500
}
499501
} else {
500502
State {
@@ -514,6 +516,7 @@ mod tests {
514516
vt_client: None,
515517
cert: Some("../../testdata/server_cert.der".into()),
516518
key: Some("../../testdata/server_key.der".into()),
519+
aws: None,
517520
}
518521
};
519522

crates/server/src/lib.rs

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,17 @@ use std::sync::Arc;
3333
use std::time::{Duration, SystemTime};
3434

3535
use anyhow::{bail, ensure, Context, Result};
36+
use aws_config::{AppName, BehaviorVersion, SdkConfig};
37+
use aws_sdk_s3::config::{Credentials, SharedCredentialsProvider};
38+
use aws_sdk_s3::primitives::ByteStream;
39+
use aws_sdk_s3::Client;
40+
use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
3641
use axum_server::tls_rustls::RustlsConfig;
42+
use base64::prelude::BASE64_STANDARD;
43+
use base64::Engine;
3744
use chrono::Local;
3845
use chrono_humanize::{Accuracy, HumanTime, Tense};
46+
use constcat::concat;
3947
use flate2::read::GzDecoder;
4048
use flate2::write::GzEncoder;
4149
use flate2::Compression;
@@ -84,6 +92,9 @@ pub struct State {
8492

8593
/// Https private key file
8694
key: Option<PathBuf>,
95+
96+
/// AWS client and S3 bucket
97+
aws: Option<(Client, String)>,
8798
}
8899

89100
impl State {
@@ -99,6 +110,9 @@ impl State {
99110
key: Option<PathBuf>,
100111
pg_cert: Option<PathBuf>,
101112
#[cfg(feature = "vt")] vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
113+
aws_access_key_id: Option<String>,
114+
aws_secret_access_key: Option<String>,
115+
aws_bucket: Option<String>,
102116
) -> Result<Self> {
103117
if let Some(dir) = &directory {
104118
if !dir.exists() {
@@ -121,6 +135,34 @@ impl State {
121135
ensure!(key_file.exists(), "Key file {key_file:?} does not exist!");
122136
}
123137

138+
if aws_access_key_id.is_some() != aws_secret_access_key.is_some()
139+
|| aws_access_key_id.is_some() != aws_bucket.is_some()
140+
{
141+
bail!("AWS credentials and the S3 bucket must be provided, or be `None`.");
142+
}
143+
144+
let client = if aws_access_key_id.is_none() {
145+
None
146+
} else {
147+
// The Aws-related .unwrap() calls are okay since we know they're `Some()`
148+
let creds = Credentials::new(
149+
aws_access_key_id.unwrap(),
150+
aws_secret_access_key.unwrap(),
151+
None,
152+
None,
153+
"malwaredb",
154+
);
155+
let provider = SharedCredentialsProvider::new(creds);
156+
let sleep_impl = SharedAsyncSleep::new(TokioSleep::new());
157+
let config = SdkConfig::builder()
158+
.credentials_provider(provider)
159+
.behavior_version(BehaviorVersion::latest())
160+
.sleep_impl(sleep_impl)
161+
.app_name(AppName::new(concat!("malwaredb-v", MDB_VERSION)).unwrap())
162+
.build();
163+
Some((Client::new(&config), aws_bucket.unwrap()))
164+
};
165+
124166
let db_type = db::DatabaseType::from_string(db_string, pg_cert).await?;
125167
// TODO: allow config and keys to be refreshed so changes don't require a restart
126168
let db_config = db_type.get_config().await?;
@@ -139,15 +181,19 @@ impl State {
139181
started: SystemTime::now(),
140182
cert,
141183
key,
184+
aws: client,
142185
})
143186
}
144187

145188
/// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
146189
pub async fn store_bytes(&self, data: &[u8]) -> Result<bool> {
147-
if let Some(dest_path) = &self.directory {
190+
if self.directory.is_none() && self.aws.is_none() {
191+
Ok(false)
192+
} else {
148193
let mut hasher = Sha256::new();
149194
hasher.update(data);
150-
let sha256 = hex::encode(hasher.finalize());
195+
let sha256_bytes = hasher.finalize();
196+
let sha256 = hex::encode(sha256_bytes);
151197

152198
// Trait `HashPath` needs to be re-worked so it can work with Strings.
153199
// This code below ends up making the String into ASCII representations of the hash
@@ -160,16 +206,6 @@ impl State {
160206
sha256
161207
);
162208

163-
// The path which has the file name included, with the storage directory prepended.
164-
//let hashed_path = result.hashed_path(3);
165-
let mut dest_path = dest_path.clone();
166-
dest_path.push(hashed_path);
167-
168-
// Remove the file name so we can just have the directory path.
169-
let mut just_the_dir = dest_path.clone();
170-
just_the_dir.pop();
171-
std::fs::create_dir_all(just_the_dir)?;
172-
173209
let data = if self.db_config.compression {
174210
let mut compressor = GzEncoder::new(Vec::new(), Compression::default());
175211
compressor.write_all(data)?;
@@ -190,30 +226,79 @@ impl State {
190226
data
191227
};
192228

193-
std::fs::write(dest_path, data)?;
229+
if let Some(dest_path) = &self.directory {
230+
// The path which has the file name included, with the storage directory prepended.
231+
//let hashed_path = result.hashed_path(3);
232+
let mut dest_path = dest_path.clone();
233+
dest_path.push(hashed_path);
234+
235+
// Remove the file name so we can just have the directory path.
236+
let mut just_the_dir = dest_path.clone();
237+
just_the_dir.pop();
238+
std::fs::create_dir_all(just_the_dir)?;
239+
240+
std::fs::write(dest_path, data)?;
241+
} else if let Some((client, bucket)) = &self.aws {
242+
// Data is compressed and/or encrypted, so re-hash for AWS
243+
let sha256_b64 =
244+
if self.db_config.compression || self.db_config.default_key.is_some() {
245+
let mut hasher = Sha256::new();
246+
hasher.update(&data);
247+
BASE64_STANDARD.encode(hasher.finalize())
248+
} else {
249+
BASE64_STANDARD.encode(sha256_bytes)
250+
};
251+
252+
let bytes = ByteStream::from(data);
253+
client
254+
.put_object()
255+
.bucket(bucket)
256+
.checksum_sha256(sha256_b64)
257+
.key(hashed_path)
258+
.body(bytes)
259+
.send()
260+
.await?;
261+
}
194262

195263
Ok(true)
196-
} else {
197-
Ok(false)
198264
}
199265
}
200266

201267
/// Retrieve a sample given the SHA-256 hash
202268
/// Assumes that MalwareDB permissions have already been checked to ensure this is permitted.
203269
pub async fn retrieve_bytes(&self, sha256: &String) -> Result<Vec<u8>> {
204-
if let Some(dest_path) = &self.directory {
270+
if self.directory.is_none() && self.aws.is_none() {
271+
bail!("files are not saved")
272+
} else {
273+
// Trait `HashPath` needs to be re-worked so it can work with Strings.
274+
// This code below ends up making the String into ASCII representations of the hash
275+
// See: https://github.com/malwaredb/malwaredb-rs/issues/60
276+
//let path = sha256.as_bytes().iter().hashed_path(3);
277+
205278
let path = format!(
206279
"{}/{}/{}/{}",
207280
&sha256[0..2],
208281
&sha256[2..4],
209282
&sha256[4..6],
210283
sha256
211284
);
212-
// Trait `HashPath` needs to be re-worked so it can work with Strings.
213-
// This code below ends up making the String into ASCII representations of the hash
214-
// See: https://github.com/malwaredb/malwaredb-rs/issues/60
215-
//let path = sha256.as_bytes().iter().hashed_path(3);
216-
let contents = std::fs::read(dest_path.join(path))?;
285+
286+
let contents = if let Some(dest_path) = &self.directory {
287+
std::fs::read(dest_path.join(path))?
288+
} else if let Some((client, bucket)) = &self.aws {
289+
client
290+
.get_object()
291+
.bucket(bucket)
292+
.key(path)
293+
.send()
294+
.await?
295+
.body
296+
.collect()
297+
.await?
298+
.to_vec()
299+
} else {
300+
bail!("No AWS config and no local directory, shouldn't happen!")
301+
};
217302

218303
let contents = if !self.keys.is_empty() {
219304
let (key_id, nonce) = self.db_type.get_file_encryption_key_id(sha256).await?;
@@ -241,8 +326,6 @@ impl State {
241326
} else {
242327
Ok(contents)
243328
}
244-
} else {
245-
bail!("files are not saved")
246329
}
247330
}
248331

src/cli/config.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,21 @@ pub struct Config {
8080
#[cfg(feature = "vt")]
8181
#[serde(default)]
8282
pub vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
83+
84+
/// AWS Access Key
85+
#[arg(short, long)]
86+
#[serde(default)]
87+
pub access_key_id: Option<String>,
88+
89+
/// AWS Secret Access Key
90+
#[arg(short, long)]
91+
#[serde(default)]
92+
pub secret_access_key: Option<String>,
93+
94+
/// AWS S3 bucket
95+
#[arg(short, long)]
96+
#[serde(default)]
97+
pub bucket: Option<String>,
8398
}
8499

85100
const fn default_port() -> u16 {
@@ -108,6 +123,9 @@ impl Config {
108123
self.pg_cert,
109124
#[cfg(feature = "vt")]
110125
self.vt_client,
126+
self.access_key_id,
127+
self.secret_access_key,
128+
self.bucket,
111129
)
112130
.await
113131
}
@@ -174,6 +192,9 @@ impl Default for Config {
174192
cert: None,
175193
key: None,
176194
pg_cert: None,
195+
access_key_id: None,
196+
secret_access_key: None,
197+
bucket: None,
177198
}
178199
}
179200
}

src/cli/run.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ impl Run {
3939
cfg.pg_cert,
4040
#[cfg(feature = "vt")]
4141
cfg.vt_client,
42+
cfg.access_key_id,
43+
cfg.secret_access_key,
44+
cfg.bucket,
4245
)
4346
.await
4447
}
@@ -76,6 +79,7 @@ impl Load {
7679
}
7780
}
7881

82+
#[allow(clippy::large_enum_variant)]
7983
#[derive(Subcommand, Clone, Debug)]
8084
pub enum Subcommands {
8185
Load(Load),
@@ -105,6 +109,9 @@ mod tests {
105109
None,
106110
#[cfg(feature = "vt")]
107111
None,
112+
None,
113+
None,
114+
None,
108115
)
109116
.await
110117
.unwrap();

src/cli/vt/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ impl VtOption {
3535
}
3636

3737
/// VirusTotal config either from disk or command line for VT operations outside the normal running of MalwareDB
38+
#[allow(clippy::large_enum_variant)]
3839
#[derive(Subcommand, Clone, Debug)]
3940
enum VtConfig {
4041
/// Load VirusTotal config from a file

src/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ async fn main() -> anyhow::Result<ExitCode> {
3434
"malwaredb_types",
3535
"deadpool_postgres",
3636
"postgres",
37+
"aws_config",
38+
"aws_sdk_s3",
39+
"aws_smithy_async",
3740
#[cfg(feature = "sqlite")]
3841
"rusqlite",
3942
]

0 commit comments

Comments
 (0)