Skip to content

Commit 0f15cea

Browse files
authored
warn if taking derivs of a method calling model calculate without model, updateNodes, or constantNodes args (#1566)
* Fix misspelling in an error msg. * Add a bit more on alternatives to default samplers in documentation. * Provide a test of compilation in Chapter 4 of manual. * Make BNP docn more user-friendly for less theoretical users. * Make minor edit to nimDerivs roxygen. * Fix up nimDerivs roxygen to be more clear about derivs of nfs. * Add check that correct args passed to nimDerivs when taking deriv of method containing calculate call (NCT issue 557). * Fix test for new AD warning. * Extend checking of nimDerivs args to nested case, and hide checking behind an option. * Fix typo in function name. * Make slight change to comment.
1 parent 7d2f839 commit 0f15cea

5 files changed

Lines changed: 288 additions & 11 deletions

File tree

packages/nimble/R/RCfunction_core.R

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,66 @@ nf_checkDSLcode_buildDerivs <- function(code, buildDerivs) {
263263
if(isFALSE(buildDerivs) || !length(buildDerivs) || is.null(buildDerivs) ||
264264
(is.character(buildDerivs) && !methodName %in% buildDerivs) ||
265265
(is.list(buildDerivs) && !methodName %in% names(buildDerivs)))
266-
message(" [Note] Detected use of `nimDerivs` with a function or method, `", methodName, "`, for which `buildDerivs` has not been set. This nimbleFunction cannot be compiled.")
266+
messageIfVerbose(" [Note] Detected use of `nimDerivs` with a function or method, `", methodName, "`, for which `buildDerivs` has not been set. This nimbleFunction cannot be compiled.")
267267
}
268268

269269
}
270270
}
271271
invisible(NULL)
272272
}
273273

274+
nf_checkDSLcode_checkForCalc <- function(code) {
275+
code <- body(code)
276+
return(sum(all.names(code) == "calculate") != sum(all.vars(code)=="calculate"))
277+
}
278+
279+
nf_checkDSLcode_checkDerivsOf <- function(code) {
280+
code <- body(code)
281+
derivsFound <- which(findDerivsCalls(code))
282+
if(length(derivsFound)) {
283+
derivsOf <- sapply(derivsFound, function(i)
284+
return(code[[i]][[3]][[2]][[1]]))
285+
return(as.character(derivsOf[sapply(derivsOf, is.name)]))
286+
}
287+
return(NULL)
288+
}
289+
290+
findDerivsCalls <- function(code) {
291+
## This assumes `derivs()` call is from assignment like `var <- derivs()`.
292+
sapply(code, function(expr)
293+
length(expr) >= 3 && length(expr[[1]]) == 1 &&
294+
as.character(expr[[1]]) %in% c("=", "<-", "<<-") &&
295+
length(expr[[3]]) > 1 && length(expr[[3]][[1]]) == 1 &&
296+
as.character(expr[[3]][[1]]) %in% c('derivs', 'nimDerivs'))
297+
}
298+
299+
checkNestedCalcCall <- function(functionName, methodsWithCalc, methodsDerivsOf) {
300+
if(functionName %in% methodsWithCalc) return(TRUE)
301+
if(functionName %in% names(methodsDerivsOf))
302+
return(any(sapply(methodsDerivsOf[[functionName]], checkNestedCalcCall,
303+
methodsWithCalc, methodsDerivsOf)))
304+
return(FALSE)
305+
}
306+
307+
nf_checkDSLcode_calcDerivsArgs <- function(code, methodsWithCalc, methodsDerivsOf) {
308+
code <- body(code)
309+
derivsFound <- which(findDerivsCalls(code))
310+
for(idx in derivsFound) {
311+
argNames <- names(code[[idx]][[3]])
312+
call <- code[[idx]][[3]][[2]][[1]]
313+
if(length(call) == 1 && checkNestedCalcCall(as.character(call), methodsWithCalc, methodsDerivsOf) &&
314+
length(setdiff(c('model', 'constantNodes', 'updateNodes'), argNames)))
315+
messageIfVerbose(" [Warning] Detected use of `nimDerivs` on a function or method, `", code[[idx]][[3]][[2]][[1]], "`,\n",
316+
" that appears to contain the use of `calculate` on a model.\n",
317+
" If model calculations are done in the method being differentiated, the 'model'\n",
318+
" argument to 'nimDerivs' should be included to ensure correct restoration of\n",
319+
" values in the model, and the 'updateNodes' and 'constantNodes' arguments\n",
320+
" should also be provided (see Section 16.7.2 of the User Manual).")
321+
}
322+
invisible(NULL)
323+
}
324+
325+
274326
nf_checkDSLcode <- function(code, methodNames, setupVarNames, args, where = NULL) {
275327
validCalls <- c(names(sizeCalls),
276328
otherDSLcalls,

packages/nimble/R/nimbleFunction_Rderivs.R

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -703,15 +703,7 @@ nimDerivs_nf <- function(call = NA,
703703
if(e$restoreInfo$deepestDepth < e$restoreInfo$currentDepth)
704704
e$restoreInfo$deepestDepth <- e$restoreInfo$currentDepth
705705
}
706-
} else { # partial check for whether there is a model in the nimbleFunction
707-
if(is(derivFxn, 'refMethodDef') && is.nf(e$.self)) {
708-
isModel <- sapply(names(e), function(x) is.model(e[[x]]))
709-
if(any(isModel)) {
710-
modelElement <- names(e)[which(isModel)]
711-
warning("nimDerivs_nf: detected a model, ", paste(modelElement, collapse = ','), ", associated with the nimbleFunction whose method is being differentiated. If model calculations are done in the method being differentiated, the 'model' argument to 'nimDerivs' should be included to ensure correct restoration of values in the model.")
712-
}
713-
}
714-
}
706+
}
715707

