001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.bio.dp.twohead;
024
025import java.io.Serializable;
026import java.util.ArrayList;
027import java.util.Collections;
028import java.util.HashMap;
029import java.util.Iterator;
030import java.util.List;
031import java.util.Map;
032
033import org.biojava.bio.BioError;
034import org.biojava.bio.BioException;
035import org.biojava.bio.alignment.Alignment;
036import org.biojava.bio.alignment.SimpleAlignment;
037import org.biojava.bio.dp.BackPointer;
038import org.biojava.bio.dp.DP;
039import org.biojava.bio.dp.DPMatrix;
040import org.biojava.bio.dp.EmissionState;
041import org.biojava.bio.dp.IllegalTransitionException;
042import org.biojava.bio.dp.MarkovModel;
043import org.biojava.bio.dp.ScoreType;
044import org.biojava.bio.dp.SimpleStatePath;
045import org.biojava.bio.dp.State;
046import org.biojava.bio.dp.StatePath;
047import org.biojava.bio.symbol.Alphabet;
048import org.biojava.bio.symbol.DoubleAlphabet;
049import org.biojava.bio.symbol.GappedSymbolList;
050import org.biojava.bio.symbol.IllegalAlphabetException;
051import org.biojava.bio.symbol.IllegalSymbolException;
052import org.biojava.bio.symbol.SimpleGappedSymbolList;
053import org.biojava.bio.symbol.SimpleSymbolList;
054import org.biojava.bio.symbol.SymbolList;
055import org.biojava.utils.SmallMap;
056
057/**
058 * Algorithms for dynamic programming (alignments) between pairs
059 * of SymbolLists.
060 * Based on a single-head DP implementation by Matt Pocock.
061 *
062 * @author Thomas Down
063 * @author Matthew Pocock
064 */
065
066public class PairwiseDP extends DP implements Serializable {
067  private final HashMap emissionCache;
068  private final CellCalculatorFactory ccFactory;
069
070  public PairwiseDP(MarkovModel mm, CellCalculatorFactoryMaker ccfm)
071  throws
072    IllegalSymbolException,
073    IllegalTransitionException,
074    BioException
075  {
076    super(mm);
077    Alphabet alpha = mm.emissionAlphabet();
078    emissionCache = new HashMap();
079    emissionCache.put(ScoreType.PROBABILITY, new EmissionCache(
080      alpha, getStates(), getDotStatesIndex(), ScoreType.PROBABILITY)
081    );
082    emissionCache.put(ScoreType.ODDS, new EmissionCache(
083      alpha, getStates(), getDotStatesIndex(), ScoreType.ODDS)
084    );
085    emissionCache.put(ScoreType.NULL_MODEL, new EmissionCache(
086      alpha, getStates(), getDotStatesIndex(), ScoreType.NULL_MODEL)
087    );
088    this.ccFactory = ccfm.make(this);
089  }
090  
091  private EmissionCache getEmissionCache(ScoreType scoreType) {
092    return (EmissionCache) emissionCache.get(scoreType);
093  }
094
095  //
096  // BACKWARD
097  //
098
099  public void update() {
100    super.update();
101    // workaround for bug in vm
102    if(emissionCache != null) {
103      for(Iterator i = emissionCache.values().iterator(); i.hasNext(); ) {
104        ((EmissionCache) i.next()).clear();
105      }
106    }
107  }
108
109  private Cell run(PairDPCursor cursor, CellCalculator cc)
110      throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException
111  {
112    Cell[][] cells = cursor.press();
113    if(cursor.hasNext()) {
114      cursor.next(cells);
115      cc.initialize(cells);
116    }
117    while(cursor.hasNext()) {
118      cursor.next(cells);
119      cc.calcCell(cells);
120    }
121    return cells[0][0];
122  }
123  
124  private double runFB(PairDPCursor cursor, CellCalculator cc) 
125      throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException
126  {
127    Cell cell = run(cursor, cc);
128    
129    // Terminate!
130    State [] states = getStates();
131    int l = 0;
132    State magicalState = getModel().magicalState();
133    while (states[l] != magicalState) {
134      ++l;
135    }
136      
137    return cell.scores[l];
138  }
139  
140  public double backward(SymbolList[] seqs, ScoreType scoreType) 
141  throws IllegalSymbolException,
142  IllegalAlphabetException,
143  IllegalTransitionException {
144    return backwardMatrix(seqs, scoreType).getScore();
145  }
146
147  public DPMatrix backwardMatrix(SymbolList[] seqs, ScoreType scoreType) 
148  throws IllegalSymbolException,
149  IllegalAlphabetException,
150  IllegalTransitionException {
151    if (seqs.length != 2) {
152      throw new IllegalArgumentException("This DP object only runs on pairs.");
153    }
154    lockModel();
155    PairDPMatrix matrix = new PairDPMatrix(this, seqs[0], seqs[1]);
156    PairDPCursor cursor = new BackMatrixPairDPCursor(
157      seqs[0], seqs[1],
158      2, 2,
159      matrix,
160      getEmissionCache(scoreType)
161    );
162    CellCalculator cc = ccFactory.backwards(scoreType);
163    double score = runFB(cursor, cc);
164    unlockModel();
165    matrix.setScore(score);
166    return matrix;
167  }
168
169  public DPMatrix backwardMatrix(SymbolList[] seqs, DPMatrix d, ScoreType scoreType) 
170  throws IllegalSymbolException,
171  IllegalAlphabetException,
172  IllegalTransitionException {
173    return backwardMatrix(seqs, scoreType);
174  }
175
176  public double forward(SymbolList[] seqs, ScoreType scoreType) 
177  throws IllegalSymbolException,
178  IllegalAlphabetException,
179  IllegalTransitionException {
180    if (seqs.length != 2) {
181      throw new IllegalArgumentException("This DP object only runs on pairs.");
182    }
183    lockModel();
184    PairDPCursor cursor = new LightPairDPCursor(
185      seqs[0], seqs[1],
186      2, 2, getStates().length, getEmissionCache(scoreType)
187    );
188    CellCalculator cc = ccFactory.forwards(scoreType);
189    double score = runFB(cursor, cc);
190    unlockModel();
191    return score;
192  }
193
194  public DPMatrix forwardMatrix(SymbolList[] seqs, ScoreType scoreType) 
195  throws
196    IllegalSymbolException,
197    IllegalAlphabetException,
198    IllegalTransitionException
199  {
200    if (seqs.length != 2) {
201      throw new IllegalArgumentException("This DP object only runs on pairs.");
202    }
203    lockModel();
204    PairDPMatrix matrix = new PairDPMatrix(this, seqs[0], seqs[1]);
205    PairDPCursor cursor = new MatrixPairDPCursor(
206      seqs[0], seqs[1],
207      2, 2, matrix, getEmissionCache(scoreType)
208    );
209    CellCalculator cc = ccFactory.forwards(scoreType);
210    double score = runFB(cursor, cc);
211    matrix.setScore(score);
212    unlockModel();
213    return matrix;
214  }
215
216  public DPMatrix forwardMatrix(SymbolList[] seqs, DPMatrix d, ScoreType scoreType) 
217  throws
218    IllegalSymbolException,
219    IllegalAlphabetException,
220    IllegalTransitionException
221  {
222    return forwardMatrix(seqs, scoreType);
223  }
224
225  public StatePath viterbi(SymbolList[] seqs, ScoreType scoreType) 
226  throws
227    IllegalSymbolException,
228    IllegalAlphabetException,
229    IllegalTransitionException
230  {
231    if (seqs.length != 2) {
232      throw new IllegalArgumentException("This DP object only runs on pairs.");
233    }
234    lockModel();
235    SymbolList seq0 = seqs[0];
236    SymbolList seq1 = seqs[1];
237    State magic = getModel().magicalState();
238    BackPointer TERMINAL_BP = new BackPointer(magic);
239    PairDPCursor cursor = new LightPairDPCursor(
240      seq0, seq1,
241      2, 2, getStates().length, getEmissionCache(scoreType)
242    );
243    CellCalculator cc = ccFactory.viterbi(scoreType, TERMINAL_BP);
244    Cell currentCell = run(cursor, cc);
245  
246    // Terminate!
247
248    int l = 0;
249    State [] states = getStates();
250    State magicalState = getModel().magicalState();
251    while (states[l] != magicalState) {
252      ++l;
253    }
254
255    // Traceback...  
256
257    BackPointer[] bpCol = currentCell.backPointers;
258    BackPointer bp = bpCol[l];
259    List statel = new ArrayList();
260    GappedSymbolList gap0 = new SimpleGappedSymbolList(seq0);
261    GappedSymbolList gap1 = new SimpleGappedSymbolList(seq1);
262    int i0 = seq0.length()+1;
263    int i1 = seq1.length()+1;
264  
265    // parse 1
266    //System.out.println("Parse 1");
267    for(BackPointer bpi = bp.back; bpi != TERMINAL_BP; bpi = bpi.back) {
268      try {
269        //System.out.println("bp.back" + bp.back);
270      /*System.out.print(
271        "Processing " + bpi.state.getName()
272      );*/
273      statel.add(bpi.state);
274      if(bpi.state instanceof EmissionState) { 
275        int [] advance = ((EmissionState) bpi.state).getAdvance();
276        //System.out.print( "\t" + advance[0] + " " + advance[1]);
277        if(advance[0] == 0) {
278          gap0.addGapInSource(i0);
279          //System.out.println(gap0.seqString());
280          //System.out.print("\t-");
281        } else {
282          i0--;
283            //System.out.print("\t" + seq0.symbolAt(i0).getToken());
284        }
285        if(advance[1] == 0) {
286          gap1.addGapInSource(i1);
287          //System.out.println(gap1.seqString());
288          //System.out.print(" -");
289        } else {
290          i1--;
291            //System.out.print(" " + seq1.symbolAt(i1).getToken());
292        }
293      }
294      //System.out.println("\tat " + i0 + ", " + i1);
295      } catch (IndexOutOfBoundsException ie) {
296        while(bpi != TERMINAL_BP) {
297          //System.out.println(bpi.state.getName());
298          bpi = bpi.back;
299        }
300        throw new BioError(ie); 
301      }
302    }
303    //System.out.println(gap0.seqString());
304    //System.out.println(gap1.seqString());
305    double [] scoreA = new double[statel.size()];
306    Map aMap = new SmallMap();
307    aMap.put(seq0, gap0);
308    aMap.put(seq1, gap1);
309    Alignment ali = new SimpleAlignment(aMap);
310    GappedSymbolList gappedAli = new SimpleGappedSymbolList(ali);
311
312    // parse 2
313    //System.out.println("Parse 2");
314    int di = statel.size()-1;
315    int dj = ali.length()+1;
316    for(BackPointer bpi = bp.back; bpi != TERMINAL_BP; bpi = bpi.back) {
317      scoreA[di] = bpi.score;
318      if(bpi.state instanceof EmissionState) {
319        dj--;
320      } else {
321        gappedAli.addGapInSource(dj);
322      }
323      di--;
324    }
325
326    Collections.reverse(statel);
327    SymbolList statesSL = new SimpleSymbolList(getModel().stateAlphabet(), statel);
328    SymbolList scoresSL = DoubleAlphabet.fromArray(scoreA);
329    StatePath sp = new SimpleStatePath(currentCell.scores[l], gappedAli, statesSL, scoresSL);
330    unlockModel();
331    return sp;
332  }
333}