Dariusz Brzeziński
Część I: Klasyfikacja
caretDział informatyki zajmujący się przewidywaniem danych na podstawie danych.
Uczenie maszynowe wywodzi się ze sztucznej inteligencji i statystyki, a algorytmy z tej dziedziny często pojawiają się przy okazji terminów: data science, statistical data analysis, deep learning, AI, knowledge discovery czy data mining.
Yaser Abu-Mostafa
Klasyfikacja jest procesem trzyetapowym:
Podział na zbiór uczący i testowy może być wykonany na kilka sposobów w zależności od rozmiarów i charakterystyki zbioru danych:
Istotny elementem podziału danych na zbiór uczący i testowy jest stratyfikacja (losowanie warstwowe).
Proporcje klas/wartości powinny być porównywalne w zbiorze uczącym i testowym! Szczególnie w przypadku występowania niezbalansowania klas!
Klasyfikacja
Accuracy, AUC ROC, G-mean, Precision, Recall, F-score, Sensitivity, Specificity, Kappa
Regresja
R2 , MAE, MSE, RMSE
Grupowanie
RAND, Precision, Recall, F-score, Homogenity, Completeness, Silhouette, ARI, AMI
Ocena klasyfikatorów na zbiorze uczącym nie jest wiarygodna jeśli rozważamy predykcję nowych faktów.
Nadmierne dopasowanie do specyfiki danych uczących powiązane jest najczęściej z utratą zdolności uogólniania.
Ocena zdolności predykcyjnych klasyfikatora powinna być wykonywana na danych, które nie były wykorzystane do uczenia klasyfikatora!
Jeśli zarówno błąd na zbiorze testowym jak i uczącym jest wysoki, zachodzi problem “niedouczenia” (ang. high bias). Możliwe kroki:
Jeśli błąd na zbiorze testowym jest znacząco wyższy niż na uczącym, prawdopodobnie doszło do przeuczenia (ang. high variance). Możliwe kroki:
Do rozwiązania problemu klasyfikacji istnieje wiele rożnych algorytmów, które z reguły posiadają wiele parametrów wpływających na ich zdolność predykcji.
Problem wyboru najlepszego modelu klasyfikacyjnego (model = dane + algorytm + parametry) po angielsku zwany jest model selection.
Problem znalezienia najlepszego modelu klasyfikacyjnego wymaga uczenia i testowania wielu modeli, aby, korzystając z wybranej miary oceny, wybrać najlepszy.
Aby nie doprowadzić do przeuczenia, model powinien być wybierany w oparciu o inne dane niż zbiór testowy!
W procesie uczenia wszelkie operacje na danych powinny się opierać tylko na zbiorze uczącym. Jeśli chcemy znormalizować dane (\( \mu=0 \), \( \sigma=1 \)), powinniśmy to robić w oparciu o rozkład danych w zbiorze uczącym, a nie korzystając ze wszystkich danych.
Podczas testowania należy korzystać z przetwarzania wstępnego stworzonego na zbiorze uczącym, a nie testowym!
Oznacza to m.in. inne przetwarzanie wstępne dla każdej iteracji oceny krzyżowej…
Popularne schematy: holdout, 10-fold cv + holdout, 5x2-fold cv + holdout, holdout + 10x10 cv
Zainstalujmy bibliotekę caret:
install.packages("caret", dependencies = c("Depends", "Suggests"))
library(mlbench)
data(Sonar)
kable(summary(Sonar))
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 | V11 | V12 | V13 | V14 | V15 | V16 | V17 | V18 | V19 | V20 | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | V29 | V30 | V31 | V32 | V33 | V34 | V35 | V36 | V37 | V38 | V39 | V40 | V41 | V42 | V43 | V44 | V45 | V46 | V47 | V48 | V49 | V50 | V51 | V52 | V53 | V54 | V55 | V56 | V57 | V58 | V59 | V60 | Class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Min. :0.00150 | Min. :0.00060 | Min. :0.00150 | Min. :0.00580 | Min. :0.00670 | Min. :0.01020 | Min. :0.0033 | Min. :0.00550 | Min. :0.00750 | Min. :0.0113 | Min. :0.0289 | Min. :0.0236 | Min. :0.0184 | Min. :0.0273 | Min. :0.0031 | Min. :0.0162 | Min. :0.0349 | Min. :0.0375 | Min. :0.0494 | Min. :0.0656 | Min. :0.0512 | Min. :0.0219 | Min. :0.0563 | Min. :0.0239 | Min. :0.0240 | Min. :0.0921 | Min. :0.0481 | Min. :0.0284 | Min. :0.0144 | Min. :0.0613 | Min. :0.0482 | Min. :0.0404 | Min. :0.0477 | Min. :0.0212 | Min. :0.0223 | Min. :0.0080 | Min. :0.0351 | Min. :0.0383 | Min. :0.0371 | Min. :0.0117 | Min. :0.0360 | Min. :0.0056 | Min. :0.0000 | Min. :0.0000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.000000 | Min. :0.000800 | Min. :0.000500 | Min. :0.001000 | Min. :0.00060 | Min. :0.000400 | Min. :0.00030 | Min. :0.000300 | Min. :0.000100 | Min. :0.000600 | M:111 | |
| 1st Qu.:0.01335 | 1st Qu.:0.01645 | 1st Qu.:0.01895 | 1st Qu.:0.02438 | 1st Qu.:0.03805 | 1st Qu.:0.06703 | 1st Qu.:0.0809 | 1st Qu.:0.08042 | 1st Qu.:0.09703 | 1st Qu.:0.1113 | 1st Qu.:0.1293 | 1st Qu.:0.1335 | 1st Qu.:0.1661 | 1st Qu.:0.1752 | 1st Qu.:0.1646 | 1st Qu.:0.1963 | 1st Qu.:0.2059 | 1st Qu.:0.2421 | 1st Qu.:0.2991 | 1st Qu.:0.3506 | 1st Qu.:0.3997 | 1st Qu.:0.4069 | 1st Qu.:0.4502 | 1st Qu.:0.5407 | 1st Qu.:0.5258 | 1st Qu.:0.5442 | 1st Qu.:0.5319 | 1st Qu.:0.5348 | 1st Qu.:0.4637 | 1st Qu.:0.4114 | 1st Qu.:0.3456 | 1st Qu.:0.2814 | 1st Qu.:0.2579 | 1st Qu.:0.2176 | 1st Qu.:0.1794 | 1st Qu.:0.1543 | 1st Qu.:0.1601 | 1st Qu.:0.1743 | 1st Qu.:0.1740 | 1st Qu.:0.1865 | 1st Qu.:0.1631 | 1st Qu.:0.1589 | 1st Qu.:0.1552 | 1st Qu.:0.1269 | 1st Qu.:0.09448 | 1st Qu.:0.06855 | 1st Qu.:0.06425 | 1st Qu.:0.04512 | 1st Qu.:0.02635 | 1st Qu.:0.01155 | 1st Qu.:0.008425 | 1st Qu.:0.007275 | 1st Qu.:0.005075 | 1st Qu.:0.005375 | 1st Qu.:0.00415 | 1st Qu.:0.004400 | 1st Qu.:0.00370 | 1st Qu.:0.003600 | 1st Qu.:0.003675 | 1st Qu.:0.003100 | R: 97 | |
| Median :0.02280 | Median :0.03080 | Median :0.03430 | Median :0.04405 | Median :0.06250 | Median :0.09215 | Median :0.1070 | Median :0.11210 | Median :0.15225 | Median :0.1824 | Median :0.2248 | Median :0.2490 | Median :0.2640 | Median :0.2811 | Median :0.2817 | Median :0.3047 | Median :0.3084 | Median :0.3683 | Median :0.4350 | Median :0.5425 | Median :0.6177 | Median :0.6649 | Median :0.6997 | Median :0.6985 | Median :0.7211 | Median :0.7545 | Median :0.7456 | Median :0.7319 | Median :0.6808 | Median :0.6071 | Median :0.4904 | Median :0.4296 | Median :0.3912 | Median :0.3510 | Median :0.3127 | Median :0.3211 | Median :0.3063 | Median :0.3127 | Median :0.2835 | Median :0.2781 | Median :0.2595 | Median :0.2451 | Median :0.2225 | Median :0.1777 | Median :0.14800 | Median :0.12135 | Median :0.10165 | Median :0.07810 | Median :0.04470 | Median :0.01790 | Median :0.013900 | Median :0.011400 | Median :0.009550 | Median :0.009300 | Median :0.00750 | Median :0.006850 | Median :0.00595 | Median :0.005800 | Median :0.006400 | Median :0.005300 | NA | |
| Mean :0.02916 | Mean :0.03844 | Mean :0.04383 | Mean :0.05389 | Mean :0.07520 | Mean :0.10457 | Mean :0.1217 | Mean :0.13480 | Mean :0.17800 | Mean :0.2083 | Mean :0.2360 | Mean :0.2502 | Mean :0.2733 | Mean :0.2966 | Mean :0.3202 | Mean :0.3785 | Mean :0.4160 | Mean :0.4523 | Mean :0.5048 | Mean :0.5630 | Mean :0.6091 | Mean :0.6243 | Mean :0.6470 | Mean :0.6727 | Mean :0.6754 | Mean :0.6999 | Mean :0.7022 | Mean :0.6940 | Mean :0.6421 | Mean :0.5809 | Mean :0.5045 | Mean :0.4390 | Mean :0.4172 | Mean :0.4032 | Mean :0.3926 | Mean :0.3848 | Mean :0.3638 | Mean :0.3397 | Mean :0.3258 | Mean :0.3112 | Mean :0.2893 | Mean :0.2783 | Mean :0.2465 | Mean :0.2141 | Mean :0.19723 | Mean :0.16063 | Mean :0.12245 | Mean :0.09142 | Mean :0.05193 | Mean :0.02042 | Mean :0.016069 | Mean :0.013420 | Mean :0.010709 | Mean :0.010941 | Mean :0.00929 | Mean :0.008222 | Mean :0.00782 | Mean :0.007949 | Mean :0.007941 | Mean :0.006507 | NA | |
| 3rd Qu.:0.03555 | 3rd Qu.:0.04795 | 3rd Qu.:0.05795 | 3rd Qu.:0.06450 | 3rd Qu.:0.10028 | 3rd Qu.:0.13412 | 3rd Qu.:0.1540 | 3rd Qu.:0.16960 | 3rd Qu.:0.23342 | 3rd Qu.:0.2687 | 3rd Qu.:0.3016 | 3rd Qu.:0.3312 | 3rd Qu.:0.3513 | 3rd Qu.:0.3862 | 3rd Qu.:0.4529 | 3rd Qu.:0.5357 | 3rd Qu.:0.6594 | 3rd Qu.:0.6791 | 3rd Qu.:0.7314 | 3rd Qu.:0.8093 | 3rd Qu.:0.8170 | 3rd Qu.:0.8320 | 3rd Qu.:0.8486 | 3rd Qu.:0.8722 | 3rd Qu.:0.8737 | 3rd Qu.:0.8938 | 3rd Qu.:0.9171 | 3rd Qu.:0.9003 | 3rd Qu.:0.8521 | 3rd Qu.:0.7352 | 3rd Qu.:0.6420 | 3rd Qu.:0.5803 | 3rd Qu.:0.5561 | 3rd Qu.:0.5961 | 3rd Qu.:0.5934 | 3rd Qu.:0.5565 | 3rd Qu.:0.5189 | 3rd Qu.:0.4405 | 3rd Qu.:0.4349 | 3rd Qu.:0.4244 | 3rd Qu.:0.3875 | 3rd Qu.:0.3842 | 3rd Qu.:0.3245 | 3rd Qu.:0.2717 | 3rd Qu.:0.23155 | 3rd Qu.:0.20037 | 3rd Qu.:0.15443 | 3rd Qu.:0.12010 | 3rd Qu.:0.06853 | 3rd Qu.:0.02527 | 3rd Qu.:0.020825 | 3rd Qu.:0.016725 | 3rd Qu.:0.014900 | 3rd Qu.:0.014500 | 3rd Qu.:0.01210 | 3rd Qu.:0.010575 | 3rd Qu.:0.01043 | 3rd Qu.:0.010350 | 3rd Qu.:0.010325 | 3rd Qu.:0.008525 | NA | |
| Max. :0.13710 | Max. :0.23390 | Max. :0.30590 | Max. :0.42640 | Max. :0.40100 | Max. :0.38230 | Max. :0.3729 | Max. :0.45900 | Max. :0.68280 | Max. :0.7106 | Max. :0.7342 | Max. :0.7060 | Max. :0.7131 | Max. :0.9970 | Max. :1.0000 | Max. :0.9988 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :0.9657 | Max. :0.9306 | Max. :1.0000 | Max. :0.9647 | Max. :1.0000 | Max. :1.0000 | Max. :0.9497 | Max. :1.0000 | Max. :0.9857 | Max. :0.9297 | Max. :0.8995 | Max. :0.8246 | Max. :0.7733 | Max. :0.7762 | Max. :0.70340 | Max. :0.72920 | Max. :0.55220 | Max. :0.33390 | Max. :0.19810 | Max. :0.08250 | Max. :0.100400 | Max. :0.070900 | Max. :0.039000 | Max. :0.035200 | Max. :0.04470 | Max. :0.039400 | Max. :0.03550 | Max. :0.044000 | Max. :0.036400 | Max. :0.043900 | NA | 
library(caret)
set.seed(23)
inTraining <-
    createDataPartition(
        # atrybut do stratyfikacji
        y = Sonar$Class,
        # procent w zbiorze uczącym
        p = .75,
        # chcemy indeksy a nie listę
        list = FALSE)
