Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Plot estimated coefficient functions #5

Merged
merged 5 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: QuantRegGLasso
Title: Adaptively Weighted Group Lasso for Semiparametic Quantile Rgression Models
Version: 0.3.0
Version: 0.4.0
Authors@R: c(person(
given = "Wen-Ting",
family = "Wang",
Expand Down Expand Up @@ -35,7 +35,7 @@ BugReports: https://github.com/egpivo/QuantRegGLasso/issues
Depends:
R (>= 3.4.0)
Imports:
Rcpp (>= 1.0.10)
Rcpp (>= 1.0.10), ggplot2
LinkingTo: Rcpp, RcppArmadillo
Suggests:
knitr,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ export(qrglasso)
importFrom(Rcpp, evalCpp)
importFrom(splines, bs)
importFrom(stats, runif)
importFrom(graphics, par)
import(ggplot2)
S3method(plot, qrglasso.predict)
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
## QuantRegGLasso 0.3.0 (Release Date: 2024-01-09)
### Overview
- Added a `plot.qrglasso.predict` function for displaying top k coefficient functions via `qrglasso.predict` object.

---

## QuantRegGLasso 0.3.0 (Release Date: 2024-01-08)
### Overview
- Added a `predict` function for estimating top k coefficient functions via `qrglasso` object
- Added a helper function to meet the pre-conditions of `predict`

---

## QuantRegGLasso 0.2.0 (Release Date: 2024-01-06)
### Overview
Expand Down
4 changes: 4 additions & 0 deletions R/global.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
utils::globalVariables(c(
"z",
"coefficient"
))
35 changes: 35 additions & 0 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,38 @@ check_predict_parameters <- function(qrglasso_object, top_k, degree, boundaries)
stop("Please enter a smaller degree")
}
}

#' Internal function: Plot sequentially
#' @keywords internal
#' @param objs Valid ggplot2 objects
#' @return `NULL`
#'
plot_sequentially <- function(objs) {
originalPar <- par(no.readonly = TRUE)
on.exit(par(originalPar))
par(ask = TRUE)
suppressWarnings({
for (obj in objs) {
suppressWarnings(print(obj))
}
})
par(ask = FALSE)
}

#' Internal function: Plot 2D fields for cross validation results
#' @keywords internal
#' @param data A dataframe contains columns ``z``, ``coefficient``
#' @param variate A character represent the title
#' @return A ggplot object
plot_coefficient_function <- function(data, variate) {
default_theme <- theme_classic() +
theme(
text = element_text(size = 24),
plot.title = element_text(hjust = 0.5)
)
result <- ggplot(data, aes(x = z, y = coefficient)) +
geom_point(col="#4634eb") +
ggtitle(variate) +
default_theme
return(result)
}
57 changes: 52 additions & 5 deletions R/qrglasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,15 @@ qrglasso <-
#'
#' @description Predict the top-k coefficient functions
#'
#' @param qrglasso_object An `qrglasso` class object.
#' @param qrglasso_object An \code{qrglasso} class object.
#' @param top_k Integer. A matrix of the top K estimated functions. Default is 5.
#' @param degree Integer. Degree of the piecewise polynomial. Default is 2.
#' @param boundaries Array. Two boundary points. Default is c(0, 1).
#' @param is_approx Logical. If TRUE, the size of covariate indexes will be 1e6; otherwise, 1e4. Default is FALSE.
#' @seealso \link{qrglasso}
#' @return A prediction matrix of Y at the new locations, x_new.
#' @seealso \code{\link{qrglasso}}
#' @return A list containing:
#' \item{coef_functions}{Matrix. Top-k coefficient function estimates with dimenstion (\eqn{m \times k}) where $m$ is size of `z`.}
#' \item{z}{Array. Index predictors used in generation}
#' @examples
#' set.seed(123)
#' n <- 100
Expand All @@ -164,7 +166,7 @@ qrglasso <-
#' # Call qrglasso with default parameters
#' result <- qrglasso(Y = Y, W = W, L = 5)
#' estimate <- predict(result)
#' print(dim(estimate))
#' print(dim(estimate$coef_functions))
#'
predict <-
function(qrglasso_object,
Expand All @@ -181,5 +183,50 @@ predict <-
z <- approx_bsplines$z
gamma_hat <- qrglasso_object$gamma[, which.min(qrglasso_object$BIC[, 1])]
estimate <- bsplines %*% matrix(gamma_hat, nrow = dim(bsplines)[2])
return(estimate[, 1:min(top_k, dim(estimate)[2])])
obj.predict <- list(
coef_functions = as.matrix(estimate[, 1:min(top_k, dim(estimate)[2])]),
z = z
)
class(obj.predict) <- "qrglasso.predict"
return(obj.predict)
}

#' @title Display the estimated coefficient functions
#'
#' @description Display the estimated coefficient functions by BIC
#'
#' @param x An object of class \code{qrglasso.predict} for the \code{plot} method
#' @param ... Not used directly
#' @return \code{NULL}
#' @seealso \code{\link{qrglasso}}
#'
#' @export
#' @method plot qrglasso.predict
#' @examples
#' set.seed(123)
#' n <- 100
#' p <- 5
#' L <- 5
#' Y <- matrix(rnorm(n), n, 1)
#' W <- matrix(rnorm(n * p * (L - 1)), n, p * (L - 1))
#'
#' result <- qrglasso(Y = Y, W = W, L = 5)
#' estimate <- predict(result, top_k = 2)
#' plot(estimate)
#'
plot.qrglasso.predict <- function(x, ...) {
if (!inherits(x, "qrglasso.predict")) {
stop("Invalid object! Please enter a `qrglasso.predict` object")
}
originalPar <- par(no.readonly = TRUE)
k <- dim(x$coef_functions)[2]
result <- list()
for (i in 1:k) {
variate <- paste0("Coefficient function - g", i)
data <- data.frame(z = x$z, coefficient = x$coef_functions[,i])
result[[variate]] <- plot_coefficient_function(data, variate)
}
plot_sequentially(result)
par(originalPar)
}

35 changes: 35 additions & 0 deletions man/plot.qrglasso.predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/plot_coefficient_function.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/plot_sequentially.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test_helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,32 @@ test_that("check_predict_parameters correctly validates parameters", {
expect_error(check_predict_parameters(mock_qrglasso, top_k = 3, degree = 11, boundaries = c(0, 1)),
"Please enter a smaller degree")
})

# Test plot_sequentially
test_that("plot_sequentially prints ggplot2 objects", {
# Create mock ggplot2 objects for testing
plot1 <- ggplot(mtcars, aes(x = mpg, y = disp)) + geom_point()
plot2 <- ggplot(mtcars, aes(x = wt, y = hp)) + geom_point()

# Capture the current environment before calling the function
original_env <- environment()

# Test the function
plot_sequentially(list(plot1, plot2))

# Verify that the environment has been changed
expect_equal(original_env, environment())
})

# Test plot_coefficient_function
test_that("plot_coefficient_function returns a ggplot object", {
# Create a mock dataframe for testing
data <- data.frame(z = 1:10, coef = rnorm(10))
variate <- "Test Variate"

# Test the function
plot_result <- plot_coefficient_function(data, variate)

# Verify the output
expect_true(is.ggplot(plot_result))
})
1 change: 1 addition & 0 deletions tests/testthat/test_qrglasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ test_that("predictcoefficient functions", {
# Invalid boundaries order
expect_error(predict(mock_qrglasso, boundaries = c(1, 0)))
})