/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.multitarget;

import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.MultiLabelPrediction;
import com.yahoo.labs.samoa.instances.Prediction;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.Classifier;
import moa.classifiers.MultiTargetRegressor;
import moa.core.DoubleVector;
import moa.core.FastVector;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.options.ClassOption;
import moa.streams.InstanceStream;

public class BasicMultiTargetRegressor
extends AbstractMultiLabelLearner
implements MultiTargetRegressor {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption;
    protected Classifier[] ensemble;
    protected boolean hasStarted = false;
    protected InstancesHeader[] header;

    public BasicMultiTargetRegressor() {
        this.init();
    }

    protected void init() {
        this.baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "rules.AMRulesRegressor");
    }

    @Override
    public void resetLearningImpl() {
        this.hasStarted = false;
    }

    @Override
    public void trainOnInstanceImpl(MultiLabelInstance instance) {
        if (!this.hasStarted) {
            this.ensemble = new Classifier[instance.numberOutputTargets()];
            Classifier baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
            baseLearner.resetLearning();
            for (int i = 0; i < this.ensemble.length; ++i) {
                this.ensemble[i] = baseLearner.copy();
            }
            this.hasStarted = true;
        }
        for (int i = 0; i < this.ensemble.length; ++i) {
            Instance weightedInst = this.transformInstance(instance, i);
            this.ensemble[i].trainOnInstance(weightedInst);
        }
    }

    protected Instance transformInstance(MultiLabelInstance inst, int outputIndex) {
        if (this.header == null) {
            this.header = new InstancesHeader[this.ensemble.length];
        }
        if (this.header[outputIndex] == null) {
            FastVector<Attribute> attributes = new FastVector<Attribute>();
            for (int attributeIndex = 0; attributeIndex < inst.numInputAttributes(); ++attributeIndex) {
                attributes.addElement(inst.inputAttribute(attributeIndex));
            }
            attributes.addElement(inst.outputAttribute(outputIndex));
            this.header[outputIndex] = new InstancesHeader(new Instances(this.getCLICreationString(InstanceStream.class), attributes, 0));
            this.header[outputIndex].setClassIndex(attributes.size() - 1);
            this.ensemble[outputIndex].setModelContext(this.header[outputIndex]);
        }
        int numAttributes = this.header[outputIndex].numInputAttributes();
        double[] attVals = new double[numAttributes + 1];
        for (int attributeIndex = 0; attributeIndex < numAttributes; ++attributeIndex) {
            attVals[attributeIndex] = inst.valueInputAttribute(attributeIndex);
        }
        DenseInstance instance = new DenseInstance(1.0, attVals);
        instance.setDataset(this.header[outputIndex]);
        instance.setClassValue(inst.valueOutputAttribute(outputIndex));
        return instance;
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        if (this.ensemble.length > 0) {
            return this.ensemble[0].getModelMeasurements();
        }
        return null;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
        if (this.ensemble.length > 0 && this.ensemble[0] instanceof AbstractClassifier) {
            for (int i = 0; i < this.ensemble.length; ++i) {
                StringUtils.appendIndented(out, indent + 1, "Model output attribute #" + i + "\n");
                ((AbstractClassifier)this.ensemble[i]).getModelDescription(out, indent + 1);
            }
        }
    }

    @Override
    public Prediction getPredictionForInstance(MultiLabelInstance instance) {
        MultiLabelPrediction prediction = null;
        if (this.hasStarted) {
            prediction = new MultiLabelPrediction(this.ensemble.length);
            DoubleVector combinedVote = new DoubleVector();
            for (int i = 0; i < this.ensemble.length; ++i) {
                double vote = this.ensemble[i].getVotesForInstance(this.transformInstance(instance, i))[0];
                prediction.setVote(i, 0, vote);
            }
        }
        return prediction;
    }
}

