diff --git a/main.py b/main.py index 400fc0e..1fed423 100755 --- a/main.py +++ b/main.py @@ -27,7 +27,9 @@ import logging import os.path import pickle +import shlex import sys +import subprocess from recuperabit import logic, utils # scanners @@ -70,7 +72,7 @@ def list_parts(parts, shorthands, test): """List partitions corresponding to test.""" for i, part in shorthands: if test(parts[part]): - print('Partition #' + str(i), '->', parts[part]) + yield 'Partition #' + str(i), '->', parts[part] def check_valid_part(num, parts, shorthands, rebuild=True): @@ -96,22 +98,22 @@ def check_valid_part(num, parts, shorthands, rebuild=True): def interpret(cmd, arguments, parts, shorthands, outdir): """Perform command required by user.""" if cmd == 'help': - print('Available commands:') - for name, desc in commands: - print(' %s%s' % (name.ljust(28), desc)) + return 'Available commands:' \ + + "\n".join([' %s%s' % (name.ljust(28), desc) for name, desc in commands]) \ + + "\n" elif cmd == 'tree': if len(arguments) != 1: - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: part = check_valid_part(arguments[0], parts, shorthands) if part is not None: - print('-'*10) - print(utils.tree_folder(part.root)) - print(utils.tree_folder(part.lost)) - print('-'*10) + return "\n".join(['-'*10, + utils.tree_folder(part.root), + utils.tree_folder(part.lost), + '-'*10]) elif cmd == 'bodyfile': if len(arguments) != 2: - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: part = check_valid_part(arguments[0], parts, shorthands) if part is not None: @@ -125,12 +127,12 @@ def interpret(cmd, arguments, parts, shorthands, outdir): try: with codecs.open(fname, 'w', encoding='utf8') as outfile: outfile.write('\n'.join(contents)) - print('Saved body file to %s' % fname) + return 'Saved body file to %s' % fname except IOError: - print('Cannot open file %s for output!' % fname) + return 'Cannot open file %s for output!' % fname elif cmd == 'csv': if len(arguments) != 2: - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: part = check_valid_part(arguments[0], parts, shorthands) if part is not None: @@ -141,12 +143,12 @@ def interpret(cmd, arguments, parts, shorthands, outdir): outfile.write( '\n'.join(contents) ) - print('Saved CSV file to %s' % fname) + return 'Saved CSV file to %s' % fname except IOError: - print('Cannot open file %s for output!' % fname) + return 'Cannot open file %s for output!' % fname elif cmd == 'tikzplot': if len(arguments) not in (1, 2): - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: part = check_valid_part(arguments[0], parts, shorthands) if part is not None: @@ -155,14 +157,14 @@ def interpret(cmd, arguments, parts, shorthands, outdir): try: with codecs.open(fname, 'w') as outfile: outfile.write(utils.tikz_part(part) + '\n') - print('Saved Tikz code to %s' % fname) + return 'Saved Tikz code to %s' % fname except IOError: - print('Cannot open file %s for output!' % fname) + return 'Cannot open file %s for output!' % fname else: - print(utils.tikz_part(part)) + return utils.tikz_part(part) elif cmd == 'restore': if len(arguments) != 2: - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: partid = arguments[0] part = check_valid_part(partid, parts, shorthands) @@ -177,12 +179,12 @@ def interpret(cmd, arguments, parts, shorthands, outdir): for i in [index, indexi]: myfile = part.get(i, myfile) if myfile is None: - print('The index is not valid') + return 'The index is not valid' else: logic.recursive_restore(myfile, part, partition_dir) elif cmd == 'locate': if len(arguments) != 2: - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: part = check_valid_part(arguments[0], parts, shorthands) if part is not None: @@ -193,10 +195,10 @@ def interpret(cmd, arguments, parts, shorthands, outdir): ' [GHOST]' if node.is_ghost else ' [DELETED]' if node.is_deleted else '' ) - print('[%s]: %s%s' % (node.index, path, desc)) + return '[%s]: %s%s' % (node.index, path, desc) elif cmd == 'traceback': if len(arguments) != 2: - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: partid = arguments[0] part = check_valid_part(partid, parts, shorthands) @@ -210,23 +212,22 @@ def interpret(cmd, arguments, parts, shorthands, outdir): for i in [index, indexi]: myfile = part.get(i, myfile) if myfile is None: - print('The index is not valid') + return 'The index is not valid' else: while myfile is not None: - print('[{}] {}'.format(myfile.index, myfile.full_path(part))) + return '[{}] {}'.format(myfile.index, myfile.full_path(part)) myfile = part.get(myfile.parent) elif cmd == 'merge': if len(arguments) != 2: - print('Wrong number of parameters!') + return 'Wrong number of parameters!' else: part1 = check_valid_part(arguments[0], parts, shorthands, rebuild=False) part2 = check_valid_part(arguments[1], parts, shorthands, rebuild=False) if None in (part1, part2): return if part1.fs_type != part2.fs_type: - print('Cannot merge partitions with types (%s, %s)' % (part1.fs_type, part2.fs_type)) - return - print('Merging partitions...') + return 'Cannot merge partitions with types (%s, %s)' % (part1.fs_type, part2.fs_type) + # print('Merging partitions...') utils.merge(part1, part2) source_position = int(arguments[1]) destination_position = int(arguments[0]) @@ -239,22 +240,22 @@ def interpret(cmd, arguments, parts, shorthands, outdir): rebuilt.remove(par) except: pass - print('There are now %d partitions.' % (len(parts), )) + return 'There are now %d partitions.' % (len(parts), ) elif cmd == 'recoverable': - list_parts(parts, shorthands, lambda x: x.recoverable) + return "\n".join(list_parts(parts, shorthands, lambda x: x.recoverable)) elif cmd == 'recoverable_size': if len(arguments) != 1: - print('Wrong number of parameters!') + return "Wrong number of parameters!" else: - list_parts(parts, shorthands, lambda x: x.size is not None and x.size > arguments[0]) + return "\n".join(list_parts(parts, shorthands, lambda x: x.size is not None and x.size > arguments[0])) elif cmd == 'other': - list_parts(parts, shorthands, lambda x: not x.recoverable) + return "\n".join(list_parts(parts, shorthands, lambda x: not x.recoverable)) elif cmd == 'allparts': - list_parts(parts, shorthands, lambda x: True) + return "\n".join(list_parts(parts, shorthands, lambda x: True)) elif cmd == 'quit': exit(0) else: - print('Unknown command.') + return 'Unknown command.' def main(): @@ -363,18 +364,60 @@ def main(): parts.update(scanner.get_partitions()) shorthands = list(enumerate(parts)) + known_cmds = [e[0].split(' ')[0] for e in commands] logging.info('%i partitions found.', len(parts)) while True: print('\nWrite command ("help" for details):') try: - command = input('> ').split(' ') + command = shlex.split(input('> ').strip()) except (EOFError, KeyboardInterrupt): print('') exit(0) + if len(command) == 0: + continue + + # pipe/redirect shell expression handling + # (if the first word is a known recuperabit command) + if ('>' in command or '|' in command) and command[0] in known_cmds: + end = None + for end, k in enumerate(command): + if k == '>' or k == '|': + break + + redirect_to = None + running = command[0:end] + out = interpret(running[0], running[1:end - 1], parts, shorthands, args.outputdir) + command = command[end:] + if command[-2] == '>': + if not os.path.exists(command[-1]): + redirect_to = command[-1] + command.pop() + command.pop() + + # print(f"Python runs \"{' '.join(running)}\", shell runs \"{' '.join(command)}\" redirects to {redirect_to}") + if command[0] == '|': + try: + r = subprocess.run(command[1:], input=out, capture_output=True, encoding='utf8', check=True) + out = r.stdout + except (TypeError, FileNotFoundError) as e: + print("Shell expression error:", e) + except subprocess.CalledProcessError as e: + print("Shell expression error:", e) + print(e.stderr) + if redirect_to: + with open(redirect_to, 'w', encoding='utf8') as f: + f.write(out) + print(f'{len(out)} bytes written to {redirect_to}') + else: + print(out) + continue + cmd = command[0] arguments = command[1:] - interpret(cmd, arguments, parts, shorthands, args.outputdir) + out = interpret(cmd, arguments, parts, shorthands, args.outputdir) + if out is not None: + print(out) if __name__ == '__main__': main()