Skip to content

Commit

Permalink
adjusted=False for BalancedAccuracy in pb (#6)
Browse files Browse the repository at this point in the history
* adjusted=False for BalancedAccuracy in pb

* lint
  • Loading branch information
frcaud authored Mar 21, 2023
1 parent 76b6384 commit a48776b
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@


N_FOLDS = 5
problem_title = 'Predict schizophrenia from \
brain grey matter (classification)'
problem_title = "Predict schizophrenia from \
brain grey matter (classification)"

_target_column_name = 'diagnosis'
_prediction_label_names = ['control', 'schizophrenia']
_target_column_name = "diagnosis"
_prediction_label_names = ["control", "schizophrenia"]

# A type (class) which will be used to create wrapper objects for y_pred
Predictions = rw.prediction_types.make_multiclass(
label_names=_prediction_label_names)
label_names=_prediction_label_names
)
# An object implementing the workflow
workflow = rw.workflows.Estimator()

# Score types
score_types = [
rw.score_types.ROCAUC(name='auc'),
rw.score_types.BalancedAccuracy(name='bacc')
rw.score_types.ROCAUC(name="auc"),
rw.score_types.BalancedAccuracy(name="bacc", adjusted=False),
]


Expand All @@ -32,8 +34,8 @@ def get_cv(X, y):
return cv_train.split(X, y)


def _read_data(path, dataset, datatype=['rois', 'vbm']):
""" Read data.
def _read_data(path, dataset, datatype=["rois", "vbm"]):
"""Read data.
Parameters
----------
Expand All @@ -54,45 +56,47 @@ def _read_data(path, dataset, datatype=['rois', 'vbm']):
"""
# Read target
participants = pd.read_csv(os.path.join(
path, 'data', "%s_participants.csv" % dataset))
participants = pd.read_csv(
os.path.join(path, "data", "%s_participants.csv" % dataset)
)
y_arr = participants[_target_column_name].values

x_arr_l = []
# Read ROIs
if 'rois' in datatype:
rois = pd.read_csv(os.path.join(
path, 'data', "%s_rois.csv" % dataset))
x_rois_arr = rois.loc[:, 'l3thVen_GM_Vol':]
if "rois" in datatype:
rois = pd.read_csv(os.path.join(path, "data", "%s_rois.csv" % dataset))
x_rois_arr = rois.loc[:, "l3thVen_GM_Vol":]
assert x_rois_arr.shape[1] == 284
x_arr_l.append(x_rois_arr)

# Read 3d images and mask
if 'vbm' in datatype:
imgs_arr_zip = np.load(os.path.join(path, 'data',
"%s_vbm.npz" % dataset))
x_img_arr = imgs_arr_zip['imgs_arr'].squeeze()
mask_arr = imgs_arr_zip['mask_arr']
if "vbm" in datatype:
imgs_arr_zip = np.load(
os.path.join(path, "data", "%s_vbm.npz" % dataset)
)
x_img_arr = imgs_arr_zip["imgs_arr"].squeeze()
mask_arr = imgs_arr_zip["mask_arr"]
x_img_arr = x_img_arr[:, mask_arr]
x_arr_l.append(x_img_arr)

x_arr = np.concatenate(x_arr_l, axis=1)

if datatype == ['rois', 'vbm']: # TODO: Remove this check
if datatype == ["rois", "vbm"]: # TODO: Remove this check
assert np.all(x_arr[:, :284] == x_rois_arr)
assert np.all(x_arr[:, 284:] == x_img_arr)

return x_arr, y_arr


def get_train_data(path='.', datatype=['rois', 'vbm']):
dataset = 'train'
def get_train_data(path=".", datatype=["rois", "vbm"]):
dataset = "train"
return _read_data(path, dataset, datatype)


def get_test_data(path='.', datatype=['rois', 'vbm']):
dataset = 'test'
def get_test_data(path=".", datatype=["rois", "vbm"]):
dataset = "test"
return _read_data(path, dataset, datatype)


# x_arr, y_arr = get_train_data()
# x_arr, y_arr = get_test_data()

0 comments on commit a48776b

Please sign in to comment.