001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022package org.biojava.bio.dp; 023 024import java.util.Iterator; 025 026import org.biojava.bio.BioException; 027import org.biojava.bio.dist.Distribution; 028import org.biojava.bio.dist.DistributionTrainerContext; 029import org.biojava.bio.dist.SimpleDistributionTrainerContext; 030import org.biojava.bio.symbol.FiniteAlphabet; 031import org.biojava.bio.symbol.IllegalSymbolException; 032import org.biojava.bio.symbol.Symbol; 033import org.biojava.utils.ChangeVetoException; 034 035public class SimpleHMMTrainer 036 implements HMMTrainer 037{ 038 DistributionTrainerContext dtc = new SimpleDistributionTrainerContext(); 039 FiniteAlphabet states; 040 MarkovModel model; 041 042 public SimpleHMMTrainer(MarkovModel model) 043 throws IllegalSymbolException 044 { 045 this.model = model; 046 047 // go thru model and add the Distributions 048 states = model.stateAlphabet(); 049 050 for (Iterator stateI = states.iterator(); stateI.hasNext(); ) { 051 State thisState = (State) stateI.next(); 052 // add emission Distributions 053 if (thisState instanceof EmissionState) { 054 EmissionState thisEmitter = (EmissionState) thisState; 055 Distribution emissionDist = thisEmitter.getDistribution(); 056 dtc.registerDistribution(emissionDist); 057 emissionDist.registerWithTrainer(dtc); 058 } 059 060 // add transition Distributions 061 Distribution transDist = model.getWeights(thisState); 062 dtc.registerDistribution(transDist); 063 transDist.registerWithTrainer(dtc); 064 } 065 } 066 067 public void startCycle() 068 { 069 dtc.clearCounts(); 070 } 071 072 public void recordEmittedSymbol(State state, Symbol symbol, double weight) 073 throws IllegalSymbolException 074 { 075 // look up the emission Distribution I need 076 if (state instanceof EmissionState) { 077 Distribution emissionDist = ((EmissionState) state).getDistribution(); 078 dtc.addCount(emissionDist, symbol, weight); 079 } 080 else throw new IllegalSymbolException("specified State is not an EmissionState."); 081 } 082 083 public void recordTransition(State source, State dest, double weight) 084 throws IllegalArgumentException 085 { 086 // verify the transition 087 try { 088 if (model.containsTransition(source, dest)) { 089 Distribution transDist = model.getWeights(source); 090 dtc.addCount(transDist, dest, weight); 091 } 092 else throw new IllegalArgumentException("the specified transition is illegal for this model."); 093 } 094 catch (IllegalSymbolException ise) { 095 throw new IllegalArgumentException("either source or destination are not valid"); 096 } 097 } 098 099 public void completeCycle() 100 throws BioException 101 { 102 try { 103 dtc.train(); 104 } 105 catch (ChangeVetoException cve) { 106 throw new BioException(cve); 107 } 108 } 109} 110