R Notebook

6.5.1 Best Subset Selection

Hitters data로 best subset selection 진행. Salary 칼럼을 다른 변수들로 예측하는 상황. 우선 Salary 칼럼에 결측치가 있으므로 해당 행을 지우자.

library(ISLR)
names(Hitters)
##  [1] "AtBat"     "Hits"      "HmRun"     "Runs"      "RBI"      
##  [6] "Walks"     "Years"     "CAtBat"    "CHits"     "CHmRun"   
## [11] "CRuns"     "CRBI"      "CWalks"    "League"    "Division" 
## [16] "PutOuts"   "Assists"   "Errors"    "Salary"    "NewLeague"
dim(Hitters)
## [1] 322  20
sum(is.na(Hitters$Salary))
## [1] 59
hitters = na.omit(Hitters)
dim(hitters)
## [1] 263  20

leaps library에 있는 regsubsets() function이 best subset selection을 시행해준다. 문법은 lm과 동일하고 summary()는 각 모델에 대한 best set of variables을 출력해준다.

library(leaps)
mod.best = regsubsets(Salary~. , hitters)
summary(mod.best)
## Subset selection object
## Call: regsubsets.formula(Salary ~ ., hitters)
## 19 Variables  (and intercept)
##            Forced in Forced out
## AtBat          FALSE      FALSE
## Hits           FALSE      FALSE
## HmRun          FALSE      FALSE
## Runs           FALSE      FALSE
## RBI            FALSE      FALSE
## Walks          FALSE      FALSE
## Years          FALSE      FALSE
## CAtBat         FALSE      FALSE
## CHits          FALSE      FALSE
## CHmRun         FALSE      FALSE
## CRuns          FALSE      FALSE
## CRBI           FALSE      FALSE
## CWalks         FALSE      FALSE
## LeagueN        FALSE      FALSE
## DivisionW      FALSE      FALSE
## PutOuts        FALSE      FALSE
## Assists        FALSE      FALSE
## Errors         FALSE      FALSE
## NewLeagueN     FALSE      FALSE
## 1 subsets of each size up to 8
## Selection Algorithm: exhaustive
##          AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns
## 1  ( 1 ) " "   " "  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 2  ( 1 ) " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 3  ( 1 ) " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 4  ( 1 ) " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 5  ( 1 ) "*"   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 6  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "  
## 7  ( 1 ) " "   "*"  " "   " "  " " "*"   " "   "*"    "*"   "*"    " "  
## 8  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   " "    " "   "*"    "*"  
##          CRBI CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN
## 1  ( 1 ) "*"  " "    " "     " "       " "     " "     " "    " "       
## 2  ( 1 ) "*"  " "    " "     " "       " "     " "     " "    " "       
## 3  ( 1 ) "*"  " "    " "     " "       "*"     " "     " "    " "       
## 4  ( 1 ) "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 5  ( 1 ) "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 6  ( 1 ) "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 7  ( 1 ) " "  " "    " "     "*"       "*"     " "     " "    " "       
## 8  ( 1 ) " "  "*"    " "     "*"       "*"     " "     " "    " "

*은 그 변수가 모델에 포함되었다는 뜻이다. 예를 들어 1개의 변수를 사용하는 애는 Hits와 CRBI를 포함한다는 뜻이다. regsubsets()은 기본값으로 변수가 8개까지 포함된 best subset model을 출력해주는데 바꾸고 싶으면 nvmax 을 조정하면 된다.

