Skip to content

Commit

Permalink
Improves perfomrance
Browse files Browse the repository at this point in the history
  • Loading branch information
PTNobel committed Dec 16, 2024
1 parent 5e5f0ff commit f8f54e8
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions utils/sherlock_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
import sys
if len(sys.argv) == 1 or len(sys.argv) > 3:
raise RuntimeError()

elif len(sys.argv) == 2:
task_id = 0xEE364A
else len(sys.argv) == 3:
elif len(sys.argv) == 3:
task_id = int(sys.argv[2])

data_dir = "/oak/stanford/groups/candes/for_parth"
Expand All @@ -31,17 +30,17 @@

print('Data loading')
X = ad.matrix.concatenate(
[ad.matrix.dense(covars_dense)] +
[ad.matrix.dense(covars_dense, n_threads=32)] +
[
ad.matrix.snp_unphased(
ad.io.snp_unphased(
os.path.join(cache_dir, f"EUR_subset_chr{chr}.snpdat"),
)
), n_threads=32
)
for chr in chromosomes],
axis=1,
)
print(X.shape)
print(f'{X.shape=}')

rng = np.random.default_rng(task_id)
P = np.random.permutation(y.shape[-1])
Expand All @@ -52,14 +51,18 @@
y_train = y[train_mask]
X_test = X[test_mask]
y_test = y[test_mask]
print(f'{X_train.shape=}')
print(f'{X_test.shape=}')

t0 = time.monotonic()
ti_solve = time.monotonic()
state = ad.grpnet(
X=X_train,
glm=ad.glm.gaussian(y_train),
early_exit=False,
min_ratio=1e-6,
n_threads=32,
)
tf = time.monotonic()
print(f"{tf-t0} seconds for solve")
tf_solve = time.monotonic()


loss = torch.nn.MSELoss()
Expand All @@ -72,6 +75,8 @@
oos[i] = loss(torch.from_numpy(y_hat_test[i]), torch.from_numpy(y_test))
ins[i] = loss(torch.from_numpy(y_hat_train[i]), torch.from_numpy(y_train))

ld, alo, ts, r2 = ai.get_alo_for_sweep(y_train, state, loss)
ti_alo = time.monotonic()
ld, alo, ts, r2 = ai.get_alo_for_sweep(y_train, state, loss, 10)
tf_alo = time.monotonic()

np.savez(sys.argv[1], lamda=ld, alo=alo, oos=oos, in_sample=ins, ts=ts, r2=r2)
np.savez(sys.argv[1], lamda=ld, alo=alo, oos=oos, in_sample=ins, ts=ts, r2=r2, solve_time=tf_solve - ti_solve, alo_time=tf_alo - ti_alo)

0 comments on commit f8f54e8

Please sign in to comment.