/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.drift;

import com.yahoo.labs.samoa.instances.Instance;
import java.util.LinkedList;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.core.driftdetection.ChangeDetector;
import moa.classifiers.meta.WEKAClassifier;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

public class DriftDetectionMethodClassifier
extends AbstractClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "bayes.NaiveBayes");
    public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'd', "Drift detection method to use.", ChangeDetector.class, "DDM");
    protected Classifier classifier;
    protected Classifier newclassifier;
    protected ChangeDetector driftDetectionMethod;
    protected boolean newClassifierReset;
    protected int ddmLevel;
    public static final int DDM_INCONTROL_LEVEL = 0;
    public static final int DDM_WARNING_LEVEL = 1;
    public static final int DDM_OUTCONTROL_LEVEL = 2;
    protected int changeDetected = 0;
    protected int warningDetected = 0;

    @Override
    public String getPurposeString() {
        return "Classifier that replaces the current classifier with a new one when a change is detected in accuracy.";
    }

    @Override
    public void resetLearningImpl() {
        this.classifier = ((Classifier)this.getPreparedClassOption(this.baseLearnerOption)).copy();
        this.newclassifier = this.classifier.copy();
        this.classifier.resetLearning();
        this.newclassifier.resetLearning();
        this.driftDetectionMethod = ((ChangeDetector)this.getPreparedClassOption(this.driftDetectionMethodOption)).copy();
        this.newClassifierReset = false;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        int trueClass = (int)inst.classValue();
        boolean prediction = Utils.maxIndex(this.classifier.getVotesForInstance(inst)) == trueClass;
        this.driftDetectionMethod.input(prediction ? 0.0 : 1.0);
        this.ddmLevel = 0;
        if (this.driftDetectionMethod.getChange()) {
            this.ddmLevel = 2;
        }
        if (this.driftDetectionMethod.getWarningZone()) {
            this.ddmLevel = 1;
        }
        switch (this.ddmLevel) {
            case 1: {
                if (this.newClassifierReset) {
                    ++this.warningDetected;
                    this.newclassifier.resetLearning();
                    this.newClassifierReset = false;
                }
                this.newclassifier.trainOnInstance(inst);
                break;
            }
            case 2: {
                ++this.changeDetected;
                this.classifier = null;
                this.classifier = this.newclassifier;
                if (this.classifier instanceof WEKAClassifier) {
                    ((WEKAClassifier)this.classifier).buildClassifier();
                }
                this.newclassifier = ((Classifier)this.getPreparedClassOption(this.baseLearnerOption)).copy();
                this.newclassifier.resetLearning();
                break;
            }
            case 0: {
                this.newClassifierReset = true;
                break;
            }
        }
        this.classifier.trainOnInstance(inst);
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        return this.classifier.getVotesForInstance(inst);
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
        ((AbstractClassifier)this.classifier).getModelDescription(out, indent);
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        LinkedList<Measurement> measurementList = new LinkedList<Measurement>();
        measurementList.add(new Measurement("Change detected", this.changeDetected));
        measurementList.add(new Measurement("Warning detected", this.warningDetected));
        Measurement[] modelMeasurements = ((AbstractClassifier)this.classifier).getModelMeasurements();
        if (modelMeasurements != null) {
            for (Measurement measurement : modelMeasurements) {
                measurementList.add(measurement);
            }
        }
        this.changeDetected = 0;
        this.warningDetected = 0;
        return measurementList.toArray(new Measurement[measurementList.size()]);
    }
}

