MNIST – scikit-learn – code d'A. Geron

MNIST est le Hello World du Machine Learning.
Le présent article commente le code du §3 du livre d’Aurélien Geron.
Rappel de ce que nous avons écrit à son sujet.

Le livre Hands-On Machine Learning with Scikit-Learn and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems d’Aurélien Geron est très très bon. Il est complet, profond, bien écrit, avec de nombreux exemples de code. C’est un livre incontournable sur le sujet. Il n’y a pas trop de maths et ce qu’il y a n’est pas très compliqué. Certes, le livre demande des efforts de lecture mais l’apprentissage est progressif.

Le code, disponibile ici, décrit comment effectuer une classification sur MNIST avec seulement scikit-Learn.
Pour les commentaires, nous nous appuierons aussi sur Recognizing hand-written digits et Making your First Machine Learning Classifier in Scikit-learn.

Lecture des données

Toutes les API ML/DL proposent la lecture des données MNIST.

from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
mnist

Une autre façon de lire les données, préférable à notre avis, est d’écrire :

from sklearn.datasets import load_digits
digits = load_digits()

Helper functions

Il faut écrire quelques fonctions pour afficher une ou plusieurs images. C’est le cas de plot_digit, plot_digits. Ces fonctions n’ont par vocation à être commentées.

Répartition des données

Les données (mnist) contiennent à la fois les images (codage niveaux de gris) et les labels (les étiquettes : 0..9).
Par exemple :

some_digit = X[36000]
some_digit
array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,  86, 131, 225, 225, 225,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  13,  73, 197, 253, 252, 252, 252, 252,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         4,  29,  29, 154, 187, 252, 252, 253, 252, 252, 233, 145,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  29, 252, 253, 252, 252, 252, 252, 253, 204, 112,  37,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0, 169, 253, 255, 253, 228, 126,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,  98, 243, 252, 253, 252, 246, 130,  38,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,  98, 240, 252, 252, 253, 252, 252,
       252, 221,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0, 225, 252, 252, 236, 225,
       223, 230, 252, 252,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 146, 252, 157,
        50,   0,   0,  25, 205, 252,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,  26, 207, 253,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,  29,  19,   0,   0,
         0,   0,   0,   0,   0,   0,   0,  73, 205, 252,  79,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 120, 215, 209,
       175,   0,   0,   0,   0,   0,   0,   0,  19, 209, 252, 220,  79,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 174,
       252, 252, 239, 140,   0,   0,   0,   0,   0,  29, 104, 252, 249,
       177,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0, 174, 252, 252, 223,   0,   0,   0,   0,   0,   0, 174, 252,
       252, 223,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0, 141, 241, 253, 146,   0,   0,   0,   0, 169, 253,
       255, 253, 253,  84,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0, 178, 252, 154,  85,  85, 210, 225,
       243, 252, 215, 121,  27,   9,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,  66, 208, 220, 252, 253,
       252, 252, 214, 195,  31,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  19,  37,
        84, 146, 223, 114,  28,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0], dtype=uint8)

alors que :

my_label=y[36000]
my_label
5.0

On répartit les données en deux jeux (training set, test set), qu’on mélange ensuite :

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

Classification binaire

Pour un apprentissage progressif, l’auteur propose de commencer par une classification binaire. C’est un chiffre 5 ou pas.
Pour cela, on ne retient que les données pour lesquelles le label est 5.

y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

On a X_train.shape : (60000, 784) et y_train_5.shape  (60000, ).
Par définition, y_train_5 est un tableau de booléens dont les valeurs sont vraies lorsque l’image est étiquetée à 5.
SGDClassifier est un classificateur linéaire (SVM, logistic regression, …) avec SGD training.

from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

On utilise random_state pour qu’à chaque fois on obtienne les mêmes valeurs (aléatoires).
fit a ici 2 paramètres :

X : {array-like, sparse matrix}, shape (n_samples, n_features)

Training data

y : numpy array, shape (n_samples,)

Target values

 
Le test sur some_digit qui est un 5 retourne vrai, ce qui est de bon aloi.

sgd_clf.predict([some_digit])

ça y est ! on a notre logiciel (binaire). Mais que vaut-il ?

Mesure de la performance

Validation croisée

cross-validation (CV for short) : A test set should still be held out for final evaluation, but the validation set is no longer needed when doing CV. In the basic approach, called k-fold CV, the training set is split into k smaller sets (other approaches are described below, but generally follow the same principles). The following procedure is followed for each of the k “folds”:

  • A model is trained using k−1 of the folds as training data;
  • the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute a performance measure such as accuracy).

The performance measure reported by k-fold cross-validation is then the average of the values computed in the loop. This approach can be computationally expensive, but does not waste too much data (as is the case when fixing an arbitrary validation set), which is a major advantage in problems such as inverse inference where the number of samples is very small.

