001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.bio.dp;
024
025import java.io.Serializable;
026
027import org.biojava.bio.BioError;
028import org.biojava.bio.dist.Distribution;
029import org.biojava.bio.dp.onehead.SingleDP;
030import org.biojava.bio.dp.onehead.SingleDPMatrix;
031import org.biojava.bio.symbol.AlphabetManager;
032import org.biojava.bio.symbol.IllegalAlphabetException;
033import org.biojava.bio.symbol.IllegalSymbolException;
034import org.biojava.bio.symbol.Symbol;
035import org.biojava.bio.symbol.SymbolList;
036
037/**
038 * <p>
039 * Train a hidden markov model using a sampling algorithm.
040 * </p>
041 *
042 * <p>
043 * Note: this class currently only works for one-head models.
044 * </p>
045 *
046 * @author Matthew Pocock
047 * @author Richard Holland
048 * @since 1.0
049 */
050public class BaumWelchSampler extends AbstractTrainer implements Serializable {
051
052          protected double singleSequenceIteration(
053                            ModelTrainer trainer,
054                            SymbolList symList
055                          ) throws IllegalSymbolException, IllegalTransitionException, IllegalAlphabetException {
056                            return this.singleSequenceIteration(trainer, symList, ScoreType.PROBABILITY);
057          }
058
059  protected double singleSequenceIteration(
060    ModelTrainer trainer,
061    SymbolList symList,
062    ScoreType scoreType
063  ) throws IllegalSymbolException, IllegalTransitionException, IllegalAlphabetException {
064    SingleDP dp = (SingleDP) getDP();
065    State [] states = dp.getStates();
066    int [][] backwardTransitions = dp.getBackwardTransitions();
067    double [][] backwardTransitionScores = dp.getBackwardTransitionScores(scoreType);
068    MarkovModel model = dp.getModel();
069
070    SymbolList [] rll = { symList };
071
072    // System.out.print("Forward...  ");
073    SingleDPMatrix fm = (SingleDPMatrix) dp.forwardMatrix(rll, scoreType);
074    double fs = fm.getScore();
075    // System.out.println("Score = " + fs);
076
077    // System.out.print("Backward... ");
078    SingleDPMatrix bm = (SingleDPMatrix) dp.backwardMatrix(rll, scoreType);
079    // System.out.println("Score = " + bs);
080
081    Symbol gap = AlphabetManager.getGapSymbol();
082
083    // state trainer
084
085    for (int i = 1; i <= symList.length(); i++) {
086      Symbol sym = symList.symbolAt(i);
087      double [] fsc = fm.scores[i];
088      double [] bsc = bm.scores[i];
089      double p = Math.random();
090      for (int s = 0; s < dp.getDotStatesIndex(); s++) {
091        if (! (states[s] instanceof MagicalState)) {
092          p -= Math.exp(fsc[s] + bsc[s] - fs);
093          if (p <= 0.0) {
094            trainer.addCount(
095              ((EmissionState) states[s]).getDistribution(),
096              sym,
097              1.0
098            );
099            break;
100          }
101        }
102      }
103    }
104
105    // transition trainer
106    for (int i = 0; i <= symList.length(); i++) {
107      Symbol sym = (i < symList.length())
108            ? symList.symbolAt(i + 1)
109            : gap;
110      double [] fsc = fm.scores[i];
111      double [] bsc = bm.scores[i+1];
112      double [] bsc2 = bm.scores[i];
113      double[] weightVector = dp.getEmission(sym, scoreType);
114      for (int s = 0; s < states.length; s++) {  // any -> emission transitions
115        int [] ts = backwardTransitions[s];
116        double [] tss = backwardTransitionScores[s];
117        Distribution dist = model.getWeights(states[s]);
118        double p = Math.random();
119        for (int tc = 0; tc < ts.length; tc++) {
120          int t = ts[tc];
121          double weight = 1.0;
122          if (states[t] instanceof EmissionState) {
123            weight = Math.exp(weightVector[t]);
124          }
125          if (weight != 0.0) {
126            if(t < dp.getDotStatesIndex()) {
127              p -= Math.exp(fsc[s] + tss[tc] + bsc[t] - fs) * weight;
128            } else {
129              p -= Math.exp(fsc[s] + tss[tc] + bsc2[t] - fs);
130            }
131            if (p <= 0.0) {
132              try {
133                trainer.addCount(
134                  dist,
135                  states[t],
136                  1.0
137                );
138              } catch (IllegalSymbolException ise) {
139                throw new BioError(
140                  "Transition in backwardTransitions dissapeared", ise);
141              }
142              break;
143            }
144          }
145        }
146      }
147    }
148    return fs;
149  }
150
151  public BaumWelchSampler(DP dp) {
152    super(dp);
153  }
154}