Mis a jour le 2016-05-22, 16:22

Arbres de décision (rpart)

Objectif : prédire une variable en fonction d'attributs pour une liste d'individus. On suppose avoir une liste d'individus caractérisés par des variables explicatives, et on cherche à prédire une variable expliquée. L'apprentissage se fait par partionnement récursif des instances selon des règles sur les variables explicatives. Deux types d'arbres de décision :
Implémentation documentée ici : rpart. Attention, il faut charger la librairie par library(rpart)
Exemple d'utilisation simple avec variable prédite de type facteur (arbre de classification) :
fr <- data.frame(x = runif(1000, 0, 3), y = runif(1000, 2, 5))
fr$z <- factor(ifelse(fr$x < 2, "a", ifelse(fr$y > 4, "b", "a")))
fit <- rpart(z ~ x + y, fr, method = "class")
L'objet retourné est de la classe rpart. L'argument method="class" est optionnel (automatique par défaut car la variable prédite est de type facteur).
Impression de l'arbre sous forme textuelle pour un arbre de classification :
n= 1000 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 1000 116 a (0.8840000 0.1160000)  
  2) y< 4.014725 691   0 a (1.0000000 0.0000000) *
  3) y>=4.014725 309 116 a (0.6245955 0.3754045)  
    6) x< 2.003892 193   0 a (1.0000000 0.0000000) *
    7) x>=2.003892 116   0 b (0.0000000 1.0000000) *
A chaque noeud, on a : On peut avoir tout le détail en faisant : summary(fit).
Exemple d'utilisation avec variable prédite de type numérique (arbre de régression) :
fr <- data.frame(x = runif(1000, 0, 3), y = runif(1000, 2, 5))
fr$z <- ifelse(fr$x < 2, 2 * fr$x, 3 * fr$y)
fit <- rpart(z ~ x + y, fr, method = "anova")
  
Impression de l'arbre sous forme textuelle pour un arbre de régression :
n= 1000

node), split, n, deviance, yval
      * denotes terminal node

 1) root 1000 18975.56000  4.807525
   2) x< 2.003762 668   889.05200  1.996973
     4) x< 1.0084 341   115.48120  1.021080 *
     5) x>=1.0084 327   110.15180  3.014648 *
   3) x>=2.003762 332  2192.93300 10.462490
     6) y< 3.548979 179   326.58590  8.429348
      12) y< 2.812628 90    40.22737  7.272919 *
      13) y>=2.812628 89    44.28705  9.598771 *
     7) y>=3.548979 153   260.75250 12.841140
      14) y< 4.306183 81    33.33708 11.754910 *
      15) y>=4.306183 72    24.32592 14.063150 *
  
A chaque noeud, on a :
Impression d'un arbre sous forme graphique :
plot(fit, uniform = TRUE, branch = 0.5, margin = 0.1)
text(fit, all = FALSE, use.n = TRUE)
avec les paramètres suivants pour plot (qui dessine l'arbre) :
et les paramètres suivants pour text (qui étiquette l'arbre) :
Réglage de la complexité de l'arbre : plus l'arbre est complexe (beaucoup de noeuds), plus il va bien apprendre l'échantillon d'apprentissage, mais aussi plus il va faire de l'overfitting : adaptation uniquement à l'échantillon d'apprentissage, mais beaucoup d'erreurs sur un nouvel échantillon de test. La complexité doit donc être pénalisée, d'où un paramètre cp (complexity parameter) :
Choix de la complexité de l'arbre :
printcp(fit) affiche pour chaque valeur seuil du paramètre cp :
fr <- data.frame(x = runif(1000, 0, 3), y = runif(1000, 2, 5))
fr$z <- factor(ifelse(jitter(fr$x, amount = 0.5) < 2, "a", ifelse(jitter(fr$y, amount = 0.5) > 4, "b", "a")))
fit <- rpart(z ~ x + y, fr, method = "class", control = rpart.control(cp = 0.00001))
printcp(fit)
affiche :
Root node error: 111/1000 = 0.111

n= 1000 

        CP nsplit rel error  xerror     xstd
1 0.342342      0   1.00000 1.00000 0.089493
2 0.018018      2   0.31532 0.44144 0.061499
3 0.006006      4   0.27928 0.36036 0.055827
4 0.000010      7   0.26126 0.39640 0.058430
Pour chaque valeur de cp seuil, on a le nombre de splits de l'arbre correspondant (nsplit), l'erreur croisée xerror et son écart-type xstd. On prend en général la première valeur de cp (i.e. la plus grande) qui est à moins de un écart-type du minimum de xerror. Une fois que la valeur de cp est choisie, on peut récupérer l'arbre correspondant par :
prune(fit, cp = 0.07)
(renvoie un objet de la classe rpart, mais avec certains noeuds supprimés par rapport à l'objet de départ).
Prédiction de valeurs, en sortant une matrice avec les facteurs en colonne et les probabilités des classes en valeurs :
predict(fit, data.frame(x = c(1, 3, 1, 3), y = c(2, 2, 5, 5)))
  a b
1 1 0
2 1 0
3 1 0
4 0 1
  
Prédiction de valeurs, en sortant un vecteur de facteurs :
predict(fit, data.frame(x = c(1, 3, 1, 3), y = c(2, 2, 5, 5)), type = "class")
1 2 3 4 
a a a b 
  

Copyright Aymeric Duclert
programmer en R, tutoriel R, graphes en R