@@ -68,8 +68,11 @@ test_that('glmnet prediction, single lambda', {
6868 y = hpc $ input_fields
6969 )
7070
71- uni_pred <- c(5.05125589060219 , 4.86977761622526 , 4.90912345599309 , 4.93931874108359 ,
72- 5.08755154547758 )
71+ # glmn_mod <- glmnet::glmnet(x = as.matrix(hpc[, num_pred]), y = hpc$input_fields,
72+ # alpha = .3, nlambda = 15)
73+
74+ uni_pred <- c(640.599944271351 , 196.646976529848 , 186.279646400216 , 194.673852228774 ,
75+ 198.126819755653 )
7376
7477 expect_equal(uni_pred , predict(res_xy , hpc [1 : 5 , num_pred ])$ .pred , tolerance = 0.0001 )
7578
@@ -80,8 +83,8 @@ test_that('glmnet prediction, single lambda', {
8083 control = ctrl
8184 )
8285
83- form_pred <- c(5.23960117346944 , 5.08769210344022 , 5.15129212608077 , 5.12000510716518 ,
84- 5.26736239856889 )
86+ form_pred <- c(570.504089227118 , 162.413061474088 , 167.022896537861 , 157.609071878082 ,
87+ 165.887783741483 )
8588
8689 expect_equal(form_pred , predict(res_form , hpc [1 : 5 ,])$ .pred , tolerance = 0.0001 )
8790})
@@ -118,16 +121,16 @@ test_that('glmnet prediction, multiple lambda', {
118121 mult_pred <-
119122 tibble :: tribble(
120123 ~ penalty , ~ .pred ,
121- 0.01 , 5.01352459498158 ,
122- 0.1 , 5.05124049139868 ,
123- 0.01 , 4.71767499960808 ,
124- 0.1 , 4.87103404621362 ,
125- 0.01 , 4.7791916685127 ,
126- 0.1 , 4.91028250633598 ,
127- 0.01 , 4.83366808792755 ,
128- 0.1 , 4.9399094532023 ,
129- 0.01 , 5.07269451405628 ,
130- 0.1 , 5.08728178043569
124+ 0.01 , 639.672880668187 ,
125+ 0.1 , 639.672880668187 ,
126+ 0.01 , 197.744613311359 ,
127+ 0.1 , 197.744613311359 ,
128+ 0.01 , 187.737940787615 ,
129+ 0.1 , 187.737940787615 ,
130+ 0.01 , 195.780487678662 ,
131+ 0.1 , 195.780487678662 ,
132+ 0.01 , 199.217707535882 ,
133+ 0.1 , 199.217707535882
131134 )
132135
133136 expect_equal(
@@ -163,16 +166,16 @@ test_that('glmnet prediction, multiple lambda', {
163166 form_pred <-
164167 tibble :: tribble(
165168 ~ penalty , ~ .pred ,
166- 0.01 , 5.09237402805557 ,
167- 0.1 , 5.24228948237804 ,
168- 0.01 , 4.75071416991856 ,
169- 0.1 , 5.09448280355765 ,
170- 0.01 , 4.89375747015535 ,
171- 0.1 , 5.15636527125752 ,
172- 0.01 , 4.82338959520112 ,
173- 0.1 , 5.12592317615935 ,
174- 0.01 , 5.15481201301174 ,
175- 0.1 , 5.26930099973607
169+ 0.01 , 570.474473760044 ,
170+ 0.1 , 570.474473760044 ,
171+ 0.01 , 164.040104978709 ,
172+ 0.1 , 164.040104978709 ,
173+ 0.01 , 168.709676954287 ,
174+ 0.1 , 168.709676954287 ,
175+ 0.01 , 159.173862504055 ,
176+ 0.1 , 159.173862504055 ,
177+ 0.01 , 167.559854709074 ,
178+ 0.1 , 167.559854709074
176179 )
177180
178181 expect_equal(
@@ -190,7 +193,7 @@ test_that('glmnet prediction, all lambda', {
190193 skip_if(run_glmnet )
191194
192195 hpc_all <- linear_reg(mixture = .3 ) %> %
193- set_engine(" glmnet" )
196+ set_engine(" glmnet" , nlambda = 7 )
194197
195198 res_xy <- fit_xy(
196199 hpc_all ,
@@ -202,7 +205,7 @@ test_that('glmnet prediction, all lambda', {
202205 all_pred <- predict(res_xy $ fit , newx = as.matrix(hpc [1 : 5 , num_pred ]))
203206 all_pred <- stack(as.data.frame(all_pred ))
204207 all_pred $ penalty <- rep(res_xy $ fit $ lambda , each = 5 )
205- all_pred $ rows <- rep(1 : 5 , 2 )
208+ all_pred $ rows <- rep(1 : 5 , length( res_xy $ fit $ lambda ) )
206209 all_pred <- all_pred [order(all_pred $ rows , all_pred $ penalty ), ]
207210 all_pred <- all_pred [, c(" penalty" , " values" )]
208211 names(all_pred ) <- c(" penalty" , " .pred" )
@@ -223,7 +226,7 @@ test_that('glmnet prediction, all lambda', {
223226 form_pred <- predict(res_form $ fit , newx = form_mat )
224227 form_pred <- stack(as.data.frame(form_pred ))
225228 form_pred $ penalty <- rep(res_form $ fit $ lambda , each = 5 )
226- form_pred $ rows <- rep(1 : 5 , 2 )
229+ form_pred $ rows <- rep(1 : 5 , length( res_form $ fit $ lambda ) )
227230 form_pred <- form_pred [order(form_pred $ rows , form_pred $ penalty ), ]
228231 form_pred <- form_pred [, c(" penalty" , " values" )]
229232 names(form_pred ) <- c(" penalty" , " .pred" )
0 commit comments