training <- Sonar[ inTraining,]
testing  <- Sonar[-inTraining,]
ctrl <- trainControl(
    # powtórzona ocena krzyżowa
    method = "repeatedcv",
    # liczba podziałów
    number = 2,
    # liczba powtórzeń
    repeats = 5)
Oprócz powtarzanej oceny krzyżowej domyślnie dostępne są również: tradycyjna ocena krzyżowa, bootstraping czy ocena krzyżowa z pojedynczym przykładem.
caret domyślnie optymalizuje parametry wybierając trzy wartości dla każdego zdefiniowanego dla modelu parametru optymalizacyjnego.
set.seed(23)
fit <- train(Class ~ .,
             data = training,
             method = "rf",
             trControl = ctrl,
             # Paramter dla algorytmu uczącego
             ntree = 10)
W powyższym przykładzie tworzony jest model klasyfikacyjny zgodnie z algorytmem Random Forest. caret obsługuje obecnie ponad 200 różnych implementacji algorytmów, których listę można znaleźć na stronie projektu.
fit
Random Forest 
157 samples
 60 predictor
  2 classes: 'M', 'R' 
No pre-processing
Resampling: Cross-Validated (2 fold, repeated 5 times) 
Summary of sample sizes: 79, 78, 78, 79, 78, 79, ... 
Resampling results across tuning parameters:
  mtry  Accuracy   Kappa    
   2    0.7209023  0.4311814
  31    0.7541870  0.5033600
  60    0.7324895  0.4598274
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 31.
rfClasses <- predict(fit, newdata = testing)
confusionMatrix(data = rfClasses, testing$Class)
Confusion Matrix and Statistics
          Reference
