kNN

El método nearest neighbours es útil en clasificación

Hay información más detallada en el apartado de generalidades

kNN es un algoritmo sencillo, pero poco sofisticado - \(k\) hace referencia al número de “vecinos” que se utiliza.

Básico
  • el modelo se construye durante la fase de predicción
  • menos sensible al ruido
  • sensible a outliers
  • peor rendimiento en espacios multidimensionales

La clasificación se realiza midiendo la distancia entre observaciones no categorizadas e infieriendo sus clases por sus vecinos más cercanos

La función knn entrena el modelo y hace las predicciones en un mismo tiempo

knn(train,    # el modelo de entrenamiento
    test,     # los datos de prueba
    cl,       # las etiquetas de 'train'
    k = 1)    # el número de 'vecinos' que se consideran

k-NN coloca cada elemento utilizando las coordenadas de las características a analizar, y mide la distancia Euclidea entre los distintos elementos. Para los elementos \(p\) y \(q\) con características \(p_{i}\) y \(q_{i}\):

\[ dist(p, q) = \sqrt{(p_{1} - q_{1})^2 + (p_{2} - q_{2})^2 + ... + (p_{n} - q_{n})^2} \]

El valor de k

Si utilizamos \(k = 1\), el elemento se clasifica respecto a su vecino más cercano; sin embargo, si utilizamos \(k = 3\), considera los tres vecinos más cercanos y los clasifica según el que sea más numeroso (o el más próximo en caso de que sean distintos).

Valores de \(k\) muy altos tiende a ignorar patrones pequeños - hay un seso de clasificar todo como el elemento más numeroso. Hay estrategias que ponderan el peso de los vecinos en función de otros parámetros, como su distancia.

Valores de \(k\) muy pequeños son más sensibles al ruido, y hacen más fácil una clasificación incorrecta.

Valor de k

Una recomendación habitual es empezar con la raíz cuadrada del número de elementos de entrenamiento, \(\sqrt{n}\).

Preparar los datos

Lo primero que hay que hacer es eliminar las variables de identificación, que no son relevantes para el análisis, y separar las variables que tienen que ver con el diagnóstico (aunque esto debe hacerse una vez separadas las muestras de entrenamiento y prueba)

Los datos deben estar normalizados para que pueda generarse un modelo útil; por ejemplo, si una variable va del 1 al 10 y otra del 1 a 1.000, esta última siempre tendrá más peso al calcular las distancias.

Estrategias de normalización

Min-max

Convierte los valores de un rango en un número entre el 0 y el 1.

\[ X_{new} = \frac{X - X_{min}}{X_{max} - X_{min}} \]

Es menos robusta porque las muestras de prueba sobre las que se aplica el modelo pueden tener valores fuera del rango que hemos utilizado durante el entrenamiento.

Puntación escalar o z-score

Es un método más robusto porque asume que los nuevos datos proceden de la misma distribución normal, y por tanto comparte las mismas características.

\[ X_{new} = \frac{X - \mu}{\sigma} \]

Variables categóricas

Se convierten en \(0\) y \(1\); en caso de que sea una variable con varias categorías, cada una de ellas se convierte en una variable binomial. La versión denominada dummy coding genera menos problemas en modelos lineales, aunque one-hot encoding suele usarse en machine learning.

Dummy coding Templado Caliente
Frío 0 0
Templado 1 0
Caliente 0 1
One-hot encoding Frío Templado Caliente
Frío 1 0 0
Templado 0 1 0
Caliente 0 0 1
Lazy learning

Los algoritmos de clasificación se consideran “lazy” porque no existe abstracción ni generalización, y por tanto no pueden considerarse aprendizaje.

El modelo almacena los datos de entrenamiento y después los compara con los datos nuevos. No se construye ningún modelo, no se parametriza, y por tanto se considera machine learning no paramétrico

Datos de entrenamiento

Hay que generar una base de datos de entrenamiento y otra de prueba; lo mejor es separar el data frame utilizando algunas columnas. Para generar los dato de entrenamiento se debería elegir una muestra aleatoria de la muestras original.

En caso de que la muestra ya estuviese aleatorizada, bastaría con seleccionar unporcentaje de casos:

training <- original[1:75]
testing <- original[76:100]

En los casos en los que esto no es así, hay otras estrategias como la que ofrece sample.split del paquete caTools.

Acepta un vector factorizado y un porcentaje, y devuelve un vector Booleano con una selección de casos que resepta la proporción original de la muestra.

