@@ -33,9 +33,17 @@ use std::sync::Arc;
3333use std:: time:: { Duration , SystemTime } ;
3434
3535use anyhow:: { bail, ensure, Context , Result } ;
36+ use aws_config:: { AppName , BehaviorVersion , SdkConfig } ;
37+ use aws_sdk_s3:: config:: { Credentials , SharedCredentialsProvider } ;
38+ use aws_sdk_s3:: primitives:: ByteStream ;
39+ use aws_sdk_s3:: Client ;
40+ use aws_smithy_async:: rt:: sleep:: { SharedAsyncSleep , TokioSleep } ;
3641use axum_server:: tls_rustls:: RustlsConfig ;
42+ use base64:: prelude:: BASE64_STANDARD ;
43+ use base64:: Engine ;
3744use chrono:: Local ;
3845use chrono_humanize:: { Accuracy , HumanTime , Tense } ;
46+ use constcat:: concat;
3947use flate2:: read:: GzDecoder ;
4048use flate2:: write:: GzEncoder ;
4149use flate2:: Compression ;
@@ -84,6 +92,9 @@ pub struct State {
8492
8593 /// Https private key file
8694 key : Option < PathBuf > ,
95+
96+ /// AWS client and S3 bucket
97+ aws : Option < ( Client , String ) > ,
8798}
8899
89100impl State {
@@ -99,6 +110,9 @@ impl State {
99110 key : Option < PathBuf > ,
100111 pg_cert : Option < PathBuf > ,
101112 #[ cfg( feature = "vt" ) ] vt_client : Option < malwaredb_virustotal:: VirusTotalClient > ,
113+ aws_access_key_id : Option < String > ,
114+ aws_secret_access_key : Option < String > ,
115+ aws_bucket : Option < String > ,
102116 ) -> Result < Self > {
103117 if let Some ( dir) = & directory {
104118 if !dir. exists ( ) {
@@ -121,6 +135,34 @@ impl State {
121135 ensure ! ( key_file. exists( ) , "Key file {key_file:?} does not exist!" ) ;
122136 }
123137
138+ if aws_access_key_id. is_some ( ) != aws_secret_access_key. is_some ( )
139+ || aws_access_key_id. is_some ( ) != aws_bucket. is_some ( )
140+ {
141+ bail ! ( "AWS credentials and the S3 bucket must be provided, or be `None`." ) ;
142+ }
143+
144+ let client = if aws_access_key_id. is_none ( ) {
145+ None
146+ } else {
147+ // The Aws-related .unwrap() calls are okay since we know they're `Some()`
148+ let creds = Credentials :: new (
149+ aws_access_key_id. unwrap ( ) ,
150+ aws_secret_access_key. unwrap ( ) ,
151+ None ,
152+ None ,
153+ "malwaredb" ,
154+ ) ;
155+ let provider = SharedCredentialsProvider :: new ( creds) ;
156+ let sleep_impl = SharedAsyncSleep :: new ( TokioSleep :: new ( ) ) ;
157+ let config = SdkConfig :: builder ( )
158+ . credentials_provider ( provider)
159+ . behavior_version ( BehaviorVersion :: latest ( ) )
160+ . sleep_impl ( sleep_impl)
161+ . app_name ( AppName :: new ( concat ! ( "malwaredb-v" , MDB_VERSION ) ) . unwrap ( ) )
162+ . build ( ) ;
163+ Some ( ( Client :: new ( & config) , aws_bucket. unwrap ( ) ) )
164+ } ;
165+
124166 let db_type = db:: DatabaseType :: from_string ( db_string, pg_cert) . await ?;
125167 // TODO: allow config and keys to be refreshed so changes don't require a restart
126168 let db_config = db_type. get_config ( ) . await ?;
@@ -139,15 +181,19 @@ impl State {
139181 started : SystemTime :: now ( ) ,
140182 cert,
141183 key,
184+ aws : client,
142185 } )
143186 }
144187
145188 /// Store the sample with a depth of three based on the sample's SHA-256 hash, even if compressed
146189 pub async fn store_bytes ( & self , data : & [ u8 ] ) -> Result < bool > {
147- if let Some ( dest_path) = & self . directory {
190+ if self . directory . is_none ( ) && self . aws . is_none ( ) {
191+ Ok ( false )
192+ } else {
148193 let mut hasher = Sha256 :: new ( ) ;
149194 hasher. update ( data) ;
150- let sha256 = hex:: encode ( hasher. finalize ( ) ) ;
195+ let sha256_bytes = hasher. finalize ( ) ;
196+ let sha256 = hex:: encode ( sha256_bytes) ;
151197
152198 // Trait `HashPath` needs to be re-worked so it can work with Strings.
153199 // This code below ends up making the String into ASCII representations of the hash
@@ -160,16 +206,6 @@ impl State {
160206 sha256
161207 ) ;
162208
163- // The path which has the file name included, with the storage directory prepended.
164- //let hashed_path = result.hashed_path(3);
165- let mut dest_path = dest_path. clone ( ) ;
166- dest_path. push ( hashed_path) ;
167-
168- // Remove the file name so we can just have the directory path.
169- let mut just_the_dir = dest_path. clone ( ) ;
170- just_the_dir. pop ( ) ;
171- std:: fs:: create_dir_all ( just_the_dir) ?;
172-
173209 let data = if self . db_config . compression {
174210 let mut compressor = GzEncoder :: new ( Vec :: new ( ) , Compression :: default ( ) ) ;
175211 compressor. write_all ( data) ?;
@@ -190,30 +226,79 @@ impl State {
190226 data
191227 } ;
192228
193- std:: fs:: write ( dest_path, data) ?;
229+ if let Some ( dest_path) = & self . directory {
230+ // The path which has the file name included, with the storage directory prepended.
231+ //let hashed_path = result.hashed_path(3);
232+ let mut dest_path = dest_path. clone ( ) ;
233+ dest_path. push ( hashed_path) ;
234+
235+ // Remove the file name so we can just have the directory path.
236+ let mut just_the_dir = dest_path. clone ( ) ;
237+ just_the_dir. pop ( ) ;
238+ std:: fs:: create_dir_all ( just_the_dir) ?;
239+
240+ std:: fs:: write ( dest_path, data) ?;
241+ } else if let Some ( ( client, bucket) ) = & self . aws {
242+ // Data is compressed and/or encrypted, so re-hash for AWS
243+ let sha256_b64 =
244+ if self . db_config . compression || self . db_config . default_key . is_some ( ) {
245+ let mut hasher = Sha256 :: new ( ) ;
246+ hasher. update ( & data) ;
247+ BASE64_STANDARD . encode ( hasher. finalize ( ) )
248+ } else {
249+ BASE64_STANDARD . encode ( sha256_bytes)
250+ } ;
251+
252+ let bytes = ByteStream :: from ( data) ;
253+ client
254+ . put_object ( )
255+ . bucket ( bucket)
256+ . checksum_sha256 ( sha256_b64)
257+ . key ( hashed_path)
258+ . body ( bytes)
259+ . send ( )
260+ . await ?;
261+ }
194262
195263 Ok ( true )
196- } else {
197- Ok ( false )
198264 }
199265 }
200266
201267 /// Retrieve a sample given the SHA-256 hash
202268 /// Assumes that MalwareDB permissions have already been checked to ensure this is permitted.
203269 pub async fn retrieve_bytes ( & self , sha256 : & String ) -> Result < Vec < u8 > > {
204- if let Some ( dest_path) = & self . directory {
270+ if self . directory . is_none ( ) && self . aws . is_none ( ) {
271+ bail ! ( "files are not saved" )
272+ } else {
273+ // Trait `HashPath` needs to be re-worked so it can work with Strings.
274+ // This code below ends up making the String into ASCII representations of the hash
275+ // See: https://github.com/malwaredb/malwaredb-rs/issues/60
276+ //let path = sha256.as_bytes().iter().hashed_path(3);
277+
205278 let path = format ! (
206279 "{}/{}/{}/{}" ,
207280 & sha256[ 0 ..2 ] ,
208281 & sha256[ 2 ..4 ] ,
209282 & sha256[ 4 ..6 ] ,
210283 sha256
211284 ) ;
212- // Trait `HashPath` needs to be re-worked so it can work with Strings.
213- // This code below ends up making the String into ASCII representations of the hash
214- // See: https://github.com/malwaredb/malwaredb-rs/issues/60
215- //let path = sha256.as_bytes().iter().hashed_path(3);
216- let contents = std:: fs:: read ( dest_path. join ( path) ) ?;
285+
286+ let contents = if let Some ( dest_path) = & self . directory {
287+ std:: fs:: read ( dest_path. join ( path) ) ?
288+ } else if let Some ( ( client, bucket) ) = & self . aws {
289+ client
290+ . get_object ( )
291+ . bucket ( bucket)
292+ . key ( path)
293+ . send ( )
294+ . await ?
295+ . body
296+ . collect ( )
297+ . await ?
298+ . to_vec ( )
299+ } else {
300+ bail ! ( "No AWS config and no local directory, shouldn't happen!" )
301+ } ;
217302
218303 let contents = if !self . keys . is_empty ( ) {
219304 let ( key_id, nonce) = self . db_type . get_file_encryption_key_id ( sha256) . await ?;
@@ -241,8 +326,6 @@ impl State {
241326 } else {
242327 Ok ( contents)
243328 }
244- } else {
245- bail ! ( "files are not saved" )
246329 }
247330 }
248331
0 commit comments