Using scikit-learn regressions with PyMVPAΒΆ

This is practically part two of the example on using scikit-learn classifiers with PyMVPA. Just like classifiers, implementations of regression algorithms in scikit-learn use the estimator and predictor API. Consequently, the same wrapper class (SKLLearnerAdapter) as before is applicable when using scikit-learn regressions in PyMVPA.

The example demonstrates this by mimicking the “Decision Tree Regression” example from the scikit-learn documentation – applying the minimal modifications necessary to the scikit-learn decision tree regression algorithm (with two different parameter settings) implementation on a PyMVPA dataset.


import numpy as np
from sklearn.tree import DecisionTreeRegressor

rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))

So far the code has been identical. The first difference is the import of the adaptor class. We also use a convenient way to convert the data into a proper Dataset.

# this first import is only required to run the example a part of the test suite
from mvpa2 import cfg
from mvpa2.clfs.skl.base import SKLLearnerAdapter
from mvpa2.datasets import dataset_wizard
ds_train=dataset_wizard(samples=X, targets=y)

The following lines are an example of the only significant modification with respect to a pure scikit-learn implementation: the regression is wrapped into the adaptor. The result is a PyMVPA learner, hence can be called with a dataset that contains both samples and targets.

clf_1 = SKLLearnerAdapter(DecisionTreeRegressor(max_depth=2))
clf_2 = SKLLearnerAdapter(DecisionTreeRegressor(max_depth=5))


X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = clf_1.predict(X_test)
y_2 = clf_2.predict(X_test)

# plot the results
# which clearly show the overfitting for the second depth setting
import pylab as pl

pl.scatter(X, y, c="k", label="data")
pl.plot(X_test, y_1, c="g", label="max_depth=2", linewidth=2)
pl.plot(X_test, y_2, c="r", label="max_depth=5", linewidth=2)
pl.title("Decision Tree Regression")

See also

The full source code of this example is included in the PyMVPA source distribution (doc/examples/