-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
37 lines (32 loc) · 968 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Author: Rao Yulong
import pickle
import parameter
import numpy as np
import pandas as pd
import random
import torch
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
seed = 2021
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
def save_obj(obj, name ):
with open(name + '.pkl', 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(name ):
with open( name + '.pkl', 'rb') as f:
return pickle.load(f)
class presDataset(torch.utils.data.Dataset):
def __init__(self, a, b):
self.pS_array, self.pH_array = a, b
def __getitem__(self, idx):
sid = self.pS_array[idx]
hid = self.pH_array[idx]
return sid, hid
def __len__(self):
return self.pH_array.shape[0]