Uczenie maszynowe w R

Dariusz Brzeziński

Część I: Klasyfikacja

Agenda

  • Podstawowe metody uczenia maszynowego
    • Zasady uczenia i testowania
    • Optymalizacja parametrów
    • Wstępne przetwarzanie w procesie uczenia
  • Klasyfikacja
  • Biblioteka caret

Uczenie maszynowe

Dział informatyki zajmujący się przewidywaniem danych na podstawie danych.

Uczenie maszynowe wywodzi się ze sztucznej inteligencji i statystyki, a algorytmy z tej dziedziny często pojawiają się przy okazji terminów: data science, statistical data analysis, deep learning, AI, knowledge discovery czy data mining.

Esencja uczenia maszynowego

  • Istnieje wzorzec w problemie
  • Nie potrafimy zamodelować wzorca matematycznie
  • Mamy dane dotyczące problemu

Yaser Abu-Mostafa

Metody uczenia maszynowego

  • Nadzorowane
    • klasyfikacja (Regresja logistyczna, CART, k-NN, Sieci neuronowe, Naiwny Bayes, Ripper, Bagging, Boosting, Random Forest)
    • regresja (Regresja liniowa, Regresja nieliniowa, GBM, SGD)
  • Nienadzorowane
    • grupowanie (k-means, AHC, DBSCAN, Affinity Propagation, SOM)
    • reguły asocjacyjne (Apriori, FP-Growth, Prefix Span)

Zasady uczenia i testowania

Klasyfikacja jest procesem trzyetapowym:

  1. Konstrukcja modelu w oparciu o zbiór uczący
  2. Ocena modelu na przykładach testowych
  3. Użycie modelu na nowych danych

Podział na zbiór uczący i testowy

Podział na zbiór uczący i testowy może być wykonany na kilka sposobów w zależności od rozmiarów i charakterystyki zbioru danych:

  • oddzielny zbiór testowy (ang. holdout set)
  • ocena krzyżowa (ang. k-fold cross-validation)
  • pojedyncze przykłady testowe (ang. leave-one-out cv)
  • losowanie ze zwracaniem (ang. bootstraping)
  • wielokrotna ocena krzyżowa (ang. repeated cross-validation)

Stratyfikowanie danych

Istotny elementem podziału danych na zbiór uczący i testowy jest stratyfikacja (losowanie warstwowe).

Proporcje klas/wartości powinny być porównywalne w zbiorze uczącym i testowym! Szczególnie w przypadku występowania niezbalansowania klas!

Miary oceny (szczegóły za tydzień)

Klasyfikacja

Accuracy, AUC ROC, G-mean, Precision, Recall, F-score, Sensitivity, Specificity, Kappa

Regresja

R2 , MAE, MSE, RMSE

Grupowanie

RAND, Precision, Recall, F-score, Homogenity, Completeness, Silhouette, ARI, AMI

Problem przeuczenia

Ocena klasyfikatorów na zbiorze uczącym nie jest wiarygodna jeśli rozważamy predykcję nowych faktów.

Nadmierne dopasowanie do specyfiki danych uczących powiązane jest najczęściej z utratą zdolności uogólniania.

Ocena zdolności predykcyjnych klasyfikatora powinna być wykonywana na danych, które nie były wykorzystane do uczenia klasyfikatora!

Diagnozowanie przeuczenia

plot of chunk unnamed-chunk-2

Jeśli zarówno błąd na zbiorze testowym jak i uczącym jest wysoki, zachodzi problem “niedouczenia” (ang. high bias). Możliwe kroki:

  • skorzystać z silniejszego modelu,
  • dodać zmienne wielomianowe,
  • zmniejszyć regularyzację.

Diagnozowanie przeuczenia

plot of chunk unnamed-chunk-3

Jeśli błąd na zbiorze testowym jest znacząco wyższy niż na uczącym, prawdopodobnie doszło do przeuczenia (ang. high variance). Możliwe kroki:

  • zdobyć więcej przykładów,
  • ograniczyć liczbę atrybutów,
  • zwiększyć regularyzację.

Optymalizacja parametrów

Do rozwiązania problemu klasyfikacji istnieje wiele rożnych algorytmów, które z reguły posiadają wiele parametrów wpływających na ich zdolność predykcji.

Problem wyboru najlepszego modelu klasyfikacyjnego (model = dane + algorytm + parametry) po angielsku zwany jest model selection.

Optymalizacja parametrów

Problem znalezienia najlepszego modelu klasyfikacyjnego wymaga uczenia i testowania wielu modeli, aby, korzystając z wybranej miary oceny, wybrać najlepszy.

