Skip to content

Commit

Permalink
beginner_source/flava_finetuning_tutorial.py ๋ฒˆ์—ญ (#778)
Browse files Browse the repository at this point in the history
* beginner_source/nn_tutorial.py ๋ฒˆ์—ญ
  • Loading branch information
chanmuzi authored May 12, 2024
1 parent dfc7f1e commit 8639a80
Showing 1 changed file with 48 additions and 60 deletions.
108 changes: 48 additions & 60 deletions beginner_source/flava_finetuning_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
# -*- coding: utf-8 -*-
"""
TorchMultimodal Tutorial: Finetuning FLAVA
TorchMultimodal ํŠœํ† ๋ฆฌ์–ผ: FLAVA ๋ฏธ์„ธ์กฐ์ •
============================================
**๋ฒˆ์—ญ:** `๊น€์ฐฌ <https://github.com/chanmuzi>`__
"""


######################################################################
# Multimodal AI has recently become very popular owing to its ubiquitous
# nature, from use cases like image captioning and visual search to more
# recent applications like image generation from text. **TorchMultimodal
# is a library powered by Pytorch consisting of building blocks and end to
# end examples, aiming to enable and accelerate research in
# multimodality**.
#
# In this tutorial, we will demonstrate how to use a **pretrained SoTA
# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
# TorchMultimodal library to finetune on a multimodal task i.e. visual
# question answering** (VQA). The model consists of two unimodal transformer
# based encoders for text and image and a multimodal encoder to combine
# the two embeddings. It is pretrained using contrastive, image text matching and
# text, image and multimodal masking losses.
# ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ AI๋Š” ์ตœ๊ทผ์— ์ด๋ฏธ์ง€ ์ž๋ง‰์ถ”๊ฐ€, ์‹œ๊ฐ์  ๊ฒ€์ƒ‰๋ถ€ํ„ฐ ํ…์ŠคํŠธ๋กœ๋ถ€ํ„ฐ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ๊ฐ™์€
# ์ตœ๊ทผ์˜ ์‘์šฉ๊นŒ์ง€ ๊ทธ ์‚ฌ์šฉ์ด ๋น ๋ฅด๊ฒŒ ํ™•์‚ฐ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. **TorchMultimodal์€ PyTorch๋ฅผ
# ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ, ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์—ฐ๊ตฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๊ณ  ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•œ ๋นŒ๋”ฉ ๋ธ”๋ก๊ณผ
# end-to-end ์˜ˆ์ œ๋“ค์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค**.
#
# ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” **์‚ฌ์ „ ํ›ˆ๋ จ๋œ SoTA ๋ชจ๋ธ์ธ** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **๋ฅผ**
# **TorchMultimodal ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ์‚ฌ์šฉํ•˜์—ฌ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์ž‘์—…์ธ ์‹œ๊ฐ์  ์งˆ์˜ ์‘๋‹ต(VQA)์— ๋ฏธ์„ธ์กฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ ๋“œ๋ฆฌ๊ฒ ์Šต๋‹ˆ๋‹ค.**
# ์ด ๋ชจ๋ธ์€ ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€๋ฅผ ์œ„ํ•œ ๋‘ ๊ฐœ์˜ ๋‹จ์ผ ๋ชจ๋‹ฌ ํŠธ๋žœ์Šคํฌ๋จธ ๊ธฐ๋ฐ˜ ์ธ์ฝ”๋”์™€
# ๋‘ ์ž„๋ฒ ๋”ฉ์„ ๊ฒฐํ•ฉํ•˜๋Š” ๋‹ค์ค‘ ๋ชจ๋‹ฌ ์ธ์ฝ”๋”๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
# ์ด ๋ชจ๋ธ์€ ๋Œ€์กฐ์ , ์ด๋ฏธ์ง€-ํ…์ŠคํŠธ ๋งค์นญ, ๊ทธ๋ฆฌ๊ณ  ํ…์ŠคํŠธ, ์ด๋ฏธ์ง€ ๋ฐ ๋‹ค์ค‘ ๋ชจ๋‹ฌ ๋งˆ์Šคํ‚น ์†์‹ค์„ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์ „ ํ›ˆ๋ จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.



######################################################################
# Installation
# ์„ค์น˜
# -----------------
# We will use TextVQA dataset and ``bert tokenizer`` from Hugging Face for this
# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.
# ์ด ํŠœํ† ๋ฆฌ์–ผ์„ ์œ„ํ•ด์„œ๋Š” TextVQA ๋ฐ์ดํ„ฐ์…‹๊ณผ Hugging Face์˜ ``bert ํ† ํฌ๋‚˜์ด์ €`` ๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.
# ๋”ฐ๋ผ์„œ TorchMultimodal ์™ธ์—๋„ datasets๊ณผ transformers๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
#
# .. note::
#
# When running this tutorial in Google Colab, install the required packages by
# creating a new cell and running the following commands:
#
# ์ด ํŠœํ† ๋ฆฌ์–ผ์„ Google Colab์—์„œ ์‹คํ–‰ํ•  ๊ฒฝ์šฐ, ์ƒˆ๋กœ์šด ์…€์„ ๋งŒ๋“ค๊ณ  ๋‹ค์Œ์˜ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์—ฌ
# ํ•„์š”ํ•œ ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•˜์„ธ์š”:
#
# .. code-block::
#
Expand All @@ -40,32 +41,27 @@
#

######################################################################
# Steps
# ๋‹จ๊ณ„
# -----
#
# 1. Download the Hugging Face dataset to a directory on your computer by running the following command:
# 1. ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์—ฌ Hugging Face ๋ฐ์ดํ„ฐ์…‹์„ ์ปดํ“จํ„ฐ์˜ ๋””๋ ‰ํ† ๋ฆฌ์— ๋‹ค์šด๋กœ๋“œํ•˜์„ธ์š”:
#
# .. code-block::
#
# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
# tar xf vocab.tar.gz
#
# .. note::
# If you are running this tutorial in Google Colab, run these commands
# in a new cell and prepend these commands with an exclamation mark (!)
# ์ด ํŠœํ† ๋ฆฌ์–ผ์„ Google Colab์—์„œ ์‹คํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ, ์ƒˆ ์…€์—์„œ ์ด ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜๊ณ  ๋ช…๋ น์–ด ์•ž์— ๋Š๋‚Œํ‘œ (!)๋ฅผ ๋ถ™์ด์„ธ์š”.
#
#
# 2. For this tutorial, we treat VQA as a classification task where
# the inputs are images and question (text) and the output is an answer class.
# So we need to download the vocab file with answer classes and create the answer to
# label mapping.
# 2. ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” VQA๋ฅผ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ(ํ…์ŠคํŠธ)์ด ์ž…๋ ฅ๋˜๊ณ  ์ถœ๋ ฅ์ด ๋‹ต๋ณ€ ํด๋ž˜์Šค์ธ ๋ถ„๋ฅ˜ ์ž‘์—…์œผ๋กœ ์ทจ๊ธ‰ํ•ฉ๋‹ˆ๋‹ค.
# ๋”ฐ๋ผ์„œ ๋‹ต๋ณ€ ํด๋ž˜์Šค์™€ ๋ ˆ์ด๋ธ” ๋งคํ•‘์„ ์ƒ์„ฑํ•  ๋‹จ์–ด์žฅ ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
#
# We also load the `textvqa
# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples
# (images,questions and answers) from Hugging Face
# ๋˜ํ•œ Hugging Face์—์„œ `textvqa ๋ฐ์ดํ„ฐ์…‹ <https://arxiv.org/pdf/1904.08920.pdf>`__ ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š”๋ฐ,
# ์ด ๋ฐ์ดํ„ฐ์…‹์€ 34602๊ฐœ์˜ ํ›ˆ๋ จ ์ƒ˜ํ”Œ(์ด๋ฏธ์ง€, ์งˆ๋ฌธ, ๋‹ต๋ณ€)์„ ํฌํ•จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
#
# We see there are 3997 answer classes including a class representing
# unknown answers.
# 3997๊ฐœ์˜ ๋‹ต๋ณ€ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Œ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด์—๋Š” ์•Œ ์ˆ˜ ์—†๋Š” ๋‹ต๋ณ€์„ ๋‚˜ํƒ€๋‚ด๋Š” ํด๋ž˜์Šค๋„ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
#

with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
Expand All @@ -81,7 +77,7 @@
dataset = load_dataset("textvqa")

######################################################################
# Lets display a sample entry from the dataset:
# ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ƒ˜ํ”Œ ์—”ํŠธ๋ฆฌ๋ฅผ ํ‘œ์‹œํ•ด ๋ด…์‹œ๋‹ค:
#

import matplotlib.pyplot as plt
Expand All @@ -95,12 +91,10 @@


######################################################################
# 3. Next, we write the transform function to convert the image and text into
# Tensors consumable by our model - For images, we use the transforms from
# torchvision to convert to Tensor and resize to uniform sizes - For text,
# we tokenize (and pad) them using the ``BertTokenizer`` from Hugging Face -
# For answers (i.e. labels), we take the most frequently occurring answer
# as the label to train with:
# 3. ๋‹ค์Œ์œผ๋กœ, ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ๋ฅผ ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€ํ™˜ ํ•จ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.
# - ์ด๋ฏธ์ง€์˜ ๊ฒฝ์šฐ, torchvision์˜ ๋ณ€ํ™˜์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์ผ์ •ํ•œ ํฌ๊ธฐ๋กœ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
# - ํ…์ŠคํŠธ์˜ ๊ฒฝ์šฐ, Hugging Face์˜ ``BertTokenizer`` ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ† ํฐํ™”(๋ฐ ํŒจ๋”ฉ)ํ•ฉ๋‹ˆ๋‹ค.
# - ๋‹ต๋ณ€(์ฆ‰, ๋ ˆ์ด๋ธ”)์˜ ๊ฒฝ์šฐ, ๊ฐ€์žฅ ๋นˆ๋ฒˆํ•˜๊ฒŒ ๋‚˜ํƒ€๋‚˜๋Š” ๋‹ต๋ณ€์„ ํ›ˆ๋ จ ๋ ˆ์ด๋ธ”๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค:
#

import torch
Expand Down Expand Up @@ -133,25 +127,21 @@ def transform(tokenizer, input):


######################################################################
# 4. Finally, we import the ``flava_model_for_classification`` from
# ``torchmultimodal``. It loads the pretrained FLAVA checkpoint by default and
# includes a classification head.
# 4. ๋งˆ์ง€๋ง‰์œผ๋กœ, ``torchmultimodal`` ์—์„œ ``flava_model_for_classification`` ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
# ์ด๊ฒƒ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ FLAVA ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋กœ๋“œํ•˜๊ณ  ๋ถ„๋ฅ˜ ํ—ค๋“œ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
#
# The model forward function passes the image through the visual encoder
# and the question through the text encoder. The image and question
# embeddings are then passed through the multimodal encoder. The final
# embedding corresponding to the CLS token is passed through a MLP head
# which finally gives the probability distribution over each possible
# answers.
# ๋ชจ๋ธ์˜ ์ˆœ๋ฐฉํ–ฅ ํ•จ์ˆ˜๋Š” ์ด๋ฏธ์ง€๋ฅผ ์‹œ๊ฐ ์ธ์ฝ”๋”์— ํ†ต๊ณผ์‹œํ‚ค๊ณ  ์งˆ๋ฌธ์„ ํ…์ŠคํŠธ ์ธ์ฝ”๋”์— ํ†ต๊ณผ์‹œํ‚ต๋‹ˆ๋‹ค.
# ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์˜ ์ž„๋ฒ ๋”ฉ์€ ๊ทธ ํ›„ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์ธ์ฝ”๋”๋ฅผ ํ†ต๊ณผํ•ฉ๋‹ˆ๋‹ค.
# ์ตœ์ข… ์ž„๋ฒ ๋”ฉ์€ CLS ํ† ํฐ์— ํ•ด๋‹นํ•˜๋ฉฐ, ์ด๋Š” MLP ํ—ค๋“œ๋ฅผ ํ†ต๊ณผํ•˜์—ฌ ๊ฐ ๊ฐ€๋Šฅํ•œ ๋‹ต๋ณ€์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
#

from torchmultimodal.models.flava.model import flava_model_for_classification
model = flava_model_for_classification(num_classes=len(vocab))


######################################################################
# 5. We put together the dataset and model in a toy training loop to
# demonstrate how to train the model for 3 iterations:
# 5. ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋ชจ๋ธ์„ ํ•จ๊ป˜ ๋ชจ์•„ 3ํšŒ ๋ฐ˜๋ณต์„ ์œ„ํ•œ ๊ฐ„๋‹จํ•œ ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ์ž‘์„ฑํ•˜์—ฌ
# ๋ชจ๋ธ ํ›ˆ๋ จ ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:
#

from torch import nn
Expand All @@ -177,14 +167,12 @@ def transform(tokenizer, input):


######################################################################
# Conclusion
# ๊ฒฐ๋ก 
# -------------------
#
# This tutorial introduced the basics around how to finetune on a
# multimodal task using FLAVA from TorchMultimodal. Please also check out
# other examples from the library like
# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
# which is a multimodal model for object detection and
# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
# which is multitask model spanning image, video and 3d classification.
# ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” TorchMultimodal์˜ FLAVA๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ์ž‘์—…์— ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š”
# ๊ธฐ๋ณธ์ ์ธ ๋ฐฉ์‹์„ ์†Œ๊ฐœํ–ˆ์Šต๋‹ˆ๋‹ค. ๊ฐ์ฒด ํƒ์ง€๋ฅผ ์œ„ํ•œ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ๋ชจ๋ธ์ธ `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__ ๊ณผ
# ์ด๋ฏธ์ง€, ๋น„๋””์˜ค, 3D ๋ถ„๋ฅ˜๋ฅผ ํฌ๊ด„ํ•˜๋Š” ๋‹ค์ž‘์—… ๋ชจ๋ธ `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
# ๊ฐ™์€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ๋‹ค๋ฅธ ์˜ˆ์ œ๋“ค๋„ ํ™•์ธํ•ด ๋ณด์„ธ์š”.
#
#

0 comments on commit 8639a80

Please sign in to comment.