package tools.mdsd.probdist.api.entity;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import tools.mdsd.probdist.api.entity.Conditionable;
import tools.mdsd.probdist.api.exception.ProbabilityDistributionException;
import tools.mdsd.probdist.api.factory.IProbabilityDistributionFactory;
import tools.mdsd.probdist.api.parser.DefaultParameterParser;
import tools.mdsd.probdist.api.random.ISeedProvider;
import tools.mdsd.probdist.distributionfunction.DistributionfunctionFactory;
import tools.mdsd.probdist.distributionfunction.Domain;
import tools.mdsd.probdist.distributionfunction.Parameter;
import tools.mdsd.probdist.distributionfunction.ProbabilityDistribution;
import tools.mdsd.probdist.distributionfunction.SimpleParameter;
import tools.mdsd.probdist.distributionfunction.TabularCPD;
import tools.mdsd.probdist.distributionfunction.TabularCPDEntry;

/* loaded from: input_file:tools/mdsd/probdist/api/entity/TabularCPDEvaluator.class */
public class TabularCPDEvaluator implements CPDEvaluator {
    private final IProbabilityDistributionFactory<CategoricalValue> probabilityDistributionFactory;
    private final Map<TabularCPDEntry, UnivariateProbabilitiyMassFunction> entryToPMF = new HashMap();
    private boolean initialized = false;

    public TabularCPDEvaluator(TabularCPD tabularCPD, ProbabilityDistribution probabilityDistribution, IProbabilityDistributionFactory<CategoricalValue> iProbabilityDistributionFactory) {
        this.probabilityDistributionFactory = iProbabilityDistributionFactory;
        initPMFEntries(tabularCPD, probabilityDistribution);
    }