Aby nie doprowadzić do przeuczenia, model powinien być wybierany w oparciu o inne dane niż zbiór testowy!

zbiór danych = uczący + walidacyjny + testowy

Wstępne przetwarzanie w procesie uczenia

W procesie uczenia wszelkie operacje na danych powinny się opierać tylko na zbiorze uczącym. Jeśli chcemy znormalizować dane (\( \mu=0 \), \( \sigma=1 \)), powinniśmy to robić w oparciu o rozkład danych w zbiorze uczącym, a nie korzystając ze wszystkich danych.

Podczas testowania należy korzystać z przetwarzania wstępnego stworzonego na zbiorze uczącym, a nie testowym!

Oznacza to m.in. inne przetwarzanie wstępne dla każdej iteracji oceny krzyżowej…

Schematy wyboru modelu i estymacji błędu

  • holdout (uczący, walidacyjny, testowy)
  • k-fold cv (cv[uczący, walidacyjny], testowy)
  • nested k-fold cv (cv[cv[uczący, walidacyjny], testowy])

Popularne schematy: holdout, 10-fold cv + holdout, 5x2-fold cv + holdout, holdout + 10x10 cv

Przykład: Nested cross-validation

Nested cross-validation

Uczenie maszynowe w R

  • Bardzo dużo gotowych algorytmów uczących
  • Algorytmy rozproszone po różnych paczkach
  • Niespójności w nazewnictwie w zależności od paczki
  • W efekcie trudności w porównywaniu wielu metod

caret

  • Biblioteka do uczenia i testowania
  • Uspójniania wywołania między różnymi paczkami
  • Implementuje podstawowe schematy uczenia i optymalizacji parametrów
  • Ponad 200 algorytmów
  • Ułatwia zrównoleglanie uczenia na wiele procesorów

caret

Zainstalujmy bibliotekę caret:

install.packages("caret", dependencies = c("Depends", "Suggests"))

caret

caret - proces uczenia

caret - przykład

