Grid Search¶
Tables of contents
scikits.learn.grid_search is a package to optimize the parameters of a model (e.g. Support Vector Classifier) using cross-validation
It is implemented in python, and uses the numpy and scipy packages. The computation can be run in parallel using the multiprocessing package.
GridSearchCV¶
- class scikits.learn.grid_search.GridSearchCV(estimator, param_grid, loss_func=None, fit_params={}, n_jobs=1, iid=True)¶
Grid search on the parameters of a classifier.
Important members are fit, predict.
GridSearchCV implements a “fit” method and a “predict” method like any classifier except that the parameters of the classifier used to predict is optimized by cross-validation
Parameters : estimator: object type that implements the “fit” and “predict” methods :
A object of that type is instanciated for each grid point
param_grid: dict :
a dictionary of parameters that are used the generate the grid
loss_func: callable, optional :
function that takes 2 arguments and compares them in order to evaluate the performance of prediciton (small is good) if None is passed, the score of the estimator is maximized
fit_params : dict, optional
parameters to pass to the fit method
n_jobs: int, optional :
number of jobs to run in parallel (default 1)
iid: boolean, optional :
If True, the data is assumed to be identically distributed across the folds, and the loss minimized is the total loss per sample, and not the mean loss across the folds.
Examples
>>> import numpy as np >>> from scikits.learn.cross_val import LeaveOneOut >>> from scikits.learn.svm import SVR >>> from scikits.learn.grid_search import GridSearchCV >>> X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) >>> y = np.array([1, 1, 2, 2]) >>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]} >>> svr = SVR() >>> clf = GridSearchCV(svr, parameters, n_jobs=1) >>> clf.fit(X, y).predict([[-0.8, -1]]) array([ 1.14])
Methods
fit(X, Y) self Fit the model predict(X) array Predict using the model. - fit(X, y, refit=True, cv=None, **kw)¶
Run fit with all sets of parameters Returns the best classifier
Parameters : X: array, [n_samples, n_features] :
Training vector, where n_samples in the number of samples and n_features is the number of features.
y: array, [n_samples] :
Target vector relative to X
cv : crossvalidation generator
see scikits.learn.cross_val module
refit: boolean :
refit the best estimator with the entire dataset
Examples¶
See Parameter estimation using grid search with a nested cross-validation for an example of Grid Search computation on the digits dataset.