001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.bio.dp;
024
025import java.io.PrintStream;
026import java.util.Collections;
027import java.util.HashMap;
028import java.util.Iterator;
029import java.util.Map;
030import java.util.NoSuchElementException;
031
032import org.biojava.bio.Annotation;
033import org.biojava.bio.BioError;
034import org.biojava.bio.BioException;
035import org.biojava.bio.dist.Distribution;
036import org.biojava.bio.dist.DistributionFactory;
037import org.biojava.bio.dist.OrderNDistributionFactory;
038import org.biojava.bio.seq.io.SymbolTokenization;
039import org.biojava.bio.symbol.Alphabet;
040import org.biojava.bio.symbol.AlphabetManager;
041import org.biojava.bio.symbol.FiniteAlphabet;
042import org.biojava.bio.symbol.IllegalAlphabetException;
043import org.biojava.bio.symbol.IllegalSymbolException;
044import org.biojava.bio.symbol.Symbol;
045import org.biojava.utils.ChangeVetoException;
046import org.w3c.dom.Element;
047import org.w3c.dom.NodeList;
048
049/**
050 * @author Matthew Pocock
051 * @author Thomas Down
052 * @author Samiul Hasan
053 */
054public class XmlMarkovModel {
055  public static WeightMatrix readMatrix(Element root)
056  throws IllegalSymbolException, IllegalAlphabetException, BioException {
057    Element alphaE = (Element) root.getElementsByTagName("alphabet").item(0);
058    Alphabet sa = AlphabetManager.alphabetForName(
059      alphaE.getAttribute("name"));
060    if(! (sa instanceof FiniteAlphabet)) {
061      throw new IllegalAlphabetException(
062        "Can't read WeightMatrix over infinite alphabet " +
063        sa.getName() + " of type " + sa.getClass()
064      );
065    }
066    FiniteAlphabet seqAlpha = (FiniteAlphabet) sa;
067    SymbolTokenization symParser = seqAlpha.getTokenization("token");
068    SymbolTokenization nameParser = seqAlpha.getTokenization("name");
069
070    int columns = 0;
071    NodeList colL = root.getElementsByTagName("col");
072    for(int i = 0; i < colL.getLength(); i++) {
073      int indx = Integer.parseInt(((Element) colL.item(i)).getAttribute("indx"));
074      columns = Math.max(columns, indx);
075    }
076
077    WeightMatrix wm = new SimpleWeightMatrix(seqAlpha, columns, DistributionFactory.DEFAULT);
078
079    colL = root.getElementsByTagName("col");
080    for(int i = 0; i < colL.getLength(); i++) {
081      Element colE = (Element) colL.item(i);
082      int indx = Integer.parseInt(colE.getAttribute("indx")) - 1;
083      NodeList weights = colE.getElementsByTagName("weight");
084      for(int j = 0; j < weights.getLength(); j++) {
085        Element weightE = (Element) weights.item(j);
086        String symName = weightE.getAttribute("res");
087        if(symName == null || "".equals(symName)) {
088          symName = weightE.getAttribute("sym");
089        }       
090        Symbol sym;
091        if(symName.length() > 1) {
092          sym = nameParser.parseToken(symName);
093        } else {
094          sym = symParser.parseToken(symName);
095        }
096        try {
097          wm.getColumn(indx).setWeight(sym, Double.parseDouble(weightE.getAttribute("prob")));
098        } catch (ChangeVetoException cve) {
099          throw new BioError("Assertion failure: Should be able to set the weights");
100        }
101      }
102    }
103
104    return wm;
105  }
106
107  public static MarkovModel readModel(Element root)
108  throws BioException, IllegalSymbolException, IllegalAlphabetException {
109    if(root.getTagName().equals("WeightMatrix")) {
110      return new WMAsMM(readMatrix(root));
111    }
112
113    int heads = Integer.parseInt(root.getAttribute("heads"));
114    Element alphaE = (Element) root.getElementsByTagName("alphabet").item(0);
115    Alphabet seqAlpha = AlphabetManager.alphabetForName(
116      alphaE.getAttribute("name")
117    );
118    SimpleMarkovModel model = new SimpleMarkovModel(heads, seqAlpha);
119    int [] advance = new int[heads];
120    for(int i = 0; i < heads; i++) {
121      advance[i] = 1;
122    }
123
124    SymbolTokenization nameParser = null;
125    SymbolTokenization symbolParser = null;
126
127    try {
128      nameParser = seqAlpha.getTokenization("name");
129    } catch (NoSuchElementException nsee) {
130    }
131
132    try {
133      symbolParser = seqAlpha.getTokenization("token");
134    } catch (NoSuchElementException nsee) {
135    }
136
137    if(nameParser == null && symbolParser == null) {
138      throw new BioException(
139        "Couldn't find a parser for alphabet " +
140        seqAlpha.getName()
141      );
142    }
143
144    Map nameToState = new HashMap();
145    nameToState.put("_start_", model.magicalState());
146    nameToState.put("_end_", model.magicalState());
147    nameToState.put("_START_", model.magicalState());
148    nameToState.put("_END_", model.magicalState());
149    NodeList states = root.getElementsByTagName("state");
150    DistributionFactory dFact;
151    if( (seqAlpha.getAlphabets().size() > 1) &&
152        seqAlpha.getAlphabets().equals(
153          Collections.nCopies(
154            seqAlpha.getAlphabets().size(),
155            seqAlpha.getAlphabets().get(0)
156          )
157        )
158    ) {
159      dFact = OrderNDistributionFactory.DEFAULT;
160    } else {
161      dFact = DistributionFactory.DEFAULT;
162    }
163    for(int i = 0; i < states.getLength(); i++) {
164      Element stateE = (Element) states.item(i);
165      String name = stateE.getAttribute("name");
166      Distribution dis = dFact.createDistribution(seqAlpha);
167      EmissionState state = new SimpleEmissionState(
168        name, Annotation.EMPTY_ANNOTATION, advance, dis
169      );
170
171      nameToState.put(name, state);
172      NodeList weights = stateE.getElementsByTagName("weight");
173      for(int j = 0; j < weights.getLength(); j++) {
174        Element weightE = (Element) weights.item(j);
175        String symName = weightE.getAttribute("res");
176        if(symName == null || "".equals(symName)) {
177          symName = weightE.getAttribute("sym");
178        }
179        Symbol sym;
180        if(symName.length() == 1) {
181          if(symbolParser != null) {
182            sym = symbolParser.parseToken(symName);
183          } else {
184            sym = nameParser.parseToken(symName);
185          }
186        } else {
187          try {
188            if(nameParser != null) {
189              sym = nameParser.parseToken(symName);
190            } else {
191              sym = symbolParser.parseToken(symName);
192            }
193          } catch (IllegalSymbolException ise) {
194            throw new BioException("Can't extract symbol from " + weightE + " in " + stateE, ise);
195          }
196        }
197        try {
198          dis.setWeight(sym, Double.parseDouble(weightE.getAttribute("prob")));
199        } catch (ChangeVetoException cve) {
200          throw new BioError(
201            "Assertion failure: Should be able to edit distribution", cve
202          );
203        }
204      }
205
206      try {
207        model.addState(state);
208      } catch (ChangeVetoException cve) {
209        throw new BioError(
210         "Assertion failure: Should be able to add states to model",  cve
211        );
212      }
213    }
214
215    NodeList transitions = root.getElementsByTagName("transition");
216    for(int i = 0; i < transitions.getLength(); i++) {
217      Element transitionE = (Element) transitions.item(i);
218      State from = (State) nameToState.get(transitionE.getAttribute("from"));
219      State to = (State) nameToState.get(transitionE.getAttribute("to"));
220      try {
221        model.createTransition(from, to);
222      } catch (IllegalSymbolException ite) {
223        throw new BioError(
224
225          "We should have unlimited write-access to this model. " +
226          "Something is very wrong.", ite
227        );
228      } catch (ChangeVetoException cve) {
229        throw new BioError(
230
231          "We should have unlimited write-access to this model. " +
232          "Something is very wrong.", cve
233        );
234      }
235    }
236
237        for(int i = 0; i < transitions.getLength(); i++) {
238      Element transitionE = (Element) transitions.item(i);
239      State from = (State) nameToState.get(transitionE.getAttribute("from"));
240      State to = (State) nameToState.get(transitionE.getAttribute("to"));
241      double prob = Double.parseDouble(transitionE.getAttribute("prob"));
242      try {
243        model.getWeights(from).setWeight(to, prob);
244      } catch (IllegalSymbolException ite) {
245        throw new BioError(
246
247          "We should have unlimited write-access to this model. " +
248          "Something is very wrong.", ite
249        );
250      } catch (ChangeVetoException cve) {
251        throw new BioError(
252
253          "We should have unlimited write-access to this model. " +
254          "Something is very wrong.", cve
255        );
256      }
257    }
258    return model;
259  }
260
261  public static void writeMatrix(WeightMatrix matrix, PrintStream out) throws Exception {
262    FiniteAlphabet symA = (FiniteAlphabet) matrix.getAlphabet();
263
264    out.println("<MarkovModel>\n  <alphabet name=\"" + symA.getName() + "\"/>");
265
266    for(int i = 0; i < matrix.columns(); i++) {
267      out.println("  <col indx=\"" + (i+1) + "\">");
268      for(Iterator si = symA.iterator(); si.hasNext(); ) {
269        Symbol s = (Symbol) si.next();
270        out.println("    <weight sym=\"" + s.getName() +
271                             "\" prob=\"" + matrix.getColumn(i).getWeight(s) + "\"/>");
272        }
273      out.println("  </col>");
274    }
275
276    out.println("</MarkovModel>");
277  }
278
279  public static void writeModel(MarkovModel model, PrintStream out)
280  throws Exception {
281    model = DP.flatView(model);
282    FiniteAlphabet stateA = model.stateAlphabet();
283    FiniteAlphabet symA = (FiniteAlphabet) model.emissionAlphabet();
284
285    out.println("<MarkovModel heads=\"" + model.advance().length + "\">");
286    out.println("<alphabet name=\"" + symA.getName() + "\"/>");
287
288    // print out states & scores
289    for(Iterator stateI = stateA.iterator(); stateI.hasNext(); ) {
290      State s = (State) stateI.next();
291      if(! (s instanceof MagicalState)) {
292        out.println("  <state name=\"" + s.getName() + "\">");
293        if(s instanceof EmissionState) {
294          EmissionState es = (EmissionState) s;
295          Distribution dis = es.getDistribution();
296          for(Iterator symI = symA.iterator(); symI.hasNext(); ) {
297            Symbol sym = (Symbol) symI.next();
298            out.println("    <weight sym=\"" + sym.getName() +
299                        "\" prob=\"" + dis.getWeight(sym) + "\"/>");
300          }
301        }
302        out.println("  </state>");
303      }
304    }
305
306    // print out transitions
307    for(Iterator i = stateA.iterator(); i.hasNext(); ) {
308      State from = (State) i.next();
309      printTransitions(model, from, out);
310    }
311
312    out.println("</MarkovModel>");
313  }
314
315  static private void printTransitions(MarkovModel model, State from, PrintStream out) throws IllegalSymbolException {
316    for(Iterator i = model.transitionsFrom(from).iterator(); i.hasNext(); ) {
317      State to = (State) i.next();
318      try {
319      out.println("  <transition from=\"" + ((from instanceof MagicalState) ? "_start_" : from.getName()) +
320                             "\" to=\"" + ((to instanceof MagicalState) ? "_end_" : to.getName()) +
321                             "\" prob=\"" + model.getWeights(from).getWeight(to) + "\"/>");
322      } catch (IllegalSymbolException ite) {
323        throw new BioError("Transition listed in transitionsFrom(" +
324                           from.getName() + ") has dissapeared", ite);
325      }
326    }
327  }
328}