001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023package org.biojava.bio.dp; 024 025import java.io.Serializable; 026 027import org.biojava.bio.dist.Distribution; 028import org.biojava.bio.dp.onehead.SingleDP; 029import org.biojava.bio.dp.onehead.SingleDPMatrix; 030import org.biojava.bio.symbol.AlphabetManager; 031import org.biojava.bio.symbol.IllegalAlphabetException; 032import org.biojava.bio.symbol.IllegalSymbolException; 033import org.biojava.bio.symbol.Symbol; 034import org.biojava.bio.symbol.SymbolList; 035 036/** 037* <p> 038* Train a hidden markov model using maximum likelihood. 039* </p> 040* 041* <p> 042* Note: this class currently only works for one-head models. 043* </p> 044* 045* @author Matthew Pocock 046* @author Thomas Down 047* @author Todd Riley 048* @since 1.0 049*/ 050public class BaumWelchTrainer extends AbstractTrainer implements Serializable { 051 protected double singleSequenceIteration( 052 ModelTrainer trainer, 053 SymbolList symList 054 ) throws IllegalSymbolException, IllegalTransitionException, IllegalAlphabetException { 055 ScoreType scoreType = ScoreType.PROBABILITY; 056 SingleDP dp = (SingleDP) getDP(); 057 State [] states = dp.getStates(); 058 int [][] backwardTransitions = dp.getBackwardTransitions(); 059 double [][] backwardTransitionScores = dp.getBackwardTransitionScores(scoreType); 060 MarkovModel model = dp.getModel(); 061 062 SymbolList [] rll = { symList }; 063 064 // System.out.print("Forward... "); 065 SingleDPMatrix fm = (SingleDPMatrix) dp.forwardMatrix(rll, scoreType); 066 double fs = fm.getScore(); 067 // System.out.println("Score = " + fs); 068 069 // System.out.print("Backward... "); 070 SingleDPMatrix bm = (SingleDPMatrix) dp.backwardMatrix(rll, scoreType); 071 // System.out.println("Score = " + bs); 072 073 Symbol gap = AlphabetManager.getGapSymbol(); 074 075 // state trainer 076 for (int i = 1; i <= symList.length(); i++) { 077 Symbol sym = symList.symbolAt(i); 078 double [] fsc = fm.scores[i]; 079 double [] bsc = bm.scores[i]; 080 for (int s = 0; s < dp.getDotStatesIndex(); s++) { 081 if (! (states[s] instanceof MagicalState)) { 082 trainer.addCount( 083 ((EmissionState) states[s]).getDistribution(), 084 sym, 085 mathExp(fsc[s] + bsc[s] - fs) 086 ); 087 } 088 } 089 } 090 091 // transition trainer 092 for (int i = 0; i <= symList.length(); i++) { 093 Symbol sym = (i < symList.length()) 094 ? symList.symbolAt(i + 1) 095 : gap; 096 double [] fsc = fm.scores[i]; 097 double [] bsc = bm.scores[i+1]; 098 double [] bsc2 = bm.scores[i]; 099 double[] weightVector = dp.getEmission(sym, scoreType); 100 for (int s = 0; s < states.length; s++) { // any -> emission transitions 101 int [] ts = backwardTransitions[s]; 102 double [] tss = backwardTransitionScores[s]; 103 Distribution dist = model.getWeights(states[s]); 104 for (int tc = 0; tc < ts.length; tc++) { 105 int t = ts[tc]; 106 if(t < dp.getDotStatesIndex()) { 107 double weight = mathExp(weightVector[t]); 108 if (weight != 0.0) { 109 trainer.addCount( 110 dist, states[t], 111 mathExp( 112 fsc[s] + tss[tc] + bsc[t] 113 - 114 fs 115 ) * weight 116 ); 117 } 118 } else { 119 trainer.addCount( 120 dist, states[t], 121 mathExp( 122 fsc[s] + tss[tc] + bsc2[t] 123 - 124 fs 125 ) 126 ); 127 } 128 } 129 } 130 } 131 132 return fs; 133 } 134 135 136 public double mathExp(double arg) { 137 //Double argObj = new Double(arg); 138 Double resultObj; 139 140 if (Double.isNaN(arg)) { 141 //System.err.println("NaN encountered as arg to Math.exp in BaumWelch Training Loop"); 142 arg = Double.NEGATIVE_INFINITY; 143 //System.exit(-1); 144 } 145 resultObj = new Double(Math.exp(arg)); 146 return(resultObj.doubleValue()); 147 } 148 149 150 151 public BaumWelchTrainer(DP dp) { 152 super(dp); 153 } 154} 155 156