Skip to content

Commit

Permalink
changelog and iterable lists
Browse files Browse the repository at this point in the history
  • Loading branch information
jdhenaos committed Dec 11, 2024
1 parent e2f7faf commit 8a1e2d4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `MedShapeNet` Dataset ([#9823](https://github.com/pyg-team/pytorch_geometric/pull/9823))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand All @@ -25,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Adapt `dgcnn_classification` example to work with `ModelNet` and `MedShapeNet` Datasets ([#9823](https://github.com/pyg-team/pytorch_geometric/pull/9823))
- Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606))
- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807))

Expand Down
32 changes: 22 additions & 10 deletions torch_geometric/datasets/medshapenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,33 +88,45 @@ def process(self) -> None:
from MedShapeNet import MedShapeNet as msn
from torch.utils.data import random_split
import urllib3

msn_instance = msn()

pool = urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)
urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50)

list_of_datasets = msn_instance.datasets(False)
list_of_datasets = list(filter(lambda x: x not in ['medshapenetcore/ASOCA','medshapenetcore/AVT','medshapenetcore/AutoImplantCraniotomy','medshapenetcore/FaceVR'], list_of_datasets))

list_of_datasets = list(filter(lambda x: x not in ['medshapenetcore/ASOCA',
'medshapenetcore/AVT',
'medshapenetcore/AutoImplantCraniotomy',
'medshapenetcore/FaceVR'],
list_of_datasets))

train_size = int(0.7 * self.size) # 70% for training
val_size = int(0.15 * self.size) # 15% for validation
test_size = self.size - train_size - val_size # Remainder for testing

train_list, val_list, test_list = [], [], []
train_list: List[str] = []
val_list: List[str] = []
test_list: List[str] = []
for dataset in list_of_datasets:
self.newpath = self.root + '/' + dataset.split("/")[1]
if not os.path.exists(self.newpath):
os.makedirs(self.newpath)
stl_files = msn_instance.dataset_files(dataset, '.stl')
stl_files = stl_files[:self.size]

train_data, val_data, test_data = random_split(stl_files, [train_size, val_size, test_size])
train_list.extend(train_data)
val_list.extend(val_data)
test_list.extend(test_data)
train_data, val_data, test_data = random_split(stl_files, [train_size,
val_size,
test_size])
train_list.extend(list(train_data))
val_list.extend(list(val_data))
test_list.extend(list(test_data))

for stl_file in stl_files:
msn_instance.download_stl_as_numpy(bucket_name = dataset, stl_file = stl_file, output_dir = self.newpath, print_output=False)
msn_instance.download_stl_as_numpy(bucket_name = dataset,
stl_file = stl_file,
output_dir = self.newpath,
print_output=False)


class_mapping = {
'3DTeethSeg': 0,
Expand Down

0 comments on commit 8a1e2d4

Please sign in to comment.