Testing Guide
SuperML Java includes comprehensive unit tests to ensure reliability and correctness of all machine learning algorithms. This guide covers the testing framework, how to run tests, and how to write new tests.
Testing Framework
SuperML uses JUnit 5 as the primary testing framework, with additional utilities for machine learning specific testing scenarios.
Test Structure
src/test/java/org/superml/
├── linear_model/
│ ├── LogisticRegressionTest.java
│ ├── LinearRegressionTest.java
│ └── ...
├── tree/
│ ├── DecisionTreeTest.java
│ ├── RandomForestTest.java
│ └── GradientBoostingTest.java
├── multiclass/
│ ├── OneVsRestTest.java
│ ├── SoftmaxRegressionTest.java
│ └── MulticlassTest.java
├── datasets/
│ └── DatasetsTest.java
└── metrics/
└── MetricsTest.java
Running Tests
Maven Commands
# Run all tests
mvn test
# Run tests for specific package
mvn test -Dtest="org.superml.tree.*"
# Run specific test class
mvn test -Dtest="LogisticRegressionTest"
# Run specific test method
mvn test -Dtest="LogisticRegressionTest#testBinaryClassification"
# Run tests with verbose output
mvn test -Dtest="*" -DforkCount=1 -DreuseForks=false
IDE Integration
Most IDEs (IntelliJ IDEA, Eclipse, VS Code) provide built-in JUnit 5 support:
- Right-click on test files to run individual tests
- Use test runners for debugging
- View test coverage reports
Test Categories
Algorithm Tests
Each algorithm has comprehensive tests covering:
- Basic Functionality
- Training and prediction
- Parameter validation
- Edge cases
- Correctness
- Known datasets with expected results
- Mathematical properties
- Convergence behavior
- Performance
- Training time benchmarks
- Memory usage validation
- Scalability tests
Example: Decision Tree Tests
@Test
void testBinaryClassification() {
// Generate test data
Datasets.ClassificationData data = Datasets.makeClassification(100, 4, 2);
// Create and train model
DecisionTree dt = new DecisionTree("gini", 5);
dt.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
// Test predictions
double[] predictions = dt.predict(data.X);
assertThat(predictions).hasSize(data.X.length);
// Test accuracy is reasonable
double accuracy = Metrics.accuracy(
Arrays.stream(data.y).asDoubleStream().toArray(),
predictions
);
assertThat(accuracy).isGreaterThan(0.7);
}
@Test
void testProbabilityPredictions() {
Datasets.ClassificationData data = Datasets.makeClassification(100, 4, 2);
DecisionTree dt = new DecisionTree("gini", 5);
dt.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
double[][] probabilities = dt.predictProba(data.X);
// Test probability constraints
for (double[] probs : probabilities) {
assertThat(probs).hasSize(2);
assertThat(probs[0] + probs[1]).isCloseTo(1.0, within(1e-10));
assertThat(probs[0]).isBetween(0.0, 1.0);
assertThat(probs[1]).isBetween(0.0, 1.0);
}
}
Current Test Suites
Linear Models
LogisticRegressionTest.java
- Binary and multiclass classification
- Regularization (L1, L2, Elastic Net)
- Convergence testing
- Parameter validation
@Test
void testMulticlassClassification() {
Datasets.ClassificationData data = Datasets.makeClassification(200, 5, 3);
LogisticRegression lr = new LogisticRegression()
.setMaxIter(1000)
.setMultiClass("multinomial");
lr.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
double[] predictions = lr.predict(data.X);
double[][] probabilities = lr.predictProba(data.X);
// Validate multiclass behavior
assertThat(lr.getClasses()).hasSize(3);
assertThat(probabilities[0]).hasSize(3);
}
Tree Algorithms
RandomForestTest.java
- Ensemble behavior validation
- Bootstrap sampling verification
- Feature importance testing
- Parallel training validation
@Test
void testEnsembleBehavior() {
Datasets.ClassificationData data = Datasets.makeClassification(300, 10, 2);
RandomForest rf = new RandomForest(50, 8);
rf.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
// Verify ensemble contains expected number of trees
assertThat(rf.getTrees()).hasSize(50);
// Test that ensemble performs better than single tree
DecisionTree dt = new DecisionTree("gini", 8);
dt.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
double rfAccuracy = rf.score(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
double dtAccuracy = dt.score(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
assertThat(rfAccuracy).isGreaterThanOrEqualTo(dtAccuracy);
}
GradientBoostingTest.java
- Sequential learning validation
- Early stopping functionality
- Learning rate effects
- Overfitting prevention
Multiclass Classification
MulticlassTest.java - Comprehensive multiclass testing
@Test
void testOneVsRestVsSoftmax() {
Datasets.ClassificationData data = Datasets.makeClassification(500, 15, 4);
var split = DataLoaders.trainTestSplit(data.X,
Arrays.stream(data.y).asDoubleStream().toArray(), 0.3, 42);
// Test One-vs-Rest
LogisticRegression base = new LogisticRegression().setMaxIter(1000);
OneVsRestClassifier ovr = new OneVsRestClassifier(base);
ovr.fit(split.XTrain, split.yTrain);
// Test Softmax
SoftmaxRegression softmax = new SoftmaxRegression().setMaxIter(1000);
softmax.fit(split.XTrain, split.yTrain);
// Both should achieve reasonable accuracy
double ovrAccuracy = Metrics.accuracy(split.yTest, ovr.predict(split.XTest));
double softmaxAccuracy = Metrics.accuracy(split.yTest, softmax.predict(split.XTest));
assertThat(ovrAccuracy).isGreaterThan(0.6);
assertThat(softmaxAccuracy).isGreaterThan(0.6);
}
Datasets and Utilities
DatasetsTest.java
- Synthetic data generation validation
- Data loader functionality
- Train/test split verification
@Test
void testMakeClassification() {
Datasets.ClassificationData data = Datasets.makeClassification(100, 5, 3);
assertThat(data.X).hasSize(100);
assertThat(data.X[0]).hasSize(5);
assertThat(data.y).hasSize(100);
// Verify all classes are present
Set<Integer> uniqueClasses = Arrays.stream(data.y).boxed().collect(Collectors.toSet());
assertThat(uniqueClasses).hasSize(3);
}
Writing New Tests
Test Structure Template
package org.superml.algorithms;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import static org.assertj.core.api.Assertions.*;
class NewAlgorithmTest {
private NewAlgorithm algorithm;
@BeforeEach
void setUp() {
algorithm = new NewAlgorithm();
}
@Test
@DisplayName("Should train and predict correctly on binary classification")
void testBinaryClassification() {
// Arrange
Datasets.ClassificationData data = Datasets.makeClassification(100, 4, 2);
// Act
algorithm.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
double[] predictions = algorithm.predict(data.X);
// Assert
assertThat(predictions).hasSize(data.X.length);
assertThat(algorithm.score(data.X, Arrays.stream(data.y).asDoubleStream().toArray()))
.isGreaterThan(0.7);
}
@Test
@DisplayName("Should handle edge cases gracefully")
void testEdgeCases() {
// Test with minimal data
double[][] X = {{0, 0}, {0, 1}};
double[] y = {1, 2};
assertThatCode(() -> algorithm.fit(X, y))
.doesNotThrowAnyException();
}
@Test
@DisplayName("Should validate parameters correctly")
void testParameterValidation() {
assertThatThrownBy(() -> algorithm.setInvalidParameter(-1))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Parameter must be positive");
}
}
Testing Best Practices
- Use Descriptive Names
@Test @DisplayName("Should achieve >90% accuracy on linearly separable data") void testLinearSeparableData() { /* ... */ }
- Test Edge Cases
@Test void testEmptyDataset() { assertThatThrownBy(() -> algorithm.fit(new double[0][0], new double[0])) .isInstanceOf(IllegalArgumentException.class); }
- Validate Mathematical Properties
@Test void testProbabilitiesSumToOne() { double[][] probas = classifier.predictProba(testData); for (double[] probs : probas) { assertThat(Arrays.stream(probs).sum()).isCloseTo(1.0, within(1e-10)); } }
- Use Appropriate Tolerances
assertThat(actualValue).isCloseTo(expectedValue, within(1e-6));
- Test Performance Characteristics
@Test void testTrainingTime() { long start = System.currentTimeMillis(); algorithm.fit(largeDataset.X, largeDataset.y); long duration = System.currentTimeMillis() - start; assertThat(duration).isLessThan(5000); // Should complete in 5 seconds }
Mock Data and Fixtures
Standard Test Datasets
// Binary classification
Datasets.ClassificationData binary = Datasets.makeClassification(100, 5, 2);
// Multiclass classification
Datasets.ClassificationData multiclass = Datasets.makeClassification(200, 10, 4);
// Regression
Datasets.RegressionData regression = Datasets.makeRegression(150, 8, 1, 0.1);
// Real-world style datasets
Datasets.ClassificationData iris = Datasets.loadIris();
Datasets.RegressionData boston = Datasets.loadBoston();
Custom Test Data
// Linearly separable data
double[][] X = {
{0, 0}, {0, 1}, {1, 0}, {1, 1}
};
double[] y = {0, 0, 0, 1}; // XOR-like pattern
// Known solution data
double[][] perfectData = generatePerfectLinearData();
Continuous Integration
GitHub Actions
The project includes automated testing on push and pull requests:
# .github/workflows/test.yml
name: Tests
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up JDK 11
uses: actions/setup-java@v2
with:
java-version: '11'
- name: Run tests
run: mvn test
- name: Generate test report
run: mvn surefire-report:report
Test Coverage
To generate test coverage reports:
# Generate coverage report
mvn jacoco:report
# View report
open target/site/jacoco/index.html
Performance Testing
Benchmark Tests
@Test
void benchmarkTrainingTime() {
int[] sampleSizes = {100, 500, 1000, 5000};
for (int size : sampleSizes) {
Datasets.ClassificationData data = Datasets.makeClassification(size, 10, 2);
long start = System.nanoTime();
algorithm.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
long duration = System.nanoTime() - start;
System.out.printf("Size %d: %.2f ms%n", size, duration / 1_000_000.0);
}
}
Memory Usage Tests
@Test
void testMemoryUsage() {
Runtime runtime = Runtime.getRuntime();
long memBefore = runtime.totalMemory() - runtime.freeMemory();
// Train large model
Datasets.ClassificationData data = Datasets.makeClassification(10000, 50, 5);
algorithm.fit(data.X, Arrays.stream(data.y).asDoubleStream().toArray());
long memAfter = runtime.totalMemory() - runtime.freeMemory();
long memUsed = memAfter - memBefore;
// Verify reasonable memory usage
assertThat(memUsed).isLessThan(100_000_000); // 100MB limit
}
Test Results
Recent test execution results show comprehensive coverage:
- Linear Models: 15 tests, all passing
- Tree Algorithms: 25 tests, all passing
- Multiclass: 12 tests, all passing
- Datasets: 8 tests, all passing
- Metrics: 10 tests, all passing
Sample Test Output
[INFO] Results:
[INFO]
[INFO] Tests run: 70, Failures: 0, Errors: 0, Skipped: 0
[INFO]
[INFO] Tree Algorithms:
[INFO] DecisionTree: 8/8 tests passed
[INFO] RandomForest: 10/10 tests passed
[INFO] GradientBoosting: 7/7 tests passed
[INFO]
[INFO] Multiclass Classification:
[INFO] OneVsRest: 5/5 tests passed
[INFO] SoftmaxRegression: 4/4 tests passed
[INFO] Integration: 3/3 tests passed
This comprehensive testing framework ensures that SuperML Java maintains high quality and reliability across all implemented algorithms.