-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
45 lines (33 loc) · 985 Bytes
/
train.py
File metadata and controls
45 lines (33 loc) · 985 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import torch
import gc
from rfdetr import RFDETRNano
torch.cuda.empty_cache()
gc.collect()
DATASET_DIR = "./dataset_64k"
OUTPUT_DIR = "/"
NUM_CLASSES = 31
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
print(f"Dataset: {DATASET_DIR}")
print(f"Output: {OUTPUT_DIR}")
print("\nModel Initialization")
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
from rfdetr import RFDETRMedium
model = RFDETRMedium(NUM_CLASSES=NUM_CLASSES)
model.train(
dataset_dir=DATASET_DIR,
epochs=15,
batch_size=16,
grad_accum_steps=1,
lr=1e-5,
output_dir="./rfdetr_model",
)
print("Training Complete!")