package pl.poznan.put.cs.idss.jrs.wrappers;

import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.GregorianCalendar;
import java.util.Random;
import pl.poznan.put.cs.idss.jrs.approximations.ConsistencyMeasure;
import pl.poznan.put.cs.idss.jrs.approximations.MonotonicDecisionClassContainer;
import pl.poznan.put.cs.idss.jrs.approximations.MonotonicUnionContainer;
import pl.poznan.put.cs.idss.jrs.approximations.RoughMembershipMeasure;
import pl.poznan.put.cs.idss.jrs.approximations.StandardDecisionClassContainer;
import pl.poznan.put.cs.idss.jrs.approximations.StandardUnionContainer;
import pl.poznan.put.cs.idss.jrs.classifiers.ClassificationResultsFoldValidationContainer;
import pl.poznan.put.cs.idss.jrs.classifiers.ClassificationResultsValidationContainer;
import pl.poznan.put.cs.idss.jrs.classifiers.ClassificationStatisticsCollector;
import pl.poznan.put.cs.idss.jrs.classifiers.ClassificationStatisticsPresenter;
import pl.poznan.put.cs.idss.jrs.classifiers.Classifier;
import pl.poznan.put.cs.idss.jrs.classifiers.ensembles.EnsembleClassificationStatisticsFoldCollector;
import pl.poznan.put.cs.idss.jrs.classifiers.ensembles.MajorityVotingMethod;
import pl.poznan.put.cs.idss.jrs.core.SerialIOException;
import pl.poznan.put.cs.idss.jrs.core.Transfer;
import pl.poznan.put.cs.idss.jrs.core.UnknownValueException;
import pl.poznan.put.cs.idss.jrs.core.mem.MemoryContainer;
import pl.poznan.put.cs.idss.jrs.core.mem.MemoryContainerDecisionsManager;
import pl.poznan.put.cs.idss.jrs.core.mem.RandomizableMemoryContainer;
import pl.poznan.put.cs.idss.jrs.ensembles.Bagging;
import pl.poznan.put.cs.idss.jrs.validators.CrossValidation;

/* loaded from: input_file:pl/poznan/put/cs/idss/jrs/wrappers/BaggingWrapper.class */
public class BaggingWrapper implements ClassificationStatisticsPresenter {
    protected RandomizableMemoryContainer learningMemoryContainer;
    protected SimpleClassifierWrapper baseClassifier;
    protected SimpleClassifierWrapper[] classifiers;
    protected Bagging bagging;
    protected Random generator;
    protected ConsistencyMeasure consistencyMeasure;
    protected int numberOfClassifiers;
    int aggregationMethod;
    protected String resultsFileName;
    protected EnsembleClassificationStatisticsFoldCollector classificationStatistics;
    protected boolean inconsistencyVCBagging;

    public BaggingWrapper() {
        this.aggregationMethod = 0;
        this.inconsistencyVCBagging = false;
        this.learningMemoryContainer = null;
        this.baseClassifier = null;
        this.classifiers = null;
        this.bagging = null;
        this.consistencyMeasure = null;
        this.numberOfClassifiers = 0;
        this.generator = null;
        this.aggregationMethod = 0;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer) {
        this();
        this.learningMemoryContainer = randomizableMemoryContainer;
    }

    public BaggingWrapper(SimpleClassifierWrapper simpleClassifierWrapper) {
        this();
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
    }

    public BaggingWrapper(Bagging bagging) {
        this();
        this.bagging = bagging;
    }

    public BaggingWrapper(Bagging bagging, Random random) {
        this();
        this.bagging = bagging;
        this.generator = random;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, Bagging bagging) {
        this(randomizableMemoryContainer);
        this.bagging = bagging;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, Bagging bagging, Random random) {
        this(randomizableMemoryContainer);
        this.bagging = bagging;
        this.generator = random;
    }

    public BaggingWrapper(SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging) {
        this();
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
        this.bagging = bagging;
    }