mod.best = regsubsets(Salary~. ,data=hitters,nvmax=19)
summary(mod.best)
## Subset selection object
## Call: regsubsets.formula(Salary ~ ., data = hitters, nvmax = 19)
## 19 Variables  (and intercept)
##            Forced in Forced out
## AtBat          FALSE      FALSE
## Hits           FALSE      FALSE
## HmRun          FALSE      FALSE
## Runs           FALSE      FALSE
## RBI            FALSE      FALSE
## Walks          FALSE      FALSE
## Years          FALSE      FALSE
## CAtBat         FALSE      FALSE
## CHits          FALSE      FALSE
## CHmRun         FALSE      FALSE
## CRuns          FALSE      FALSE
## CRBI           FALSE      FALSE
## CWalks         FALSE      FALSE
## LeagueN        FALSE      FALSE
## DivisionW      FALSE      FALSE
## PutOuts        FALSE      FALSE
## Assists        FALSE      FALSE
## Errors         FALSE      FALSE
## NewLeagueN     FALSE      FALSE
## 1 subsets of each size up to 19
## Selection Algorithm: exhaustive
##           AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns
## 1  ( 1 )  " "   " "  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 2  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 3  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 4  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 5  ( 1 )  "*"   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 6  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "  
## 7  ( 1 )  " "   "*"  " "   " "  " " "*"   " "   "*"    "*"   "*"    " "  
## 8  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   "*"    "*"  
## 9  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 10  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 11  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 12  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 13  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 14  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 15  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    "*"   " "    "*"  
## 16  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"  
## 17  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"  
## 18  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   " "    "*"  
## 19  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   "*"    "*"  
##           CRBI CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN
## 1  ( 1 )  "*"  " "    " "     " "       " "     " "     " "    " "       
## 2  ( 1 )  "*"  " "    " "     " "       " "     " "     " "    " "       
## 3  ( 1 )  "*"  " "    " "     " "       "*"     " "     " "    " "       
## 4  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 5  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 6  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 7  ( 1 )  " "  " "    " "     "*"       "*"     " "     " "    " "       
## 8  ( 1 )  " "  "*"    " "     "*"       "*"     " "     " "    " "       
## 9  ( 1 )  "*"  "*"    " "     "*"       "*"     " "     " "    " "       
## 10  ( 1 ) "*"  "*"    " "     "*"       "*"     "*"     " "    " "       
## 11  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     " "    " "       
## 12  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     " "    " "       
## 13  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 14  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 15  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 16  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 17  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 18  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 19  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"
names(summary(mod.best))
## [1] "which"  "rsq"    "rss"    "adjr2"  "cp"     "bic"    "outmat" "obj"
summary(mod.best)$adjr2
##  [1] 0.3188503 0.4208024 0.4450753 0.4672734 0.4808971 0.4972001 0.5007849
##  [8] 0.5137083 0.5180572 0.5222606 0.5225706 0.5217245 0.5206736 0.5195431
## [15] 0.5178661 0.5162219 0.5144464 0.5126097 0.5106270
summary(mod.best)$rsq
##  [1] 0.3214501 0.4252237 0.4514294 0.4754067 0.4908036 0.5087146 0.5141227
##  [8] 0.5285569 0.5346124 0.5404950 0.5426153 0.5436302 0.5444570 0.5452164
## [15] 0.5454692 0.5457656 0.5459518 0.5460945 0.5461159

RSS, adj R^2, Cp, BIC를 그래프로 그려보는 것은 어떤 모델을 선택할지 도움을 준다.

library(ggplot2)
sum.mat = data.frame('rsq' = summary(mod.best)$rsq,'adjr2'=summary(mod.best)$adjr2,'bic'=summary(mod.best)$bic,'num.var' = 1:19)
#var 갯수에 따른 결정계수 그래프
ggplot(sum.mat,aes(x=num.var,y=rsq)) + geom_line(color='red')+ labs(title='rsq and ajd rsq with best subset selection',caption='red: rsq, blue: adj rsq') + geom_line(color='blue',aes(x=num.var,y=adjr2)) + geom_vline(xintercept =sum.mat$num.var[which.max(sum.mat$adjr2)],color='green')

#var 갯수에 따른 adj bic
ggplot(sum.mat,aes(x=num.var, y=bic)) + geom_line(color='red') + labs(title='bic with best subset selection') + geom_vline(xintercept = sum.mat$num.var[which.min(sum.mat$bic)],color='green')

Forward and Backward stepwise selection

best subset selection을 실행하는 regsubsets func에서 method = 'forward' 또는 'backward'라고 하면 된다.