716708
## standardize the derivFxnCall arguments
717709
derivFxnCall <- match.call(derivFxn, derivFxnCall)

packages/nimble/R/nimbleFunction_core.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,25 @@ nimbleFunction <- function(setup = NULL,
8585
force(where) # so that we can get to namespace where a nf is defined by using topenv(parent.frame(2)) in getNimbleFunctionEnvironment()
8686
if(is.logical(setup)) if(setup) setup <- function() {} else setup <- NULL
8787

88+
8889
## Check for correct entries in `buildDerivs` separately from `nfMethodRC$new()` because
8990
## that only has access to `thisBuildDerivs`, and we need to check if `buildDerivs` is set
9091
## for the method on which `nimDerivs` is called.
9192
tmp <- sapply(c(list(run = run), methods), nf_checkDSLcode_buildDerivs, buildDerivs)
93+
94+
## Check that if a model calculate is in the code of `run` or another method on
95+
## which `derivs` is called, that the `model`, `updateNodes`,and `constantNodes`
96+
## arguments are provided.
97+
if(getNimbleOption('checkDerivsArgs') && length(buildDerivs)) {
98+
allMethods <- c(list(run = run), methods)
99+
if(is.character(buildDerivs)) nms <- buildDerivs else nms <- names(buildDerivs)
100+
methodsWithCalc <- sapply(allMethods[nms], nf_checkDSLcode_checkForCalc)
101+
methodsWithCalc <- nms[methodsWithCalc]
102+
methodsDerivsOf <- sapply(allMethods, nf_checkDSLcode_checkDerivsOf)
103+
methodsDerivsOf <- methodsDerivsOf[!sapply(methodsDerivsOf, is.null)]
104+
if(length(methodsWithCalc))
105+
tmp <- sapply(c(list(run = run), methods), nf_checkDSLcode_calcDerivsArgs, methodsWithCalc, methodsDerivsOf)
106+
}
92107

93108
if(is.null(setup)) {
94109
if(length(methods) > 0) stop('Cannot provide multiple methods if there is no setup function. Use "setup = function(){}" or "setup = TRUE" if you need a setup function that does not do anything', call. = FALSE)

packages/nimble/R/options.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ nimOptimMethod("bobyqa",
226226
useOldcWiseRule = FALSE, # This is a safety toggle for one change in sizeBinaryCwise, 1/24/23. After a while we can remove this.
227227
stripUnusedTypeDefs = TRUE,
228228
digits = NULL,
229-
enableVirtualNodeFunctionDefs = FALSE
229+
enableVirtualNodeFunctionDefs = FALSE,
230+
checkDerivsArgs = TRUE
230231
)
231232
)
232233

packages/nimble/tests/testthat/test-ADerrorTrapping.R

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,220 @@ test_that("Incorrect use of buildDerivs=TRUE in nimbleFunction with setup.", {
327327
)
328328
})
329329

