/*
 * Decompiled with CFR 0.152.
 */
package tools.mdsd.probdist.api.entity;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import tools.mdsd.probdist.api.entity.CPDEvaluator;
import tools.mdsd.probdist.api.entity.CategoricalValue;
import tools.mdsd.probdist.api.entity.Conditionable;
import tools.mdsd.probdist.api.entity.ProbabilityDistributionFunction;
import tools.mdsd.probdist.api.entity.UnivariateProbabilitiyMassFunction;
import tools.mdsd.probdist.api.exception.ProbabilityDistributionException;
import tools.mdsd.probdist.api.factory.IProbabilityDistributionFactory;
import tools.mdsd.probdist.api.random.ISeedProvider;
import tools.mdsd.probdist.distributionfunction.DistributionfunctionFactory;
import tools.mdsd.probdist.distributionfunction.Domain;
import tools.mdsd.probdist.distributionfunction.ParamRepresentation;
import tools.mdsd.probdist.distributionfunction.Parameter;
import tools.mdsd.probdist.distributionfunction.ProbabilityDistribution;
import tools.mdsd.probdist.distributionfunction.SimpleParameter;
import tools.mdsd.probdist.distributionfunction.TabularCPD;
import tools.mdsd.probdist.distributionfunction.TabularCPDEntry;