Prediction  M  R
         M 22  4
         R  5 20
               Accuracy : 0.8235         
                 95% CI : (0.6913, 0.916)
    No Information Rate : 0.5294         
    P-Value [Acc > NIR] : 1.117e-05      
                  Kappa : 0.6467         
 Mcnemar's Test P-Value : 1              
            Sensitivity : 0.8148         
            Specificity : 0.8333         
         Pos Pred Value : 0.8462         
         Neg Pred Value : 0.8000         
             Prevalence : 0.5294         
         Detection Rate : 0.4314         
   Detection Prevalence : 0.5098         
      Balanced Accuracy : 0.8241         
       'Positive' Class : M              
rfGrid <- expand.grid(mtry = 10:30)
gridCtrl <- trainControl(
    method = "repeatedcv",
    summaryFunction = twoClassSummary,
    classProbs = TRUE,
    number = 2,
    repeats = 5)
set.seed(23)
fitTune <- train(Class ~ .,
             data = training,
             method = "rf",
             metric = "ROC",
             preProc = c("center", "scale"),
             trControl = gridCtrl,
             tuneGrid = rfGrid,
             ntree = 30)
fitTune
Random Forest 
157 samples
 60 predictor
  2 classes: 'M', 'R' 
Pre-processing: centered (60), scaled (60) 
Resampling: Cross-Validated (2 fold, repeated 5 times) 
Summary of sample sizes: 79, 78, 78, 79, 78, 79, ... 
Resampling results across tuning parameters:
  mtry  ROC        Sens       Spec     
  10    0.8881095  0.8523810  0.7267267
  11    0.8859958  0.8523810  0.7262012
  12    0.8719755  0.8452381  0.6963213
  13    0.8666622  0.8214286  0.7075075
  14    0.8808228  0.8404762  0.7427928
  15    0.8717593  0.8142857  0.7185435
  16    0.8737872  0.8404762  0.7236486
  17    0.8682388  0.8309524  0.7346096
  18    0.8732992  0.8214286  0.7290541
  19    0.8807441  0.8452381  0.7291291
  20    0.8702658  0.8333333  0.7403904
  21    0.8687732  0.8095238  0.7346847
  22    0.8658882  0.8285714  0.7186937
  23    0.8759965  0.8404762  0.6854354
  24    0.8687473  0.8380952  0.6965465
  25    0.8617564  0.8119048  0.6963213
  26    0.8756265  0.8428571  0.7121622
  27    0.8667257  0.8119048  0.7258258
  28    0.8723053  0.8190476  0.7267267
  29    0.8683174  0.7952381  0.7482733
  30    0.8714643  0.8333333  0.7181682
