001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023package org.biojava.bio.dp.twohead; 024 025import java.io.Serializable; 026import org.biojava.bio.BioError; 027import org.biojava.bio.dp.BackPointer; 028import org.biojava.bio.dp.DP; 029import org.biojava.bio.dp.EmissionState; 030import org.biojava.bio.dp.IllegalTransitionException; 031import org.biojava.bio.dp.ScoreType; 032import org.biojava.bio.dp.State; 033import org.biojava.bio.symbol.IllegalAlphabetException; 034import org.biojava.bio.symbol.IllegalSymbolException; 035 036/** 037 * @author Matthew Pocock 038 * @author Thomas Down 039 */ 040public class DPInterpreter implements CellCalculatorFactory, Serializable { 041 private final DP dp; 042 043 public DPInterpreter(DP dp) { 044 this.dp = dp; 045 } 046 047 public CellCalculator forwards(ScoreType scoreType) 048 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException { 049 return new Forward(dp, scoreType); 050 } 051 052 public CellCalculator backwards(ScoreType scoreType) 053 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException { 054 return new Backward(dp, scoreType); 055 } 056 057 public CellCalculator viterbi(ScoreType scoreType, BackPointer terminal) 058 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException { 059 return new Viterbi(dp, scoreType, terminal); 060 } 061 062 063 private static class Forward implements CellCalculator { 064 private final int[][] transitions; 065 private final double[][] transitionScores; 066 private final State[] states; 067 private final State magicalState; 068 069 public Forward(DP dp, ScoreType scoreType) 070 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException { 071 states = dp.getStates(); 072 073 transitions = dp.getForwardTransitions(); 074 transitionScores = dp.getForwardTransitionScores(scoreType); 075 magicalState = dp.getModel().magicalState(); 076 } 077 078 public void initialize(Cell [][] cells) 079 throws 080 IllegalSymbolException, 081 IllegalAlphabetException, 082 IllegalTransitionException 083 { 084 _calcCell(cells, true); 085 } 086 087 public void calcCell(Cell [][] cells) 088 throws 089 IllegalSymbolException, 090 IllegalAlphabetException, 091 IllegalTransitionException 092 { 093 _calcCell(cells, false); 094 } 095 096 public void _calcCell(Cell [][] cells, boolean initializationHack) 097 throws 098 IllegalSymbolException, 099 IllegalAlphabetException, 100 IllegalTransitionException 101 { 102 Cell curCell = cells[0][0]; 103 double[] curCol = curCell.scores; 104 double[] emissions = curCell.emissions; 105 //System.out.println("curCol = " + curCol); 106 107 STATELOOP: 108 for (int l = 0; l < states.length; ++l) { 109 State curState = states[l]; 110 //System.out.println("State = " + states[l].getName()); 111 try { 112 if(initializationHack && (curState instanceof EmissionState)) { 113 if(curState == magicalState) { 114 curCol[l] = 0.0; 115 } else { 116 curCol[l] = Double.NaN; 117 } 118 //System.out.println("Initialized state to " + curCol[l]); 119 continue STATELOOP; 120 } 121 122 //System.out.println("Calculating weight"); 123 double[] sourceScores; 124 double weight; 125 if (! (curState instanceof EmissionState)) { 126 weight = 0.0; 127 sourceScores = curCol; 128 } else { 129 weight = emissions[l]; 130 //System.out.println("Weight " + emissions[l]); 131 if(weight == Double.NEGATIVE_INFINITY || Double.isNaN(weight)) { 132 curCol[l] = Double.NaN; 133 continue STATELOOP; 134 } 135 int [] advance = ((EmissionState)curState).getAdvance(); 136 sourceScores = cells[advance[0]][advance[1]].scores; 137 //System.out.println("Values from " + advance[0] + ", " + advance[1] + " " + sourceScores); 138 } 139 //System.out.println("weight = " + weight); 140 141 int [] tr = transitions[l]; 142 double[] trs = transitionScores[l]; 143 144 /*for(int ci = 0; ci < tr.length; ci++) { 145 System.out.println( 146 "Source = " + states[tr[ci]].getName() + 147 "\t= " + sourceScores[tr[ci]] 148 ); 149 }*/ 150 151 // Calculate probabilities for states with transitions 152 // here. 153 154 // Find base for addition 155 double constant = Double.NaN; 156 double score = 0.0; 157 for ( 158 int ci = 0; 159 ci < tr.length; 160 ci++ 161 ) { 162 int trc = tr[ci]; 163 double trSc = sourceScores[trc]; 164 if(!Double.isNaN(trSc) && trSc != Double.NEGATIVE_INFINITY) { 165 if(Double.isNaN(constant)) { 166 constant = trSc; 167 } 168 double sk = trs[ci]; 169 if(!Double.isNaN(sk) && sk != Double.NEGATIVE_INFINITY) { 170 score += Math.exp(trSc + sk - constant); 171 } 172 } 173 } 174 if(Double.isNaN(constant)) { 175 curCol[l] = Double.NaN; 176 //System.out.println("found no source"); 177 } else { 178 curCol[l] = weight + Math.log(score) + constant; 179 } 180 } catch (Exception e) { 181 throw new BioError( 182 183 "Problem with state " + l + " -> " + states[l].getName(), e 184 ); 185 } catch (BioError e) { 186 throw new BioError( 187 188 "Error with state " + l + " -> " + states[l].getName(), e 189 ); 190 } 191 } 192 /*for (int l = 0; l < states.length; ++l) { 193 State curState = states[l]; 194 System.out.println( 195 "State = " + states[l].getName() + 196 "\t = " + curCol[l] 197 ); 198 }*/ 199 } 200 } 201 202 private static class Backward implements CellCalculator { 203 private final int[][] transitions; 204 private final double[][] transitionScores; 205 private final State[] states; 206 private final State magicalState; 207 208 public Backward( 209 DP dp, ScoreType scoreType 210 ) throws 211 IllegalSymbolException, 212 IllegalAlphabetException, 213 IllegalTransitionException 214 { 215 states = dp.getStates(); 216 transitions = dp.getBackwardTransitions(); 217 transitionScores = dp.getBackwardTransitionScores(scoreType); 218 magicalState = dp.getModel().magicalState(); 219 } 220 221 public void initialize(Cell [][] cells) 222 throws 223 IllegalSymbolException, 224 IllegalAlphabetException, 225 IllegalTransitionException 226 { 227 _calcCell(cells, true); 228 } 229 230 public void calcCell(Cell [][] cells) 231 throws 232 IllegalSymbolException, 233 IllegalAlphabetException, 234 IllegalTransitionException 235 { 236 _calcCell(cells, false); 237 } 238 239 public void _calcCell(Cell [][] cells, boolean initializationHack) 240 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException 241 { 242 Cell curCell = cells[0][0]; 243 double[] curCol = curCell.scores; 244 245 STATELOOP: 246 for (int l = states.length - 1; l >= 0; --l) { 247 //System.out.println("State = " + states[l].getName()); 248 State curState = states[l]; 249 if(initializationHack && (curState instanceof EmissionState)) { 250 if(curState == magicalState) { 251 curCol[l] = 0.0; 252 } else { 253 curCol[l] = Double.NaN; 254 } 255 continue STATELOOP; 256 } 257 int [] tr = transitions[l]; 258 double[] trs = transitionScores[l]; 259 260 // Calculate probabilities for states with transitions 261 // here. 262 263 double[] sourceScores = new double[tr.length]; 264 for (int ci = 0; ci < tr.length; ++ci) { 265 double weight; 266 267 int destI = tr[ci]; 268 State destS = states[destI]; 269 Cell targetCell; 270 if (destS instanceof EmissionState) { 271 int [] advance = ((EmissionState)destS).getAdvance(); 272 targetCell = cells[advance[0]][advance[1]]; 273 weight = targetCell.emissions[destI]; 274 if(Double.isNaN(weight)) { 275 curCol[l] = Double.NaN; 276 continue STATELOOP; 277 } 278 } else { 279 targetCell = curCell; 280 weight = 0.0; 281 } 282 sourceScores[ci] = targetCell.scores[destI] + weight; 283 } 284 /*for(int ci = 0; ci < tr.length; ci++) { 285 System.out.println( 286 "Source = " + states[tr[ci]].getName() + 287 "\t= " + sourceScores[ci] 288 ); 289 }*/ 290 291 // Find base for addition 292 double constant = Double.NaN; 293 double score = 0.0; 294 for( 295 int ci = 0; 296 ci < tr.length; 297 ci++ 298 ) { 299 double skc = sourceScores[ci]; 300 if(skc != Double.NEGATIVE_INFINITY && !Double.isNaN(skc)) { 301 if(Double.isNaN(constant)) { 302 constant = skc; 303 } 304 double sk = trs[ci]; 305 if(!Double.isNaN(sk) && sk != Double.NEGATIVE_INFINITY) { 306 score += Math.exp(skc + sk - constant); 307 } 308 } 309 } 310 if(Double.isNaN(constant)) { 311 curCol[l] = Double.NaN; 312 } else { 313 curCol[l] = Math.log(score) + constant; 314 //System.out.println(curCol[l]); 315 } 316 } 317 /*for (int l = 0; l < states.length; ++l) { 318 State curState = states[l]; 319 System.out.println( 320 "State = " + states[l].getName() + 321 "\t = " + curCol[l] 322 ); 323 }*/ 324 } 325 } 326 327 328 private class Viterbi implements CellCalculator { 329 private final int[][] transitions; 330 private final double[][] transitionScores; 331 private final State[] states; 332 private final BackPointer TERMINAL_BP; 333 private final State magicalState; 334 335 public Viterbi(DP dp, ScoreType scoreType, BackPointer terminal) 336 throws 337 IllegalSymbolException, 338 IllegalAlphabetException, 339 IllegalTransitionException 340 { 341 TERMINAL_BP = terminal; 342 states = dp.getStates(); 343 transitions = dp.getForwardTransitions(); 344 transitionScores = dp.getForwardTransitionScores(scoreType); 345 magicalState = dp.getModel().magicalState(); 346 } 347 348 public void initialize(Cell[][] cells) 349 throws 350 IllegalSymbolException, 351 IllegalAlphabetException, 352 IllegalTransitionException 353 { 354 _calcCell(cells, true); 355 } 356 357 public void calcCell(Cell[][] cells) 358 throws 359 IllegalSymbolException, 360 IllegalAlphabetException, 361 IllegalTransitionException 362 { 363 _calcCell(cells, false); 364 } 365 366 public void _calcCell(Cell [][] cells, boolean initializationHack) 367 throws 368 IllegalSymbolException, 369 IllegalAlphabetException, 370 IllegalTransitionException 371 { 372 Cell curCell = cells[0][0]; 373 double[] curCol = curCell.scores; 374 BackPointer[] curBPs = curCell.backPointers; 375 double[] emissions = curCell.emissions; 376 //System.out.println("Scores " + curCol); 377 STATELOOP: 378 for (int l = 0; l < states.length; ++l) { 379 State curState = states[l]; 380 //System.out.println("State = " + l + "=" + states[l].getName()); 381 try { 382 //System.out.println("trying initialization"); 383 if(initializationHack && (curState instanceof EmissionState)) { 384 if(curState == magicalState) { 385 curCol[l] = 0.0; 386 curBPs[l] = TERMINAL_BP; 387 } else { 388 curCol[l] = Double.NaN; 389 curBPs[l] = null; 390 } 391 //System.out.println("Initialized state to " + curCol[l]); 392 continue STATELOOP; 393 } 394 395 double weight; 396 double[] sourceScores; 397 BackPointer[] oldBPs; 398 if(! (curState instanceof EmissionState)) { 399 weight = 0.0; 400 sourceScores = curCol; 401 oldBPs = curBPs; 402 } else { 403 weight = emissions[l]; 404 if(weight == Double.NEGATIVE_INFINITY || Double.isNaN(weight)) { 405 curCol[l] = Double.NaN; 406 curBPs[l] = null; 407 continue STATELOOP; 408 } 409 int [] advance = ((EmissionState)curState).getAdvance(); 410 Cell oldCell = cells[advance[0]][advance[1]]; 411 sourceScores = oldCell.scores; 412 oldBPs = oldCell.backPointers; 413 //System.out.println("Looking back " + advance[0] + ", " + advance[1]); 414 } 415 //System.out.println("weight = " + weight); 416 417 double score = Double.NEGATIVE_INFINITY; 418 int [] tr = transitions[l]; 419 double[] trs = transitionScores[l]; 420 421 int bestK = -1; // index into states[l] 422 for (int kc = 0; kc < tr.length; ++kc) { 423 int k = tr[kc]; // actual state index 424 double sk = sourceScores[k]; 425 426 /*System.out.println("kc is " + kc); 427 System.out.println("with from " + k + "=" + states[k].getName()); 428 System.out.println("prevScore = " + sk);*/ 429 if (sk != Double.NEGATIVE_INFINITY && !Double.isNaN(sk)) { 430 double t = trs[kc]; 431 //System.out.println("Transition score = " + t); 432 double newScore = t + sk; 433 if (newScore > score) { 434 score = newScore; 435 bestK = k; 436 //System.out.println("New best source at " + kc + " is " + score); 437 } 438 } 439 } 440 if (bestK != -1) { 441 curCol[l] = weight + score; 442 /*System.out.println("Weight = " + weight); 443 System.out.println("Score = " + score); 444 System.out.println( 445 "Creating " + states[bestK].getName() + 446 " -> " + states[l].getName() + 447 " (" + curCol[l] + ")" 448 );*/ 449 try { 450 State s = states[l]; 451 curBPs[l] = new BackPointer( 452 s, 453 oldBPs[bestK], 454 curCol[l] 455 ); 456 } catch (Throwable t) { 457 throw new BioError( 458 459 "Couldn't generate backpointer for " + states[l].getName() + 460 " back to " + states[bestK].getName(), t 461 ); 462 } 463 } else { 464 //System.out.println("No where to come from");; 465 curBPs[l] = null; 466 curCol[l] = Double.NaN; 467 } 468 } catch (Exception e) { 469 throw new BioError( 470 471 "Problem with state " + l + " -> " + states[l].getName(),e 472 ); 473 } catch (BioError e) { 474 throw new BioError( 475 476 "Error with state " + l + " -> " + states[l].getName(),e 477 ); 478 } 479 } 480 /*System.out.println("backpointers:"); 481 for(int l = 0; l < states.length; l++) { 482 System.out.print(states[l].getName() + "\t" + curCol[l] + "\t"); 483 BackPointer b = curBPs[l]; 484 if(b != null) { 485 for(BackPointer bb = b; bb.back != bb; bb = bb.back) { 486 System.out.print(bb.state.getName() + " -> "); 487 } 488 System.out.println("!"); 489 } else { 490 System.out.print("\n"); 491 } 492 }*/ 493 initializationHack = false; 494 } 495 } 496 497 public static class Maker implements CellCalculatorFactoryMaker { 498 public CellCalculatorFactory make(DP dp) { 499 return new DPInterpreter(dp); 500 } 501 } 502}