/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.functions;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.core.Utils;

public class SPegasos
extends AbstractClassifier {
    private static final long serialVersionUID = -3732968666673530290L;
    protected double m_lambda = 1.0E-4;
    public FloatOption lambdaRegularizationOption = new FloatOption("lambdaRegularization", 'l', "Lambda regularization parameter .", 1.0E-4, 0.0, 2.147483647E9);
    protected static final int HINGE = 0;
    protected static final int LOGLOSS = 1;
    protected int m_loss = 0;
    public MultiChoiceOption lossFunctionOption = new MultiChoiceOption("lossFunction", 'o', "The loss function to use.", new String[]{"HINGE", "LOGLOSS"}, new String[]{"Hinge loss (SVM)", "Log loss (logistic regression)"}, 0);
    protected double[] m_weights;
    protected double m_t;

    @Override
    public String getPurposeString() {
        return "Stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007).";
    }

    public void setLambda(double lambda) {
        this.m_lambda = lambda;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLossFunction(int function) {
        this.m_loss = function;
    }

    public int getLossFunction() {
        return this.m_loss;
    }

    public void reset() {
        this.m_t = 2.0;
        this.m_weights = null;
    }

    protected static double dotProd(Instance inst1, double[] weights, int classIndex) {
        double result = 0.0;
        int n1 = inst1.numValues();
        int n2 = weights.length - 1;
        int p1 = 0;
        int p2 = 0;
        while (p1 < n1 && p2 < n2) {
            int ind2;
            int ind1 = inst1.index(p1);
            if (ind1 == (ind2 = p2++)) {
                if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
                    result += inst1.valueSparse(p1) * weights[p2];
                }
                ++p1;
                ++p2;
                continue;
            }
            if (ind1 > ind2) continue;
            ++p1;
        }
        return result;
    }

    protected double dloss(double z) {
        if (this.m_loss == 0) {
            return z < 1.0 ? 1.0 : 0.0;
        }
        if (z < 0.0) {
            return 1.0 / (Math.exp(z) + 1.0);
        }
        double t = Math.exp(-z);
        return t / (t + 1.0);
    }

    @Override
    public void resetLearningImpl() {
        this.reset();
        this.setLambda(this.lambdaRegularizationOption.getValue());
        this.setLossFunction(this.lossFunctionOption.getChosenIndex());
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        if (this.m_weights == null) {
            this.m_weights = new double[instance.numAttributes() + 1];
        }
        if (!instance.classIsMissing()) {
            double learningRate = 1.0 / (this.m_lambda * this.m_t);
            double scale = 1.0 - 1.0 / this.m_t;
            double y = instance.classValue() == 0.0 ? -1.0 : 1.0;
            double wx = SPegasos.dotProd(instance, this.m_weights, instance.classIndex());
            double z = y * (wx + this.m_weights[this.m_weights.length - 1]);
            for (int j = 0; j < this.m_weights.length - 1; ++j) {
                if (j == instance.classIndex()) continue;
                int n = j;
                this.m_weights[n] = this.m_weights[n] * scale;
            }
            if (this.m_loss == 1 || z < 1.0) {
                double loss = this.dloss(z);
                int n1 = instance.numValues();
                for (int p1 = 0; p1 < n1; ++p1) {
                    int indS = instance.index(p1);
                    if (indS == instance.classIndex() || instance.isMissingSparse(p1)) continue;
                    double m = learningRate * loss * (instance.valueSparse(p1) * y);
                    int n = indS;
                    this.m_weights[n] = this.m_weights[n] + m;
                }
                int n = this.m_weights.length - 1;
                this.m_weights[n] = this.m_weights[n] + learningRate * loss * y;
            }
            double norm = 0.0;
            for (int k = 0; k < this.m_weights.length - 1; ++k) {
                if (k == instance.classIndex()) continue;
                norm += this.m_weights[k] * this.m_weights[k];
            }
            double scale2 = Math.min(1.0, 1.0 / (this.m_lambda * norm));
            if (scale2 < 1.0) {
                scale2 = Math.sqrt(scale2);
                for (int j = 0; j < this.m_weights.length - 1; ++j) {
                    if (j == instance.classIndex()) continue;
                    int n = j;
                    this.m_weights[n] = this.m_weights[n] * scale2;
                }
            }
            this.m_t += 1.0;
        }
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        if (this.m_weights == null) {
            return new double[inst.numAttributes() + 1];
        }
        double[] result = new double[2];
        double wx = SPegasos.dotProd(inst, this.m_weights, inst.classIndex());
        double z = wx + this.m_weights[this.m_weights.length - 1];
        if (z <= 0.0) {
            if (this.m_loss == 1) {
                result[0] = 1.0 / (1.0 + Math.exp(z));
                result[1] = 1.0 - result[0];
            } else {
                result[0] = 1.0;
            }
        } else if (this.m_loss == 1) {
            result[1] = 1.0 / (1.0 + Math.exp(-z));
            result[0] = 1.0 - result[1];
        } else {
            result[1] = 1.0;
        }
        return result;
    }

    @Override
    public void getModelDescription(StringBuilder result, int indent) {
        StringUtils.appendIndented(result, indent, this.toString());
        StringUtils.appendNewline(result);
    }

    @Override
    public String toString() {
        if (this.m_weights == null) {
            return "SPegasos: No model built yet.\n";
        }
        StringBuffer buff = new StringBuffer();
        buff.append("Loss function: ");
        if (this.m_loss == 0) {
            buff.append("Hinge loss (SVM)\n\n");
        } else {
            buff.append("Log loss (logistic regression)\n\n");
        }
        int printed = 0;
        for (int i = 0; i < this.m_weights.length - 1; ++i) {
            if (printed > 0) {
                buff.append(" + ");
            } else {
                buff.append("   ");
            }
            buff.append(Utils.doubleToString(this.m_weights[i], 12, 4) + " " + "\n");
            ++printed;
        }
        if (this.m_weights[this.m_weights.length - 1] > 0.0) {
            buff.append(" + " + Utils.doubleToString(this.m_weights[this.m_weights.length - 1], 12, 4));
        } else {
            buff.append(" - " + Utils.doubleToString(-this.m_weights[this.m_weights.length - 1], 12, 4));
        }
        return buff.toString();
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }
}