library(mlbench)
data(Sonar)
kable(summary(Sonar))
V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11 V12 V13 V14 V15 V16 V17 V18 V19 V20 V21 V22 V23 V24 V25 V26 V27 V28 V29 V30 V31 V32 V33 V34 V35 V36 V37 V38 V39 V40 V41 V42 V43 V44 V45 V46 V47 V48 V49 V50 V51 V52 V53 V54 V55 V56 V57 V58 V59 V60 Class
Min. :0.00150 Min. :0.00060 Min. :0.00150 Min. :0.00580 Min. :0.00670 Min. :0.01020 Min. :0.0033 Min. :0.00550 Min. :0.00750 Min. :0.0113 Min. :0.0289 Min. :0.0236 Min. :0.0184 Min. :0.0273 Min. :0.0031 Min. :0.0162 Min. :0.0349 Min. :0.0375 Min. :0.0494 Min. :0.0656 Min. :0.0512 Min. :0.0219 Min. :0.0563 Min. :0.0239 Min. :0.0240 Min. :0.0921 Min. :0.0481 Min. :0.0284 Min. :0.0144 Min. :0.0613 Min. :0.0482 Min. :0.0404 Min. :0.0477 Min. :0.0212 Min. :0.0223 Min. :0.0080 Min. :0.0351 Min. :0.0383 Min. :0.0371 Min. :0.0117 Min. :0.0360 Min. :0.0056 Min. :0.0000 Min. :0.0000 Min. :0.00000 Min. :0.00000 Min. :0.00000 Min. :0.00000 Min. :0.00000 Min. :0.00000 Min. :0.000000 Min. :0.000800 Min. :0.000500 Min. :0.001000 Min. :0.00060 Min. :0.000400 Min. :0.00030 Min. :0.000300 Min. :0.000100 Min. :0.000600 M:111
1st Qu.:0.01335 1st Qu.:0.01645 1st Qu.:0.01895 1st Qu.:0.02438 1st Qu.:0.03805 1st Qu.:0.06703 1st Qu.:0.0809 1st Qu.:0.08042 1st Qu.:0.09703 1st Qu.:0.1113 1st Qu.:0.1293 1st Qu.:0.1335 1st Qu.:0.1661 1st Qu.:0.1752 1st Qu.:0.1646 1st Qu.:0.1963 1st Qu.:0.2059 1st Qu.:0.2421 1st Qu.:0.2991 1st Qu.:0.3506 1st Qu.:0.3997 1st Qu.:0.4069 1st Qu.:0.4502 1st Qu.:0.5407 1st Qu.:0.5258 1st Qu.:0.5442 1st Qu.:0.5319 1st Qu.:0.5348 1st Qu.:0.4637 1st Qu.:0.4114 1st Qu.:0.3456 1st Qu.:0.2814 1st Qu.:0.2579 1st Qu.:0.2176 1st Qu.:0.1794 1st Qu.:0.1543 1st Qu.:0.1601 1st Qu.:0.1743 1st Qu.:0.1740 1st Qu.:0.1865 1st Qu.:0.1631 1st Qu.:0.1589 1st Qu.:0.1552 1st Qu.:0.1269 1st Qu.:0.09448 1st Qu.:0.06855 1st Qu.:0.06425 1st Qu.:0.04512 1st Qu.:0.02635 1st Qu.:0.01155 1st Qu.:0.008425 1st Qu.:0.007275 1st Qu.:0.005075 1st Qu.:0.005375 1st Qu.:0.00415 1st Qu.:0.004400 1st Qu.:0.00370 1st Qu.:0.003600 1st Qu.:0.003675 1st Qu.:0.003100 R: 97
Median :0.02280 Median :0.03080 Median :0.03430 Median :0.04405 Median :0.06250 Median :0.09215 Median :0.1070 Median :0.11210 Median :0.15225 Median :0.1824 Median :0.2248 Median :0.2490 Median :0.2640 Median :0.2811 Median :0.2817 Median :0.3047 Median :0.3084 Median :0.3683 Median :0.4350 Median :0.5425 Median :0.6177 Median :0.6649 Median :0.6997 Median :0.6985 Median :0.7211 Median :0.7545 Median :0.7456 Median :0.7319 Median :0.6808 Median :0.6071 Median :0.4904 Median :0.4296 Median :0.3912 Median :0.3510 Median :0.3127 Median :0.3211 Median :0.3063 Median :0.3127 Median :0.2835 Median :0.2781 Median :0.2595 Median :0.2451 Median :0.2225 Median :0.1777 Median :0.14800 Median :0.12135 Median :0.10165 Median :0.07810 Median :0.04470 Median :0.01790 Median :0.013900 Median :0.011400 Median :0.009550 Median :0.009300 Median :0.00750 Median :0.006850 Median :0.00595 Median :0.005800 Median :0.006400 Median :0.005300 NA
Mean :0.02916 Mean :0.03844 Mean :0.04383 Mean :0.05389 Mean :0.07520 Mean :0.10457 Mean :0.1217 Mean :0.13480 Mean :0.17800 Mean :0.2083 Mean :0.2360 Mean :0.2502 Mean :0.2733 Mean :0.2966 Mean :0.3202 Mean :0.3785 Mean :0.4160 Mean :0.4523 Mean :0.5048 Mean :0.5630 Mean :0.6091 Mean :0.6243 Mean :0.6470 Mean :0.6727 Mean :0.6754 Mean :0.6999 Mean :0.7022 Mean :0.6940 Mean :0.6421 Mean :0.5809 Mean :0.5045 Mean :0.4390 Mean :0.4172 Mean :0.4032 Mean :0.3926 Mean :0.3848 Mean :0.3638 Mean :0.3397 Mean :0.3258 Mean :0.3112 Mean :0.2893 Mean :0.2783 Mean :0.2465 Mean :0.2141 Mean :0.19723 Mean :0.16063 Mean :0.12245 Mean :0.09142 Mean :0.05193 Mean :0.02042 Mean :0.016069 Mean :0.013420 Mean :0.010709 Mean :0.010941 Mean :0.00929 Mean :0.008222 Mean :0.00782 Mean :0.007949 Mean :0.007941 Mean :0.006507 NA
3rd Qu.:0.03555 3rd Qu.:0.04795 3rd Qu.:0.05795 3rd Qu.:0.06450 3rd Qu.:0.10028 3rd Qu.:0.13412 3rd Qu.:0.1540 3rd Qu.:0.16960 3rd Qu.:0.23342 3rd Qu.:0.2687 3rd Qu.:0.3016 3rd Qu.:0.3312 3rd Qu.:0.3513 3rd Qu.:0.3862 3rd Qu.:0.4529 3rd Qu.:0.5357 3rd Qu.:0.6594 3rd Qu.:0.6791 3rd Qu.:0.7314 3rd Qu.:0.8093 3rd Qu.:0.8170 3rd Qu.:0.8320 3rd Qu.:0.8486 3rd Qu.:0.8722 3rd Qu.:0.8737 3rd Qu.:0.8938 3rd Qu.:0.9171 3rd Qu.:0.9003 3rd Qu.:0.8521 3rd Qu.:0.7352 3rd Qu.:0.6420 3rd Qu.:0.5803 3rd Qu.:0.5561 3rd Qu.:0.5961 3rd Qu.:0.5934 3rd Qu.:0.5565 3rd Qu.:0.5189 3rd Qu.:0.4405 3rd Qu.:0.4349 3rd Qu.:0.4244 3rd Qu.:0.3875 3rd Qu.:0.3842 3rd Qu.:0.3245 3rd Qu.:0.2717 3rd Qu.:0.23155 3rd Qu.:0.20037 3rd Qu.:0.15443 3rd Qu.:0.12010 3rd Qu.:0.06853 3rd Qu.:0.02527 3rd Qu.:0.020825 3rd Qu.:0.016725 3rd Qu.:0.014900 3rd Qu.:0.014500 3rd Qu.:0.01210 3rd Qu.:0.010575 3rd Qu.:0.01043 3rd Qu.:0.010350 3rd Qu.:0.010325 3rd Qu.:0.008525 NA
Max. :0.13710 Max. :0.23390 Max. :0.30590 Max. :0.42640 Max. :0.40100 Max. :0.38230 Max. :0.3729 Max. :0.45900 Max. :0.68280 Max. :0.7106 Max. :0.7342 Max. :0.7060 Max. :0.7131 Max. :0.9970 Max. :1.0000 Max. :0.9988 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :0.9657 Max. :0.9306 Max. :1.0000 Max. :0.9647 Max. :1.0000 Max. :1.0000 Max. :0.9497 Max. :1.0000 Max. :0.9857 Max. :0.9297 Max. :0.8995 Max. :0.8246 Max. :0.7733 Max. :0.7762 Max. :0.70340 Max. :0.72920 Max. :0.55220 Max. :0.33390 Max. :0.19810 Max. :0.08250 Max. :0.100400 Max. :0.070900 Max. :0.039000 Max. :0.035200 Max. :0.04470 Max. :0.039400 Max. :0.03550 Max. :0.044000 Max. :0.036400 Max. :0.043900 NA

