diff --git a/README.md b/README.md index e0ce5cb6..b262a861 100644 --- a/README.md +++ b/README.md @@ -344,6 +344,12 @@ This tool will scan all objects that have not been previously scanned in the buc asynchronously. As such you'll have to go to your cloudwatch logs to see the scan results or failures. Additionally, the script uses the same environment variables you'd use in your lambda so you can configure them similarly. +If you want to scan a subset of the bucket (or even a single object) you can use the `--prefix` option to filter objects. Remember that S3 objects are not real paths and prefixes match on substrings, so `--prefix=test` will match objects named `test`, `testfoo`, and `testing/foo`: + +```sh +scan_bucket.py --lambda-function-name= --s3-bucket-name= --prefix=path/to/files/ +``` + ## Testing There are two types of tests in this repository. The first is pre-commit tests and the second are python tests. All of diff --git a/display_infected.py b/display_infected.py index 0c40bc98..b69ae8cb 100755 --- a/display_infected.py +++ b/display_infected.py @@ -29,13 +29,15 @@ # Get all objects in an S3 bucket that are infected -def get_objects_and_sigs(s3_client, s3_bucket_name): +def get_objects_and_sigs(s3_client, s3_bucket_name, prefix=None): s3_object_list = [] s3_list_objects_result = {"IsTruncated": True} while s3_list_objects_result["IsTruncated"]: s3_list_objects_config = {"Bucket": s3_bucket_name} + if prefix is not None: + s3_list_objects_config["Prefix"] = prefix continuation_token = s3_list_objects_result.get("NextContinuationToken") if continuation_token: s3_list_objects_config["ContinuationToken"] = continuation_token @@ -75,7 +77,7 @@ def object_infected(s3_client, s3_bucket_name, key_name): return False, None -def main(s3_bucket_name): +def main(s3_bucket_name, prefix): # Verify the S3 bucket exists s3_client = boto3.client("s3") @@ -86,7 +88,7 @@ def main(s3_bucket_name): sys.exit(1) # Scan the objects in the bucket - s3_object_and_sigs_list = get_objects_and_sigs(s3_client, s3_bucket_name) + s3_object_and_sigs_list = get_objects_and_sigs(s3_client, s3_bucket_name, prefix) for (key_name, av_signature) in s3_object_and_sigs_list: print("Infected: {}/{}, {}".format(s3_bucket_name, key_name, av_signature)) @@ -98,6 +100,7 @@ def main(s3_bucket_name): parser.add_argument( "--s3-bucket-name", required=True, help="The name of the S3 bucket to scan" ) + parser.add_argument("--prefix", help="The prefix to filter the bucket objects by") args = parser.parse_args() - main(args.s3_bucket_name) + main(args.s3_bucket_name, args.prefix) diff --git a/scan_bucket.py b/scan_bucket.py index 6043ffb0..11e8236e 100755 --- a/scan_bucket.py +++ b/scan_bucket.py @@ -26,13 +26,15 @@ # Get all objects in an S3 bucket that have not been previously scanned -def get_objects(s3_client, s3_bucket_name): +def get_objects(s3_client, s3_bucket_name, prefix=None): s3_object_list = [] s3_list_objects_result = {"IsTruncated": True} while s3_list_objects_result["IsTruncated"]: s3_list_objects_config = {"Bucket": s3_bucket_name} + if prefix is not None: + s3_list_objects_config["Prefix"] = prefix continuation_token = s3_list_objects_result.get("NextContinuationToken") if continuation_token: s3_list_objects_config["ContinuationToken"] = continuation_token @@ -85,7 +87,7 @@ def format_s3_event(s3_bucket_name, key_name): return s3_event -def main(lambda_function_name, s3_bucket_name, limit): +def main(lambda_function_name, s3_bucket_name, limit, prefix): # Verify the lambda exists lambda_client = boto3.client("lambda") try: @@ -103,7 +105,7 @@ def main(lambda_function_name, s3_bucket_name, limit): sys.exit(1) # Scan the objects in the bucket - s3_object_list = get_objects(s3_client, s3_bucket_name) + s3_object_list = get_objects(s3_client, s3_bucket_name, prefix) if limit: s3_object_list = s3_object_list[: min(limit, len(s3_object_list))] for key_name in s3_object_list: @@ -121,6 +123,12 @@ def main(lambda_function_name, s3_bucket_name, limit): "--s3-bucket-name", required=True, help="The name of the S3 bucket to scan" ) parser.add_argument("--limit", type=int, help="The number of records to limit to") + parser.add_argument("--prefix", help="The prefix to filter the bucket objects by") args = parser.parse_args() - main(args.lambda_function_name, args.s3_bucket_name, args.limit) + main( + args.lambda_function_name, + args.s3_bucket_name, + limit=args.limit, + prefix=args.prefix, + )