diff --git a/scripts/format_includes.py b/scripts/format_includes.py index 463bf346..e1f515bf 100644 --- a/scripts/format_includes.py +++ b/scripts/format_includes.py @@ -1,8 +1,18 @@ +import os import sys -def check_line(line, words): +def check_line(line: str, words): + index_include: int = line.find("#include") + if index_include == -1: + return False + index_include += len('#include') + line_after_include = line[index_include:] + index: int = line_after_include.find("\"") + 1 + if (index == 0): + return False + file = line_after_include[index:] for word in words: - if line.startswith("#include \"" + word): + if (file.startswith(word)): return True return False @@ -26,15 +36,20 @@ def format_file(input_path, words): with open(input_path, 'w') as file: file.writelines(formatted_content) - +def format_directory(directory, words): + for root, _, files in os.walk(directory): + for file_name in files: + if file_name.endswith(('.hpp', '.cpp')): + file_path = os.path.join(root, file_name) + format_file(file_path, words) def main(): if len(sys.argv) < 3: - print('Usage: python3 format_includes.py start_words...') + print('Usage: python3 format_includes.py start_words...') return input_path = sys.argv[1] words = sys.argv[2:] - format_file(input_path, words) + format_directory(input_path, words) if __name__ == '__main__':