Por ejemplo, dado una base de datos disease con una variable $diagnosis sobre la cuál queremos clasificar los grupos, haríamos:

index <- caTools::sample.split(disease$diagnosis, 0.75)
training_data <- disease[index,]    # 75% de la muestra
test_data <- disease[!index,]       # 25% de la muestra

Entrenamiento del modelo

Hay distintos paquetes que incluyen kNN, pero en este caso utilizaremos el paquete class.

kNN - sintaxis

training_data y testing_datano deberían contener la variable que contiene la clasificación.

training_data_classes es la variable de clasificación de los datos de entrenamiento (es bueno guardar las dos para poder comparar)

library(class)

training_data_classes <- training_data$id
training_data <- training_data[,-1]
testing_data_classes <- testing_data$id
testing_data <- testing_data[,-1]

modelo <- knn(train = training_data,
              test = testing_data,
              clstu = training_data_classes,
              k = número de vecinos)

class hace referencia al vector factorizado en los que se pretenden clasificar los datos.

Para calcular el valor inicial de k podemos utilizar \(k = \sqrt{n}\).

Ejemplo

Sin estandarización

library(class)      # kNN
library(gmodels)    # CrossTable
library(caTools)    # sample.split aleatorio

original_data <- read.csv("data/wisc_bc_data.csv")

# Eliminamos id o variables irrelevantes
original_data <- original_data[-1]

# Generamos training y testing data
i <- caTools::sample.split(original_data$diagnosis, 0.75)
trn_data <- original_data[i,]
tst_data <- original_data[-i,]

# Eliminamos las clases
trn_labels <- trn_data$diagnosis
trn_data <- trn_data[-1]

tst_labels <- tst_data$diagnosis
tst_data <- tst_data[-1]

# Aproximamos un valor útil de k 
k_value = round(sqrt(length(original_data[,1])))

# Generamos la prediccion
tst_predictions <- class::knn(
    train = trn_data, 
    test = tst_data,
    cl = trn_labels,
    k = k_value
)

# Y evaluamos los resultados del modelo

gmodels::CrossTable(
    x = tst_labels,
    y = tst_predictions,
    prop.chisq = FALSE)

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       348 |         8 |       356 | 
             |     0.978 |     0.022 |     0.627 | 
             |     0.911 |     0.043 |           | 
             |     0.613 |     0.014 |           | 
-------------|-----------|-----------|-----------|
           M |        34 |       178 |       212 | 
             |     0.160 |     0.840 |     0.373 | 
             |     0.089 |     0.957 |           | 
             |     0.060 |     0.313 |           | 
-------------|-----------|-----------|-----------|
Column Total |       382 |       186 |       568 | 
             |     0.673 |     0.327 |           | 
-------------|-----------|-----------|-----------|

 

K-value en este caso era 24.

Con estandarización

Utilizamos scale() para devolver las puntuaciones escalares; lo ideal es hacer con el con junto inicial de datos pero tenemos que eliminar aquellas variables que no sean numéricas:

library(class)      # kNN
library(gmodels)    # CrossTable
library(caTools)    # sample.split aleatorio

original_data <- read.csv("data/wisc_bc_data.csv")

# Eliminamos id o variables irrelevantes
original_data <- original_data[-1]

# Extraemos las etiquetas...
original_labels <- original_data$diagnosis

# ... y las eliminamos también
original_data <- original_data[-1]

# Escalamos las variables
original_data <- scale(original_data)

# Generamos training y testing data
i <- caTools::sample.split(original_labels, 0.75)
trn_data <- original_data[i,]
trn_labels <- original_labels[i]
tst_data <- original_data[-i,]
tst_labels <- original_labels[-i]

# Aproximamos un valor útil de k 
k_value = round(sqrt(length(original_data[,1])))

# Generamos la prediccion
tst_predictions <- class::knn(
    train = trn_data, 
    test = tst_data,
    cl = trn_labels,
    k = k_value
)

# Y evaluamos los resultados del modelo

CrossTable(
    x = tst_labels,
    y = tst_predictions,
    prop.chisq = FALSE)

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       355 |         1 |       356 | 
             |     0.997 |     0.003 |     0.627 | 
             |     0.932 |     0.005 |           | 
             |     0.625 |     0.002 |           | 
-------------|-----------|-----------|-----------|
           M |        26 |       186 |       212 | 
             |     0.123 |     0.877 |     0.373 | 
             |     0.068 |     0.995 |           | 
             |     0.046 |     0.327 |           | 
