/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;

public class OzaBagAdwin
extends AbstractClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
    protected Classifier[] ensemble;
    protected ADWIN[] ADError;

    @Override
    public String getPurposeString() {
        return "Bagging for evolving data streams using ADWIN.";
    }

    @Override
    public void resetLearningImpl() {
        int i;
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        Classifier baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        baseLearner.resetLearning();
        for (i = 0; i < this.ensemble.length; ++i) {
            this.ensemble[i] = baseLearner.copy();
        }
        this.ADError = new ADWIN[this.ensemble.length];
        for (i = 0; i < this.ensemble.length; ++i) {
            this.ADError[i] = new ADWIN();
        }
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        boolean Change = false;
        for (int i = 0; i < this.ensemble.length; ++i) {
            int k = MiscUtils.poisson(1.0, this.classifierRandom);
            if (k > 0) {
                Instance weightedInst = inst.copy();
                weightedInst.setWeight(inst.weight() * (double)k);
                this.ensemble[i].trainOnInstance(weightedInst);
            }
            boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(inst);
            double ErrEstim = this.ADError[i].getEstimation();
            if (!this.ADError[i].setInput(correctlyClassifies ? 0.0 : 1.0) || !(this.ADError[i].getEstimation() > ErrEstim)) continue;
            Change = true;
        }
        if (Change) {
            double max = 0.0;
            int imax = -1;
            for (int i = 0; i < this.ensemble.length; ++i) {
                if (!(max < this.ADError[i].getEstimation())) continue;
                max = this.ADError[i].getEstimation();
                imax = i;
            }
            if (imax != -1) {
                this.ensemble[imax].resetLearning();
                this.ADError[imax] = new ADWIN();
            }
        }
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
            if (!(vote.sumOfValues() > 0.0)) continue;
            vote.normalize();
            combinedVote.addValues(vote);
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return new Measurement[]{new Measurement("ensemble size", this.ensemble != null ? (double)this.ensemble.length : 0.0)};
    }

    @Override
    public Classifier[] getSubClassifiers() {
        return (Classifier[])this.ensemble.clone();
    }
}

