![]() | ||
Large margin nearest neighbor (LMNN) classification is a statistical machine learning algorithm for metric learning. It learns a pseudometric designed for k-nearest neighbor classification. The algorithm is based on semidefinite programming, a sub-class of convex optimization.
Contents
The goal of supervised learning (more specifically classification) is to learn a decision rule that can categorize data instances into pre-defined classes. The k-nearest neighbor rule assumes a training data set of labeled instances (i.e. the classes are known). It classifies a new data instance with the class obtained from the majority vote of the k closest (labeled) training instances. Closeness is measured with a pre-defined metric. Large margin nearest neighbors is an algorithm that learns this global (pseudo-)metric in a supervised fashion to improve the classification accuracy of the k-nearest neighbor rule.
Setup
The main intuition behind LMNN is to learn a pseudometric under which all data instances in the training set are surrounded by at least k instances that share the same class label. If this is achieved, the leave-one-out error (a special case of cross validation) is minimized. Let the training data consist of a data set
The algorithm learns a pseudometric of the type
For
Figure 1 illustrates the effect of the metric under varying
The algorithm distinguishes between two types of special data points: target neighbors and impostors.
Target neighbors
Target neighbors are selected before learning. Each instance
Impostors
An impostor of a data point
Algorithm
Large margin nearest neighbors optimizes the matrix
The first optimization goal is achieved by minimizing the average distance between instances and their target neighbors
The second goal is achieved by constraining impostors
The margin of exactly one unit fixes the scale of the matrix
The final optimization problem becomes:
Here the slack variables
Extensions and efficient solvers
LMNN was extended to multiple local metrics in the 2008 paper. This extension significantly improves the classification error, but involves a more expensive optimization problem. In their 2009 publication in the Journal of Machine Learning Research, Weinberger and Saul derive an efficient solver for the semi-definite program. It can learn a metric for the MNIST handwritten digit data set in several hours, involving billions of pairwise constraints. An open source Matlab implementation is freely available at the authors web page.
Kumal et al. extended the algorithm to incorporate local invariances to multivariate polynomial transformations and improved regularization.