Skip to content

Commit

Permalink
clearn: cexact: properly initialize storage before learning
Browse files Browse the repository at this point in the history
  • Loading branch information
RenatoGeh committed Aug 29, 2023
1 parent 984f8e7 commit 4488c12
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions pasp/cexact.c
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ bool prob_obs(program_t *P, observations_t *obs, bool lstables_sat, prob_storage
PyErr_SetString(PyExc_ValueError, "received NULL prob_storage_t as argument!");
goto cleanup;
}
if (!prob_obs_reuse(P, obs, lstables_sat, ret, Q, derive)) goto cleanup;
if (!prob_obs_reuse(P, obs, lstables_sat, ret, Q, derive, num_procs)) goto cleanup;

return true;
cleanup:
Expand All @@ -798,10 +798,9 @@ bool prob_obs(program_t *P, observations_t *obs, bool lstables_sat, prob_storage
}

bool prob_obs_reuse(program_t *P, observations_t *obs, bool lstable_sat, prob_storage_t *ret,
prob_storage_t Q[NUM_PROCS], bool derive) {
prob_storage_t Q[NUM_PROCS], bool derive, size_t num_procs) {
total_choice_t theta;
size_t total_choice_n = get_num_facts(P);
size_t num_procs = estimate_nprocs(total_choice_n + P->AD_n + P->NA_n);
bool busy_procs[NUM_PROCS] = {0};
storage_t S[NUM_PROCS] = {{0}};
size_t i;
Expand Down
2 changes: 1 addition & 1 deletion pasp/cexact.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ bool prob_obs(program_t *P, observations_t *obs, bool lstables_sat, prob_storage
/* Same as prob_obs, but reuse the prob_storage_t's in Q. It's memory safe to assign ret to &Q[0]
* or NULL; the latter used if the user prefers to access data directly from Q. */
bool prob_obs_reuse(program_t *P, observations_t *obs, bool lstable_sat, prob_storage_t *ret,
prob_storage_t Q[NUM_PROCS], bool derive);
prob_storage_t Q[NUM_PROCS], bool derive, size_t num_procs);

#endif
4 changes: 2 additions & 2 deletions pasp/clearn.c
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ bool learn(program_t *P, PyArrayObject *obs, PyArrayObject *obs_counts,
if (!forward_neural(P, &O)) goto cleanup;

/* Compute probabilities. */
if (!prob_obs_reuse(P, &O, lstable_sat, NULL, Q, derive)) goto cleanup;
if (!prob_obs_reuse(P, &O, lstable_sat, NULL, Q, derive, num_procs)) goto cleanup;

alg[which](P, &Q[0], N, eta, obs_counts, &O);

Expand Down Expand Up @@ -387,7 +387,7 @@ bool learn_batch(program_t *P, PyArrayObject *obs, size_t niters, double eta, si
if (!forward_neural(P, &O)) goto cleanup;

/* Compute probabilities. */
if (!prob_obs_reuse(P, &O, lstable_sat, NULL, Q, derive)) goto cleanup;
if (!prob_obs_reuse(P, &O, lstable_sat, NULL, Q, derive, num_procs)) goto cleanup;

alg[which](P, &Q[0], &O, eta, smooth);

Expand Down
2 changes: 1 addition & 1 deletion pasp/cstorage.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ bool prob_storage_learnable(prob_storage_t *S) { return S->n || S->m || S->pr ||

size_t init_prob_storage_seq(prob_storage_t Q[NUM_PROCS], program_t *P, observations_t *O) {
size_t total_choice_n = get_num_facts(P);
size_t num_procs = estimate_nprocs(total_choice_n + P->AD_n);
size_t num_procs = estimate_nprocs(total_choice_n + P->AD_n + P->NA_n);
size_t i = 0;

for (i = 0; i < num_procs; ++i) {
Expand Down

0 comments on commit 4488c12

Please sign in to comment.