Classification using Random forest in R

r_rf.png Introduction

Random forest (or decision tree forests) is one of the most popular decision tree-based ensemble models. The accuracy of these models tends to be higher than most of the other decision trees. Random Forest algorithm can be used for both classification and regression applications.

The main drawback of decision trees is that they are prone to overfitting. The reason for this is that trees, if grown deep, are able to fit all kinds of variations in the data, including noise. Although it is possible to address this partially by pruning, the result often remains less than satisfactory.

Ensemble learning is a method to combine results produced by different learners into one format, with the aim of producing better classification results and regression results.

In ensemble learning, bagging, boosting, and random forest are the three most common methods.

Random forest uses the classification results voted from many classification trees. The idea is simple: a single classification tree will obtain a single classification result with a single input vector. However, a random forest grows many classification trees, obtaining multiple results from a single input. Therefore, a random forest will use the majority of votes from all the decision trees to classify data or use an average output for regression.

Random forest creates a large number of decision trees. Every observation is feed into every decision tree. The most common outcome for each observation is used as the final output. A new observation is feed into all the trees and taking a majority vote for each classification model. In ensemble terms, the trees are weak learners and the random forest is a strong learner.

r_rf_concept.png

It's worth noting that relative to other ensemble-based methods, random forests are quite competitive and offer key advantages relative to the competition. For instance, random forests tend to be easier to use and less prone to overfitting. The following is the general strengths and weaknesses of random forest models.

Strengths of Random forest

  • This algorithm can solve both type of problems i.e. classification and regression and does a decent estimation at both fronts.
  • It has an effective method for estimating missing data and maintains accuracy when a large proportion of the data are missing.
  • Selects only the most important features.
  • Can be used on data with an extremely large number of features or examples (higher dimensionality). It can handle thousands of input variables and identify most significant variables so it is considered as one of the dimensionality reduction methods. Further, the model outputs Importance of variable, which can be a very handy feature (on some random data set).

Weaknesses of Random forest

  • Unlike a decision tree, the model is not easily interpretable. Random Forest can feel like a black box approach for statistical modelers.
  • It surely does a good job at classification but not as good as for regression problem as it does not give precise continuous nature predictions. In case of regression, it doesn’t predict beyond the range in the training data, and that they may over-fit data sets that are particularly noisy.
  • May require some work to tune the model to the data.

Modeling in R

Random forest algorithm is built in randomForest package of R and same name function allows us to use it.

Some of the commonly used parameters of randomForest functions are

  • x - Random forest formula
  • data - input data frame
  • ntree - number of decision trees to be grown
  • mtry - the number of features used to find the best feature
  • replace takes True and False and indicates whether to take sample with/without replacement
  • importance - whether independent variable importance in random forest be assessed
  • proximity - whether to calculate proximity measures between rows of a data frame

The randomForest package optionally produces two additional pieces of information: a measure of the importance of the predictor variables, and a measure of the internal structure of the data (the proximity of different data points to one another).

By default, the randomForest() function creates an ensemble of 500 trees that consider sqrt(p) random features at each split, where p is the number of features in the training dataset and sqrt() refers to R's square root function. The goal of using a large number of trees is to train enough so that each feature has a chance to appear in several models.

Let's see how the default randomForest() parameters work with the iris data set. The set.seed() function ensures that the result can be replicated:

#install.packages("randomForest")
library("randomForest")
set.seed(1234)

Split iris data to training data and testing data

ind = sample(2, nrow(iris), replace=TRUE, prob=c(0.7,0.3))
trainData = iris[ind==1,]
testData = iris[ind==2,]

Generate Random forest learning tree

iris_rf = randomForest(Species~., data=trainData, ntree=100, proximity=T)
table(predict(iris_rf), trainData$Species)

To look at a summary of the model's performance, we can simply type the resulting object's name:

iris_rf

The output notes that the random forest included 100 trees and tried two variables at each split. At first glance, you might be alarmed at the seemingly poor performance according to the confusion matrix—the error rate of 5.36 percent is far worse than the resubstitution error of any of the other ensemble methods. However, this confusion matrix does not show resubstitution error. Instead, it reflects the out-of-bag error rate (listed in the output as OOB estimate of error rate), which unlike resubstitution error, is an unbiased estimate of the test set error. This means that it should be a fairly reasonable estimate of future performance.

You can use the plot() function to plot the mean square error of the forest object:

plot(iris_rf)
r_rf_plot.png

100 decision trees or a forest has been built using the random forest algorithm based learning. We can plot the error rate across decision trees. The plot seems to indicate that after 40 decision trees, there is not a significant reduction in error rate.

You can then examine the importance of each attribute within the fitted classifier:

importance(iris_rf)

Try to build random forest for testing data. Similar to other classification methods, you can obtain the classification table:

irisPred = predict(iris_rf, newdata=testData)
table(irisPred, testData$Species)

Try to see the margin, positive or negative, if positif it means correct classification

plot(margin(iris_rf, testData$Species))

Let's determine the misclassification rate. First, build a confusion matrix. Each column of the matrix represents the number of predictions of each class, while each row represents the instances in the actual class.

CM = table(irisPred, testData$Species)

Second, build a diagonal mark quality prediction. Applying the diag function to this table then selects the diagonal elements, i.e., the number of points where random forest agrees with the true classification, and the sum command simply adds these values up.

accuracy = (sum(diag(CM)))/sum(CM)

The model has a high overall accuracy that is 0.9473684.

Useful links

comments powered by Disqus