caret - podział zbioru danych

library(caret)
set.seed(23)
inTraining <-
    createDataPartition(
        # atrybut do stratyfikacji
        y = Sonar$Class,
        # procent w zbiorze uczącym
        p = .75,
        # chcemy indeksy a nie listę
        list = FALSE)

training <- Sonar[ inTraining,]
testing  <- Sonar[-inTraining,]

caret - schemat uczenia

ctrl <- trainControl(
    # powtórzona ocena krzyżowa
    method = "repeatedcv",
    # liczba podziałów
    number = 2,
    # liczba powtórzeń
    repeats = 5)

Oprócz powtarzanej oceny krzyżowej domyślnie dostępne są również: tradycyjna ocena krzyżowa, bootstraping czy ocena krzyżowa z pojedynczym przykładem.

caret domyślnie optymalizuje parametry wybierając trzy wartości dla każdego zdefiniowanego dla modelu parametru optymalizacyjnego.

caret - uczenie

set.seed(23)
fit <- train(Class ~ .,
             data = training,
             method = "rf",
             trControl = ctrl,
             # Paramter dla algorytmu uczącego
             ntree = 10)

W powyższym przykładzie tworzony jest model klasyfikacyjny zgodnie z algorytmem Random Forest. caret obsługuje obecnie ponad 200 różnych implementacji algorytmów, których listę można znaleźć na stronie projektu.

caret - uczenie

fit
Random Forest 

157 samples
 60 predictor
  2 classes: 'M', 'R' 

No pre-processing
Resampling: Cross-Validated (2 fold, repeated 5 times) 
Summary of sample sizes: 79, 78, 78, 79, 78, 79, ... 
Resampling results across tuning parameters:

  mtry  Accuracy   Kappa    
   2    0.7209023  0.4311814
  31    0.7541870  0.5033600
  60    0.7324895  0.4598274

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 31.

caret - predykcja

rfClasses <- predict(fit, newdata = testing)
confusionMatrix(data = rfClasses, testing$Class)
Confusion Matrix and Statistics

          Reference
Prediction  M  R
         M 22  4
         R  5 20

               Accuracy : 0.8235         
                 95% CI : (0.6913, 0.916)
    No Information Rate : 0.5294         
    P-Value [Acc > NIR] : 1.117e-05      

                  Kappa : 0.6467         

 Mcnemar's Test P-Value : 1              

            Sensitivity : 0.8148         
            Specificity : 0.8333         
         Pos Pred Value : 0.8462         
         Neg Pred Value : 0.8000         
             Prevalence : 0.5294         
         Detection Rate : 0.4314         
   Detection Prevalence : 0.5098         
      Balanced Accuracy : 0.8241         

       'Positive' Class : M              