330+
test_that("Warning message works for use of nimDerivs with model calculate and incorrect args", {
331+
expect_silent(
332+
mynf <- nimbleFunction(
333+
setup = function(model){
334+
paramNodes = 'psi[1:4]'
335+
},
336+
run = function(x = double(1), alpha=double(1)) {
337+
returnType(double(0))
338+
inds <- 1:length(x)
339+
340+
tmp <- derivs(dens_calc(x), inds, order = c(0,1), model = model, constantNodes = "", updateNodes = "")
341+
342+
## Do AD on ddirch directly. This works.
343+
tmp <- derivs(dens_direct(x, alpha), inds, order = c(0,1))
344+
345+
## Do AD on model$calculate. This works.
346+
tmp <- derivs(model$calculate(paramNodes), wrt = paramNodes, order = c(0,1))
347+
348+
return(dens_calc(x))
349+
},
350+
methods = list(
351+
## This mimics calcPrior_p in nimbleQuad
352+
dens_calc = function(x = double(1)) {
353+
values(model, paramNodes) <<- x
354+
result <- model$calculate(paramNodes)
355+
returnType(double(0))
356+
return(result)
357+
},
358+
dens_direct = function(x = double(1), alpha=double(1)) {
359+
result <- ddirch(x, alpha, log = TRUE)
360+
calculate <- 7
361+
returnType(double(0))
362+
return(result)
363+
}
364+
),
365+
buildDerivs = c('dens_calc','dens_direct')
366+
))
367+
368+
expect_message(
369+
mynf <- nimbleFunction(
370+
setup = function(model){
371+
paramNodes = 'psi[1:4]'
372+
},
373+
run = function(x = double(1), alpha=double(1)) {
374+
returnType(double(0))
375+
inds <- 1:length(x)
376+
377+
tmp = derivs(dens_calc(x), inds, order = c(0,1), constantNodes = "", updateNodes = "")
378+
379+
return(dens_calc(x))
380+
},
381+
methods = list(
382+
## This mimics calcPrior_p in nimbleQuad
383+
dens_calc = function(x = double(1)) {
384+
values(model, paramNodes) <<- x
385+
result <- model$calculate(paramNodes)
386+
returnType(double(0))
387+
return(result)
388+
},
389+
dens_direct = function(x = double(1), alpha=double(1)) {
390+
result <- ddirch(x, alpha, log = TRUE)
391+
calculate <- 7
392+
returnType(double(0))
393+
return(result)
394+
}
395+
),
396+
buildDerivs = c('dens_calc','dens_direct')
397+
), "appears to contain the use of `calculate` on a model")
398+
399+
expect_message(
400+
mynf <- nimbleFunction(
401+
setup = function(model){
402+
paramNodes = 'psi[1:4]'
403+
},
404+
run = function(x = double(1), alpha=double(1)) {
405+
returnType(double(0))
406+
inds <- 1:length(x)
407+
408+
tmp <- derivs(dens_calc(x), inds, order = c(0,1), model = model)
409+
410+
return(dens_calc(x))
411+
},
412+
methods = list(
413+
## This mimics calcPrior_p in nimbleQuad
414+
dens_calc = function(x = double(1)) {
415+
values(model, paramNodes) <<- x
416+
result <- model$calculate(paramNodes)
417+
returnType(double(0))
418+
return(result)
419+
},
420+
dens_direct = function(x = double(1), alpha=double(1)) {
421+
result <- ddirch(x, alpha, log = TRUE)
422+
calculate <- 7
423+
returnType(double(0))
424+
return(result)
425+
}
426+
),
427+
buildDerivs = c('dens_calc','dens_direct')
428+
), "appears to contain the use of `calculate` on a model")
429+
})
430+
431+
432+
433+
434+
435+
test_that("Warning message works for use of nimDerivs with nested model calculate and incorrect args", {
436+
expect_silent(
437+
mynf <- nimbleFunction(
438+
setup = function(model){
439+
paramNodes = 'psi[1:4]'
440+
},
441+
run = function(x = double(1), alpha=double(1)) {
442+
},
443+
methods = list(
444+
inner_logLik = function(reTransform = double(1)) {
445+
values(model, randomEffectsNodes) <<- reTransform
446+
ans <- model$calculate(innerCalcNodes)
447+
return(ans)
448+
returnType(double())
449+
},
450+
gr_inner_logLik_internal = function(reTransform = double(1)) {
451+
ans <- derivs(inner_logLik(reTransform), wrt = re_indices_inner, order = 1, model = model,
452+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes)
453+
return(ans$jacobian[1,])
454+
returnType(double(1))
455+
},
456+
## Double taping for efficiency
457+
he_inner_logLik_internal = function(reTransform = double(1)) {
458+
ans <- derivs(gr_inner_logLik_internal(reTransform), wrt = re_indices_inner, order = 1, model = model,
459+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes)
460+
return(ans$jacobian)
461+
returnType(double(2))
462+
},
463+
he_inner_logLik_internal = function(reTransform = double(1)) {
464+
ans <- derivs(gr_inner_logLik_internal(reTransform), wrt = re_indices_inner, order = 0, model = model,
465+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes)
466+
return(ans$value)
467+
returnType(double(2))
468+
}
469+
), buildDerivs = c('inner_logLik','gr_inner_logLik_internal','he_inner_logLik_internal')
470+
)
471+
)
472+
473+
expect_message(
474+
mynf <- nimbleFunction(
475+
setup = function(model){
476+
paramNodes = 'psi[1:4]'
477+
},
478+
run = function(x = double(1), alpha=double(1)) {
479+
},
480+
methods = list(
481+
inner_logLik = function(reTransform = double(1)) {
482+
values(model, randomEffectsNodes) <<- reTransform
483+
ans <- model$calculate(innerCalcNodes)
484+
return(ans)
485+
returnType(double())
486+
},
487+
gr_inner_logLik_internal = function(reTransform = double(1)) {
488+
ans <- derivs(inner_logLik(reTransform), wrt = re_indices_inner, order = 1, model = model,
489+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes)
490+
return(ans$jacobian[1,])
491+
returnType(double(1))
492+
},
493+
he_inner_logLik_internal = function(reTransform = double(1)) {
494+
ans <- derivs(gr_inner_logLik_internal(reTransform), wrt = re_indices_inner, order = 1, model = model,
495+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes)
496+
return(ans$jacobian)
497+
returnType(double(2))
498+
},
499+
he_inner_logLik = function(reTransform = double(1)) {
500+
ans <- derivs(gr_inner_logLik_internal(reTransform), wrt = re_indices_inner, order = 0, model = model,
501+
constantNodes = inner_constantNodes) # Missing `updateNodes`.
502+
return(ans$value)
503+
returnType(double(2))
504+
}
505+
),
506+
buildDerivs = c('inner_logLik','gr_inner_logLik_internal','he_inner_logLik_internal')
507+
), "appears to contain the use of `calculate` on a model")
508+
509+
expect_message(
510+
mynf <- nimbleFunction(
511+
setup = function(model){
512+
paramNodes = 'psi[1:4]'
513+
},
514+
run = function(x = double(1), alpha=double(1)) {
515+
},
516+
methods = list(
517+
inner_logLik = function(reTransform = double(1)) {
518+
values(model, randomEffectsNodes) <<- reTransform
519+
ans <- model$calculate(innerCalcNodes)
520+
return(ans)
521+
returnType(double())
522+
},
523+
gr_inner_logLik_internal = function(reTransform = double(1)) {
524+
ans <- derivs(inner_logLik(reTransform), wrt = re_indices_inner, order = 1, model = model,
525+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes)
526+
return(ans$jacobian[1,])
527+
returnType(double(1))
528+
},
529+
he_inner_logLik_internal = function(reTransform = double(1)) {
530+
ans <- derivs(gr_inner_logLik_internal(reTransform), wrt = re_indices_inner, order = 1,
531+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes) # Missing `model`.
532+
return(ans$jacobian)
533+
returnType(double(2))
534+
},
535+
he_inner_logLik = function(reTransform = double(1)) {
536+
ans <- derivs(gr_inner_logLik_internal(reTransform), wrt = re_indices_inner, order = 0, model = model,
537+
updateNodes = inner_updateNodes, constantNodes = inner_constantNodes)
538+
return(ans$value)
539+
returnType(double(2))
540+
}
541+
),
542+
buildDerivs = c('inner_logLik','gr_inner_logLik_internal','he_inner_logLik_internal')
543+
), "appears to contain the use of `calculate` on a model")
544+
})
545+
546+

0 commit comments

Comments
 (0)