Le code est simple :

from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")

on obtient les résultats suivants :

array([0.9502 , 0.96565, 0.96495])

Un taux d’exactitude (accuracy) supérieur à 95% est le minimum qu’on puisse attendre, car sachant qu’il y a 1/10 de chiffres 5, en répondant toujours (pas un 5) on obtient 90% de bonnes réponses !

Matrice de confusion

On calcule les prédictions et on regarde la matrice de confusion.

from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
array([[53272,  1307],
       [ 1077,  4344]])

Que veut dire ce résultat ?
Il y a 53272 chiffres différents de 5 qui ont été correctement classés, 1307 ont été mal classés.
Il y 4344 chiffres 5 qui ont été bien classés et 1077 qui ont été mal classés.

précision

La précision permet de répondre à la question suivante :

  • Quelle proportion d’identifications positives était effectivement correcte ?

Pour rappel, la précision peut être définie comme suit :
[latex]\text{Précision} = \frac{VP}{VP+FP}[/latex] Comme nous avons :

Vrais Positifs (VP) : 4344 Faux positifs (FP) : 1307
Faux négatifs (FN) : 1077 Vrais négatifs (VN) : 53272

 
precision =   4344 / (4344 + 1307) = 0,7687

from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred)
0.7687135020350381

recall (rappel)

Le rappel permet de répondre à la question suivante :

  • Quelle proportion de résultats positifs réels a été identifiée correctement ?

Mathématiquement, le rappel est défini comme suit :

[latex]\text{Rappel} = \frac{VP}{VP+FN}[/latex] recall =   4344 / (4344 + 1077) = 0,8013

recall_score(y_train_5, y_train_pred)
0.801328168234643

Pour évaluer les performances d’un modèle de façon complète, vous devez analyser à la fois la précision et le rappel. Malheureusement, précision et rappel sont fréquemment en tension. Ceci est dû au fait que l’amélioration de la précision se fait généralement au détriment du rappel et réciproquement.

Le reste du code explique comment trouver le bon seuil en fonction de la précision ou du rappel souhaité. Cette partie du code n’est pas commentée.

roc

Une courbe ROC (receiver operating characteristic) est un graphique représentant les performances d’un modèle de classification pour tous les seuils de classification. Cette courbe trace le taux de vrais positifs en fonction du taux de faux positifs :

Le taux de vrais positifs (TVP) est l’équivalent du rappel. Il est donc défini comme suit :
[latex]TVP = \frac{VP} {VP + FN}[/latex] Le taux de faux positifs (TFP) est défini comme suit :
[latex]TFP = \frac{FP} {FP + VN}[/latex] La courbe roc se calcule simplement :

from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)

(pour le y_scores voir le code)
L’affichage de la courbe roc :

O, a l’habitude de calculer l’aire sour la courbe (AUC)

AUC signifie « aire sous la courbe ROC ». Cette valeur mesure l’intégralité de l’aire à deux dimensions située sous l’ensemble de la courbe ROC (par calculs d’intégrales) de (0,0) à (1,1).
L’AUC fournit une mesure agrégée des performances pour tous les seuils de classification possibles. On peut interpréter l’AUC comme une mesure de la probabilité pour que le modèle classe un exemple positif aléatoire au-dessus d’un exemple négatif aléatoire.

 

from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_5, y_scores)
0.9624496555967155

Les valeurs d’AUC sont comprises dans une plage de 0 à 1. Un modèle dont 100 % des prédictions sont erronées a un AUC de 0,0. Si toutes ses prédictions sont correctes, son AUC est de 1,0.

Avec un(e) AUC de 0.96, doit-on conclure que la performance est bonne ? Pas forcément.
L’auteur fait ensuite un test en utilisant un autre classifier (RandomForestClassifier) qui obtient – pour cet exemple – de bien meilleures performances.

Multiclass classification

Maintenant qu’on s’est à peu près dire si c’est un 5 ou pas, on pourrait s’intéresser aux autres chiffres !
Il existe de nombreuses façons de régler le problème.
OvA (One versus All) et OvO (One vers One) sont présentés dans le livre.  Avec OvA on aura 10 classifiers alors que pour OvO, on en aura 45.

sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])

L’apprentissage est fait sur la totalité du jeu de données (hors le test).
10 classifiers binaires sont exécutés (OvA). Pour chaque digit testé, on retient la valeur maximale.

array([[-311402.62954431, -363517.28355739, -446449.5306454 ,
        -183226.61023518, -414337.15339485,  161855.74572176,
        -452576.39616343, -471957.14962573, -518542.33997148,
        -536774.63961222]])

Le reste du code n’est pas expliqué car il a peu d’utilité pour notre leçon.

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée.