Skip to content
This repository was archived by the owner on Jun 20, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ This tool will scan all objects that have not been previously scanned in the buc
asynchronously. As such you'll have to go to your cloudwatch logs to see the scan results or failures. Additionally,
the script uses the same environment variables you'd use in your lambda so you can configure them similarly.

If you want to scan a subset of the bucket (or even a single object) you can use the `--prefix` option to filter objects. Remember that S3 objects are not real paths and prefixes match on substrings, so `--prefix=test` will match objects named `test`, `testfoo`, and `testing/foo`:

```sh
scan_bucket.py --lambda-function-name=<lambda_function_name> --s3-bucket-name=<s3-bucket-to-scan> --prefix=path/to/files/
```

## Testing

There are two types of tests in this repository. The first is pre-commit tests and the second are python tests. All of
Expand Down
11 changes: 7 additions & 4 deletions display_infected.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@


# Get all objects in an S3 bucket that are infected
def get_objects_and_sigs(s3_client, s3_bucket_name):
def get_objects_and_sigs(s3_client, s3_bucket_name, prefix=None):

s3_object_list = []

s3_list_objects_result = {"IsTruncated": True}
while s3_list_objects_result["IsTruncated"]:
s3_list_objects_config = {"Bucket": s3_bucket_name}
if prefix is not None:
s3_list_objects_config["Prefix"] = prefix
continuation_token = s3_list_objects_result.get("NextContinuationToken")
if continuation_token:
s3_list_objects_config["ContinuationToken"] = continuation_token
Expand Down Expand Up @@ -75,7 +77,7 @@ def object_infected(s3_client, s3_bucket_name, key_name):
return False, None


def main(s3_bucket_name):
def main(s3_bucket_name, prefix):

# Verify the S3 bucket exists
s3_client = boto3.client("s3")
Expand All @@ -86,7 +88,7 @@ def main(s3_bucket_name):
sys.exit(1)

# Scan the objects in the bucket
s3_object_and_sigs_list = get_objects_and_sigs(s3_client, s3_bucket_name)
s3_object_and_sigs_list = get_objects_and_sigs(s3_client, s3_bucket_name, prefix)
for (key_name, av_signature) in s3_object_and_sigs_list:
print("Infected: {}/{}, {}".format(s3_bucket_name, key_name, av_signature))

Expand All @@ -98,6 +100,7 @@ def main(s3_bucket_name):
parser.add_argument(
"--s3-bucket-name", required=True, help="The name of the S3 bucket to scan"
)
parser.add_argument("--prefix", help="The prefix to filter the bucket objects by")
args = parser.parse_args()

main(args.s3_bucket_name)
main(args.s3_bucket_name, args.prefix)
16 changes: 12 additions & 4 deletions scan_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@


# Get all objects in an S3 bucket that have not been previously scanned
def get_objects(s3_client, s3_bucket_name):
def get_objects(s3_client, s3_bucket_name, prefix=None):

s3_object_list = []

s3_list_objects_result = {"IsTruncated": True}
while s3_list_objects_result["IsTruncated"]:
s3_list_objects_config = {"Bucket": s3_bucket_name}
if prefix is not None:
s3_list_objects_config["Prefix"] = prefix
continuation_token = s3_list_objects_result.get("NextContinuationToken")
if continuation_token:
s3_list_objects_config["ContinuationToken"] = continuation_token
Expand Down Expand Up @@ -85,7 +87,7 @@ def format_s3_event(s3_bucket_name, key_name):
return s3_event


def main(lambda_function_name, s3_bucket_name, limit):
def main(lambda_function_name, s3_bucket_name, limit, prefix):
# Verify the lambda exists
lambda_client = boto3.client("lambda")
try:
Expand All @@ -103,7 +105,7 @@ def main(lambda_function_name, s3_bucket_name, limit):
sys.exit(1)

# Scan the objects in the bucket
s3_object_list = get_objects(s3_client, s3_bucket_name)
s3_object_list = get_objects(s3_client, s3_bucket_name, prefix)
if limit:
s3_object_list = s3_object_list[: min(limit, len(s3_object_list))]
for key_name in s3_object_list:
Expand All @@ -121,6 +123,12 @@ def main(lambda_function_name, s3_bucket_name, limit):
"--s3-bucket-name", required=True, help="The name of the S3 bucket to scan"
)
parser.add_argument("--limit", type=int, help="The number of records to limit to")
parser.add_argument("--prefix", help="The prefix to filter the bucket objects by")
args = parser.parse_args()

main(args.lambda_function_name, args.s3_bucket_name, args.limit)
main(
args.lambda_function_name,
args.s3_bucket_name,
limit=args.limit,
prefix=args.prefix,
)