mod.fwd = regsubsets(Salary~. , data=hitters, nvmax = 19, method = 'forward')
summary(mod.fwd)
## Subset selection object
## Call: regsubsets.formula(Salary ~ ., data = hitters, nvmax = 19, method = "forward")
## 19 Variables  (and intercept)
##            Forced in Forced out
## AtBat          FALSE      FALSE
## Hits           FALSE      FALSE
## HmRun          FALSE      FALSE
## Runs           FALSE      FALSE
## RBI            FALSE      FALSE
## Walks          FALSE      FALSE
## Years          FALSE      FALSE
## CAtBat         FALSE      FALSE
## CHits          FALSE      FALSE
## CHmRun         FALSE      FALSE
## CRuns          FALSE      FALSE
## CRBI           FALSE      FALSE
## CWalks         FALSE      FALSE
## LeagueN        FALSE      FALSE
## DivisionW      FALSE      FALSE
## PutOuts        FALSE      FALSE
## Assists        FALSE      FALSE
## Errors         FALSE      FALSE
## NewLeagueN     FALSE      FALSE
## 1 subsets of each size up to 19
## Selection Algorithm: forward
##           AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns
## 1  ( 1 )  " "   " "  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 2  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 3  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 4  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 5  ( 1 )  "*"   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 6  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "  
## 7  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "  
## 8  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    "*"  
## 9  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 10  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 11  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 12  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 13  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 14  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 15  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    "*"   " "    "*"  
## 16  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"  
## 17  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"  
## 18  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   " "    "*"  
## 19  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   "*"    "*"  
##           CRBI CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN
## 1  ( 1 )  "*"  " "    " "     " "       " "     " "     " "    " "       
## 2  ( 1 )  "*"  " "    " "     " "       " "     " "     " "    " "       
## 3  ( 1 )  "*"  " "    " "     " "       "*"     " "     " "    " "       
## 4  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 5  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 6  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 7  ( 1 )  "*"  "*"    " "     "*"       "*"     " "     " "    " "       
## 8  ( 1 )  "*"  "*"    " "     "*"       "*"     " "     " "    " "       
## 9  ( 1 )  "*"  "*"    " "     "*"       "*"     " "     " "    " "       
## 10  ( 1 ) "*"  "*"    " "     "*"       "*"     "*"     " "    " "       
## 11  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     " "    " "       
## 12  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     " "    " "       
## 13  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 14  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 15  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 16  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 17  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 18  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 19  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"
mod.bwd = regsubsets(Salary~. , data=hitters, nvmax = 19, method = 'forward')
summary(mod.bwd)
## Subset selection object
## Call: regsubsets.formula(Salary ~ ., data = hitters, nvmax = 19, method = "forward")
## 19 Variables  (and intercept)
##            Forced in Forced out
## AtBat          FALSE      FALSE
## Hits           FALSE      FALSE
## HmRun          FALSE      FALSE
## Runs           FALSE      FALSE
## RBI            FALSE      FALSE
## Walks          FALSE      FALSE
## Years          FALSE      FALSE
## CAtBat         FALSE      FALSE
## CHits          FALSE      FALSE
## CHmRun         FALSE      FALSE
## CRuns          FALSE      FALSE
## CRBI           FALSE      FALSE
## CWalks         FALSE      FALSE
## LeagueN        FALSE      FALSE
## DivisionW      FALSE      FALSE
## PutOuts        FALSE      FALSE
## Assists        FALSE      FALSE
## Errors         FALSE      FALSE
## NewLeagueN     FALSE      FALSE
## 1 subsets of each size up to 19
## Selection Algorithm: forward
##           AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns
## 1  ( 1 )  " "   " "  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 2  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 3  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 4  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 5  ( 1 )  "*"   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "  
## 6  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "  
## 7  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "  
## 8  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    "*"  
## 9  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 10  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 11  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"  
## 12  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 13  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 14  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    " "   " "    "*"  
## 15  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    "*"   " "    "*"  
## 16  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"  
## 17  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"  
## 18  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   " "    "*"  
## 19  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   "*"    "*"  
##           CRBI CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN
## 1  ( 1 )  "*"  " "    " "     " "       " "     " "     " "    " "       
## 2  ( 1 )  "*"  " "    " "     " "       " "     " "     " "    " "       
## 3  ( 1 )  "*"  " "    " "     " "       "*"     " "     " "    " "       
## 4  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 5  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 6  ( 1 )  "*"  " "    " "     "*"       "*"     " "     " "    " "       
## 7  ( 1 )  "*"  "*"    " "     "*"       "*"     " "     " "    " "       
## 8  ( 1 )  "*"  "*"    " "     "*"       "*"     " "     " "    " "       
## 9  ( 1 )  "*"  "*"    " "     "*"       "*"     " "     " "    " "       
## 10  ( 1 ) "*"  "*"    " "     "*"       "*"     "*"     " "    " "       
## 11  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     " "    " "       
## 12  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     " "    " "       
## 13  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 14  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 15  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 16  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 17  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 18  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 19  ( 1 ) "*"  "*"    "*"     "*"       "*"     "*"     "*"    "*"

