/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.bayes;

import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.util.Arrays;
import moa.classifiers.AbstractClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.core.Utils;

public class NaiveBayesMultinomial
extends AbstractClassifier {
    public FloatOption laplaceCorrectionOption = new FloatOption("laplaceCorrection", 'l', "Laplace correction factor.", 1.0, 0.0, 2.147483647E9);
    private static final long serialVersionUID = -7204398796974263187L;
    protected double[] m_classTotals;
    protected Instances m_headerInfo;
    protected int m_numClasses;
    protected double[] m_probOfClass;
    protected DoubleVector[] m_wordTotalForClass;
    protected boolean reset = false;

    @Override
    public String getPurposeString() {
        return "Multinomial Naive Bayes classifier: performs classic bayesian prediction while making naive assumption that all inputs are independent.";
    }

    @Override
    public void resetLearningImpl() {
        this.reset = true;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        if (this.reset) {
            this.m_numClasses = inst.numClasses();
            double laplace = this.laplaceCorrectionOption.getValue();
            int numAttributes = inst.numAttributes();
            this.m_probOfClass = new double[this.m_numClasses];
            Arrays.fill(this.m_probOfClass, laplace);
            this.m_classTotals = new double[this.m_numClasses];
            Arrays.fill(this.m_classTotals, laplace * (double)numAttributes);
            this.m_wordTotalForClass = new DoubleVector[this.m_numClasses];
            for (int i = 0; i < this.m_numClasses; ++i) {
                this.m_wordTotalForClass[i] = new DoubleVector();
            }
            this.reset = false;
        }
        int classIndex = inst.classIndex();
        int classValue = (int)inst.classValue();
        double w = inst.weight();
        int n = classValue;
        this.m_probOfClass[n] = this.m_probOfClass[n] + w;
        int n2 = classValue;
        this.m_classTotals[n2] = this.m_classTotals[n2] + w * this.totalSize(inst);
        double total = this.m_classTotals[classValue];
        for (int i = 0; i < inst.numValues(); ++i) {
            int index = inst.index(i);
            if (index == classIndex || inst.isMissing(i)) continue;
            double laplaceCorrection = 0.0;
            if (this.m_wordTotalForClass[classValue].getValue(index) == 0.0) {
                laplaceCorrection = this.laplaceCorrectionOption.getValue();
            }
            this.m_wordTotalForClass[classValue].addToValue(index, w * inst.valueSparse(i) + laplaceCorrection);
        }
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        int i;
        if (this.reset) {
            return new double[2];
        }
        double[] probOfClassGivenDoc = new double[this.m_numClasses];
        double totalSize = this.totalSize(instance);
        for (i = 0; i < this.m_numClasses; ++i) {
            probOfClassGivenDoc[i] = Math.log(this.m_probOfClass[i]) - totalSize * Math.log(this.m_classTotals[i]);
        }
        for (i = 0; i < instance.numValues(); ++i) {
            int index = instance.index(i);
            if (index == instance.classIndex() || instance.isMissing(i)) continue;
            double wordCount = instance.valueSparse(i);
            for (int c = 0; c < this.m_numClasses; ++c) {
                double value = this.m_wordTotalForClass[c].getValue(index);
                int n = c;
                probOfClassGivenDoc[n] = probOfClassGivenDoc[n] + wordCount * Math.log(value == 0.0 ? this.laplaceCorrectionOption.getValue() : value);
            }
        }
        return Utils.logs2probs(probOfClassGivenDoc);
    }

    public double totalSize(Instance instance) {
        int classIndex = instance.classIndex();
        double total = 0.0;
        for (int i = 0; i < instance.numValues(); ++i) {
            double count;
            int index = instance.index(i);
            if (index == classIndex || instance.isMissing(i) || !((count = instance.valueSparse(i)) >= 0.0)) continue;
            total += count;
        }
        return total;
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override
    public void getModelDescription(StringBuilder result, int indent) {
        int c;
        StringUtils.appendIndented(result, indent, "xxx MNB1 xxx\n\n");
        result.append("The independent probability of a class\n");
        result.append("--------------------------------------\n");
        for (c = 0; c < this.m_numClasses; ++c) {
            result.append(this.m_headerInfo.classAttribute().value(c)).append("\t").append(Double.toString(this.m_probOfClass[c])).append("\n");
        }
        result.append("\nThe probability of a word given the class\n");
        result.append("-----------------------------------------\n\t");
        for (c = 0; c < this.m_numClasses; ++c) {
            result.append(this.m_headerInfo.classAttribute().value(c)).append("\t");
        }
        result.append("\n");
        for (int w = 0; w < this.m_headerInfo.numAttributes(); ++w) {
            if (w == this.m_headerInfo.classIndex()) continue;
            result.append(this.m_headerInfo.attribute(w).name()).append("\t");
            for (int c2 = 0; c2 < this.m_numClasses; ++c2) {
                double value = this.m_wordTotalForClass[c2].getValue(w);
                if (value == 0.0) {
                    value = this.laplaceCorrectionOption.getValue();
                }
                result.append(value / this.m_classTotals[c2]).append("\t");
            }
            result.append("\n");
        }
        StringUtils.appendNewline(result);
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }
}

