001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.bio.dp;
024
025import java.io.Serializable;
026
027import org.biojava.bio.dist.Distribution;
028import org.biojava.bio.dp.onehead.SingleDP;
029import org.biojava.bio.dp.onehead.SingleDPMatrix;
030import org.biojava.bio.symbol.AlphabetManager;
031import org.biojava.bio.symbol.IllegalAlphabetException;
032import org.biojava.bio.symbol.IllegalSymbolException;
033import org.biojava.bio.symbol.Symbol;
034import org.biojava.bio.symbol.SymbolList;
035
036/**
037* <p>
038* Train a hidden markov model using maximum likelihood.
039* </p>
040*
041* <p>
042* Note: this class currently only works for one-head models.
043* </p>
044*
045* @author Matthew Pocock
046* @author Thomas Down
047* @author Todd Riley
048* @since 1.0
049*/
050public class BaumWelchTrainer extends AbstractTrainer implements Serializable {
051 protected double singleSequenceIteration(
052   ModelTrainer trainer,
053   SymbolList symList
054 ) throws IllegalSymbolException, IllegalTransitionException, IllegalAlphabetException {
055   ScoreType scoreType = ScoreType.PROBABILITY;
056   SingleDP dp = (SingleDP) getDP();
057   State [] states = dp.getStates();
058   int [][] backwardTransitions = dp.getBackwardTransitions();
059   double [][] backwardTransitionScores = dp.getBackwardTransitionScores(scoreType);
060   MarkovModel model = dp.getModel();
061
062   SymbolList [] rll = { symList };
063
064   // System.out.print("Forward...  ");
065   SingleDPMatrix fm = (SingleDPMatrix) dp.forwardMatrix(rll, scoreType);
066   double fs = fm.getScore();
067   // System.out.println("Score = " + fs);
068
069   // System.out.print("Backward... ");
070   SingleDPMatrix bm = (SingleDPMatrix) dp.backwardMatrix(rll, scoreType);
071   // System.out.println("Score = " + bs);
072
073   Symbol gap = AlphabetManager.getGapSymbol();
074
075   // state trainer
076   for (int i = 1; i <= symList.length(); i++) {
077     Symbol sym = symList.symbolAt(i);
078     double [] fsc = fm.scores[i];
079     double [] bsc = bm.scores[i];
080     for (int s = 0; s < dp.getDotStatesIndex(); s++) {
081       if (! (states[s] instanceof MagicalState)) {
082         trainer.addCount(
083           ((EmissionState) states[s]).getDistribution(),
084           sym,
085           mathExp(fsc[s] + bsc[s] - fs)
086         );
087       }
088     }
089   }
090
091   // transition trainer
092   for (int i = 0; i <= symList.length(); i++) {
093     Symbol sym = (i < symList.length())
094           ? symList.symbolAt(i + 1)
095           : gap;
096     double [] fsc = fm.scores[i];
097     double [] bsc = bm.scores[i+1];
098     double [] bsc2 = bm.scores[i];
099     double[] weightVector = dp.getEmission(sym, scoreType);
100     for (int s = 0; s < states.length; s++) {  // any -> emission transitions
101       int [] ts = backwardTransitions[s];
102       double [] tss = backwardTransitionScores[s];
103       Distribution dist = model.getWeights(states[s]);
104       for (int tc = 0; tc < ts.length; tc++) {
105         int t = ts[tc];
106         if(t < dp.getDotStatesIndex()) {
107           double weight = mathExp(weightVector[t]);
108           if (weight != 0.0) {
109             trainer.addCount(
110               dist, states[t],
111               mathExp(
112                 fsc[s] + tss[tc] + bsc[t]
113                 -
114                 fs
115               ) * weight
116             );
117           }
118         } else {
119           trainer.addCount(
120             dist, states[t],
121             mathExp(
122               fsc[s] + tss[tc] + bsc2[t]
123               -
124               fs
125             )
126           );
127         }
128       }
129     }
130   }
131
132   return fs;
133 }
134
135
136   public double mathExp(double arg) {
137                //Double argObj = new Double(arg);
138                Double resultObj;
139
140                if (Double.isNaN(arg)) {
141                    //System.err.println("NaN encountered as arg to Math.exp in BaumWelch Training Loop");
142                    arg = Double.NEGATIVE_INFINITY;
143           //System.exit(-1);
144                }
145                resultObj = new Double(Math.exp(arg));
146                return(resultObj.doubleValue());
147   }
148
149
150
151 public BaumWelchTrainer(DP dp) {
152   super(dp);
153 }
154}
155
156