/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.rules.functions;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.LinkedList;
import java.util.Random;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.rules.functions.AMRulesRegressorFunction;
import moa.core.DoubleVector;
import moa.core.Measurement;

public class Perceptron
extends AbstractClassifier
implements AMRulesRegressorFunction {
    private final double SD_THRESHOLD = 1.0E-7;
    private static final long serialVersionUID = 1L;
    public FlagOption constantLearningRatioDecayOption = new FlagOption("learningRatio_Decay_set_constant", 'd', "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
    public FloatOption learningRatioOption = new FloatOption("learningRatio", 'l', "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025);
    public FloatOption learningRateDecayOption = new FloatOption("learningRateDecay", 'm', " Learning Rate decay to use for training the Perceptron.", 0.001);
    public FloatOption fadingFactorOption = new FloatOption("fadingFactor", 'e', "Fading factor for the Perceptron accumulated error", 0.99, 0.0, 1.0);
    private double nError;
    protected double fadingFactor;
    protected double learningRatio;
    protected double learningRateDecay;
    protected double[] weightAttribute;
    public DoubleVector perceptronattributeStatistics = new DoubleVector();
    public DoubleVector squaredperceptronattributeStatistics = new DoubleVector();
    protected double perceptronInstancesSeen;
    protected double perceptronYSeen;
    protected double accumulatedError;
    protected boolean initialisePerceptron;
    protected double perceptronsumY;
    protected double squaredperceptronsumY;
    protected int[] numericAttributesIndex;

    public Perceptron() {
        this.initialisePerceptron = true;
    }

    public Perceptron(Perceptron p) {
        this.constantLearningRatioDecayOption = p.constantLearningRatioDecayOption;
        this.learningRatioOption = p.learningRatioOption;
        this.learningRateDecayOption = p.learningRateDecayOption;
        this.fadingFactorOption = p.fadingFactorOption;
        this.nError = p.nError;
        this.fadingFactor = p.fadingFactor;
        this.learningRatio = p.learningRatio;
        this.learningRateDecay = p.learningRateDecay;
        if (p.weightAttribute != null) {
            this.weightAttribute = (double[])p.weightAttribute.clone();
        }
        this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics);
        this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics);
        this.perceptronInstancesSeen = p.perceptronInstancesSeen;
        this.initialisePerceptron = p.initialisePerceptron;
        this.perceptronsumY = p.perceptronsumY;
        this.squaredperceptronsumY = p.squaredperceptronsumY;
        this.perceptronYSeen = p.perceptronYSeen;
        this.numericAttributesIndex = (int[])p.numericAttributesIndex.clone();
    }

    public void setWeights(double[] w) {
        this.weightAttribute = w;
    }

    public double[] getWeights() {
        return this.weightAttribute;
    }

    public double getInstancesSeen() {
        return this.perceptronInstancesSeen;
    }

    public void setInstancesSeen(int pInstancesSeen) {
        this.perceptronInstancesSeen = pInstancesSeen;
    }

    @Override
    public void resetLearningImpl() {
        this.initialisePerceptron = true;
        this.reset();
    }

    public void reset() {
        this.nError = 0.0;
        this.accumulatedError = 0.0;
        this.perceptronInstancesSeen = 0.0;
        this.perceptronattributeStatistics = new DoubleVector();
        this.squaredperceptronattributeStatistics = new DoubleVector();
        this.perceptronsumY = 0.0;
        this.squaredperceptronsumY = 0.0;
        this.perceptronYSeen = 0.0;
    }

    public void resetError() {
        this.nError = 0.0;
        this.accumulatedError = 0.0;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        this.accumulatedError = Math.abs(this.prediction(inst) - inst.classValue()) * inst.weight() + this.fadingFactor * this.accumulatedError;
        this.nError = inst.weight() + this.fadingFactor * this.nError;
        if (this.initialisePerceptron) {
            LinkedList<Integer> numericIndices = new LinkedList<Integer>();
            for (int i = 0; i < inst.numAttributes(); ++i) {
                if (!inst.attribute(i).isNumeric() || i == inst.classIndex()) continue;
                numericIndices.add(i);
            }
            this.numericAttributesIndex = new int[numericIndices.size()];
            int j = 0;
            for (Integer index : numericIndices) {
                this.numericAttributesIndex[j++] = index;
            }
            this.fadingFactor = this.fadingFactorOption.getValue();
            this.classifierRandom = new Random();
            this.classifierRandom.setSeed(this.randomSeedOption.getValue());
            this.initialisePerceptron = false;
            this.weightAttribute = new double[this.numericAttributesIndex.length + 1];
            for (int i = 0; i < this.numericAttributesIndex.length + 1; ++i) {
                this.weightAttribute[i] = 2.0 * this.classifierRandom.nextDouble() - 1.0;
            }
            this.learningRatio = this.learningRatioOption.getValue();
            this.learningRateDecay = this.learningRateDecayOption.getValue();
        }
        this.perceptronInstancesSeen += inst.weight();
        this.perceptronYSeen += inst.weight();
        for (int j = 0; j < this.numericAttributesIndex.length; ++j) {
            int instAttIndex = Perceptron.modelAttIndexToInstanceAttIndex(this.numericAttributesIndex[j], inst);
            double value = inst.value(instAttIndex);
            this.perceptronattributeStatistics.addToValue(j, value * inst.weight());
            this.squaredperceptronattributeStatistics.addToValue(j, value * value * inst.weight());
        }
        double value = inst.classValue();
        this.perceptronsumY += value * inst.weight();
        this.squaredperceptronsumY += value * value * inst.weight();
        if (!this.constantLearningRatioDecayOption.isSet()) {
            this.learningRatio = this.learningRatioOption.getValue() / (1.0 + this.perceptronInstancesSeen * this.learningRateDecay);
        }
        this.updateWeights(inst, this.learningRatio);
    }

    private double prediction(Instance inst) {
        if (this.initialisePerceptron) {
            return 0.0;
        }
        double[] normalizedInstance = this.normalizedInstance(inst);
        double normalizedPrediction = this.prediction(normalizedInstance);
        return this.denormalizedPrediction(normalizedPrediction);
    }

    public double normalizedPrediction(Instance inst) {
        double[] normalizedInstance = this.normalizedInstance(inst);
        double normalizedPrediction = this.prediction(normalizedInstance);
        return normalizedPrediction;
    }

    private double denormalizedPrediction(double normalizedPrediction) {
        if (!this.initialisePerceptron) {
            double meanY = this.perceptronsumY / this.perceptronYSeen;
            double sdY = this.computeSD(this.squaredperceptronsumY, this.perceptronsumY, this.perceptronYSeen);
            if (sdY > 1.0E-7) {
                return normalizedPrediction * sdY + meanY;
            }
            return normalizedPrediction + meanY;
        }
        return normalizedPrediction;
    }

    public double prediction(double[] instanceValues) {
        double prediction = 0.0;
        if (!this.initialisePerceptron) {
            for (int j = 0; j < instanceValues.length - 1; ++j) {
                prediction += this.weightAttribute[j] * instanceValues[j];
            }
            prediction += this.weightAttribute[instanceValues.length - 1];
        }
        return prediction;
    }

    public double[] normalizedInstance(Instance inst) {
        double[] normalizedInstance = new double[this.numericAttributesIndex.length + 1];
        for (int j = 0; j < this.numericAttributesIndex.length; ++j) {
            int instAttIndex = Perceptron.modelAttIndexToInstanceAttIndex(this.numericAttributesIndex[j], inst);
            double mean = this.perceptronattributeStatistics.getValue(j) / this.perceptronYSeen;
            double sd = this.computeSD(this.squaredperceptronattributeStatistics.getValue(j), this.perceptronattributeStatistics.getValue(j), this.perceptronYSeen);
            normalizedInstance[j] = sd > 1.0E-7 ? (inst.value(instAttIndex) - mean) / sd : inst.value(instAttIndex) - mean;
        }
        return normalizedInstance;
    }

    public double computeSD(double squaredVal, double val, double size) {
        if (size > 1.0) {
            return Math.sqrt((squaredVal - val * val / size) / (size - 1.0));
        }
        return 0.0;
    }

    public void updateWeights(Instance inst, double learningRatio) {
        int j;
        double[] normalizedInstance = this.normalizedInstance(inst);
        double normalizedPredict = this.prediction(normalizedInstance);
        double normalizedY = this.normalizeActualClassValue(inst);
        double sumWeights = 0.0;
        double delta = normalizedY - normalizedPredict;
        for (j = 0; j < this.numericAttributesIndex.length; ++j) {
            int n = j;
            this.weightAttribute[n] = this.weightAttribute[n] + learningRatio * delta * normalizedInstance[j] * inst.weight();
            sumWeights += Math.abs(this.weightAttribute[j]);
        }
        int n = this.numericAttributesIndex.length;
        this.weightAttribute[n] = this.weightAttribute[n] + learningRatio * delta * inst.weight();
        if ((sumWeights += Math.abs(this.weightAttribute[this.numericAttributesIndex.length])) > (double)this.numericAttributesIndex.length) {
            for (j = 0; j < this.numericAttributesIndex.length; ++j) {
                this.weightAttribute[j] = this.weightAttribute[j] / sumWeights;
            }
            this.weightAttribute[this.numericAttributesIndex.length] = this.weightAttribute[this.numericAttributesIndex.length] / sumWeights;
        }
    }

    public void normalizeWeights() {
        int j;
        double sumWeights = 0.0;
        for (j = 0; j < this.weightAttribute.length; ++j) {
            sumWeights += Math.abs(this.weightAttribute[j]);
        }
        for (j = 0; j < this.weightAttribute.length; ++j) {
            this.weightAttribute[j] = this.weightAttribute[j] / sumWeights;
        }
    }

    private double normalizeActualClassValue(Instance inst) {
        double meanY = this.perceptronsumY / this.perceptronYSeen;
        double sdY = this.computeSD(this.squaredperceptronsumY, this.perceptronsumY, this.perceptronYSeen);
        double normalizedY = 0.0;
        normalizedY = sdY > 1.0E-7 ? (inst.classValue() - meanY) / sdY : inst.classValue() - meanY;
        return normalizedY;
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        if (!this.initialisePerceptron) {
            return new double[]{this.prediction(inst)};
        }
        return new double[]{0.0};
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
        if (this.weightAttribute != null) {
            for (int i = 0; i < this.weightAttribute.length - 1; ++i) {
                if (this.weightAttribute[i] > 0.0 && i > 0) {
                    out.append(" +" + (double)Math.round(this.weightAttribute[i] * 1000.0) / 1000.0 + " X" + i);
                    continue;
                }
                if (!(this.weightAttribute[i] < 0.0) && i != 0) continue;
                out.append(" " + (double)Math.round(this.weightAttribute[i] * 1000.0) / 1000.0 + " X" + i);
            }
            if (this.weightAttribute[this.weightAttribute.length - 1] >= 0.0) {
                out.append(" +" + (double)Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000.0) / 1000.0);
            } else {
                out.append(" " + (double)Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000.0) / 1000.0);
            }
        }
    }

    public void setLearningRatio(double learningRatio) {
        this.learningRatio = learningRatio;
    }

    @Override
    public double getCurrentError() {
        if (this.nError > 0.0) {
            return this.accumulatedError / this.nError;
        }
        return Double.MAX_VALUE;
    }
}

