Encore une fois, merci à Jeremy Howard, pour son merveilleux cours sur le Machine Learning.
Cet article doit tout (sauf les erreurs) à son cours n°2 : https://course.fast.ai/videos/?lesson=2
Nous développons ici, dans Google Colab, une application, qui permet de détecter automatiquement si un drapeau est un drapeau breton, un drapeau pirate ou un drapeau de la Juventus de Turin.
Collecte des données
Obtention des URLs
Tout d’abord, nous allons collecter nos images. Grâce à Google.
Le plus simple est d’utiliser Google Chrome comme éditeur. Mais ce qui suit peut aussi être effectué avec les autres navigateurs.
- Effectuer une requête dans Google Images (par exemple : drapeaux bretons)
- Descendre pour obtenir plus de résultats
- Ouvrir une console JavaScript CmdOptJ sur Mac
- Exécuter la commande suivante dans la console :
urls = Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou);
window.open('data:text/csv;charset=utf-8,' + escape(urls.join('\n')));
- Sauvegarder le fichier (c’est la liste des URLs des images présentées dans Google Images)
- Recommencez avec les drapeaux Pirates et ceux de la Juventus
Sauvegarde des données
Maintenant que nous avons les URLs de nos images, il faut les stocker dans l’environnement Google Colab.
- Tout d’abord, ouvrir l’environnement :
https://colab.research.google.com/notebooks/welcome.ipynb#recent=true
- créer les dossiers
!mkdir flags
!mkdir flags/bretons
!mkdir flags/pirates
!mkdir flags/juventus
- Copier les fichiers des URLS des drapeaux. On utilisera files.upload() qui permet d’importer les fichiers souhaités présents sur votre disque
from google.colab import files
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
- déplacer chaque fichier dans son dossier
!mv drapeaux_bretons.csv flags
!mv drapeaux_juventus.csv flags
!mv drapeaux_pirates.csv flags
Import des images
Recommencez trois fois en changeant les paramètres :
(évidemment l’import de fastai.vision ne se fait qu’une seule fois)
from fastai.vision import *
# On recommence ça 3 fois, avec chaque lot d'images
folder = 'bretons'
file = 'drapeaux_bretons.csv'
path = Path('flags')
dest = path/folder
download_images(path/file, dest, max_pics=200)
folder = 'pirates'
file = 'drapeaux_pirates.csv'
dest = path/folder
download_images(path/file, dest, max_pics=200)
folder = 'juventus'
file = 'drapeaux_juventus.csv'
dest = path/folder
download_images(path/file, dest, max_pics=200)
A la suite de ces lignes, vous devez avoir toutes vos images dans notre environnement.
Vérification des images
On supprime les images qu’on ne peut pas ouvrir
classes = ['bretons','juventus','pirates']
for c in classes:
print(c)
verify_images(path/c, delete=True, max_size=500)
Voir quelques images
On crée un jeu de données train et un pour la validation (20%).
Voir : https://docs.fast.ai/vision.data.html#ImageDataBunch
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
data.show_batch(rows=3, figsize=(7,
Résumé du jeu de données
# Les classes, le nombre de classes, la taille du training set, la taille du validation set
data.classes, data.c, len(data.train_ds), len(data.valid_ds)
(['bretons', 'juventus', 'pirates'], 3, 430, 107)
Entraînement du modèle
Resnet
On utilise un CNN de type ResNet (34).
learn = create_cnn(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
learn.save('stage-1')
On obtient une erreur de 5%, ce qui n’est pas si mal.
epoch train_loss valid_loss error_rate
1 1.238096 0.631061 0.308411
2 0.781144 0.306836 0.121495
3 0.584778 0.205743 0.065421
4 0.456692 0.187582 0.056075
learning rate
On cherche le learning rate le plus approprié, en appliquant la méthode décrite ci-dessous :
The method
Deep Learning 2: Part 1 Lesson 1 – Hiromi Suenagalearn.lr_find()
helps you find an optimal learning rate. It uses the technique developed in the 2015 paper Cyclical Learning Rates for Training Neural Networks, where we simply keep increasing the learning rate from a very small value, until the loss stops decreasing. We can plot the learning rate across batches to see what this looks like.
learn.unfreeze()
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(2, max_lr=slice(3e-5,3e-4))
L’erreur chute à 4%
epoch train_loss valid_loss error_rate
1 0.212879 0.153322 0.037383
2 0.167536 0.135117 0.046729
learn.save('stage-2')
Interprétation
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
Nettoyage des données
Il ya des erreurs de classification car souvent il y a des données qui ne sont pas de qualité. Une intervention « humaine » est parfois nécessaire.
from fastai.widgets import *
ds, idxs = DatasetFormatter().from_toplosses(learn, ds_type=DatasetType.Valid)
L’idéal pour corriger les données, c’est d’utiliser les Widgets de fastai. Le problème, c’est qu’ils ne marchent pas avec Colab.
Matrice de confusion
learn.load('stage-2');
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

interp.plot_top_losses(25)
interp.most_confused()
from fastai.widgets import *
ds, idxs = interp.top_losses()
top_loss_paths=data.valid_ds.x[idxs]
top_loss_paths
ImageItemList (107 items)
[Image (3, 200, 500), Image (3, 500, 500), Image (3, 400, 400), Image (3, 500, 253), Image (3, 374, 500)]...
Path: flags
x=top_loss_paths[0]
x
Maintenant , il faut trouver un moyen de supprimer les images erronées sur GCT (Google Colab Tools)
Pour cela, on va d’abord les identifier.
interp.top_losses()
dv = data.valid_ds
# L'image ayant la plus mauvaise prédiction (98 dans notre cas)
dv[98][0]
On obtient ainsi la liste des images les moins bien classées (dans le validation set). Pour connaitre le nom du fichier, sachant l’id :
data.valid_ds.x.items[98]
# retourne par exemple : PosixPath('flags/bretons/00000115.jpg')
Puis pour détruire l’image :
!rm flags/bretons/00000115.jpg
Ensuite
On itère après avoir mis à jour les données (supprimer les images erronées).