Kaggle Digit Recognizer: Mahout Random Forest attempt
I’ve written previously about the K-means approach that Jen and I took when trying to solve Kaggle’s Digit Recognizer and having stalled at about 80% accuracy we decided to try one of the algorithms suggested in the tutorials section - the random forest!
We initially used a clojure random forests library but struggled to build the random forest from the training set data in a reasonable amount of time so we switched to Mahout’s version which is based on Leo Breiman’s random forests paper.
There’s a really good example explaining how ensembles work on the Factual blog which we found quite useful in helping us understand how random forests are supposed to work.
One of the most powerful Machine Learning techniques we turn to is ensembling. Ensemble methods build surprisingly strong models out of a collection of weak models called base learners, and typically require far less tuning when compared to models like Support Vector Machines. Most ensemble methods use decision trees as base learners and many ensembling techniques, like Random Forests and Adaboost, are specific to tree ensembles.
We were able to adapt the https://github.com/apache/mahout/blob/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java included in the examples section of the Mahout repository to do what we wanted.
To start with we wrote the following code to build the random forest:
public class MahoutKaggleDigitRecognizer {
public static void main(String[] args) throws Exception {
String descriptor = "L N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N ";
String[] trainDataValues = fileAsStringArray("data/train.csv");
Data data = DataLoader.loadData(DataLoader.generateDataset(descriptor, false, trainDataValues), trainDataValues);
int numberOfTrees = 100;
DecisionForest forest = buildForest(numberOfTrees, data);
}
private static DecisionForest buildForest(int numberOfTrees, Data data) {
int m = (int) Math.floor(Maths.log(2, data.getDataset().nbAttributes()) + 1);
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
treeBuilder.setM(m);
return new SequentialBuilder(RandomUtils.getRandom(), treeBuilder, data.clone()).build(numberOfTrees);
}
private static String[] fileAsStringArray(String file) throws Exception {
ArrayList<String> list = new ArrayList<String>();
DataInputStream in = new DataInputStream(new FileInputStream(file));
BufferedReader br = new BufferedReader(new InputStreamReader(in));
String strLine;
br.readLine(); // discard top one (header)
while ((strLine = br.readLine()) != null) {
list.add(strLine);
}
in.close();
return list.toArray(new String[list.size()]);
}
}
The training data file looks a bit like this:
label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8...,pixel783
1,0,0,0,0,0,0,...,0
0,0,0,0,0,0,0,...,0
So in this case the label is in the first column which is represented as an L in the descriptor and the next 784 columns are the numerical value of the pixels in the image (hence the 784 N's in the descriptor).
We’re telling it to create a random forest which contains 100 trees and since we have a finite number of categories that an entry can be classified as we pass false as the 2nd argument (regression) of https://github.com/apache/mahout/blob/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java#L184.
The m value determines how many attributes (pixel values in this case) are used to construct each tree and supposedly log2(number_of_attributes) + 1 is the optimal value for that!
We then wrote the following code to predict the labels of the test data set:
public class MahoutKaggleDigitRecognizer {
public static void main(String[] args) throws Exception {
...
String[] testDataValues = testFileAsStringArray("data/test.csv");
Data test = DataLoader.loadData(data.getDataset(), testDataValues);
Random rng = RandomUtils.getRandom();
for (int i = 0; i < test.size(); i++) {
Instance oneSample = test.get(i);
double classify = forest.classify(test.getDataset(), rng, oneSample);
int label = data.getDataset().valueOf(0, String.valueOf((int) classify));
System.out.println("Label: " + label);
}
private static String[] testFileAsStringArray(String file) throws Exception {
ArrayList<String> list = new ArrayList<String>();
DataInputStream in = new DataInputStream(new FileInputStream(file));
BufferedReader br = new BufferedReader(new InputStreamReader(in));
String strLine;
br.readLine(); // discard top one (header)
while ((strLine = br.readLine()) != null) {
list.add("-," + strLine);
}
in.close();
return list.toArray(new String[list.size()]);
}
}
There were a couple of things that we found confusing when working out how to do this:
-
The format of the test data needs to be identical to that of the training data which consisted of a label followed by 784 numerical values. Obviously with the test data we don’t have a label so Mahout excepts us to pass a '-' where the label would go otherwise it will throw an exception, which explains the '-' on the list.add line.
-
We initially thought the value returned by https://github.com/apache/mahout/blob/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java#L90 was the prediction but in actual fact it’s an index which we then need to look up on the data set.
When we ran this algorithm against the test data set with 10 trees we got an accuracy of 83.8%, with 50 trees we got 84.4%, with 100 trees we got 96.28% and with 200 trees we got 96.33% which is where we’ve currently peaked.
The amount of time it’s taking to build the forests as we increase the number of trees is also starting to become a problem so our next step is either to look at a way to parallelise the creation of the forest or do some sort of feature extractionto try and improve the accuracy.
The code is on github if you’re interested in playing with it or have any suggestions on how to improve it.
About the author
I'm currently working on short form content at ClickHouse. I publish short 5 minute videos showing how to solve data problems on YouTube @LearnDataWithMark. I previously worked on graph analytics at Neo4j, where I also co-authored the O'Reilly Graph Algorithms Book with Amy Hodler.