/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.rules.multilabel.core;

import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.Arrays;
import java.util.LinkedList;
import moa.classifiers.MultiLabelLearner;
import moa.classifiers.rules.core.Utils;
import moa.classifiers.rules.multilabel.attributeclassobservers.AttributeStatisticsObserver;
import moa.classifiers.rules.multilabel.attributeclassobservers.NominalStatisticsObserver;
import moa.classifiers.rules.multilabel.attributeclassobservers.NumericStatisticsObserver;
import moa.classifiers.rules.multilabel.core.AttributeExpansionSuggestion;
import moa.classifiers.rules.multilabel.core.LearningLiteral;
import moa.classifiers.rules.multilabel.core.splitcriteria.MultiLabelSplitCriterion;
import moa.classifiers.rules.multilabel.functions.AMRulesFunction;
import moa.core.AutoExpandVector;
import moa.core.DoubleVector;
import moa.core.ObjectRepository;
import moa.tasks.TaskMonitor;

public class LearningLiteralRegression
extends LearningLiteral {
    private static final long serialVersionUID = 1L;

    public LearningLiteralRegression() {
    }

    public LearningLiteralRegression(int[] outputsToLearn) {
        super(outputsToLearn);
    }

    @Override
    protected double[] getNormalizedErrors(Prediction prediction, Instance instance) {
        double[] errors = new double[this.outputsToLearn.length];
        for (int i = 0; i < this.outputsToLearn.length; ++i) {
            double predY = this.normalizeOutputValue(i, prediction.getVote(this.outputsToLearn[i], 0));
            double trueY = this.normalizeOutputValue(i, instance.valueOutputAttribute(this.outputsToLearn[i]));
            errors[i] = Math.abs(predY - trueY);
        }
        return errors;
    }

    private double normalizeOutputValue(int outputToLearnIndex, double value) {
        double meanY = this.literalStatistics[outputToLearnIndex].getValue(1) / this.literalStatistics[outputToLearnIndex].getValue(0);
        double sdY = Utils.computeSD(this.literalStatistics[outputToLearnIndex].getValue(2), this.literalStatistics[outputToLearnIndex].getValue(1), this.literalStatistics[outputToLearnIndex].getValue(0));
        double normalizedY = 0.0;
        if (sdY > 1.0E-7) {
            normalizedY = (value - meanY) / sdY;
        }
        return normalizedY;
    }

    @Override
    protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    @Override
    public boolean tryToExpand(double splitConfidence, double tieThreshold) {
        boolean shouldSplit = false;
        Object[] bestSplitSuggestions = this.getBestSplitSuggestions(this.splitCriterion);
        Arrays.sort(bestSplitSuggestions);
        if (bestSplitSuggestions.length < 2) {
            this.bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
        } else {
            double hoeffdingBound = LearningLiteralRegression.computeHoeffdingBound(this.splitCriterion.getRangeOfMerit(this.literalStatistics), splitConfidence, this.weightSeen);
            this.bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
            Object secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2];
            if (this.bestSuggestion.merit - ((AttributeExpansionSuggestion)secondBestSuggestion).merit > hoeffdingBound || hoeffdingBound < tieThreshold) {
                shouldSplit = true;
            }
        }
        if (shouldSplit) {
            DoubleVector[] newLiteralStatistics;
            DoubleVector[][] resultingStatistics = this.bestSuggestion.getResultingNodeStatistics();
            double[] branchMerits = this.splitCriterion.getBranchesSplitMerits(resultingStatistics);
            if (branchMerits[1] > branchMerits[0]) {
                this.bestSuggestion.getPredicate().negateCondition();
                newLiteralStatistics = this.getBranchStatistics(resultingStatistics, 1);
            } else {
                newLiteralStatistics = this.getBranchStatistics(resultingStatistics, 0);
            }
            int[] newOutputs = this.outputSelector.getNextOutputIndices(newLiteralStatistics, this.literalStatistics, this.outputsToLearn);
            if (this.learner instanceof AMRulesFunction) {
                ((AMRulesFunction)((Object)this.learner)).resetWithMemory();
            }
            this.expandedLearningLiteral = new LearningLiteralRegression(newOutputs);
            this.expandedLearningLiteral.setLearner((MultiLabelLearner)this.learner.copy());
            this.otherBranchLearningLiteral = new LearningLiteralRegression(newOutputs);
            this.otherBranchLearningLiteral.setLearner((MultiLabelLearner)this.learner.copy());
        }
        return shouldSplit;
    }

    private DoubleVector[] getBranchStatistics(DoubleVector[][] resultingStatistics, int indexBranch) {
        DoubleVector[] selBranchStats = new DoubleVector[resultingStatistics.length];
        for (int i = 0; i < resultingStatistics.length; ++i) {
            selBranchStats[i] = resultingStatistics[i][indexBranch];
        }
        return selBranchStats;
    }

    private AttributeExpansionSuggestion[] getBestSplitSuggestions(MultiLabelSplitCriterion criterion) {
        LinkedList<AttributeExpansionSuggestion> bestSuggestions = new LinkedList<AttributeExpansionSuggestion>();
        for (int i = 0; i < this.attributeObservers.size(); ++i) {
            AttributeStatisticsObserver obs = (AttributeStatisticsObserver)this.attributeObservers.get(i);
            if (obs == null) continue;
            AttributeExpansionSuggestion bestSuggestion = null;
            bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, this.literalStatistics, i);
            if (bestSuggestion == null) continue;
            bestSuggestions.add(bestSuggestion);
        }
        return bestSuggestions.toArray(new AttributeExpansionSuggestion[bestSuggestions.size()]);
    }

    @Override
    public void trainOnInstance(MultiLabelInstance instance) {
        int i;
        if (this.attributesMask == null) {
            this.initializeAttibutesMask(instance);
        }
        int numOutputs = instance.numberOutputTargets();
        if (!this.hasStarted) {
            int i2;
            if (this.outputsToLearn == null) {
                this.outputsToLearn = new int[instance.numberOutputTargets()];
                for (i2 = 0; i2 < numOutputs; ++i2) {
                    this.outputsToLearn[i2] = i2;
                }
            }
            this.literalStatistics = new DoubleVector[this.outputsToLearn.length];
            for (i2 = 0; i2 < this.outputsToLearn.length; ++i2) {
                this.literalStatistics[i2] = new DoubleVector(new double[3]);
            }
            this.hasStarted = true;
        }
        double weight = instance.weight();
        DoubleVector[] exampleStatistics = new DoubleVector[this.outputsToLearn.length];
        for (i = 0; i < this.outputsToLearn.length; ++i) {
            double target = instance.valueOutputAttribute(this.outputsToLearn[i]);
            double sum = weight * target;
            double squaredSum = weight * target * target;
            exampleStatistics[i] = new DoubleVector(new double[]{weight, sum, squaredSum});
            this.literalStatistics[i].addValues(exampleStatistics[i].getArrayRef());
        }
        if (this.attributeObservers == null) {
            this.attributeObservers = new AutoExpandVector();
        }
        int ct = 0;
        for (i = 0; i < instance.numInputAttributes(); ++i) {
            if (!this.attributesMask[i]) continue;
            AttributeStatisticsObserver obs = (AttributeStatisticsObserver)this.attributeObservers.get(ct);
            if (obs == null) {
                if (instance.attribute(i).isNumeric()) {
                    obs = (NumericStatisticsObserver)this.numericStatisticsObserver.copy();
                } else if (instance.attribute(i).isNominal()) {
                    obs = (NominalStatisticsObserver)this.nominalStatisticsObserver.copy();
                }
                this.attributeObservers.set(ct, obs);
            }
            obs.observeAttribute(instance.valueInputAttribute(i), exampleStatistics);
            ++ct;
        }
        Prediction prediction = this.learner.getPredictionForInstance(instance);
        if (prediction != null) {
            this.errorMeasurer.addPrediction(prediction, instance);
        }
        this.learner.trainOnInstance(instance);
        this.weightSeen += instance.weight();
    }
}

