package pal.eval;

import java.util.Vector;
import pal.alignment.SitePattern;
import pal.misc.Identifier;
import pal.misc.PalObjectEvent;
import pal.misc.PalObjectListener;
import pal.misc.Utils;
import pal.substmodel.RateMatrix;
import pal.tree.Node;
import pal.tree.Tree;

/* loaded from: input_file:pal/eval/FastLikelihoodCalculator.class */
public class FastLikelihoodCalculator implements PalObjectListener, LikelihoodCalculator {
    RootNode root_;
    boolean modelChanged_;
    int numberOfSites_;
    int numberOfStates_;
    RateMatrix model_;
    SitePattern sitePattern_;
    private static double THRESHOLD = 1.0E-12d;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:pal/eval/FastLikelihoodCalculator$InternalNode.class */
    public class InternalNode extends NNode {
        private NNode[] children_;
        private double[][][] childSiteStateProbs_;
        double[] endStateProbs_;
        private final FastLikelihoodCalculator this$0;

        /* JADX WARN: Type inference failed for: r1v11, types: [double[][], double[][][]] */
        public InternalNode(FastLikelihoodCalculator fastLikelihoodCalculator, Node node) {
            super(fastLikelihoodCalculator, node);
            this.this$0 = fastLikelihoodCalculator;
            this.children_ = new NNode[node.getChildCount()];
            for (int i = 0; i < this.children_.length; i++) {
                this.children_[i] = fastLikelihoodCalculator.create(node.getChild(i));
            }
            this.childSiteStateProbs_ = new double[this.children_.length];
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public void setModel(RateMatrix rateMatrix) {
            super.setModel(rateMatrix);
            if (this.endStateProbs_ == null || this.this$0.numberOfStates_ != this.endStateProbs_.length) {
                this.endStateProbs_ = new double[this.this$0.numberOfStates_];
            }
            for (int i = 0; i < this.children_.length; i++) {
                this.children_[i].setModel(rateMatrix);
            }
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public void setupSequences(SitePattern sitePattern) {
            super.setupSequences(sitePattern);
            for (int i = 0; i < this.children_.length; i++) {
                this.children_[i].setupSequences(sitePattern);
            }
        }

        public boolean isLeaf() {
            return false;
        }

        private final boolean populateChildSiteStateProbs() {
            boolean z = true;
            for (int i = 0; i < this.children_.length; i++) {
                double[][] calculateSiteStateProbabilities = this.children_[i].calculateSiteStateProbabilities();
                if (calculateSiteStateProbabilities != null) {
                    this.childSiteStateProbs_[i] = calculateSiteStateProbabilities;
                    z = false;
                } else if (this.childSiteStateProbs_[i] == null) {
                    throw new RuntimeException("Assertion error : Not as should be!");
                }
            }
            return z;
        }

        protected final int getNumberOfChildren() {
            return this.children_.length;
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public NNode switchNodes(Node node) {
            int childCount = node.getChildCount();
            if (childCount == 0) {
                return this.this$0.create(node);
            }
            if (childCount != this.children_.length) {
                NNode[] nNodeArr = new NNode[childCount];
                for (int i = 0; i < childCount; i++) {
                    if (i < this.children_.length) {
                        nNodeArr[i] = this.children_[i].switchNodes(node.getChild(i));
                    } else {
                        nNodeArr[i] = this.this$0.create(node.getChild(i));
                    }
                }
                this.children_ = nNodeArr;
            } else {
                for (int i2 = 0; i2 < childCount; i2++) {
                    this.children_[i2] = this.children_[i2].switchNodes(node.getChild(i2));
                }
            }
            setPeer(node);
            return this;
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public double[][] calculateSiteStateProbabilities() {
            if (populateChildSiteStateProbs() && !this.this$0.modelChanged_ && !isBranchLengthChanged()) {
                return null;
            }
            double[][] transitionProbabilities = getTransitionProbabilities();
            double[][] siteStateProbabilities = getSiteStateProbabilities();
            for (int i = 0; i < siteStateProbabilities.length; i++) {
                for (int i2 = 0; i2 < this.this$0.numberOfStates_; i2++) {
                    double d = this.childSiteStateProbs_[0][i][i2];
                    for (int i3 = 1; i3 < this.childSiteStateProbs_.length; i3++) {
                        d *= this.childSiteStateProbs_[i3][i][i2];
                    }
                    this.endStateProbs_[i2] = d;
                }
                for (int i4 = 0; i4 < this.this$0.numberOfStates_; i4++) {
                    double d2 = 0.0d;
                    for (int i5 = 0; i5 < this.this$0.numberOfStates_; i5++) {
                        d2 += transitionProbabilities[i4][i5] * this.endStateProbs_[i5];
                    }
                    siteStateProbabilities[i][i4] = d2;
                }
            }
            return siteStateProbabilities;
        }

        public double calculateFinal(double[] dArr, int[] iArr) {
            populateChildSiteStateProbs();
            double d = 0.0d;
            for (int i = 0; i < this.this$0.numberOfSites_; i++) {
                double d2 = 0.0d;
                for (int i2 = 0; i2 < this.this$0.numberOfStates_; i2++) {
                    double d3 = this.childSiteStateProbs_[0][i][i2];
                    for (int i3 = 1; i3 < this.childSiteStateProbs_.length; i3++) {
                        d3 *= this.childSiteStateProbs_[i3][i][i2];
                    }
                    d2 += dArr[i2] * d3;
                }
                d += Math.log(d2) * iArr[i];
            }
            return d;
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public LeafNode[] getLeafNodes() {
            Vector vector = new Vector();
            for (int i = 0; i < this.children_.length; i++) {
                for (LeafNode leafNode : this.children_[i].getLeafNodes()) {
                    vector.addElement(leafNode);
                }
            }
            LeafNode[] leafNodeArr = new LeafNode[vector.size()];
            vector.copyInto(leafNodeArr);
            return leafNodeArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:pal/eval/FastLikelihoodCalculator$LeafNode.class */
    public class LeafNode extends NNode {
        private final FastLikelihoodCalculator this$0;

        public LeafNode(FastLikelihoodCalculator fastLikelihoodCalculator, Node node) {
            super(fastLikelihoodCalculator, node);
            this.this$0 = fastLikelihoodCalculator;
        }

        public double computeLikelihood() {
            return 0.0d;
        }

        public boolean isLeaf() {
            return true;
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        protected final void setPeer(Node node) {
            if (!this.peer_.getIdentifier().getName().equals(node.getIdentifier().getName())) {
                this.lastLength_ = Double.NEGATIVE_INFINITY;
            }
            this.peer_ = node;
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public NNode switchNodes(Node node) {
            if (node.getChildCount() != 0) {
                return this.this$0.create(node);
            }
            setPeer(node);
            return this;
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public LeafNode[] getLeafNodes() {
            return new LeafNode[]{this};
        }

        @Override // pal.eval.FastLikelihoodCalculator.NNode
        public double[][] calculateSiteStateProbabilities() {
            if (!this.this$0.modelChanged_ && !isBranchLengthChanged()) {
                return null;
            }
            byte[] sequence = getSequence();
            double[][] transitionProbabilitiesReverse = getTransitionProbabilitiesReverse();
            double[][] siteStateProbabilities = getSiteStateProbabilities();
            for (int i = 0; i < sequence.length; i++) {
                byte b = sequence[i];
                if (b < 0) {
                    for (int i2 = 0; i2 < this.this$0.numberOfStates_; i2++) {
                        siteStateProbabilities[i][i2] = 1.0d;
                    }
                } else {
                    for (int i3 = 0; i3 < this.this$0.numberOfStates_; i3++) {
                        siteStateProbabilities[i][i3] = transitionProbabilitiesReverse[i3][b];
                    }
                }
            }
            return siteStateProbabilities;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:pal/eval/FastLikelihoodCalculator$NNode.class */
    public abstract class NNode {
        private double[][] transitionProbs_;
        double lastLength_ = Double.NEGATIVE_INFINITY;
        Node peer_;
        private byte[] sequence_;
        private double[][] siteStateProbabilities_;
        private final FastLikelihoodCalculator this$0;

        public NNode(FastLikelihoodCalculator fastLikelihoodCalculator, Node node) {
            this.this$0 = fastLikelihoodCalculator;
            this.peer_ = node;
        }

        public void setModel(RateMatrix rateMatrix) {
            if (this.transitionProbs_ == null || this.this$0.numberOfStates_ != this.transitionProbs_.length) {
                this.transitionProbs_ = new double[this.this$0.numberOfStates_][this.this$0.numberOfStates_];
                this.siteStateProbabilities_ = new double[this.this$0.numberOfSites_][this.this$0.numberOfStates_];
            }
        }

        protected void setPeer(Node node) {
            this.peer_ = node;
        }

        public final boolean isBranchLengthChanged() {
            return Math.abs(this.peer_.getBranchLength() - this.lastLength_) > FastLikelihoodCalculator.THRESHOLD;
        }

        protected final double[][] getSiteStateProbabilities() {
            return this.siteStateProbabilities_;
        }

        public final void setSequence(byte[] bArr) {
            this.sequence_ = Utils.getCopy(bArr);
            for (int i = 0; i < this.sequence_.length; i++) {
                if (bArr[i] >= this.this$0.numberOfStates_) {
                    this.sequence_[i] = -1;
                }
            }
        }

        public final boolean hasSequence() {
            return this.sequence_ != null;
        }

        public final byte[] getSequence() {
            return this.sequence_;
        }

        protected double[][] getTransitionProbabilities() {
            if (this.this$0.modelChanged_ || isBranchLengthChanged()) {
                double branchLength = this.peer_.getBranchLength();
                this.this$0.model_.setDistance(branchLength);
                this.this$0.model_.getTransitionProbabilities(this.transitionProbs_);
                this.lastLength_ = branchLength;
            }
            return this.transitionProbs_;
        }

        protected double[][] getTransitionProbabilitiesReverse() {
            if (this.this$0.modelChanged_ || isBranchLengthChanged()) {
                double branchLength = this.peer_.getBranchLength();
                this.this$0.model_.setDistance(branchLength);
                this.this$0.model_.getTransitionProbabilities(this.transitionProbs_);
                this.lastLength_ = branchLength;
            }
            return this.transitionProbs_;
        }

        private String toString(byte[] bArr) {
            char[] cArr = new char[bArr.length];
            for (int i = 0; i < cArr.length; i++) {
                cArr[i] = (char) (65 + bArr[i]);
            }
            return new String(cArr);
        }

        public void setupSequences(SitePattern sitePattern) {
            int whichIdNumber;
            Identifier identifier = this.peer_.getIdentifier();
            if (identifier == null || (whichIdNumber = sitePattern.whichIdNumber(identifier.getName())) < 0) {
                return;
            }
            if (this.sequence_ == null) {
                this.sequence_ = new byte[sitePattern.pattern[whichIdNumber].length];
            }
            byte[] bArr = sitePattern.pattern[whichIdNumber];
            for (int i = 0; i < this.sequence_.length; i++) {
                if (bArr[i] >= this.this$0.numberOfStates_) {
                    this.sequence_[i] = -1;
                } else {
                    this.sequence_[i] = bArr[i];
                }
            }
        }

        public abstract double[][] calculateSiteStateProbabilities();

        public abstract LeafNode[] getLeafNodes();

        public abstract NNode switchNodes(Node node);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:pal/eval/FastLikelihoodCalculator$RootNode.class */
    public class RootNode extends InternalNode {
        double[] equilibriumProbabilities_;
        int[] siteWeightings_;
        private final FastLikelihoodCalculator this$0;

        public RootNode(FastLikelihoodCalculator fastLikelihoodCalculator, Node node) {
            this(fastLikelihoodCalculator, node, null);
        }

        public RootNode(FastLikelihoodCalculator fastLikelihoodCalculator, Node node, double[] dArr) {
            super(fastLikelihoodCalculator, node);
            this.this$0 = fastLikelihoodCalculator;
            this.equilibriumProbabilities_ = dArr;
        }

        public double computeLikelihood() {
            double calculateFinal = calculateFinal(this.equilibriumProbabilities_, this.siteWeightings_);
            this.this$0.modelChanged_ = false;
            return calculateFinal;
        }

        @Override // pal.eval.FastLikelihoodCalculator.InternalNode, pal.eval.FastLikelihoodCalculator.NNode
        public void setModel(RateMatrix rateMatrix) {
            super.setModel(rateMatrix);
            this.equilibriumProbabilities_ = rateMatrix.getEquilibriumFrequencies();
        }

        @Override // pal.eval.FastLikelihoodCalculator.InternalNode, pal.eval.FastLikelihoodCalculator.NNode
        public void setupSequences(SitePattern sitePattern) {
            super.setupSequences(sitePattern);
            this.siteWeightings_ = sitePattern.weight;
        }
    }

    public FastLikelihoodCalculator(SitePattern sitePattern) {
        this.modelChanged_ = false;
        this.sitePattern_ = sitePattern;
        this.numberOfSites_ = sitePattern.getNumberOfPatterns();
        this.numberOfStates_ = sitePattern.getDataType().getNumStates();
    }

    public FastLikelihoodCalculator(SitePattern sitePattern, Tree tree, RateMatrix rateMatrix) {
        this(sitePattern);
        setTree(tree);
        setRateMatrix(rateMatrix);
    }

    @Override // pal.misc.PalObjectListener
    public void parametersChanged(PalObjectEvent palObjectEvent) {
        this.modelChanged_ = true;
    }

    @Override // pal.misc.PalObjectListener
    public void structureChanged(PalObjectEvent palObjectEvent) {
        this.modelChanged_ = true;
    }

    public final void setRateMatrix(RateMatrix rateMatrix) {
        if (this.model_ == null || this.model_ != rateMatrix) {
            this.model_ = rateMatrix;
            this.model_.addPalObjectListener(this);
            this.root_.setModel(this.model_);
            this.modelChanged_ = true;
        }
    }

    @Override // pal.eval.LikelihoodCalculator
    public void release() {
        try {
            this.model_.removePalObjectListener(this);
            this.model_ = null;
        } catch (NullPointerException e) {
        }
    }

    public final void setTree(Tree tree) {
        if (this.root_ == null) {
            this.root_ = new RootNode(this, tree.getRoot());
        } else {
            if (this.root_.switchNodes(tree.getRoot()) != this.root_) {
                throw new RuntimeException("Assertion error : new tree generates new Root NNode (tree probably contains only one branch)");
            }
            this.root_.setModel(this.model_);
        }
        this.root_.setupSequences(this.sitePattern_);
    }

    public final void updateSitePattern(SitePattern sitePattern) {
        this.sitePattern_ = sitePattern;
        this.root_.setupSequences(sitePattern);
        if (sitePattern.numPatterns != this.numberOfSites_) {
            this.numberOfSites_ = sitePattern.numPatterns;
            this.root_.setModel(this.model_);
            this.modelChanged_ = true;
        }
    }

    @Override // pal.eval.LikelihoodCalculator
    public double calculateLogLikelihood() {
        return this.root_.computeLikelihood();
    }

    final NNode create(Node node) {
        return node.getChildCount() == 0 ? new LeafNode(this, node) : new InternalNode(this, node);
    }
}