ROC was used to select the optimal model using the largest value.
The final value used for the model was mtry = 10.
ggplot(fitTune) + theme_bw()
rfTuneClasses <- predict(fitTune,
                         newdata = testing)
confusionMatrix(data = rfTuneClasses,
                testing$Class)
Confusion Matrix and Statistics
          Reference
Prediction  M  R
         M 25  3
         R  2 21
               Accuracy : 0.902           
                 95% CI : (0.7859, 0.9674)
    No Information Rate : 0.5294          
    P-Value [Acc > NIR] : 1.209e-08       
                  Kappa : 0.8028          
 Mcnemar's Test P-Value : 1               
            Sensitivity : 0.9259          
            Specificity : 0.8750          
         Pos Pred Value : 0.8929          
         Neg Pred Value : 0.9130          
             Prevalence : 0.5294          
         Detection Rate : 0.4902          
   Detection Prevalence : 0.5490          
      Balanced Accuracy : 0.9005          
       'Positive' Class : M               
library(pROC)
rfTuneProbs <- predict(fitTune,
                       newdata = testing,
                       type="prob")
rocCurve <- roc(response = testing$Class,
                predictor = rfTuneProbs[, "M"],
                levels = rev(levels(testing$Class)))
Jeśli algorytm na to pozwala, oprócz predykcji w postaci klas, można również uzyskać wartości prawdopodobieństw wskazania każdej z klas.
plot(rocCurve)
trainControl parametr method = "none"search = "random"  w trainControl i tuneLength w trainlibrary(modeldata)
data(mlc_churn)
churnData <- data.frame(mlc_churn)