Dariusz Brzeziński
Część I: Klasyfikacja
caret
Dział informatyki zajmujący się przewidywaniem danych na podstawie danych.
Uczenie maszynowe wywodzi się ze sztucznej inteligencji i statystyki, a algorytmy z tej dziedziny często pojawiają się przy okazji terminów: data science, statistical data analysis, deep learning, AI, knowledge discovery czy data mining.
Yaser Abu-Mostafa
Klasyfikacja jest procesem trzyetapowym:
Podział na zbiór uczący i testowy może być wykonany na kilka sposobów w zależności od rozmiarów i charakterystyki zbioru danych:
Istotny elementem podziału danych na zbiór uczący i testowy jest stratyfikacja (losowanie warstwowe).
Proporcje klas/wartości powinny być porównywalne w zbiorze uczącym i testowym! Szczególnie w przypadku występowania niezbalansowania klas!
Klasyfikacja
Accuracy, AUC ROC, G-mean, Precision, Recall, F-score, Sensitivity, Specificity, Kappa
Regresja
R2 , MAE, MSE, RMSE
Grupowanie
RAND, Precision, Recall, F-score, Homogenity, Completeness, Silhouette, ARI, AMI
Ocena klasyfikatorów na zbiorze uczącym nie jest wiarygodna jeśli rozważamy predykcję nowych faktów.
Nadmierne dopasowanie do specyfiki danych uczących powiązane jest najczęściej z utratą zdolności uogólniania.
Ocena zdolności predykcyjnych klasyfikatora powinna być wykonywana na danych, które nie były wykorzystane do uczenia klasyfikatora!
Jeśli zarówno błąd na zbiorze testowym jak i uczącym jest wysoki, zachodzi problem “niedouczenia” (ang. high bias). Możliwe kroki:
Jeśli błąd na zbiorze testowym jest znacząco wyższy niż na uczącym, prawdopodobnie doszło do przeuczenia (ang. high variance). Możliwe kroki:
Do rozwiązania problemu klasyfikacji istnieje wiele rożnych algorytmów, które z reguły posiadają wiele parametrów wpływających na ich zdolność predykcji.
Problem wyboru najlepszego modelu klasyfikacyjnego (model = dane + algorytm + parametry) po angielsku zwany jest model selection.
Problem znalezienia najlepszego modelu klasyfikacyjnego wymaga uczenia i testowania wielu modeli, aby, korzystając z wybranej miary oceny, wybrać najlepszy.
Aby nie doprowadzić do przeuczenia, model powinien być wybierany w oparciu o inne dane niż zbiór testowy!
W procesie uczenia wszelkie operacje na danych powinny się opierać tylko na zbiorze uczącym. Jeśli chcemy znormalizować dane (\( \mu=0 \), \( \sigma=1 \)), powinniśmy to robić w oparciu o rozkład danych w zbiorze uczącym, a nie korzystając ze wszystkich danych.
Podczas testowania należy korzystać z przetwarzania wstępnego stworzonego na zbiorze uczącym, a nie testowym!
Oznacza to m.in. inne przetwarzanie wstępne dla każdej iteracji oceny krzyżowej…
Popularne schematy: holdout, 10-fold cv + holdout, 5x2-fold cv + holdout, holdout + 10x10 cv
Zainstalujmy bibliotekę caret
:
install.packages("caret", dependencies = c("Depends", "Suggests"))
library(mlbench)
data(Sonar)
kable(summary(Sonar))
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 | V11 | V12 | V13 | V14 | V15 | V16 | V17 | V18 | V19 | V20 | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | V29 | V30 | V31 | V32 | V33 | V34 | V35 | V36 | V37 | V38 | V39 | V40 | V41 | V42 | V43 | V44 | V45 | V46 | V47 | V48 | V49 | V50 | V51 | V52 | V53 | V54 | V55 | V56 | V57 | V58 | V59 | V60 | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Min. :0.00150 | Min. :0.00060 | Min. :0.00150 | Min. :0.00580 | Min. :0.00670 | Min. :0.01020 | Min. :0.0033 | Min. :0.00550 | Min. :0.00750 | Min. :0.0113 | Min. :0.0289 | Min. :0.0236 | Min. :0.0184 | Min. :0.0273 | Min. :0.0031 | Min. :0.0162 | Min. :0.0349 | Min. :0.0375 | Min. :0.0494 | Min. :0.0656 | Min. :0.0512 | Min. :0.0219 | Min. :0.0563 | Min. :0.0239 | Min. :0.0240 | Min. :0.0921 | Min. :0.0481 | Min. :0.0284 | Min. :0.0144 | Min. :0.0613 | Min. :0.0482 | Min. :0.0404 | Min. :0.0477 | Min. :0.0212 | Min. :0.0223 | Min. :0.0080 | Min. :0.0351 | Min. :0.0383 | Min. :0.0371 | Min. :0.0117 | Min. :0.0360 | Min. :0.0056 | Min. :0.0000 | Min. :0.0000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.00000 | Min. :0.000000 | Min. :0.000800 | Min. :0.000500 | Min. :0.001000 | Min. :0.00060 | Min. :0.000400 | Min. :0.00030 | Min. :0.000300 | Min. :0.000100 | Min. :0.000600 | M:111 | |
1st Qu.:0.01335 | 1st Qu.:0.01645 | 1st Qu.:0.01895 | 1st Qu.:0.02438 | 1st Qu.:0.03805 | 1st Qu.:0.06703 | 1st Qu.:0.0809 | 1st Qu.:0.08042 | 1st Qu.:0.09703 | 1st Qu.:0.1113 | 1st Qu.:0.1293 | 1st Qu.:0.1335 | 1st Qu.:0.1661 | 1st Qu.:0.1752 | 1st Qu.:0.1646 | 1st Qu.:0.1963 | 1st Qu.:0.2059 | 1st Qu.:0.2421 | 1st Qu.:0.2991 | 1st Qu.:0.3506 | 1st Qu.:0.3997 | 1st Qu.:0.4069 | 1st Qu.:0.4502 | 1st Qu.:0.5407 | 1st Qu.:0.5258 | 1st Qu.:0.5442 | 1st Qu.:0.5319 | 1st Qu.:0.5348 | 1st Qu.:0.4637 | 1st Qu.:0.4114 | 1st Qu.:0.3456 | 1st Qu.:0.2814 | 1st Qu.:0.2579 | 1st Qu.:0.2176 | 1st Qu.:0.1794 | 1st Qu.:0.1543 | 1st Qu.:0.1601 | 1st Qu.:0.1743 | 1st Qu.:0.1740 | 1st Qu.:0.1865 | 1st Qu.:0.1631 | 1st Qu.:0.1589 | 1st Qu.:0.1552 | 1st Qu.:0.1269 | 1st Qu.:0.09448 | 1st Qu.:0.06855 | 1st Qu.:0.06425 | 1st Qu.:0.04512 | 1st Qu.:0.02635 | 1st Qu.:0.01155 | 1st Qu.:0.008425 | 1st Qu.:0.007275 | 1st Qu.:0.005075 | 1st Qu.:0.005375 | 1st Qu.:0.00415 | 1st Qu.:0.004400 | 1st Qu.:0.00370 | 1st Qu.:0.003600 | 1st Qu.:0.003675 | 1st Qu.:0.003100 | R: 97 | |
Median :0.02280 | Median :0.03080 | Median :0.03430 | Median :0.04405 | Median :0.06250 | Median :0.09215 | Median :0.1070 | Median :0.11210 | Median :0.15225 | Median :0.1824 | Median :0.2248 | Median :0.2490 | Median :0.2640 | Median :0.2811 | Median :0.2817 | Median :0.3047 | Median :0.3084 | Median :0.3683 | Median :0.4350 | Median :0.5425 | Median :0.6177 | Median :0.6649 | Median :0.6997 | Median :0.6985 | Median :0.7211 | Median :0.7545 | Median :0.7456 | Median :0.7319 | Median :0.6808 | Median :0.6071 | Median :0.4904 | Median :0.4296 | Median :0.3912 | Median :0.3510 | Median :0.3127 | Median :0.3211 | Median :0.3063 | Median :0.3127 | Median :0.2835 | Median :0.2781 | Median :0.2595 | Median :0.2451 | Median :0.2225 | Median :0.1777 | Median :0.14800 | Median :0.12135 | Median :0.10165 | Median :0.07810 | Median :0.04470 | Median :0.01790 | Median :0.013900 | Median :0.011400 | Median :0.009550 | Median :0.009300 | Median :0.00750 | Median :0.006850 | Median :0.00595 | Median :0.005800 | Median :0.006400 | Median :0.005300 | NA | |
Mean :0.02916 | Mean :0.03844 | Mean :0.04383 | Mean :0.05389 | Mean :0.07520 | Mean :0.10457 | Mean :0.1217 | Mean :0.13480 | Mean :0.17800 | Mean :0.2083 | Mean :0.2360 | Mean :0.2502 | Mean :0.2733 | Mean :0.2966 | Mean :0.3202 | Mean :0.3785 | Mean :0.4160 | Mean :0.4523 | Mean :0.5048 | Mean :0.5630 | Mean :0.6091 | Mean :0.6243 | Mean :0.6470 | Mean :0.6727 | Mean :0.6754 | Mean :0.6999 | Mean :0.7022 | Mean :0.6940 | Mean :0.6421 | Mean :0.5809 | Mean :0.5045 | Mean :0.4390 | Mean :0.4172 | Mean :0.4032 | Mean :0.3926 | Mean :0.3848 | Mean :0.3638 | Mean :0.3397 | Mean :0.3258 | Mean :0.3112 | Mean :0.2893 | Mean :0.2783 | Mean :0.2465 | Mean :0.2141 | Mean :0.19723 | Mean :0.16063 | Mean :0.12245 | Mean :0.09142 | Mean :0.05193 | Mean :0.02042 | Mean :0.016069 | Mean :0.013420 | Mean :0.010709 | Mean :0.010941 | Mean :0.00929 | Mean :0.008222 | Mean :0.00782 | Mean :0.007949 | Mean :0.007941 | Mean :0.006507 | NA | |
3rd Qu.:0.03555 | 3rd Qu.:0.04795 | 3rd Qu.:0.05795 | 3rd Qu.:0.06450 | 3rd Qu.:0.10028 | 3rd Qu.:0.13412 | 3rd Qu.:0.1540 | 3rd Qu.:0.16960 | 3rd Qu.:0.23342 | 3rd Qu.:0.2687 | 3rd Qu.:0.3016 | 3rd Qu.:0.3312 | 3rd Qu.:0.3513 | 3rd Qu.:0.3862 | 3rd Qu.:0.4529 | 3rd Qu.:0.5357 | 3rd Qu.:0.6594 | 3rd Qu.:0.6791 | 3rd Qu.:0.7314 | 3rd Qu.:0.8093 | 3rd Qu.:0.8170 | 3rd Qu.:0.8320 | 3rd Qu.:0.8486 | 3rd Qu.:0.8722 | 3rd Qu.:0.8737 | 3rd Qu.:0.8938 | 3rd Qu.:0.9171 | 3rd Qu.:0.9003 | 3rd Qu.:0.8521 | 3rd Qu.:0.7352 | 3rd Qu.:0.6420 | 3rd Qu.:0.5803 | 3rd Qu.:0.5561 | 3rd Qu.:0.5961 | 3rd Qu.:0.5934 | 3rd Qu.:0.5565 | 3rd Qu.:0.5189 | 3rd Qu.:0.4405 | 3rd Qu.:0.4349 | 3rd Qu.:0.4244 | 3rd Qu.:0.3875 | 3rd Qu.:0.3842 | 3rd Qu.:0.3245 | 3rd Qu.:0.2717 | 3rd Qu.:0.23155 | 3rd Qu.:0.20037 | 3rd Qu.:0.15443 | 3rd Qu.:0.12010 | 3rd Qu.:0.06853 | 3rd Qu.:0.02527 | 3rd Qu.:0.020825 | 3rd Qu.:0.016725 | 3rd Qu.:0.014900 | 3rd Qu.:0.014500 | 3rd Qu.:0.01210 | 3rd Qu.:0.010575 | 3rd Qu.:0.01043 | 3rd Qu.:0.010350 | 3rd Qu.:0.010325 | 3rd Qu.:0.008525 | NA | |
Max. :0.13710 | Max. :0.23390 | Max. :0.30590 | Max. :0.42640 | Max. :0.40100 | Max. :0.38230 | Max. :0.3729 | Max. :0.45900 | Max. :0.68280 | Max. :0.7106 | Max. :0.7342 | Max. :0.7060 | Max. :0.7131 | Max. :0.9970 | Max. :1.0000 | Max. :0.9988 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :1.0000 | Max. :0.9657 | Max. :0.9306 | Max. :1.0000 | Max. :0.9647 | Max. :1.0000 | Max. :1.0000 | Max. :0.9497 | Max. :1.0000 | Max. :0.9857 | Max. :0.9297 | Max. :0.8995 | Max. :0.8246 | Max. :0.7733 | Max. :0.7762 | Max. :0.70340 | Max. :0.72920 | Max. :0.55220 | Max. :0.33390 | Max. :0.19810 | Max. :0.08250 | Max. :0.100400 | Max. :0.070900 | Max. :0.039000 | Max. :0.035200 | Max. :0.04470 | Max. :0.039400 | Max. :0.03550 | Max. :0.044000 | Max. :0.036400 | Max. :0.043900 | NA |
library(caret)
set.seed(23)
inTraining <-
createDataPartition(
# atrybut do stratyfikacji
y = Sonar$Class,
# procent w zbiorze uczącym
p = .75,
# chcemy indeksy a nie listę
list = FALSE)
training <- Sonar[ inTraining,]
testing <- Sonar[-inTraining,]
ctrl <- trainControl(
# powtórzona ocena krzyżowa
method = "repeatedcv",
# liczba podziałów
number = 2,
# liczba powtórzeń
repeats = 5)
Oprócz powtarzanej oceny krzyżowej domyślnie dostępne są również: tradycyjna ocena krzyżowa, bootstraping czy ocena krzyżowa z pojedynczym przykładem.
caret
domyślnie optymalizuje parametry wybierając trzy wartości dla każdego zdefiniowanego dla modelu parametru optymalizacyjnego.
set.seed(23)
fit <- train(Class ~ .,
data = training,
method = "rf",
trControl = ctrl,
# Paramter dla algorytmu uczącego
ntree = 10)
W powyższym przykładzie tworzony jest model klasyfikacyjny zgodnie z algorytmem Random Forest. caret
obsługuje obecnie ponad 200 różnych implementacji algorytmów, których listę można znaleźć na stronie projektu.
fit
Random Forest
157 samples
60 predictor
2 classes: 'M', 'R'
No pre-processing
Resampling: Cross-Validated (2 fold, repeated 5 times)
Summary of sample sizes: 79, 78, 78, 79, 78, 79, ...
Resampling results across tuning parameters:
mtry Accuracy Kappa
2 0.7209023 0.4311814
31 0.7541870 0.5033600
60 0.7324895 0.4598274
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 31.
rfClasses <- predict(fit, newdata = testing)
confusionMatrix(data = rfClasses, testing$Class)
Confusion Matrix and Statistics
Reference
Prediction M R
M 22 4
R 5 20
Accuracy : 0.8235
95% CI : (0.6913, 0.916)
No Information Rate : 0.5294
P-Value [Acc > NIR] : 1.117e-05
Kappa : 0.6467
Mcnemar's Test P-Value : 1
Sensitivity : 0.8148
Specificity : 0.8333
Pos Pred Value : 0.8462
Neg Pred Value : 0.8000
Prevalence : 0.5294
Detection Rate : 0.4314
Detection Prevalence : 0.5098
Balanced Accuracy : 0.8241
'Positive' Class : M
rfGrid <- expand.grid(mtry = 10:30)
gridCtrl <- trainControl(
method = "repeatedcv",
summaryFunction = twoClassSummary,
classProbs = TRUE,
number = 2,
repeats = 5)
set.seed(23)
fitTune <- train(Class ~ .,
data = training,
method = "rf",
metric = "ROC",
preProc = c("center", "scale"),
trControl = gridCtrl,
tuneGrid = rfGrid,
ntree = 30)
fitTune
Random Forest
157 samples
60 predictor
2 classes: 'M', 'R'
Pre-processing: centered (60), scaled (60)
Resampling: Cross-Validated (2 fold, repeated 5 times)
Summary of sample sizes: 79, 78, 78, 79, 78, 79, ...
Resampling results across tuning parameters:
mtry ROC Sens Spec
10 0.8881095 0.8523810 0.7267267
11 0.8859958 0.8523810 0.7262012
12 0.8719755 0.8452381 0.6963213
13 0.8666622 0.8214286 0.7075075
14 0.8808228 0.8404762 0.7427928
15 0.8717593 0.8142857 0.7185435
16 0.8737872 0.8404762 0.7236486
17 0.8682388 0.8309524 0.7346096
18 0.8732992 0.8214286 0.7290541
19 0.8807441 0.8452381 0.7291291
20 0.8702658 0.8333333 0.7403904
21 0.8687732 0.8095238 0.7346847
22 0.8658882 0.8285714 0.7186937
23 0.8759965 0.8404762 0.6854354
24 0.8687473 0.8380952 0.6965465
25 0.8617564 0.8119048 0.6963213
26 0.8756265 0.8428571 0.7121622
27 0.8667257 0.8119048 0.7258258
28 0.8723053 0.8190476 0.7267267
29 0.8683174 0.7952381 0.7482733
30 0.8714643 0.8333333 0.7181682
ROC was used to select the optimal model using the largest value.
The final value used for the model was mtry = 10.
ggplot(fitTune) + theme_bw()
rfTuneClasses <- predict(fitTune,
newdata = testing)
confusionMatrix(data = rfTuneClasses,
testing$Class)
Confusion Matrix and Statistics
Reference
Prediction M R
M 25 3
R 2 21
Accuracy : 0.902
95% CI : (0.7859, 0.9674)
No Information Rate : 0.5294
P-Value [Acc > NIR] : 1.209e-08
Kappa : 0.8028
Mcnemar's Test P-Value : 1
Sensitivity : 0.9259
Specificity : 0.8750
Pos Pred Value : 0.8929
Neg Pred Value : 0.9130
Prevalence : 0.5294
Detection Rate : 0.4902
Detection Prevalence : 0.5490
Balanced Accuracy : 0.9005
'Positive' Class : M
library(pROC)
rfTuneProbs <- predict(fitTune,
newdata = testing,
type="prob")
rocCurve <- roc(response = testing$Class,
predictor = rfTuneProbs[, "M"],
levels = rev(levels(testing$Class)))
Jeśli algorytm na to pozwala, oprócz predykcji w postaci klas, można również uzyskać wartości prawdopodobieństw wskazania każdej z klas.
plot(rocCurve)
trainControl
parametr method = "none"
search = "random"
w trainControl
i tuneLength
w train
library(modeldata)
data(mlc_churn)
churnData <- data.frame(mlc_churn)