package edu.rice.cs.bioinfo.programs.phylonet.algos.network;

import edu.rice.cs.bioinfo.programs.phylonet.structs.network.NetNode;
import edu.rice.cs.bioinfo.programs.phylonet.structs.network.Network;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.model.TMutableNode;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.model.TNode;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.model.Tree;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.model.sti.STITree;
import edu.rice.cs.bioinfo.programs.phylonet.structs.tree.model.sti.STITreeCluster;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/rice/cs/bioinfo/programs/phylonet/algos/network/MDCOnNetwork.class */
public class MDCOnNetwork {
    boolean _printDetail = false;

    public void setPrintDetails(boolean z) {
        this._printDetail = z;
    }

    public List<Integer> countExtraCoal(Network network, List<Tree> list, Map<String, String> map) {
        ArrayList arrayList = new ArrayList();
        Map<String, Integer> hashMap = new HashMap<>();
        Tree networkToTree = networkToTree(network, hashMap);
        Map<String, String> hashMap2 = new HashMap<>();
        for (Map.Entry<String, Integer> entry : hashMap.entrySet()) {
            if (entry.getValue().intValue() > 1) {
                for (int i = 1; i <= entry.getValue().intValue(); i++) {
                    hashMap2.put(entry.getKey() + "_" + i, entry.getKey());
                }
            }
        }
        for (Tree tree : list) {
            List<String> asList = Arrays.asList(tree.getLeaves());
            if (map == null) {
                map = new HashMap();
                for (String str : asList) {
                    map.put(str, str);
                }
            }
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            List<Integer> arrayList4 = new ArrayList<>();
            for (String str2 : asList) {
                String str3 = map.get(str2);
                int indexOf = arrayList2.indexOf(str3);
                if (indexOf == -1) {
                    arrayList2.add(str3);
                    ArrayList arrayList5 = new ArrayList();
                    arrayList5.add(str2);
                    arrayList3.add(arrayList5);
                    arrayList4.add(hashMap.get(str3));
                } else {
                    ((List) arrayList3.get(indexOf)).add(str2);
                }
            }
            List<int[]> arrayList6 = new ArrayList<>();
            Iterator it = arrayList3.iterator();
            while (it.hasNext()) {
                int[] iArr = new int[((List) it.next()).size()];
                Arrays.fill(iArr, 1);
                arrayList6.add(iArr);
            }
            int i2 = Integer.MAX_VALUE;
            do {
                HashMap hashMap3 = new HashMap();
                for (int i3 = 0; i3 < arrayList2.size(); i3++) {
                    String str4 = (String) arrayList2.get(i3);
                    List list2 = (List) arrayList3.get(i3);
                    int[] iArr2 = arrayList6.get(i3);
                    for (int i4 = 0; i4 < list2.size(); i4++) {
                        hashMap3.put(list2.get(i4), str4 + "_" + iArr2[i4]);
                    }
                }
                int countExtraCoal = countExtraCoal(tree, networkToTree, hashMap3, hashMap2);
                if (this._printDetail) {
                    System.out.println(hashMap3 + ":  " + countExtraCoal);
                }
                i2 = Math.min(countExtraCoal, i2);
            } while (mergeNumberAddOne(arrayList6, arrayList4));
            arrayList.add(Integer.valueOf(i2));
        }
        return arrayList;
    }

    private boolean mergeNumberAddOne(List<int[]> list, List<Integer> list2) {
        for (int i = 0; i < list.size(); i++) {
            int[] iArr = list.get(i);
            int intValue = list2.get(i).intValue();
            for (int i2 = 0; i2 < iArr.length; i2++) {
                if (iArr[i2] != intValue) {
                    iArr[i2] = iArr[i2] + 1;
                    return true;
                }
                iArr[i2] = 1;
            }
            Arrays.fill(iArr, 1);
        }
        return false;
    }

    private Tree networkToTree(Network<Double> network, Map<String, Integer> map) {
        removeBinaryNodes(network);
        STITree sTITree = new STITree();
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        linkedList.offer(network.getRoot());
        linkedList2.offer((TMutableNode) sTITree.getRoot());
        long currentTimeMillis = System.currentTimeMillis();
        while (!linkedList.isEmpty()) {
            NetNode netNode = (NetNode) linkedList.poll();
            TMutableNode tMutableNode = (TMutableNode) linkedList2.poll();
            for (NetNode netNode2 : netNode.getChildren()) {
                if (netNode2.getName().equals("")) {
                    long j = currentTimeMillis;
                    currentTimeMillis = j + 1;
                    netNode2.setName("hnode" + j);
                }
                String name = netNode2.getName();
                if (netNode2.isNetworkNode()) {
                    name = netNode2.getName() + "TO" + netNode.getName();
                }
                Integer num = map.get(name);
                if (num == null) {
                    num = 0;
                }
                Integer valueOf = Integer.valueOf(num.intValue() + 1);
                map.put(name, valueOf);
                TMutableNode createChild = tMutableNode.createChild(name + "_" + valueOf);
                double parentDistance = netNode2.getParentDistance(netNode);
                if (parentDistance == Double.NEGATIVE_INFINITY) {
                    createChild.setParentDistance(0.0d);
                } else {
                    createChild.setParentDistance(parentDistance);
                }
                linkedList.offer(netNode2);
                linkedList2.offer(createChild);
            }
        }
        return sTITree;
    }