-------------|-----------|-----------|-----------|
Column Total |       381 |       187 |       568 | 
             |     0.671 |     0.329 |           | 
-------------|-----------|-----------|-----------|

 

Buscar k-values alternativos

k-values

Siempre hay que redondear el k-value, tiene que ser un número entero de forma explícita.

kvals <- c(1, 5, 11, 15, 21, 27)
for (kval in kvals) {
   tst_predictions <- knn(
        train = trn_data,
        test = tst_data,
        cl = trn_labels,
        k = round(kval))

    CrossTable(
        x = tst_labels,
        y = tst_predictions,
        prop.chisq = FALSE)
}

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       354 |         2 |       356 | 
             |     0.994 |     0.006 |     0.627 | 
             |     0.981 |     0.010 |           | 
             |     0.623 |     0.004 |           | 
-------------|-----------|-----------|-----------|
           M |         7 |       205 |       212 | 
             |     0.033 |     0.967 |     0.373 | 
             |     0.019 |     0.990 |           | 
             |     0.012 |     0.361 |           | 
-------------|-----------|-----------|-----------|
Column Total |       361 |       207 |       568 | 
             |     0.636 |     0.364 |           | 
-------------|-----------|-----------|-----------|

 

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       353 |         3 |       356 | 
             |     0.992 |     0.008 |     0.627 | 
             |     0.957 |     0.015 |           | 
             |     0.621 |     0.005 |           | 
-------------|-----------|-----------|-----------|
           M |        16 |       196 |       212 | 
             |     0.075 |     0.925 |     0.373 | 
             |     0.043 |     0.985 |           | 
             |     0.028 |     0.345 |           | 
-------------|-----------|-----------|-----------|
Column Total |       369 |       199 |       568 | 
             |     0.650 |     0.350 |           | 
-------------|-----------|-----------|-----------|

 

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       353 |         3 |       356 | 
             |     0.992 |     0.008 |     0.627 | 
             |     0.951 |     0.015 |           | 
             |     0.621 |     0.005 |           | 
-------------|-----------|-----------|-----------|
           M |        18 |       194 |       212 | 
             |     0.085 |     0.915 |     0.373 | 
             |     0.049 |     0.985 |           | 
             |     0.032 |     0.342 |           | 
-------------|-----------|-----------|-----------|
Column Total |       371 |       197 |       568 | 
             |     0.653 |     0.347 |           | 
-------------|-----------|-----------|-----------|

 

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       354 |         2 |       356 | 
             |     0.994 |     0.006 |     0.627 | 
             |     0.941 |     0.010 |           | 
             |     0.623 |     0.004 |           | 
-------------|-----------|-----------|-----------|
           M |        22 |       190 |       212 | 
             |     0.104 |     0.896 |     0.373 | 
             |     0.059 |     0.990 |           | 
             |     0.039 |     0.335 |           | 
-------------|-----------|-----------|-----------|
Column Total |       376 |       192 |       568 | 
             |     0.662 |     0.338 |           | 
-------------|-----------|-----------|-----------|

 

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       355 |         1 |       356 | 
             |     0.997 |     0.003 |     0.627 | 
             |     0.932 |     0.005 |           | 
             |     0.625 |     0.002 |           | 
-------------|-----------|-----------|-----------|
           M |        26 |       186 |       212 | 
             |     0.123 |     0.877 |     0.373 | 
             |     0.068 |     0.995 |           | 
             |     0.046 |     0.327 |           | 
-------------|-----------|-----------|-----------|
Column Total |       381 |       187 |       568 | 
             |     0.671 |     0.329 |           | 
-------------|-----------|-----------|-----------|

 

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|

 
Total Observations in Table:  568 

 
             | tst_predictions 
  tst_labels |         B |         M | Row Total | 
-------------|-----------|-----------|-----------|
           B |       355 |         1 |       356 | 
             |     0.997 |     0.003 |     0.627 | 
             |     0.934 |     0.005 |           | 
             |     0.625 |     0.002 |           | 
-------------|-----------|-----------|-----------|
           M |        25 |       187 |       212 | 
             |     0.118 |     0.882 |     0.373 | 
             |     0.066 |     0.995 |           | 
             |     0.044 |     0.329 |           | 
-------------|-----------|-----------|-----------|
Column Total |       380 |       188 |       568 | 
             |     0.669 |     0.331 |           | 
-------------|-----------|-----------|-----------|