diff --git a/.gitignore b/.gitignore index e581009..d65843d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,116 @@ -__pycache__ -*.pyc -*.egg-info +#### joe made this: http://goel.io/joe + +#####=== IPythonNotebook ===##### +# Temporary data +.ipynb_checkpoints/ + +#####=== Python ===##### + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +#####=== JetBrains ===##### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio + +*.iml + +## Directory-based project format: +.idea/ +# if you remove the above rule, at least ignore the following: + +# User-specific stuff: +# .idea/workspace.xml +# .idea/tasks.xml +# .idea/dictionaries + +# Sensitive or high-churn files: +# .idea/dataSources.ids +# .idea/dataSources.xml +# .idea/sqlDataSources.xml +# .idea/dynamic.xml +# .idea/uiDesigner.xml + +# Gradle: +# .idea/gradle.xml +# .idea/libraries + +# Mongo Explorer plugin: +# .idea/mongoSettings.xml + +## File-based project format: +*.ipr +*.iws + +## Plugin-specific files: + +# IntelliJ +/out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties + +.ropeproject diff --git a/.travis.yml b/.travis.yml index f80490f..6eb9fd3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,10 +3,19 @@ language: python matrix: include: - python: '2.7' + env: + IPYTHON_VERSION='5.8' - python: '3.5' + env: + IPYTHON_VERSION='7.9' - python: '3.6' + env: + IPYTHON_VERSION='7' - python: '3.7' + env: + IPYTHON_VERSION='7' install: +- pip install IPython==$IPYTHON_VERSION - python setup.py install - pip install -r requirements.txt script: diff --git a/README.md b/README.md index 91c321f..b59f636 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,13 @@ pytorch_memlab A simple and accurate **CUDA** memory management laboratory for pytorch, it consists of different parts about the memory: - - A `line_profiler` style CUDA memory profiler with simple API. - - A reporter to inspect tensors occupying the CUDA memory. - - An interesting feature to temporarily move all the CUDA tensors into - CPU memory for courtesy, and of course the backward transferring. + + - [A `line_profiler` style CUDA memory profiler with simple API.](#memory-profiler) + - [A reporter to inspect tensors occupying the CUDA memory.](#memory-reporter) + - [An interesting feature to temporarily move all the CUDA tensors into + CPU memory for courtesy, and of course the backward transferring.](#courtesy) + - [IPython support through `%mlrun`/`%%mlrun` line/cell magic + commands.](#ipython-support) Installation ----- @@ -130,6 +133,49 @@ func() More samples can be found in `test/test_line_profiler.py` +### IPython support + +Make sure you have `IPython` installed, or have installed `pytorch-memlab` with +`pip install pytorch-memlab[ipython]`. + +First, load the extension: + +```python +%%load_ext pytorch_memlab +``` + +This makes the `%mlrun` and `%%mlrun` line/cell magics available for use. For +example, in a new cell run the following to profile an entire cell + +```python +%%mlrun -f func +import torch +from pytorch_memlab import profile, set_target_gpu +def func(): + net1 = torch.nn.Linear(1024, 1024).cuda(0) + set_target_gpu(1) + net2 = torch.nn.Linear(1024, 1024).cuda(1) + set_target_gpu(0) + net3 = torch.nn.Linear(1024, 1024).cuda(0) +``` + +Or you can invoke the profiler for a single statement on via the `%mlrun` cell +magic. + +```python +import torch +from pytorch_memlab import profile, set_target_gpu +def func(input_size): + net1 = torch.nn.Linear(input_size, 1024).cuda(0) +%mlrun -f func func(2048) +``` + +See `%mlrun?` for help on what arguments are supported. You can set the GPU +device to profile, dump profiling results to a file, and return the +`LineProfiler` object for post-profile inspection. + +Find out more by checking out the [demo Jupyter notebook](./demo.ipynb) + ### Memory Reporter diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 0000000..de82abb --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once installed, you need to load the `pytorch_memlab` IPython extensions:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext pytorch_memlab" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One magic is provided, `mlrun` which can act either as a line magic `%mlrun`, or as a cell magic `%%mlrun`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mDocstring:\u001b[0m\n", + "::\n", + "\n", + " %mlrun [--function FUNC] [-r] [-T OUTPUT] [-g GPU_ID] [-q]\n", + " [statement [statement ...]]\n", + "\n", + "Execute a statement/cell under the PyTorch Memlab profiler to collect CUDA memory\n", + "allocation information on a per-line basis.\n", + "\n", + "positional arguments:\n", + " statement Code to run under profiler. You can omit this in cell\n", + " magic mode.\n", + "\n", + "optional arguments:\n", + " --function FUNC, -f FUNC\n", + " Function to profile. Can be specified multiple times\n", + " -r, --return-profiler\n", + " Return LineProfiler object for introspection\n", + " -T OUTPUT, --dump-profile OUTPUT\n", + " Dump text profile output to file\n", + " -g GPU_ID, --gpu GPU_ID\n", + " Profile memory usage of GPU ID\n", + " -q, --quiet Don't print out profile results\n", + "\u001b[0;31mFile:\u001b[0m ~/pytorch_memlab/pytorch_memlab/extension.py\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%mlrun?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we need some torch code to profile:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def x():\n", + " torch.nn.Linear(100, 100).cuda()\n", + " \n", + "def y(gpu=0):\n", + " torch.nn.Linear(1000, 100).cuda(device=gpu)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can profile multiple functions at the same type by repeatedly specifying `-f`" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "File: \n", + "Function: x at line 3\n", + "\n", + "Line # Max usage Peak usage diff max diff peak Line Contents\n", + "===============================================================\n", + " 3 def x():\n", + " 4 0.00B 2.00M 0.00B 2.00M torch.nn.Linear(100, 100).cuda()\n", + "\n", + "File: \n", + "Function: y at line 6\n", + "\n", + "Line # Max usage Peak usage diff max diff peak Line Contents\n", + "===============================================================\n", + " 6 def y(gpu=0):\n", + " 7 0.00B 2.00M 0.00B 2.00M torch.nn.Linear(1000, 100).cuda(device=gpu)\n", + "\n" + ] + } + ], + "source": [ + "%%mlrun -f x -f y\n", + "\n", + "x()\n", + "y()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can alos profile with the `%mlrun` line magic" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "File: \n", + "Function: z at line 1\n", + "\n", + "Line # Max usage Peak usage diff max diff peak Line Contents\n", + "===============================================================\n", + " 1 def z():\n", + " 2 0.00B 2.00M 0.00B 2.00M torch.nn.Linear(100, 100).cuda()\n", + "\n" + ] + } + ], + "source": [ + "def z():\n", + " torch.nn.Linear(100, 100).cuda()\n", + "%mlrun -f z z()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can specify which GPU you wish to profile using `-g`:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "File: \n", + "Function: x at line 3\n", + "\n", + "Line # Max usage Peak usage diff max diff peak Line Contents\n", + "===============================================================\n", + " 3 def x():\n", + " 4 0.00B 0.00B 0.00B 0.00B torch.nn.Linear(100, 100).cuda()\n", + "\n", + "File: \n", + "Function: y at line 6\n", + "\n", + "Line # Max usage Peak usage diff max diff peak Line Contents\n", + "===============================================================\n", + " 6 def y(gpu=0):\n", + " 7 0.00B 2.00M 0.00B 2.00M torch.nn.Linear(1000, 100).cuda(device=gpu)\n", + "\n" + ] + } + ], + "source": [ + "%%mlrun -f x -f y -g 1 y\n", + "\n", + "x()\n", + "y(gpu=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can get a handle on the `LineProfiler` object using `-r`" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{\", line 3>: {'line_stat': defaultdict(list,\n", + " {-1: [(0, 0)], 4: [(0, 2097152)]}),\n", + " 'func': ,\n", + " 'func_name': 'x',\n", + " 'source_code': (['def x():\\n', ' torch.nn.Linear(100, 100).cuda()\\n'], 3),\n", + " 'last_lineno': 0}}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "profiler = %mlrun -q -r -f x x()\n", + "profiler.code_map" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can dump stats out to a file using `-T`:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "%mlrun -q -T profile.log -f x x()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "File: \n", + "Function: x at line 3\n", + "\n", + "Line # Max usage Peak usage diff max diff peak Line Contents\n", + "===============================================================\n", + " 3 def x():\n", + " 4 0.00B 2.00M 0.00B 2.00M torch.nn.Linear(100, 100).cuda()\n", + "\n" + ] + } + ], + "source": [ + "!head profile.log" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:memlab]", + "language": "python", + "name": "conda-env-memlab-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pytorch_memlab/__init__.py b/pytorch_memlab/__init__.py index 2ba1c34..38649d2 100644 --- a/pytorch_memlab/__init__.py +++ b/pytorch_memlab/__init__.py @@ -1,3 +1,7 @@ from .courtesy import Courtesy from .mem_reporter import MemReporter from .line_profiler import LineProfiler, profile, profile_every, set_target_gpu +try: + from .extension import load_ipython_extension +except ImportError: + pass diff --git a/pytorch_memlab/extension.py b/pytorch_memlab/extension.py new file mode 100644 index 0000000..c79a56e --- /dev/null +++ b/pytorch_memlab/extension.py @@ -0,0 +1,87 @@ +from IPython.core.magic import ( + Magics, + magics_class, + line_cell_magic, + needs_local_scope, +) +from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring +from .line_profiler import LineProfiler +from tempfile import mkstemp + + +class UsageError(Exception): + pass + + +@magics_class +class MemlabMagics(Magics): + @magic_arguments() + @argument('--function', + '-f', + metavar='FUNC', + action='append', + default=[], + help="""Function to profile. Can be specified multiple times to profile multiple + functions""") + @argument('-r', + '--return-profiler', + action='store_true', + help='Return LineProfiler object for introspection') + @argument('-T', + '--dump-profile', + metavar='OUTPUT', + help='Dump text profile output to file') + @argument('-g', + '--gpu', + metavar='GPU_ID', + default=0, + type=int, + help='Profile memory usage of this GPU') + @argument('-q', + '--quiet', + action='store_true', + help='Don\'t print out profile results') + @argument('statement', + nargs='*', + default=None, + help='Code to run under profiler. You can omit this in cell magic mode.') + @line_cell_magic + @needs_local_scope + def mlrun(self, line=None, cell=None, local_ns=None): + """Execute a statement/cell under the PyTorch Memlab profiler to collect CUDA memory + allocation information on a per-line basis. + """ + args = parse_argstring(self.mlrun, line) + global_ns = self.shell.user_global_ns + + funcs = [] + for name in args.function: + try: + fn = eval(name, global_ns, local_ns) + funcs.append(fn) + except NameError as e: + raise UsageError('Could not find function {!r}.\n{}: {}'.format( + name, e.__class__.__name__, e) + ) + profiler = LineProfiler(*funcs, target_gpu=args.gpu) + if cell is not None: + code = cell + else: + assert args.statement is not None + code = '\n'.join(args.statement) + with profiler: + exec(compile(code, filename='', mode='exec'), local_ns) + + if not args.quiet: + profiler.print_stats() + + if args.dump_profile is not None: + with open(args.dump_profile, 'w') as f: + profiler.print_stats(stream=f) + + if args.return_profiler: + return profiler + + +def load_ipython_extension(ipython): + ipython.register_magics(MemlabMagics) diff --git a/pytorch_memlab/line_profiler.py b/pytorch_memlab/line_profiler.py index 4da28b1..755cd76 100644 --- a/pytorch_memlab/line_profiler.py +++ b/pytorch_memlab/line_profiler.py @@ -8,10 +8,6 @@ from .utils import readable_size -# profile the memory usage on gpu=0 by default -target_gpu = 0 - - def set_target_gpu(gpu_id): """Set the target GPU id to profile memory @@ -24,8 +20,7 @@ def set_target_gpu(gpu_id): - gpu_id: cuda index to profile the memory on, also accepts `torch.device` object. """ - global target_gpu - target_gpu = gpu_id + global_line_profiler.target_gpu = gpu_id class LineProfiler: @@ -51,7 +46,8 @@ class LineProfiler: ``` """ - def __init__(self, *functions): + def __init__(self, *functions, **kwargs): + self.target_gpu = kwargs.get('target_gpu', 0) self.functions = [] self.code_map = {} self.enabled = False @@ -108,7 +104,7 @@ def trace_callback(self, frame, event, arg): if event in ['line', 'return'] and frame.f_code in self.code_map: line_stat = self.code_map[frame.f_code]['line_stat'] - with torch.cuda.device(target_gpu): + with torch.cuda.device(self.target_gpu): allocated_memory = torch.cuda.memory_allocated() cached_memory = torch.cuda.memory_cached() torch.cuda.empty_cache() @@ -121,22 +117,24 @@ def trace_callback(self, frame, event, arg): self.code_map[frame.f_code]['last_lineno'] = lineno return - def print_stats(self): + def print_stats(self, stream=None): """Print the stat of each functions """ for code, stat in self.code_map.items(): show_func( filename=code.co_filename, trace_stat=stat, + stream=stream ) - def print_func_stats(self, func): + def print_func_stats(self, func, stream=None): """Print the stat of a registered function""" code = func.__code__ if code in self.code_map: show_func( filename=code.co_filename, trace_stat=self.code_map[code], + stream=stream ) @@ -227,19 +225,9 @@ def show_func(filename, trace_stat, stream=None): linenos = list(trace_stat['line_stat'].keys()) start_lineno = trace_stat['source_code'][1] - if os.path.exists(filename): - stream.write("File: %s\n" % filename) - stream.write("Function: %s at line %s\n" % (func_name, start_lineno)) - sublines = trace_stat['source_code'][0] - else: - stream.write("\n") - stream.write("Could not find file %s\n" % filename) - stream.write("Are you sure you are running this program from the same directory\n") - stream.write("that you ran the profiler from?\n") - stream.write("Continuing without the function's contents.\n") - # Fake empty lines so we can see the timings, if not the code. - nlines = max(linenos) - min(min(linenos), start_lineno) + 1 - sublines = [''] * nlines + stream.write("File: %s\n" % filename) + stream.write("Function: %s at line %s\n" % (func_name, start_lineno)) + sublines = trace_stat['source_code'][0] prev_max_allocated = 0 prev_max_cached = 0 @@ -267,17 +255,17 @@ def show_func(filename, trace_stat, stream=None): empty = ('', '', '', '') header = template % ('Line #', 'Max usage', 'Peak usage', 'diff max', 'diff peak', 'Line Contents') - stream.write("\n") + stream.write('\n') stream.write(header) - stream.write("\n") + stream.write('\n') stream.write('=' * len(header)) - stream.write("\n") + stream.write('\n') for lineno, line in zip(linenos, sublines): show_line_stat = lineno_mem.get(lineno, empty) max_usage, peak_usage, diff_max, diff_peak = show_line_stat txt = template % (lineno, max_usage, peak_usage, diff_max, diff_peak, line.rstrip('\n').rstrip('\r')) stream.write(txt) - stream.write("\n") - stream.write("\n") + stream.write('\n') + stream.write('\n') stream.flush() diff --git a/requirements.txt b/requirements.txt index eda30ca..757e4ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ calmsize torch future +ipython diff --git a/setup.py b/setup.py index 50d92c3..5448ba7 100644 --- a/setup.py +++ b/setup.py @@ -27,5 +27,8 @@ 'setuptools', 'calmsize', ], + extras_require={ + 'ipython': ['IPython>=0.13'], + }, packages=['pytorch_memlab'], )