package edu.rice.cs.bioinfo.programs.phylonet.commands;

import edu.rice.cs.bioinfo.library.language.pyson._1_0.ir.blockcontents.Parameter;
import edu.rice.cs.bioinfo.library.language.pyson._1_0.ir.blockcontents.ParameterIdent;
import edu.rice.cs.bioinfo.library.language.pyson._1_0.ir.blockcontents.ParameterIdentList;
import edu.rice.cs.bioinfo.library.language.pyson._1_0.ir.blockcontents.SyntaxCommand;
import edu.rice.cs.bioinfo.library.language.richnewick._1_1.reading.ast.NetworkNonEmpty;
import edu.rice.cs.bioinfo.library.language.richnewick._1_1.reading.ast.Networks;
import edu.rice.cs.bioinfo.library.language.richnewick.reading.RichNewickReader;
import edu.rice.cs.bioinfo.library.programming.Container;
import edu.rice.cs.bioinfo.library.programming.Func3;
import edu.rice.cs.bioinfo.library.programming.Proc;
import edu.rice.cs.bioinfo.library.programming.Proc1;
import edu.rice.cs.bioinfo.library.programming.Proc3;
import edu.rice.cs.bioinfo.programs.phylonet.algos.network.GeneTreeProbability;
import edu.rice.cs.bioinfo.programs.phylonet.structs.network.NetNode;
import edu.rice.cs.bioinfo.programs.phylonet.structs.network.Network;
import edu.rice.cs.bioinfo.programs.phylonet.structs.network.io.RnNewickPrinter;
import edu.rice.cs.bioinfo.programs.phylonet.structs.network.model.bni.BniNetwork;
import edu.rice.cs.bioinfo.programs.phylonet.structs.network.model.bni.NetworkFactoryFromRNNetwork;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.io.NewickReader;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.model.Tree;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.model.sti.STITree;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.util.Trees;
import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.optimization.GoalType;
import org.apache.commons.math3.optimization.univariate.BrentOptimizer;

@CommandName("SearchBranchLengthsMaxGTProb")
/* loaded from: input_file:edu/rice/cs/bioinfo/programs/phylonet/commands/SearchBranchLengthsMaxGTProb.class */
public class SearchBranchLengthsMaxGTProb extends CommandBaseFileOut {
    private HashMap<String, String> _taxonMap;
    private boolean _printDetail;
    private NetworkNonEmpty _speciesNetwork;
    private List<NetworkNonEmpty> _geneTrees;
    private ParameterIdentList _geneTreeParam;
    private double _maxBranchLength;
    private int _maxAssigmentAttemptsPerBranchParam;
    private int _assigmentRounds;
    private double _improvementThreshold;
    private Func3<Network<Double>, List<Tree>, List<Integer>, Double> _computeGTProbStrategyCalGTProb;
    private Func3<Network<Double>, List<Tree>, List<Integer>, Double> _computeGTProbStrategyBox;
    private Func3<Network<Double>, List<Tree>, List<Integer>, Double> _computeGTProbStrategy;

    /* loaded from: input_file:edu/rice/cs/bioinfo/programs/phylonet/commands/SearchBranchLengthsMaxGTProb$Box.class */
    public class Box {
        public double t0;
        public double t1;
        public double gamma;
        public double P1;
        public double P2;
        public double P3;
        public double t0star;
        public double t1star;
        public double gammastar;
        public double P1star;
        public double P2star;
        public double P3star;
        public double lnLikelihood;
        public double MaxlnLikelihood;
        public double finalt0star;
        public double finalt1star;
        public double finalgammastar;
        public int n;

        public Box(double d, double d2, double d3, double d4, double d5, double d6, int i) {
            this.t0 = d;
            this.t1 = d2;
            this.gamma = d3;
            this.t0star = d4;
            this.t1star = d5;
            this.gammastar = d6;
            this.n = i;
            this.P1 = ((1.0d - d3) * (1.0d - (0.6666666666666666d * Math.exp(-d2)))) + ((d3 * Math.exp(-d)) / 3.0d);
            this.P2 = (d3 * (1.0d - (0.6666666666666666d * Math.exp(-d)))) + (((1.0d - d3) * Math.exp(-d2)) / 3.0d);
            this.P3 = (((1.0d - d3) * Math.exp(-d2)) / 3.0d) + ((d3 * Math.exp(-d)) / 3.0d);
            callnLikelihood();
        }

