001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022package org.biojava.bio.dp;
023
024import java.util.Iterator;
025
026import org.biojava.bio.BioException;
027import org.biojava.bio.dist.Distribution;
028import org.biojava.bio.dist.DistributionTrainerContext;
029import org.biojava.bio.dist.SimpleDistributionTrainerContext;
030import org.biojava.bio.symbol.FiniteAlphabet;
031import org.biojava.bio.symbol.IllegalSymbolException;
032import org.biojava.bio.symbol.Symbol;
033import org.biojava.utils.ChangeVetoException;
034
035public class SimpleHMMTrainer
036    implements HMMTrainer
037{
038    DistributionTrainerContext dtc = new SimpleDistributionTrainerContext();
039    FiniteAlphabet states;
040    MarkovModel model;
041
042    public SimpleHMMTrainer(MarkovModel model)
043        throws IllegalSymbolException
044    {
045        this.model = model;
046
047        // go thru model and add the Distributions
048        states = model.stateAlphabet();
049
050        for (Iterator stateI = states.iterator(); stateI.hasNext(); ) {
051            State thisState = (State) stateI.next();
052            // add emission Distributions
053            if (thisState instanceof EmissionState) {
054                EmissionState thisEmitter = (EmissionState) thisState;
055                Distribution emissionDist = thisEmitter.getDistribution();
056                dtc.registerDistribution(emissionDist);
057                emissionDist.registerWithTrainer(dtc);
058            }
059
060            // add transition Distributions
061            Distribution transDist = model.getWeights(thisState);
062            dtc.registerDistribution(transDist);
063            transDist.registerWithTrainer(dtc);
064        }
065    }
066
067    public void startCycle()
068    {
069        dtc.clearCounts();
070    }
071
072    public void recordEmittedSymbol(State state, Symbol symbol, double weight)
073        throws IllegalSymbolException
074    {
075        // look up the emission Distribution I need
076        if (state instanceof EmissionState) {
077            Distribution emissionDist = ((EmissionState) state).getDistribution();
078            dtc.addCount(emissionDist, symbol, weight);
079        }
080        else throw new IllegalSymbolException("specified State is not an EmissionState.");
081    }
082
083    public void recordTransition(State source, State dest, double weight)
084        throws IllegalArgumentException
085    {
086        // verify the transition
087        try {
088        if (model.containsTransition(source, dest)) {
089            Distribution transDist = model.getWeights(source);
090            dtc.addCount(transDist, dest, weight);
091        }
092        else throw new IllegalArgumentException("the specified transition is illegal for this model.");
093        }
094        catch (IllegalSymbolException ise) {
095            throw new IllegalArgumentException("either source or destination are not valid");
096        }
097    }
098
099    public void completeCycle()
100        throws BioException
101    {
102        try {
103            dtc.train();
104        }
105        catch (ChangeVetoException cve) {
106            throw new BioException(cve);
107        }
108    }
109}
110