In this lab we will go through the model building, validation, and
interpretation of tree models. The focus will be on rpart
package. Recall that when the response variable \(Y\) is continuous, we fit regression tree;
when the reponse variable \(Y\) is
categorical, we fit classification tree. We build tree models for our
familiar datasets, Boston Housing data and Credit Card Default data, for
regression and classification tree respectively.
Load the data, and randomly split to training and testing sample.
library(tidyverse)
library(MASS)
data(Boston)
index <- sample(nrow(Boston),nrow(Boston)*0.90)
boston.train <- Boston[index,]
boston.test <- Boston[-index,]
We will use the ‘rpart’ library for model building and ‘rpart.plot’ for plotting.
install.packages('rpart')
install.packages('rpart.plot')
library(rpart)
library(rpart.plot)
The simple form of the rpart function is similar to lm and glm. It takes a formula argument in which you specify the response and predictor variables, and the dataset.
boston.rpart <- rpart(formula = medv ~ ., data = boston.train)
boston.rpart
## n= 455
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 455 39064.4100 22.48703
## 2) rm< 6.825 375 15001.1500 19.61307
## 4) lstat>=14.4 156 2892.2470 14.77564
## 8) crim>=7.46495 63 715.3571 11.54762 *
## 9) crim< 7.46495 93 1075.7180 16.96237 *
## 5) lstat< 14.4 219 5858.0300 23.05890
## 10) dis>=1.5511 212 2814.7590 22.56557
## 20) rm< 6.5445 174 1515.5840 21.62471 *
## 21) rm>=6.5445 38 439.8737 26.87368 *
## 11) dis< 1.5511 7 1429.0200 38.00000 *
## 3) rm>=6.825 80 6446.9140 35.95875
## 6) rm< 7.437 54 2105.9960 31.43148
## 12) lstat>=9.65 9 427.9956 23.17778 *
## 13) lstat< 9.65 45 942.2658 33.08222 *
## 7) rm>=7.437 26 935.4015 45.36154 *
prp(boston.rpart, digits = 4, extra = 1)
Make sure you know how to interpret this tree model!
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | black | lstat |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0.05 | 0 | 3.41 | 0 | 0.49 | 6.42 | 66.1 | 3.09 | 2 | 270 | 17.8 | 392.18 | 8.81 |
The prediction for regression tree is also similar to
lm()
and glm()
models.
boston.test.pred.tree = predict(boston.rpart, newdata=boston.test)
MSPE.tree<-
Calculate the mean squared prediction error (MSPE) for linear regression model using all variables. Then compare the results. What is your conclusion? Further, try to compare the regression tree with the best linear regression model using some variable selection procedures.
boston.lm<-
pred.lm<-
MSPE.lm<-
In rpart(), the cp(complexity parameter) argument is one of the parameters that are used to control the compexity of the tree. The help document for rpart tells you “Any split that does not decrease the overall lack of fit by a factor of cp is not attempted”. For a regression tree, the overall Rsquare must increase by cp at each step. Basically, the smaller the cp value, the larger (complex) tree rpart will attempt to fit. The default value for cp is 0.01.
The idea of pruning a tree is that we start with a large tree (small cp value) that may overfit the data, and then cut it down by choosing an appropriate cp value (via cross validation).
boston.largetree <- rpart(formula = medv ~ ., data = boston.train, cp = 0.001)
prp(boston.largetree)
The plotcp() function gives the relationship between 10-fold cross-validation error and size of tree.
plotcp(boston.largetree)
From left to right, cp decreases, which means the tree is growing. You can observe from the above graph that the cross-validation error (x-val) does not always go down when the tree becomes more complex. The analogy is when you add more variables in a regression model, its ability to predict future observations not necessarily increases. A good choice of cp for pruning is often the leftmost value for which the mean lies below the horizontal line. In the Boston housing example, you may conclude that having a tree model with more than 10 splits is not helptul.
To look at the error vs size of tree more carefully, you can look at the following table:
cptable <- printcp(boston.largetree)
##
## Regression tree:
## rpart(formula = medv ~ ., data = boston.train, cp = 0.001)
##
## Variables actually used in tree construction:
## [1] age crim dis lstat nox ptratio rm tax
##
## Root node error: 39064/455 = 85.856
##
## n= 455
##
## CP nsplit rel error xerror xstd
## 1 0.4509566 0 1.00000 1.00467 0.087217
## 2 0.1600144 1 0.54904 0.61683 0.061860
## 3 0.0871769 2 0.38903 0.43070 0.050364
## 4 0.0413228 3 0.30185 0.40514 0.049951
## 5 0.0281886 4 0.26053 0.35831 0.048549
## 6 0.0219970 5 0.23234 0.35149 0.048858
## 7 0.0188339 6 0.21034 0.34591 0.048816
## 8 0.0074651 7 0.19151 0.31264 0.045726
## 9 0.0070284 8 0.18404 0.28287 0.041752
## 10 0.0054920 9 0.17702 0.28258 0.043522
## 11 0.0051444 10 0.17152 0.28898 0.043793
## 12 0.0045901 12 0.16124 0.28983 0.043822
## 13 0.0040034 13 0.15665 0.28788 0.043577
## 14 0.0036502 14 0.15264 0.28773 0.044308
## 15 0.0036042 15 0.14899 0.28460 0.044163
## 16 0.0031818 16 0.14539 0.28286 0.044166
## 17 0.0029959 17 0.14221 0.27765 0.044103
## 18 0.0024727 18 0.13921 0.27652 0.044995
## 19 0.0023450 19 0.13674 0.27572 0.044984
## 20 0.0018380 20 0.13439 0.27678 0.044979
## 21 0.0016324 21 0.13255 0.27571 0.044921
## 22 0.0011454 23 0.12929 0.27729 0.044929
## 23 0.0010845 24 0.12814 0.27622 0.044919
## 24 0.0010462 25 0.12706 0.27516 0.044908
## 25 0.0010000 26 0.12601 0.27484 0.044911
Root node error is the error when you do not do anything too smart in prediction, in regression case, it is the mean squared error(MSE) if you use the average of medv as the prediction. That is,
sum((boston.train$medv - mean(boston.train$medv))^2)/nrow(boston.train)
## [1] 85.85585
The first 2 columns CP and nsplit tells you how large the tree is. The 3rd column is a relative error, which is MSE/root node error. Therefore, rel.error \(\times\) root node error is the MSE (training error). For example, the last row “(rel error)*(root node error)“, which is the same as the in-sample MSE if you calculate using predict:
mean((predict(boston.largetree) - boston.train$medv)^2)
## [1] 10.81898
xerror is the cross-validation (default is 10-fold) error. You can see that the rel error (training error) is always decreasing as model gets complex, while the cross-validation error (measure of performance on future observations) is not. That is why we prune the tree to avoid overfitting.
To prune the tree, we need to choose the appropriate cp value from
the cp table. Then we can use the function prune()
with the
chosen cp value to reduce the tree.
## cp that corresponds to minimum CV error
cp.min <- cptable[which.min(cptable[,4]),1]
tree.prune.min<- prune(boston.largetree, cp = cp.min)
## getting the 1se cp is a little complicated in coding (you can eye examine it instead of coding)
ind <- which.min(abs(cptable[,4]-sum(cptable[which.min(cptable[,4]),c(4,5)])))
cp.1se <- cptable[ind,1]
tree.prune.1se<- prune(boston.largetree, cp = cp.1se)
Load the data, rename response variable (because it is too long), convert categorical variable to factor, and randomly split to training and testing sample.
credit.data <- read_csv("https://www.dropbox.com/s/tnoo06n8m842uit/credit_card_default.csv?dl=1")
# rename
credit.data<- rename(credit.data, default=`default payment next month`)
# convert categorical data to factor
credit.data$EDUCATION<- as.factor(credit.data$EDUCATION)
credit.data$MARRIAGE<- as.factor(credit.data$MARRIAGE)
# random splitting
index <- sample(nrow(credit.data),nrow(credit.data)*0.80)
credit.train = credit.data[index,]
credit.test = credit.data[-index,]
You need to tell R you want a classification tree. We have to specify
method="class"
, since the default is to fit regression
tree.
credit.rpart0 <- rpart(formula = default ~ ., data = credit.train, method = "class")
prp(credit.rpart0, extra = 1)
However, this tree minimizes the symmetric cost, which is equivalent to misclassification rate.
Note that in the predict()
function, we need
type="class"
in order to get binary prediction.
Recall the example in logistic regression. In the credit scoring case it means that false negatives (predicting 0 when truth is 1, or giving out loans that end up in default) will cost more than false positives (predicting 1 when truth is 0, rejecting loans that you should not reject).
Here we make the assumption that false negative cost 5 times of false positive. In real life the cost structure should be carefully researched.
credit.rpart1 <- rpart(formula = default ~ .,
data = credit.train,
method = "class",
parms = list(loss=matrix(c(0,5,1,0), nrow = 2)))
prp(credit.rpart1, extra = 1)
parms
argument is a list, among which the the loss
matrix can be specified using loss
. The diagonal elements
are 0, and off-diagonal elements tells you the loss(cost) of classifying
something wrong. For binary classification, the numbers in c() specify
the cost in this sequence: c(0, False Negative, False Positive, 0). If
you have symmetric cost, you can ignore the parms argument.For more advanced controls, you should carefully read the help document for the rpart function.
For each of above classification tree, obtain the FPR and FNR for the testing sample.
Compute the cost (on testing sample) for each model, respectively.
Compare these tree models with logistic regression in terms of MR, FPR, and FNR. (Use the same cost ratio to select optimal cutoff.)
In predict()
for the tree model, try
type="prob"
. What can you say about these predicted
probabilities?
To get ROC curve, we get the predicted probability of Y being 1 from the fitted tree.
credit.test.prob.rpart<- predict(credit.rpart1, credit.test, type="prob")
credit.test.prob.rpart has 2 columns, the first one is prob(Y) = 0 and the second prob(Y) = 1. We need the second column.
library(ROCR)
pred = prediction(credit.test.prob.rpart[,2], credit.test$default)
perf = performance(pred, "tpr", "fpr")
plot(perf)
slot(performance(pred, "auc"), "y.values")[[1]]
## [1] 0.7217353
Compare tree model with logistic regression in terms of AUC.
Prune the classification tree. How does the pruned tree perform?
digit<- data.matrix(read_csv("https://www.dropbox.com/s/ulujvi2a4ykfzju/train.csv?dl=1"))
dim(digit)
## [1] 42000 785
## visualize the data
plotTrain <- function(data, index){
op <- par(no.readonly=TRUE)
x <- ceiling(sqrt(length(index)))
par(mfrow=c(x, x), mar=c(.1, .1, .1, .1))
for (i in index){ #reverse and transpose each matrix to rotate images
m <- matrix(data[i,-1], nrow=28, byrow=TRUE)
m <- apply(m, 2, rev)
image(t(m), col=grey.colors(255), axes=FALSE)
text(0.05, 0.2, col="white", cex=1.2, data[i, 1])
}
par(op) #reset the original graphics parameters
}
plotTrain(data=digit, index=1:100)
index<- sample(1:nrow(digit), 0.6*nrow(digit))
train<- digit[index,]
test<- digit[-index,]
Currently, each cell uses 0-255 to represent the grey color scale. We recale it to 0-1.
## standardize X
train.x <- train[,-1] #remove 'label' column
test.x<- test[,-1]
train.y <- train[,1] #label column
test.y<- test[,1]
train.x <- train.x/255
test.x <- test.x/255
Here we use classification tree to train a classifier, and then compare with the multinomial logit model (in last lab).
Due to the large size, we only use first 3000 observations as training sample.
fit.tree <- rpart(y ~., method = "class", data = data.frame(y=train.y[1:3000], x=train.x[1:3000,]), cp=0.00001)
plotcp(fit.tree)
printcp(fit.tree)
##
## Classification tree:
## rpart(formula = y ~ ., data = data.frame(y = train.y[1:3000],
## x = train.x[1:3000, ]), method = "class", cp = 1e-05)
##
## Variables actually used in tree construction:
## [1] x.pixel127 x.pixel151 x.pixel152 x.pixel153 x.pixel154 x.pixel155
## [7] x.pixel156 x.pixel157 x.pixel178 x.pixel181 x.pixel183 x.pixel204
## [13] x.pixel208 x.pixel210 x.pixel211 x.pixel236 x.pixel239 x.pixel242
## [19] x.pixel245 x.pixel262 x.pixel267 x.pixel270 x.pixel271 x.pixel288
## [25] x.pixel290 x.pixel292 x.pixel294 x.pixel295 x.pixel297 x.pixel316
## [31] x.pixel318 x.pixel322 x.pixel324 x.pixel326 x.pixel328 x.pixel347
## [37] x.pixel350 x.pixel353 x.pixel355 x.pixel372 x.pixel377 x.pixel379
## [43] x.pixel380 x.pixel386 x.pixel399 x.pixel400 x.pixel405 x.pixel406
## [49] x.pixel431 x.pixel434 x.pixel435 x.pixel439 x.pixel457 x.pixel458
## [55] x.pixel485 x.pixel486 x.pixel489 x.pixel490 x.pixel491 x.pixel492
## [61] x.pixel515 x.pixel516 x.pixel517 x.pixel541 x.pixel542 x.pixel568
## [67] x.pixel571 x.pixel581 x.pixel596 x.pixel597 x.pixel600 x.pixel608
## [73] x.pixel624 x.pixel632 x.pixel652 x.pixel655 x.pixel656 x.pixel657
## [79] x.pixel686 x.pixel97
##
## Root node error: 2666/3000 = 0.88867
##
## n= 3000
##
## CP nsplit rel error xerror xstd
## 1 0.09039760 0 1.00000 1.00750 0.0062892
## 2 0.08777194 1 0.90960 0.90885 0.0080974
## 3 0.08252063 2 0.82183 0.84546 0.0088803
## 4 0.06189047 3 0.73931 0.74794 0.0096993
## 5 0.04538635 5 0.61553 0.61965 0.0102195
## 6 0.03600900 6 0.57014 0.57052 0.0102714
## 7 0.02775694 7 0.53413 0.54089 0.0102647
## 8 0.02625656 8 0.50638 0.52138 0.0102447
## 9 0.02550638 9 0.48012 0.51725 0.0102389
## 10 0.02138035 10 0.45461 0.48687 0.0101788
## 11 0.01762941 11 0.43323 0.46812 0.0101264
## 12 0.01050263 12 0.41560 0.43586 0.0100082
## 13 0.00937734 13 0.40510 0.42611 0.0099653
## 14 0.00862716 15 0.38635 0.42123 0.0099426
## 15 0.00750188 16 0.37772 0.41110 0.0098928
## 16 0.00675169 17 0.37022 0.40885 0.0098812
## 17 0.00637659 20 0.34996 0.40023 0.0098351
## 18 0.00600150 21 0.34359 0.39272 0.0097927
## 19 0.00525131 23 0.33158 0.38785 0.0097641
## 20 0.00375094 25 0.32108 0.36497 0.0096175
## 21 0.00356339 29 0.30608 0.35109 0.0095186
## 22 0.00337584 31 0.29895 0.34734 0.0094905
## 23 0.00325081 36 0.28207 0.33983 0.0094326
## 24 0.00312578 39 0.27232 0.33871 0.0094237
## 25 0.00300075 42 0.26294 0.33683 0.0094088
## 26 0.00281320 44 0.25694 0.33346 0.0093815
## 27 0.00262566 47 0.24756 0.33271 0.0093754
## 28 0.00243811 52 0.23443 0.32858 0.0093413
## 29 0.00225056 54 0.22956 0.32183 0.0092839
## 30 0.00187547 61 0.21380 0.30983 0.0091770
## 31 0.00168792 65 0.20630 0.30758 0.0091562
## 32 0.00150038 67 0.20293 0.30983 0.0091770
## 33 0.00112528 78 0.18642 0.30645 0.0091457
## 34 0.00075019 85 0.17854 0.30458 0.0091281
## 35 0.00056264 92 0.17329 0.30758 0.0091562
## 36 0.00037509 94 0.17217 0.30720 0.0091527
## 37 0.00001000 95 0.17179 0.31470 0.0092212
# make prediction
pred.y.tree<- predict(fit.tree, data.frame(y=test.y, x=test.x), type = "class")
# accuracy rate
mean(test.y==pred.y.tree)
## [1] 0.7297619
We got pretty good accuracy. As we learn more advanced ML algorithm, you will see that the accuracy rate could hit to 99%.
plotResults <- function(testdata, index, preds){
op <- par(no.readonly=TRUE)
x <- ceiling(sqrt(length(index)))
par(mfrow=c(x,x), mar=c(.1,.1,.1,.1))
for (i in index){
m <- matrix(testdata[i,], nrow=28, byrow=TRUE)
m <- apply(m, 2, rev)
image(t(m), col=grey.colors(255), axes=FALSE)
text(0.05,0.1,col="green", cex=1.2, preds[i])
}
par(op)
}
Here are the first 100 images in the test set and their predicted values:
plotResults(testdata=test.x, index=1:100, preds=pred.y.tree)
We see it did a very good job.
Use rpart()
to fit regression and classification
tree.
Know how to interpret a tree.
Use predict()
for prediction, and how to assess the
performance.
Know how to use Cp plot/table to prune a large tree.