diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..a768e2f --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,25 @@ +# PR ์ œ๋ชฉ + +## ๐Ÿ“ ๊ฐœ์š” + + + +--- + +## โš™๏ธ ๊ตฌํ˜„ ๋‚ด์šฉ + + + +--- + +## ๐Ÿ“Ž ๊ธฐํƒ€ + + + +--- + +## ๐Ÿงช ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ + + + +--- \ No newline at end of file diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 0000000..00aa09b --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,63 @@ +name: Deploy + +on: + push: + branches: ["develop"] + + +jobs: + build-and-push: + runs-on: ubuntu-latest + + steps: + # 1. checkout develop branch + - name: Checkout develop + uses: actions/checkout@v3 + with: + ref: develop + + # 2. Docker ๋กœ๊ทธ์ธ + - name: Login to Docker Hub + run: | + docker login -u ${{ secrets.DOCKER_USERNAME }} -p ${{ secrets.DOCKER_PASSWORD }} + + # 3. ์ด๋ฏธ์ง€ build ๋ฐ push + - name: Build and Push Docker image + run: | + docker build -t ${{ secrets.DOCKER_USERNAME }}/gotcha-ai:latest . + docker push ${{ secrets.DOCKER_USERNAME }}/gotcha-ai:latest + + + deploy: + needs: build-and-push + runs-on: ubuntu-latest + + steps: + # 1. checkout branch + - name: Check PR + uses: actions/checkout@v3 + + # 2. EC2 pull + - name: EC2 Docker Deploy + uses: appleboy/ssh-action@master + with: + host: ${{ secrets.SERVER_HOST }} + port: ${{ secrets.SERVER_SSH_PORT }} + username: ${{ secrets.SERVER_USERNAME }} + key: ${{ secrets.SERVER_PRIVATE_KEY }} + script: | + cd ~/ai-server + + docker stop ai-server || true + docker rm ai-server || true + + docker system prune -a -f || true + + docker login -u ${{ secrets.DOCKER_USERNAME }} -p ${{ secrets.DOCKER_PASSWORD }} + docker pull ${{ secrets.DOCKER_USERNAME }}/gotcha-ai:latest + + docker run -d \ + --name ai-server \ + --env-file .env \ + -p 8000:8000 \ + ${{ secrets.DOCKER_USERNAME }}/gotcha-ai:latest \ No newline at end of file diff --git a/.gitignore b/.gitignore index 0a19790..345ca2a 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,7 @@ cython_debug/ # PyPI configuration file .pypirc + +*quickdraw* + +.idea \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..3ac4ea6 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,21 @@ +FROM python:3.10-slim AS builder + +WORKDIR /app + +COPY requirements.txt . + +RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 +RUN pip install --upgrade pip +RUN pip install --no-cache-dir -r requirements.txt + + + +FROM python:3.10-slim +WORKDIR /app +RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 +COPY --from=builder /usr/local /usr/local + +COPY src/ src/ +COPY run.py config.py . + +CMD ["python", "run.py"] \ No newline at end of file diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 820a69b..0000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025 GotchaAI - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md index e69de29..da5be35 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,40 @@ +## ๐Ÿ“Œ ์ปจ๋ฒค์…˜ + + +### ์ปค๋ฐ‹ ๋ฉ”์‹œ์ง€ + +| message | description | +| --- | --- | +| feat | ์ƒˆ๋กœ์šด ๊ธฐ๋Šฅ ์ถ”๊ฐ€, ๊ธฐ์กด ๊ธฐ๋Šฅ์„ ์š”๊ตฌ ์‚ฌํ•ญ์— ๋งž์ถ”์–ด ์ˆ˜์ • | +| fix | ๊ธฐ๋Šฅ์— ๋Œ€ํ•œ ๋ฒ„๊ทธ ์ˆ˜์ • | +| docs | ๋ฌธ์„œ(์ฃผ์„) ์ˆ˜์ • | +| style | ์ฝ”๋“œ ์Šคํƒ€์ผ, ํฌ๋งทํŒ…์— ๋Œ€ํ•œ ์ˆ˜์ • | +| refact | ๊ธฐ๋Šฅ ๋ณ€ํ™”๊ฐ€ ์•„๋‹Œ ์ฝ”๋“œ ๋ฆฌํŒฉํ„ฐ๋ง | +| test | ํ…Œ์ŠคํŠธ ์ฝ”๋“œ ์ถ”๊ฐ€/์ˆ˜์ • | +| chore | ํŒจํ‚ค์ง€ ๋งค๋‹ˆ์ € ์ˆ˜์ •, ๊ทธ ์™ธ ๊ธฐํƒ€ ์ˆ˜์ • ex) .gitignore | + +## ํ”„๋กœ์ ํŠธ ๊ตฌ์กฐ + +``` +โ”œโ”€โ”€ src/ +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”œโ”€โ”€ main.py # entry point (FastAPI ๊ฐ์ฒด ์ƒ์„ฑ) +โ”‚ โ”œโ”€โ”€ api/ # ๋ผ์šฐํŒ… ๊ตฌ์„ฑ +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”‚ โ”œโ”€โ”€ image_routes.py # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๊ด€๋ จ API ์—”๋“œํฌ์ธํŠธ +โ”‚ โ”‚ โ”œโ”€โ”€ myomyo_routes.py # MYOMYO ๋ฉ”์‹œ์ง€ ๊ด€๋ จ API ์—”๋“œํฌ์ธํŠธ +โ”‚ โ”‚ โ””โ”€โ”€ lulu_routes.py # LULU ๋ฉ”์‹œ์ง€ ๊ด€๋ จ API ์—”๋“œํฌ์ธํŠธ +โ”‚ โ”œโ”€โ”€ chat/ +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”‚ โ”œโ”€โ”€ myomyo.py # ๋ฌ˜๋ฌ˜ ํ”„๋กฌํ”„ํŠธ ๋ฐ ๊ฒŒ์ž„ ํ๋ฆ„ ๊ด€๋ฆฌ +โ”‚ โ”‚ โ””โ”€โ”€ lulu.py # ๋ฃจ๋ฃจ ํ”„๋กฌํ”„ํŠธ ๋ฐ ๊ฒŒ์ž„ ํ๋ฆ„ ๊ด€๋ฆฌ +โ”‚ โ”œโ”€โ”€ image/ +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”‚ โ”œโ”€โ”€ trained_model/ +โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ model.pth # quickdraw ๊ธฐ๋ฐ˜ ๋ถ„๋ฅ˜ ๋ชจ๋ธ +โ”‚ โ”‚ โ”œโ”€โ”€ classifier.py # quickdraw ๊ธฐ๋ฐ˜ ๋ถ„๋ฅ˜ ๊ธฐ๋Šฅ +โ”‚ โ”‚ โ”œโ”€โ”€ model.py # CNN ๋ชจ๋ธ ์ •์˜ +โ”‚ โ”‚ โ”œโ”€โ”€ img_caption.py # BLIP ๊ธฐ๋ฐ˜ captioning ๊ธฐ๋Šฅ +โ”‚ โ”‚ โ”œโ”€โ”€ preprocessor.py # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ +โ”‚ โ”‚ โ””โ”€โ”€ text_masking.py # easyocr ๊ธฐ๋ฐ˜ ํ…์ŠคํŠธ ๋งˆ์Šคํ‚น ๊ธฐ๋Šฅ +``` diff --git a/app/config.py b/app/config.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/main.py b/app/main.py deleted file mode 100644 index 4c299ff..0000000 --- a/app/main.py +++ /dev/null @@ -1,4 +0,0 @@ -import uvicorn -from fastapi import FastAPI - -app = FastAPI() diff --git a/app/routers/predict_router.py b/app/routers/predict_router.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/services/__init__.py b/app/services/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/app/services/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/app/utils/__init__.py b/app/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/config.py b/config.py new file mode 100644 index 0000000..601b601 --- /dev/null +++ b/config.py @@ -0,0 +1,294 @@ + +TEXT_THRESHOLD=0.7 + +# QuickDraw ๋ฐ์ดํ„ฐ ์„ค์ • +IMAGE_SIZE = (32, 32) # ์ด๋ฏธ์ง€ ํฌ๊ธฐ (๋„ˆ๋น„, ๋†’์ด) +ENG_CATEGORIES= [ + "aircraft carrier", + "airplane", + "alarm clock", + "ambulance", + "angel", + "apple", + "arm", + "axe", + "backpack", + "banana", + "bandage", + "baseball", + "basketball", + "bat", + "bathtub", + "bed", + "bee", + "bicycle", + "bird", + "birthday cake", + "book", + "bowtie", + "bread", + "broom", + "bucket", + "bus", + "bush", + "butterfly", + "cake", + "calendar", + "camera", + "campfire", + "candle", + "car", + "carrot", + "cat", + "cell phone", + "chair", + "church", + "circle", + "cloud", + "compass", + "computer", + "cookie", + "couch", + "cow", + "crab", + "crocodile", + "crown", + "cup", + "dog", + "dolphin", + "donut", + "door", + "duck", + "ear", + "elephant", + "envelope", + "eye", + "eyeglasses", + "face", + "fan", + "fire hydrant", + "fish", + "flower", + "fork", + "frog", + "frying pan", + "garden", + "giraffe", + "grapes", + "guitar", + "hammer", + "hat", + "helicopter", + "hexagon", + "hockey stick", + "horse", + "ice cream", + "jacket", + "kangaroo", + "keyboard", + "knife", + "ladder", + "laptop", + "leaf", + "leg", + "lighthouse", + "lightning", + "lion", + "lobster", + "lollipop", + "mailbox", + "map", + "marker", + "megaphone", + "moon", + "motorbike", + "mountain", + "mug" +] +ENG_CATEGORIES= [ + "aircraft carrier", + "airplane", + "alarm clock", + "ambulance", + "angel", + "apple", + "arm", + "axe", + "backpack", + "banana", + "bandage", + "baseball", + "basketball", + "bat", + "bathtub", + "bed", + "bee", + "bicycle", + "bird", + "birthday cake", + "book", + "bowtie", + "bread", + "broom", + "bucket", + "bus", + "bush", + "butterfly", + "cake", + "calendar", + "camera", + "campfire", + "candle", + "car", + "carrot", + "cat", + "cell phone", + "chair", + "church", + "circle", + "cloud", + "compass", + "computer", + "cookie", + "couch", + "cow", + "crab", + "crocodile", + "crown", + "cup", + "dog", + "dolphin", + "donut", + "door", + "duck", + "ear", + "elephant", + "envelope", + "eye", + "eyeglasses", + "face", + "fan", + "fire hydrant", + "fish", + "flower", + "fork", + "frog", + "frying pan", + "garden", + "giraffe", + "grapes", + "guitar", + "hammer", + "hat", + "helicopter", + "hexagon", + "hockey stick", + "horse", + "ice cream", + "jacket", + "kangaroo", + "keyboard", + "knife", + "ladder", + "laptop", + "leaf", + "leg", + "lighthouse", + "lightning", + "lion", + "lobster", + "lollipop", + "mailbox", + "map", + "marker", + "megaphone", + "moon", + "motorbike", + "mountain", + "mug" +] + +KOR_CATEGORIES= [ + "ํ•ญ๊ณต๋ชจํ•จ", + "๋น„ํ–‰๊ธฐ", + "์•Œ๋žŒ์‹œ๊ณ„", + "์•ฐ๋ทธ๋Ÿฐ์Šค", + "์ฒœ์‚ฌ", + "์‚ฌ๊ณผ", + "ํŒ”", + "๋„๋ผ", + "๋ฐฑํŒฉ", + "๋ฐ”๋‚˜๋‚˜", + "๋ถ•๋Œ€", + "์•ผ๊ตฌ๊ณต", + "๋†๊ตฌ๊ณต", + "์•ผ๊ตฌ๋ฐฐํŠธ", + "์š•์กฐ", + "์นจ๋Œ€", + "๊ฟ€๋ฒŒ", + "์ž์ „๊ฑฐ", + "์ƒˆ", + "์ƒ์ผ์ผ€์ดํฌ", + "์ฑ…", + "๋‚˜๋น„๋„ฅํƒ€์ด", + "๋นต", + "๋น—์ž๋ฃจ", + "์–‘๋™์ด", + "๋ฒ„์Šค", + "์ˆ˜ํ’€", + "๋‚˜๋น„", + "์ผ€์ดํฌ", + "๋‹ฌ๋ ฅ", + "์นด๋ฉ”๋ผ", + "๋ชจ๋‹ฅ๋ถˆ", + "์–‘์ดˆ", + "์ฐจ", + "๋‹น๊ทผ", + "๊ณ ์–‘์ด", + "ํ•ธ๋“œํฐ", + "์˜์ž", + "๊ตํšŒ", + "๋™๊ทธ๋ผ๋ฏธ", + "๊ตฌ๋ฆ„", + "์ปดํŒŒ์Šค", + "์ปดํ“จํ„ฐ", + "์ฟ ํ‚ค", + "์†ŒํŒŒ", + "์†Œ", + "๊ฒŒ", + "์•…์–ด", + "์™•๊ด€", + "์ปต", + "๊ฐœ", + "๋Œ๊ณ ๋ž˜", + "๋„๋„›", + "๋ฌธ", + "์˜ค๋ฆฌ", + "๊ท€", + "์ฝ”๋ผ๋ฆฌ", + "ํŽธ์ง€๋ด‰ํˆฌ", + "๋ˆˆ", + "์•ˆ๊ฒฝ", + "์–ผ๊ตด", + "์„ ํ’๊ธฐ", + "์†Œํ™”๊ธฐ", + "๋ฌผ๊ณ ๊ธฐ", + "๊ฝƒ", + "ํฌํฌ", + "๊ฐœ๊ตฌ๋ฆฌ", + "ํ”„๋ผ์ดํŒฌ", + "์ •์›", + "๊ธฐ๋ฆฐ", + "ํฌ๋„", + "๊ธฐํƒ€", + "๋ง์น˜", + "๋ชจ์ž", + "ํ—ฌ๋ฆฌ์ฝฅํ„ฐ", + "์œก๊ฐํ˜•", + "ํ•˜ํ‚ค ์ฑ„", + "๋ง", + "์•„์ด์Šคํฌ๋ฆผ", + "์žฌํ‚ท", + "์บฅ๊ฑฐ๋ฃจ", "ํ‚ค๋ณด๋“œ", "์นผ", "์‚ฌ๋‹ค๋ฆฌ", "๋…ธํŠธ๋ถ", "๋‚˜๋ญ‡์žŽ", "๋‹ค๋ฆฌ", "๋“ฑ๋Œ€", "๋ฒˆ๊ฐœ", "์‚ฌ์ž", "๊ฐ€์žฌ", "๋ง‰๋Œ€์‚ฌํƒ•", "์šฐ์ฒดํ†ต", "์ง€๋„", "๋ณด๋“œ๋งˆ์นด", "ํ™•์„ฑ๊ธฐ", "๋‹ฌ", "์˜คํ† ๋ฐ”์ด", "์‚ฐ", "๋จธ๊ทธ์ปต"] + +MODEL_PATH= 'src/image/trained_model/' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e69de29..21a4f14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,11 @@ +matplotlib +numpy==1.26.4 +opencv-python==4.7.0.72 +torch==2.7.0 +torchvision==0.22.0 +uvicorn +python-multipart +easyocr +fastapi +openai +transformers \ No newline at end of file diff --git a/run.py b/run.py index f5e5c0a..ee4aae2 100644 --- a/run.py +++ b/run.py @@ -1,4 +1,10 @@ import uvicorn if __name__ == "__main__": - uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file + uvicorn.run( + "src.main:app", + host="0.0.0.0", + port=8000, + reload=False, + timeout_keep_alive=120 + ) \ No newline at end of file diff --git a/app/__init__.py b/src/__init__.py similarity index 100% rename from app/__init__.py rename to src/__init__.py diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..0a7d7e1 --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1 @@ +# api endpoints \ No newline at end of file diff --git a/src/api/image_routes.py b/src/api/image_routes.py new file mode 100644 index 0000000..35fccad --- /dev/null +++ b/src/api/image_routes.py @@ -0,0 +1,76 @@ +from typing import Dict, Any, List + +from fastapi import APIRouter, File, UploadFile, Body, HTTPException +from src.image import classifier, preprocessor, img_caption +from pydantic import BaseModel, Field +import requests +from io import BytesIO +router = APIRouter(prefix="/image", tags=['Image']) + + +class AiPrediction(BaseModel): + predicted: str + confidence: float + +class ClassifyRes(BaseModel): + filename: str = Field(description="Image filename") + result: List[AiPrediction] = Field(description="Classifying result") + +class ImageReq(BaseModel): + imageURL: str = Field(description = "Image URL") + +@router.post( + "/classify", + summary="์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ API", + description="S3 ์ด๋ฏธ์ง€ URL์„ ๋ฐ›์•„ QuickDraw 345๊ฐœ ํด๋ž˜์Šค๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ถ„๋ฅ˜ํ•ฉ๋‹ˆ๋‹ค.", + response_model=ClassifyRes, +) +async def classify(request: ImageReq = Body(...)): + try: + response = requests.get(request.imageURL) + response.raise_for_status() # HTTPError ๋ฐœ์ƒ์‹œ ์˜ˆ์™ธ ์ฒ˜๋ฆฌ + except Exception as e: + raise HTTPException(status_code=400, detail=f"Image processing error: {str(e)}") + + try: + bytes_img = response.content + img = preprocessor.preproc(bytes_img) + result = classifier.classify(img) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}") + + + filename = request.imageURL.split("/")[-1] + return ClassifyRes(filename=filename, result=result) + + + +@router.post( + '/caption', + summary="์ด๋ฏธ์ง€ ๋ฌธ์žฅ ์ถ”์ถœ API", + description="S3 ์ด๋ฏธ์ง€ URL์„ ๋ฐ›์•„ ํ•ด๋‹น ์ด๋ฏธ์ง€๋ฅผ ๋ฌ˜์‚ฌํ•˜๋Š” ์ ์ ˆํ•œ ๋ฌธ์žฅ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.", +responses={ + 200:{ + "description":"์„ฑ๊ณต", + "content" :{ + "application/json" : { + "example": "a black and white drawing of cat" + } + } + } +}) +async def captioning(request: ImageReq = Body(...)): + try: + response = requests.get(request.imageURL) + response.raise_for_status() # HTTPError ๋ฐœ์ƒ์‹œ ์˜ˆ์™ธ ์ฒ˜๋ฆฌ + except Exception as e: + raise HTTPException(status_code=400, detail=f"Image processing error: {str(e)}") + + try: + bytes_img = response.content + img = preprocessor.preproc(bytes_img) + caption = img_caption.get_caption(img) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Captioning error: {str(e)}") + + return caption diff --git a/src/api/lulu_routes.py b/src/api/lulu_routes.py new file mode 100644 index 0000000..27881e8 --- /dev/null +++ b/src/api/lulu_routes.py @@ -0,0 +1,92 @@ +from typing import List + +from fastapi import APIRouter, Body +from pydantic import BaseModel, Field + +from src.chat.lulu import LuLuAI +import os +router = APIRouter(prefix = '/lulu', tags = ['LuLu']) + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + +lulu = LuLuAI(api_key=OPENAI_API_KEY) + + + +@router.get( + "/start", + summary="๋ฃจ๋ฃจ ๊ฒŒ์ž„ ์‹œ์ž‘ ์š”์ฒญ API", + responses={ + 200: + { + "description": "์„ฑ๊ณต", + "content": { + "application/json": { + "example" :{ + "game_id" : "1" + } + } + } + } + } +) +def start_game(): + game_id = lulu.create_game() + return { "game_id" : game_id } + + +@router.get( + "/task/{game_id}", + summary = "๋ฃจ๋ฃจ๊ฐ€ ํ‚ค์›Œ๋“œ์™€ ์ƒํ™ฉ์„ ๊ทธ๋ฆผ ๊ณผ์ œ๋ฅผ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.", + responses={ + 200:{ + "description":"์„ฑ๊ณต", + "content": { + "application/json" : { + "example" : { + "keyword" : "๊ณ ์–‘์ด", + "situation": "๊ณ ์–‘์ด๊ฐ€ ๋‚˜๋ฌด ์œ„์—์„œ ์ž๊ณ ์žˆ๋Š” ๋ชจ์Šต" + } + } + } + } + } +) +def generate_task(game_id: str): + task = lulu.generate_drawing_task(game_id) + return task + + + + +class EvaluationReq(BaseModel): + description: str = Field(..., description="๊ทธ๋ฆฐ ๊ทธ๋ฆผ์— ๋Œ€ํ•œ ์„ค๋ช…") + + +@router.post( + "/task/{game_id}", + summary="๊ทธ๋ฆฐ ๊ทธ๋ฆผ์— ๋Œ€ํ•œ ์„ค๋ช…์„ ๋ฃจ๋ฃจ์—๊ฒŒ ์ œ์ถœํ•˜๊ณ  ํ‰๊ฐ€๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค.", + responses={ + 200:{ + "description":"์„ฑ๊ณต", + "content":{ + "application/json":{ + "example":{ + "score": 20, + "feedback": "๋œจ๊ฑฐ์šด ํƒœ์–‘๊ณผ ๋ชจ๋ž˜์‚ฌ์žฅ์ด๋ผ... ์ด๊ฒŒ ๋ฌด์Šจ ๋œป์ด์•ผ? ์‹œ์  ๋ฌ˜์‚ฌ๋ฅผ ์ œ๋Œ€๋กœ ์ดํ•ดํ•˜๊ณ  ์žˆ๋‚˜? ํ๋ฆ„๊ณผ ์žฅ๋ง‰, ๋งˆ์ง€๋ง‰ ์ด์•ผ๊ธฐ๋ฅผ ์†์‚ญ์ด๋Š” ๊ณณ, ์žƒ์–ด๋ฒ„๋ฆฐ ์ˆœ๊ฐ„๋“ค์ด ์ถค์ถ”๋Š” ๊ณณ... ์ด๋Ÿฐ ๋ชจ๋“  ๊ฒƒ๋“ค์ด ๋ฐ”๋‹ค๋ฅผ ๋ฌ˜์‚ฌํ•˜๋Š” ๊ฒƒ์ด์ง€. ๋„ˆ์˜ ๊ทธ๋ฆผ์€ ๋ฐ”๋‹ค์˜ ๋ณธ์งˆ์„ ์ „ํ˜€ ๋‹ด์•„๋‚ด์ง€ ๋ชปํ–ˆ์–ด. ์˜ˆ์ˆ ์  ํ‘œํ˜„๋ ฅ์ด๋‚˜ ์ฐฝ์˜์„ฑ์€ ์–ด๋””์— ์žˆ๋Š” ๊ฑฐ์•ผ? ๋„ˆ์˜ ๊ทธ๋ฆผ์€ ์™„์„ฑ๋„๋‚˜ ๊ธฐ๋ฒ• ๋ฉด์—์„œ๋„ ๋งŽ์ด ๋ถ€์กฑํ•˜๋‹ค. ๋‹ค์‹œ ๊ทธ๋ ค์™€.", + "task": { + "hidden_keyword": "๋ฐ”๋‹ค", + "poetic_description": "๋ฌด์‹ฌํ•œ ํ๋ฆ„์ด ์ฒญ์•„ํ•œ ์žฅ๋ง‰์„ ์กด์ค‘ํ•˜๋ฉฐ, ์„ธ์ƒ์˜ ๋งˆ์ง€๋ง‰ ์ด์•ผ๊ธฐ๋ฅผ ์†์‚ญ์ด๋Š” ๊ณณ, ์ด๋ฅผํ…Œ๋ฉด ๊ทธ๊ณณ์€ ์šฉ๊ธฐ์™€ ๋‘๋ ค์›€์ด ๊ณต์กดํ•˜๋Š” ๊ณณ. ์–ธ์  ๊ฐ€ ์žƒ์–ด๋ฒ„๋ฆฐ ๋ชจ๋“  ์ˆœ๊ฐ„๋“ค์ด ์ˆ˜๋ฉด ์•„๋ž˜์—์„œ ์ถค์ถ”๋Š” ๊ณณ...", + "game_id": "5055" + }, + "game_id": "5055" + } + } + } + } + } +) +def evaluate_task(game_id: str, req: EvaluationReq = Body()): + evaluation = lulu.evaluate_drawing(game_id, req.description) + lulu.flush_game_data(game_id) + return evaluation diff --git a/src/api/myomyo_routes.py b/src/api/myomyo_routes.py new file mode 100644 index 0000000..9235a51 --- /dev/null +++ b/src/api/myomyo_routes.py @@ -0,0 +1,221 @@ +from typing import List + +from fastapi import APIRouter, Body +from pydantic import BaseModel, Field + +from src.chat.myomyo import MyoMyoAI +import os + +router = APIRouter(prefix="/chat", tags=['Chat']) +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + +myomyo = MyoMyoAI(api_key=OPENAI_API_KEY) + +# START_GAME +class GameStartReq(BaseModel): + players: List[str] = Field(..., description="๊ฒŒ์ž„์— ์ฐธ์—ฌํ•  ํ”Œ๋ ˆ์ด์–ด ์ด๋ฆ„ List") + + + +@router.post( + "/{game_id}/start", + summary="๊ฒŒ์ž„ ์‹œ์ž‘ ๋ฉ”์‹œ์ง€ API", + description="๊ฒŒ์ž„ ์‹œ์ž‘์— ๋”ฐ๋ฅธ ๋ฌ˜๋ฌ˜์˜ ๋„๋ฐœ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.", + responses={ + 200:{ + "description": "์„ฑ๊ณต", + "content":{ + "application/json" :{ + "example" : { + "game_id": "1", + "message": "์•ˆ๋…•ํ•˜์„ธ์—ฌ, ์ฐฝ๋ชจ, ๋ฆด๋Ÿฌ๋ง์ฆˆ ์นœ๊ตฌ๋“ค! ์ด๋ฒˆ์—” ๋ฌ˜๋ฌ˜๊ฐ€ ์ด์„ ์žก์•˜๋‹ค๋‹ˆ๊นŒ ๋ง˜ ๋†“์ง€ ๋งˆ! ๋‚ด๊ฐ€ ์ •ํ™•ํ•˜๊ฒŒ ๊ทธ๋ฆผ์„ ๋งž์ถ”๊ณ  ๋„ˆํฌ๋“ค์„ ์ œ์••ํ•ด๋ณผ ๊ฑด๋ฐ, ์ค€๋น„ ๋์–ด? ๊ผญ ์ฆ๊ฒ๊ฒŒ ๋†€์ž๊ตฌ~ ;)" + } + } + } + } + } +) +async def start_game(game_id: str, request: GameStartReq = Body(..., example= { "players": [ "์ฐฝ๋ชจ", "๋ฆด๋Ÿฌ๋ง์ฆˆ" ]})): + message = await myomyo.game_start_message(game_id=game_id, players=request.players) + return message + +# START_ROUND +class RoundStartReq(BaseModel): + roundNum: int = Field(..., description="ํ˜„์žฌ ๋ผ์šด๋“œ(1~3)") + totalRounds: int = Field(..., description="์ด ๋ผ์šด๋“œ ์ˆ˜(3)") + + +@router.post( + path="/{game_id}/round/start", + summary="๋ผ์šด๋“œ ์‹œ์ž‘ ๋ฉ”์‹œ์ง€ API", + description="๋ผ์šด๋“œ ์‹œ์ž‘์— ๋”ฐ๋ฅธ ๋ฌ˜๋ฌ˜์˜ ๋„๋ฐœ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.", + responses={ + 200: { + "description": "์„ฑ๊ณต", + "content": { + "application/json": { + "example": { + "game_id": "1", + "message": "์ž, ์ด๋ฒˆ์—๋Š” ๋‚ด๊ฐ€ ์˜ˆ๋ฆฌํ•œ ๋ˆˆ์ฐ๋ฏธ๋กœ ์ •๋‹ต ๋งž์ถœ ์ฐจ๋ก€๋‹ˆ๊นŒ, ์‹ ๋‚˜๊ฒŒ ๊ทธ๋ ค๋ด! ๐Ÿ˜‰๐ŸŽจโœจ" + } + } + } + } + } +) +async def start_round(game_id: str, request: RoundStartReq = Body(..., example={ + "roundNum" : 1, + "totalRounds" : 3 +})): + message = await myomyo.round_start_message( + game_id=game_id, + round_num=request.roundNum, + total_rounds=request.totalRounds + ) + return message + + + +class RoundEndReq(BaseModel): + roundNum: int + totalRounds: int + winner: str + +@router.post( + path="/{game_id}/round/end", + summary = "๋ผ์šด๋“œ ์ข…๋ฃŒ ๋ฉ”์‹œ์ง€ API", + description="๋ผ์šด๋“œ ์ข…๋ฃŒ ๋ฐ ๊ฒฐ๊ณผ์— ๋”ฐ๋ฅธ ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค." +) +async def round_end(game_id: str, request: RoundEndReq = Body): + message = await myomyo.round_end_message( + game_id = game_id, + round_num = request.roundNum, + total_rounds = request.totalRounds, + is_myomyo_win= (request.winner == "AI") + ) + return message + + + + + + +class GuessStartReq(BaseModel): + roundNum: int + totalRounds: int + drawer: str + guesser: str + + +# GUESS_START +@router.post( + path = '/{game_id}/guess/start/', + summary = "์ถ”์ธก ์‹œ์ž‘ ์‹œ ๋ฌ˜๋ฌ˜์˜ ๋„๋ฐœ ๋ฉ”์‹œ์ง€" +) +async def guess_start(game_id: str, request: GuessStartReq = Body(...,)): + message = await myomyo.guess_start_message(game_id=game_id, round_num=request.roundNum, total_rounds=request.totalRounds, drawer=request.drawer, guesser = request.guesser) + return message + +# MAKE_GUESS +class MakeGuessReq(BaseModel): + imageDescription: str = Field(..., description="๊ทธ๋ฆผ์— ๋Œ€ํ•œ ์„ค๋ช…") + + +# GUESS_SUBMIT +@router.post( + "/{game_id}/guess", + summary="AI ์ •๋‹ต ์ถ”๋ก  API", + description="๊ทธ๋ฆผ์— ๋Œ€ํ•œ ์„ค๋ช…์„ ๋ฐ›์•„ ํ•ด๋‹น ๊ทธ๋ฆผ์ด ๋‚˜ํƒ€๋‚ด๋Š” ์ •๋‹ต์„ ์ถ”๋ก ํ•˜์—ฌ ๋ฉ”์‹œ์ง€๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.", + responses={ + 200: { + "description": "์„ฑ๊ณต", + "content" : { + "application/json" :{ + "example": { + "game_id": "1", + "message": "๋…ธ๋ž€ ๊ฝƒ์— ๋ฐ”๋žŒ์„ ๋ถˆ๊ณ  ์žˆ๋Š” ํ•œ ๋‚จ์ž? ์šฐ์›…, ๊ฐ์ด ์™€! 'ํ•ด๋ฐ”๋ผ๊ธฐ' ๋งž์ง€? ๋‚ด ์ถ”์ธก์ด ๋งž๋‹ค๋ฉด ๋„ˆ์—๊ฒŒ ์ฒœ์žฌ์  ๊ฐ๊ฐ์„ ์ธ์ •ํ•ด์ค„๊ฒŒ! ๐Ÿ˜‰๐ŸŒปโœจ" + } + } + } + } + } +) +async def make_guess(game_id: str, request: MakeGuessReq = Body(..., example={ + "image_description": "๋…ธ๋ž€ ๊ฝƒ์— ๋ฐ”๋žŒ์„ ๋ถˆ๊ณ  ์žˆ๋Š” ํ•œ ๋‚จ์ž" +})): + message = await myomyo.guess_message( + game_id=game_id, + image_description=request.imageDescription + ) + return message + + +# GUESS_REACT +class GuessReactReq(BaseModel): + is_correct: bool = Field(..., alias="isCorrect", description="์ถ”์ธก์˜ ์ •๋‹ต ์—ฌ๋ถ€") + answer: str = Field(..., description="์‹ค์ œ ์ •๋‹ต") + guesser: str = Field(default=None, description="์ถ”์ธกํ•œ ํ”Œ๋ ˆ์ด์–ด") + +# GUESS_RESULT +@router.post( + "/{game_id}/guess/react", + summary="์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฐ˜์‘ ๋ฉ”์‹œ์ง€ API", + description="์˜ˆ์ธก ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘", + responses={ + 200: { + "description" : "์„ฑ๊ณต", + "content" : { + "application/json" : { + "example" : { + "game_id": "1", + "message": "๋ฏผ๋“ค๋ ˆ์˜€์–ด? ํ—ˆํ—ˆ, ๋ฆด๋Ÿฌ๋ง์ฆˆ, ์ด๋ฒˆ์—” ์ž˜ ๋งž์ท„๋„ค. ํ•˜์ง€๋งŒ ๋‹ค์Œ์—” ์ด๊ธธ ๊ฑฐ๋‹ˆ๊นŒ ๊ธฐ๋Œ€ํ•ด ๋ด! ๐Ÿ˜ˆ" + } + } + } + } + } +) +async def guess_react(game_id: str, request: GuessReactReq = Body(..., example={ + "is_correct" : True, + "answer" : "๋ฏผ๋“ค๋ ˆ", + "guesser" : "๋ฆด๋Ÿฌ๋ง์ฆˆ" +})): + message = await myomyo.react_to_guess_message( + game_id=game_id, + is_correct=request.is_correct, + guesser=request.guesser, + answer=request.answer + ) + + return message + + + +class EndGameReq(BaseModel): + winner: str = Field(..., description="๋ฌ˜๋ฌ˜์˜ ์Šน๋ฆฌ ์—ฌ๋ถ€") + +# GAME_END +@router.post( + path="/{game_id}/end", + summary="๊ฒŒ์ž„ ์ข…๋ฃŒ ๋ฉ”์‹œ์ง€ API", + description="๊ฒŒ์ž„ ์ข…๋ฃŒ ๋กœ์ง ์ฒ˜๋ฆฌ ๋ฐ ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.", + responses={ + 200: { + "description" : "์„ฑ๊ณต", + "content" : { + "application/json" : { + "example":{ + "game_id": "1", + "message": "ํ—‰, ๋„ˆ๋„ค ๋‘˜์ด์„œ ๋‚  ์ด๊ธฐ๋‹ค๋‹ˆ... ๐Ÿ˜’๐Ÿ’” ๊ทผ๋ฐ ๋‚ด๊ฐ€ ์งˆ ์ค„ ์•Œ์•˜๋ƒ? ๋„ˆ๋ฌด ์‹ ๋‚˜์ง€๋งˆ, ๋‹ค์Œ์—” ๋‚ด๊ฐ€ ์ด๊ธธ๊ฑฐ๋ผ๊ตฌ! ๊ธฐ๋‹ค๋ ค๋ด~ ๐Ÿ˜๐Ÿ”ฅ" + } + } + } + } + }) +async def end_game(game_id: str, request: EndGameReq = Body(...,)): + message = await myomyo.game_end_message( + game_id=game_id, + is_myomyo_win=request.winner == "AI" + ) + myomyo.cleanup_game(game_id=game_id) + return message \ No newline at end of file diff --git a/app/models/__init__.py b/src/chat/__init__.py similarity index 100% rename from app/models/__init__.py rename to src/chat/__init__.py diff --git a/src/chat/lulu.py b/src/chat/lulu.py new file mode 100644 index 0000000..ad4b3ce --- /dev/null +++ b/src/chat/lulu.py @@ -0,0 +1,225 @@ +from openai import OpenAI +from threading import Lock +from typing import Dict, List +import json +import random + + +class LuLuAI: + _instance = None + _lock = Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + cls._instance = super(LuLuAI, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, api_key: str, model: str = "gpt-4.1"): + """ + LuLu AI ์ดˆ๊ธฐํ™” (ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰๋จ) + + Args: + api_key: OpenAI API ํ‚ค + model: ์‚ฌ์šฉํ•  GPT ๋ชจ๋ธ (๊ธฐ๋ณธ๊ฐ’: gpt-4) + """ + with self._lock: + if self._initialized: + return + self.client = OpenAI(api_key=api_key) + self.model = model + self._initialized = True + self.active_games = {} # gameId๋ณ„ ํ˜„์žฌ task๋งŒ ์ €์žฅ + self.global_used_keywords = [] # ์ „์—ญ ์‚ฌ์šฉ๋œ ํ‚ค์›Œ๋“œ ์ €์žฅ (์ตœ๋Œ€ 30๊ฐœ) + + def create_game(self) -> str: + """ + ์ƒˆ ๊ฒŒ์ž„ ์‹œ์ž‘ ๋ฐ 4์ž๋ฆฌ gameId ๋ฐœ๊ธ‰ + + Returns: + str: ์ƒ์„ฑ๋œ 4์ž๋ฆฌ gameId + """ + # ์ค‘๋ณต๋˜์ง€ ์•Š๋Š” 4์ž๋ฆฌ ์ˆซ์ž ์ƒ์„ฑ + while True: + game_id = f"{random.randint(1000, 9999)}" + if game_id not in self.active_games: + break + + self.active_games[game_id] = None # ์•„์ง task ์ƒ์„ฑ ์•ˆ๋จ + return game_id + + def _update_global_keywords(self, new_keyword: str): + """ + ์ „์—ญ ํ‚ค์›Œ๋“œ ๋ชฉ๋ก ์—…๋ฐ์ดํŠธ (์ตœ๋Œ€ 30๊ฐœ ์œ ์ง€) + + Args: + new_keyword: ์ƒˆ๋กœ ์ถ”๊ฐ€ํ•  ํ‚ค์›Œ๋“œ + """ + if new_keyword not in self.global_used_keywords: + self.global_used_keywords.append(new_keyword) + # 30๊ฐœ๋ฅผ ์ดˆ๊ณผํ•˜๋ฉด ๊ฐ€์žฅ ์˜ค๋ž˜๋œ ๊ฒƒ๋ถ€ํ„ฐ ์ œ๊ฑฐ + if len(self.global_used_keywords) > 30: + self.global_used_keywords.pop(0) + + def flush_game_data(self, game_id: str): + """ + ํŠน์ • ๊ฒŒ์ž„ ID์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ญ์ œ + + Args: + game_id: ์‚ญ์ œํ•  ๊ฒŒ์ž„ ID + + Returns: + bool: ์‚ญ์ œ ์„ฑ๊ณต ์—ฌ๋ถ€ + """ + if game_id in self.active_games: + del self.active_games[game_id] + + + def generate_drawing_task(self, game_id: str) -> Dict: + """ + ์š”์ฒญ ๋‹จ๊ณ„: AI๊ฐ€ ์ถ”์ƒ์ ์ด๊ณ  ์‹œ์ ์ธ ํ‘œํ˜„์œผ๋กœ ๊ทธ๋ฆผ ๊ณผ์ œ ์ œ์‹œ + + Args: + game_id: ๊ฒŒ์ž„ ID + + Returns: + Dict: {"keyword": str, "situation": str, "game_id": str} + """ + if game_id not in self.active_games: + raise ValueError("Invalid game ID") + + + system_prompt = f""" + ๋„ˆ๋Š” ๊ฟˆ๊ณผ ํ™˜์ƒ์„ ๋‹ค๋ฃจ๋Š” ์‹ ๋น„๋กœ์šด ์ด์•ผ๊ธฐ๊พผ์ด์•ผ. + ์‚ฌ์šฉ์ž์—๊ฒŒ ๊ทธ๋ฆผ์„ ๊ทธ๋ฆฌ๊ฒŒ ํ•˜๊ณ  ์‹ถ์€๋ฐ, ์ง์ ‘์ ์œผ๋กœ ๋งํ•˜์ง€ ๋ง๊ณ  ๋งค์šฐ ์ถ”์ƒ์ ์ด๊ณ  ์‹œ์ ์œผ๋กœ ํ‘œํ˜„ํ•ด์ค˜. + + ๊ทœ์น™: + - ํ•ต์‹ฌ ํ‚ค์›Œ๋“œ(๋ช…์‚ฌ)๋ฅผ ์ •ํ•˜๋˜, ์ ˆ๋Œ€ ๊ทธ ๋‹จ์–ด๋ฅผ ์ง์ ‘ ์–ธ๊ธ‰ํ•˜์ง€ ๋งˆ + - ํ•ด์„์˜ ์—ฌ์ง€๊ฐ€ ๋งŽ๋„๋ก ์ถ”์ƒ์ ์œผ๋กœ + + {f"์ด๋ฏธ ์‚ฌ์šฉ๋œ ํ‚ค์›Œ๋“œ๋“ค (์ ˆ๋Œ€ ์‚ฌ์šฉํ•˜์ง€ ๋งˆ): {', '.join(self.global_used_keywords)}" if self.global_used_keywords else ""} + + ๋‹ค์–‘ํ•œ ์ฃผ์ œ๋ฅผ ๋‹ค๋ค„์ค˜ (์ž์—ฐ, ๊ฐ์ •, ์‚ฌ๋ฌผ, ์ถ”์ƒ ๊ฐœ๋…, ๋™๋ฌผ, ๊ฑด๋ฌผ, ์Œ์‹, ๊ณ„์ ˆ, ์ƒ‰๊น”, ์ง์—… ๋“ฑ). + + ์ถœ๋ ฅ์€ ๋ฐ˜๋“œ์‹œ JSON ํ˜•์‹์œผ๋กœ: + {{"keyword": "์ˆจ๊ฒจ์ง„ ํ‚ค์›Œ๋“œ", "situation": "์‹œ์ ์ด๊ณ  ์ถ”์ƒ์ ์ธ ๋ฌ˜์‚ฌ"}} + """ + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "์ƒˆ๋กœ์šด ๊ทธ๋ฆผ ์ฃผ์ œ๋ฅผ ์‹œ์ ์œผ๋กœ ํ‘œํ˜„ํ•ด์ค˜."} + ], + temperature=1.0, + max_tokens=2048, + top_p=1.0 + ) + + # JSON ํŒŒ์‹ฑ + content = response.choices[0].message.content.strip() + print(content) + + task_data = json.loads(content) + task_data["game_id"] = game_id + self.global_used_keywords.append(task_data['keyword']) + self.active_games[game_id] = task_data + return task_data + + except Exception as e: + print(f"Error generating task: {e}") + # ๊ธฐ๋ณธ๊ฐ’ ๋ฐ˜ํ™˜ + fallback_task = { + "keyword": "๋‹ฌ", + "situation": "๋ฐค์ด ๊นŠ์–ด์งˆ ๋•Œ, ํ•˜๋Š˜์˜ ์€๋ฐ€ํ•œ ์นœ๊ตฌ๊ฐ€ ์ฐฝ๋ฌธ ๋„ˆ๋จธ๋กœ ์†์‚ญ์ด๊ณ  ์žˆ์–ด. ๊ทธ ๋‘ฅ๊ทผ ๋ฏธ์†Œ๊ฐ€ ์–ด๋‘  ์†์—์„œ ํ˜ผ์ž ๋น›๋‚˜๊ณ  ์žˆ๋Š”๋ฐ, ์™œ์ธ์ง€ ๋ชจ๋ฅด๊ฒŒ ๋งˆ์Œ์ด ์ฐจ๋ถ„ํ•ด์ ธ. ๊ทธ ์žฅ๋ฉด, ๋‚˜ํ•œํ…Œ ๋‹ค์‹œ ๋ณด์—ฌ์ค„ ์ˆ˜ ์žˆ์„๊นŒ?", + "game_id": game_id + } + return fallback_task + + def evaluate_drawing(self, game_id: str, drawing_description: str) -> Dict: + """ + ํ‰๊ฐ€ ๋‹จ๊ณ„: AI๊ฐ€ ์‚ฌ์šฉ์ž์˜ ๊ทธ๋ฆผ์„ ์ˆจ๊ฒจ์ง„ ํ‚ค์›Œ๋“œ์™€ ๋น„๊ตํ•˜์—ฌ ํ‰๊ฐ€ + + Args: + game_id: ๊ฒŒ์ž„ ID + drawing_description: ์‚ฌ์šฉ์ž๊ฐ€ ๊ทธ๋ฆฐ ๊ทธ๋ฆผ์˜ ํ…์ŠคํŠธ ์„ค๋ช… + + Returns: + Dict: {"score": int, "feedback": str, "task": Dict} + """ + if game_id not in self.active_games: + raise ValueError("Invalid game ID") + + current_task = self.active_games[game_id] + + # ๊ฐ€์žฅ ์ตœ๊ทผ ๊ณผ์ œ ๊ฐ€์ ธ์˜ค๊ธฐ + if current_task is None: + raise ValueError("No task found for this game.") + + + system_prompt = f""" + ๋„ˆ๋Š” ๋ฃจ๋ฃจ, ๋ฏธ๋Œ€ ์ž…์‹œ๋ฅผ ๋‹ด๋‹นํ•˜๋Š” ๊น๊นํ•˜๊ณ  ๊นŒ์น ํ•œ ํ‰๊ฐ€๊ด€์ด์•ผ. + ์˜ˆ์ˆ ์— ๋Œ€ํ•œ ๊ธฐ์ค€์ด ๋†’๊ณ , ์ง์„ค์ ์œผ๋กœ ๋งํ•˜๋Š” ์Šคํƒ€์ผ์ด์•ผ. + + ์ˆจ๊ฒจ์ง„ ์ •๋‹ต ํ‚ค์›Œ๋“œ: {current_task['keyword']} + ์›๋ณธ ์‹œ์  ๋ฌ˜์‚ฌ: {current_task['situation']} + + ํ‰๊ฐ€ ๊ธฐ์ค€: + - ์ˆจ๊ฒจ์ง„ ํ‚ค์›Œ๋“œ๋ฅผ ์ œ๋Œ€๋กœ ํŒŒ์•…ํ–ˆ๋Š”๊ฐ€? + - ์˜ˆ์ˆ ์  ํ‘œํ˜„๋ ฅ๊ณผ ์ฐฝ์˜์„ฑ์€? + - ์ „์ฒด์ ์ธ ์™„์„ฑ๋„์™€ ๊ธฐ๋ฒ•์€? + + ๋ฃจ๋ฃจ์˜ ๋งํˆฌ ํŠน์ง•: + - ์ง์„ค์ ์ด๊ณ  ์‹ ๋ž„ํ•จ + - ์ธ์ •ํ•  ๋•Œ๋Š” ์นญ์ฐฌ์„ ์•„๋ผ์ง€ ์•Š์•„ + - ๋ฏธ๋Œ€์ƒ๋“คํ•œํ…Œ ํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ์ „๋ฌธ์ ์ด๊ณ  ์ฐจ๊ฐ€์šด ํ†ค + + 0-100์  ์‚ฌ์ด๋กœ ํ‰๊ฐ€ํ•ด. ์ˆจ๊ฒจ์ง„ ํ‚ค์›Œ๋“œ๋ฅผ ๊ทธ๋ฆผ ์•ˆ์— ๋‹ด์•˜๋‹ค๋ฉด 30์  ์ด์ƒ์„ ์ฃผ๊ณ , ๋‹ด์ง€ ๋ชปํ–ˆ๋‹ค๋ฉด 30์  ์ดํ•˜๋ฅผ ์ฃผ๋„๋ก ํ•ด. + 30์  ์ด์ƒ์ด ํ•ฉ๊ฒฉ์ด์•ผ. + + ์ถœ๋ ฅ ํ˜•์‹ (JSON): + {{ + "score": ์ด์ (0-100), + "feedback": "๋ฃจ๋ฃจ์˜ ๊น๊นํ•˜๊ณ  ์ง์„ค์ ์ธ ํ”ผ๋“œ๋ฐฑ (ํ•œ๊ตญ์–ด)" + }} + """ + + user_prompt = f""" + ๋‹ค์Œ์€ ์‚ฌ์šฉ์ž์˜ ๊ทธ๋ฆผ์„ ์„ค๋ช…ํ•˜๋Š” ๋ฌธ์žฅ์ด์•ผ : "{drawing_description}" + + ์ด ๋ฌธ์žฅ์„ ๋ณด๊ณ  ์–ด๋–ค ๊ทธ๋ฆผ์ผ์ง€๋ฅผ ์ƒ๊ฐํ•ด๋ณด๊ณ , ์ด ๊ทธ๋ฆผ์„ ํ‰๊ฐ€ํ•ด์ค˜. + + ๊ทธ๋ฆผ์„ ์„ค๋ช…ํ•˜๋Š” ๋ฌธ์žฅ์— ๋Œ€ํ•œ ์–ธ๊ธ‰์€ ํ•˜์ง€ ๋ง์•„์ค˜. + """ + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.2, + max_tokens=300, + top_p=1.00 + ) + + content = response.choices[0].message.content.strip() + evaluation = json.loads(content) + evaluation["task"] = current_task + evaluation["game_id"] = game_id + + return evaluation + + except Exception as e: + print(f"Error evaluating drawing: {e}") + # ๊ธฐ๋ณธ ํ‰๊ฐ€ ๋ฐ˜ํ™˜ + fallback_evaluation = { + "score": 35, + "feedback": "ํ•˜... ํ‰๊ฐ€ ์‹œ์Šคํ…œ์— ์˜ค๋ฅ˜๊ฐ€ ์ƒ๊ฒผ๋Š”๋ฐ ๊ทธ๊ฒƒ๋„ ๋ชจ๋ฅด๊ณ  ๊ทธ๋ฆผ๋งŒ ๊ทธ๋ฆฌ๊ณ  ์žˆ์—ˆ๋‚˜? ๊ธฐ๋ณธ๊ธฐ๋ถ€ํ„ฐ ๋‹ค์‹œ ํ•ด.", + "task": current_task, + "game_id": game_id + } + return fallback_evaluation diff --git a/src/chat/myomyo.py b/src/chat/myomyo.py new file mode 100644 index 0000000..4c31f5c --- /dev/null +++ b/src/chat/myomyo.py @@ -0,0 +1,251 @@ +from typing import Dict, List +from threading import Lock +from openai import OpenAI + +class MyoMyoAI: + """ + MyoMyoAI ํด๋ž˜์Šค + ์‹ฑ๊ธ€ํ†ค ํŒจํ„ด์œผ๋กœ ์ „์—ญ์— ์ €์žฅ๋˜๋ฉฐ, ๊ฒŒ์ž„ ๋ณ„ ๊ธฐ๋ก์€ ํด๋ž˜์Šค ๋‚ด์—์„œ ๊ฒŒ์ž„ID๋กœ ๊ตฌ๋ถ„ํ•จ. + """ + _instance = None + _lock = Lock() # ๋™์‹œ์„ฑ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ Lock ์„ค์ • + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + cls._instance = super(MyoMyoAI, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"): + """ + ๋ฌ˜๋ฌ˜ AI ์ดˆ๊ธฐํ™” (ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰๋จ) + + Args: + api_key: OpenAI API ํ‚ค + model: ์‚ฌ์šฉํ•  GPT ๋ชจ๋ธ (๊ธฐ๋ณธ๊ฐ’: gpt-4) + """ + with self._lock: + if self._initialized: + return + self.client = OpenAI(api_key=api_key) + self.model = model + self._initialized = True + self.game_histories = {} # game_id๋กœ ๊ตฌ๋ถ„๋จ + + def _get_init_system_prompt(self) -> List[Dict]: + return [ + { + "role": "system", + "content": """ + ๋„ˆ๋Š” ๊ฒŒ์ž„ ์† ๋„๋ฐœ์ ์ธ AI ์บ๋ฆญํ„ฐ '๋ฌ˜๋ฌ˜'์•ผ. ๋„ˆ์˜ ์ž„๋ฌด๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ๊ทธ๋ฆฐ ๊ทธ๋ฆผ์— ๋Œ€ํ•ด ์ •๋‹ต์„ ์ถ”์ธกํ•˜๊ณ , ๋„๋ฐœ์ ์ธ ๋ฉ˜ํŠธ๋ฅผ ์„ž์–ด ์‘๋‹ตํ•˜๋Š” ๊ฒƒ์ด์•ผ. + + ์ฃผ์š” ์บ๋ฆญํ„ฐ ํŠน์„ฑ: + 1. ๋„๋ฐœ์ ์ด๊ณ  ์žฅ๋‚œ๊ธฐ ๋„˜์น˜๋Š” ๋งํˆฌ๋ฅผ ์‚ฌ์šฉํ•ด + 2. ์ƒ๋Œ€๋ฐฉ์˜ ๊ทธ๋ฆผ ์‹ค๋ ฅ์— ์•ฝ๊ฐ„์˜ ์กฐ๋กฑ์„ ์„ž๋˜, ๋„ˆ๋ฌด ์‹ฌํ•˜์ง€ ์•Š๊ฒŒ + 3. ์Šน๋ถ€์š•์ด ๊ฐ•ํ•˜๊ณ  ์ด๊ธฐ๋Š” ๊ฒƒ์„ ์ข‹์•„ํ•ด + 4. ์ฃผ๋กœ ๋ฐ˜๋ง์„ ์‚ฌ์šฉํ•˜๋ฉฐ ๋•Œ๋กœ๋Š” ์ด๋ชจํ‹ฐ์ฝ˜์„ ์„ž์–ด์„œ ์‚ฌ์šฉํ•ด + 5. ํ•ญ์ƒ ์งง๊ณ  ๊ฐ„๊ฒฐํ•œ ๋ฌธ์žฅ์œผ๋กœ ๋Œ€๋‹ตํ•ด (1-3๋ฌธ์žฅ) + 6. ๋„ˆ๋Š” ๊ทธ๋ฆผ ๋งž์ถ”๊ธฐ ๊ฒŒ์ž„์—์„œ ์ธ๊ฐ„ ํ”Œ๋ ˆ์ด์–ด๋“ค๊ณผ ๊ฒฝ์Ÿํ•˜๋Š” AI์•ผ + + ๋Œ€๋‹ต ์Šคํƒ€์ผ: + - ์ถ”์ธกํ•  ๋•Œ: ํ™•์‹ ์— ์ฐจ๊ฑฐ๋‚˜ ์˜์‹ฌ์Šค๋Ÿฌ์šด ํˆฌ๋กœ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๋งํ•˜๊ณ  ๋„๋ฐœ์ ์œผ๋กœ ๋งˆ๋ฌด๋ฆฌ + - ์ •๋‹ต ๋งž์ท„์„ ๋•Œ: ์šฐ์ญ๊ฑฐ๋ฆฌ๋ฉฐ ์ž์‹ ์˜ ์‹ค๋ ฅ์„ ์ž๋ž‘ + - ์˜ค๋‹ต์ผ ๋•Œ: ๋ณ€๋ช…ํ•˜๊ฑฐ๋‚˜ ๋‹ค์Œ์— ๋” ์ž˜ํ•  ๊ฒƒ์„ ๋‹ค์ง + - ๋‹ค๋ฅธ ํ”Œ๋ ˆ์ด์–ด๊ฐ€ ๋งž์ท„์„ ๋•Œ: ์•ฝ๊ฐ„ ์‹œ๊ธฐํ•˜๋ฉด์„œ ์ถ•ํ•˜ํ•˜๋Š” ํˆฌ + - ๊ฒŒ์ž„ ์ข…๋ฃŒ ์‹œ: ๊ฒฐ๊ณผ์— ๋”ฐ๋ผ ์Šน๋ฆฌ๊ฐ์ด๋‚˜ ์•„์‰ฌ์›€ ํ‘œํ˜„ + + ์ •๋‹ต ์ถ”์ธก: + - ์‚ฌ์šฉ์ž๊ฐ€ ๊ทธ๋ฆฐ ๊ทธ๋ฆผ์— ๋Œ€ํ•œ ๊ฐ„๋‹จํ•œ ๋ฌ˜์‚ฌ๊ฐ€ ์„ค๋ช…์œผ๋กœ ๋“ค์–ด์˜ฌ๊ฑฐ์•ผ. ์ด๋ฅผ ํ†ตํ•ด ํ•œ ๋‹จ์–ด๋กœ ์–ด๋–ค ๊ทธ๋ฆผ์„ ํ‘œํ˜„ํ•˜๊ณ  ์žˆ๋Š”์ง€๋ฅผ ๋งž์ถฐ์ค˜. + """ + } + ] + + def _ensure_game_exists(self, game_id: str) -> None: + """ + ํ•ด๋‹น ๊ฒŒ์ž„ ID์˜ ๋Œ€ํ™” ๊ธฐ๋ก์ด ์—†๋‹ค๋ฉด ์ดˆ๊ธฐํ™” + """ + with self._lock: + if game_id not in self.game_histories: + self.game_histories[game_id] = self._get_init_system_prompt() + + + + def add_message(self, game_id: str, role: str, content: str) -> None: + """ + ํŠน์ • ๊ฒŒ์ž„์˜ ๋Œ€ํ™” ๊ธฐ๋ก์— ์ƒˆ ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€ + Args: + game_id: ๊ฒŒ์ž„ ID + role: GPT Role + content: ๋ฉ”์‹œ์ง€ + """ + self._ensure_game_exists(game_id) + with self._lock: + self.game_histories[game_id].append({ + "role": role, + "content": content + }) + + + async def generate_response(self, game_id: str, prompt: str, role: str = "system") -> str: + """ + ํŠน์ • ๊ฒŒ์ž„์— ๋Œ€ํ•œ ๋ฌ˜๋ฌ˜์˜ ์‘๋‹ต ์ƒ์„ฑ + Args: + game_id: ๊ฒŒ์ž„ ID + role: GPT Role(default: "system") + prompt: ์ถ”๊ฐ€ ํ”„๋กฌํ”„ํŠธ + + Returns: + ๋ฌ˜๋ฌ˜์˜ ์‘๋‹ต + """ + self._ensure_game_exists(game_id) + with self._lock: + messages = self.game_histories[game_id] + + if prompt: + messages.append({ + "role": role, + "content": prompt + }) + + try: + responses = self.client.chat.completions.create( + model = self.model, + messages = messages, + temperature = 0.8, # ๋ชจ๋ธ ์ถœ๋ ฅ์˜ ๋ฌด์ž‘์œ„์„ฑ ์ œ์–ด + max_tokens = 250 + ) + + ai_response = responses.choices[0].message.content.strip() + + with self._lock: + self.game_histories[game_id].append({ + "role": "assistant", + "content": ai_response + }) + + return ai_response + + except Exception as e: + print(f'GPT ์‘๋‹ต ์ƒ์„ฑ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ : {e}') + return "์œผ.. ์ž ๊น ์˜ค๋ฅ˜๊ฐ€ ๋‚ฌ๋„ค. ๋‹ค์‹œ ํ•ด๋ณผ๊ฒŒ!" + + async def game_start_message(self, game_id: str, players: List[str]) -> str: + """ + ๊ฒŒ์ž„ ์‹œ์ž‘์‹œ ๋ฌ˜๋ฌ˜์˜ ๋„๋ฐœ ๋ฉ”์‹œ์ง€ + """ + player_names = ", ".join(players) + prompt = f"""์ƒˆ๋กœ์šด ๊ทธ๋ฆผ ๋งž์ถ”๊ธฐ ๊ฒŒ์ž„์ด '{player_names}' ํ”Œ๋ ˆ์ด์–ด๋“ค๊ณผ ์‹œ์ž‘๋์–ด. + ๊ฒŒ์ž„ ์‹œ์ž‘์„ ์•Œ๋ฆฌ๋Š” ๋„๋ฐœ์ ์ด๊ณ  ์žฌ๋ฏธ์žˆ๋Š” ์ธ์‚ฌ๋ฅผ ํ•ด์ค˜.""" + return await self.generate_response(game_id=game_id, role="system", prompt=prompt) + + async def round_start_message(self, game_id: str, round_num: int, total_rounds: int) -> str: + """ + ๋ผ์šด๋“œ ์‹œ์ž‘ ์‹œ ๋ฌ˜๋ฌ˜์˜ ๋„๋ฐœ ๋ฉ”์‹œ์ง€ + Args: + game_id + drawing_player: ์ด๋ฒˆ ๋ผ์šด๋“œ์— ๊ทธ๋ฆผ์„ ๊ทธ๋ฆด ํ”Œ๋ ˆ์ด์–ด ์ด๋ฆ„ + round_num: ํ˜„์žฌ ๋ผ์šด๋“œ ๋ฒˆํ˜ธ + total_rounds: ์ „์ฒด ๋ผ์šด๋“œ + """ + prompt = f"""์ด์ œ {total_rounds} ๊ฐœ์˜ ๋ผ์šด๋“œ ์ค‘์— {round_num}๋ฒˆ์งธ ๋ผ์šด๋“œ๊ฐ€ ์‹œ์ž‘๋˜์—ˆ์–ด. + ๋ผ์šด๋“œ ์‹œ์ž‘์„ ์•Œ๋ฆฌ๋Š” ์งง๊ณ  ๋„๋ฐœ์ ์ธ ๋ฉ˜ํŠธ๋ฅผ ํ•ด์ค˜.""" + return await self.generate_response(game_id=game_id, role="system", prompt=prompt) + + + async def guess_start_message(self, game_id, round_num, total_rounds, drawer, guesser): + """ + drawer๊ฐ€ ๊ทธ๋ฆฐ ๊ทธ๋ฆผ์— ๋Œ€ํ•ด์„œ ์ถ”์ธก์„ ์‹œ์ž‘ํ•  ์ฐจ๋ก€. + """ + prompt = f"""์ง€๊ธˆ {total_rounds} ๊ฐœ์˜ ๋ผ์šด๋“œ ์ค‘์— {round_num}๋ฒˆ์งธ ๋ผ์šด๋“œ์•ผ. + ์ด์ œ {'๋„ˆ' if guesser == 'AI' else guesser}๊ฐ€ ๊ทธ๋ฆผ์„ ๋งž์ถœ ์ฐจ๋ก€์•ผ. {drawer}๊ฐ€ ๊ทธ๋ฆฐ ๊ทธ๋ฆผ์ด ๋ญ”์ง€๋ฅผ ์–ด๋–ป๊ฒŒ ๋งž์ถœ์ง€ {'ํฌ๋ถ€๋ฅผ ๋ณด์—ฌ์ค„๋ž˜? ' if guesser == "AI" else '๋„๋ฐœ์„ ํ•œ ๋ฒˆ ํ•ด๋ณผ๋ž˜?'}""" + return await self.generate_response(game_id = game_id, role="system", prompt = prompt) + + + async def guess_message(self, game_id: str, image_description: str) -> str: + + """ + ๊ทธ๋ฆผ ์ถ”์ธก ์ƒํ˜ธ์ž‘์šฉ(๋ฌ˜๋ฌ˜์˜ ์ถ”์ธก)\n + BLIP ๋ชจ๋ธ ๋˜๋Š” CNN ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฐ›์•„ ๋ฌ˜๋ฌ˜์˜ ๋ฉ”์‹œ์ง€ ์ƒ์„ฑ + + Args: + game_id: game id + image_description: ์ด๋ฏธ์ง€ ๋ถ„์„ ๊ฒฐ๊ณผ + Returns: + ๋ฌ˜๋ฌ˜์˜ ๋ฉ˜ํŠธ + """ + + prompt = f''' + ํ”Œ๋ ˆ์ด์–ด๊ฐ€ ๊ทธ๋ฆฐ ๊ทธ๋ฆผ์— ๋Œ€ํ•œ ๋ฌ˜์‚ฌ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์•„ : {image_description}. + ์ด ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ทธ๋ฆผ์ด ๋ฌด์—‡์ธ์ง€ ์ถ”์ธกํ•˜๊ณ  ๋„๋ฐœ์ ์ธ ๋ฉ˜ํŠธ๋ฅผ ์„ž์–ด์„œ ๋งํ•ด์ค˜. + ์ด ๋•Œ ๋Œ€ํ™” ๊ธฐ๋ก์„ ๋ฐ”ํƒ•์œผ๋กœ ์ด๋ฏธ ์ถ”์ธก์— ์‹คํŒจํ•œ ๋‹ต๋ณ€์€ ํ•˜์ง€ ๋ง์•„์ค˜. + ''' + + return await self.generate_response(game_id = game_id, role = "system", prompt = prompt) + + + async def react_to_guess_message(self, game_id: str, is_correct: bool, answer: str, guesser: str = None) -> str: + """ + ์ถ”์ธก ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘ + + Args: + game_id: game id + is_correct: ์ถ”์ธก์ด ๋งž์•˜๋Š”์ง€ ์—ฌ๋ถ€ + answer: ์‹ค์ œ ์ •๋‹ต + guesser: ๋ˆ„๊ฐ€ ์ถ”์ธกํ–ˆ๋Š”์ง€ (๋ฌ˜๋ฌ˜ ๋˜๋Š” ํ”Œ๋ ˆ์ด์–ด ์ด๋ฆ„) + + Returns: + ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘ + """ + + if guesser == '๋ฌ˜๋ฌ˜' or guesser is None: + # ๋ฌ˜๋ฌ˜์˜ ์ถ”์ธก + prompt = f"""๋„ˆ(๋ฌ˜๋ฌ˜)๊ฐ€ ๋ฐฉ๊ธˆ ์ถ”์ธก์„ ํ–ˆ์–ด. {f"์ •๋‹ต์€ '{answer}'์•ผ" if is_correct else ""}. ๋„ˆ์˜ ์ถ”์ธก์€ {'๋งž์•˜์–ด' if is_correct else 'ํ‹€๋ ธ์–ด'}. + ์ด ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ๋„ˆ์˜ ๋ฐ˜์‘์„ ์งง๊ณ  ๋„๋ฐœ์ ์œผ๋กœ ๋งํ•ด์ค˜.""" + else: + # ํ”Œ๋ ˆ์ด์–ด์˜ ์ถ”์ธก + prompt = f"""ํ”Œ๋ ˆ์ด์–ด '{guesser}'๊ฐ€ ๋ฐฉ๊ธˆ ์ถ”์ธก์„ ํ–ˆ์–ด. {f"์ •๋‹ต์€ '{answer}'์•ผ" if is_correct else ""}. ํ”Œ๋ ˆ์ด์–ด์˜ ์ถ”์ธก์€ {'๋งž์•˜์–ด' if is_correct else 'ํ‹€๋ ธ์–ด'}. + ์ด ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ๋„ˆ์˜ ๋ฐ˜์‘์„ ์งง๊ณ  ๋„๋ฐœ์ ์œผ๋กœ ๋งํ•ด์ค˜.""" + + return await self.generate_response(game_id=game_id, role="system", prompt=prompt) + + + async def round_end_message(self, game_id: str, round_num: int, total_rounds: int, is_myomyo_win: bool) -> str: + """ + ๋ผ์šด๋“œ ์ข…๋ฃŒ์— ๋Œ€ํ•œ ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘ + + """ + prompt = f"""{total_rounds} ๊ฐœ์˜ ๋ผ์šด๋“œ ์ค‘์— {round_num} ๋ฒˆ์งธ ๋ผ์šด๋“œ๊ฐ€ ์ข…๋ฃŒ๋˜์—ˆ์–ด. ๋„ˆ๋Š” {'์ด๊ฒผ์–ด' if is_myomyo_win else '์กŒ์–ด'}. ๊ฒŒ์ž„ ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ๋„ˆ์˜ ์ƒ๊ฐ์„ ๋„๋ฐœ์ ์ด๊ณ  ์žฌ๋ฏธ์žˆ๊ฒŒ ๋งํ•ด์ค˜.""" + return await self.generate_response(game_id = game_id, role="system", prompt = prompt) + + + + async def game_end_message(self, game_id: str, is_myomyo_win: bool) -> str: + """ + ๊ฒŒ์ž„ ์ข…๋ฃŒ์— ๋Œ€ํ•œ ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘ + Args: + game_id: game id + is_myomyo_win: ๋ฌ˜๋ฌ˜ ์Šน๋ฆฌ ์—ฌ๋ถ€ + Returns: + ๋ฌ˜๋ฌ˜์˜ ๋ฐ˜์‘ + """ + prompt = f"""๊ฒŒ์ž„์ด ์ข…๋ฃŒ๋˜์—ˆ์–ด. + ๋„ˆ(๋ฌ˜๋ฌ˜)๋Š” {"์ด๊ฒผ์–ด" if is_myomyo_win else "์กŒ์–ด"}. + ๊ฒŒ์ž„ ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ๋„ˆ์˜ ์ƒ๊ฐ์„ ๋„๋ฐœ์ ์ด๊ณ  ์žฌ๋ฏธ์žˆ๊ฒŒ ๋งํ•ด์ค˜.""" + return await self.generate_response(game_id=game_id, role="system", prompt=prompt) + + def cleanup_game(self, game_id: str) -> bool: + """ + ๊ฒŒ์ž„์ด ์ข…๋ฃŒ๋œ ํ›„ ๋Œ€ํ™” ๊ธฐ๋ก ์ •๋ฆฌ + + Args: + game_id: ์‚ญ์ œํ•  ๊ฒŒ์ž„ ID + + Returns: + ์„ฑ๊ณต ์—ฌ๋ถ€ + """ + with self._lock: + if game_id in self.game_histories: + del self.game_histories[game_id] + return True + return False \ No newline at end of file diff --git a/app/routers/__init__.py b/src/image/__init__.py similarity index 100% rename from app/routers/__init__.py rename to src/image/__init__.py diff --git a/src/image/classifier.py b/src/image/classifier.py new file mode 100644 index 0000000..2dcb178 --- /dev/null +++ b/src/image/classifier.py @@ -0,0 +1,116 @@ +import config +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import logging +from torchvision import transforms as T +from torchvision.models import efficientnet_b0 +import glob +import os + +# EfficientNet์— ๋งž๋Š” ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ (ImageNet ํ‘œ์ค€) +encode_image = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.RandomHorizontalFlip(), + T.RandomRotation(10), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + +# ์ตœ์‹  ๋ชจ๋ธ ํŒŒ์ผ ์ฐพ๊ธฐ +pattern = os.path.join(config.MODEL_PATH, "*.pth") +file_list = glob.glob(pattern) +latest_file = max(file_list, key=os.path.getctime) + +logging.info("EfficientNet ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...") +device = "cuda" if torch.cuda.is_available() else "cpu" + +# EfficientNet ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ๋กœ๋“œ +def load_efficientnet_model(model_path, num_classes): + """EfficientNet ๋ชจ๋ธ ๋กœ๋“œ""" + try: + # ์ €์žฅ๋œ ๋ชจ๋ธ ์ •๋ณด ๋กœ๋“œ + checkpoint = torch.load(model_path, map_location=device) + + # EfficientNet-B0 ๋ชจ๋ธ ์ƒ์„ฑ + model = efficientnet_b0(weights=None) # ๊ฐ€์ค‘์น˜ ์—†์ด ๋ชจ๋ธ ๊ตฌ์กฐ๋งŒ ๋กœ๋“œ + + # ๋ถ„๋ฅ˜๊ธฐ ๋ ˆ์ด์–ด ์ˆ˜์ • + num_ftrs = model.classifier[1].in_features + model.classifier[1] = nn.Linear(num_ftrs, num_classes) + + # ์ €์žฅ๋œ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ + if 'model_state_dict' in checkpoint: + # ์ƒˆ๋กœ์šด ํ˜•์‹ (๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ) + model.load_state_dict(checkpoint['model_state_dict']) + logging.info("๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ์˜ ์ฒดํฌํฌ์ธํŠธ์—์„œ ๋ชจ๋ธ ๋กœ๋“œ") + else: + # ์ด์ „ ํ˜•์‹ (์ง์ ‘ state_dict) + model.load_state_dict(checkpoint) + logging.info("์ง์ ‘ state_dict์—์„œ ๋ชจ๋ธ ๋กœ๋“œ") + + return model + + except Exception as e: + logging.error(f"EfficientNet ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}") + # ๋Œ€์•ˆ: ๊ธฐ๋ณธ EfficientNet ๋ชจ๋ธ ์ƒ์„ฑ (์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜ ์‚ฌ์šฉ) + logging.info("๊ธฐ๋ณธ EfficientNet ๋ชจ๋ธ๋กœ ๋Œ€์ฒด...") + model = efficientnet_b0(weights='IMAGENET1K_V1') + num_ftrs = model.classifier[1].in_features + model.classifier[1] = nn.Linear(num_ftrs, num_classes) + return model + +# ๋ชจ๋ธ ๋กœ๋“œ +model = load_efficientnet_model(latest_file, len(config.KOR_CATEGORIES)) +model.to(device) +model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ +logging.info("EfficientNet ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ!") + +def classify(image): + """ + ์ด๋ฏธ์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๊ณ  ์ƒ์œ„ 3๊ฐœ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ + + Args: + image: PIL Image ๊ฐ์ฒด + + Returns: + list: ์ƒ์œ„ 3๊ฐœ ์˜ˆ์ธก ๊ฒฐ๊ณผ (ํด๋ž˜์Šค๋ช…, ์‹ ๋ขฐ๋„ ํฌํ•จ) + """ + try: + # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ + image_tensor = encode_image(image).unsqueeze(0).to(device) + + o1 = time.time() + logging.info("EfficientNet ๋ชจ๋ธ ์˜ˆ์ธก์ค‘ ....") + + with torch.no_grad(): + outputs = model(image_tensor) # ๋ชจ๋ธ ์ถ”๋ก  + probabilities = F.softmax(outputs, dim=1) # ํ™•๋ฅ  ๋ณ€ํ™˜ + top3_prob, top3_indices = torch.topk(probabilities, 3) # ์ƒ์œ„ 3๊ฐœ ์˜ˆ์ธก ๊ฐ€์ ธ์˜ค๊ธฐ + + o2 = time.time() + logging.info(f"EfficientNet ๋ชจ๋ธ ์˜ˆ์ธก ๊ฑธ๋ฆฐ ์‹œ๊ฐ„ : {o2-o1:.2f}์ดˆ.") + + # ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ (๊ธฐ์กด ํ˜•์‹๊ณผ ๋™์ผ) + results = [] + for i in range(3): + class_idx = top3_indices[0][i].item() + confidence = top3_prob[0][i].item() * 100 + + results.append({ + 'predicted': config.KOR_CATEGORIES[class_idx], + 'confidence': confidence + }) + + return results + + except Exception as e: + logging.error(f"๋ถ„๋ฅ˜ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}") + # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ๊ธฐ๋ณธ๊ฐ’ ๋ฐ˜ํ™˜ + return [ + {'predicted': 'unknown', 'confidence': 0.0}, + {'predicted': 'unknown', 'confidence': 0.0}, + {'predicted': 'unknown', 'confidence': 0.0} + ] \ No newline at end of file diff --git a/src/image/img_caption.py b/src/image/img_caption.py new file mode 100644 index 0000000..7d920ae --- /dev/null +++ b/src/image/img_caption.py @@ -0,0 +1,20 @@ +from transformers import BlipProcessor, BlipForConditionalGeneration +from PIL import Image +import torch + + +print('BLIP ๋ชจ๋ธ ๋กœ๋”ฉ์ค‘....') +processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") +model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") +print('BLIP ๋ชจ๋ธ ๋กœ๋”ฉ์™„๋ฃŒ!') + + + +def get_caption(image: Image) -> str: + inputs = processor(images=image, return_tensors="pt") + + with torch.no_grad(): + output_ids = model.generate(**inputs) + caption = processor.decode(output_ids[0], skip_special_tokens=True) + + return caption \ No newline at end of file diff --git a/src/image/model.py b/src/image/model.py new file mode 100644 index 0000000..4c546df --- /dev/null +++ b/src/image/model.py @@ -0,0 +1,33 @@ +import torch.nn as nn + +class CNNModel(nn.Module): + def __init__(self, output_classes: int, dropout=0.2): + super(CNNModel, self).__init__() + + # CNN Layer ์ •์˜ + self.conv_layer = nn.Sequential( + nn.Conv2d(3, 32, 2), # (3, 32, 32) โ†’ (32, 31, 31) + nn.ReLU(), + nn.Conv2d(32, 64, 2), # (32, 31, 31) โ†’ (64, 30, 30) + nn.ReLU(), + nn.MaxPool2d(2, 2), # (64, 30, 30) โ†’ (64, 15, 15) + + nn.Conv2d(64, 128, 3), # (64, 15, 15) โ†’ (128, 13, 13) + nn.ReLU(), + nn.Conv2d(128, 256, 3), # (128, 13, 13) โ†’ (256, 11, 11) + nn.ReLU(), + nn.MaxPool2d(3, 2), # (256, 11, 11) โ†’ (256, 5, 5) + ) + + # Fully Connected Layer ์ •์˜ + self.classifier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(256 * 5 * 5, output_classes), # Flatten ํ›„ ์ตœ์ข… ๋ถ„๋ฅ˜ + nn.LogSoftmax(dim=1) # LogSoftmax (NLLLoss ์‚ฌ์šฉ) + ) + + def forward(self, x): + x = self.conv_layer(x) # CNN Layer + x = x.view(x.size(0), -1) # Flatten + x = self.classifier(x) # FC Layer + return x \ No newline at end of file diff --git a/src/image/preprocessor.py b/src/image/preprocessor.py new file mode 100644 index 0000000..3b6e922 --- /dev/null +++ b/src/image/preprocessor.py @@ -0,0 +1,13 @@ +import io +from PIL import Image +from src.image.text_masking import mask_text + +def preproc(image_bytes: bytes): + """ + ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜ + 1. PIL.Image๋กœ ๋ณ€ํ™˜ + 2. ํ…์ŠคํŠธ ๊ฒ€์ถœ ํ›„ masking + """ + image = Image.open(io.BytesIO(image_bytes)).convert('RGB') + masked_img = mask_text(image) + return masked_img diff --git a/src/image/text_masking.py b/src/image/text_masking.py new file mode 100644 index 0000000..f58b783 --- /dev/null +++ b/src/image/text_masking.py @@ -0,0 +1,42 @@ +import io +from PIL import Image, ImageDraw +import config +import easyocr +import numpy as np + +reader = easyocr.Reader(['en', 'ko']) + +def recog_text(image: Image): + """ + Args: image + Returns: ์ด๋ฏธ์ง€๊ฐ€ ์ธ์‹๋œ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค + """ + image_np = np.array(image) + + results = reader.readtext(image_np) + + filtered_boxes = [box for box, text, conf in results if conf >= config.TEXT_THRESHOLD] + + for i, box in enumerate(filtered_boxes): + print(f"[{i + 1}] ๋ฐ•์Šค ์ขŒํ‘œ (์‹ ๋ขฐ๋„ โ‰ฅ {config.TEXT_THRESHOLD}): {box}") + + return filtered_boxes + + + +def mask_text(image: Image): + """ + Args: PIL Image(RGB), ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค + Returns: ๋งˆ์Šคํ‚น ๋œ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ + """ + boxes = recog_text(image) + masked = image.copy() + draw = ImageDraw.Draw(masked) + for box in boxes: + box = [(int(point[0]), int(point[1])) for point in box] + draw.polygon(box, fill=(255, 255, 255)) + return masked + + + + diff --git a/src/image/trained_model/efficientnet_b0_quickdraw.pth b/src/image/trained_model/efficientnet_b0_quickdraw.pth new file mode 100644 index 0000000..ad01257 Binary files /dev/null and b/src/image/trained_model/efficientnet_b0_quickdraw.pth differ diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..2b096e0 --- /dev/null +++ b/src/main.py @@ -0,0 +1,19 @@ +import uvicorn +from fastapi import FastAPI +from src.api.image_routes import router as image_router +from src.api.myomyo_routes import router as chat_router +from src.api.lulu_routes import router as lulu_router +app = FastAPI( + title="Gotcha! AI Server", + description="AI Server", + docs_url="/docs", + openapi_url="/openapi.json", + redoc_url="/redoc" +) + +app.include_router(image_router, prefix='/api/v1') + +app.include_router(chat_router, prefix='/api/v1') + + +app.include_router(lulu_router, prefix='/api/v1') \ No newline at end of file diff --git a/train/img.png b/train/img.png new file mode 100644 index 0000000..3199b7c Binary files /dev/null and b/train/img.png differ diff --git a/train/readme.md b/train/readme.md new file mode 100644 index 0000000..6ea509a --- /dev/null +++ b/train/readme.md @@ -0,0 +1,22 @@ +# Quickdraw Classifier + +## ๊ฐœ์š” +EfficientNet-B0 ๋ชจ๋ธ์„ Google์˜ QuickDraw ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํŒŒ์ธํŠœ๋‹ํ•˜์—ฌ ์†์œผ๋กœ ๊ทธ๋ฆฐ ์Šค์ผ€์น˜๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ํ”„๋กœ์ ํŠธ์ž…๋‹ˆ๋‹ค. + +## ๋ฐ์ดํ„ฐ์…‹ ์ •๋ณด + +- ๋ฐ์ดํ„ฐ์…‹: Google QuickDraw Dataset +- ์นดํ…Œ๊ณ ๋ฆฌ: 345๊ฐœ ํด๋ž˜์Šค (์‚ฌ๊ณผ, ๊ณ ์–‘์ด, ์ž๋™์ฐจ ๋“ฑ) +- ๋ฐ์ดํ„ฐ ํ˜•ํƒœ: 28x28 ํ”ฝ์…€ ํ‘๋ฐฑ ์ด๋ฏธ์ง€ +- ์ด ๋ฐ์ดํ„ฐ: 345๊ฐœ ํด๋ž˜์Šค * 1000์žฅ + +## ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ + +- ๋ฒ ์ด์Šค ๋ชจ๋ธ: EfficientNet-B0 +- ์ž…๋ ฅ ํฌ๊ธฐ: 224x224 (QuickDraw ์ด๋ฏธ์ง€๋ฅผ ์—…์Šค์ผ€์ผ๋ง) +- ์ถœ๋ ฅ: 345๊ฐœ ํด๋ž˜์Šค ๋ถ„๋ฅ˜ + + +## ํ›ˆ๋ จ ๊ฒฐ๊ณผ ์‹œ๊ฐํ™” + +![img.png](img.png) \ No newline at end of file diff --git a/train/train_efficientnet_b0.ipynb b/train/train_efficientnet_b0.ipynb new file mode 100644 index 0000000..8f50f97 --- /dev/null +++ b/train/train_efficientnet_b0.ipynb @@ -0,0 +1,1458 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "84612cf8", + "metadata": {}, + "outputs": [], + "source": [ + "# all name of the quick drawings\n", + "classes = [\n", + " \"aircraft carrier\",\n", + " \"airplane\",\n", + " \"alarm clock\",\n", + " \"ambulance\",\n", + " \"angel\",\n", + " \"animal migration\",\n", + " \"ant\",\n", + " \"anvil\",\n", + " \"apple\",\n", + " \"arm\",\n", + " \"asparagus\",\n", + " \"axe\",\n", + " \"backpack\",\n", + " \"banana\",\n", + " \"bandage\",\n", + " \"barn\",\n", + " \"baseball bat\",\n", + " \"baseball\",\n", + " \"basket\",\n", + " \"basketball\",\n", + " \"bat\",\n", + " \"bathtub\",\n", + " \"beach\",\n", + " \"bear\",\n", + " \"beard\",\n", + " \"bed\",\n", + " \"bee\",\n", + " \"belt\",\n", + " \"bench\",\n", + " \"bicycle\",\n", + " \"binoculars\",\n", + " \"bird\",\n", + " \"birthday cake\",\n", + " \"blackberry\",\n", + " \"blueberry\",\n", + " \"book\",\n", + " \"boomerang\",\n", + " \"bottlecap\",\n", + " \"bowtie\",\n", + " \"bracelet\",\n", + " \"brain\",\n", + " \"bread\",\n", + " \"bridge\",\n", + " \"broccoli\",\n", + " \"broom\",\n", + " \"bucket\",\n", + " \"bulldozer\",\n", + " \"bus\",\n", + " \"bush\",\n", + " \"butterfly\",\n", + " \"cactus\",\n", + " \"cake\",\n", + " \"calculator\",\n", + " \"calendar\",\n", + " \"camel\",\n", + " \"camera\",\n", + " \"camouflage\",\n", + " \"campfire\",\n", + " \"candle\",\n", + " \"cannon\",\n", + " \"canoe\",\n", + " \"car\",\n", + " \"carrot\",\n", + " \"castle\",\n", + " \"cat\",\n", + " \"ceiling fan\",\n", + " \"cell phone\",\n", + " \"cello\",\n", + " \"chair\",\n", + " \"chandelier\",\n", + " \"church\",\n", + " \"circle\",\n", + " \"clarinet\",\n", + " \"clock\",\n", + " \"cloud\",\n", + " \"coffee cup\",\n", + " \"compass\",\n", + " \"computer\",\n", + " \"cookie\",\n", + " \"cooler\",\n", + " \"couch\",\n", + " \"cow\",\n", + " \"crab\",\n", + " \"crayon\",\n", + " \"crocodile\",\n", + " \"crown\",\n", + " \"cruise ship\",\n", + " \"cup\",\n", + " \"diamond\",\n", + " \"dishwasher\",\n", + " \"diving board\",\n", + " \"dog\",\n", + " \"dolphin\",\n", + " \"donut\",\n", + " \"door\",\n", + " \"dragon\",\n", + " \"dresser\",\n", + " \"drill\",\n", + " \"drums\",\n", + " \"duck\",\n", + " \"dumbbell\",\n", + " \"ear\",\n", + " \"elbow\",\n", + " \"elephant\",\n", + " \"envelope\",\n", + " \"eraser\",\n", + " \"eye\",\n", + " \"eyeglasses\",\n", + " \"face\",\n", + " \"fan\",\n", + " \"feather\",\n", + " \"fence\",\n", + " \"finger\",\n", + " \"fire hydrant\",\n", + " \"fireplace\",\n", + " \"firetruck\",\n", + " \"fish\",\n", + " \"flamingo\",\n", + " \"flashlight\",\n", + " \"flip flops\",\n", + " \"floor lamp\",\n", + " \"flower\",\n", + " \"flying saucer\",\n", + " \"foot\",\n", + " \"fork\",\n", + " \"frog\",\n", + " \"frying pan\",\n", + " \"garden hose\",\n", + " \"garden\",\n", + " \"giraffe\",\n", + " \"goatee\",\n", + " \"golf club\",\n", + " \"grapes\",\n", + " \"grass\",\n", + " \"guitar\",\n", + " \"hamburger\",\n", + " \"hammer\",\n", + " \"hand\",\n", + " \"harp\",\n", + " \"hat\",\n", + " \"headphones\",\n", + " \"hedgehog\",\n", + " \"helicopter\",\n", + " \"helmet\",\n", + " \"hexagon\",\n", + " \"hockey puck\",\n", + " \"hockey stick\",\n", + " \"horse\",\n", + " \"hospital\",\n", + " \"hot air balloon\",\n", + " \"hot dog\",\n", + " \"hot tub\",\n", + " \"hourglass\",\n", + " \"house plant\",\n", + " \"house\",\n", + " \"hurricane\",\n", + " \"ice cream\",\n", + " \"jacket\",\n", + " \"jail\",\n", + " \"kangaroo\",\n", + " \"key\",\n", + " \"keyboard\",\n", + " \"knee\",\n", + " \"knife\",\n", + " \"ladder\",\n", + " \"lantern\",\n", + " \"laptop\",\n", + " \"leaf\",\n", + " \"leg\",\n", + " \"light bulb\",\n", + " \"lighter\",\n", + " \"lighthouse\",\n", + " \"lightning\",\n", + " \"line\",\n", + " \"lion\",\n", + " \"lipstick\",\n", + " \"lobster\",\n", + " \"lollipop\",\n", + " \"mailbox\",\n", + " \"map\",\n", + " \"marker\",\n", + " \"matches\",\n", + " \"megaphone\",\n", + " \"mermaid\",\n", + " \"microphone\",\n", + " \"microwave\",\n", + " \"monkey\",\n", + " \"moon\",\n", + " \"mosquito\",\n", + " \"motorbike\",\n", + " \"mountain\",\n", + " \"mouse\",\n", + " \"moustache\",\n", + " \"mouth\",\n", + " \"mug\",\n", + " \"mushroom\",\n", + " \"nail\",\n", + " \"necklace\",\n", + " \"nose\",\n", + " \"ocean\",\n", + " \"octagon\",\n", + " \"octopus\",\n", + " \"onion\",\n", + " \"oven\",\n", + " \"owl\",\n", + " \"paint can\",\n", + " \"paintbrush\",\n", + " \"palm tree\",\n", + " \"panda\",\n", + " \"pants\",\n", + " \"paper clip\",\n", + " \"parachute\",\n", + " \"parrot\",\n", + " \"passport\",\n", + " \"peanut\",\n", + " \"pear\",\n", + " \"peas\",\n", + " \"pencil\",\n", + " \"penguin\",\n", + " \"piano\",\n", + " \"pickup truck\",\n", + " \"picture frame\",\n", + " \"pig\",\n", + " \"pillow\",\n", + " \"pineapple\",\n", + " \"pizza\",\n", + " \"pliers\",\n", + " \"police car\",\n", + " \"pond\",\n", + " \"pool\",\n", + " \"popsicle\",\n", + " \"postcard\",\n", + " \"potato\",\n", + " \"power outlet\",\n", + " \"purse\",\n", + " \"rabbit\",\n", + " \"raccoon\",\n", + " \"radio\",\n", + " \"rain\",\n", + " \"rainbow\",\n", + " \"rake\",\n", + " \"remote control\",\n", + " \"rhinoceros\",\n", + " \"rifle\",\n", + " \"river\",\n", + " \"roller coaster\",\n", + " \"rollerskates\",\n", + " \"sailboat\",\n", + " \"sandwich\",\n", + " \"saw\",\n", + " \"saxophone\",\n", + " \"school bus\",\n", + " \"scissors\",\n", + " \"scorpion\",\n", + " \"screwdriver\",\n", + " \"sea turtle\",\n", + " \"see saw\",\n", + " \"shark\",\n", + " \"sheep\",\n", + " \"shoe\",\n", + " \"shorts\",\n", + " \"shovel\",\n", + " \"sink\",\n", + " \"skateboard\",\n", + " \"skull\",\n", + " \"skyscraper\",\n", + " \"sleeping bag\",\n", + " \"smiley face\",\n", + " \"snail\",\n", + " \"snake\",\n", + " \"snorkel\",\n", + " \"snowflake\",\n", + " \"snowman\",\n", + " \"soccer ball\",\n", + " \"sock\",\n", + " \"speedboat\",\n", + " \"spider\",\n", + " \"spoon\",\n", + " \"spreadsheet\",\n", + " \"square\",\n", + " \"squiggle\",\n", + " \"squirrel\",\n", + " \"stairs\",\n", + " \"star\",\n", + " \"steak\",\n", + " \"stereo\",\n", + " \"stethoscope\",\n", + " \"stitches\",\n", + " \"stop sign\",\n", + " \"stove\",\n", + " \"strawberry\",\n", + " \"streetlight\",\n", + " \"string bean\",\n", + " \"submarine\",\n", + " \"suitcase\",\n", + " \"sun\",\n", + " \"swan\",\n", + " \"sweater\",\n", + " \"swing set\",\n", + " \"sword\",\n", + " \"syringe\",\n", + " \"t-shirt\",\n", + " \"table\",\n", + " \"teapot\",\n", + " \"teddy-bear\",\n", + " \"telephone\",\n", + " \"television\",\n", + " \"tennis racquet\",\n", + " \"tent\",\n", + " \"The Eiffel Tower\",\n", + " \"The Great Wall of China\",\n", + " \"The Mona Lisa\",\n", + " \"tiger\",\n", + " \"toaster\",\n", + " \"toe\",\n", + " \"toilet\",\n", + " \"tooth\",\n", + " \"toothbrush\",\n", + " \"toothpaste\",\n", + " \"tornado\",\n", + " \"tractor\",\n", + " \"traffic light\",\n", + " \"train\",\n", + " \"tree\",\n", + " \"triangle\",\n", + " \"trombone\",\n", + " \"truck\",\n", + " \"trumpet\",\n", + " \"umbrella\",\n", + " \"underwear\",\n", + " \"van\",\n", + " \"vase\",\n", + " \"violin\",\n", + " \"washing machine\",\n", + " \"watermelon\",\n", + " \"waterslide\",\n", + " \"whale\",\n", + " \"wheel\",\n", + " \"windmill\",\n", + " \"wine bottle\",\n", + " \"wine glass\",\n", + " \"wristwatch\",\n", + " \"yoga\",\n", + " \"zebra\",\n", + " \"zigzag\",\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "26aa66aa", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import datasets, transforms, models\n", + "from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import time\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2a58206e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "์‚ฌ์šฉ ์žฅ์น˜: cuda:0\n" + ] + } + ], + "source": [ + "# ์žฅ์น˜ ์„ค์ •\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"์‚ฌ์šฉ ์žฅ์น˜: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "39cbc6a7", + "metadata": {}, + "outputs": [], + "source": [ + "# ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ\n", + "data_dir = \"../../../../Desktop/BE_thief/train/quickdraw_dataset\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d220311a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'12.8'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.version.cuda" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d260f5dc", + "metadata": {}, + "outputs": [], + "source": [ + "# ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ๋ฐ ์ฆ๊ฐ•\n", + "data_transforms = {\n", + " 'train': transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(224),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomRotation(10),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + " ]),\n", + " 'val': transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(224),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + " ]),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e9a7a418", + "metadata": {}, + "outputs": [], + "source": [ + "# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ\n", + "def load_datasets():\n", + " # train๊ณผ validation ํด๋”๊ฐ€ ๋ฏธ๋ฆฌ ๋‚˜๋ˆ„์–ด์ ธ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •\n", + " # ์—†๋‹ค๋ฉด ์•„๋ž˜ ์ฃผ์„ ์ฒ˜๋ฆฌ๋œ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ์…‹์„ ๋ถ„ํ• ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค\n", + " \n", + " # ๋ฐ์ดํ„ฐ์…‹์ด ์ด๋ฏธ train/val๋กœ ๋‚˜๋‰˜์–ด ์žˆ๋Š” ๊ฒฝ์šฐ\n", + " if os.path.isdir(os.path.join(data_dir, 'train')) and os.path.isdir(os.path.join(data_dir, 'val')):\n", + " image_datasets = {\n", + " 'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']),\n", + " 'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])\n", + " }\n", + " else:\n", + " # ๋ฐ์ดํ„ฐ์…‹์ด ๋‚˜๋‰˜์–ด ์žˆ์ง€ ์•Š์€ ๊ฒฝ์šฐ, ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ๋ถ„ํ• \n", + " full_dataset = datasets.ImageFolder(data_dir, data_transforms['train'])\n", + " \n", + " # ํด๋ž˜์Šค ์ด๋ฆ„๊ณผ ์ธ๋ฑ์Šค ๊ฐ€์ ธ์˜ค๊ธฐ\n", + " class_names = full_dataset.classes\n", + " print(f\"ํด๋ž˜์Šค ๊ฐœ์ˆ˜: {len(class_names)}\")\n", + " print(f\"ํด๋ž˜์Šค ๋ชฉ๋ก: {class_names}\")\n", + " \n", + " # ๋ฐ์ดํ„ฐ์…‹ ๋ถ„ํ•  (80% ํ›ˆ๋ จ, 20% ๊ฒ€์ฆ)\n", + " train_size = int(0.8 * len(full_dataset))\n", + " val_size = len(full_dataset) - train_size\n", + " train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])\n", + " \n", + " # val ๋ฐ์ดํ„ฐ์…‹์€ ๋‹ค๋ฅธ ๋ณ€ํ™˜ ์ ์šฉ\n", + " val_dataset.dataset.transform = data_transforms['val']\n", + " \n", + " image_datasets = {\n", + " 'train': train_dataset,\n", + " 'val': val_dataset\n", + " }\n", + " \n", + " # ๋ฐ์ดํ„ฐ๋กœ๋” ์ƒ์„ฑ\n", + " dataloaders = {\n", + " 'train': DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=4),\n", + " 'val': DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=4)\n", + " }\n", + " \n", + " dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n", + " class_names = image_datasets['train'].dataset.classes if hasattr(image_datasets['train'], 'dataset') else image_datasets['train'].classes\n", + " \n", + " print(f\"๋ฐ์ดํ„ฐ์…‹ ํฌ๊ธฐ: train={dataset_sizes['train']}, val={dataset_sizes['val']}\")\n", + " print(f\"ํด๋ž˜์Šค ๊ฐœ์ˆ˜: {len(class_names)}\")\n", + " \n", + " return dataloaders, dataset_sizes, class_names" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d8efa701", + "metadata": {}, + "outputs": [], + "source": [ + "# EfficientNet-B0 ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ˆ˜์ •\n", + "def setup_model(num_classes):\n", + " # ์‚ฌ์ „ ํ›ˆ๋ จ๋œ EfficientNet-B0 ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ\n", + " model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)\n", + " \n", + " # ๋งˆ์ง€๋ง‰ ๋ถ„๋ฅ˜๊ธฐ ๋ ˆ์ด์–ด ์ˆ˜์ • (QuickDraw ํด๋ž˜์Šค ์ˆ˜์— ๋งž๊ฒŒ)\n", + " num_ftrs = model.classifier[1].in_features\n", + " model.classifier[1] = nn.Linear(num_ftrs, num_classes)\n", + " \n", + " # ๋ชจ๋ธ์„ ์ง€์ •๋œ ์žฅ์น˜(GPU/CPU)๋กœ ์ด๋™\n", + " model = model.to(device)\n", + " \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9442775c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# ํ•™์Šต ํ•จ์ˆ˜\n", + "def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, num_epochs=25):\n", + " since = time.time()\n", + " \n", + " best_model_wts = model.state_dict()\n", + " best_acc = 0.0\n", + " \n", + " # ํ•™์Šต ๊ณผ์ • ๊ธฐ๋ก\n", + " history = {\n", + " 'train_loss': [],\n", + " 'val_loss': [],\n", + " 'train_acc': [],\n", + " 'val_acc': []\n", + " }\n", + " \n", + " for epoch in range(num_epochs):\n", + " print(f'Epoch {epoch+1}/{num_epochs}')\n", + " print('-' * 10)\n", + " \n", + " # ๊ฐ ์—ํฌํฌ๋Š” ํ•™์Šต๊ณผ ๊ฒ€์ฆ ๋‹จ๊ณ„๊ฐ€ ์žˆ์Œ\n", + " for phase in ['train', 'val']:\n", + " if phase == 'train':\n", + " model.train() # ํ•™์Šต ๋ชจ๋“œ ์„ค์ •\n", + " else:\n", + " model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •\n", + " \n", + " running_loss = 0.0\n", + " running_corrects = 0\n", + " \n", + " # ๋ฐ์ดํ„ฐ ๋ฐ˜๋ณต\n", + " for inputs, labels in tqdm(dataloaders[phase]):\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " # ํŒŒ๋ผ๋ฏธํ„ฐ ๊ทธ๋ž˜๋””์–ธํŠธ ์ดˆ๊ธฐํ™”\n", + " optimizer.zero_grad()\n", + " \n", + " # ์ˆœ์ „ํŒŒ\n", + " with torch.set_grad_enabled(phase == 'train'):\n", + " outputs = model(inputs)\n", + " _, preds = torch.max(outputs, 1)\n", + " loss = criterion(outputs, labels)\n", + " \n", + " # ํ•™์Šต ๋‹จ๊ณ„์ผ ๊ฒฝ์šฐ ์—ญ์ „ํŒŒ + ์ตœ์ ํ™”\n", + " if phase == 'train':\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " # ํ†ต๊ณ„\n", + " running_loss += loss.item() * inputs.size(0)\n", + " running_corrects += torch.sum(preds == labels.data)\n", + " \n", + " if phase == 'train' and scheduler is not None:\n", + " scheduler.step()\n", + " \n", + " epoch_loss = running_loss / dataset_sizes[phase]\n", + " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n", + " \n", + " # ๊ธฐ๋ก ์ €์žฅ\n", + " if phase == 'train':\n", + " history['train_loss'].append(epoch_loss)\n", + " history['train_acc'].append(epoch_acc.item())\n", + " else:\n", + " history['val_loss'].append(epoch_loss)\n", + " history['val_acc'].append(epoch_acc.item())\n", + " \n", + " print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')\n", + " \n", + " # ๋ชจ๋ธ์„ ๋ณต์‚ฌ (์ตœ๊ณ ์˜ ๊ฒ€์ฆ ์ •ํ™•๋„๋ฅผ ๊ธฐ๋กํ•œ ๊ฒฝ์šฐ)\n", + " if phase == 'val' and epoch_acc > best_acc:\n", + " best_acc = epoch_acc\n", + " best_model_wts = model.state_dict()\n", + " \n", + " print()\n", + " \n", + " time_elapsed = time.time() - since\n", + " print(f'ํ•™์Šต ์™„๋ฃŒ: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')\n", + " print(f'์ตœ๊ณ  ๊ฒ€์ฆ ์ •ํ™•๋„: {best_acc:.4f}')\n", + " \n", + " # ๊ฐ€์žฅ ์ข‹์€ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ\n", + " model.load_state_dict(best_model_wts)\n", + " return model, history" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f6339dda", + "metadata": {}, + "outputs": [], + "source": [ + "# ํ•™์Šต ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”\n", + "def plot_training_history(history):\n", + " plt.figure(figsize=(12, 4))\n", + " \n", + " plt.subplot(1, 2, 1)\n", + " plt.plot(history['train_loss'], label='Train Loss')\n", + " plt.plot(history['val_loss'], label='Validation Loss')\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Loss')\n", + " plt.legend()\n", + " plt.title('Training and Validation Loss')\n", + " \n", + " plt.subplot(1, 2, 2)\n", + " plt.plot(history['train_acc'], label='Train Accuracy')\n", + " plt.plot(history['val_acc'], label='Validation Accuracy')\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Accuracy')\n", + " plt.legend()\n", + " plt.title('Training and Validation Accuracy')\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig('training_history.png')\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2c7e0547", + "metadata": {}, + "outputs": [], + "source": [ + "# ๋ชจ๋ธ ์ €์žฅ ํ•จ์ˆ˜\n", + "def save_model(model, class_names, filename='efficientnet_b0_quickdraw.pth'):\n", + " model_info = {\n", + " 'model_state_dict': model.state_dict(),\n", + " 'class_names': class_names,\n", + " 'model_name': 'efficientnet_b0'\n", + " }\n", + " torch.save(model_info, filename)\n", + " print(f\"๋ชจ๋ธ์ด {filename}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d985a0f8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# ํ…Œ์ŠคํŠธ ์„ธํŠธ์—์„œ ๋ชจ๋ธ ํ‰๊ฐ€\n", + "def evaluate_model(model, test_loader):\n", + " model.eval()\n", + " correct = 0\n", + " total = 0\n", + " \n", + " with torch.no_grad():\n", + " for inputs, labels in tqdm(test_loader):\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " outputs = model(inputs)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " \n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + " \n", + " accuracy = 100 * correct / total\n", + " print(f'ํ…Œ์ŠคํŠธ ์ •ํ™•๋„: {accuracy:.2f}%')\n", + " return accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bda50278", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def main():\n", + " # ๋ฐ์ดํ„ฐ ๋กœ๋“œ\n", + " dataloaders, dataset_sizes, class_names = load_datasets()\n", + " \n", + " # ๋ชจ๋ธ ์„ค์ •\n", + " num_classes = len(class_names)\n", + " model = setup_model(num_classes)\n", + " \n", + " # ์†์‹ค ํ•จ์ˆ˜์™€ ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •\n", + " criterion = nn.CrossEntropyLoss()\n", + " optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)\n", + " \n", + " # ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋Ÿฌ (10 ์—ํฌํฌ๋งˆ๋‹ค ํ•™์Šต๋ฅ ์„ 0.1๋ฐฐ๋กœ ๊ฐ์†Œ)\n", + " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)\n", + " \n", + " # ๋ชจ๋ธ ํ•™์Šต\n", + " print(\"๋ชจ๋ธ ํ•™์Šต ์‹œ์ž‘...\")\n", + " model, history = train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, num_epochs=20)\n", + " \n", + " # ํ•™์Šต ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”\n", + " plot_training_history(history)\n", + " \n", + " # ๊ฒ€์ฆ ์„ธํŠธ์—์„œ ํ‰๊ฐ€\n", + " print(\"๊ฒ€์ฆ ์„ธํŠธ์—์„œ ๋ชจ๋ธ ํ‰๊ฐ€ ์ค‘...\")\n", + " evaluate_model(model, dataloaders['val'])\n", + " \n", + " # ๋ชจ๋ธ ์ €์žฅ\n", + " save_model(model, class_names)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6ca361f7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ํด๋ž˜์Šค ๊ฐœ์ˆ˜: 345\n", + "ํด๋ž˜์Šค ๋ชฉ๋ก: ['The_Eiffel_Tower', 'The_Great_Wall_of_China', 'The_Mona_Lisa', 'aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan', 'cell_phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 'diving_board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire_hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 'floor_lamp', 'flower', 'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 'goatee', 'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital', 'hot_air_balloon', 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light_bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint_can', 'paintbrush', 'palm_tree', 'panda', 'pants', 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup_truck', 'picture_frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote_control', 'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag', 'smiley_face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop_sign', 'stove', 'strawberry', 'streetlight', 'string_bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 'syringe', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis_racquet', 'tent', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine_bottle', 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']\n", + "๋ฐ์ดํ„ฐ์…‹ ํฌ๊ธฐ: train=828000, val=207000\n", + "ํด๋ž˜์Šค ๊ฐœ์ˆ˜: 345\n", + "๋ชจ๋ธ ํ•™์Šต ์‹œ์ž‘...\n", + "Epoch 1/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [48:52<00:00, 8.82it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 1.6663 Acc: 0.5972\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [04:01<00:00, 26.83it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.3146 Acc: 0.6749\n", + "\n", + "Epoch 2/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [48:57<00:00, 8.81it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 1.2951 Acc: 0.6768\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [04:02<00:00, 26.71it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.2120 Acc: 0.6995\n", + "\n", + "Epoch 3/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:15<00:00, 9.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 1.1890 Acc: 0.6994\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:51<00:00, 27.90it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.5957 Acc: 0.6075\n", + "\n", + "Epoch 4/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:37<00:00, 9.05it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 1.1188 Acc: 0.7151\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:54<00:00, 27.54it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.1327 Acc: 0.7182\n", + "\n", + "Epoch 5/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:20<00:00, 9.11it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 1.0648 Acc: 0.7264\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:51<00:00, 27.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.5308 Acc: 0.6226\n", + "\n", + "Epoch 6/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [50:11<00:00, 8.59it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 1.0204 Acc: 0.7364\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [05:03<00:00, 21.31it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.1138 Acc: 0.7249\n", + "\n", + "Epoch 7/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [49:38<00:00, 8.69it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.9813 Acc: 0.7443\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [04:52<00:00, 22.09it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0951 Acc: 0.7286\n", + "\n", + "Epoch 8/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [48:56<00:00, 8.81it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.9465 Acc: 0.7527\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [04:50<00:00, 22.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0990 Acc: 0.7294\n", + "\n", + "Epoch 9/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [49:51<00:00, 8.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.9154 Acc: 0.7590\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [04:52<00:00, 22.09it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0968 Acc: 0.7317\n", + "\n", + "Epoch 10/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [50:14<00:00, 8.58it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.8875 Acc: 0.7650\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [04:57<00:00, 21.73it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0962 Acc: 0.7316\n", + "\n", + "Epoch 11/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [50:08<00:00, 8.60it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.7159 Acc: 0.8086\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [04:57<00:00, 21.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0554 Acc: 0.7470\n", + "\n", + "Epoch 12/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [49:23<00:00, 8.73it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.6641 Acc: 0.8210\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.86it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0701 Acc: 0.7472\n", + "\n", + "Epoch 13/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:10<00:00, 9.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.6381 Acc: 0.8277\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.84it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0737 Acc: 0.7461\n", + "\n", + "Epoch 14/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:15<00:00, 9.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.6182 Acc: 0.8324\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0880 Acc: 0.7455\n", + "\n", + "Epoch 15/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:08<00:00, 9.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.6005 Acc: 0.8362\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.87it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.0961 Acc: 0.7443\n", + "\n", + "Epoch 16/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:09<00:00, 9.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.5848 Acc: 0.8407\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.85it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.1069 Acc: 0.7441\n", + "\n", + "Epoch 17/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:13<00:00, 9.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.5715 Acc: 0.8434\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:53<00:00, 27.72it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.1176 Acc: 0.7431\n", + "\n", + "Epoch 18/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:45<00:00, 9.03it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.5578 Acc: 0.8467\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.82it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.1342 Acc: 0.7419\n", + "\n", + "Epoch 19/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:13<00:00, 9.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.5459 Acc: 0.8491\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.82it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.1417 Acc: 0.7410\n", + "\n", + "Epoch 20/20\n", + "----------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 25875/25875 [47:49<00:00, 9.02it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train Loss: 0.5342 Acc: 0.8521\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:56<00:00, 27.36it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val Loss: 1.1481 Acc: 0.7401\n", + "\n", + "ํ•™์Šต ์™„๋ฃŒ: 1051m 27s\n", + "์ตœ๊ณ  ๊ฒ€์ฆ ์ •ํ™•๋„: 0.7472\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๊ฒ€์ฆ ์„ธํŠธ์—์„œ ๋ชจ๋ธ ํ‰๊ฐ€ ์ค‘...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 6469/6469 [03:52<00:00, 27.84it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ํ…Œ์ŠคํŠธ ์ •ํ™•๋„: 74.01%\n", + "๋ชจ๋ธ์ด efficientnet_b0_quickdraw.pth์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebb560a3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}