public class TabularCPDEvaluator
implements CPDEvaluator {
    private final Map<TabularCPDEntry, UnivariateProbabilitiyMassFunction> entryToPMF = new HashMap<TabularCPDEntry, UnivariateProbabilitiyMassFunction>();
    private final IProbabilityDistributionFactory<CategoricalValue> probabilityDistributionFactory;
    private boolean initialized = false;

    public TabularCPDEvaluator(TabularCPD tabularCPD, ProbabilityDistribution distribution, IProbabilityDistributionFactory<CategoricalValue> probabilityDistributionFactory) {
        this.probabilityDistributionFactory = probabilityDistributionFactory;
        this.initPMFEntries(tabularCPD, distribution);
    }

    @Override
    public void init(Optional<ISeedProvider> seedProvider) {
        if (this.initialized) {
            return;
        }
        this.initialized = true;
        for (Map.Entry<TabularCPDEntry, UnivariateProbabilitiyMassFunction> entry : this.entryToPMF.entrySet()) {
            UnivariateProbabilitiyMassFunction value = entry.getValue();
            value.init(seedProvider);
        }
    }

    private void initPMFEntries(TabularCPD tabularCPD, ProbabilityDistribution distribution) {
        for (TabularCPDEntry each : tabularCPD.getCpdEntries()) {
            UnivariateProbabilitiyMassFunction pmfRealisation = this.createPMFRealisation(distribution, each);
            this.entryToPMF.put(each, pmfRealisation);
        }
    }

    private UnivariateProbabilitiyMassFunction createPMFRealisation(ProbabilityDistribution distribution, TabularCPDEntry cpdEntry) {
        ProbabilityDistribution pmfEntry = this.createPMFEntry(distribution);
        this.setParamRepresentation(pmfEntry, cpdEntry);
        return this.getPMFRealisation(distribution, pmfEntry);
    }

    private UnivariateProbabilitiyMassFunction getPMFRealisation(ProbabilityDistribution distribution, ProbabilityDistribution pmfEntry) {
        return (UnivariateProbabilitiyMassFunction)this.probabilityDistributionFactory.getInstanceOf(pmfEntry).orElseThrow(() -> new ProbabilityDistributionException(String.format("There is no realisation for the PDF: %s", distribution.getInstantiated().getEntityName())));
    }

    private ProbabilityDistribution createPMFEntry(ProbabilityDistribution distribution) {
        ProbabilityDistribution pmfEntry = this.createPMFStructure(distribution);
        Parameter param = this.createParam(distribution);
        pmfEntry.getParams().add((Object)param);
        return pmfEntry;
    }

    private ProbabilityDistribution createPMFStructure(ProbabilityDistribution distribution) {
        DistributionfunctionFactory factory = DistributionfunctionFactory.eINSTANCE;
        ProbabilityDistribution structure = factory.createProbabilityDistribution();
        structure.getRandomVariables().addAll((Collection)distribution.getRandomVariables());
        structure.setInstantiated(distribution.getInstantiated());
        return structure;
    }

    private Parameter createParam(ProbabilityDistribution distribution) {
        DistributionfunctionFactory factory = DistributionfunctionFactory.eINSTANCE;
        Parameter param = factory.createParameter();
        param.setInstantiated(((Parameter)distribution.getParams().get(0)).getInstantiated());
        return param;
    }

    private void setParamRepresentation(ProbabilityDistribution pmfEntry, TabularCPDEntry cpdEntry) {
        SimpleParameter sParam = DistributionfunctionFactory.eINSTANCE.createSimpleParameter();
        SimpleParameter entry = cpdEntry.getEntry();
        sParam.setType(entry.getType());
        sParam.setValue(entry.getValue());
        Parameter param = (Parameter)pmfEntry.getParams().get(0);
        param.setRepresentation((ParamRepresentation)sParam);
    }

    @Override
    public Double evaluate(CategoricalValue value, List<Conditionable.Conditional<CategoricalValue>> conditionals) {
        return this.getCPDGiven(conditionals).probability(value);
    }

    @Override
    public ProbabilityDistributionFunction<CategoricalValue> getCPDGiven(List<Conditionable.Conditional<CategoricalValue>> conditionals) {
        if (!this.initialized) {
            throw new RuntimeException("not initilaized");
        }
        TabularCPDEntry cpdEntryMatching = this.findCPDEntryMatching(conditionals);
        return this.entryToPMF.get(cpdEntryMatching);
    }

    private TabularCPDEntry findCPDEntryMatching(List<Conditionable.Conditional<CategoricalValue>> conditionals) {
        return this.entryToPMF.keySet().stream().filter(this.entryMatching(conditionals)).findFirst().orElseThrow(() -> new ProbabilityDistributionException(String.format("The conditionals %1s are not included in the CPD table with %2s.", this.toString(conditionals), this.toString(this.entryToPMF.keySet()))));
    }

    private Predicate<TabularCPDEntry> entryMatching(List<Conditionable.Conditional<CategoricalValue>> queriedConditionals) {
        return e -> {
            List<Conditionable.Conditional<CategoricalValue>> entryConditionals = this.toCPDConditionals((List<String>)e.getConditonals());
            if (entryConditionals.size() != queriedConditionals.size()) {
                throw new IllegalArgumentException("The number of queried conditionals do not match the size of the tabular conditionals");
            }
            for (Conditionable.Conditional each : queriedConditionals) {
                entryConditionals.removeIf(this.isEqualTo(each));
            }
            return entryConditionals.isEmpty();
        };
    }

    private Predicate<Conditionable.Conditional<CategoricalValue>> isEqualTo(Conditionable.Conditional<CategoricalValue> other) {
        return given -> {
            if (given.getValueSpace() != other.getValueSpace()) {
                return false;
            }
            return ((String)((CategoricalValue)given.getValue()).value).equals(((CategoricalValue)conditional.getValue()).value);
        };
    }

    private List<Conditionable.Conditional<CategoricalValue>> toCPDConditionals(List<String> conditonals) {
        return conditonals.stream().map(each -> new Conditionable.Conditional<CategoricalValue>(Domain.CATEGORY, CategoricalValue.create(each))).collect(Collectors.toList());
    }

    private String toString(List<Conditionable.Conditional<CategoricalValue>> conditionals) {
        if (conditionals.size() == 1) {
            return conditionals.get(0).getValue().toString();
        }
        StringBuilder builder = new StringBuilder();
        for (Conditionable.Conditional<CategoricalValue> each : conditionals) {
            builder.append(String.format(",%s", each.getValue().toString()));
        }
        return builder.toString().replaceFirst(",", "");
    }

    private String toString(Set<TabularCPDEntry> entries) {
        ArrayList entryList = Lists.newArrayList(entries);
        if (entryList.size() == 1) {
            return ((TabularCPDEntry)entryList.get(0)).getEntry().getValue();
        }
        StringBuilder builder = new StringBuilder();
        for (TabularCPDEntry each : entryList) {
            builder.append(String.format(",%s", each.getEntry().getValue().toString()));
        }
        return builder.toString().replaceFirst(",", "");
    }
}