    private int countExtraCoal(Tree tree, Tree tree2, Map<String, String> map, Map<String, String> map2) {
        int i = 0;
        String[] leaves = tree2.getLeaves();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (TNode tNode : tree2.postTraverse()) {
            BitSet bitSet = new BitSet();
            if (tNode.isLeaf()) {
                int i2 = 0;
                while (true) {
                    if (i2 >= leaves.length) {
                        break;
                    }
                    if (tNode.getName().equals(leaves[i2])) {
                        bitSet.set(i2);
                        break;
                    }
                    i2++;
                }
                hashMap.put(tNode, bitSet);
            } else {
                Iterator<? extends TNode> it = tNode.getChildren().iterator();
                while (it.hasNext()) {
                    bitSet.or((BitSet) hashMap.get(it.next()));
                }
                hashMap.put(tNode, bitSet);
            }
            if (tNode.getChildCount() != 1 || tNode.getParentDistance() != 0.0d) {
                STITreeCluster sTITreeCluster = new STITreeCluster(leaves);
                sTITreeCluster.setCluster(bitSet);
                if (sTITreeCluster.getClusterSize() < leaves.length) {
                    int clusterCoalNum = getClusterCoalNum(tree, sTITreeCluster, map);
                    String name = tNode.getName();
                    if (name == null || !map2.containsKey(name)) {
                        i += Math.max(0, clusterCoalNum - 1);
                    } else {
                        String str = map2.get(name);
                        Integer num = (Integer) hashMap2.get(str);
                        if (num == null) {
                            num = 0;
                        }
                        hashMap2.put(str, Integer.valueOf(num.intValue() + clusterCoalNum));
                    }
                }
            }
        }
        Iterator it2 = hashMap2.entrySet().iterator();
        while (it2.hasNext()) {
            i += Math.max(0, ((Integer) ((Map.Entry) it2.next()).getValue()).intValue() - 1);
        }
        return i;
    }

    private int getClusterCoalNum(Tree tree, STITreeCluster sTITreeCluster, Map<String, String> map) {
        HashMap hashMap = new HashMap();
        LinkedList linkedList = new LinkedList();
        Collections.addAll(linkedList, sTITreeCluster.getTaxa());
        int i = 0;
        for (TNode tNode : tree.postTraverse()) {
            if (tNode.isLeaf()) {
                int indexOf = linkedList.indexOf(map.get(tNode.getName()));
                BitSet bitSet = new BitSet(linkedList.size());
                bitSet.set(indexOf);
                if (sTITreeCluster.containsCluster(bitSet)) {
                    i++;
                }
                hashMap.put(tNode, bitSet);
            } else {
                BitSet bitSet2 = new BitSet(linkedList.size());
                int i2 = 0;
                int childCount = tNode.getChildCount();
                Iterator<? extends TNode> it = tNode.getChildren().iterator();
                while (it.hasNext()) {
                    BitSet bitSet3 = (BitSet) hashMap.get(it.next());
                    bitSet2.or(bitSet3);
                    if (childCount > 2 && sTITreeCluster.containsCluster(bitSet3)) {
                        i2++;
                    }
                }
                if (sTITreeCluster.containsCluster(bitSet2)) {
                    i = (i - tNode.getChildCount()) + 1;
                } else if (i2 > 1) {
                    i = (i - i2) + 1;
                }
                hashMap.put(tNode, bitSet2);
            }
        }
        return i;
    }

    private void removeBinaryNodes(Network<Double> network) {
        LinkedList<NetNode> linkedList = new LinkedList();
        for (NetNode<Double> netNode : network.bfs()) {
            if (netNode.getIndeg() == 1 && netNode.getOutdeg() == 1) {
                linkedList.add(netNode);
            }
        }
        for (NetNode netNode2 : linkedList) {
            NetNode netNode3 = (NetNode) netNode2.getChildren().iterator().next();
            if (netNode3.getIndeg() == 1) {
                NetNode netNode4 = (NetNode) netNode2.getParents().iterator().next();
                double parentDistance = netNode2.getParentDistance(netNode4) + netNode3.getParentDistance(netNode2);
                double parentProbability = netNode2.getParentProbability(netNode4) * netNode3.getParentProbability(netNode2);
                netNode4.removeChild(netNode2);
                netNode2.removeChild(netNode3);
                netNode4.adoptChild(netNode3, parentDistance);
                netNode3.setParentProbability(netNode4, parentProbability);
            }
        }
    }
}