변수가 7개일 때 어떤 변수들이 어떤 값으로 선택되었는지 알고 싶다면?

coef(mod.bwd,7)
##  (Intercept)        AtBat         Hits        Walks         CRBI 
##  109.7873062   -1.9588851    7.4498772    4.9131401    0.8537622 
##       CWalks    DivisionW      PutOuts 
##   -0.3053070 -127.1223928    0.2533404

Choosing among models using the validation set approach and cross-validation

train data와 test data의 index을 정한다.

set.seed(1)
train = sample(c(TRUE,FALSE),nrow(hitters),replace=TRUE)
test = !train

train data로 best subset selection을 해보자.

mod.best = regsubsets(Salary~. , data=hitters[train,],nvmax=19)

test data로부터 model matrix을 만든다.

test.mat = model.matrix(Salary~. , data=hitters[test,])

이제 validation error을 계산해보자.

val.error = rep(0,19)
for (i in 1:19){
  coefi = coef(mod.best,id=i)
  pred = test.mat[,names(coefi)] %*% coefi # %*%은 matrix multiplication.
  val.error[i] = mean((hitters$Salary[test]-pred)^2)
}
which.min(val.error) # test error가 가장 낮은 애는 9개의 변수를 포함하는 모델
## [1] 10
coef(mod.best,9)
##  (Intercept)        AtBat         Hits        Walks       CAtBat 
## -116.8513468   -1.5669672    7.6177014    3.5505374   -0.1888594 
##        CHits       CHmRun       CWalks      LeagueN      PutOuts 
##    1.1121891    1.3421445   -0.7221434   84.0143083    0.2433223

regsubsets()에 대한 predict 함수없기 때문에 위와같이 했다. 그렇다면 함수를 직접 만들어보자.

predict.regsubsets = function(object,newdata,id,...){
  form = as.formula(object$call[[2]]) # object은 regsubsets() object
  mat = model.matrix(form,newdata) # model.matrix(model fit한 formula, data)
  coefi = coef(object, id=id) # regsubsets() 결과에서 var이 ~개일 때의 coef 저장하기
  xvars = names(coefi) # coefi 칼럼명 뽑아내기
  mar[,xvars]%*%coefi
}

Ridge Regression

ridge와 lasso을 시행하는 함수 glmnet()은 glmnet 패키지 안에 있으며 다른 모델 적합 함수와는 문법이 약간 다르다. matrix 형태의 x와 vector 형태의 y를 쓴다.

x = model.matrix(Salary~. , hitters)[,-1]
y = hitters$Salary

model.matrix()는 x를 만드는데 유용한 함수이다. 19개의 예측변수를 포함하는 행렬을 만들분만 아니라 범주형 자료를 자동으로 더미화시켜준다. glmnet() 함수는 수치형 자료만 받기 때문에 이렇더 더미화 해주는 것이 중요하다.

library(glmnet)
## Loading required package: Matrix

## Loading required package: foreach

## Loaded glmnet 2.0-16
grid = 10^seq(10,-2,length=100)
mod.ridge = glmnet(x,y,alpha=0,lambda=grid) 

glmnet()은 alpha = 0이면 ridge이고 = 1이면 lasso을 시행한다. 원래 glmnet()은 자동적으로 람다를 선택하는데 여기서는 10^10부터 10^2까지 해봄. 또한 glmnet()은 자동적으로 변수를 표준화(standardize) 해준다(lasso, ridge을 할때 변수를 표준화해주라고 본문에 나와있음.) 이 기능을 없애고 싶으면 standardize=FALSE 해주면 됨. 모델 적합 결과는 람다 값에 따른 matrix 형태이다. 여기서는 변수가 20개 있고 람다 값을 100개 설정했으므로 20 * 100 matrix이다.

dim(coef(mod.ridge))
## [1]  20 100

본문에서 살펴봤듯이, 람다가 커질수록(l2 norm이 작아질수록), penalty를 많이 준다는 뜻이고, 계수들이 작아질 것이다.

