001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023package org.biojava.bio.dp; 024 025import java.io.Serializable; 026import java.util.ArrayList; 027import java.util.Collections; 028import java.util.Comparator; 029import java.util.HashMap; 030import java.util.HashSet; 031import java.util.Iterator; 032import java.util.LinkedList; 033import java.util.List; 034import java.util.ListIterator; 035import java.util.Map; 036import java.util.Set; 037 038import org.biojava.bio.BioError; 039import org.biojava.bio.BioException; 040import org.biojava.bio.dist.Distribution; 041import org.biojava.bio.symbol.DoubleAlphabet; 042import org.biojava.bio.symbol.FiniteAlphabet; 043import org.biojava.bio.symbol.IllegalAlphabetException; 044import org.biojava.bio.symbol.IllegalSymbolException; 045import org.biojava.bio.symbol.SimpleSymbolList; 046import org.biojava.bio.symbol.Symbol; 047import org.biojava.bio.symbol.SymbolList; 048import org.biojava.utils.ChangeEvent; 049import org.biojava.utils.ChangeListener; 050import org.biojava.utils.ChangeType; 051import org.biojava.utils.ChangeVetoException; 052 053/** 054 * <p> 055 * Objects that can perform dymamic programming operations upon sequences with 056 * HMMs. 057 * </p> 058 * 059 * <p> 060 * The three main DP operations are Forwards, Backwards and Viterbi. Forwards 061 * and Backwards calculate the probability of the sequences having been made in 062 * any way by the model. Viterbi finds the most supported way that the sequence 063 * could have been made. 064 * </p> 065 * 066 * <p> 067 * Each of the functions can return the dynamic-programming matrix containing 068 * the intermediate results. This may be useful for model training, or for 069 * visualisation. 070 * </p> 071 * 072 * <p> 073 * Each of the funcitons can be calculated using the model probabilities, the 074 * null-model probabilities or the odds (ratio between the two). For Forwards 075 * and Backwards, the odds calculations produce numbers with questionable basis 076 * in reality. For Viterbi with odds, you will recieve the path through the 077 * model that is most different from the null model, and supported by the 078 * probabilities. 079 * </p> 080 * 081 * @author Matthew Pocock 082 * @author Thomas Down 083 */ 084public abstract class DP { 085 private static List NO_ADVANCE = new ArrayList(); 086 087 private int[] getNoAdvance() { 088 int heads = getModel().advance().length; 089 int[] no_advance = (int[]) NO_ADVANCE.get(heads); 090 091 if (no_advance == null) { 092 no_advance = new int[heads]; 093 for (int i = 0; i < heads; i++) { 094 no_advance[i] = 0; 095 } 096 097 NO_ADVANCE.add(heads, no_advance); 098 } 099 100 return no_advance; 101 } 102 103 /** 104 * Scores the SymbolList from symbol start to symbol (start+columns) with a 105 * weight matrix. 106 * 107 * @param matrix the weight matrix used to evaluate the sequences 108 * @param symList the SymbolList to assess 109 * @param start the index of the first symbol in the window to evaluate 110 * @return the log probability or likelyhood of this weight matrix 111 * having generated symbols start to (start + columns) of symList 112 */ 113 public static double scoreWeightMatrix( 114 WeightMatrix matrix, SymbolList symList, int start) 115 throws IllegalSymbolException { 116 double score = 0; 117 int cols = matrix.columns(); 118 119 for (int c = 0; c < cols; c++) { 120 score += Math.log( 121 matrix.getColumn(c).getWeight(symList.symbolAt(c + start))); 122 } 123 124 return score; 125 } 126 127 /** 128 * Scores the SymbolList from symbol start to symbol (start+columns) with a 129 * weight matrix using a particular ScoreType. 130 * 131 * <p> 132 * This method allows you to use score types such as ScoreType.ODDS. The other 133 * scoreWeightMatrix methods gives a result similar or identical to 134 * ScoreType.PROBABILITY. 135 * </p> 136 * 137 * @param matrix the weight matrix used to evaluate the sequences 138 * @param symList the SymbolList to assess 139 * @param scoreType the score type to apply 140 * @param start the index of the first symbol in the window to evaluate 141 * @return the sum of log scores of this weight matrix 142 * having generated symbols start to (start + columns) of symList 143 * @since 1.4 144 */ 145 public static double scoreWeightMatrix( 146 WeightMatrix matrix, 147 SymbolList symList, 148 ScoreType scoreType, 149 int start) 150 throws IllegalSymbolException { 151 double score = 0; 152 int cols = matrix.columns(); 153 154 for (int c = 0; c < cols; c++) { 155 score += Math.log(scoreType.calculateScore( 156 matrix.getColumn(c), symList.symbolAt(c + start))); 157 } 158 159 return score; 160 } 161 public static MarkovModel flatView(MarkovModel model) 162 throws IllegalAlphabetException, IllegalSymbolException { 163 for (Iterator i = model.stateAlphabet().iterator(); i.hasNext();) { 164 State s = (State) i.next(); 165 if ( 166 !(s instanceof DotState) && 167 !(s instanceof EmissionState) 168 ) { 169 return new FlatModel(model); 170 } 171 } 172 173 return model; 174 } 175 176 public State[] stateList(MarkovModel mm) 177 throws IllegalSymbolException, IllegalTransitionException, 178 BioException { 179 FiniteAlphabet alpha = mm.stateAlphabet(); 180 181 List emissionStates = new ArrayList(); 182 HMMOrderByTransition comp = new HMMOrderByTransition(mm); 183 List dotStates = new LinkedList(); 184 for (Iterator addStates = alpha.iterator(); addStates.hasNext();) { 185 Object state = addStates.next(); 186 if (state instanceof MagicalState) { 187 emissionStates.add(0, state); 188 } else if (state instanceof EmissionState) { 189 emissionStates.add(state); 190 } else { 191 ListIterator checkOld = dotStates.listIterator(); 192 int insertPos = -1; 193 while (checkOld.hasNext() && insertPos == -1) { 194 Object oldState = checkOld.next(); 195 if (comp.compare(state, oldState) == HMMOrderByTransition.LESS_THAN) { 196 insertPos = checkOld.nextIndex() - 1; 197 } 198 } 199 if (insertPos >= 0) { 200 dotStates.add(insertPos, state); 201 } else { 202 dotStates.add(state); 203 } 204 } 205 } 206 Collections.sort(emissionStates, new Comparator() { 207 public int compare(Object o1, Object o2) { 208 State s = (State) o1; 209 State t = (State) o2; 210 211 // sort by advance 212 int[] sa; 213 if (s instanceof EmissionState) { 214 sa = ((EmissionState) s).getAdvance(); 215 } else { 216 sa = getNoAdvance(); 217 } 218 219 int[] ta; 220 if (t instanceof EmissionState) { 221 ta = ((EmissionState) t).getAdvance(); 222 } else { 223 ta = getNoAdvance(); 224 } 225 226 for (int i = 0; i < sa.length; i++) { 227 if (sa[i] > ta[i]) { 228 return -1; 229 } else if (sa[i] < ta[i]) { 230 return +1; 231 } 232 } 233 234 // give up - sort by name 235 return s.getName().compareTo(t.getName()); 236 } 237 }); 238 State[] sl = new State[emissionStates.size() + dotStates.size()]; 239 int i = 0; 240 for (Iterator si = emissionStates.iterator(); si.hasNext();) { 241 EmissionState ex = (EmissionState) si.next(); 242 int[] ad = ex.getAdvance(); 243 if (ad.length != mm.advance().length) { 244 throw new BioException( 245 "State " + ex.getName() + " advances " + ad.length + " heads. " + 246 " however, the model " + mm.stateAlphabet().getName() + 247 " advances " + mm.advance().length + " heads." 248 ); 249 } 250 for (int adi = 0; ad != null && adi < ad.length; adi++) { 251 if (ad[adi] != 0) { 252 ad = null; 253 } 254 } 255 if (ad != null) { 256 throw new Error( 257 "State " + ex.getName() + " has advance " + ad 258 ); 259 } 260 sl[i++] = ex; 261 } 262 for (Iterator si = dotStates.iterator(); si.hasNext();) { 263 sl[i++] = (State) si.next(); 264 } 265 return sl; 266 } 267 268 /** 269 * Returns a matrix for the specified States describing all 270 * valid Transitions between those States. 271 * <p> 272 * The matrix is 2-dimensional. The primary array has an element 273 * corresponding to every State in the states argument. That 274 * element is itself an array the elements of which identify 275 * the States that can reach that State. The source States 276 * are identified by their index within the states [] array. 277 * @param model MarkovModel to be analysed. 278 * @param states The States for which the transition matrix is to be determined. 279 */ 280 public static int[][] forwardTransitions( 281 MarkovModel model, 282 State[] states 283 ) throws IllegalSymbolException { 284 int stateCount = states.length; 285 int[][] transitions = new int[stateCount][]; 286 287 for (int i = 0; i < stateCount; i++) { 288 int[] tmp = new int[stateCount]; 289 int len = 0; 290 FiniteAlphabet trans = model.transitionsTo(states[i]); 291 for (int j = 0; j < stateCount; j++) { 292 if (trans.contains(states[j])) { 293 tmp[len++] = j; 294 } 295 } 296 int[] tmp2 = new int[len]; 297 for (int j = 0; j < len; j++) { 298 tmp2[j] = tmp[j]; 299 } 300 transitions[i] = tmp2; 301 } 302 303 return transitions; 304 } 305 306 /** 307 * Compute the log(score) of all transitions 308 * between the specified States. The layout 309 * of the array is identical to that of the transitions 310 * array. 311 * <p> 312 * Note that all parameters <b>MUST</b> be 313 * consistent with each other!!!! 314 * <p> 315 * @param model The model for which the data is to be computed. 316 * @param states The States within that model for which scores are required. 317 * @param transitions The transition matrix obtained by calling forwardTransitions() with the above argument values. 318 * @param scoreType The type of score to be evaluated. 319 */ 320 public static double[][] forwardTransitionScores( 321 MarkovModel model, 322 State[] states, 323 int[][] transitions, 324 ScoreType scoreType 325 ) { 326 // System.out.println("forwardTransitionScores"); 327 int stateCount = states.length; 328 double[][] scores = new double[stateCount][]; 329 330 for (int i = 0; i < stateCount; i++) { 331 State is = states[i]; 332 scores[i] = new double[transitions[i].length]; 333 for (int j = 0; j < scores[i].length; j++) { 334 try { 335 scores[i][j] = Math.log(scoreType.calculateScore( 336 model.getWeights(states[transitions[i][j]]), 337 is 338 )); 339 /*System.out.println( 340 states[transitions[i][j]] + "\t-> " + 341 is.getName() + "\t = " + 342 scores[i][j] + "\t(" + 343 scoreType.calculateScore( 344 model.getWeights(states[transitions[i][j]]), 345 is 346 ) 347 );*/ 348 } catch (IllegalSymbolException ite) { 349 throw new BioError( 350 "Transition listed in transitions array has dissapeared.", 351 ite); 352 } 353 } 354 } 355 356 return scores; 357 } 358 359 public static int[][] backwardTransitions( 360 MarkovModel model, 361 State[] states 362 ) throws IllegalSymbolException { 363 int stateCount = states.length; 364 int[][] transitions = new int[stateCount][]; 365 366 for (int i = 0; i < stateCount; i++) { 367 int[] tmp = new int[stateCount]; 368 int len = 0; 369 FiniteAlphabet trans = model.transitionsFrom(states[i]); 370 for (int j = 0; j < stateCount; j++) { 371 if (trans.contains(states[j])) { 372 tmp[len++] = j; 373 } 374 } 375 int[] tmp2 = new int[len]; 376 for (int j = 0; j < len; j++) { 377 tmp2[j] = tmp[j]; 378 } 379 transitions[i] = tmp2; 380 } 381 382 return transitions; 383 } 384 385 public static double[][] backwardTransitionScores(MarkovModel model, 386 State[] states, 387 int[][] transitions, 388 ScoreType scoreType 389 ) { 390 int stateCount = states.length; 391 double[][] scores = new double[stateCount][]; 392 393 for (int i = 0; i < stateCount; i++) { 394 State is = states[i]; 395 scores[i] = new double[transitions[i].length]; 396 for (int j = 0; j < scores[i].length; j++) { 397 try { 398 scores[i][j] = Math.log(scoreType.calculateScore( 399 model.getWeights(is), 400 states[transitions[i][j]] 401 )); 402 } catch (IllegalSymbolException ite) { 403 throw new BioError( 404 "Transition listed in transitions array has dissapeared", 405 ite); 406 } 407 } 408 } 409 410 return scores; 411 } 412 413 private MarkovModel model; 414 private State[] states; 415 private int[][] forwardTransitions; 416 private int[][] backwardTransitions; 417 private int dotStatesIndex; 418 private int lockCount = 0; 419 420 public int getDotStatesIndex() { 421 return dotStatesIndex; 422 } 423 424 public MarkovModel getModel() { 425 return model; 426 } 427 428 public State[] getStates() { 429 return states; 430 } 431 432 public int[][] getForwardTransitions() { 433 return forwardTransitions; 434 } 435 436 private Map forwardTransitionScores; 437 private Map backwardTransitionScores; 438 439 public double[][] getForwardTransitionScores(ScoreType scoreType) { 440 double[][] ts = (double[][]) forwardTransitionScores.get(scoreType); 441 if (ts == null) { 442 forwardTransitionScores.put(scoreType, ts = forwardTransitionScores( 443 getModel(), getStates(), forwardTransitions, scoreType 444 )); 445 } 446 return ts; 447 } 448 449 public int[][] getBackwardTransitions() { 450 return backwardTransitions; 451 } 452 453 public double[][] getBackwardTransitionScores(ScoreType scoreType) { 454 double[][] ts = (double[][]) backwardTransitionScores.get(scoreType); 455 if (ts == null) { 456 backwardTransitionScores.put(scoreType, ts = backwardTransitionScores( 457 getModel(), getStates(), backwardTransitions, scoreType 458 )); 459 } 460 return ts; 461 } 462 463 public void lockModel() { 464 if (lockCount++ == 0) { 465 getModel().addChangeListener(ChangeListener.ALWAYS_VETO, ChangeType.UNKNOWN); 466 } 467 } 468 469 public void unlockModel() { 470 if (--lockCount == 0) { 471 getModel().removeChangeListener(ChangeListener.ALWAYS_VETO, ChangeType.UNKNOWN); 472 } 473 } 474 475 public void update() { 476 try { 477 if(this.states == null) { 478 this.states = stateList(model); 479 this.forwardTransitions = forwardTransitions(model, states); 480 this.backwardTransitions = backwardTransitions(model, states); 481 482 // Find first dot state 483 int i; 484 for (i = 0; i < states.length; ++i) { 485 if (!(states[i] instanceof EmissionState)) { 486 break; 487 } 488 } 489 dotStatesIndex = i; 490 } 491 492 this.forwardTransitionScores.clear(); 493 this.backwardTransitionScores.clear(); 494 } catch (Exception e) { 495 throw new BioError("Something is seriously wrong with the DP code", e); 496 } 497 } 498 499 public DP(MarkovModel model){ 500 this.setModel(model); 501 } 502 503 /** 504 * This method will result in a DP with no model. Use the setModel() method 505 * to set the model before use. 506 */ 507 public DP(){} 508 509 public void setModel(MarkovModel model){ 510 this.model = model; 511 this.forwardTransitionScores = new HashMap(); 512 this.backwardTransitionScores = new HashMap(); 513 this.update(); 514 515 model.addChangeListener(UPDATER, ChangeType.UNKNOWN); 516 } 517 518 public abstract double forward(SymbolList[] symList, ScoreType scoreType) 519 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException; 520 521 public abstract double backward(SymbolList[] symList, ScoreType scoreType) 522 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException; 523 524 public abstract DPMatrix forwardMatrix(SymbolList[] symList, ScoreType scoreType) 525 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException; 526 527 public abstract DPMatrix backwardMatrix(SymbolList[] symList, ScoreType scoreType) 528 throws IllegalSymbolException, IllegalAlphabetException, IllegalTransitionException; 529 530 public abstract DPMatrix forwardMatrix(SymbolList[] symList, DPMatrix matrix, ScoreType scoreType) 531 throws IllegalArgumentException, IllegalSymbolException, 532 IllegalAlphabetException, IllegalTransitionException; 533 534 public abstract DPMatrix backwardMatrix(SymbolList[] symList, DPMatrix matrix, ScoreType scoreType) 535 throws IllegalArgumentException, IllegalSymbolException, 536 IllegalAlphabetException, IllegalTransitionException; 537 538 public abstract StatePath viterbi(SymbolList[] symList, ScoreType scoreType) 539 throws IllegalSymbolException, IllegalArgumentException, IllegalAlphabetException, IllegalTransitionException; 540 541 public DPMatrix forwardsBackwards(SymbolList[] symList, ScoreType scoreType) 542 throws BioException { 543 try { 544 System.out.println("Making backward matrix"); 545 final DPMatrix bMatrix = backwardMatrix(symList, scoreType); 546 System.out.println("Making forward matrix"); 547 final DPMatrix fMatrix = forwardMatrix(symList, scoreType); 548 549 System.out.println("Making forward/backward matrix"); 550 return new DPMatrix() { 551 public double getCell(int[] index) { 552 return fMatrix.getCell(index) + bMatrix.getCell(index); 553 } 554 555 public double getScore() { 556 return fMatrix.getScore(); 557 } 558 559 public MarkovModel model() { 560 return fMatrix.model(); 561 } 562 563 public SymbolList[] symList() { 564 return fMatrix.symList(); 565 } 566 567 public State[] states() { 568 return fMatrix.states(); 569 } 570 }; 571 } catch (Exception e) { 572 throw new BioException("Couldn't build forwards-backwards matrix", e); 573 } 574 } 575 576 /** 577 * <p> 578 * Generates an alignment from a model. 579 * </p> 580 * 581 * <p> 582 * If the length is set to -1 then the model length will be sampled 583 * using the model's transition to the end state. If the length is 584 * fixed using length, then the transitions to the end state are implicitly 585 * invoked. 586 * </p> 587 * 588 * @param length the length of the sequence to generate 589 * @return a StatePath generated at random 590 */ 591 public StatePath generate(int length) 592 throws IllegalSymbolException, BioException { 593 List tokenList = new ArrayList(); 594 List stateList = new ArrayList(); 595 List scoreList = new ArrayList(); 596 double totScore = 0.0; 597 double symScore = 0.0; 598 int i = length; 599 State oldState; 600 Symbol token; 601 602 oldState = (State) model.getWeights(model.magicalState()).sampleSymbol(); 603 symScore += model.getWeights(model.magicalState()).getWeight(oldState); 604 605 DoubleAlphabet dAlpha = DoubleAlphabet.getInstance(); 606 if (oldState instanceof EmissionState) { 607 EmissionState eState = (EmissionState) oldState; 608 token = eState.getDistribution().sampleSymbol(); 609 symScore += eState.getDistribution().getWeight(token); 610 stateList.add(oldState); 611 tokenList.add(token); 612 scoreList.add(dAlpha.getSymbol(symScore)); 613 totScore += symScore; 614 symScore = 0.0; 615 i--; 616 } 617 618 while (i != 0) { 619 State newState = null; 620 Distribution dist = model.getWeights(oldState); 621 do { 622 newState = (State) dist.sampleSymbol(); 623 } while (newState == model.magicalState() && i > 0); 624 try { 625 symScore += dist.getWeight(newState); 626 } catch (IllegalSymbolException ise) { 627 throw new BioError( 628 "Transition returned from sampleTransition is invalid", 629 ise); 630 } 631 632 if (newState == model.magicalState()) { 633 break; 634 } 635 636 if (newState instanceof EmissionState) { 637 EmissionState eState = (EmissionState) newState; 638 token = eState.getDistribution().sampleSymbol(); 639 symScore += eState.getDistribution().getWeight(token); 640 stateList.add(newState); 641 tokenList.add(token); 642 scoreList.add(dAlpha.getSymbol(symScore)); 643 totScore += symScore; 644 symScore = 0.0; 645 i--; 646 } 647 oldState = newState; 648 } 649 650 SymbolList tokens = new SimpleSymbolList(model.emissionAlphabet(), tokenList); 651 SymbolList states = new SimpleSymbolList(model.stateAlphabet(), stateList); 652 SymbolList scores = new SimpleSymbolList(dAlpha, scoreList); 653 654 return new SimpleStatePath( 655 totScore, 656 tokens, 657 states, 658 scores 659 ); 660 } 661 662 public static class ReverseIterator implements Iterator, Serializable { 663 private SymbolList sym; 664 private int index; 665 666 public ReverseIterator(SymbolList sym) { 667 this.sym = sym; 668 index = sym.length(); 669 } 670 671 public boolean hasNext() { 672 return index > 0; 673 } 674 675 public Object next() { 676 return sym.symbolAt(index--); 677 } 678 679 public void remove() throws UnsupportedOperationException { 680 throw new UnsupportedOperationException("This itterator can not cause modifications"); 681 } 682 } 683 684 private final ChangeListener UPDATER = new ChangeListener() { 685 public void preChange(ChangeEvent ce) 686 throws ChangeVetoException { 687 } 688 689 public void postChange(ChangeEvent ce) { 690 if (ce.getType().isMatchingType(MarkovModel.ARCHITECTURE)) { 691 System.out.println("architecture alterred"); 692 states = null; 693 } 694 695 if ( 696 (ce.getType().isMatchingType(MarkovModel.ARCHITECTURE)) || 697 (ce.getType().isMatchingType(MarkovModel.PARAMETER)) 698 ) { 699 update(); 700 } 701 } 702 }; 703 704 private static class HMMOrderByTransition { 705 public final static Object GREATER_THAN = new Object(); 706 public final static Object LESS_THAN = new Object(); 707 public final static Object EQUAL = new Object(); 708 public final static Object DISJOINT = new Object(); 709 710 private MarkovModel mm; 711 712 private HMMOrderByTransition(MarkovModel mm) { 713 this.mm = mm; 714 } 715 716 public Object compare(Object o1, Object o2) 717 throws IllegalTransitionException, IllegalSymbolException { 718 if (o1 == o2) { 719 return EQUAL; 720 } 721 State s1 = (State) o1; 722 State s2 = (State) o2; 723 724 if (transitionsTo(s1, s2)) { 725 return LESS_THAN; 726 } 727 if (transitionsTo(s2, s1)) { 728 return GREATER_THAN; 729 } 730 731 return DISJOINT; 732 } 733 734 private boolean transitionsTo(State from, State to) 735 throws IllegalTransitionException, IllegalSymbolException { 736 Set checkedSet = new HashSet(); 737 Set workingSet = new HashSet(); 738 for ( 739 Iterator i = mm.transitionsFrom(from).iterator(); 740 i.hasNext(); 741 ) { 742 workingSet.add(i.next()); 743 } 744 745 while (workingSet.size() > 0) { 746 Set newWorkingSet = new HashSet(); 747 for (Iterator i = workingSet.iterator(); i.hasNext();) { 748 State s = (State) i.next(); 749 if (s instanceof EmissionState) { 750 continue; 751 } 752 if (s == from) { 753 throw new IllegalTransitionException( 754 from, from, "Loop in dot states." 755 ); 756 } 757 if (s == to) { 758 return true; 759 } 760 for (Iterator j = mm.transitionsFrom(s).iterator(); j.hasNext();) { 761 State s2 = (State) j.next(); 762 if (!workingSet.contains(s2) && !checkedSet.contains(s2)) { 763 newWorkingSet.add(s2); 764 } 765 } 766 checkedSet.add(s); 767 } 768 workingSet = newWorkingSet; 769 } 770 return false; 771 } 772 } 773} 774