diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..fd13dc8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,10 @@ + +The MIT License (MIT) +Copyright (c) 2019, Amil Khan + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..039defb --- /dev/null +++ b/Makefile @@ -0,0 +1,144 @@ +.PHONY: clean data lint requirements sync_data_to_s3 sync_data_from_s3 + +################################################################################# +# GLOBALS # +################################################################################# + +PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +BUCKET = [OPTIONAL] your-bucket-for-syncing-data (do not include 's3://') +PROFILE = default +PROJECT_NAME = gandetection +PYTHON_INTERPRETER = python3 + +ifeq (,$(shell which conda)) +HAS_CONDA=False +else +HAS_CONDA=True +endif + +################################################################################# +# COMMANDS # +################################################################################# + +## Install Python Dependencies +requirements: test_environment + $(PYTHON_INTERPRETER) -m pip install -U pip setuptools wheel + $(PYTHON_INTERPRETER) -m pip install -r requirements.txt + +## Make Dataset +data: requirements + $(PYTHON_INTERPRETER) src/data/make_dataset.py + +## Delete all compiled Python files +clean: + find . -type f -name "*.py[co]" -delete + find . -type d -name "__pycache__" -delete + +## Lint using flake8 +lint: + flake8 src + +## Upload Data to S3 +sync_data_to_s3: +ifeq (default,$(PROFILE)) + aws s3 sync data/ s3://$(BUCKET)/data/ +else + aws s3 sync data/ s3://$(BUCKET)/data/ --profile $(PROFILE) +endif + +## Download Data from S3 +sync_data_from_s3: +ifeq (default,$(PROFILE)) + aws s3 sync s3://$(BUCKET)/data/ data/ +else + aws s3 sync s3://$(BUCKET)/data/ data/ --profile $(PROFILE) +endif + +## Set up python interpreter environment +create_environment: +ifeq (True,$(HAS_CONDA)) + @echo ">>> Detected conda, creating conda environment." +ifeq (3,$(findstring 3,$(PYTHON_INTERPRETER))) + conda create --name $(PROJECT_NAME) python=3 +else + conda create --name $(PROJECT_NAME) python=2.7 +endif + @echo ">>> New conda env created. Activate with:\nsource activate $(PROJECT_NAME)" +else + $(PYTHON_INTERPRETER) -m pip install -q virtualenv virtualenvwrapper + @echo ">>> Installing virtualenvwrapper if not already intalled.\nMake sure the following lines are in shell startup file\n\ + export WORKON_HOME=$$HOME/.virtualenvs\nexport PROJECT_HOME=$$HOME/Devel\nsource /usr/local/bin/virtualenvwrapper.sh\n" + @bash -c "source `which virtualenvwrapper.sh`;mkvirtualenv $(PROJECT_NAME) --python=$(PYTHON_INTERPRETER)" + @echo ">>> New virtualenv created. Activate with:\nworkon $(PROJECT_NAME)" +endif + +## Test python environment is setup correctly +test_environment: + $(PYTHON_INTERPRETER) test_environment.py + +################################################################################# +# PROJECT RULES # +################################################################################# + + + +################################################################################# +# Self Documenting Commands # +################################################################################# + +.DEFAULT_GOAL := help + +# Inspired by +# sed script explained: +# /^##/: +# * save line in hold space +# * purge line +# * Loop: +# * append newline + line to hold space +# * go to next line +# * if line starts with doc comment, strip comment character off and loop +# * remove target prerequisites +# * append hold space (+ newline) to line +# * replace newline plus comments by `---` +# * print line +# Separate expressions are necessary because labels cannot be delimited by +# semicolon; see +.PHONY: help +help: + @echo "$$(tput bold)Available rules:$$(tput sgr0)" + @echo + @sed -n -e "/^## / { \ + h; \ + s/.*//; \ + :doc" \ + -e "H; \ + n; \ + s/^## //; \ + t doc" \ + -e "s/:.*//; \ + G; \ + s/\\n## /---/; \ + s/\\n/ /g; \ + p; \ + }" ${MAKEFILE_LIST} \ + | LC_ALL='C' sort --ignore-case \ + | awk -F '---' \ + -v ncol=$$(tput cols) \ + -v indent=19 \ + -v col_on="$$(tput setaf 6)" \ + -v col_off="$$(tput sgr0)" \ + '{ \ + printf "%s%*s%s ", col_on, -indent, $$1, col_off; \ + n = split($$2, words, " "); \ + line_length = ncol - indent; \ + for (i = 1; i <= n; i++) { \ + line_length -= length(words[i]) + 1; \ + if (line_length <= 0) { \ + line_length = ncol - indent - length(words[i]) - 1; \ + printf "\n%*s ", -indent, " "; \ + } \ + printf "%s ", words[i]; \ + } \ + printf "\n"; \ + }' \ + | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars') diff --git a/README.md b/README.md new file mode 100644 index 0000000..72cf66f --- /dev/null +++ b/README.md @@ -0,0 +1,67 @@ +# **GAN Generated Image Detection using Convolutional Neural Networks** +*** +============================== + +### Abstract + +Synthetic image generation using _Generative Adversarial Network (GAN)_ architectures have become increasingly harder to fail the eye test. With a relatively low-cost GPU and enough time, we have seen images of fake celebrities, bedrooms, and landscapes---to name a few---deceive a reasonable person. Previous use cases of GANs, such as increasing the number of samples in a small dataset, have seen widespread adoption across disciplines. However, over the past year, we have also seen GANs being used for malicious cases as well. Hence, we will feed GAN generated images we produced to a model whose task is to determine whether an image is "Real" or "Fake''. We demonstrate GAN generated image detection using five ImageNet classification models for the classification task: classification of real images and fake images presented as inputs to the ImageNet model. + +### Data Generation + +For the data generation portion, we decided to use the Progressive GANs implementation since it produced the best results at the time of this writing. More specifically, however, the Progressive GANs approach of starting with low-resolution images, and then progressively increasing the resolution by adding layers to the networks, allowed greater stability during training, shorter training times, and higher resolution images. + +I opted to put everything you need for the data generation portion inside the `data` folder. + + +### Model Performance + +The best model that was able to balance computational time, complexity, and accuracy. With 96% accuracy on the test set at epoch 11, we were able to obtain extremely low error rates and a generalizable model. + +Project Organization +------------ + + ├── LICENSE + ├── Makefile <- Makefile with commands like `make data` or `make train` + ├── README.md <- The top-level README for developers using this project. + ├── data + │   ├── external <- Data from third party sources. + │   ├── interim <- Intermediate data that has been transformed. + │   ├── processed <- The final, canonical data sets for modeling. + │   └── raw <- The original, immutable data dump. + │ + ├── docs <- A default Sphinx project; see sphinx-doc.org for details + │ + ├── models <- Trained and serialized models, model predictions, or model summaries + │ + ├── notebooks <- Jupyter notebooks. + │ + ├── references <- Data dictionaries, manuals, and all other explanatory materials. + │ + ├── reports <- Generated analysis as HTML, PDF, LaTeX, etc. + │   └── figures <- Generated graphics and figures to be used in reporting + │ + ├── requirements.txt <- The requirements file for reproducing the analysis environment, e.g. + │ generated with `pip freeze > requirements.txt` + │ + ├── setup.py <- makes project pip installable (pip install -e .) so src can be imported + ├── src <- Source code for use in this project. + │   ├── __init__.py <- Makes src a Python module + │ │ + │   ├── data <- Scripts to download or generate data + │   │   └── make_dataset.py + │ │ + │   ├── features <- Scripts to turn raw data into features for modeling + │   │   └── build_features.py + │ │ + │   ├── models <- Scripts to train models and then use trained models to make + │ │ │ predictions + │   │   ├── predict_model.py + │   │   └── train_model.py + │ │ + │   └── visualization <- Scripts to create exploratory and results oriented visualizations + │   └── visualize.py + │ + └── tox.ini <- tox file with settings for running tox; see tox.testrun.org + + +-------- diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..7400ab1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,153 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + +clean: + -rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/gandetection.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/gandetection.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/gandetection" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/gandetection" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." diff --git a/docs/commands.rst b/docs/commands.rst new file mode 100644 index 0000000..2d162f3 --- /dev/null +++ b/docs/commands.rst @@ -0,0 +1,10 @@ +Commands +======== + +The Makefile contains the central entry points for common tasks related to this project. + +Syncing data to S3 +^^^^^^^^^^^^^^^^^^ + +* `make sync_data_to_s3` will use `aws s3 sync` to recursively sync files in `data/` up to `s3://[OPTIONAL] your-bucket-for-syncing-data (do not include 's3://')/data/`. +* `make sync_data_from_s3` will use `aws s3 sync` to recursively sync files from `s3://[OPTIONAL] your-bucket-for-syncing-data (do not include 's3://')/data/` to `data/`. diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..e9f14c9 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- +# +# GanDetection documentation build configuration file, created by +# sphinx-quickstart. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import os +import sys + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# sys.path.insert(0, os.path.abspath('.')) + +# -- General configuration ----------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The encoding of source files. +# source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'GanDetection' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = '0.1' +# The full version, including alpha/beta/rc tags. +release = '0.1' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ['_build'] + +# The reST default role (used for this markup: `text`) to use for all documents. +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + + +# -- Options for HTML output --------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'default' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +# html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +# html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +# html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +# html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +# html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +# html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +# html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +# html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +# html_additional_pages = {} + +# If false, no module index is generated. +# html_domain_indices = True + +# If false, no index is generated. +# html_use_index = True + +# If true, the index is split into individual pages for each letter. +# html_split_index = False + +# If true, links to the reST sources are added to the pages. +# html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +# html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +# html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +# html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +# html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = 'gandetectiondoc' + + +# -- Options for LaTeX output -------------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # 'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ('index', + 'gandetection.tex', + u'GanDetection Documentation', + u"Amil Khan", 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +# latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +# latex_use_parts = False + +# If true, show page references after internal links. +# latex_show_pagerefs = False + +# If true, show URL addresses after external links. +# latex_show_urls = False + +# Documents to append as an appendix to all manuals. +# latex_appendices = [] + +# If false, no module index is generated. +# latex_domain_indices = True + + +# -- Options for manual page output -------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ('index', 'gandetection', u'GanDetection Documentation', + [u"Amil Khan"], 1) +] + +# If true, show URL addresses after external links. +# man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------------ + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ('index', 'gandetection', u'GanDetection Documentation', + u"Amil Khan", 'GanDetection', + 'Detecting GAN Generated images using Convolutional Neural Networks', 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +# texinfo_appendices = [] + +# If false, no module index is generated. +# texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +# texinfo_show_urls = 'footnote' diff --git a/docs/getting-started.rst b/docs/getting-started.rst new file mode 100644 index 0000000..b4f71c3 --- /dev/null +++ b/docs/getting-started.rst @@ -0,0 +1,6 @@ +Getting started +=============== + +This is where you describe how to get set up on a clean install, including the +commands necessary to get the raw data (using the `sync_data_from_s3` command, +for example), and then how to make the cleaned, final data sets. diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..16f4986 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,24 @@ +.. GanDetection documentation master file, created by + sphinx-quickstart. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +GanDetection documentation! +============================================== + +Contents: + +.. toctree:: + :maxdepth: 2 + + getting-started + commands + + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..d5a6536 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,190 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\gandetection.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\gandetection.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +:end diff --git a/models/bninception.py b/models/bninception.py new file mode 100644 index 0000000..04371b3 --- /dev/null +++ b/models/bninception.py @@ -0,0 +1,515 @@ +from __future__ import print_function, division, absolute_import +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import os +import sys + +__all__ = ['BNInception', 'bninception'] + +pretrained_settings = { + 'bninception': { + 'imagenet': { + # Was ported using python2 (may trigger warning) + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth', + # 'url': 'http://yjxiong.me/others/bn_inception-9f5701afb96c8044.pth', + 'input_space': 'BGR', + 'input_size': [3, 224, 224], + 'input_range': [0, 255], + 'mean': [104, 117, 128], + 'std': [1, 1, 1], + 'num_classes': 1000 + } + } +} + +class BNInception(nn.Module): + + def __init__(self, num_classes=1000): + super(BNInception, self).__init__() + inplace = True + self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) + self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True) + self.conv1_relu_7x7 = nn.ReLU (inplace) + self.pool1_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) + self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) + self.conv2_relu_3x3_reduce = nn.ReLU (inplace) + self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True) + self.conv2_relu_3x3 = nn.ReLU (inplace) + self.pool2_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3a_relu_1x1 = nn.ReLU (inplace) + self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3a_relu_3x3 = nn.ReLU (inplace) + self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) + self.inception_3a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) + self.inception_3a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True) + self.inception_3a_relu_pool_proj = nn.ReLU (inplace) + self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3b_relu_1x1 = nn.ReLU (inplace) + self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True) + self.inception_3b_relu_3x3 = nn.ReLU (inplace) + self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) + self.inception_3b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) + self.inception_3b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3b_relu_pool_proj = nn.ReLU (inplace) + self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) + self.inception_3c_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True) + self.inception_3c_relu_3x3 = nn.ReLU (inplace) + self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) + self.inception_3c_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) + self.inception_3c_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) + self.inception_3c_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3c_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True) + self.inception_4a_relu_1x1 = nn.ReLU (inplace) + self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) + self.inception_4a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True) + self.inception_4a_relu_3x3 = nn.ReLU (inplace) + self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) + self.inception_4a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4a_relu_pool_proj = nn.ReLU (inplace) + self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True) + self.inception_4b_relu_1x1 = nn.ReLU (inplace) + self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) + self.inception_4b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4b_relu_3x3 = nn.ReLU (inplace) + self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) + self.inception_4b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4b_relu_pool_proj = nn.ReLU (inplace) + self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True) + self.inception_4c_relu_1x1 = nn.ReLU (inplace) + self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4c_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True) + self.inception_4c_relu_3x3 = nn.ReLU (inplace) + self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4c_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True) + self.inception_4c_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True) + self.inception_4c_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4c_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4c_relu_pool_proj = nn.ReLU (inplace) + self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True) + self.inception_4d_relu_1x1 = nn.ReLU (inplace) + self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4d_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True) + self.inception_4d_relu_3x3 = nn.ReLU (inplace) + self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) + self.inception_4d_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True) + self.inception_4d_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True) + self.inception_4d_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4d_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4d_relu_pool_proj = nn.ReLU (inplace) + self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) + self.inception_4e_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True) + self.inception_4e_relu_3x3 = nn.ReLU (inplace) + self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) + self.inception_4e_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True) + self.inception_4e_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True) + self.inception_4e_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4e_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True) + self.inception_5a_relu_1x1 = nn.ReLU (inplace) + self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) + self.inception_5a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True) + self.inception_5a_relu_3x3 = nn.ReLU (inplace) + self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) + self.inception_5a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) + self.inception_5a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) + self.inception_5a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_5a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) + self.inception_5a_relu_pool_proj = nn.ReLU (inplace) + self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True) + self.inception_5b_relu_1x1 = nn.ReLU (inplace) + self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) + self.inception_5b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True) + self.inception_5b_relu_3x3 = nn.ReLU (inplace) + self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) + self.inception_5b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) + self.inception_5b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) + self.inception_5b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_5b_pool = nn.MaxPool2d ((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) + self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) + self.inception_5b_relu_pool_proj = nn.ReLU (inplace) + self.global_pool = nn.AvgPool2d (7, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + self.last_linear = nn.Linear (1024, num_classes) + + def features(self, input): + conv1_7x7_s2_out = self.conv1_7x7_s2(input) + conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) + conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) + pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out) + conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) + conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) + conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) + conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out) + conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) + conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) + pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out) + inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) + inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) + inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) + inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) + inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) + inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) + inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out) + inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) + inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) + inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) + inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(inception_3a_double_3x3_reduce_out) + inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(inception_3a_double_3x3_reduce_bn_out) + inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out) + inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) + inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) + inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out) + inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) + inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) + inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) + inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) + inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) + inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) + inception_3a_output_out = torch.cat([inception_3a_relu_1x1_out,inception_3a_relu_3x3_out,inception_3a_relu_double_3x3_2_out ,inception_3a_relu_pool_proj_out], 1) + inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) + inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) + inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) + inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) + inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) + inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) + inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out) + inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) + inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) + inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) + inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(inception_3b_double_3x3_reduce_out) + inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(inception_3b_double_3x3_reduce_bn_out) + inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out) + inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) + inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) + inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out) + inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) + inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) + inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) + inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) + inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) + inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) + inception_3b_output_out = torch.cat([inception_3b_relu_1x1_out,inception_3b_relu_3x3_out,inception_3b_relu_double_3x3_2_out,inception_3b_relu_pool_proj_out], 1) + inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) + inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) + inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) + inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out) + inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) + inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) + inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) + inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(inception_3c_double_3x3_reduce_out) + inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(inception_3c_double_3x3_reduce_bn_out) + inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out) + inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) + inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) + inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out) + inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) + inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) + inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) + inception_3c_output_out = torch.cat([inception_3c_relu_3x3_out,inception_3c_relu_double_3x3_2_out,inception_3c_pool_out], 1) + inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) + inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) + inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) + inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) + inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) + inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) + inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out) + inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) + inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) + inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) + inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(inception_4a_double_3x3_reduce_out) + inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(inception_4a_double_3x3_reduce_bn_out) + inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out) + inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) + inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) + inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out) + inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) + inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) + inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) + inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) + inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) + inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) + inception_4a_output_out = torch.cat([inception_4a_relu_1x1_out,inception_4a_relu_3x3_out,inception_4a_relu_double_3x3_2_out,inception_4a_relu_pool_proj_out], 1) + inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) + inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) + inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) + inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) + inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) + inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) + inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out) + inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) + inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) + inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) + inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(inception_4b_double_3x3_reduce_out) + inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(inception_4b_double_3x3_reduce_bn_out) + inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out) + inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) + inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) + inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out) + inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) + inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) + inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) + inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) + inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) + inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) + inception_4b_output_out = torch.cat([inception_4b_relu_1x1_out,inception_4b_relu_3x3_out,inception_4b_relu_double_3x3_2_out,inception_4b_relu_pool_proj_out], 1) + inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) + inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) + inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) + inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) + inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) + inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) + inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out) + inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) + inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) + inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) + inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(inception_4c_double_3x3_reduce_out) + inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(inception_4c_double_3x3_reduce_bn_out) + inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out) + inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) + inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) + inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out) + inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) + inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) + inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) + inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) + inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) + inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) + inception_4c_output_out = torch.cat([inception_4c_relu_1x1_out,inception_4c_relu_3x3_out,inception_4c_relu_double_3x3_2_out,inception_4c_relu_pool_proj_out], 1) + inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) + inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) + inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) + inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) + inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) + inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) + inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out) + inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) + inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) + inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) + inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(inception_4d_double_3x3_reduce_out) + inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(inception_4d_double_3x3_reduce_bn_out) + inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out) + inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) + inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) + inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out) + inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) + inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) + inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) + inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) + inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) + inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) + inception_4d_output_out = torch.cat([inception_4d_relu_1x1_out,inception_4d_relu_3x3_out,inception_4d_relu_double_3x3_2_out,inception_4d_relu_pool_proj_out], 1) + inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) + inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) + inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) + inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out) + inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) + inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) + inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) + inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(inception_4e_double_3x3_reduce_out) + inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(inception_4e_double_3x3_reduce_bn_out) + inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out) + inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) + inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) + inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out) + inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) + inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) + inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) + inception_4e_output_out = torch.cat([inception_4e_relu_3x3_out,inception_4e_relu_double_3x3_2_out,inception_4e_pool_out], 1) + inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) + inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) + inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) + inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) + inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) + inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) + inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out) + inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) + inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) + inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) + inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(inception_5a_double_3x3_reduce_out) + inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(inception_5a_double_3x3_reduce_bn_out) + inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out) + inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) + inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) + inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out) + inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) + inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) + inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) + inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) + inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) + inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) + inception_5a_output_out = torch.cat([inception_5a_relu_1x1_out,inception_5a_relu_3x3_out,inception_5a_relu_double_3x3_2_out,inception_5a_relu_pool_proj_out], 1) + inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) + inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) + inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) + inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) + inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) + inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) + inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out) + inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) + inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) + inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) + inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(inception_5b_double_3x3_reduce_out) + inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(inception_5b_double_3x3_reduce_bn_out) + inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out) + inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) + inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) + inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out) + inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) + inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) + inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) + inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) + inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) + inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) + inception_5b_output_out = torch.cat([inception_5b_relu_1x1_out,inception_5b_relu_3x3_out,inception_5b_relu_double_3x3_2_out,inception_5b_relu_pool_proj_out], 1) + return inception_5b_output_out + + def logits(self, features): + x = self.global_pool(features) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + +def bninception(num_classes=1000, pretrained='imagenet'): + r"""BNInception model architecture from `_ paper. + """ + model = BNInception(num_classes=num_classes) + if pretrained is not None: + settings = pretrained_settings['bninception'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + model.load_state_dict(model_zoo.load_url(settings['url'])) + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + return model + + +if __name__ == '__main__': + + model = bninception() diff --git a/models/inceptionresnetv2.py b/models/inceptionresnetv2.py new file mode 100644 index 0000000..8f55bb0 --- /dev/null +++ b/models/inceptionresnetv2.py @@ -0,0 +1,380 @@ +from __future__ import print_function, division, absolute_import +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import os +import sys + +__all__ = ['InceptionResNetV2', 'inceptionresnetv2'] + +pretrained_settings = { + 'inceptionresnetv2': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + 'imagenet+background': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1001 + } + } +} + + +class BasicConv2d(nn.Module): + + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_planes, out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, bias=False) # verify bias false + self.bn = nn.BatchNorm2d(out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, noReLU=False): + super(Block8, self).__init__() + + self.scale = scale + self.noReLU = noReLU + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)), + BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + if not self.noReLU: + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if not self.noReLU: + out = self.relu(out) + return out + + +class InceptionResNetV2(nn.Module): + + def __init__(self, num_classes=1001): + super(InceptionResNetV2, self).__init__() + # Special attributs + self.input_space = None + self.input_size = (299, 299, 3) + self.mean = None + self.std = None + # Modules + self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17) + ) + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10) + ) + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20) + ) + self.block8 = Block8(noReLU=True) + self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) + self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False) + self.last_linear = nn.Linear(1536, num_classes) + + def features(self, input): + x = self.conv2d_1a(input) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def logits(self, features): + x = self.avgpool_1a(features) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + +def inceptionresnetv2(num_classes=1000, pretrained='imagenet'): + r"""InceptionResNetV2 model architecture from the + `"InceptionV4, Inception-ResNet..." `_ paper. + """ + if pretrained: + settings = pretrained_settings['inceptionresnetv2'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + + # both 'imagenet'&'imagenet+background' are loaded from same parameters + model = InceptionResNetV2(num_classes=1001) + model.load_state_dict(model_zoo.load_url(settings['url'])) + + if pretrained == 'imagenet': + new_last_linear = nn.Linear(1536, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + + model.mean = settings['mean'] + model.std = settings['std'] + else: + model = InceptionResNetV2(num_classes=num_classes) + return model + +''' +TEST +Run this code with: +``` +cd $HOME/pretrained-models.pytorch +python -m pretrainedmodels.inceptionresnetv2 +``` +''' +if __name__ == '__main__': + + assert inceptionresnetv2(num_classes=10, pretrained=None) + print('success') + assert inceptionresnetv2(num_classes=1000, pretrained='imagenet') + print('success') + assert inceptionresnetv2(num_classes=1001, pretrained='imagenet+background') + print('success') + + # fail + assert inceptionresnetv2(num_classes=1001, pretrained='imagenet') \ No newline at end of file diff --git a/models/inceptionv4.py b/models/inceptionv4.py new file mode 100644 index 0000000..d77bd2f --- /dev/null +++ b/models/inceptionv4.py @@ -0,0 +1,356 @@ +from __future__ import print_function, division, absolute_import +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import os +import sys + +__all__ = ['InceptionV4', 'inceptionv4'] + +pretrained_settings = { + 'inceptionv4': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + 'imagenet+background': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1001 + } + } +} + + +class BasicConv2d(nn.Module): + + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_planes, out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, bias=False) # verify bias false + self.bn = nn.BatchNorm2d(out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_3a(nn.Module): + + def __init__(self): + super(Mixed_3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_4a(nn.Module): + + def __init__(self): + super(Mixed_4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(64, 96, kernel_size=(3,3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_5a(nn.Module): + + def __init__(self): + super(Mixed_5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class Inception_A(nn.Module): + + def __init__(self): + super(Inception_A, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_A(nn.Module): + + def __init__(self): + super(Reduction_A, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_B(nn.Module): + + def __init__(self): + super(Inception_B, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0)) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3)) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_B(nn.Module): + + def __init__(self): + super(Reduction_B, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)), + BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_C(nn.Module): + + def __init__(self): + super(Inception_C, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + + def __init__(self, num_classes=1001): + super(InceptionV4, self).__init__() + # Special attributs + self.input_space = None + self.input_size = (299, 299, 3) + self.mean = None + self.std = None + # Modules + self.features = nn.Sequential( + BasicConv2d(3, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed_3a(), + Mixed_4a(), + Mixed_5a(), + Inception_A(), + Inception_A(), + Inception_A(), + Inception_A(), + Reduction_A(), # Mixed_6a + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Reduction_B(), # Mixed_7a + Inception_C(), + Inception_C(), + Inception_C() + ) + self.avg_pool = nn.AvgPool2d(8, count_include_pad=False) + self.last_linear = nn.Linear(1536, num_classes) + + def logits(self, features): + x = self.avg_pool(features) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def inceptionv4(num_classes=1000, pretrained='imagenet'): + if pretrained: + settings = pretrained_settings['inceptionv4'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + + # both 'imagenet'&'imagenet+background' are loaded from same parameters + model = InceptionV4(num_classes=1001) + model.load_state_dict(model_zoo.load_url(settings['url'])) + + if pretrained == 'imagenet': + new_last_linear = nn.Linear(1536, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + else: + model = InceptionV4(num_classes=num_classes) + return model + + +''' +TEST +Run this code with: +``` +cd $HOME/pretrained-models.pytorch +python -m pretrainedmodels.inceptionv4 +``` +''' +if __name__ == '__main__': + + assert inceptionv4(num_classes=10, pretrained=None) + print('success') + assert inceptionv4(num_classes=1000, pretrained='imagenet') + print('success') + assert inceptionv4(num_classes=1001, pretrained='imagenet+background') + print('success') + + # fail + assert inceptionv4(num_classes=1001, pretrained='imagenet') \ No newline at end of file diff --git a/models/nasnet.py b/models/nasnet.py new file mode 100644 index 0000000..4ec3b39 --- /dev/null +++ b/models/nasnet.py @@ -0,0 +1,646 @@ +from __future__ import print_function, division, absolute_import +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.autograd import Variable + +pretrained_settings = { + 'nasnetalarge': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + 'input_space': 'RGB', + 'input_size': [3, 331, 331], # resize 354 + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + 'imagenet+background': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + 'input_space': 'RGB', + 'input_size': [3, 331, 331], # resize 354 + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1001 + } + } +} + + +class MaxPoolPad(nn.Module): + + def __init__(self): + super(MaxPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:] + return x + + +class AvgPoolPad(nn.Module): + + def __init__(self, stride=2, padding=1): + super(AvgPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:] + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel, + stride=dw_stride, + padding=dw_padding, + bias=bias, + groups=in_channels) + self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + super(BranchSeparables, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) + self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesStem(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + super(BranchSeparablesStem, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) + self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesReduction(BranchSeparables): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False): + BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias) + self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0)) + + def forward(self, x): + x = self.relu(x) + x = self.padding(x) + x = self.separable_1(x) + x = x[:, :, 1:, 1:].contiguous() + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Module): + def __init__(self, stem_filters, num_filters=42): + super(CellStem0, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2) + self.comb_iter_0_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, bias=False) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem1(nn.Module): + + def __init__(self, stem_filters, num_filters): + super(CellStem1, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_filters, self.num_filters, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False)) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True) + + self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, bias=False) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_right = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, bias=False) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + + x_relu = self.relu(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class FirstCell(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(FirstCell, self).__init__() + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + def forward(self, x, x_prev): + x_relu = self.relu(x_prev) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NormalCell(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + + self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False) + self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell0(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_4_right = MaxPoolPad() + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell1(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NASNetALarge(nn.Module): + """NASNetALarge (6 @ 4032) """ + + def __init__(self, num_classes=1001, stem_filters=96, penultimate_filters=4032, filters_multiplier=2): + super(NASNetALarge, self).__init__() + self.num_classes = num_classes + self.stem_filters = stem_filters + self.penultimate_filters = penultimate_filters + self.filters_multiplier = filters_multiplier + + filters = self.penultimate_filters // 24 + # 24 is default value for the architecture + + self.conv0 = nn.Sequential() + self.conv0.add_module('conv', nn.Conv2d(in_channels=3, out_channels=self.stem_filters, kernel_size=3, padding=0, stride=2, + bias=False)) + self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_filters, eps=0.001, momentum=0.1, affine=True)) + + self.cell_stem_0 = CellStem0(self.stem_filters, num_filters=filters // (filters_multiplier ** 2)) + self.cell_stem_1 = CellStem1(self.stem_filters, num_filters=filters // filters_multiplier) + + self.cell_0 = FirstCell(in_channels_left=filters, out_channels_left=filters//2, + in_channels_right=2*filters, out_channels_right=filters) + self.cell_1 = NormalCell(in_channels_left=2*filters, out_channels_left=filters, + in_channels_right=6*filters, out_channels_right=filters) + self.cell_2 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, + in_channels_right=6*filters, out_channels_right=filters) + self.cell_3 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, + in_channels_right=6*filters, out_channels_right=filters) + self.cell_4 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, + in_channels_right=6*filters, out_channels_right=filters) + self.cell_5 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, + in_channels_right=6*filters, out_channels_right=filters) + + self.reduction_cell_0 = ReductionCell0(in_channels_left=6*filters, out_channels_left=2*filters, + in_channels_right=6*filters, out_channels_right=2*filters) + + self.cell_6 = FirstCell(in_channels_left=6*filters, out_channels_left=filters, + in_channels_right=8*filters, out_channels_right=2*filters) + self.cell_7 = NormalCell(in_channels_left=8*filters, out_channels_left=2*filters, + in_channels_right=12*filters, out_channels_right=2*filters) + self.cell_8 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, + in_channels_right=12*filters, out_channels_right=2*filters) + self.cell_9 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, + in_channels_right=12*filters, out_channels_right=2*filters) + self.cell_10 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, + in_channels_right=12*filters, out_channels_right=2*filters) + self.cell_11 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, + in_channels_right=12*filters, out_channels_right=2*filters) + + self.reduction_cell_1 = ReductionCell1(in_channels_left=12*filters, out_channels_left=4*filters, + in_channels_right=12*filters, out_channels_right=4*filters) + + self.cell_12 = FirstCell(in_channels_left=12*filters, out_channels_left=2*filters, + in_channels_right=16*filters, out_channels_right=4*filters) + self.cell_13 = NormalCell(in_channels_left=16*filters, out_channels_left=4*filters, + in_channels_right=24*filters, out_channels_right=4*filters) + self.cell_14 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, + in_channels_right=24*filters, out_channels_right=4*filters) + self.cell_15 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, + in_channels_right=24*filters, out_channels_right=4*filters) + self.cell_16 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, + in_channels_right=24*filters, out_channels_right=4*filters) + self.cell_17 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, + in_channels_right=24*filters, out_channels_right=4*filters) + + self.relu = nn.ReLU() + self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) + self.dropout = nn.Dropout() + self.last_linear = nn.Linear(24*filters, self.num_classes) + + def features(self, input): + x_conv0 = self.conv0(input) + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + x_cell_4 = self.cell_4(x_cell_3, x_cell_2) + x_cell_5 = self.cell_5(x_cell_4, x_cell_3) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) + + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + x_cell_10 = self.cell_10(x_cell_9, x_cell_8) + x_cell_11 = self.cell_11(x_cell_10, x_cell_9) + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) + + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + x_cell_16 = self.cell_16(x_cell_15, x_cell_14) + x_cell_17 = self.cell_17(x_cell_16, x_cell_15) + return x_cell_17 + + def logits(self, features): + x = self.relu(features) + x = self.avg_pool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def nasnetalarge(num_classes=1001, pretrained='imagenet'): + r"""NASNetALarge model architecture from the + `"NASNet" `_ paper. + """ + if pretrained: + settings = pretrained_settings['nasnetalarge'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + + # both 'imagenet'&'imagenet+background' are loaded from same parameters + model = NASNetALarge(num_classes=1001) + model.load_state_dict(model_zoo.load_url(settings['url'])) + + if pretrained == 'imagenet': + new_last_linear = nn.Linear(model.last_linear.in_features, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + + model.mean = settings['mean'] + model.std = settings['std'] + else: + model = NASNetALarge(num_classes=num_classes) + return model + + +if __name__ == "__main__": + + model = NASNetALarge() + input = Variable(torch.randn(2, 3, 331, 331)) + + output = model(input) + print(output.size()) + + diff --git a/models/nasnet_mobile.py b/models/nasnet_mobile.py new file mode 100644 index 0000000..4146822 --- /dev/null +++ b/models/nasnet_mobile.py @@ -0,0 +1,661 @@ +""" +NASNet Mobile +Thanks to Anastasiia (https://github.com/DagnyT) for the great help, support and motivation! + + +------------------------------------------------------------------------------------ + Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M) +------------------------------------------------------------------------------------ +| NASNet-A (4 @ 1056) | 74.08% | 91.74% | 564 M | 5.3 | +------------------------------------------------------------------------------------ +# References: + - [Learning Transferable Architectures for Scalable Image Recognition] + (https://arxiv.org/abs/1707.07012) +""" +from __future__ import print_function, division, absolute_import +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.autograd import Variable +import numpy as np + +pretrained_settings = { + 'nasnetamobile': { + 'imagenet': { + #'url': 'https://github.com/veronikayurchuk/pretrained-models.pytorch/releases/download/v1.0/nasnetmobile-7e03cead.pth.tar', + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], # resize 256 + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + # 'imagenet+background': { + # # 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + # 'input_space': 'RGB', + # 'input_size': [3, 224, 224], # resize 256 + # 'input_range': [0, 1], + # 'mean': [0.5, 0.5, 0.5], + # 'std': [0.5, 0.5, 0.5], + # 'num_classes': 1001 + # } + } +} + + +class MaxPoolPad(nn.Module): + + def __init__(self): + super(MaxPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:].contiguous() + return x + + +class AvgPoolPad(nn.Module): + + def __init__(self, stride=2, padding=1): + super(AvgPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:].contiguous() + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel, + stride=dw_stride, + padding=dw_padding, + bias=bias, + groups=in_channels) + self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, name=None, bias=False): + super(BranchSeparables, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) + self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + self.name = name + + def forward(self, x): + x = self.relu(x) + if self.name == 'specific': + x = nn.ZeroPad2d((1, 0, 1, 0))(x) + x = self.separable_1(x) + if self.name == 'specific': + x = x[:, :, 1:, 1:].contiguous() + + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesStem(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + super(BranchSeparablesStem, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) + self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesReduction(BranchSeparables): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False): + BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias) + self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0)) + + def forward(self, x): + x = self.relu(x) + x = self.padding(x) + x = self.separable_1(x) + x = x[:, :, 1:, 1:].contiguous() + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Module): + def __init__(self, stem_filters, num_filters=42): + super(CellStem0, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2) + self.comb_iter_0_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, bias=False) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem1(nn.Module): + + def __init__(self, stem_filters, num_filters): + super(CellStem1, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_filters, self.num_filters, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False)) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True) + + self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, name='specific', bias=False) + self.comb_iter_0_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, name='specific', bias=False) + + # self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, name='specific', bias=False) + + # self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, name='specific', bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, name='specific', bias=False) + # self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_4_right = MaxPoolPad() + + def forward(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + + x_relu = self.relu(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class FirstCell(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(FirstCell, self).__init__() + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + def forward(self, x, x_prev): + x_relu = self.relu(x_prev) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NormalCell(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + + self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False) + self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell0(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_4_right = MaxPoolPad() + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell1(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, name='specific', bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, name='specific', bias=False) + + # self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, name='specific', bias=False) + + # self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, name='specific', bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, name='specific', bias=False) + # self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_4_right =MaxPoolPad() + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NASNetAMobile(nn.Module): + """NASNetAMobile (4 @ 1056) """ + + def __init__(self, num_classes=1001, stem_filters=32, penultimate_filters=1056, filters_multiplier=2): + super(NASNetAMobile, self).__init__() + self.num_classes = num_classes + self.stem_filters = stem_filters + self.penultimate_filters = penultimate_filters + self.filters_multiplier = filters_multiplier + + filters = self.penultimate_filters // 24 + # 24 is default value for the architecture + + self.conv0 = nn.Sequential() + self.conv0.add_module('conv', nn.Conv2d(in_channels=3, out_channels=self.stem_filters, kernel_size=3, padding=0, stride=2, + bias=False)) + self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_filters, eps=0.001, momentum=0.1, affine=True)) + + self.cell_stem_0 = CellStem0(self.stem_filters, num_filters=filters // (filters_multiplier ** 2)) + self.cell_stem_1 = CellStem1(self.stem_filters, num_filters=filters // filters_multiplier) + + self.cell_0 = FirstCell(in_channels_left=filters, out_channels_left=filters//2, # 1, 0.5 + in_channels_right=2*filters, out_channels_right=filters) # 2, 1 + self.cell_1 = NormalCell(in_channels_left=2*filters, out_channels_left=filters, # 2, 1 + in_channels_right=6*filters, out_channels_right=filters) # 6, 1 + self.cell_2 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1 + in_channels_right=6*filters, out_channels_right=filters) # 6, 1 + self.cell_3 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1 + in_channels_right=6*filters, out_channels_right=filters) # 6, 1 + + self.reduction_cell_0 = ReductionCell0(in_channels_left=6*filters, out_channels_left=2*filters, # 6, 2 + in_channels_right=6*filters, out_channels_right=2*filters) # 6, 2 + + self.cell_6 = FirstCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1 + in_channels_right=8*filters, out_channels_right=2*filters) # 8, 2 + self.cell_7 = NormalCell(in_channels_left=8*filters, out_channels_left=2*filters, # 8, 2 + in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2 + self.cell_8 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2 + in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2 + self.cell_9 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2 + in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2 + + self.reduction_cell_1 = ReductionCell1(in_channels_left=12*filters, out_channels_left=4*filters, # 12, 4 + in_channels_right=12*filters, out_channels_right=4*filters) # 12, 4 + + self.cell_12 = FirstCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2 + in_channels_right=16*filters, out_channels_right=4*filters) # 16, 4 + self.cell_13 = NormalCell(in_channels_left=16*filters, out_channels_left=4*filters, # 16, 4 + in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4 + self.cell_14 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, # 24, 4 + in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4 + self.cell_15 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, # 24, 4 + in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4 + + self.relu = nn.ReLU() + self.avg_pool = nn.AvgPool2d(7, stride=1, padding=0) + self.dropout = nn.Dropout() + self.last_linear = nn.Linear(24*filters, self.num_classes) + + def features(self, input): + x_conv0 = self.conv0(input) + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_3, x_cell_2) + + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_3) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_9, x_cell_8) + + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_9) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + return x_cell_15 + + def logits(self, features): + x = self.relu(features) + x = self.avg_pool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def nasnetamobile(num_classes=1001, pretrained='imagenet'): + r"""NASNetALarge model architecture from the + `"NASNet" `_ paper. + """ + if pretrained: + settings = pretrained_settings['nasnetamobile'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + + # both 'imagenet'&'imagenet+background' are loaded from same parameters + model = NASNetAMobile(num_classes=num_classes) + model.load_state_dict(model_zoo.load_url(settings['url'], map_location=None)) + + # if pretrained == 'imagenet': + # new_last_linear = nn.Linear(model.last_linear.in_features, 1000) + # new_last_linear.weight.data = model.last_linear.weight.data[1:] + # new_last_linear.bias.data = model.last_linear.bias.data[1:] + # model.last_linear = new_last_linear + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + + model.mean = settings['mean'] + model.std = settings['std'] + else: + settings = pretrained_settings['nasnetamobile']['imagenet'] + model = NASNetAMobile(num_classes=num_classes) + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + + model.mean = settings['mean'] + model.std = settings['std'] + return model + + +if __name__ == "__main__": + + model = NASNetAMobile() + input = Variable(torch.randn(2, 3, 224, 224)) + output = model(input) + + print(output.size()) diff --git a/models/pnasnet.py b/models/pnasnet.py new file mode 100644 index 0000000..c169c69 --- /dev/null +++ b/models/pnasnet.py @@ -0,0 +1,401 @@ +from __future__ import print_function, division, absolute_import +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +pretrained_settings = { + 'pnasnet5large': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', + 'input_space': 'RGB', + 'input_size': [3, 331, 331], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + 'imagenet+background': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', + 'input_space': 'RGB', + 'input_size': [3, 331, 331], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1001 + } + } +} + + +class MaxPool(nn.Module): + + def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False): + super(MaxPool, self).__init__() + self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None + self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) + + def forward(self, x): + if self.zero_pad: + x = self.zero_pad(x) + x = self.pool(x) + if self.zero_pad: + x = x[:, :, 1:, 1:] + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride, + dw_padding): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, + kernel_size=dw_kernel_size, + stride=dw_stride, padding=dw_padding, + groups=in_channels, bias=False) + self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, + kernel_size=1, bias=False) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + stem_cell=False, zero_pad=False): + super(BranchSeparables, self).__init__() + padding = kernel_size // 2 + middle_channels = out_channels if stem_cell else in_channels + self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None + self.relu_1 = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, middle_channels, + kernel_size, dw_stride=stride, + dw_padding=padding) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) + self.relu_2 = nn.ReLU() + self.separable_2 = SeparableConv2d(middle_channels, out_channels, + kernel_size, dw_stride=1, + dw_padding=padding) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.relu_1(x) + if self.zero_pad: + x = self.zero_pad(x) + x = self.separable_1(x) + if self.zero_pad: + x = x[:, :, 1:, 1:].contiguous() + x = self.bn_sep_1(x) + x = self.relu_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class ReluConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1): + super(ReluConvBn, self).__init__() + self.relu = nn.ReLU() + self.conv = nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + bias=False) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.relu(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class FactorizedReduction(nn.Module): + + def __init__(self, in_channels, out_channels): + super(FactorizedReduction, self).__init__() + self.relu = nn.ReLU() + self.path_1 = nn.Sequential(OrderedDict([ + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', nn.Conv2d(in_channels, out_channels // 2, + kernel_size=1, bias=False)), + ])) + self.path_2 = nn.Sequential(OrderedDict([ + ('pad', nn.ZeroPad2d((0, 1, 0, 1))), + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', nn.Conv2d(in_channels, out_channels // 2, + kernel_size=1, bias=False)), + ])) + self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.relu(x) + + x_path1 = self.path_1(x) + + x_path2 = self.path_2.pad(x) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + + out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + return out + + +class CellBase(nn.Module): + + def cell_forward(self, x_left, x_right): + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) + x_comb_iter_3_right = self.comb_iter_3_right(x_right) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_left) + if self.comb_iter_4_right: + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + else: + x_comb_iter_4_right = x_right + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat( + [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, + x_comb_iter_4], 1) + return x_out + + +class CellStem0(CellBase): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right): + super(CellStem0, self).__init__() + self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, + kernel_size=1) + self.comb_iter_0_left = BranchSeparables(in_channels_left, + out_channels_left, + kernel_size=5, stride=2, + stem_cell=True) + self.comb_iter_0_right = nn.Sequential(OrderedDict([ + ('max_pool', MaxPool(3, stride=2)), + ('conv', nn.Conv2d(in_channels_left, out_channels_left, + kernel_size=1, bias=False)), + ('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)), + ])) + self.comb_iter_1_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=7, stride=2) + self.comb_iter_1_right = MaxPool(3, stride=2) + self.comb_iter_2_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=5, stride=2) + self.comb_iter_2_right = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3, stride=2) + self.comb_iter_3_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3) + self.comb_iter_3_right = MaxPool(3, stride=2) + self.comb_iter_4_left = BranchSeparables(in_channels_right, + out_channels_right, + kernel_size=3, stride=2, + stem_cell=True) + self.comb_iter_4_right = ReluConvBn(out_channels_right, + out_channels_right, + kernel_size=1, stride=2) + + def forward(self, x_left): + x_right = self.conv_1x1(x_left) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class Cell(CellBase): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right, is_reduction=False, zero_pad=False, + match_prev_layer_dimensions=False): + super(Cell, self).__init__() + + # If `is_reduction` is set to `True` stride 2 is used for + # convolutional and pooling layers to reduce the spatial size of + # the output of a cell approximately by a factor of 2. + stride = 2 if is_reduction else 1 + + # If `match_prev_layer_dimensions` is set to `True` + # `FactorizedReduction` is used to reduce the spatial size + # of the left input of a cell approximately by a factor of 2. + self.match_prev_layer_dimensions = match_prev_layer_dimensions + if match_prev_layer_dimensions: + self.conv_prev_1x1 = FactorizedReduction(in_channels_left, + out_channels_left) + else: + self.conv_prev_1x1 = ReluConvBn(in_channels_left, + out_channels_left, kernel_size=1) + + self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, + kernel_size=1) + self.comb_iter_0_left = BranchSeparables(out_channels_left, + out_channels_left, + kernel_size=5, stride=stride, + zero_pad=zero_pad) + self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad) + self.comb_iter_1_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=7, stride=stride, + zero_pad=zero_pad) + self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad) + self.comb_iter_2_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=5, stride=stride, + zero_pad=zero_pad) + self.comb_iter_2_right = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3, stride=stride, + zero_pad=zero_pad) + self.comb_iter_3_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3) + self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad) + self.comb_iter_4_left = BranchSeparables(out_channels_left, + out_channels_left, + kernel_size=3, stride=stride, + zero_pad=zero_pad) + if is_reduction: + self.comb_iter_4_right = ReluConvBn(out_channels_right, + out_channels_right, + kernel_size=1, stride=stride) + else: + self.comb_iter_4_right = None + + def forward(self, x_left, x_right): + x_left = self.conv_prev_1x1(x_left) + x_right = self.conv_1x1(x_right) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class PNASNet5Large(nn.Module): + def __init__(self, num_classes=1001): + super(PNASNet5Large, self).__init__() + self.num_classes = num_classes + self.conv_0 = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), + ('bn', nn.BatchNorm2d(96, eps=0.001)) + ])) + self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, + in_channels_right=96, + out_channels_right=54) + self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108, + in_channels_right=270, out_channels_right=108, + match_prev_layer_dimensions=True, + is_reduction=True) + self.cell_0 = Cell(in_channels_left=270, out_channels_left=216, + in_channels_right=540, out_channels_right=216, + match_prev_layer_dimensions=True) + self.cell_1 = Cell(in_channels_left=540, out_channels_left=216, + in_channels_right=1080, out_channels_right=216) + self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216, + in_channels_right=1080, out_channels_right=216) + self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216, + in_channels_right=1080, out_channels_right=216) + self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432, + in_channels_right=1080, out_channels_right=432, + is_reduction=True, zero_pad=True) + self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432, + in_channels_right=2160, out_channels_right=432, + match_prev_layer_dimensions=True) + self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432, + in_channels_right=2160, out_channels_right=432) + self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432, + in_channels_right=2160, out_channels_right=432) + self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864, + in_channels_right=2160, out_channels_right=864, + is_reduction=True) + self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864, + in_channels_right=4320, out_channels_right=864, + match_prev_layer_dimensions=True) + self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864, + in_channels_right=4320, out_channels_right=864) + self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, + in_channels_right=4320, out_channels_right=864) + self.relu = nn.ReLU() + self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) + self.dropout = nn.Dropout(0.5) + self.last_linear = nn.Linear(4320, num_classes) + + def features(self, x): + x_conv_0 = self.conv_0(x) + x_stem_0 = self.cell_stem_0(x_conv_0) + x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) + x_cell_0 = self.cell_0(x_stem_0, x_stem_1) + x_cell_1 = self.cell_1(x_stem_1, x_cell_0) + x_cell_2 = self.cell_2(x_cell_0, x_cell_1) + x_cell_3 = self.cell_3(x_cell_1, x_cell_2) + x_cell_4 = self.cell_4(x_cell_2, x_cell_3) + x_cell_5 = self.cell_5(x_cell_3, x_cell_4) + x_cell_6 = self.cell_6(x_cell_4, x_cell_5) + x_cell_7 = self.cell_7(x_cell_5, x_cell_6) + x_cell_8 = self.cell_8(x_cell_6, x_cell_7) + x_cell_9 = self.cell_9(x_cell_7, x_cell_8) + x_cell_10 = self.cell_10(x_cell_8, x_cell_9) + x_cell_11 = self.cell_11(x_cell_9, x_cell_10) + return x_cell_11 + + def logits(self, features): + x = self.relu(features) + x = self.avg_pool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def pnasnet5large(num_classes=1001, pretrained='imagenet'): + r"""PNASNet-5 model architecture from the + `"Progressive Neural Architecture Search" + `_ paper. + """ + if pretrained: + settings = pretrained_settings['pnasnet5large'][pretrained] + assert num_classes == settings[ + 'num_classes'], 'num_classes should be {}, but is {}'.format( + settings['num_classes'], num_classes) + + # both 'imagenet'&'imagenet+background' are loaded from same parameters + model = PNASNet5Large(num_classes=1001) + model.load_state_dict(model_zoo.load_url(settings['url'])) + + if pretrained == 'imagenet': + new_last_linear = nn.Linear(model.last_linear.in_features, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + + model.mean = settings['mean'] + model.std = settings['std'] + else: + model = PNASNet5Large(num_classes=num_classes) + return model diff --git a/models/xception.py b/models/xception.py new file mode 100644 index 0000000..7783c47 --- /dev/null +++ b/models/xception.py @@ -0,0 +1,235 @@ +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" +from __future__ import print_function, division, absolute_import +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.nn import init + +__all__ = ['xception'] + +pretrained_settings = { + 'xception': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000, + 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } + } +} + + +class SeparableConv2d(nn.Module): + def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): + super(SeparableConv2d,self).__init__() + + self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) + self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) + + def forward(self,x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides!=1: + self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip=None + + rep=[] + + filters=in_filters + if grow_first: + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps-1): + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3,strides,1)) + self.rep = nn.Sequential(*rep) + + def forward(self,inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x+=skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + def __init__(self, num_classes=1000): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.num_classes = num_classes + + self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32,64,3,bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + #do relu here + + self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) + self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) + self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) + + self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) + + self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) + + self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) + + self.conv3 = SeparableConv2d(1024,1536,3,1,1) + self.bn3 = nn.BatchNorm2d(1536) + self.relu3 = nn.ReLU(inplace=True) + + #do relu here + self.conv4 = SeparableConv2d(1536,2048,3,1,1) + self.bn4 = nn.BatchNorm2d(2048) + + self.fc = nn.Linear(2048, num_classes) + + # #------- init weights -------- + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + # elif isinstance(m, nn.BatchNorm2d): + # m.weight.data.fill_(1) + # m.bias.data.zero_() + # #----------------------------- + + def features(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.relu3(x) + + x = self.conv4(x) + x = self.bn4(x) + return x + + def logits(self, features): + x = nn.ReLU(inplace=True)(features) + + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def xception(num_classes=1000, pretrained='imagenet'): + model = Xception(num_classes=num_classes) + if pretrained: + settings = pretrained_settings['xception'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + + model = Xception(num_classes=num_classes) + model.load_state_dict(model_zoo.load_url(settings['url'])) + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + + # TODO: ugly + model.last_linear = model.fc + del model.fc + return model diff --git a/notebooks/1.0_DenseNetGanDetection.ipynb b/notebooks/1.0_DenseNetGanDetection.ipynb new file mode 100644 index 0000000..c24fba7 --- /dev/null +++ b/notebooks/1.0_DenseNetGanDetection.ipynb @@ -0,0 +1,621 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **AmilGan Detection using DenseNet**\n", + "#### Amil Khan | March 1, 2019 | Version 2\n", + "***" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from collections import OrderedDict\n", + "import math\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.autograd import Variable\n", + "from functools import reduce\n", + "import torch.utils.model_zoo as model_zoo\n", + "import torch.nn.functional as F\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data.sampler import SubsetRandomSampler\n", + "from torchvision import datasets\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "plt.rcParams['figure.dpi'] = 200\n", + "%config InlineBackend.figure_format = 'retina'\n", + "train_on_gpu = torch.cuda.is_available()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Loading, Preprocessing, Wrangling, Cleansing\n", + "\n", + "I chose to functionalize everything in the data preprocessing pipeline for two reasons:\n", + "\n", + "- **Reproducibility** Many times, there is a large body of beautiful code that has been written, but no documentation. The data processing step is usually where people get stuck.\n", + "- **Iterability** I wanted to iterate fast when chnaging parameters in the module, as well as have one block of code that will take care of everything after restarting the kernal.\n", + "\n", + "**Inputs**: \n", + "- `path_to_train`: Path to your training set folder (I am using PyTorch's `ImageFolder` module)\n", + "- `path_to_test`: Path to your test set folder\n", + "- `num_workers`: number of subprocesses to use for data loading\n", + "- `batch_size`: how many samples per batch to load\n", + "- `valid_size`: percentage of training set to use as validation\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def DataConstructor2000(path_to_train, path_to_test, classes=None, num_workers=4, batch_size=32, valid_size = 0.2):\n", + " \n", + " \n", + " # Transformations to the image, edit as need be\n", + " transform = transforms.Compose([\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.Resize([32,32]),\n", + " transforms.ToTensor()])\n", + " \n", + " train_dataset = datasets.ImageFolder(path_to_train, transform=transform)\n", + " print(\"Successfully Loaded Training Set.\")\n", + "\n", + " test_dataset = datasets.ImageFolder(path_to_test, transform=transform)\n", + " print(\"Successfully Loaded Test Set.\")\n", + "\n", + " \n", + " # obtain training indices that will be used for validation\n", + " num_train = len(train_dataset)\n", + " indices = list(range(num_train))\n", + " np.random.shuffle(indices)\n", + " split = int(np.floor(valid_size * num_train))\n", + " train_idx, valid_idx = indices[split:], indices[:split]\n", + "\n", + " # define samplers for obtaining training and validation batches\n", + " train_sampler = SubsetRandomSampler(train_idx)\n", + " valid_sampler = SubsetRandomSampler(valid_idx)\n", + "\n", + " train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, \n", + " sampler=train_sampler, num_workers=num_workers)\n", + "\n", + " valid_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, \n", + " sampler=valid_sampler, num_workers=num_workers)\n", + "\n", + " test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=20, \n", + " num_workers=num_workers, shuffle=True)\n", + " if classes != None:\n", + " print(\"Number of Classes:\", len(classes))\n", + " return train_loader, valid_loader, test_loader, classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_loader, valid_loader,test_loader, classes = DataConstructor2000(path_to_train='/workspace/Documents/pretrained-models.pytorch-master/pretrainedmodels/training/',path_to_test='/workspace/Documents/pretrained-models.pytorch-master/pretrainedmodels/test/',\n", + " classes=['Fake','Real'],num_workers=40, batch_size=150, valid_size = 0.3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize for Confirmation\n", + "You do not need to change anything here. It should run right out of the box. But feel free to change what you need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "helper function to un-normalize and display an image\n", + "def imshow(img):\n", + "# img = img / 2 + 0.5 # unnormalize if you added normalization in the transformation step\n", + " plt.imshow(np.transpose(img, (1, 2, 0))) # convert from Tensor image\n", + "\n", + "# obtain one batch of training images\n", + "dataiter = iter(train_loader)\n", + "images, labels = dataiter.next()\n", + "images = images.numpy() # convert images to numpy for display\n", + "print(images.shape)\n", + "\n", + "# plot the images in the batch, along with the corresponding labels\n", + "fig = plt.figure(figsize=(25, 4))\n", + "# display 20 images\n", + "for idx in np.arange(20):\n", + " ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])\n", + " imshow(images[idx])\n", + " ax.set_title(classes[labels[idx]])\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Time to Define the Model\n", + "\n", + "In this notebook, I am using DenseNet for our Gan Detection problem.\n", + "\n", + "__Abstract.__ Recent work has shown that convolutional networks can be substantially deeper, more accurate, and efficient to train if they contain shorter connections between layers close to the input and those close to the output. In this paper, we embrace this observation and introduce the Dense Convolutional Network (`DenseNet`), which connects each layer to every other layer in a feed-forward fashion. Whereas traditional convolutional networks with $L$ layers have $L$ connections—one between each layer and its subsequent layer—our network has $\\frac{L(L+1)}{2}$ direct connections. For each layer, the feature-maps of all preceding layers are used as inputs, and its own feature-maps are used as inputsinto all subsequent layers. DenseNets have several compelling advantages: they alleviate the vanishing-gradient problem, strengthen feature propagation, encourage feature reuse, and substantially reduce the number of parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Bottleneck(nn.Module):\n", + " def __init__(self, nChannels, growthRate):\n", + " super(Bottleneck, self).__init__()\n", + " interChannels = 4*growthRate\n", + " self.bn1 = nn.BatchNorm2d(nChannels)\n", + " self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,\n", + " bias=False)\n", + " self.bn2 = nn.BatchNorm2d(interChannels)\n", + " self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,\n", + " padding=1, bias=False)\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(F.relu(self.bn1(x)))\n", + " out = self.conv2(F.relu(self.bn2(out)))\n", + " out = torch.cat((x, out), 1)\n", + " return out\n", + "\n", + "class SingleLayer(nn.Module):\n", + " def __init__(self, nChannels, growthRate):\n", + " super(SingleLayer, self).__init__()\n", + " self.bn1 = nn.BatchNorm2d(nChannels)\n", + " self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3,\n", + " padding=1, bias=False)\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(F.relu(self.bn1(x)))\n", + " out = torch.cat((x, out), 1)\n", + " return out\n", + "\n", + "class Transition(nn.Module):\n", + " def __init__(self, nChannels, nOutChannels):\n", + " super(Transition, self).__init__()\n", + " self.bn1 = nn.BatchNorm2d(nChannels)\n", + " self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1,\n", + " bias=False)\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(F.relu(self.bn1(x)))\n", + " out = F.avg_pool2d(out, 2)\n", + " return out\n", + "\n", + "\n", + "class DenseNet(nn.Module):\n", + " def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):\n", + " super(DenseNet, self).__init__()\n", + "\n", + " nDenseBlocks = (depth-4) // 3\n", + " if bottleneck:\n", + " nDenseBlocks //= 2\n", + "\n", + " nChannels = 2*growthRate\n", + " self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,\n", + " bias=False)\n", + " self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)\n", + " nChannels += nDenseBlocks*growthRate\n", + " nOutChannels = int(math.floor(nChannels*reduction))\n", + " self.trans1 = Transition(nChannels, nOutChannels)\n", + "\n", + " nChannels = nOutChannels\n", + " self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)\n", + " nChannels += nDenseBlocks*growthRate\n", + " nOutChannels = int(math.floor(nChannels*reduction))\n", + " self.trans2 = Transition(nChannels, nOutChannels)\n", + "\n", + " nChannels = nOutChannels\n", + " self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)\n", + " nChannels += nDenseBlocks*growthRate\n", + "\n", + " self.bn1 = nn.BatchNorm2d(nChannels)\n", + " self.fc = nn.Linear(nChannels, nClasses)\n", + "\n", + " for m in self.modules():\n", + " if isinstance(m, nn.Conv2d):\n", + " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", + " m.weight.data.normal_(0, math.sqrt(2. / n))\n", + " elif isinstance(m, nn.BatchNorm2d):\n", + " m.weight.data.fill_(1)\n", + " m.bias.data.zero_()\n", + " elif isinstance(m, nn.Linear):\n", + " m.bias.data.zero_()\n", + "\n", + " def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):\n", + " layers = []\n", + " for i in range(int(nDenseBlocks)):\n", + " if bottleneck:\n", + " layers.append(Bottleneck(nChannels, growthRate))\n", + " else:\n", + " layers.append(SingleLayer(nChannels, growthRate))\n", + " nChannels += growthRate\n", + " return nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " out = self.conv1(x)\n", + " out = self.trans1(self.dense1(out))\n", + " out = self.trans2(self.dense2(out))\n", + " out = self.dense3(out)\n", + " out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))\n", + " out = F.log_softmax(self.fc(out))\n", + " return out\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = DenseNet(growthRate=12, depth=100, reduction=0.5,\n", + "bottleneck=True, nClasses=2).cuda()\n", + "model = nn.DataParallel(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Choosing a Loss Function and Optimizer\n", + "I went with Cross-Entropy Loss. __Cross-entropy loss__, or log loss, measures the performance of a classification model whose output is a probability value between 0 and 1. Cross-entropy loss increases as the predicted probability diverges from the actual label. Hence, predicting a probability of .012 when the actual observation label is 1 would be bad and result in a high loss value. A perfect model would have a log loss of 0.\n", + " $$\\text{loss}(x, class) = -\\log\\left(\\frac{\\exp(x[class])}{\\sum_j \\exp(x[j])}\\right)\n", + " = -x[class] + \\log\\left(\\sum_j \\exp(x[j])\\right)$$\n", + " \n", + "I opted with good old __Stochastic Gradient Descent__. Nuff said." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# specify loss function (categorical cross-entropy)\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "# specify optimizer\n", + "optimizer = optimizer = torch.optim.SGD([\n", + " {'params': list(model.parameters())[:-1], 'lr': 1e-3, 'momentum': 0.9, 'weight_decay': 1e-3},\n", + " {'params': list(model.parameters())[-1], 'lr': 5e-5, 'momentum': 0.9, 'weight_decay': 1e-5}\n", + "])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Time to Train\n", + "This is where stripes are earned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# number of epochs to train the model\n", + "n_epochs = 40\n", + "\n", + "valid_loss_min = np.Inf # track change in validation loss\n", + "training_vis = []\n", + "valid_vis = []\n", + "\n", + "for epoch in range(1, n_epochs+1):\n", + "\n", + " # keep track of training and validation loss\n", + " train_loss = 0.0\n", + " valid_loss = 0.0\n", + " \n", + " ###################\n", + " # train the model #\n", + " ###################\n", + " model.train()\n", + " for data, target in train_loader:\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " \n", + " # clear the gradients of all optimized variables\n", + " optimizer.zero_grad()\n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # backward pass: compute gradient of the loss with respect to model parameters\n", + " loss.backward()\n", + " # perform a single optimization step (parameter update)\n", + " optimizer.step()\n", + " # update training loss\n", + " train_loss += loss.item()*data.size(0)\n", + "\n", + " \n", + " ###################### \n", + " # validate the model #\n", + " ######################\n", + " model.eval()\n", + " for data, target in valid_loader:\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " \n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # update average validation loss \n", + " valid_loss += loss.item()*data.size(0)\n", + " \n", + " # calculate average losses\n", + " train_loss = train_loss/len(train_loader.dataset)\n", + " valid_loss = valid_loss/len(valid_loader.dataset)\n", + " \n", + " training_vis.append(train_loss)\n", + " valid_vis.append(valid_loss)\n", + " \n", + " # print training/validation statistics \n", + " print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n", + " epoch, train_loss, valid_loss))\n", + " \n", + " # save model if validation loss has decreased\n", + " if valid_loss <= valid_loss_min:\n", + " print('\\nValidation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(\n", + " valid_loss_min,\n", + " valid_loss))\n", + " torch.save(model.state_dict(), 'DenseNet35656_gan-detector_Final.pt')\n", + " valid_loss_min = valid_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Model\n", + "\n", + "I included here the ability to load a model from previous training runs. Uncomment/Modify what you need to and go HAM. In this case, load the model, the optimizer and the criterion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# model = TheModelClass(*args, **kwargs)\n", + "# model.load_state_dict(torch.load('./DenseNet35656_gan-detector_after72.pt'))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Set Evaluation \n", + "Earlier we loaded in our test data under the name `test_loader`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# track test loss\n", + "test_loss = 0.0\n", + "class_correct = list(0. for i in range(2))\n", + "class_total = list(0. for i in range(2))\n", + "\n", + "model.eval()\n", + "# iterate over test data\n", + "for data, target in test_loader:\n", + "# print(target)\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + "# print(output)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # update test loss \n", + " test_loss += loss.item()*data.size(0)\n", + " # convert output probabilities to predicted class\n", + " _, pred = torch.max(output, 1)\n", + "# print(pred)\n", + " # compare predictions to true label\n", + " correct_tensor = pred.eq(target.data.view_as(pred))\n", + " correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())\n", + " # calculate test accuracy for each object class\n", + " for i in range(2):\n", + "# print(i)\n", + " label = target.data[i]\n", + " class_correct[label] += correct[i].item()\n", + " class_total[label] += 1\n", + "\n", + "# average test loss\n", + "test_loss = test_loss/len(test_loader.dataset)\n", + "print('Test Loss: {:.6f}\\n'.format(test_loss))\n", + "\n", + "for i in range(2):\n", + " if class_total[i] > 0:\n", + " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n", + " classes[i], 100 * class_correct[i] / class_total[i],\n", + " np.sum(class_correct[i]), np.sum(class_total[i])))\n", + " else:\n", + " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", + "\n", + "print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n", + " 100. * np.sum(class_correct) / np.sum(class_total),\n", + " np.sum(class_correct), np.sum(class_total)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Model Performance\n", + "\n", + "Most likely, we will want to see how our model performed throughout each epoch. In this plot, we are visualizing training and validation loss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), valid_vis)\n", + "plt.plot(range(epoch), valid_vis)\n", + "# np.savetxt('DenseNet.txt', [training_vis, valid_vis])\n", + "# plt.savefig()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Misclassified\n", + "\n", + "Similarly, we will want to see which types of images it correctly classified. In our case, we plot a randomly sampled batch of our test set and place the correct label in parentheses, and the predicted without." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # obtain one batch of test images\n", + "dataiter = iter(test_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "# get sample outputs\n", + "output = model(images)\n", + "# convert output probabilities to predicted class\n", + "_, preds = torch.max(output, 1)\n", + "# prep images for display\n", + "images = images.numpy()\n", + "labels = labels.numpy()\n", + "print(images.shape)\n", + "# plot the images in the batch, along with predicted and true labels\n", + "fig = plt.figure(figsize=(20, 4))\n", + "for idx in np.arange(20):\n", + " ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])\n", + " ax.imshow(np.swapaxes((images[idx]),axis1=0, axis2=2))\n", + " ax.set_title(\"{} ({})\".format(classes[preds[idx]], classes[labels[idx]]),\n", + " color=(\"green\" if classes[preds[idx]]==classes[labels[idx]] else \"red\"))\n", + "plt.savefig('Densenet_misclass.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus Sanity Check Visualization\n", + "\n", + "Here, we plot the `RGB` channels of the image, but with a twist. We plot the corresponding RGB value inside the color. Super cool sanity check and overall visualization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rgb_img = np.squeeze(images[3])\n", + "channels = ['red channel', 'green channel', 'blue channel']\n", + "\n", + "fig = plt.figure(figsize = (36, 36)) \n", + "for idx in np.arange(rgb_img.shape[0]):\n", + " ax = fig.add_subplot(1, 3, idx + 1)\n", + " img = rgb_img[idx]\n", + " ax.imshow(img)\n", + " ax.set_title(channels[idx])\n", + " width, height = img.shape\n", + " thresh = img.max()/2.5\n", + " for x in range(width):\n", + " for y in range(height):\n", + " val = round(img[x][y],2) if img[x][y] !=0 else 0\n", + " ax.annotate(str(val), xy=(y,x),\n", + " horizontalalignment='center',\n", + " verticalalignment='center', size=8,\n", + " color='white' if img[x][y] {:.6f}). Saving model ...'.format(\n", + " valid_loss_min,\n", + " valid_loss))\n", + " torch.save(model.state_dict(), 'NasnetMobile35656_gan-detector_Final.pt')\n", + " valid_loss_min = valid_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the Saved Model\n", + "\n", + "Since we are smart enough to save the best validation error, we can easily load the model and make our predictions on our test set if need be." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.load_state_dict(torch.load('NasnetMobile35656_gan-detector_Final.pt'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## How Well Can we Generalize?\n", + "\n", + "Here, we will compute the test error and test accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# track test loss\n", + "test_loss = 0.0\n", + "class_correct = list(0. for i in range(2))\n", + "class_total = list(0. for i in range(2))\n", + "test_vis = []\n", + "model.eval()\n", + "# iterate over test data\n", + "for data, target in valid_loader:\n", + "# print(target)\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + "# print(output)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # update test loss \n", + " test_loss += loss.item()*data.size(0)\n", + " \n", + " # convert output probabilities to predicted class\n", + " _, pred = torch.max(output, 1)\n", + "# print(pred)\n", + " # compare predictions to true label\n", + " correct_tensor = pred.eq(target.data.view_as(pred))\n", + " correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())\n", + " # calculate test accuracy for each object class\n", + " for i in range(2):\n", + "# print(i)\n", + " label = target.data[i]\n", + " class_correct[label] += correct[i].item()\n", + " class_total[label] += 1\n", + "\n", + "# average test loss\n", + "test_loss = test_loss/len(test_loader.dataset)\n", + "print('Test Loss: {:.6f}\\n'.format(test_loss))\n", + "test_vis.append(test_loss)\n", + "for i in range(2):\n", + " if class_total[i] > 0:\n", + " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n", + " classes[i], 100 * class_correct[i] / class_total[i],\n", + " np.sum(class_correct[i]), np.sum(class_total[i])))\n", + " else:\n", + " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", + "\n", + "print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n", + " 100. * np.sum(class_correct) / np.sum(class_total),\n", + " np.sum(class_correct), np.sum(class_total)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize the Training and Validation Errors\n", + "\n", + "To see how our model is doing as it runs through different epochs, we need to visualize the training. I tried to make it pretty, but honestly this works for now. Orange is the __Validation__ and Blue is the __Training__." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(range(epoch), training_vis)\n", + "# plt.scatter(range(epoch), training_vis)\n", + "plt.plot(range(epoch), valid_vis)\n", + "# plt.scatter(range(epoch), valid_vis)\n", + "plt.title('Training vs. Validation')\n", + "plt.xlabel('Number of Epochs')\n", + "plt.ylabel('Error')\n", + "plt.savefig('Nasnet_line.svg')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.savetxt('NasNet_Final.txt', np.array([training_vis, valid_vis]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Which images did we get Wrong and Right?\n", + "\n", + "I could have made this a function, but it works. \n", + "\n", + "Format: __Predicted (Truth)__" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# obtain one batch of test images\n", + "dataiter = iter(test_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "# get sample outputs\n", + "output = model(images)\n", + "# convert output probabilities to predicted class\n", + "_, preds = torch.max(output, 1)\n", + "# prep images for display\n", + "images = images.numpy()\n", + "labels = labels.numpy()\n", + "print(images.shape)\n", + "# plot the images in the batch, along with predicted and true labels\n", + "fig = plt.figure(figsize=(20, 4))\n", + "for idx in np.arange(20):\n", + " ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])\n", + " ax.imshow(np.swapaxes((images[idx]),axis1=0, axis2=2))\n", + " ax.set_title(\"{} ({})\".format(classes[preds[idx]], classes[labels[idx]]),\n", + " color=(\"green\" if classes[preds[idx]]==classes[labels[idx]] else \"red\"))\n", + "plt.savefig('PNASNET_mis.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exploratory Analysis/Sanity Check\n", + "\n", + "Basically just plots the corresponding value with the color." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rgb_img = np.squeeze(images[3])\n", + "channels = ['red channel', 'green channel', 'blue channel']\n", + "\n", + "fig = plt.figure(figsize = (36, 36)) \n", + "for idx in np.arange(rgb_img.shape[0]):\n", + " ax = fig.add_subplot(1, 3, idx + 1)\n", + " img = rgb_img[idx]\n", + " ax.imshow(img, cmap='gray')\n", + " ax.set_title(channels[idx])\n", + " width, height = img.shape\n", + " thresh = img.max()/2.5\n", + " for x in range(width):\n", + " for y in range(height):\n", + " val = round(img[x][y],2) if img[x][y] !=0 else 0\n", + " ax.annotate(str(val), xy=(y,x),\n", + " horizontalalignment='center',\n", + " verticalalignment='center', size=8,\n", + " color='white' if img[x][y]`_ paper.\n", + " \"\"\"\n", + " model = PNASNet5Large(num_classes=num_classes)\n", + " return model.cuda()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = pnasnet5large(num_classes=2)\n", + "model = nn.DataParallel(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Choosing a Loss Function and Optimizer\n", + "I went with Cross-Entropy Loss. __Cross-entropy loss__, or log loss, measures the performance of a classification model whose output is a probability value between 0 and 1. Cross-entropy loss increases as the predicted probability diverges from the actual label. Hence, predicting a probability of .012 when the actual observation label is 1 would be bad and result in a high loss value. A perfect model would have a log loss of 0.\n", + " $$\\text{loss}(x, class) = -\\log\\left(\\frac{\\exp(x[class])}{\\sum_j \\exp(x[j])}\\right)\n", + " = -x[class] + \\log\\left(\\sum_j \\exp(x[j])\\right)$$\n", + " \n", + "I opted with good old __Stochastic Gradient Descent__. Nuff said." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# specify loss function (categorical cross-entropy)\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "# specify optimizer\n", + "optimizer = optimizer = torch.optim.SGD([\n", + " {'params': list(model.parameters())[:-1], 'lr': 1e-3, 'momentum': 0.9, 'weight_decay': 1e-3},\n", + " {'params': list(model.parameters())[-1], 'lr': 5e-5, 'momentum': 0.9, 'weight_decay': 1e-5}\n", + "])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Time to Train\n", + "This is where stripes are earned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# number of epochs to train the model\n", + "n_epochs = 1\n", + "\n", + "valid_loss_min = np.Inf # track change in validation loss\n", + "training_vis = []\n", + "valid_vis = []\n", + "\n", + "for epoch in range(1, n_epochs+1):\n", + "\n", + " # keep track of training and validation loss\n", + " train_loss = 0.0\n", + " valid_loss = 0.0\n", + " \n", + " ###################\n", + " # train the model #\n", + " ###################\n", + " model.train()\n", + " for data, target in train_loader:\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " \n", + " # clear the gradients of all optimized variables\n", + " optimizer.zero_grad()\n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # backward pass: compute gradient of the loss with respect to model parameters\n", + " loss.backward()\n", + " # perform a single optimization step (parameter update)\n", + " optimizer.step()\n", + " # update training loss\n", + " train_loss += loss.item()*data.size(0)\n", + "\n", + " \n", + " ###################### \n", + " # validate the model #\n", + " ######################\n", + " model.eval()\n", + " for data, target in valid_loader:\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " \n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # update average validation loss \n", + " valid_loss += loss.item()*data.size(0)\n", + " \n", + " # calculate average losses\n", + " train_loss = train_loss/len(train_loader.dataset)\n", + " valid_loss = valid_loss/len(valid_loader.dataset)\n", + " \n", + " training_vis.append(train_loss)\n", + " valid_vis.append(valid_loss)\n", + " \n", + " # print training/validation statistics \n", + " print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n", + " epoch, train_loss, valid_loss))\n", + " \n", + " # save model if validation loss has decreased\n", + " if valid_loss <= valid_loss_min:\n", + " print('\\nValidation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(\n", + " valid_loss_min,\n", + " valid_loss))\n", + " torch.save(model.state_dict(), 'PNASNET_gan-detector_test.pt')\n", + " valid_loss_min = valid_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Model\n", + "\n", + "I included here the ability to load a model from previous training runs. Uncomment/Modify what you need to and go HAM. In this case, load the model, the optimizer and the criterion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# model = TheModelClass(*args, **kwargs)\n", + "# model.load_state_dict(torch.load('PNASNET_gan-detector_after96.pt'))\n", + "# torch.save(model.state_dict(), 'InceptionResnet_gan-detector_95.pt')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Set Evaluation \n", + "Earlier we loaded in our test data under the name `test_loader`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# track test loss\n", + "test_loss = 0.0\n", + "class_correct = list(0. for i in range(2))\n", + "class_total = list(0. for i in range(2))\n", + "\n", + "model.eval()\n", + "# iterate over test data\n", + "for data, target in test_loader:\n", + "# print(target)\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + "# print(output)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # update test loss \n", + " test_loss += loss.item()*data.size(0)\n", + " # convert output probabilities to predicted class\n", + " _, pred = torch.max(output, 1)\n", + "# print(pred)\n", + " # compare predictions to true label\n", + " correct_tensor = pred.eq(target.data.view_as(pred))\n", + " correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())\n", + " # calculate test accuracy for each object class\n", + " for i in range(2):\n", + "# print(i)\n", + " label = target.data[i]\n", + " class_correct[label] += correct[i].item()\n", + " class_total[label] += 1\n", + "\n", + "# average test loss\n", + "test_loss = test_loss/len(test_loader.dataset)\n", + "print('Test Loss: {:.6f}\\n'.format(test_loss))\n", + "\n", + "for i in range(2):\n", + " if class_total[i] > 0:\n", + " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n", + " classes[i], 100 * class_correct[i] / class_total[i],\n", + " np.sum(class_correct[i]), np.sum(class_total[i])))\n", + " else:\n", + " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", + "\n", + "print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n", + " 100. * np.sum(class_correct) / np.sum(class_total),\n", + " np.sum(class_correct), np.sum(class_total)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Model Performance\n", + "\n", + "Most likely, we will want to see how our model performed throughout each epoch. In this plot, we are visualizing training and validation loss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), valid_vis)\n", + "plt.plot(range(epoch), valid_vis)\n", + "# plt.savefig()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### **Save the Training and Validation Losses**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.savetxt('PNASNET.txt', np.array([training_vis, valid_vis]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Misclassified\n", + "\n", + "Similarly, we will want to see which types of images it correctly classified. In our case, we plot a randomly sampled batch of our test set and place the correct label in parentheses, and the predicted without." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# obtain one batch of test images\n", + "dataiter = iter(test_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "# get sample outputs\n", + "output = model(images)\n", + "# convert output probabilities to predicted class\n", + "_, preds = torch.max(output, 1)\n", + "# prep images for display\n", + "images = images.numpy()\n", + "labels = labels.numpy()\n", + "print(images.shape)\n", + "# plot the images in the batch, along with predicted and true labels\n", + "fig = plt.figure(figsize=(25, 4))\n", + "for idx in np.arange(20):\n", + " ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])\n", + " ax.imshow(np.swapaxes((images[idx]),axis1=0, axis2=2))\n", + " ax.set_title(\"{} ({})\".format(classes[preds[idx]], classes[labels[idx]]),\n", + " color=(\"green\" if classes[preds[idx]]==classes[labels[idx]] else \"red\"))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus Sanity Check Visualization\n", + "\n", + "Here, we plot the `RGB` channels of the image, but with a twist. We plot the corresponding RGB value inside the color. Super cool sanity check and overall visualization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rgb_img = np.squeeze(images[3])\n", + "channels = ['red channel', 'green channel', 'blue channel']\n", + "\n", + "fig = plt.figure(figsize = (36, 36)) \n", + "for idx in np.arange(rgb_img.shape[0]):\n", + " ax = fig.add_subplot(1, 3, idx + 1)\n", + " img = rgb_img[idx]\n", + " ax.imshow(img, cmap='gray')\n", + " ax.set_title(channels[idx])\n", + " width, height = img.shape\n", + " thresh = img.max()/2.5\n", + " for x in range(width):\n", + " for y in range(height):\n", + " val = round(img[x][y],2) if img[x][y] !=0 else 0\n", + " ax.annotate(str(val), xy=(y,x),\n", + " horizontalalignment='center',\n", + " verticalalignment='center', size=8,\n", + " color='white' if img[x][y] {:.6f}). Saving model ...'.format(\n", + " valid_loss_min,\n", + " valid_loss))\n", + " torch.save(model.state_dict(), 'ResNet35656_gan-detector_after82.pt')\n", + " valid_loss_min = valid_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Model\n", + "\n", + "I included here the ability to load a model from previous training runs. Uncomment/Modify what you need to and go HAM. In this case, load the model, the optimizer and the criterion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# model = TheModelClass(*args, **kwargs)\n", + "# model.load_state_dict(torch.load('ResNet35656_gan-detector_after82.pt'))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Set Evaluation \n", + "Earlier we loaded in our test data under the name `test_loader`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# track test loss\n", + "test_loss = 0.0\n", + "class_correct = list(0. for i in range(2))\n", + "class_total = list(0. for i in range(2))\n", + "\n", + "model.eval()\n", + "# iterate over test data\n", + "for data, target in test_loader:\n", + "# print(target)\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + "# print(output)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # update test loss \n", + " test_loss += loss.item()*data.size(0)\n", + " # convert output probabilities to predicted class\n", + " _, pred = torch.max(output, 1)\n", + "# print(pred)\n", + " # compare predictions to true label\n", + " correct_tensor = pred.eq(target.data.view_as(pred))\n", + " correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())\n", + " # calculate test accuracy for each object class\n", + " for i in range(2):\n", + "# print(i)\n", + " label = target.data[i]\n", + " class_correct[label] += correct[i].item()\n", + " class_total[label] += 1\n", + "\n", + "# average test loss\n", + "test_loss = test_loss/len(test_loader.dataset)\n", + "print('Test Loss: {:.6f}\\n'.format(test_loss))\n", + "\n", + "for i in range(2):\n", + " if class_total[i] > 0:\n", + " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n", + " classes[i], 100 * class_correct[i] / class_total[i],\n", + " np.sum(class_correct[i]), np.sum(class_total[i])))\n", + " else:\n", + " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", + "\n", + "print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n", + " 100. * np.sum(class_correct) / np.sum(class_total),\n", + " np.sum(class_correct), np.sum(class_total)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Model Performance\n", + "\n", + "Most likely, we will want to see how our model performed throughout each epoch. In this plot, we are visualizing training and validation loss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), valid_vis)\n", + "plt.plot(range(epoch), valid_vis)\n", + "plt.savefig('Resnet_77_scatter.svg')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### **Save the Training and Validation Losses**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.savetxt('Resnet_77.txt', np.array([training_vis, valid_vis]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Misclassified\n", + "\n", + "Similarly, we will want to see which types of images it correctly classified. In our case, we plot a randomly sampled batch of our test set and place the correct label in parentheses, and the predicted without." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# obtain one batch of test images\n", + "dataiter = iter(test_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "# get sample outputs\n", + "output = model(images)\n", + "# convert output probabilities to predicted class\n", + "_, preds = torch.max(output, 1)\n", + "# prep images for display\n", + "images = images.numpy()\n", + "labels = labels.numpy()\n", + "print(images.shape)\n", + "# plot the images in the batch, along with predicted and true labels\n", + "fig = plt.figure(figsize=(25, 4))\n", + "for idx in np.arange(20):\n", + " ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])\n", + " ax.imshow(np.swapaxes((images[idx]),axis1=0, axis2=2))\n", + " ax.set_title(\"{} ({})\".format(classes[preds[idx]], classes[labels[idx]]),\n", + " color=(\"green\" if classes[preds[idx]]==classes[labels[idx]] else \"red\"))\n", + "plt.savefig('rESNET_miss.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus Sanity Check Visualization\n", + "\n", + "Here, we plot the `RGB` channels of the image, but with a twist. We plot the corresponding RGB value inside the color. Super cool sanity check and overall visualization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rgb_img = np.squeeze(images[3])\n", + "channels = ['red channel', 'green channel', 'blue channel']\n", + "\n", + "fig = plt.figure(figsize = (36, 36)) \n", + "for idx in np.arange(rgb_img.shape[0]):\n", + " ax = fig.add_subplot(1, 3, idx + 1)\n", + " img = rgb_img[idx]\n", + " ax.imshow(img, cmap='gray')\n", + " ax.set_title(channels[idx])\n", + " width, height = img.shape\n", + " thresh = img.max()/2.5\n", + " for x in range(width):\n", + " for y in range(height):\n", + " val = round(img[x][y],2) if img[x][y] !=0 else 0\n", + " ax.annotate(str(val), xy=(y,x),\n", + " horizontalalignment='center',\n", + " verticalalignment='center', size=8,\n", + " color='white' if img[x][y] {:.6f}). Saving model ...'.format(\n", + " valid_loss_min,\n", + " valid_loss))\n", + " torch.save(model.state_dict(), 'Xception35656_gan-detector_Final.pt')\n", + " valid_loss_min = valid_loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# model = TheModelClass(*args, **kwargs)\n", + "model.load_state_dict(torch.load('Xception35656_gan-detector_Final.pt'))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# track test loss\n", + "test_loss = 0.0\n", + "class_correct = list(0. for i in range(2))\n", + "class_total = list(0. for i in range(2))\n", + "\n", + "model.eval()\n", + "# iterate over test data\n", + "for data, target in test_loader:\n", + "# print(target)\n", + " # move tensors to GPU if CUDA is available\n", + " if train_on_gpu:\n", + " data, target = data.cuda(), target.cuda()\n", + " # forward pass: compute predicted outputs by passing inputs to the model\n", + " output = model(data)\n", + "# print(output)\n", + " # calculate the batch loss\n", + " loss = criterion(output, target)\n", + " # update test loss \n", + " test_loss += loss.item()*data.size(0)\n", + " # convert output probabilities to predicted class\n", + " _, pred = torch.max(output, 1)\n", + "# print(pred)\n", + " # compare predictions to true label\n", + " correct_tensor = pred.eq(target.data.view_as(pred))\n", + " correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())\n", + " # calculate test accuracy for each object class\n", + " for i in range(2):\n", + "# print(i)\n", + " label = target.data[i]\n", + " class_correct[label] += correct[i].item()\n", + " class_total[label] += 1\n", + "\n", + "# average test loss\n", + "test_loss = test_loss/len(test_loader.dataset)\n", + "print('Test Loss: {:.6f}\\n'.format(test_loss))\n", + "\n", + "for i in range(2):\n", + " if class_total[i] > 0:\n", + " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n", + " classes[i], 100 * class_correct[i] / class_total[i],\n", + " np.sum(class_correct[i]), np.sum(class_total[i])))\n", + " else:\n", + " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", + "\n", + "print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n", + " 100. * np.sum(class_correct) / np.sum(class_total),\n", + " np.sum(class_correct), np.sum(class_total)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Model Performance\n", + "\n", + "Most likely, we will want to see how our model performed throughout each epoch. In this plot, we are visualizing training and validation loss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), training_vis)\n", + "plt.scatter(range(epoch), valid_vis)\n", + "plt.plot(range(epoch), valid_vis)\n", + "plt.savefig('Xception_line.svg')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### **Save the Training and Validation Losses**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.savetxt('Xception_85.txt', np.array([training_vis, valid_vis]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Misclassified\n", + "\n", + "Similarly, we will want to see which types of images it correctly classified. In our case, we plot a randomly sampled batch of our test set and place the correct label in parentheses, and the predicted without." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# obtain one batch of test images\n", + "dataiter = iter(test_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "# get sample outputs\n", + "output = model(images)\n", + "# convert output probabilities to predicted class\n", + "_, preds = torch.max(output, 1)\n", + "# prep images for display\n", + "images = images.numpy()\n", + "labels = labels.numpy()\n", + "print(images.shape)\n", + "# plot the images in the batch, along with predicted and true labels\n", + "fig = plt.figure(figsize=(20, 4))\n", + "for idx in np.arange(20):\n", + " ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])\n", + " ax.imshow(np.swapaxes((images[idx]),axis1=0, axis2=2))\n", + " ax.set_title(\"{} ({})\".format(classes[preds[idx]], classes[labels[idx]]),\n", + " color=(\"green\" if classes[preds[idx]]==classes[labels[idx]] else \"red\"))\n", + " \n", + "plt.savefig('Xception_mis.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus Sanity Check Visualization\n", + "\n", + "Here, we plot the `RGB` channels of the image, but with a twist. We plot the corresponding RGB value inside the color. Super cool sanity check and overall visualization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rgb_img = np.squeeze(images[3])\n", + "channels = ['red channel', 'green channel', 'blue channel']\n", + "\n", + "fig = plt.figure(figsize = (36, 36)) \n", + "for idx in np.arange(rgb_img.shape[0]):\n", + " ax = fig.add_subplot(1, 3, idx + 1)\n", + " img = rgb_img[idx]\n", + " ax.imshow(img, cmap='gray')\n", + " ax.set_title(channels[idx])\n", + " width, height = img.shape\n", + " thresh = img.max()/2.5\n", + " for x in range(width):\n", + " for y in range(height):\n", + " val = round(img[x][y],2) if img[x][y] !=0 else 0\n", + " ax.annotate(str(val), xy=(y,x),\n", + " horizontalalignment='center',\n", + " verticalalignment='center', size=8,\n", + " color='white' if img[x][y]=0.5.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4a9f020 --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from setuptools import find_packages, setup + +setup( + name='src', + packages=find_packages(), + version='0.1.0', + description='Detecting GAN Generated images using Convolutional Neural Networks', + author='Amil Khan', + license='MIT', +) diff --git a/src/data/make_dataset.py b/src/data/make_dataset.py new file mode 100644 index 0000000..96b377a --- /dev/null +++ b/src/data/make_dataset.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +import click +import logging +from pathlib import Path +from dotenv import find_dotenv, load_dotenv + + +@click.command() +@click.argument('input_filepath', type=click.Path(exists=True)) +@click.argument('output_filepath', type=click.Path()) +def main(input_filepath, output_filepath): + """ Runs data processing scripts to turn raw data from (../raw) into + cleaned data ready to be analyzed (saved in ../processed). + """ + logger = logging.getLogger(__name__) + logger.info('making final data set from raw data') + + +if __name__ == '__main__': + log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + logging.basicConfig(level=logging.INFO, format=log_fmt) + + # not used in this stub but often useful for finding various files + project_dir = Path(__file__).resolve().parents[2] + + # find .env automagically by walking up directories until it's found, then + # load up the .env entries as environment variables + load_dotenv(find_dotenv()) + + main() diff --git a/test_environment.py b/test_environment.py new file mode 100644 index 0000000..d0ac4a7 --- /dev/null +++ b/test_environment.py @@ -0,0 +1,25 @@ +import sys + +REQUIRED_PYTHON = "python3" + + +def main(): + system_major = sys.version_info.major + if REQUIRED_PYTHON == "python": + required_major = 2 + elif REQUIRED_PYTHON == "python3": + required_major = 3 + else: + raise ValueError("Unrecognized python interpreter: {}".format( + REQUIRED_PYTHON)) + + if system_major != required_major: + raise TypeError( + "This project requires Python {}. Found: Python {}".format( + required_major, sys.version)) + else: + print(">>> Development environment passes all tests!") + + +if __name__ == '__main__': + main() diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..c32fbd8 --- /dev/null +++ b/tox.ini @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 79 +max-complexity = 10