        public double callnLikelihood() {
            this.P1star = ((1.0d - this.gammastar) * (1.0d - (0.6666666666666666d * Math.exp(-this.t1star)))) + ((this.gammastar * Math.exp(-this.t0star)) / 3.0d);
            this.P2star = (this.gammastar * (1.0d - (0.6666666666666666d * Math.exp(-this.t0star)))) + (((1.0d - this.gammastar) * Math.exp(-this.t1star)) / 3.0d);
            this.P3star = (((1.0d - this.gammastar) * Math.exp(-this.t1star)) / 3.0d) + ((this.gammastar * Math.exp(-this.t0star)) / 3.0d);
            this.lnLikelihood = this.n * ((this.P1 * Math.log(this.P1star)) + (this.P2 * Math.log(this.P2star)) + (this.P3 * Math.log(this.P3star)));
            return this.lnLikelihood;
        }
    }

    public SearchBranchLengthsMaxGTProb(SyntaxCommand syntaxCommand, ArrayList<Parameter> arrayList, Map<String, NetworkNonEmpty> map, Proc3<String, Integer, Integer> proc3, RichNewickReader<Networks> richNewickReader) {
        super(syntaxCommand, arrayList, map, proc3, richNewickReader);
        this._taxonMap = null;
        this._printDetail = false;
        this._maxAssigmentAttemptsPerBranchParam = -1;
        this._assigmentRounds = Integer.MAX_VALUE;
        this._computeGTProbStrategyCalGTProb = new Func3<Network<Double>, List<Tree>, List<Integer>, Double>() { // from class: edu.rice.cs.bioinfo.programs.phylonet.commands.SearchBranchLengthsMaxGTProb.3
            @Override // edu.rice.cs.bioinfo.library.programming.Func3
            public Double execute(Network<Double> network, List<Tree> list, List<Integer> list2) {
                Iterator<Double> it = new GeneTreeProbability().calculateGTDistribution(network, list, SearchBranchLengthsMaxGTProb.this._taxonMap, SearchBranchLengthsMaxGTProb.this._printDetail).iterator();
                Iterator<Integer> it2 = list2.iterator();
                double d = 0.0d;
                for (Tree tree : list) {
                    d += Math.log(it.next().doubleValue()) * it2.next().intValue();
                }
                if (Double.isNaN(d)) {
                    throw new RuntimeException();
                }
                return Double.valueOf(d);
            }
        };
        this._computeGTProbStrategyBox = new Func3<Network<Double>, List<Tree>, List<Integer>, Double>() { // from class: edu.rice.cs.bioinfo.programs.phylonet.commands.SearchBranchLengthsMaxGTProb.4
            @Override // edu.rice.cs.bioinfo.library.programming.Func3
            public Double execute(Network<Double> network, List<Tree> list, List<Integer> list2) {
                return Double.valueOf(new Box(1.1d, 1.1d, 0.5d, network.findNode("D").getParentDistance(network.getRoot()), network.findNode("F").getParentDistance(network.getRoot()), network.findNode("E").getParentProbability(network.findNode("D")), SearchBranchLengthsMaxGTProb.this._geneTrees.size()).callnLikelihood());
            }
        };
        this._computeGTProbStrategy = this._computeGTProbStrategyCalGTProb;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.rice.cs.bioinfo.programs.phylonet.commands.CommandBase
    public int getMinNumParams() {
        return 4;
    }

    @Override // edu.rice.cs.bioinfo.programs.phylonet.commands.CommandBase
    protected int getMaxNumParams() {
        return 8;
    }

    @Override // edu.rice.cs.bioinfo.programs.phylonet.commands.CommandBase
    protected boolean checkParamsForCommand() {
        this._speciesNetwork = assertAndGetNetwork(0);
        boolean z = (1 == 0 || this._speciesNetwork == null) ? false : true;
        ParameterIdent assertParameterIdent = assertParameterIdent(1);
        boolean z2 = z && assertParameterIdent != null;
        if (assertParameterIdent != null) {
            this._maxBranchLength = Double.parseDouble(assertParameterIdent.Content);
        }
        ParameterIdent assertParameterIdent2 = assertParameterIdent(2);
        boolean z3 = z2 && assertParameterIdent2 != null;
        if (assertParameterIdent2 != null) {
            this._improvementThreshold = Double.parseDouble(assertParameterIdent2.Content);
        }
        ParameterIdent assertParameterIdent3 = assertParameterIdent(3);
        boolean z4 = z3 && assertParameterIdent3 != null;
        if (assertParameterIdent3 != null) {
            this._maxAssigmentAttemptsPerBranchParam = Integer.parseInt(assertParameterIdent3.Content);
        }
        this._geneTreeParam = assertParameterIdentList(4);
        boolean z5 = z4 && this._geneTreeParam != null;
        this._geneTrees = new LinkedList();
        for (String str : this._geneTreeParam.Elements) {
            z5 = z5 && assertNetworkExists(str, this._geneTreeParam.getLine(), this._geneTreeParam.getColumn());
            if (z5) {
                this._geneTrees.add(this.sourceIdentToNetwork.get(str));
            }
        }
        ParamExtractorAllelMap paramExtractorAllelMap = new ParamExtractorAllelMap("a", this.params, this.errorDetected);
        if (paramExtractorAllelMap.ContainsSwitch) {
            z5 = z5 && paramExtractorAllelMap.IsValidMap;
            if (paramExtractorAllelMap.IsValidMap) {
                this._taxonMap = paramExtractorAllelMap.ValueMap;
            }
        }
        if (new ParamExtractor("am", this.params, this.errorDetected).ContainsSwitch) {
            this._computeGTProbStrategy = this._computeGTProbStrategyBox;
        } else {
            this._computeGTProbStrategy = this._computeGTProbStrategyCalGTProb;
        }
        ParamExtractor paramExtractor = new ParamExtractor("p", this.params, this.errorDetected);
        if (paramExtractor.ContainsSwitch) {
            this._printDetail = true;
        }
        boolean z6 = z5 && checkForUnknownSwitches("p", "a", "am");
        checkAndSetOutFile(paramExtractorAllelMap, paramExtractor);
        return z6;
    }

    @Override // edu.rice.cs.bioinfo.programs.phylonet.commands.CommandBaseFileOut
    protected String produceResult() {
        StringBuffer stringBuffer = new StringBuffer();
        final ArrayList arrayList = new ArrayList();
        final ArrayList arrayList2 = new ArrayList();
        Iterator<NetworkNonEmpty> it = this._geneTrees.iterator();
        while (it.hasNext()) {
            NewickReader newickReader = new NewickReader(new StringReader(NetworkTransformer.toENewickTree(it.next())));
            STITree<Double> sTITree = new STITree<>(true);
            try {
                newickReader.readTree(sTITree);
            } catch (Exception e) {
                this.errorDetected.execute(e.getMessage(), Integer.valueOf(this._motivatingCommand.getLine()), Integer.valueOf(this._motivatingCommand.getColumn()));
            }
            boolean z = false;
            int i = 0;
            Iterator it2 = arrayList.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                if (Trees.haveSameRootedTopology((Tree) it2.next(), sTITree)) {
                    z = true;
                    break;
                }
                i++;
            }
            if (z) {
                arrayList2.set(i, Integer.valueOf(((Integer) arrayList2.get(i)).intValue() + 1));
            } else {
                arrayList.add(sTITree);
                arrayList2.add(1);
            }
        }
        final BniNetwork makeNetwork = new NetworkFactoryFromRNNetwork().makeNetwork((edu.rice.cs.bioinfo.library.language.richnewick._1_0.reading.ast.NetworkNonEmpty) this._speciesNetwork);
        Iterator it3 = makeNetwork.bfs().iterator();
        while (it3.hasNext()) {
            NetNode netNode = (NetNode) it3.next();
            for (NetNode netNode2 : netNode.getChildren()) {
                double parentDistance = netNode2.getParentDistance(netNode);
                if (parentDistance == Double.NEGATIVE_INFINITY || Double.isNaN(parentDistance)) {
                    parentDistance = 1.0d;
                }
                netNode2.setParentDistance(netNode, parentDistance);
                if (netNode2.getParentNumber() == 2) {
                    Iterator it4 = netNode2.getParents().iterator();
                    while (it4.hasNext()) {
                        if (netNode2.getParentProbability((NetNode) it4.next()) == 1.0d) {
                            netNode2.setParentProbability(netNode, 0.5d);
                        }
                    }
                }
            }
        }
        boolean z2 = true;
        final Container container = new Container(this._computeGTProbStrategy.execute(makeNetwork, arrayList, arrayList2));
        for (int i2 = 0; i2 < this._assigmentRounds && z2; i2++) {
            double doubleValue = ((Double) container.getContents()).doubleValue();
            ArrayList arrayList3 = new ArrayList();
            Iterator it5 = makeNetwork.bfs().iterator();
            while (it5.hasNext()) {
                final NetNode netNode3 = (NetNode) it5.next();
                for (final NetNode netNode4 : netNode3.getChildren()) {
                    if (netNode3.isRoot() && !netNode4.isLeaf()) {
                        arrayList3.add(new Proc() { // from class: edu.rice.cs.bioinfo.programs.phylonet.commands.SearchBranchLengthsMaxGTProb.1
                            @Override // edu.rice.cs.bioinfo.library.programming.Proc
                            public void execute() {
                                UnivariateFunction univariateFunction = new UnivariateFunction() { // from class: edu.rice.cs.bioinfo.programs.phylonet.commands.SearchBranchLengthsMaxGTProb.1.1
                                    @Override // org.apache.commons.math3.analysis.UnivariateFunction
                                    public double value(double d) {
                                        double parentDistance2 = netNode4.getParentDistance(netNode3);
                                        netNode4.setParentDistance(netNode3, d);
                                        double doubleValue2 = ((Double) SearchBranchLengthsMaxGTProb.this._computeGTProbStrategy.execute(makeNetwork, arrayList, arrayList2)).doubleValue();
                                        new RnNewickPrinter().print(makeNetwork, new StringWriter());
                                        if (doubleValue2 > ((Double) container.getContents()).doubleValue()) {
                                            container.setContents(Double.valueOf(doubleValue2));
                                        } else {
                                            netNode4.setParentDistance(netNode3, parentDistance2);
                                        }
                                        return doubleValue2;
                                    }
                                };
                                try {
                                    new BrentOptimizer(1.0E-12d, 1.0E-16d).optimize(SearchBranchLengthsMaxGTProb.this._maxAssigmentAttemptsPerBranchParam, univariateFunction, GoalType.MAXIMIZE, Double.MIN_VALUE, SearchBranchLengthsMaxGTProb.this._maxBranchLength);
                                } catch (TooManyEvaluationsException e2) {
                                }
                            }
                        });
                    }
                }
            }
            Iterator it6 = makeNetwork.bfs().iterator();
            while (it6.hasNext()) {
                final NetNode netNode5 = (NetNode) it6.next();
                if (!netNode5.isRoot() && netNode5.getParentNumber() == 2) {
                    Iterator it7 = netNode5.getParents().iterator();
                    final NetNode netNode6 = (NetNode) it7.next();
                    final NetNode netNode7 = (NetNode) it7.next();
                    arrayList3.add(new Proc() { // from class: edu.rice.cs.bioinfo.programs.phylonet.commands.SearchBranchLengthsMaxGTProb.2
                        @Override // edu.rice.cs.bioinfo.library.programming.Proc
                        public void execute() {
                            UnivariateFunction univariateFunction = new UnivariateFunction() { // from class: edu.rice.cs.bioinfo.programs.phylonet.commands.SearchBranchLengthsMaxGTProb.2.1
                                @Override // org.apache.commons.math3.analysis.UnivariateFunction
                                public double value(double d) {
                                    double parentProbability = netNode5.getParentProbability(netNode6);
                                    netNode5.setParentProbability(netNode6, d);
                                    netNode5.setParentProbability(netNode7, 1.0d - d);
                                    double doubleValue2 = ((Double) SearchBranchLengthsMaxGTProb.this._computeGTProbStrategy.execute(makeNetwork, arrayList, arrayList2)).doubleValue();
                                    if (doubleValue2 > ((Double) container.getContents()).doubleValue()) {
                                        container.setContents(Double.valueOf(doubleValue2));
                                    } else {
                                        netNode5.setParentProbability(netNode6, parentProbability);
                                        netNode5.setParentProbability(netNode7, 1.0d - parentProbability);
                                    }
                                    return doubleValue2;
                                }
                            };
                            try {
                                new BrentOptimizer(1.0E-12d, 1.0E-16d).optimize(SearchBranchLengthsMaxGTProb.this._maxAssigmentAttemptsPerBranchParam, univariateFunction, GoalType.MAXIMIZE, 0.0d, 1.0d);
                            } catch (TooManyEvaluationsException e2) {
                            }
                        }
                    });
                }
            }
            Iterator it8 = arrayList3.iterator();
            while (it8.hasNext()) {
                ((Proc) it8.next()).execute();
            }
            if (((Double) container.getContents()).doubleValue() == doubleValue) {
                z2 = false;
            } else {
                if (((Double) container.getContents()).doubleValue() <= doubleValue) {
                    throw new IllegalStateException("Should never have decreased prob.");
                }
                if (Math.pow(2.718281828459045d, ((Double) container.getContents()).doubleValue() - doubleValue) - 1.0d < this._improvementThreshold) {
                    z2 = false;
                }
            }
        }
        RnNewickPrinter rnNewickPrinter = new RnNewickPrinter();
        StringWriter stringWriter = new StringWriter();
        rnNewickPrinter.print(makeNetwork, stringWriter);
        String stringWriter2 = stringWriter.toString();
        richNewickGenerated(stringWriter2);
        stringBuffer.append("\nTotal log probability: " + container.getContents() + ": " + stringWriter2);
        return stringBuffer.toString();
    }

    @Override // edu.rice.cs.bioinfo.programs.phylonet.commands.CommandBaseFileOut, edu.rice.cs.bioinfo.programs.phylonet.commands.CommandBase
    public /* bridge */ /* synthetic */ void executeCommandHelp(Proc1 proc1) throws IOException {
        super.executeCommandHelp(proc1);
    }

    @Override // edu.rice.cs.bioinfo.programs.phylonet.commands.CommandBaseFileOut
    public /* bridge */ /* synthetic */ boolean getRedirectOutputToFile() {
        return super.getRedirectOutputToFile();
    }
}