    public BaggingWrapper(SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging, Random random) {
        this();
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
        this.bagging = bagging;
        this.generator = random;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
        this.bagging = bagging;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging, Random random) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
        this.bagging = bagging;
        this.generator = random;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging, ConsistencyMeasure consistencyMeasure) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
        this.bagging = bagging;
        this.consistencyMeasure = consistencyMeasure;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging, Random random, ConsistencyMeasure consistencyMeasure) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = 1;
        this.bagging = bagging;
        this.generator = random;
        this.consistencyMeasure = consistencyMeasure;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging, ConsistencyMeasure consistencyMeasure, int i) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = i;
        this.bagging = bagging;
        this.consistencyMeasure = consistencyMeasure;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper, Bagging bagging, Random random, ConsistencyMeasure consistencyMeasure, int i) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = i;
        this.bagging = bagging;
        this.generator = random;
        this.consistencyMeasure = consistencyMeasure;
    }

    public BaggingWrapper(RandomizableMemoryContainer randomizableMemoryContainer, SimpleClassifierWrapper simpleClassifierWrapper, int i, Bagging bagging, Random random, ConsistencyMeasure consistencyMeasure, int i2) {
        this(randomizableMemoryContainer);
        this.baseClassifier = simpleClassifierWrapper;
        this.numberOfClassifiers = i2;
        this.bagging = bagging;
        this.generator = random;
        this.consistencyMeasure = consistencyMeasure;
        this.aggregationMethod = i;
    }

    protected double[] getWeights() {
        return getWeights(this.learningMemoryContainer);
    }

    protected double[] getWeights(MemoryContainer memoryContainer) {
        double[] dArr = null;
        if (this.bagging != null && memoryContainer != null) {
            if (this.consistencyMeasure == null) {
                dArr = this.bagging.getEqualWeights(memoryContainer);
            } else {
                dArr = this.consistencyMeasure.getNumHVDMObjects() == -1 ? memoryContainer.getAttribute(MemoryContainerDecisionsManager.getFirstDecisionAttributeIndex(memoryContainer)).getPreferenceType() == 0 ? this.consistencyMeasure instanceof RoughMembershipMeasure ? this.bagging.getVCWeights(new StandardDecisionClassContainer(memoryContainer), this.consistencyMeasure) : this.bagging.getVCWeights(new MonotonicDecisionClassContainer(memoryContainer), this.consistencyMeasure) : this.consistencyMeasure instanceof RoughMembershipMeasure ? this.bagging.getVCWeights(new StandardUnionContainer(memoryContainer), this.consistencyMeasure) : this.bagging.getVCWeights(new MonotonicUnionContainer(memoryContainer), this.consistencyMeasure) : memoryContainer.getAttribute(MemoryContainerDecisionsManager.getFirstDecisionAttributeIndex(memoryContainer)).getPreferenceType() == 0 ? this.consistencyMeasure instanceof RoughMembershipMeasure ? this.bagging.getHVDMVCWeights(new StandardDecisionClassContainer(memoryContainer), this.consistencyMeasure) : this.bagging.getHVDMVCWeights(new MonotonicDecisionClassContainer(memoryContainer), this.consistencyMeasure) : this.consistencyMeasure instanceof RoughMembershipMeasure ? this.bagging.getHVDMVCWeights(new StandardUnionContainer(memoryContainer), this.consistencyMeasure) : this.bagging.getHVDMVCWeights(new MonotonicUnionContainer(memoryContainer), this.consistencyMeasure);
                if (this.inconsistencyVCBagging) {
                    dArr = shiftWeights(dArr);
                }
            }
        }
        return dArr;
    }

    public ClassificationResultsValidationContainer validate() throws UnknownValueException {
        return validate(this.learningMemoryContainer, GregorianCalendar.getInstance().getTimeInMillis());
    }

    public ClassificationResultsValidationContainer validate(MemoryContainer memoryContainer, long j) throws UnknownValueException {
        ClassificationResultsValidationContainer classificationResultsValidationContainer = null;
        if (this.baseClassifier == null || this.learningMemoryContainer == null || memoryContainer == null || this.bagging == null || this.numberOfClassifiers <= 0) {
            throw new UnknownValueException("Unknown classification method or not set test container");
        }
        try {
            boolean[] zArr = new boolean[this.learningMemoryContainer.size()];
            Arrays.fill(zArr, false);
            if (this.generator == null) {
                this.generator = this.bagging.getRandomNumberGenerator(j);
            }
            double[] weights = getWeights();
            this.classifiers = new SimpleClassifierWrapper[this.numberOfClassifiers];
            for (int i = 0; i < this.numberOfClassifiers; i++) {
                MemoryContainer resampleWithWeights = this.bagging.resampleWithWeights(this.learningMemoryContainer, this.generator, true, weights, zArr);
                this.classifiers[i] = (SimpleClassifierWrapper) this.baseClassifier.clone();
                this.classifiers[i].build(resampleWithWeights);
                Arrays.fill(zArr, false);
                if (this.resultsFileName != null) {
                    try {
                        Transfer.saveSimpleIsf(new FileOutputStream(String.valueOf(this.resultsFileName) + "_boot_" + i + ".isf"), resampleWithWeights);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
            classificationResultsValidationContainer = new ClassificationResultsValidationContainer(new Classifier(new MajorityVotingMethod(this.classifiers), MemoryContainerDecisionsManager.getFirstDecisionAttributeIndex(this.learningMemoryContainer)), this.learningMemoryContainer, memoryContainer);
            if (this.resultsFileName != null) {
                try {
                    classificationResultsValidationContainer.writeClassificationResults(String.valueOf(this.resultsFileName) + ".cls");
                    classificationResultsValidationContainer.writeClassificationResultsRAW(String.valueOf(this.resultsFileName) + ".raw");
                } catch (IOException e2) {
                    e2.printStackTrace();
                }
            }
        } catch (Exception e3) {
            e3.printStackTrace();
        }
        return classificationResultsValidationContainer;
    }

    public ClassificationResultsFoldValidationContainer crossValidate(int i, long j) throws UnknownValueException {
        return crossValidate(i, j, this.consistencyMeasure);
    }

    public ClassificationResultsFoldValidationContainer crossValidate(int i, long j, ConsistencyMeasure consistencyMeasure) {
        MajorityVotingMethod majorityVotingMethod;
        ClassificationResultsFoldValidationContainer classificationResultsFoldValidationContainer = null;
        if (this.baseClassifier == null || this.learningMemoryContainer == null || this.bagging == null || this.numberOfClassifiers <= 0 || i <= 0) {
            throw new UnknownValueException("Unknown classification method or not set test container");
        }
        this.consistencyMeasure = consistencyMeasure;
        try {
            CrossValidation crossValidation = new CrossValidation(this.learningMemoryContainer, MemoryContainerDecisionsManager.getFirstDecisionAttributeIndex(this.learningMemoryContainer), i);
            crossValidation.stratify(crossValidation.getRandomNumberGenerator(j));
            classificationResultsFoldValidationContainer = new ClassificationResultsFoldValidationContainer(this.learningMemoryContainer);
            this.classificationStatistics = new EnsembleClassificationStatisticsFoldCollector();
            String str = this.resultsFileName;
            for (int i2 = 0; i2 < i; i2++) {
                MemoryContainer trainDataSet = crossValidation.getTrainDataSet(i2);
                MemoryContainer testDataSet = crossValidation.getTestDataSet(i2);
                if (this.resultsFileName != null) {
                    this.resultsFileName = String.valueOf(this.resultsFileName) + "_" + i2;
                    try {
                        Transfer.saveSimpleIsf(new FileOutputStream(String.valueOf(this.resultsFileName) + "_l.isf"), trainDataSet);
                        Transfer.saveSimpleIsf(new FileOutputStream(String.valueOf(this.resultsFileName) + "_t.isf"), testDataSet);
                    } catch (IOException e) {
                        e.printStackTrace();
                    } catch (SerialIOException e2) {
                        e2.printStackTrace();
                    }
                }
                boolean[] zArr = new boolean[trainDataSet.size()];
                Arrays.fill(zArr, false);
                this.generator = this.bagging.getRandomNumberGenerator(j);
                double[] weights = getWeights(trainDataSet);
                this.classifiers = new SimpleClassifierWrapper[this.numberOfClassifiers];
                for (int i3 = 0; i3 < this.numberOfClassifiers; i3++) {
                    MemoryContainer resampleWithWeights = this.bagging.resampleWithWeights(trainDataSet, this.generator, true, weights, zArr);
                    this.classifiers[i3] = (SimpleClassifierWrapper) this.baseClassifier.clone();
                    this.classifiers[i3].build(resampleWithWeights);
                    Arrays.fill(zArr, false);
                    if (this.consistencyMeasure != null && this.consistencyMeasure.getNumAttributes() != -1) {
                        weights = getWeights(trainDataSet);
                    }
                    if (this.resultsFileName != null) {
                        try {
                            Transfer.saveSimpleIsf(new FileOutputStream(String.valueOf(this.resultsFileName) + "_boot_" + i2 + ".isf"), resampleWithWeights);
                        } catch (Exception e3) {
                            e3.printStackTrace();
                        }
                    }
                }
                switch (this.aggregationMethod) {
                    case 0:
                        majorityVotingMethod = new MajorityVotingMethod(this.classifiers);
                        break;
                    case 1:
                        majorityVotingMethod = new MajorityVotingMethod(this.classifiers);
                        break;
                    case 2:
                        majorityVotingMethod = new MajorityVotingMethod(this.classifiers);
                        break;
                    default:
                        majorityVotingMethod = new MajorityVotingMethod(this.classifiers);
                        break;
                }
                classificationResultsFoldValidationContainer.addFold(new Classifier(majorityVotingMethod, MemoryContainerDecisionsManager.getFirstDecisionAttributeIndex(this.learningMemoryContainer)), testDataSet);
                if (majorityVotingMethod instanceof ClassificationStatisticsPresenter) {
                    majorityVotingMethod.getClassificationStatisticsCollector().addNumberOfClassifiers(this.numberOfClassifiers);
                    majorityVotingMethod.getClassificationStatisticsCollector().addNumberOfObjects(testDataSet.size());
                    this.classificationStatistics.addStatistics(majorityVotingMethod.getClassificationStatisticsCollector());
                    majorityVotingMethod.getClassificationStatisticsCollector().clearCounters();
                }
                if (this.resultsFileName != null) {
                    try {
                        classificationResultsFoldValidationContainer.writeClassificationResults(String.valueOf(this.resultsFileName) + ".cls");
                        classificationResultsFoldValidationContainer.writeClassificationResultsRAW(String.valueOf(this.resultsFileName) + ".raw");
                        classificationResultsFoldValidationContainer.writeMisclassificationMatrix(i2, String.valueOf(this.resultsFileName) + ".mtx");
                    } catch (IOException e4) {
                        e4.printStackTrace();
                    }
                }
                this.resultsFileName = str;
                trainDataSet.clear();
                testDataSet.clear();
                this.classifiers = null;
            }
        } catch (Exception e5) {
            e5.printStackTrace();
        }
        if (this.resultsFileName != null) {
            try {
                classificationResultsFoldValidationContainer.writeClassificationResults(String.valueOf(this.resultsFileName) + ".cls");
                classificationResultsFoldValidationContainer.writeClassificationResultsRAW(String.valueOf(this.resultsFileName) + ".raw");
                classificationResultsFoldValidationContainer.writeMisclassificationMatrix(String.valueOf(this.resultsFileName) + ".mtx");
            } catch (IOException e6) {
                e6.printStackTrace();
            }
        }
        return classificationResultsFoldValidationContainer;
    }

    public String getResultsFileName() {
        return this.resultsFileName;
    }

    public void setResultsFileName(String str) {
        this.resultsFileName = str;
    }

    public void setBaggingRandomGenerator(Random random) {
        this.generator = random;
    }

    @Override // pl.poznan.put.cs.idss.jrs.classifiers.ClassificationStatisticsPresenter
    public String classificationStatisticsToString() {
        return this.classificationStatistics.toString();
    }

    @Override // pl.poznan.put.cs.idss.jrs.classifiers.ClassificationStatisticsPresenter
    public ClassificationStatisticsCollector getClassificationStatisticsCollector() {
        return this.classificationStatistics;
    }

    public boolean getInconsistencyVCBagging() {
        return this.inconsistencyVCBagging;
    }

    public void setInconsistencyVCBagging(boolean z) {
        this.inconsistencyVCBagging = z;
    }

    protected double[] shiftWeights(double[] dArr) {
        double[] dArr2 = null;
        if (dArr != null) {
            dArr2 = new double[dArr.length];
            Arrays.fill(dArr2, 0.0d);
            double d = 1.0d;
            for (int i = 0; i < dArr.length; i++) {
                if (d > dArr[i]) {
                    d = dArr[i];
                }
            }
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr2[i2] = (1.0d - dArr[i2]) + d;
            }
        }
        return dArr2;
    }
}
