Warning: This documentation is for scikits.learn version 0.8. — Latest stable version

This page

scikits.learn.neighbors.NeighborsClassifier

class scikits.learn.neighbors.NeighborsClassifier(n_neighbors=5, algorithm='auto', window_size=1)

Classifier implementing k-Nearest Neighbor Algorithm.

Parameters :

n_neighbors : int, optional

Default number of neighbors. Defaults to 5.

window_size : int, optional

Window size passed to BallTree

algorithm : {‘auto’, ‘ball_tree’, ‘brute’}, optional

Algorithm used to compute the nearest neighbors. ‘ball_tree’ will construct a BallTree while ‘brute’will perform brute-force search. ‘auto’ will guess the most appropriate based on current dataset.

See also

BallTree

References

http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm

Examples

>>> samples = [[0, 0, 1], [1, 0, 0]]
>>> labels = [0, 1]
>>> from scikits.learn.neighbors import NeighborsClassifier
>>> neigh = NeighborsClassifier(n_neighbors=1)
>>> neigh.fit(samples, labels)
NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
>>> print neigh.predict([[0,0,0]])
[1]

Methods

__init__(n_neighbors=5, algorithm='auto', window_size=1)
fit(X, y, **params)

Fit the model using X, y as training data

Parameters :

X : array-like, shape = [n_samples, n_features]

Training data.

y : array-like, shape = [n_samples]

Target values, array of integer values.

params : list of keyword, optional

Overwrite keywords from __init__

kneighbors(X, return_distance=True, **params)

Finds the K-neighbors of a point.

Returns distance

Parameters :

point : array-like

The new point.

n_neighbors : int

Number of neighbors to get (default is the value passed to the constructor).

return_distance : boolean, optional. Defaults to True.

If False, distances will not be returned

Returns :

dist : array

Array representing the lengths to point, only present if return_distance=True

ind : array

Indices of the nearest points in the population matrix.

Examples

In the following example, we construnct a NeighborsClassifier class from an array representing our data set and ask who’s the closest point to [1,1,1]

>>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
>>> labels = [0, 0, 1]
>>> from scikits.learn.neighbors import NeighborsClassifier
>>> neigh = NeighborsClassifier(n_neighbors=1)
>>> neigh.fit(samples, labels)
NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
>>> print neigh.kneighbors([1., 1., 1.]) 
(array([[ 0.5]]), array([[2]]...))

As you can see, it returns [[0.5]], and [[2]], which means that the element is at distance 0.5 and is the third element of samples (indexes start at 0). You can also query for multiple points:

>>> X = [[0., 1., 0.], [1., 0., 1.]]
>>> neigh.kneighbors(X, return_distance=False) 
array([[1],
       [2]]...)
predict(X, **params)

Predict the class labels for the provided data

Parameters :

X: array :

A 2-D array representing the test point.

n_neighbors : int

Number of neighbors to get (default is the value passed to the constructor).

Returns :

labels: array :

List of class labels (one for each data sample).

score(X, y)

Returns the mean error rate on the given test data and labels.

Parameters :

X : array-like, shape = [n_samples, n_features]

Training set.

y : array-like, shape = [n_samples]

Labels for X.

Returns :

z : float