Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 11, 2024
1 parent 15797c7 commit eb6048c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 121 deletions.
35 changes: 21 additions & 14 deletions examples/dgcnn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import ModelNet, MedShapeNet
from torch_geometric.datasets import MedShapeNet, ModelNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, DynamicEdgeConv, global_max_pool

Expand All @@ -16,7 +16,7 @@
'--dataset',
type=str,
default='modelnet10',
choices=['modelnet10', 'modelnet40','medshapenet'],
choices=['modelnet10', 'modelnet40', 'medshapenet'],
help='Dataset name.',
)
parser.add_argument(
Expand All @@ -42,27 +42,33 @@

print('The Dataset is: ', args.dataset)
if args.dataset == 'modelnet40':
print('Loading training data')
train_dataset = ModelNet(root, '40', True, transform, pre_transform)
print('Loading test data')
test_dataset = ModelNet(root, '40', False, transform, pre_transform)
print('Loading training data')
train_dataset = ModelNet(root, '40', True, transform, pre_transform)
print('Loading test data')
test_dataset = ModelNet(root, '40', False, transform, pre_transform)
elif args.dataset == 'medshapenet':
print('Loading training data')
train_dataset = MedShapeNet(root=root, size=50, split="train", pre_transform=pre_transform, transform=transform, force_reload=False)
print('Loading test data')
test_dataset = MedShapeNet(root=root, size=50, split="test", pre_transform=pre_transform, transform=transform, force_reload=False)
print('Loading training data')
train_dataset = MedShapeNet(root=root, size=50, split="train",
pre_transform=pre_transform,
transform=transform, force_reload=False)
print('Loading test data')
test_dataset = MedShapeNet(root=root, size=50, split="test",
pre_transform=pre_transform,
transform=transform, force_reload=False)
else:
print('Loading training data')
train_dataset = ModelNet(root, '10', True, transform, pre_transform)
print('Loading test data')
test_dataset = ModelNet(root, '10', False, transform, pre_transform)
print('Loading training data')
train_dataset = ModelNet(root, '10', True, transform, pre_transform)
print('Loading test data')
test_dataset = ModelNet(root, '10', False, transform, pre_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers)

print('Running model')


class Net(torch.nn.Module):
def __init__(self, out_channels, k=20, aggr='max'):
super().__init__()
Expand Down Expand Up @@ -115,6 +121,7 @@ def test(loader):
correct += pred.eq(data.y).sum().item()
return correct / len(loader.dataset)


for epoch in range(1, num_epochs):
loss = train()
test_acc = test(test_loader)
Expand Down
212 changes: 105 additions & 107 deletions torch_geometric/datasets/medshapenet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Callable, List, Optional
import glob
import os
import os.path as osp
import numpy as np
from typing import Callable, List, Optional

import numpy as np
import torch

from torch_geometric.data import Data, InMemoryDataset


class MedShapeNet(InMemoryDataset):
r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale
r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale
Dataset of 3D Medical Shapes for Computer Vision"
<https://arxiv.org/abs/2308.16139>`_ paper,
containing 8 different type of structures (classes).
Expand All @@ -26,7 +26,7 @@ class MedShapeNet(InMemoryDataset):
Args:
root (str): Root directory where the dataset should be saved.
size (int): Number of invividual 3D structures to download per
size (int): Number of invividual 3D structures to download per
type (classes).
train (bool, optional): If :obj:`True`, loads the training dataset,
otherwise the test dataset. (default: :obj:`True`)
Expand All @@ -45,7 +45,7 @@ class MedShapeNet(InMemoryDataset):
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
def __init__(
def __init__(
self,
root: str,
size: int = 100,
Expand All @@ -66,104 +66,102 @@ def __init__(
path = self.processed_paths[2]
self.load(path)


@property
def raw_file_names(self) -> List[str]:
return [
'3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY',
'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy'
]

@property
def processed_file_names(self) -> List[str]:
return ['train.pt','val','test.pt']

@property
def raw_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
downloading."""
if isinstance(self.raw_file_names, list):
return [osp.join(self.raw_dir, f) for f in self.raw_file_names]
else:
return [osp.join(self.raw_dir, self.raw_file_names)]

def process(self) -> None:
from MedShapeNet import MedShapeNet as msn
from torch.utils.data import random_split
import urllib3

msn_instance = msn()

urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)

list_of_datasets = msn_instance.datasets(False)
list_of_datasets = list(filter(lambda x: x not in ['medshapenetcore/ASOCA',
'medshapenetcore/AVT',
'medshapenetcore/AutoImplantCraniotomy',
'medshapenetcore/FaceVR'],
list_of_datasets))

train_size = int(0.7 * self.size) # 70% for training
val_size = int(0.15 * self.size) # 15% for validation
test_size = self.size - train_size - val_size # Remainder for testing

train_list, val_list, test_list = [], [], []
for dataset in list_of_datasets:
self.newpath = self.root + '/' + dataset.split("/")[1]
if not os.path.exists(self.newpath):
os.makedirs(self.newpath)
stl_files = msn_instance.dataset_files(dataset, '.stl')
stl_files = stl_files[:self.size]

train_data, val_data, test_data = random_split(stl_files, [train_size,
val_size,
test_size])

train_list.extend([stl_files[idx] for idx in train_data.indices])
val_list.extend([stl_files[idx] for idx in val_data.indices])
test_list.extend([stl_files[idx] for idx in test_data.indices])

for stl_file in stl_files:
msn_instance.download_stl_as_numpy(bucket_name = dataset,
stl_file = stl_file,
output_dir = self.newpath,
print_output=False)


class_mapping = {
'3DTeethSeg': 0,
'CoronaryArteries': 1,
'FLARE': 2,
'KITS': 3,
'PULMONARY': 4,
'SurgicalInstruments': 5,
'ThoracicAorta_Saitta': 6,
'ToothFairy': 7
}

for dataset, path in zip([train_list, val_list, test_list],
self.processed_paths):
data_list = []
for item in dataset:
class_name = item.split("/")[0]
item = item.split("stl")[0]
target = class_mapping[class_name]
file = osp.join(self.root, item + 'npz')

data = np.load(file)
pre_data_list = Data(
pos = torch.tensor(data["vertices"], dtype=torch.float),
face = torch.tensor(data["faces"], dtype=torch.long).t().contiguous()
)
pre_data_list.y = torch.tensor([target], dtype=torch.long)
data_list.append(pre_data_list)

if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]

if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]

self.save(data_list,path)


@property
def raw_file_names(self) -> List[str]:
return [
'3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY',
'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy'
]

@property
def processed_file_names(self) -> List[str]:
return ['train.pt', 'val', 'test.pt']

@property
def raw_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
downloading.
"""
if isinstance(self.raw_file_names, list):
return [osp.join(self.raw_dir, f) for f in self.raw_file_names]
else:
return [osp.join(self.raw_dir, self.raw_file_names)]

def process(self) -> None:
import urllib3
from MedShapeNet import MedShapeNet as msn
from torch.utils.data import random_split

msn_instance = msn()

urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)

list_of_datasets = msn_instance.datasets(False)
list_of_datasets = list(
filter(
lambda x: x not in [
'medshapenetcore/ASOCA', 'medshapenetcore/AVT',
'medshapenetcore/AutoImplantCraniotomy',
'medshapenetcore/FaceVR'
], list_of_datasets))

train_size = int(0.7 * self.size) # 70% for training
val_size = int(0.15 * self.size) # 15% for validation
test_size = self.size - train_size - val_size # Remainder for testing

train_list, val_list, test_list = [], [], []
for dataset in list_of_datasets:
self.newpath = self.root + '/' + dataset.split("/")[1]
if not os.path.exists(self.newpath):
os.makedirs(self.newpath)
stl_files = msn_instance.dataset_files(dataset, '.stl')
stl_files = stl_files[:self.size]

train_data, val_data, test_data = random_split(
stl_files, [train_size, val_size, test_size])

train_list.extend([stl_files[idx] for idx in train_data.indices])
val_list.extend([stl_files[idx] for idx in val_data.indices])
test_list.extend([stl_files[idx] for idx in test_data.indices])

for stl_file in stl_files:
msn_instance.download_stl_as_numpy(bucket_name=dataset,
stl_file=stl_file,
output_dir=self.newpath,
print_output=False)

class_mapping = {
'3DTeethSeg': 0,
'CoronaryArteries': 1,
'FLARE': 2,
'KITS': 3,
'PULMONARY': 4,
'SurgicalInstruments': 5,
'ThoracicAorta_Saitta': 6,
'ToothFairy': 7
}

for dataset, path in zip([train_list, val_list, test_list],
self.processed_paths):
data_list = []
for item in dataset:
class_name = item.split("/")[0]
item = item.split("stl")[0]
target = class_mapping[class_name]
file = osp.join(self.root, item + 'npz')

data = np.load(file)
pre_data_list = Data(
pos=torch.tensor(data["vertices"], dtype=torch.float),
face=torch.tensor(data["faces"],
dtype=torch.long).t().contiguous())
pre_data_list.y = torch.tensor([target], dtype=torch.long)
data_list.append(pre_data_list)

if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]

if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]

self.save(data_list, path)

0 comments on commit eb6048c

Please sign in to comment.