Skip to content

Commit

Permalink
recipes_source/recipes/reasoning_about_shapes.py ๋ฒˆ์—ญ (#779)
Browse files Browse the repository at this point in the history
* recipes_source/recipes/reasoning_about_shapes.py ๋ฒˆ์—ญ
  • Loading branch information
0seob authored Nov 26, 2023
1 parent 78d7662 commit 0970896
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
33 changes: 15 additions & 18 deletions recipes_source/recipes/reasoning_about_shapes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
"""
Reasoning about Shapes in PyTorch
PyTorch์˜ Shape๋“ค์— ๋Œ€ํ•œ ์ถ”๋ก 
=================================
๋ฒˆ์—ญ: `์ด์˜์„ญ <https://github.com/0seob>`_
When writing models with PyTorch, it is commonly the case that the parameters
to a given layer depend on the shape of the output of the previous layer. For
example, the ``in_features`` of an ``nn.Linear`` layer must match the
``size(-1)`` of the input. For some layers, the shape computation involves
complex equations, for example convolution operations.
์ผ๋ฐ˜์ ์œผ๋กœ PyTorch๋กœ ๋ชจ๋ธ์„ ์ž‘์„ฑํ•  ๋•Œ ํŠน์ • ๊ณ„์ธต์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ์ด์ „ ๊ณ„์ธต์˜ ์ถœ๋ ฅ shape์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด, ``nn.Linear`` ๊ณ„์ธต์˜ ``in_features`` ๋Š” ์ž…๋ ฅ์˜ ``size(-1)`` ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
๋ช‡๋ช‡ ๊ณ„์ธต์˜ ๊ฒฝ์šฐ, shape ๊ณ„์‚ฐ์€ ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ๊ณผ ๊ฐ™์€ ๋ณต์žกํ•œ ๋ฐฉ์ •์‹์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
One way around this is to run the forward pass with random inputs, but this is
wasteful in terms of memory and compute.
์ด๋ฅผ ๋žœ๋คํ•œ ์ž…๋ ฅ์œผ๋กœ ์ˆœ์ „ํŒŒ(forward pass)๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ์ด๋Š” ๋ฉ”๋ชจ๋ฆฌ์™€ ์ปดํ“จํŒ… ํŒŒ์›Œ๋ฅผ ๋‚ญ๋น„ํ•ฉ๋‹ˆ๋‹ค.
Instead, we can make use of the ``meta`` device to determine the output shapes
of a layer without materializing any data.
๋Œ€์‹ ์— ``meta`` ๋””๋ฐ”์ด์Šค๋ฅผ ํ™œ์šฉํ•œ๋‹ค๋ฉด ๋ฐ์ดํ„ฐ๋ฅผ ๊ตฌ์ฒดํ™”ํ•˜์ง€ ์•Š๊ณ ๋„ ๊ณ„์ธต์˜ ์ถœ๋ ฅ shape์„ ๊ฒฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"""

import torch
Expand All @@ -29,8 +26,8 @@


##########################################################################
# Observe that since data is not materialized, passing arbitrarily large
# inputs will not significantly alter the time taken for shape computation.
# ๋ฐ์ดํ„ฐ๊ฐ€ ๊ตฌ์ฒดํ™”๋˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์ž„์˜๋กœ ํฐ ์ž…๋ ฅ์„ ์ „๋‹ฌํ•ด๋„ shape ๊ณ„์‚ฐ์— ์†Œ์š”๋˜๋Š” ์‹œ๊ฐ„์ด
# ํฌ๊ฒŒ ๋ณ€๊ฒฝ๋˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค.

t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")
start = timeit.default_timer()
Expand All @@ -42,7 +39,7 @@


######################################################
# Consider an arbitrary network such as the following:
# ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ž„์˜์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค:

import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -61,23 +58,23 @@ def __init__(self):
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = torch.flatten(x, 1) # ๋ฐฐ์น˜๋ฅผ ์ œ์™ธํ•œ ๋ชจ๋“  ์ฐจ์›์„ ํ‰ํƒ„ํ™” ํ•ฉ๋‹ˆ๋‹ค.
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


###############################################################################
# We can view the intermediate shapes within an entire network by registering a
# forward hook to each layer that prints the shape of the output.
# ๊ฐ๊ฐ์˜ ๊ณ„์ธต์— ์ถœ๋ ฅ์˜ shape์„ ์ธ์‡„ํ•˜๋Š” forward hook์„ ๋“ฑ๋กํ•˜์—ฌ ๋„คํŠธ์›Œํฌ์˜
# ์ค‘๊ฐ„ shape์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def fw_hook(module, input, output):
print(f"Shape of output to {module} is {output.shape}.")


# Any tensor created within this torch.device context manager will be
# on the meta device.
# torch.device context manager(with ๊ตฌ๋ฌธ) ๋‚ด๋ถ€์—์„œ ์ƒ์„ฑ๋œ ๋ชจ๋“  tensor๋Š”
# meta ๋””๋ฐ”์ด์Šค ๋‚ด๋ถ€์— ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.
with torch.device("meta"):
net = Net()
inp = torch.randn((1024, 3, 32, 32))
Expand Down
4 changes: 2 additions & 2 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ Recipes are bite-sized bite-sized, actionable examples of how to use specific Py
:tags: Basics

.. customcarditem::
:header: Reasoning about Shapes in PyTorch
:card_description: Learn how to use the meta device to reason about shapes in your model.
:header: PyTorch์˜ Shape์— ๋Œ€ํ•œ ์ถ”๋ก 
:card_description: meta ๋””๋ฐ”์ด์Šค๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์˜ shape์„ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ์•„๋ด…๋‹ˆ๋‹ค.
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/recipes/reasoning_about_shapes.html
:tags: Basics
Expand Down

0 comments on commit 0970896

Please sign in to comment.