mod.ridge$lambda[50] # 람다가 11498일때
## [1] 11497.57
coef(mod.ridge)[,50] # 계수들을 살펴보자
##   (Intercept)         AtBat          Hits         HmRun          Runs 
## 407.356050200   0.036957182   0.138180344   0.524629976   0.230701523 
##           RBI         Walks         Years        CAtBat         CHits 
##   0.239841459   0.289618741   1.107702929   0.003131815   0.011653637 
##        CHmRun         CRuns          CRBI        CWalks       LeagueN 
##   0.087545670   0.023379882   0.024138320   0.025015421   0.085028114 
##     DivisionW       PutOuts       Assists        Errors    NewLeagueN 
##  -6.215440973   0.016482577   0.002612988  -0.020502690   0.301433531
sqrt(sum(coef(mod.ridge)[-1,50]^2)) # 절편 제외 l2 norm은??
## [1] 6.360612
mod.ridge$lambda[60] # 람다가 705일때
## [1] 705.4802
coef(mod.ridge)[,60] # 계수들을 살펴보자
##  (Intercept)        AtBat         Hits        HmRun         Runs 
##  54.32519950   0.11211115   0.65622409   1.17980910   0.93769713 
##          RBI        Walks        Years       CAtBat        CHits 
##   0.84718546   1.31987948   2.59640425   0.01083413   0.04674557 
##       CHmRun        CRuns         CRBI       CWalks      LeagueN 
##   0.33777318   0.09355528   0.09780402   0.07189612  13.68370191 
##    DivisionW      PutOuts      Assists       Errors   NewLeagueN 
## -54.65877750   0.11852289   0.01606037  -0.70358655   8.61181213
sqrt(sum(coef(mod.ridge)[-1,60]^2)) # 절편 제외 l2 norm은??
## [1] 57.11001

l2 norm이 람다가 작아짐에 따라 커졌음을 확인할 있다!!

predict() func을 새로운 람다에 대해서 ridge reg을 할 때에도 사용할 수 있다. 예를 들어 람다 50에 대해서는,

predict(mod.ridge,s=50,type='coefficients')[1:20,]
##   (Intercept)         AtBat          Hits         HmRun          Runs 
##  4.876610e+01 -3.580999e-01  1.969359e+00 -1.278248e+00  1.145892e+00 
##           RBI         Walks         Years        CAtBat         CHits 
##  8.038292e-01  2.716186e+00 -6.218319e+00  5.447837e-03  1.064895e-01 
##        CHmRun         CRuns          CRBI        CWalks       LeagueN 
##  6.244860e-01  2.214985e-01  2.186914e-01 -1.500245e-01  4.592589e+01 
##     DivisionW       PutOuts       Assists        Errors    NewLeagueN 
## -1.182011e+02  2.502322e-01  1.215665e-01 -3.278600e+00 -9.496680e+00

이제 train이랑 test으로 쪼개서 모델을 적합시키는 과정을 살펴본다. 보통 데이터를 이렇게 두개로 쪼개는 것은 두 가방 방법이 있는데 TRUE FALSE 벡터를 생성하는 것, 또는 1:n 중 train data index을 뽑는 것이다.

set.seed(1)
train = sample(1:nrow(x),nrow(x)/2)
test = -train
y.test = y[test]

이제 train data로 적합하 하고 test data로 성능평 평가해보자.

mod.ridge = glmnet(x[train,],y[train],alpha=0,lambda=grid, thresh=1e-12)
ridge.pred = predict(mod.ridge,s=4,newx=x[test,])
mean((y.test - ridge.pred)^2)
## [1] 101036.8

절편만 포함하는 ridge는?

mean((mean(y[train])-y.test)^2)
## [1] 193253.1

이제 cross validation을 해보자! glmnet library안에 cv.glmnet()가 이미 내장되어 있다. 기본값은 10 folds이다. 바꾸고 싶으면 nfolds =.

set.seed(1)
cv.out = cv.glmnet(x[train,],y[train],alpha=0)
cv.out$lambda.min
## [1] 211.7416

가장 작은 cross validation error을 출력하는 람다는 212이다.그렇다면 이때 test mse는 몇일까?

ridge.pred = predict(mod.ridge,s=212,newx=x[test,])
mean((y[test]-ridge.pred)^2)
## [1] 96015.27

이제 cv를 이용해서 가장 작은 test mse를 출력하는 람다를 찾았으니 전체 train data로 적합을 해보자!

