Tree-Based Algorithms Guide
SuperML Java provides comprehensive implementations of tree-based machine learning algorithms, including Decision Trees, Random Forest, and Gradient Boosting. These algorithms are among the most popular and effective machine learning methods for both classification and regression tasks.
Overview
Tree-based algorithms work by recursively splitting the feature space into regions and making predictions based on the majority class (classification) or average value (regression) in each region.
Available Algorithms
Algorithm | Class | Best For | Key Features |
---|---|---|---|
Decision Tree | DecisionTree |
Interpretable models, feature selection | Fast training, easy to understand |
Random Forest | RandomForest |
Robust predictions, feature importance | Ensemble method, handles overfitting |
Gradient Boosting | GradientBoosting |
High accuracy, competitions | Sequential learning, excellent performance |
Decision Trees
Decision Trees are the foundation of tree-based algorithms. They create a model that predicts target values by learning simple decision rules inferred from data features.
Basic Usage
import org.superml.tree.DecisionTree;
import org.superml.datasets.Datasets;
// Create and configure decision tree
DecisionTree dt = new DecisionTree("gini", 10) // criterion, max_depth
.setMinSamplesSplit(5)
.setMinSamplesLeaf(2)
.setRandomState(42);
// Load data
var dataset = Datasets.makeClassification(1000, 20, 2);
var split = DataLoaders.trainTestSplit(dataset.X,
Arrays.stream(dataset.y).asDoubleStream().toArray(), 0.2, 42);
// Train
dt.fit(split.XTrain, split.yTrain);
// Predict
double[] predictions = dt.predict(split.XTest);
double[][] probabilities = dt.predictProba(split.XTest);
Parameters
- criterion: Splitting criterion
- Classification:
"gini"
,"entropy"
- Regression:
"mse"
- Classification:
- max_depth: Maximum tree depth (prevents overfitting)
- min_samples_split: Minimum samples required to split a node
- min_samples_leaf: Minimum samples required in a leaf node
- max_features: Number of features to consider for splits
- random_state: Random seed for reproducibility
Regression Example
DecisionTree regressor = new DecisionTree("mse", 15)
.setMinSamplesSplit(10)
.setMinSamplesLeaf(5);
var dataset = Datasets.makeRegression(1000, 10, 1, 0.1);
regressor.fit(dataset.X, dataset.y);
double[] predictions = regressor.predict(testX);
double r2Score = Metrics.r2Score(testY, predictions);
Random Forest
Random Forest builds multiple decision trees and combines their predictions through voting (classification) or averaging (regression). It uses bootstrap sampling and random feature selection to create diverse trees.
Basic Usage
import org.superml.tree.RandomForest;
// Create Random Forest
RandomForest rf = new RandomForest(100, 10) // n_estimators, max_depth
.setMaxFeatures(5)
.setBootstrap(true)
.setRandomState(42);
// Train
rf.fit(XTrain, yTrain);
// Predict
double[] predictions = rf.predict(XTest);
double[][] probabilities = rf.predictProba(XTest);
// Get feature importances
double[] importances = rf.getFeatureImportances();
Advanced Configuration
RandomForest rf = new RandomForest()
.setNEstimators(200)
.setMaxDepth(15)
.setMinSamplesSplit(5)
.setMinSamplesLeaf(2)
.setMaxFeatures(10) // or -1 for auto
.setBootstrap(true)
.setMaxSamples(0.8) // sample 80% for each tree
.setNJobs(4) // parallel training
.setRandomState(42);
Parallel Training
Random Forest supports parallel training for faster performance:
// Use all available cores
RandomForest rf = new RandomForest().setNJobs(-1);
// Use specific number of threads
RandomForest rf = new RandomForest().setNJobs(4);
Gradient Boosting
Gradient Boosting builds an ensemble of weak learners (typically shallow trees) sequentially, where each new tree corrects the errors of the previous ones.
Basic Usage
import org.superml.tree.GradientBoosting;
// Create Gradient Boosting model
GradientBoosting gb = new GradientBoosting(100, 0.1, 6) // n_estimators, learning_rate, max_depth
.setSubsample(0.8)
.setRandomState(42);
// Train
gb.fit(XTrain, yTrain);
// Predict
double[] predictions = gb.predict(XTest);
double[] rawPredictions = gb.predictRaw(XTest); // before sigmoid/softmax
Early Stopping
Gradient Boosting supports early stopping to prevent overfitting:
GradientBoosting gb = new GradientBoosting()
.setNEstimators(1000)
.setLearningRate(0.1)
.setValidationFraction(0.1) // 10% for validation
.setNIterNoChange(10) // stop after 10 rounds without improvement
.setTol(1e-4); // tolerance for improvement
gb.fit(XTrain, yTrain);
// Get training history
List<Double> trainScores = gb.getTrainScores();
List<Double> validScores = gb.getValidationScores();
Progressive Predictions
You can get predictions at any stage of boosting:
// Predict using only first 50 estimators
double[] earlyPreds = gb.predictAtIteration(XTest, 50);
// Compare with full model
double[] fullPreds = gb.predict(XTest);
Performance Comparison
Here’s a typical performance comparison on a synthetic dataset:
// Generate dataset
var dataset = Datasets.makeClassification(1000, 20, 2);
var split = DataLoaders.trainTestSplit(dataset.X,
Arrays.stream(dataset.y).asDoubleStream().toArray(), 0.2, 42);
// Decision Tree
DecisionTree dt = new DecisionTree("gini", 10);
dt.fit(split.XTrain, split.yTrain);
double dtAccuracy = Metrics.accuracy(split.yTest, dt.predict(split.XTest));
// Random Forest
RandomForest rf = new RandomForest(100, 10);
rf.fit(split.XTrain, split.yTrain);
double rfAccuracy = Metrics.accuracy(split.yTest, rf.predict(split.XTest));
// Gradient Boosting
GradientBoosting gb = new GradientBoosting(100, 0.1, 6);
gb.fit(split.XTrain, split.yTrain);
double gbAccuracy = Metrics.accuracy(split.yTest, gb.predict(split.XTest));
System.out.println("Decision Tree: " + dtAccuracy);
System.out.println("Random Forest: " + rfAccuracy);
System.out.println("Gradient Boosting: " + gbAccuracy);
Best Practices
Decision Trees
- Start with moderate depth (5-15) to avoid overfitting
- Use
min_samples_split
andmin_samples_leaf
for regularization - Consider pruning for better generalization
Random Forest
- More trees generally improve performance (100-500 is common)
- Use bootstrap sampling (
bootstrap=true
) - Set
max_features
to sqrt(n_features) for classification - Use parallel training for large datasets
Gradient Boosting
- Start with small learning rate (0.01-0.3)
- Use shallow trees (depth 3-8)
- Enable early stopping for optimal performance
- Consider subsample < 1.0 for regularization
General Tips
- Always validate on a separate test set
- Use cross-validation for model selection
- Monitor for overfitting, especially with Decision Trees
- Random Forest and Gradient Boosting usually outperform single trees
- Gradient Boosting often achieves the best performance but takes longer to train
Feature Importance
Tree-based models provide feature importance scores:
// Random Forest feature importance
RandomForest rf = new RandomForest(100, 10);
rf.fit(XTrain, yTrain);
double[] importances = rf.getFeatureImportances();
// Print top features
for (int i = 0; i < importances.length; i++) {
System.out.println("Feature " + i + ": " + importances[i]);
}
Integration with Other Components
Tree algorithms work seamlessly with other SuperML components:
// With preprocessing
StandardScaler scaler = new StandardScaler();
double[][] scaledX = scaler.fitTransform(XTrain);
RandomForest rf = new RandomForest().fit(scaledX, yTrain);
// With pipelines
Pipeline pipeline = new Pipeline()
.addStep("scaler", new StandardScaler())
.addStep("classifier", new RandomForest(100, 10));
// With multiclass
OneVsRestClassifier ovr = new OneVsRestClassifier(new DecisionTree());
ovr.fit(XTrain, yTrain);
Complete Example
See TreeAlgorithmsExample.java for a comprehensive demonstration of all tree-based algorithms with both classification and regression examples.