Skip to content

Commit

Permalink
[Feature] Plot estimated coefficient functions (#5)
Browse files Browse the repository at this point in the history
* CHG: add plot modules

* CHG: add the essential manual; update the env settings; add unit tests

* FIX: correct the global variable issue

* CHG: adjust the plot settings

* CHG: update the release note; increment version number
  • Loading branch information
egpivo authored Jan 9, 2024
1 parent 42192e2 commit 570a59b
Show file tree
Hide file tree
Showing 12 changed files with 210 additions and 9 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: QuantRegGLasso
Title: Adaptively Weighted Group Lasso for Semiparametic Quantile Rgression Models
Version: 0.3.0
Version: 0.4.0
Authors@R: c(person(
given = "Wen-Ting",
family = "Wang",
Expand Down Expand Up @@ -35,7 +35,7 @@ BugReports: https://github.com/egpivo/QuantRegGLasso/issues
Depends:
R (>= 3.4.0)
Imports:
Rcpp (>= 1.0.10)
Rcpp (>= 1.0.10), ggplot2
LinkingTo: Rcpp, RcppArmadillo
Suggests:
knitr,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ export(qrglasso)
importFrom(Rcpp, evalCpp)
importFrom(splines, bs)
importFrom(stats, runif)
importFrom(graphics, par)
import(ggplot2)
S3method(plot, qrglasso.predict)
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
## QuantRegGLasso 0.3.0 (Release Date: 2024-01-09)
### Overview
- Added a `plot.qrglasso.predict` function for displaying top k coefficient functions via `qrglasso.predict` object.

---

## QuantRegGLasso 0.3.0 (Release Date: 2024-01-08)
### Overview
- Added a `predict` function for estimating top k coefficient functions via `qrglasso` object
- Added a helper function to meet the pre-conditions of `predict`

---

## QuantRegGLasso 0.2.0 (Release Date: 2024-01-06)
### Overview
Expand Down
4 changes: 4 additions & 0 deletions R/global.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
utils::globalVariables(c(
"z",
"coefficient"
))
35 changes: 35 additions & 0 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,38 @@ check_predict_parameters <- function(qrglasso_object, top_k, degree, boundaries)
stop("Please enter a smaller degree")
}
}

#' Internal function: Plot sequentially
#' @keywords internal
#' @param objs Valid ggplot2 objects
#' @return `NULL`
#'
plot_sequentially <- function(objs) {
originalPar <- par(no.readonly = TRUE)
on.exit(par(originalPar))
par(ask = TRUE)
suppressWarnings({
for (obj in objs) {
suppressWarnings(print(obj))
}
})
par(ask = FALSE)
}

#' Internal function: Plot 2D fields for cross validation results
#' @keywords internal
#' @param data A dataframe contains columns ``z``, ``coefficient``
#' @param variate A character represent the title
#' @return A ggplot object
plot_coefficient_function <- function(data, variate) {
default_theme <- theme_classic() +
theme(
text = element_text(size = 24),
plot.title = element_text(hjust = 0.5)
)
result <- ggplot(data, aes(x = z, y = coefficient)) +
geom_point(col="#4634eb") +
ggtitle(variate) +
default_theme
return(result)
}
57 changes: 52 additions & 5 deletions R/qrglasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,15 @@ qrglasso <-
#'
#' @description Predict the top-k coefficient functions
#'
#' @param qrglasso_object An `qrglasso` class object.
#' @param qrglasso_object An \code{qrglasso} class object.
#' @param top_k Integer. A matrix of the top K estimated functions. Default is 5.
#' @param degree Integer. Degree of the piecewise polynomial. Default is 2.
#' @param boundaries Array. Two boundary points. Default is c(0, 1).
#' @param is_approx Logical. If TRUE, the size of covariate indexes will be 1e6; otherwise, 1e4. Default is FALSE.
#' @seealso \link{qrglasso}
#' @return A prediction matrix of Y at the new locations, x_new.
#' @seealso \code{\link{qrglasso}}
#' @return A list containing:
#' \item{coef_functions}{Matrix. Top-k coefficient function estimates with dimenstion (\eqn{m \times k}) where $m$ is size of `z`.}
#' \item{z}{Array. Index predictors used in generation}
#' @examples
#' set.seed(123)
#' n <- 100
Expand All @@ -164,7 +166,7 @@ qrglasso <-
#' # Call qrglasso with default parameters
#' result <- qrglasso(Y = Y, W = W, L = 5)
#' estimate <- predict(result)
#' print(dim(estimate))
#' print(dim(estimate$coef_functions))
#'
predict <-
function(qrglasso_object,
Expand All @@ -181,5 +183,50 @@ predict <-
z <- approx_bsplines$z
gamma_hat <- qrglasso_object$gamma[, which.min(qrglasso_object$BIC[, 1])]
estimate <- bsplines %*% matrix(gamma_hat, nrow = dim(bsplines)[2])
return(estimate[, 1:min(top_k, dim(estimate)[2])])
obj.predict <- list(
coef_functions = as.matrix(estimate[, 1:min(top_k, dim(estimate)[2])]),
z = z
)
class(obj.predict) <- "qrglasso.predict"
return(obj.predict)
}

#' @title Display the estimated coefficient functions
#'
#' @description Display the estimated coefficient functions by BIC
#'
#' @param x An object of class \code{qrglasso.predict} for the \code{plot} method
#' @param ... Not used directly
#' @return \code{NULL}
#' @seealso \code{\link{qrglasso}}
#'
#' @export
#' @method plot qrglasso.predict
#' @examples
#' set.seed(123)
#' n <- 100
#' p <- 5
#' L <- 5
#' Y <- matrix(rnorm(n), n, 1)
#' W <- matrix(rnorm(n * p * (L - 1)), n, p * (L - 1))
#'
#' result <- qrglasso(Y = Y, W = W, L = 5)
#' estimate <- predict(result, top_k = 2)
#' plot(estimate)
#'
plot.qrglasso.predict <- function(x, ...) {
if (!inherits(x, "qrglasso.predict")) {
stop("Invalid object! Please enter a `qrglasso.predict` object")
}
originalPar <- par(no.readonly = TRUE)
k <- dim(x$coef_functions)[2]
result <- list()
for (i in 1:k) {
variate <- paste0("Coefficient function - g", i)
data <- data.frame(z = x$z, coefficient = x$coef_functions[,i])
result[[variate]] <- plot_coefficient_function(data, variate)
}
plot_sequentially(result)
par(originalPar)
}

35 changes: 35 additions & 0 deletions man/plot.qrglasso.predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/plot_coefficient_function.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/plot_sequentially.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test_helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,32 @@ test_that("check_predict_parameters correctly validates parameters", {
expect_error(check_predict_parameters(mock_qrglasso, top_k = 3, degree = 11, boundaries = c(0, 1)),
"Please enter a smaller degree")
})

# Test plot_sequentially
test_that("plot_sequentially prints ggplot2 objects", {
# Create mock ggplot2 objects for testing
plot1 <- ggplot(mtcars, aes(x = mpg, y = disp)) + geom_point()
plot2 <- ggplot(mtcars, aes(x = wt, y = hp)) + geom_point()

# Capture the current environment before calling the function
original_env <- environment()

# Test the function
plot_sequentially(list(plot1, plot2))

# Verify that the environment has been changed
expect_equal(original_env, environment())
})

# Test plot_coefficient_function
test_that("plot_coefficient_function returns a ggplot object", {
# Create a mock dataframe for testing
data <- data.frame(z = 1:10, coef = rnorm(10))
variate <- "Test Variate"

# Test the function
plot_result <- plot_coefficient_function(data, variate)

# Verify the output
expect_true(is.ggplot(plot_result))
})
1 change: 1 addition & 0 deletions tests/testthat/test_qrglasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ test_that("predictcoefficient functions", {
# Invalid boundaries order
expect_error(predict(mock_qrglasso, boundaries = c(1, 0)))
})

0 comments on commit 570a59b

Please sign in to comment.