final.out = glmnet(x,y,alpha=0,lambda=212)
coef(final.out)
## 20 x 1 sparse Matrix of class "dgCMatrix"
##                       s0
## (Intercept)   9.81495442
## AtBat         0.03191025
## Hits          1.00790629
## HmRun         0.14204716
## Runs          1.11283540
## RBI           0.87320884
## Walks         1.80310238
## Years         0.14156997
## CAtBat        0.01114964
## CHits         0.06486802
## CHmRun        0.45137750
## CRuns         0.12884114
## CRBI          0.13722324
## CWalks        0.02928067
## LeagueN      27.15822017
## DivisionW   -91.58770575
## PutOuts       0.19138790
## Assists       0.04243399
## Errors       -1.81028711
## NewLeagueN    7.23220165

앞서 살펴봤듯이, ridge는 어떠한 계수도 0으로 만들지 않는다. 즉, variable selection을 하지 않는다!

The Lasso

lasso는 ridge와 동일하게 glmnet()을 통해서 시행되며 alpha=1라고 설정해야 한다.

mod.lasso = glmnet(x[train,],y[train],alpha=1,lambda=grid)

이제 cross validation을 해보자.

set.seed(1)
cv.out = cv.glmnet(x[train,],y[train],alpha=1)
cv.out$lambda.min
## [1] 16.78016

lasso의 test mse 계산하기

lasso.pred = predict(mod.lasso,s=17,newx=x[test,])
mean((y[test]-lasso.pred)^2)
## [1] 100755.1

test mse는 ridge에서 람다가 212일때랑 거의 차이가 없지만! lasso의 장점은 ridge와는 다르게 몇몇 계수를 0으로 만든다는 점이다.

final.lasso.out = glmnet(x,y,lambda=17,alpha=1)
coef(final.lasso.out)
## 20 x 1 sparse Matrix of class "dgCMatrix"
##                       s0
## (Intercept)   20.2735038
## AtBat          .        
## Hits           1.8673382
## HmRun          .        
## Runs           .        
## RBI            .        
## Walks          2.2161744
## Years          .        
## CAtBat         .        
## CHits          .        
## CHmRun         .        
## CRuns          0.2074062
## CRBI           0.4121545
## CWalks         .        
## LeagueN        1.2663496
## DivisionW   -103.1080013
## PutOuts        0.2202108
## Assists        .        
## Errors         .        
## NewLeagueN     .

따라서 해석의 측면에서 ridge보다 더욱 강점이 있다.

Lab 3: PCR and PLS regression

Principal Components Regression

