001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023package org.biojava.bio.dp; 024 025import java.io.Serializable; 026 027import org.biojava.bio.BioError; 028import org.biojava.bio.dist.Distribution; 029import org.biojava.bio.dp.onehead.SingleDP; 030import org.biojava.bio.dp.onehead.SingleDPMatrix; 031import org.biojava.bio.symbol.AlphabetManager; 032import org.biojava.bio.symbol.IllegalAlphabetException; 033import org.biojava.bio.symbol.IllegalSymbolException; 034import org.biojava.bio.symbol.Symbol; 035import org.biojava.bio.symbol.SymbolList; 036 037/** 038 * <p> 039 * Train a hidden markov model using a sampling algorithm. 040 * </p> 041 * 042 * <p> 043 * Note: this class currently only works for one-head models. 044 * </p> 045 * 046 * @author Matthew Pocock 047 * @author Richard Holland 048 * @since 1.0 049 */ 050public class BaumWelchSampler extends AbstractTrainer implements Serializable { 051 052 protected double singleSequenceIteration( 053 ModelTrainer trainer, 054 SymbolList symList 055 ) throws IllegalSymbolException, IllegalTransitionException, IllegalAlphabetException { 056 return this.singleSequenceIteration(trainer, symList, ScoreType.PROBABILITY); 057 } 058 059 protected double singleSequenceIteration( 060 ModelTrainer trainer, 061 SymbolList symList, 062 ScoreType scoreType 063 ) throws IllegalSymbolException, IllegalTransitionException, IllegalAlphabetException { 064 SingleDP dp = (SingleDP) getDP(); 065 State [] states = dp.getStates(); 066 int [][] backwardTransitions = dp.getBackwardTransitions(); 067 double [][] backwardTransitionScores = dp.getBackwardTransitionScores(scoreType); 068 MarkovModel model = dp.getModel(); 069 070 SymbolList [] rll = { symList }; 071 072 // System.out.print("Forward... "); 073 SingleDPMatrix fm = (SingleDPMatrix) dp.forwardMatrix(rll, scoreType); 074 double fs = fm.getScore(); 075 // System.out.println("Score = " + fs); 076 077 // System.out.print("Backward... "); 078 SingleDPMatrix bm = (SingleDPMatrix) dp.backwardMatrix(rll, scoreType); 079 // System.out.println("Score = " + bs); 080 081 Symbol gap = AlphabetManager.getGapSymbol(); 082 083 // state trainer 084 085 for (int i = 1; i <= symList.length(); i++) { 086 Symbol sym = symList.symbolAt(i); 087 double [] fsc = fm.scores[i]; 088 double [] bsc = bm.scores[i]; 089 double p = Math.random(); 090 for (int s = 0; s < dp.getDotStatesIndex(); s++) { 091 if (! (states[s] instanceof MagicalState)) { 092 p -= Math.exp(fsc[s] + bsc[s] - fs); 093 if (p <= 0.0) { 094 trainer.addCount( 095 ((EmissionState) states[s]).getDistribution(), 096 sym, 097 1.0 098 ); 099 break; 100 } 101 } 102 } 103 } 104 105 // transition trainer 106 for (int i = 0; i <= symList.length(); i++) { 107 Symbol sym = (i < symList.length()) 108 ? symList.symbolAt(i + 1) 109 : gap; 110 double [] fsc = fm.scores[i]; 111 double [] bsc = bm.scores[i+1]; 112 double [] bsc2 = bm.scores[i]; 113 double[] weightVector = dp.getEmission(sym, scoreType); 114 for (int s = 0; s < states.length; s++) { // any -> emission transitions 115 int [] ts = backwardTransitions[s]; 116 double [] tss = backwardTransitionScores[s]; 117 Distribution dist = model.getWeights(states[s]); 118 double p = Math.random(); 119 for (int tc = 0; tc < ts.length; tc++) { 120 int t = ts[tc]; 121 double weight = 1.0; 122 if (states[t] instanceof EmissionState) { 123 weight = Math.exp(weightVector[t]); 124 } 125 if (weight != 0.0) { 126 if(t < dp.getDotStatesIndex()) { 127 p -= Math.exp(fsc[s] + tss[tc] + bsc[t] - fs) * weight; 128 } else { 129 p -= Math.exp(fsc[s] + tss[tc] + bsc2[t] - fs); 130 } 131 if (p <= 0.0) { 132 try { 133 trainer.addCount( 134 dist, 135 states[t], 136 1.0 137 ); 138 } catch (IllegalSymbolException ise) { 139 throw new BioError( 140 "Transition in backwardTransitions dissapeared", ise); 141 } 142 break; 143 } 144 } 145 } 146 } 147 } 148 return fs; 149 } 150 151 public BaumWelchSampler(DP dp) { 152 super(dp); 153 } 154}