    @Override // tools.mdsd.probdist.api.random.ISeedable
    public void init(Optional<ISeedProvider> optional) {
        if (this.initialized) {
            return;
        }
        this.initialized = true;
        Iterator<Map.Entry<TabularCPDEntry, UnivariateProbabilitiyMassFunction>> it = this.entryToPMF.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().init(optional);
        }
    }

    private void initPMFEntries(TabularCPD tabularCPD, ProbabilityDistribution probabilityDistribution) {
        for (TabularCPDEntry tabularCPDEntry : tabularCPD.getCpdEntries()) {
            this.entryToPMF.put(tabularCPDEntry, createPMFRealisation(probabilityDistribution, tabularCPDEntry));
        }
    }

    private UnivariateProbabilitiyMassFunction createPMFRealisation(ProbabilityDistribution probabilityDistribution, TabularCPDEntry tabularCPDEntry) {
        ProbabilityDistribution createPMFEntry = createPMFEntry(probabilityDistribution);
        setParamRepresentation(createPMFEntry, tabularCPDEntry);
        return getPMFRealisation(probabilityDistribution, createPMFEntry);
    }

    private UnivariateProbabilitiyMassFunction getPMFRealisation(ProbabilityDistribution probabilityDistribution, ProbabilityDistribution probabilityDistribution2) {
        return (UnivariateProbabilitiyMassFunction) this.probabilityDistributionFactory.getInstanceOf(probabilityDistribution2).orElseThrow(() -> {
            return new ProbabilityDistributionException(String.format("There is no realisation for the PDF: %s", probabilityDistribution.getInstantiated().getEntityName()));
        });
    }

    private ProbabilityDistribution createPMFEntry(ProbabilityDistribution probabilityDistribution) {
        ProbabilityDistribution createPMFStructure = createPMFStructure(probabilityDistribution);
        createPMFStructure.getParams().add(createParam(probabilityDistribution));
        return createPMFStructure;
    }

    private ProbabilityDistribution createPMFStructure(ProbabilityDistribution probabilityDistribution) {
        ProbabilityDistribution createProbabilityDistribution = DistributionfunctionFactory.eINSTANCE.createProbabilityDistribution();
        createProbabilityDistribution.getRandomVariables().addAll(probabilityDistribution.getRandomVariables());
        createProbabilityDistribution.setInstantiated(probabilityDistribution.getInstantiated());
        return createProbabilityDistribution;
    }

    private Parameter createParam(ProbabilityDistribution probabilityDistribution) {
        Parameter createParameter = DistributionfunctionFactory.eINSTANCE.createParameter();
        createParameter.setInstantiated(((Parameter) probabilityDistribution.getParams().get(0)).getInstantiated());
        return createParameter;
    }

    private void setParamRepresentation(ProbabilityDistribution probabilityDistribution, TabularCPDEntry tabularCPDEntry) {
        SimpleParameter createSimpleParameter = DistributionfunctionFactory.eINSTANCE.createSimpleParameter();
        SimpleParameter entry = tabularCPDEntry.getEntry();
        createSimpleParameter.setType(entry.getType());
        createSimpleParameter.setValue(entry.getValue());
        ((Parameter) probabilityDistribution.getParams().get(0)).setRepresentation(createSimpleParameter);
    }

    @Override // tools.mdsd.probdist.api.entity.CPDEvaluator
    public Double evaluate(CategoricalValue categoricalValue, List<Conditionable.Conditional<CategoricalValue>> list) {
        return getCPDGiven(list).probability(categoricalValue);
    }

    @Override // tools.mdsd.probdist.api.entity.CPDEvaluator
    public ProbabilityDistributionFunction<CategoricalValue> getCPDGiven(List<Conditionable.Conditional<CategoricalValue>> list) {
        if (!this.initialized) {
            throw new RuntimeException("not initilaized");
        }
        return this.entryToPMF.get(findCPDEntryMatching(list));
    }

    private TabularCPDEntry findCPDEntryMatching(List<Conditionable.Conditional<CategoricalValue>> list) {
        return this.entryToPMF.keySet().stream().filter(entryMatching(list)).findFirst().orElseThrow(() -> {
            return new ProbabilityDistributionException(String.format("The conditionals %1s are not included in the CPD table with %2s.", toString((List<Conditionable.Conditional<CategoricalValue>>) list), toString(this.entryToPMF.keySet())));
        });
    }

    private Predicate<TabularCPDEntry> entryMatching(List<Conditionable.Conditional<CategoricalValue>> list) {
        return tabularCPDEntry -> {
            List<Conditionable.Conditional<CategoricalValue>> cPDConditionals = toCPDConditionals(tabularCPDEntry.getConditonals());
            if (cPDConditionals.size() != list.size()) {
                throw new IllegalArgumentException("The number of queried conditionals do not match the size of the tabular conditionals");
            }
            Iterator it = list.iterator();
            while (it.hasNext()) {
                cPDConditionals.removeIf(isEqualTo((Conditionable.Conditional) it.next()));
            }
            return cPDConditionals.isEmpty();
        };
    }

    private Predicate<Conditionable.Conditional<CategoricalValue>> isEqualTo(Conditionable.Conditional<CategoricalValue> conditional) {
        return conditional2 -> {
            if (conditional2.getValueSpace() != conditional.getValueSpace()) {
                return false;
            }
            return ((String) ((CategoricalValue) conditional2.getValue()).value).equals(((CategoricalValue) conditional.getValue()).value);
        };
    }

    private List<Conditionable.Conditional<CategoricalValue>> toCPDConditionals(List<String> list) {
        return (List) list.stream().map(str -> {
            return new Conditionable.Conditional(Domain.CATEGORY, CategoricalValue.create(str));
        }).collect(Collectors.toList());
    }

    private String toString(List<Conditionable.Conditional<CategoricalValue>> list) {
        if (list.size() == 1) {
            return list.get(0).getValue().toString();
        }
        StringBuilder sb = new StringBuilder();
        Iterator<Conditionable.Conditional<CategoricalValue>> it = list.iterator();
        while (it.hasNext()) {
            sb.append(String.format(",%s", it.next().getValue().toString()));
        }
        return sb.toString().replaceFirst(DefaultParameterParser.PAIR_DELIMITER, "");
    }

    private String toString(Set<TabularCPDEntry> set) {
        ArrayList newArrayList = Lists.newArrayList(set);
        if (newArrayList.size() == 1) {
            return ((TabularCPDEntry) newArrayList.get(0)).getEntry().getValue();
        }
        StringBuilder sb = new StringBuilder();
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            sb.append(String.format(",%s", ((TabularCPDEntry) it.next()).getEntry().getValue().toString()));
        }
        return sb.toString().replaceFirst(DefaultParameterParser.PAIR_DELIMITER, "");
    }
}