caret - optymalizacja parametrów

rfGrid <- expand.grid(mtry = 10:30)
gridCtrl <- trainControl(
    method = "repeatedcv",
    summaryFunction = twoClassSummary,
    classProbs = TRUE,
    number = 2,
    repeats = 5)

set.seed(23)
fitTune <- train(Class ~ .,
             data = training,
             method = "rf",
             metric = "ROC",
             preProc = c("center", "scale"),
             trControl = gridCtrl,
             tuneGrid = rfGrid,
             ntree = 30)

caret - optymalizacja parametrów

fitTune
Random Forest 

157 samples
 60 predictor
  2 classes: 'M', 'R' 

Pre-processing: centered (60), scaled (60) 
Resampling: Cross-Validated (2 fold, repeated 5 times) 
Summary of sample sizes: 79, 78, 78, 79, 78, 79, ... 
Resampling results across tuning parameters:

  mtry  ROC        Sens       Spec     
  10    0.8881095  0.8523810  0.7267267
  11    0.8859958  0.8523810  0.7262012
  12    0.8719755  0.8452381  0.6963213
  13    0.8666622  0.8214286  0.7075075
  14    0.8808228  0.8404762  0.7427928
  15    0.8717593  0.8142857  0.7185435
  16    0.8737872  0.8404762  0.7236486
  17    0.8682388  0.8309524  0.7346096
  18    0.8732992  0.8214286  0.7290541
  19    0.8807441  0.8452381  0.7291291
  20    0.8702658  0.8333333  0.7403904
  21    0.8687732  0.8095238  0.7346847
  22    0.8658882  0.8285714  0.7186937
  23    0.8759965  0.8404762  0.6854354
  24    0.8687473  0.8380952  0.6965465
  25    0.8617564  0.8119048  0.6963213
  26    0.8756265  0.8428571  0.7121622
  27    0.8667257  0.8119048  0.7258258
  28    0.8723053  0.8190476  0.7267267
  29    0.8683174  0.7952381  0.7482733
  30    0.8714643  0.8333333  0.7181682

ROC was used to select the optimal model using the largest value.
The final value used for the model was mtry = 10.

caret - wizualizacje

ggplot(fitTune) + theme_bw()

plot of chunk unnamed-chunk-13

caret - predykcja

rfTuneClasses <- predict(fitTune,
                         newdata = testing)
confusionMatrix(data = rfTuneClasses,
                testing$Class)
Confusion Matrix and Statistics

          Reference
Prediction  M  R
         M 25  3
         R  2 21

               Accuracy : 0.902           
                 95% CI : (0.7859, 0.9674)
    No Information Rate : 0.5294          
    P-Value [Acc > NIR] : 1.209e-08       

                  Kappa : 0.8028          

 Mcnemar's Test P-Value : 1               

            Sensitivity : 0.9259          
            Specificity : 0.8750          
         Pos Pred Value : 0.8929          
         Neg Pred Value : 0.9130          
             Prevalence : 0.5294          
         Detection Rate : 0.4902          
   Detection Prevalence : 0.5490          
      Balanced Accuracy : 0.9005          

       'Positive' Class : M               

caret - predykcja

library(pROC)
rfTuneProbs <- predict(fitTune,
                       newdata = testing,
                       type="prob")
rocCurve <- roc(response = testing$Class,
                predictor = rfTuneProbs[, "M"],
                levels = rev(levels(testing$Class)))

Jeśli algorytm na to pozwala, oprócz predykcji w postaci klas, można również uzyskać wartości prawdopodobieństw wskazania każdej z klas.

caret - predykcja

plot(rocCurve)

plot of chunk unnamed-chunk-16

caret - uwagi

  • aby nauczyć model bez optymalizacji parametrów, należy w ustawić w trainControl parametr method = "none"
  • oprócz dokładnego przeszukiwania przestrzeni parametrów można również testować tylko losowe wartości za pomocą parametru search = "random" w trainControl i tuneLength w train
  • caret oferuje sporo różnych wizualizacji wyników, ale są one dostosowane do różnych systemów graficznych

Zadanie

  • Załaduj zbiór danych churn:
library(modeldata)
data(mlc_churn)
churnData <- data.frame(mlc_churn)
  • Podziel ten zbiór na uczący i testowy (75% w zbiorze uczącym)
  • Przetestuj dwa algorytmy klasyfikacyjne
  • Zastanów się czy warto wstępnie przetworzyć zbiór
  • Określ przestrzeń przeszukiwania parametrów
  • Porównaj algorytmy za pomocą wykresu