PCR은 pls library의 pcr()을 통해서 할 수 있다. 기본적인 구조는 lm()과 유사하나, scale=TRUE을 통해서 표준화를 진행하여 변수의 스케일이 결과에 영향을 미치지 않도록, validation = 'CV'를 통해서 기본값으로 ten-fold-cv를 통해 M(# of principal components)을 도출하게 한다.

library(pls)
## 
## Attaching package: 'pls'

## The following object is masked from 'package:stats':
## 
##     loadings
set.seed(2)
mod.pcr = pcr(Salary~. , data=hitters, scale=TRUE, validation = 'CV')
summary(mod.pcr)
## Data:    X dimension: 263 19 
##  Y dimension: 263 1
## Fit method: svdpc
## Number of components considered: 19
## 
## VALIDATION: RMSEP
## Cross-validated using 10 random segments.
##        (Intercept)  1 comps  2 comps  3 comps  4 comps  5 comps  6 comps
## CV             452    348.9    352.2    353.5    352.8    350.1    349.1
## adjCV          452    348.7    351.8    352.9    352.1    349.3    348.0
##        7 comps  8 comps  9 comps  10 comps  11 comps  12 comps  13 comps
## CV       349.6    350.9    352.9     353.8     355.0     356.2     363.5
## adjCV    348.5    349.8    351.6     352.3     353.4     354.5     361.6
##        14 comps  15 comps  16 comps  17 comps  18 comps  19 comps
## CV        355.2     357.4     347.6     350.1     349.2     352.6
## adjCV     352.8     355.2     345.5     347.6     346.7     349.8
## 
## TRAINING: % variance explained
##         1 comps  2 comps  3 comps  4 comps  5 comps  6 comps  7 comps
## X         38.31    60.16    70.84    79.03    84.29    88.63    92.26
## Salary    40.63    41.58    42.17    43.22    44.90    46.48    46.69
##         8 comps  9 comps  10 comps  11 comps  12 comps  13 comps  14 comps
## X         94.96    96.28     97.26     97.98     98.65     99.15     99.47
## Salary    46.75    46.86     47.76     47.82     47.85     48.10     50.40
##         15 comps  16 comps  17 comps  18 comps  19 comps
## X          99.75     99.89     99.97     99.99    100.00
## Salary     50.55     53.01     53.85     54.61     54.61

M=0일때부터, M=19일때까지, 각각의 CV score을 볼 수 있는데, 이것이 RMSEP(root mean square error)이므로 실제 MSE를 계산하려면 제곱을 해줘야 한다.

또한 cv score를 validationplot()을 통해 그래프로 볼 수 있고, 이때 val.type='MSEP'을 통해 RMSE가 아니라 MSE로 나오게 설정할 수 있다.

validationplot(mod.pcr,val.type='MSEP')

그림을 보면 M=16일 때가 최소값을 가지는데 사실 M=1일 때랑 그렇게 큰 차이가 없다. 따라서 M=1이면 충분하다고 결론지을 수 있다.

이제 CV를 통해 PCR을 해보자.

set.seed(1)
mod.pcr = pcr(Salary~. , data=hitters, subset=train, scale=TRUE, validation='CV')
validationplot(mod.pcr, val.type="MSEP")

mod.pcr의 결과로 나온 MSEP는 어떻게 접근할까?? M=7일때 test mse가 가장 낮다.

pcr.pred = predict(mod.pcr,x[test,],ncomp=7)
mean((y[test] - pcr.pred)^2)
## [1] 96556.22

test mse가 ridge와 비슷하게 낮지만 pcr은 variable selection을 하는 것도 아니고, 계수에 대해서 estimate을 하는 것도 아니기 때문에 모델을 해석하기가 더 어려워졌다!

Partial Least Squares

PLS는 plst library에 있는 plsr() func을 이용하면 된다. 문법은 pcr() func과 비슷하다.

set.seed(1)
mod.pls = plsr(Salary~. , data=hitters, subset=train, scale=TRUE, validation='CV')
summary(mod.pls)
## Data:    X dimension: 131 19 
##  Y dimension: 131 1
## Fit method: kernelpls
## Number of components considered: 19
## 
## VALIDATION: RMSEP
## Cross-validated using 10 random segments.
##        (Intercept)  1 comps  2 comps  3 comps  4 comps  5 comps  6 comps
## CV           464.6    394.2    391.5    393.1    395.0    415.0    424.0
## adjCV        464.6    393.4    390.2    391.1    392.9    411.5    418.8
##        7 comps  8 comps  9 comps  10 comps  11 comps  12 comps  13 comps
## CV       424.5    415.8    404.6     407.1     412.0     414.4     410.3
## adjCV    418.9    411.4    400.7     402.2     407.2     409.3     405.6
##        14 comps  15 comps  16 comps  17 comps  18 comps  19 comps
## CV        406.2     408.6     410.5     408.8     407.8     410.2
## adjCV     401.8     403.9     405.6     404.1     403.2     405.5
## 
## TRAINING: % variance explained
##         1 comps  2 comps  3 comps  4 comps  5 comps  6 comps  7 comps
## X         38.12    53.46    66.05    74.49    79.33    84.56    87.09
## Salary    33.58    38.96    41.57    42.43    44.04    45.59    47.05
##         8 comps  9 comps  10 comps  11 comps  12 comps  13 comps  14 comps
## X         90.74    92.55     93.94     97.23     97.88     98.35     98.85
## Salary    47.53    48.42     49.68     50.04     50.54     50.78     50.92
##         15 comps  16 comps  17 comps  18 comps  19 comps
## X          99.11     99.43     99.78     99.99    100.00
## Salary     51.04     51.11     51.15     51.16     51.18

가장 작은 cv error는 M = 2일때 발생하였으며, 이때 test mse를 계산해보면,

pls.pred = predict(mod.pls, newdata=hitters[test,], ncomp=2)
mean((y[test]-pls.pred)^2)
## [1] 101417.5

PCR과 PLS의 결과를 살펴보면 PCR은 M = 7일때 percentage of variance가 46%였는데, PLS는 M = 2일때 거의 동일하게 percentage of variance가 46%이다. 이는 PCR이 오직 예측변수에서 variance의 양을 최소화하려고 한다면(unsupervised learning) PLS는 예측변수와 반응변수(supervised learning) 에서 variance을 설명하는 direction을 찾기 때문에 발생하는 것이다.