/*
 * Decompiled with CFR 0.152.
 */
package moa.evaluation;

import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Prediction;
import java.io.Serializable;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.evaluation.LearningPerformanceEvaluator;
import moa.options.AbstractOptionHandler;
import moa.tasks.TaskMonitor;

public class BasicClassificationPerformanceEvaluator
extends AbstractOptionHandler
implements LearningPerformanceEvaluator<Example<Instance>> {
    private static final long serialVersionUID = 1L;
    protected Estimator weightCorrect;
    protected Estimator[] columnKappa;
    protected Estimator[] rowKappa;
    protected int numClasses;
    private Estimator weightCorrectNoChangeClassifier;
    private Estimator weightMajorityClassifier;
    private int lastSeenClass;
    private double totalWeightObserved;

    @Override
    public void reset() {
        this.reset(this.numClasses);
    }

    public void reset(int numClasses) {
        this.numClasses = numClasses;
        this.rowKappa = new Estimator[numClasses];
        this.columnKappa = new Estimator[numClasses];
        for (int i = 0; i < this.numClasses; ++i) {
            this.rowKappa[i] = this.newEstimator();
            this.columnKappa[i] = this.newEstimator();
        }
        this.weightCorrect = this.newEstimator();
        this.weightCorrectNoChangeClassifier = this.newEstimator();
        this.weightMajorityClassifier = this.newEstimator();
        this.lastSeenClass = 0;
        this.totalWeightObserved = 0.0;
    }

    @Override
    public void addResult(Example<Instance> example, double[] classVotes) {
        Instance inst = example.getData();
        double weight = inst.weight();
        if (!inst.classIsMissing()) {
            int trueClass = (int)inst.classValue();
            int predictedClass = Utils.maxIndex(classVotes);
            if (weight > 0.0) {
                if (this.totalWeightObserved == 0.0) {
                    this.reset(inst.dataset().numClasses());
                }
                this.totalWeightObserved += weight;
                this.weightCorrect.add(predictedClass == trueClass ? weight : 0.0);
                for (int i = 0; i < this.numClasses; ++i) {
                    this.rowKappa[i].add(predictedClass == i ? weight : 0.0);
                    this.columnKappa[i].add(trueClass == i ? weight : 0.0);
                }
            }
            this.weightCorrectNoChangeClassifier.add(this.lastSeenClass == trueClass ? weight : 0.0);
            this.weightMajorityClassifier.add(this.getMajorityClass() == trueClass ? weight : 0.0);
            this.lastSeenClass = trueClass;
        }
    }

    private int getMajorityClass() {
        int majorityClass = 0;
        double maxProbClass = 0.0;
        for (int i = 0; i < this.numClasses; ++i) {
            if (!(this.columnKappa[i].estimation() > maxProbClass)) continue;
            majorityClass = i;
            maxProbClass = this.columnKappa[i].estimation();
        }
        return majorityClass;
    }

    @Override
    public Measurement[] getPerformanceMeasurements() {
        return new Measurement[]{new Measurement("classified instances", this.getTotalWeightObserved()), new Measurement("classifications correct (percent)", this.getFractionCorrectlyClassified() * 100.0), new Measurement("Kappa Statistic (percent)", this.getKappaStatistic() * 100.0), new Measurement("Kappa Temporal Statistic (percent)", this.getKappaTemporalStatistic() * 100.0), new Measurement("Kappa M Statistic (percent)", this.getKappaMStatistic() * 100.0)};
    }

    public double getTotalWeightObserved() {
        return this.totalWeightObserved;
    }

    public double getFractionCorrectlyClassified() {
        return this.weightCorrect.estimation();
    }

    public double getFractionIncorrectlyClassified() {
        return 1.0 - this.getFractionCorrectlyClassified();
    }

    public double getKappaStatistic() {
        if (this.getTotalWeightObserved() > 0.0) {
            double p0 = this.getFractionCorrectlyClassified();
            double pc = 0.0;
            for (int i = 0; i < this.numClasses; ++i) {
                pc += this.rowKappa[i].estimation() * this.columnKappa[i].estimation();
            }
            return (p0 - pc) / (1.0 - pc);
        }
        return 0.0;
    }

    public double getKappaTemporalStatistic() {
        if (this.getTotalWeightObserved() > 0.0) {
            double p0 = this.getFractionCorrectlyClassified();
            double pc = this.weightCorrectNoChangeClassifier.estimation();
            return (p0 - pc) / (1.0 - pc);
        }
        return 0.0;
    }

    private double getKappaMStatistic() {
        if (this.getTotalWeightObserved() > 0.0) {
            double p0 = this.getFractionCorrectlyClassified();
            double pc = this.weightMajorityClassifier.estimation();
            return (p0 - pc) / (1.0 - pc);
        }
        return 0.0;
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
        Measurement.getMeasurementsDescription(this.getPerformanceMeasurements(), sb, indent);
    }

    @Override
    public void addResult(Example<Instance> testInst, Prediction prediction) {
    }

    @Override
    protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    protected Estimator newEstimator() {
        return new BasicEstimator();
    }

    public class BasicEstimator
    implements Estimator {
        protected double len;
        protected double sum;

        @Override
        public void add(double value) {
            this.sum += value;
            this.len += 1.0;
        }

        @Override
        public double estimation() {
            return this.sum / this.len;
        }
    }

    public static interface Estimator
    extends Serializable {
        public void add(double var1);

        public double estimation();
    }
}

