001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.bio.dp.twohead;
024
025import java.io.Serializable;
026import org.biojava.bio.BioError;
027import org.biojava.bio.dp.BackPointer;
028import org.biojava.bio.dp.DP;
029import org.biojava.bio.dp.EmissionState;
030import org.biojava.bio.dp.IllegalTransitionException;
031import org.biojava.bio.dp.ScoreType;
032import org.biojava.bio.dp.State;
033import org.biojava.bio.symbol.IllegalAlphabetException;
034import org.biojava.bio.symbol.IllegalSymbolException;
035
036/**
037 * @author Matthew Pocock
038 * @author Thomas Down
039 */
040public class DPInterpreter implements CellCalculatorFactory, Serializable {
041  private final DP dp;
042
043  public DPInterpreter(DP dp) {
044    this.dp = dp;
045  }
046
047  public CellCalculator forwards(ScoreType scoreType)
048  throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException {
049    return new Forward(dp, scoreType);
050  }
051
052  public CellCalculator backwards(ScoreType scoreType)
053  throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException {
054    return new Backward(dp, scoreType);
055  }
056
057  public CellCalculator viterbi(ScoreType scoreType, BackPointer terminal)
058  throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException {
059    return new Viterbi(dp, scoreType, terminal);
060  }
061
062
063  private static class Forward implements CellCalculator {
064    private final int[][] transitions;
065    private final double[][] transitionScores;
066    private final State[] states;
067    private final State magicalState;
068
069    public Forward(DP dp, ScoreType scoreType)
070    throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException {
071      states = dp.getStates();
072
073      transitions = dp.getForwardTransitions();
074      transitionScores = dp.getForwardTransitionScores(scoreType);
075      magicalState = dp.getModel().magicalState();
076    }
077
078    public void initialize(Cell [][] cells)
079    throws
080      IllegalSymbolException,
081      IllegalAlphabetException,
082      IllegalTransitionException
083    {
084      _calcCell(cells, true);
085    }
086
087    public void calcCell(Cell [][] cells)
088    throws
089      IllegalSymbolException,
090      IllegalAlphabetException,
091      IllegalTransitionException
092    {
093      _calcCell(cells, false);
094    }
095
096    public void _calcCell(Cell [][] cells, boolean initializationHack)
097    throws
098      IllegalSymbolException,
099      IllegalAlphabetException,
100      IllegalTransitionException
101    {
102      Cell curCell = cells[0][0];
103      double[] curCol = curCell.scores;
104      double[] emissions = curCell.emissions;
105      //System.out.println("curCol = " + curCol);
106
107     STATELOOP:
108      for (int l = 0; l < states.length; ++l) {
109        State curState = states[l];
110        //System.out.println("State = " + states[l].getName());
111        try {
112          if(initializationHack && (curState instanceof EmissionState)) {
113            if(curState == magicalState) {
114              curCol[l] = 0.0;
115            } else {
116              curCol[l] = Double.NaN;
117            }
118            //System.out.println("Initialized state to " + curCol[l]);
119            continue STATELOOP;
120          }
121
122          //System.out.println("Calculating weight");
123          double[] sourceScores;
124          double weight;
125          if (! (curState instanceof EmissionState)) {
126            weight = 0.0;
127            sourceScores = curCol;
128          } else {
129            weight = emissions[l];
130            //System.out.println("Weight " + emissions[l]);
131            if(weight == Double.NEGATIVE_INFINITY || Double.isNaN(weight)) {
132              curCol[l] = Double.NaN;
133              continue STATELOOP;
134            }
135            int [] advance = ((EmissionState)curState).getAdvance();
136            sourceScores = cells[advance[0]][advance[1]].scores;
137            //System.out.println("Values from " + advance[0] + ", " + advance[1] + " " + sourceScores);
138          }
139          //System.out.println("weight = " + weight);
140
141          int [] tr = transitions[l];
142          double[] trs = transitionScores[l];
143
144          /*for(int ci = 0; ci < tr.length; ci++) {
145            System.out.println(
146              "Source = " + states[tr[ci]].getName() +
147              "\t= " + sourceScores[tr[ci]]
148            );
149          }*/
150
151          // Calculate probabilities for states with transitions
152          // here.
153
154          // Find base for addition
155          double constant = Double.NaN;
156          double score = 0.0;
157          for (
158            int ci = 0;
159            ci < tr.length;
160            ci++
161          ) {
162            int trc = tr[ci];
163            double trSc = sourceScores[trc];
164            if(!Double.isNaN(trSc) && trSc != Double.NEGATIVE_INFINITY) {
165              if(Double.isNaN(constant)) {
166                constant = trSc;
167              }
168              double sk = trs[ci];
169              if(!Double.isNaN(sk) && sk != Double.NEGATIVE_INFINITY) {
170                score += Math.exp(trSc + sk - constant);
171              }
172            }
173          }
174          if(Double.isNaN(constant)) {
175            curCol[l] = Double.NaN;
176            //System.out.println("found no source");
177          } else {
178            curCol[l] = weight + Math.log(score) + constant;
179          }
180        } catch (Exception e) {
181          throw new BioError(
182
183            "Problem with state " + l + " -> " + states[l].getName(), e
184          );
185        } catch (BioError e) {
186          throw new BioError(
187
188            "Error  with state " + l + " -> " + states[l].getName(), e
189          );
190        }
191      }
192      /*for (int l = 0; l < states.length; ++l) {
193        State curState = states[l];
194        System.out.println(
195          "State = " + states[l].getName() +
196          "\t = " + curCol[l]
197        );
198      }*/
199    }
200  }
201
202  private static class Backward implements CellCalculator {
203    private final int[][] transitions;
204    private final double[][] transitionScores;
205    private final State[] states;
206    private final State magicalState;
207
208    public Backward(
209      DP dp, ScoreType scoreType
210    ) throws
211      IllegalSymbolException,
212      IllegalAlphabetException,
213      IllegalTransitionException
214    {
215      states = dp.getStates();
216      transitions = dp.getBackwardTransitions();
217      transitionScores = dp.getBackwardTransitionScores(scoreType);
218      magicalState = dp.getModel().magicalState();
219    }
220
221    public void initialize(Cell [][] cells)
222    throws
223      IllegalSymbolException,
224      IllegalAlphabetException,
225      IllegalTransitionException
226    {
227      _calcCell(cells, true);
228    }
229
230    public void calcCell(Cell [][] cells)
231    throws
232      IllegalSymbolException,
233      IllegalAlphabetException,
234      IllegalTransitionException
235    {
236      _calcCell(cells, false);
237    }
238
239    public void _calcCell(Cell [][] cells, boolean initializationHack)
240    throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException
241    {
242      Cell curCell = cells[0][0];
243      double[] curCol = curCell.scores;
244
245     STATELOOP:
246      for (int l = states.length - 1; l >= 0; --l) {
247        //System.out.println("State = " + states[l].getName());
248        State curState = states[l];
249        if(initializationHack && (curState instanceof EmissionState)) {
250          if(curState == magicalState) {
251            curCol[l] = 0.0;
252          } else {
253            curCol[l] = Double.NaN;
254          }
255          continue STATELOOP;
256        }
257        int [] tr = transitions[l];
258        double[] trs = transitionScores[l];
259
260        // Calculate probabilities for states with transitions
261        // here.
262
263            double[] sourceScores = new double[tr.length];
264        for (int ci = 0; ci < tr.length; ++ci) {
265          double weight;
266
267          int destI = tr[ci];
268          State destS = states[destI];
269          Cell targetCell;
270          if (destS instanceof EmissionState) {
271            int [] advance = ((EmissionState)destS).getAdvance();
272            targetCell = cells[advance[0]][advance[1]];
273            weight = targetCell.emissions[destI];
274            if(Double.isNaN(weight)) {
275              curCol[l] = Double.NaN;
276              continue STATELOOP;
277            }
278          } else {
279            targetCell = curCell;
280            weight = 0.0;
281          }
282          sourceScores[ci] = targetCell.scores[destI] + weight;
283        }
284        /*for(int ci = 0; ci < tr.length; ci++) {
285          System.out.println(
286            "Source = " + states[tr[ci]].getName() +
287            "\t= " + sourceScores[ci]
288          );
289        }*/
290
291        // Find base for addition
292        double constant = Double.NaN;
293        double score = 0.0;
294        for(
295          int ci = 0;
296          ci < tr.length;
297          ci++
298        ) {
299          double skc = sourceScores[ci];
300          if(skc != Double.NEGATIVE_INFINITY && !Double.isNaN(skc)) {
301            if(Double.isNaN(constant)) {
302              constant = skc;
303            }
304            double sk = trs[ci];
305            if(!Double.isNaN(sk) && sk != Double.NEGATIVE_INFINITY) {
306              score += Math.exp(skc + sk - constant);
307            }
308          }
309        }
310        if(Double.isNaN(constant)) {
311          curCol[l] = Double.NaN;
312        } else {
313          curCol[l] = Math.log(score) + constant;
314          //System.out.println(curCol[l]);
315        }
316      }
317      /*for (int l = 0; l < states.length; ++l) {
318        State curState = states[l];
319        System.out.println(
320          "State = " + states[l].getName() +
321          "\t = " + curCol[l]
322        );
323      }*/
324    }
325  }
326
327
328  private class Viterbi implements CellCalculator {
329    private final int[][] transitions;
330    private final double[][] transitionScores;
331    private final State[] states;
332    private final BackPointer TERMINAL_BP;
333    private final State magicalState;
334
335    public Viterbi(DP dp, ScoreType scoreType, BackPointer terminal)
336    throws
337      IllegalSymbolException,
338      IllegalAlphabetException,
339      IllegalTransitionException
340    {
341      TERMINAL_BP = terminal;
342      states = dp.getStates();
343      transitions = dp.getForwardTransitions();
344      transitionScores = dp.getForwardTransitionScores(scoreType);
345      magicalState = dp.getModel().magicalState();
346    }
347
348    public void initialize(Cell[][] cells)
349    throws
350      IllegalSymbolException,
351      IllegalAlphabetException,
352      IllegalTransitionException
353    {
354      _calcCell(cells, true);
355    }
356
357    public void calcCell(Cell[][] cells)
358    throws
359      IllegalSymbolException,
360      IllegalAlphabetException,
361      IllegalTransitionException
362    {
363      _calcCell(cells, false);
364    }
365
366    public void _calcCell(Cell [][] cells, boolean initializationHack)
367    throws
368      IllegalSymbolException,
369      IllegalAlphabetException,
370      IllegalTransitionException
371    {
372      Cell curCell = cells[0][0];
373      double[] curCol = curCell.scores;
374      BackPointer[] curBPs = curCell.backPointers;
375      double[] emissions = curCell.emissions;
376      //System.out.println("Scores " + curCol);
377     STATELOOP:
378      for (int l = 0; l < states.length; ++l) {
379        State curState = states[l];
380            //System.out.println("State = " + l + "=" + states[l].getName());
381        try {
382          //System.out.println("trying initialization");
383          if(initializationHack && (curState instanceof EmissionState)) {
384            if(curState == magicalState) {
385              curCol[l] = 0.0;
386              curBPs[l] = TERMINAL_BP;
387            } else {
388              curCol[l] = Double.NaN;
389              curBPs[l] = null;
390            }
391            //System.out.println("Initialized state to " + curCol[l]);
392            continue STATELOOP;
393          }
394
395          double weight;
396          double[] sourceScores;
397          BackPointer[] oldBPs;
398          if(! (curState instanceof EmissionState)) {
399            weight = 0.0;
400            sourceScores = curCol;
401            oldBPs = curBPs;
402          } else {
403            weight = emissions[l];
404            if(weight == Double.NEGATIVE_INFINITY || Double.isNaN(weight)) {
405              curCol[l] = Double.NaN;
406              curBPs[l] = null;
407              continue STATELOOP;
408            }
409            int [] advance = ((EmissionState)curState).getAdvance();
410            Cell oldCell = cells[advance[0]][advance[1]];
411            sourceScores = oldCell.scores;
412            oldBPs = oldCell.backPointers;
413            //System.out.println("Looking back " + advance[0] + ", " + advance[1]);
414          }
415          //System.out.println("weight = " + weight);
416
417          double score = Double.NEGATIVE_INFINITY;
418          int [] tr = transitions[l];
419          double[] trs = transitionScores[l];
420
421          int bestK = -1; // index into states[l]
422          for (int kc = 0; kc < tr.length; ++kc) {
423            int k = tr[kc]; // actual state index
424            double sk = sourceScores[k];
425
426            /*System.out.println("kc is " + kc);
427            System.out.println("with from " + k + "=" + states[k].getName());
428            System.out.println("prevScore = " + sk);*/
429            if (sk != Double.NEGATIVE_INFINITY && !Double.isNaN(sk)) {
430              double t = trs[kc];
431              //System.out.println("Transition score = " + t);
432              double newScore = t + sk;
433              if (newScore > score) {
434                score = newScore;
435                bestK = k;
436                //System.out.println("New best source at " + kc + " is " + score);
437              }
438            }
439          }
440          if (bestK != -1) {
441            curCol[l] = weight + score;
442            /*System.out.println("Weight = " + weight);
443            System.out.println("Score = " + score);
444            System.out.println(
445              "Creating " + states[bestK].getName() +
446              " -> " + states[l].getName() +
447              " (" + curCol[l] + ")"
448            );*/
449            try {
450              State s = states[l];
451              curBPs[l] = new BackPointer(
452                s,
453                oldBPs[bestK],
454                curCol[l]
455              );
456            } catch (Throwable t) {
457              throw new BioError(
458
459                "Couldn't generate backpointer for " + states[l].getName() +
460                " back to " + states[bestK].getName(), t
461              );
462            }
463          } else {
464            //System.out.println("No where to come from");;
465            curBPs[l] = null;
466            curCol[l] = Double.NaN;
467          }
468        } catch (Exception e) {
469          throw new BioError(
470
471            "Problem with state " + l + " -> " + states[l].getName(),e
472          );
473        } catch (BioError e) {
474          throw new BioError(
475
476            "Error  with state " + l + " -> " + states[l].getName(),e
477          );
478        }
479      }
480      /*System.out.println("backpointers:");
481      for(int l = 0; l < states.length; l++) {
482        System.out.print(states[l].getName() + "\t" + curCol[l] + "\t");
483        BackPointer b = curBPs[l];
484        if(b != null) {
485          for(BackPointer bb = b; bb.back != bb; bb = bb.back) {
486            System.out.print(bb.state.getName() + " -> ");
487          }
488          System.out.println("!");
489        } else {
490          System.out.print("\n");
491        }
492      }*/
493      initializationHack = false;
494    }
495  }
496
497  public static class Maker implements CellCalculatorFactoryMaker {
498    public CellCalculatorFactory make(DP dp) {
499      return new DPInterpreter(dp);
500    }
501  }
502}