001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.bio.dp;
024
025import java.io.Serializable;
026import java.util.ArrayList;
027import java.util.Collections;
028import java.util.Comparator;
029import java.util.HashMap;
030import java.util.HashSet;
031import java.util.Iterator;
032import java.util.LinkedList;
033import java.util.List;
034import java.util.ListIterator;
035import java.util.Map;
036import java.util.Set;
037
038import org.biojava.bio.BioError;
039import org.biojava.bio.BioException;
040import org.biojava.bio.dist.Distribution;
041import org.biojava.bio.symbol.DoubleAlphabet;
042import org.biojava.bio.symbol.FiniteAlphabet;
043import org.biojava.bio.symbol.IllegalAlphabetException;
044import org.biojava.bio.symbol.IllegalSymbolException;
045import org.biojava.bio.symbol.SimpleSymbolList;
046import org.biojava.bio.symbol.Symbol;
047import org.biojava.bio.symbol.SymbolList;
048import org.biojava.utils.ChangeEvent;
049import org.biojava.utils.ChangeListener;
050import org.biojava.utils.ChangeType;
051import org.biojava.utils.ChangeVetoException;
052
053/**
054 * <p>
055 * Objects that can perform dymamic programming operations upon sequences with
056 * HMMs.
057 * </p>
058 *
059 * <p>
060 * The three main DP operations are Forwards, Backwards and Viterbi. Forwards
061 * and Backwards calculate the probability of the sequences having been made in
062 * any way by the model. Viterbi finds the most supported way that the sequence
063 * could have been made.
064 * </p>
065 *
066 * <p>
067 * Each of the functions can return the dynamic-programming matrix containing
068 * the intermediate results. This may be useful for model training, or for
069 * visualisation.
070 * </p>
071 *
072 * <p>
073 * Each of the funcitons can be calculated using the model probabilities, the
074 * null-model probabilities or the odds (ratio between the two). For Forwards
075 * and Backwards, the odds calculations produce numbers with questionable basis
076 * in reality. For Viterbi with odds, you will recieve the path through the
077 * model that is most different from the null model, and supported by the
078 * probabilities.
079 * </p>
080 *
081 * @author Matthew Pocock
082 * @author Thomas Down
083 */
084public abstract class DP {
085  private static List NO_ADVANCE = new ArrayList();
086
087  private int[] getNoAdvance() {
088    int heads = getModel().advance().length;
089    int[] no_advance = (int[]) NO_ADVANCE.get(heads);
090
091    if (no_advance == null) {
092      no_advance = new int[heads];
093      for (int i = 0; i < heads; i++) {
094        no_advance[i] = 0;
095      }
096
097      NO_ADVANCE.add(heads, no_advance);
098    }
099
100    return no_advance;
101  }
102
103  /**
104   * Scores the SymbolList from symbol start to symbol (start+columns) with a
105   * weight matrix.
106   *
107   * @param matrix  the weight matrix used to evaluate the sequences
108   * @param symList the SymbolList to assess
109   * @param start   the index of the first symbol in the window to evaluate
110   * @return  the log probability or likelyhood of this weight matrix
111   *          having generated symbols start to (start + columns) of symList
112   */
113  public static double scoreWeightMatrix(
114          WeightMatrix matrix, SymbolList symList, int start)
115          throws IllegalSymbolException {
116    double score = 0;
117    int cols = matrix.columns();
118
119    for (int c = 0; c < cols; c++) {
120      score += Math.log(
121              matrix.getColumn(c).getWeight(symList.symbolAt(c + start)));
122    }
123
124    return score;
125  }
126
127  /**
128   * Scores the SymbolList from symbol start to symbol (start+columns) with a
129   * weight matrix using a particular ScoreType.
130   *
131   * <p>
132   * This method allows you to use score types such as ScoreType.ODDS. The other
133   * scoreWeightMatrix methods gives a result similar or identical to
134   * ScoreType.PROBABILITY.
135   * </p>
136   *
137   * @param matrix  the weight matrix used to evaluate the sequences
138   * @param symList the SymbolList to assess
139   * @param scoreType the score type to apply
140   * @param start   the index of the first symbol in the window to evaluate
141   * @return  the sum of log scores of this weight matrix
142   *          having generated symbols start to (start + columns) of symList
143   * @since 1.4
144   */
145  public static double scoreWeightMatrix(
146          WeightMatrix matrix,
147          SymbolList symList,
148          ScoreType scoreType,
149          int start)
150          throws IllegalSymbolException {
151    double score = 0;
152    int cols = matrix.columns();
153
154    for (int c = 0; c < cols; c++) {
155      score += Math.log(scoreType.calculateScore(
156              matrix.getColumn(c), symList.symbolAt(c + start)));
157    }
158
159    return score;
160  }
161  public static MarkovModel flatView(MarkovModel model)
162          throws IllegalAlphabetException, IllegalSymbolException {
163    for (Iterator i = model.stateAlphabet().iterator(); i.hasNext();) {
164      State s = (State) i.next();
165      if (
166              !(s instanceof DotState) &&
167              !(s instanceof EmissionState)
168      ) {
169        return new FlatModel(model);
170      }
171    }
172
173    return model;
174  }
175
176  public State[] stateList(MarkovModel mm)
177          throws IllegalSymbolException, IllegalTransitionException,
178          BioException {
179    FiniteAlphabet alpha = mm.stateAlphabet();
180
181    List emissionStates = new ArrayList();
182    HMMOrderByTransition comp = new HMMOrderByTransition(mm);
183    List dotStates = new LinkedList();
184    for (Iterator addStates = alpha.iterator(); addStates.hasNext();) {
185      Object state = addStates.next();
186      if (state instanceof MagicalState) {
187        emissionStates.add(0, state);
188      } else if (state instanceof EmissionState) {
189        emissionStates.add(state);
190      } else {
191        ListIterator checkOld = dotStates.listIterator();
192        int insertPos = -1;
193        while (checkOld.hasNext() && insertPos == -1) {
194          Object oldState = checkOld.next();
195          if (comp.compare(state, oldState) == HMMOrderByTransition.LESS_THAN) {
196            insertPos = checkOld.nextIndex() - 1;
197          }
198        }
199        if (insertPos >= 0) {
200          dotStates.add(insertPos, state);
201        } else {
202          dotStates.add(state);
203        }
204      }
205    }
206    Collections.sort(emissionStates, new Comparator() {
207      public int compare(Object o1, Object o2) {
208        State s = (State) o1;
209        State t = (State) o2;
210
211        // sort by advance
212        int[] sa;
213        if (s instanceof EmissionState) {
214          sa = ((EmissionState) s).getAdvance();
215        } else {
216          sa = getNoAdvance();
217        }
218
219        int[] ta;
220        if (t instanceof EmissionState) {
221          ta = ((EmissionState) t).getAdvance();
222        } else {
223          ta = getNoAdvance();
224        }
225
226        for (int i = 0; i < sa.length; i++) {
227          if (sa[i] > ta[i]) {
228            return -1;
229          } else if (sa[i] < ta[i]) {
230            return +1;
231          }
232        }
233
234        // give up - sort by name
235        return s.getName().compareTo(t.getName());
236      }
237    });
238    State[] sl = new State[emissionStates.size() + dotStates.size()];
239    int i = 0;
240    for (Iterator si = emissionStates.iterator(); si.hasNext();) {
241      EmissionState ex = (EmissionState) si.next();
242      int[] ad = ex.getAdvance();
243      if (ad.length != mm.advance().length) {
244        throw new BioException(
245                "State " + ex.getName() + " advances " + ad.length + " heads. " +
246                " however, the model " + mm.stateAlphabet().getName() +
247                " advances " + mm.advance().length + " heads."
248        );
249      }
250      for (int adi = 0; ad != null && adi < ad.length; adi++) {
251        if (ad[adi] != 0) {
252          ad = null;
253        }
254      }
255      if (ad != null) {
256        throw new Error(
257                "State " + ex.getName() + " has advance " + ad
258        );
259      }
260      sl[i++] = ex;
261    }
262    for (Iterator si = dotStates.iterator(); si.hasNext();) {
263      sl[i++] = (State) si.next();
264    }
265    return sl;
266  }
267
268  /**
269   * Returns a matrix for the specified States describing all
270   * valid Transitions between those States.
271   * <p>
272   * The matrix is 2-dimensional.  The primary array has an element
273   * corresponding to every State in the states argument.  That
274   * element is itself an array the elements of which identify 
275   * the States that can reach that State.  The source States 
276   * are identified by their index within the states [] array.
277   * @param model MarkovModel to be analysed.
278   * @param states The States for which the transition matrix is to be determined.
279   */
280  public static int[][] forwardTransitions(
281          MarkovModel model,
282          State[] states
283          ) throws IllegalSymbolException {
284    int stateCount = states.length;
285    int[][] transitions = new int[stateCount][];
286
287    for (int i = 0; i < stateCount; i++) {
288      int[] tmp = new int[stateCount];
289      int len = 0;
290      FiniteAlphabet trans = model.transitionsTo(states[i]);
291      for (int j = 0; j < stateCount; j++) {
292        if (trans.contains(states[j])) {
293          tmp[len++] = j;
294        }
295      }
296      int[] tmp2 = new int[len];
297      for (int j = 0; j < len; j++) {
298        tmp2[j] = tmp[j];
299      }
300      transitions[i] = tmp2;
301    }
302
303    return transitions;
304  }
305
306  /**
307   * Compute the log(score) of all transitions
308   * between the specified States.  The layout
309   * of the array is identical to that of the transitions
310   * array.
311   * <p>
312   * Note that all parameters <b>MUST</b> be
313   * consistent with each other!!!!
314   * <p>
315   * @param model The model for which the data is to be computed.
316   * @param states The States within that model for which scores are required.
317   * @param transitions The transition matrix obtained by calling forwardTransitions() with the above argument values.
318   * @param scoreType The type of score to be evaluated.
319   */
320  public static double[][] forwardTransitionScores(
321          MarkovModel model,
322          State[] states,
323          int[][] transitions,
324          ScoreType scoreType
325          ) {
326    // System.out.println("forwardTransitionScores");
327    int stateCount = states.length;
328    double[][] scores = new double[stateCount][];
329
330    for (int i = 0; i < stateCount; i++) {
331      State is = states[i];
332      scores[i] = new double[transitions[i].length];
333      for (int j = 0; j < scores[i].length; j++) {
334        try {
335          scores[i][j] = Math.log(scoreType.calculateScore(
336                  model.getWeights(states[transitions[i][j]]),
337                  is
338          ));
339          /*System.out.println(
340            states[transitions[i][j]] + "\t-> " +
341            is.getName() + "\t = " +
342            scores[i][j] + "\t(" +
343            scoreType.calculateScore(
344              model.getWeights(states[transitions[i][j]]),
345              is
346            )
347          );*/
348        } catch (IllegalSymbolException ite) {
349          throw new BioError(
350                  "Transition listed in transitions array has dissapeared.",
351                  ite);
352        }
353      }
354    }
355
356    return scores;
357  }
358
359  public static int[][] backwardTransitions(
360          MarkovModel model,
361          State[] states
362          ) throws IllegalSymbolException {
363    int stateCount = states.length;
364    int[][] transitions = new int[stateCount][];
365
366    for (int i = 0; i < stateCount; i++) {
367      int[] tmp = new int[stateCount];
368      int len = 0;
369      FiniteAlphabet trans = model.transitionsFrom(states[i]);
370      for (int j = 0; j < stateCount; j++) {
371        if (trans.contains(states[j])) {
372          tmp[len++] = j;
373        }
374      }
375      int[] tmp2 = new int[len];
376      for (int j = 0; j < len; j++) {
377        tmp2[j] = tmp[j];
378      }
379      transitions[i] = tmp2;
380    }
381
382    return transitions;
383  }
384
385  public static double[][] backwardTransitionScores(MarkovModel model,
386                                                    State[] states,
387                                                    int[][] transitions,
388                                                    ScoreType scoreType
389                                                    ) {
390    int stateCount = states.length;
391    double[][] scores = new double[stateCount][];
392
393    for (int i = 0; i < stateCount; i++) {
394      State is = states[i];
395      scores[i] = new double[transitions[i].length];
396      for (int j = 0; j < scores[i].length; j++) {
397        try {
398          scores[i][j] = Math.log(scoreType.calculateScore(
399                  model.getWeights(is),
400                  states[transitions[i][j]]
401          ));
402        } catch (IllegalSymbolException ite) {
403          throw new BioError(
404                  "Transition listed in transitions array has dissapeared",
405                  ite);
406        }
407      }
408    }
409
410    return scores;
411  }
412
413  private MarkovModel model;
414  private State[] states;
415  private int[][] forwardTransitions;
416  private int[][] backwardTransitions;
417  private int dotStatesIndex;
418  private int lockCount = 0;
419
420  public int getDotStatesIndex() {
421    return dotStatesIndex;
422  }
423
424  public MarkovModel getModel() {
425    return model;
426  }
427
428  public State[] getStates() {
429    return states;
430  }
431
432  public int[][] getForwardTransitions() {
433    return forwardTransitions;
434  }
435
436  private Map forwardTransitionScores;
437  private Map backwardTransitionScores;
438
439  public double[][] getForwardTransitionScores(ScoreType scoreType) {
440    double[][] ts = (double[][]) forwardTransitionScores.get(scoreType);
441    if (ts == null) {
442      forwardTransitionScores.put(scoreType, ts = forwardTransitionScores(
443              getModel(), getStates(), forwardTransitions, scoreType
444      ));
445    }
446    return ts;
447  }
448
449  public int[][] getBackwardTransitions() {
450    return backwardTransitions;
451  }
452
453  public double[][] getBackwardTransitionScores(ScoreType scoreType) {
454    double[][] ts = (double[][]) backwardTransitionScores.get(scoreType);
455    if (ts == null) {
456      backwardTransitionScores.put(scoreType, ts = backwardTransitionScores(
457              getModel(), getStates(), backwardTransitions, scoreType
458      ));
459    }
460    return ts;
461  }
462
463  public void lockModel() {
464    if (lockCount++ == 0) {
465      getModel().addChangeListener(ChangeListener.ALWAYS_VETO, ChangeType.UNKNOWN);
466    }
467  }
468
469  public void unlockModel() {
470    if (--lockCount == 0) {
471      getModel().removeChangeListener(ChangeListener.ALWAYS_VETO, ChangeType.UNKNOWN);
472    }
473  }
474
475  public void update() {
476    try {
477      if(this.states == null) {
478        this.states = stateList(model);
479        this.forwardTransitions = forwardTransitions(model, states);
480        this.backwardTransitions = backwardTransitions(model, states);
481
482        // Find first dot state
483        int i;
484        for (i = 0; i < states.length; ++i) {
485          if (!(states[i] instanceof EmissionState)) {
486            break;
487          }
488        }
489        dotStatesIndex = i;
490      }
491
492      this.forwardTransitionScores.clear();
493      this.backwardTransitionScores.clear();
494    } catch (Exception e) {
495      throw new BioError("Something is seriously wrong with the DP code", e);
496    }
497  }
498
499  public DP(MarkovModel model){
500    this.setModel(model);
501  }
502  
503  /**
504   * This method will result in a DP with no model. Use the setModel() method
505   * to set the model before use.
506   */
507  public DP(){}
508  
509  public void setModel(MarkovModel model){
510    this.model = model;
511    this.forwardTransitionScores = new HashMap();
512    this.backwardTransitionScores = new HashMap();
513    this.update();
514
515    model.addChangeListener(UPDATER, ChangeType.UNKNOWN);
516  }
517
518  public abstract double forward(SymbolList[] symList, ScoreType scoreType)
519          throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException;
520
521  public abstract double backward(SymbolList[] symList, ScoreType scoreType)
522          throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException;
523
524  public abstract DPMatrix forwardMatrix(SymbolList[] symList, ScoreType scoreType)
525          throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException;
526
527  public abstract DPMatrix backwardMatrix(SymbolList[] symList, ScoreType scoreType)
528          throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException;
529
530  public abstract DPMatrix forwardMatrix(SymbolList[] symList, DPMatrix matrix, ScoreType scoreType)
531          throws IllegalArgumentException, IllegalSymbolException,
532          IllegalAlphabetException, IllegalTransitionException;
533
534  public abstract DPMatrix backwardMatrix(SymbolList[] symList, DPMatrix matrix, ScoreType scoreType)
535          throws IllegalArgumentException, IllegalSymbolException,
536          IllegalAlphabetException, IllegalTransitionException;
537
538  public abstract StatePath viterbi(SymbolList[] symList, ScoreType scoreType)
539          throws IllegalSymbolException, IllegalArgumentException, IllegalAlphabetException, IllegalTransitionException;
540
541  public DPMatrix forwardsBackwards(SymbolList[] symList, ScoreType scoreType)
542          throws BioException {
543    try {
544      System.out.println("Making backward matrix");
545      final DPMatrix bMatrix = backwardMatrix(symList, scoreType);
546      System.out.println("Making forward matrix");
547      final DPMatrix fMatrix = forwardMatrix(symList, scoreType);
548
549      System.out.println("Making forward/backward matrix");
550      return new DPMatrix() {
551        public double getCell(int[] index) {
552          return fMatrix.getCell(index) + bMatrix.getCell(index);
553        }
554
555        public double getScore() {
556          return fMatrix.getScore();
557        }
558
559        public MarkovModel model() {
560          return fMatrix.model();
561        }
562
563        public SymbolList[] symList() {
564          return fMatrix.symList();
565        }
566
567        public State[] states() {
568          return fMatrix.states();
569        }
570      };
571    } catch (Exception e) {
572      throw new BioException("Couldn't build forwards-backwards matrix", e);
573    }
574  }
575
576  /**
577   * <p>
578   * Generates an alignment from a model.
579   * </p>
580   *
581   * <p>
582   * If the length is set to -1 then the model length will be sampled
583   * using the model's transition to the end state. If the length is
584   * fixed using length, then the transitions to the end state are implicitly
585   * invoked.
586   * </p>
587   *
588   * @param length  the length of the sequence to generate
589   * @return  a StatePath generated at random
590   */
591  public StatePath generate(int length)
592          throws IllegalSymbolException, BioException {
593    List tokenList = new ArrayList();
594    List stateList = new ArrayList();
595    List scoreList = new ArrayList();
596    double totScore = 0.0;
597    double symScore = 0.0;
598    int i = length;
599    State oldState;
600    Symbol token;
601
602    oldState = (State) model.getWeights(model.magicalState()).sampleSymbol();
603    symScore += model.getWeights(model.magicalState()).getWeight(oldState);
604
605    DoubleAlphabet dAlpha = DoubleAlphabet.getInstance();
606    if (oldState instanceof EmissionState) {
607      EmissionState eState = (EmissionState) oldState;
608      token = eState.getDistribution().sampleSymbol();
609      symScore += eState.getDistribution().getWeight(token);
610      stateList.add(oldState);
611      tokenList.add(token);
612      scoreList.add(dAlpha.getSymbol(symScore));
613      totScore += symScore;
614      symScore = 0.0;
615      i--;
616    }
617
618    while (i != 0) {
619      State newState = null;
620      Distribution dist = model.getWeights(oldState);
621      do {
622        newState = (State) dist.sampleSymbol();
623      } while (newState == model.magicalState() && i > 0);
624      try {
625        symScore += dist.getWeight(newState);
626      } catch (IllegalSymbolException ise) {
627        throw new BioError(
628                "Transition returned from sampleTransition is invalid",
629                ise);
630      }
631
632      if (newState == model.magicalState()) {
633        break;
634      }
635
636      if (newState instanceof EmissionState) {
637        EmissionState eState = (EmissionState) newState;
638        token = eState.getDistribution().sampleSymbol();
639        symScore += eState.getDistribution().getWeight(token);
640        stateList.add(newState);
641        tokenList.add(token);
642        scoreList.add(dAlpha.getSymbol(symScore));
643        totScore += symScore;
644        symScore = 0.0;
645        i--;
646      }
647      oldState = newState;
648    }
649
650    SymbolList tokens = new SimpleSymbolList(model.emissionAlphabet(), tokenList);
651    SymbolList states = new SimpleSymbolList(model.stateAlphabet(), stateList);
652    SymbolList scores = new SimpleSymbolList(dAlpha, scoreList);
653
654    return new SimpleStatePath(
655            totScore,
656            tokens,
657            states,
658            scores
659    );
660  }
661
662  public static class ReverseIterator implements Iterator, Serializable {
663    private SymbolList sym;
664    private int index;
665
666    public ReverseIterator(SymbolList sym) {
667      this.sym = sym;
668      index = sym.length();
669    }
670
671    public boolean hasNext() {
672      return index > 0;
673    }
674
675    public Object next() {
676      return sym.symbolAt(index--);
677    }
678
679    public void remove() throws UnsupportedOperationException {
680      throw new UnsupportedOperationException("This itterator can not cause modifications");
681    }
682  }
683
684  private final ChangeListener UPDATER = new ChangeListener() {
685    public void preChange(ChangeEvent ce)
686            throws ChangeVetoException {
687    }
688
689    public void postChange(ChangeEvent ce) {
690      if (ce.getType().isMatchingType(MarkovModel.ARCHITECTURE)) {
691        System.out.println("architecture alterred");
692        states = null;
693      }
694
695      if (
696              (ce.getType().isMatchingType(MarkovModel.ARCHITECTURE)) ||
697              (ce.getType().isMatchingType(MarkovModel.PARAMETER))
698      ) {
699        update();
700      }
701    }
702  };
703
704  private static class HMMOrderByTransition {
705    public final static Object GREATER_THAN = new Object();
706    public final static Object LESS_THAN = new Object();
707    public final static Object EQUAL = new Object();
708    public final static Object DISJOINT = new Object();
709
710    private MarkovModel mm;
711
712    private HMMOrderByTransition(MarkovModel mm) {
713      this.mm = mm;
714    }
715
716    public Object compare(Object o1, Object o2)
717            throws IllegalTransitionException, IllegalSymbolException {
718      if (o1 == o2) {
719        return EQUAL;
720      }
721      State s1 = (State) o1;
722      State s2 = (State) o2;
723
724      if (transitionsTo(s1, s2)) {
725        return LESS_THAN;
726      }
727      if (transitionsTo(s2, s1)) {
728        return GREATER_THAN;
729      }
730
731      return DISJOINT;
732    }
733
734    private boolean transitionsTo(State from, State to)
735            throws IllegalTransitionException, IllegalSymbolException {
736      Set checkedSet = new HashSet();
737      Set workingSet = new HashSet();
738      for (
739              Iterator i = mm.transitionsFrom(from).iterator();
740              i.hasNext();
741              ) {
742        workingSet.add(i.next());
743      }
744
745      while (workingSet.size() > 0) {
746        Set newWorkingSet = new HashSet();
747        for (Iterator i = workingSet.iterator(); i.hasNext();) {
748          State s = (State) i.next();
749          if (s instanceof EmissionState) {
750            continue;
751          }
752          if (s == from) {
753            throw new IllegalTransitionException(
754                    from, from, "Loop in dot states."
755            );
756          }
757          if (s == to) {
758            return true;
759          }
760          for (Iterator j = mm.transitionsFrom(s).iterator(); j.hasNext();) {
761            State s2 = (State) j.next();
762            if (!workingSet.contains(s2) && !checkedSet.contains(s2)) {
763              newWorkingSet.add(s2);
764            }
765          }
766          checkedSet.add(s);
767        }
768        workingSet = newWorkingSet;
769      }
770      return false;
771    }
772  }
773}
774