021package org.biojava.bio.dist;
023import java.io.IOException;
024import java.io.InputStream;
025import java.io.OutputStream;
026import java.util.ArrayList;
027import java.util.HashMap;
028import java.util.HashSet;
029import java.util.Iterator;
030import java.util.LinkedList;
031import java.util.List;
032import java.util.Random;
034import org.biojava.bio.Annotation;
035import org.biojava.bio.BioError;
036import org.biojava.bio.BioException;
037import org.biojava.bio.alignment.Alignment;
038import org.biojava.bio.seq.Sequence;
039import org.biojava.bio.seq.SequenceFactory;
040import org.biojava.bio.seq.impl.SimpleSequenceFactory;
041import org.biojava.bio.symbol.Alphabet;
042import org.biojava.bio.symbol.AlphabetIndex;
043import org.biojava.bio.symbol.AlphabetManager;
044import org.biojava.bio.symbol.AtomicSymbol;
045import org.biojava.bio.symbol.BasisSymbol;
046import org.biojava.bio.symbol.FiniteAlphabet;
047import org.biojava.bio.symbol.IllegalAlphabetException;
048import org.biojava.bio.symbol.IllegalSymbolException;
049import org.biojava.bio.symbol.Location;
050import org.biojava.bio.symbol.LocationTools;
051import org.biojava.bio.symbol.PackedSymbolListFactory;
052import org.biojava.bio.symbol.PointLocation;
053import org.biojava.bio.symbol.SimpleSymbolListFactory;
054import org.biojava.bio.symbol.Symbol;
055import org.biojava.bio.symbol.SymbolList;
056import org.biojava.bio.symbol.SymbolListFactory;
057import org.biojava.bio.symbol.SymbolListViews;
058import org.biojava.utils.AssertionFailure;
059import org.biojava.utils.ChangeVetoException;
060import org.xml.sax.SAXException;
064 * A class to hold static methods for calculations and manipulations using
065 * Distributions.
066 *
067 * @author Mark Schreiber
068 * @author Matthew Pocock
069 * @since 1.2
070 */
072public final class DistributionTools {
074  /**
075   * Overide the constructer to prevent subclassing.
076   */
077  private DistributionTools(){}
079  /**
080   * Writes a Distribution to XML that can be read with the readFromXML method.
081   *
082   * @param d the Distribution to write.
083   * @param os where to write it to.
084   * @throws IOException if writing fails
085   */
086  public static void writeToXML(Distribution d, OutputStream os) throws IOException{
087    new XMLDistributionWriter().writeDistribution(d, os);
088  }
090  /**
091   * Read a distribution from XML.
092   *
093   * @param is  an InputStream to read from
094   * @return  a Distribution parameterised by the xml in is
095   * @throws IOException  if is failed
096   * @throws SAXException if is could not be processed as XML
097   */
098  public static Distribution readFromXML(InputStream is)throws IOException, SAXException{
099    XMLDistributionReader writer = new XMLDistributionReader();
100    return writer.parseXML(is);
101  }
103  /**
104   * Randomizes the weights of a <code>Distribution</code>.
105   *
106   * @param d the <code>Distribution</code> to randomize
107   * @throws ChangeVetoException if the Distribution is locked
108   */
109  public static void randomizeDistribution(Distribution d)
110    throws ChangeVetoException{
111    Random rand = new Random();
112    FiniteAlphabet a = (FiniteAlphabet)d.getAlphabet();
113    AlphabetIndex ind = AlphabetManager.getAlphabetIndex(a);
114    DistributionTrainerContext dtc = new SimpleDistributionTrainerContext();
115    dtc.registerDistribution(d);
117    for(int i = 0; i < a.size(); i++){
118      try {
119        dtc.addCount(d,ind.symbolForIndex(i),rand.nextDouble());
120      }
121      catch (IllegalSymbolException ex) {
122        throw new BioError("Alphabet has Illegal Symbols!!", ex);
123      }
124    }
126    dtc.train();
127  }
129  /**
130   * Make a distribution from a count.
131   *
132   * @param c the count
133   * @return a Distrubution over the same <code>FiniteAlphabet</code> as <code>c</code>
134   * and trained with the counts of <code>c</code>
135   */
136  public static Distribution countToDistribution(Count c){
137    FiniteAlphabet a  = (FiniteAlphabet)c.getAlphabet();
138    Distribution d = null;
139    try{
140      d = DistributionFactory.DEFAULT.createDistribution(a);
141      AlphabetIndex index =
142          AlphabetManager.getAlphabetIndex(a);
143      DistributionTrainerContext dtc = new SimpleDistributionTrainerContext();
144      dtc.registerDistribution(d);
146      for(int i = 0; i < a.size(); i++){
147        dtc.addCount(d, index.symbolForIndex(i),
148         c.getCount((AtomicSymbol)index.symbolForIndex(i)));
149      }
150      dtc.train();
151    } catch (IllegalAlphabetException iae) {
152      throw new AssertionFailure("Assertion failure: Alphabets don't match");
153    }catch(IllegalSymbolException ise){
154      throw new AssertionFailure("Assertion Error: Cannot convert Count to Distribution", ise);
155    } catch (ChangeVetoException cve) {
156      throw new AssertionFailure("Assertion failure: distributions or counts got locked.", cve);
157    }
158    return d;
159  }
161  /**
162   * Compares the emission spectra of two distributions.
163   *
164   * @return true if alphabets and symbol weights are equal for the two distributions.
165   * @throws BioException if one or both of the Distributions are over infinite alphabets.
166   * @since 1.2
167   * @param a A <code>Distribution</code> with the same <code>Alphabet</code> as
168   * <code>b</code>
169   * @param b A <code>Distribution</code> with the same <code>Alphabet</code> as
170   * <code>a</code>
171   */
172  public static final boolean areEmissionSpectraEqual(Distribution a, Distribution b)
173    throws BioException{
174      //are either of the Dists infinite
175      if(a.getAlphabet() instanceof FiniteAlphabet == false
176          || b.getAlphabet() instanceof FiniteAlphabet == false){
177        throw new IllegalAlphabetException("Cannot compare emission spectra over infinite alphabet");
178      }
179      //are alphabets equal?
180      if(!(a.getAlphabet().equals(b.getAlphabet()))){
181        return false;
182      }
183      //are emissions equal?
184      for(Iterator i = ((FiniteAlphabet)a.getAlphabet()).iterator();i.hasNext();){
185        Symbol s = (Symbol)i.next();
186        if(a.getWeight(s) != b.getWeight(s)) return false;
187      }
188      return true;
189  }
191  /**
192   * Compares the emission spectra of two distribution arrays.
193   *
194   * @return true if alphabets and symbol weights are equal for each pair
195   * of distributions. Will return false if the arrays are of unequal length.
196   * @throws BioException if one of the Distributions is over an infinite
197   * alphabet.
198   * @since 1.3
199   * @param a A <code>Distribution[]</code> consisting of <code>Distributions</code>
200   * over a <code>FiniteAlphabet </code>
201   * @param b A <code>Distribution[]</code> consisting of <code>Distributions</code>
202   * over a <code>FiniteAlphabet </code>
203   */
204  public static final boolean areEmissionSpectraEqual(Distribution[] a,
205                                                      Distribution[] b)
206    throws BioException{
207      if(a.length != b.length) return false;
208      for (int i = 0; i < a.length; i++) {
209        if(areEmissionSpectraEqual(a[i], b[i]) == false){
210          return false;
211        }
212      }
213      return true;
214    }
216  /**
217   * A method to calculate the Kullback-Liebler Distance (relative entropy).
218   *
219   * @param logBase  - the log base for the entropy calculation. 2 is standard.
220   * @param observed - the observed frequence of <code>Symbols </code>.
221   * @param expected - the excpected or background frequency.
222   * @return  - A HashMap mapping Symbol to <code>(Double)</code> relative entropy.
223   * @since 1.2
224   */
225  public static final HashMap KLDistance(Distribution observed,
226                                   Distribution expected,
227                                   double logBase){
228    Iterator alpha = ((FiniteAlphabet)observed.getAlphabet()).iterator();
229    HashMap kldist = new HashMap(((FiniteAlphabet)observed.getAlphabet()).size());
231    while(alpha.hasNext()){
232      Symbol s = (Symbol)alpha.next();
233      try{
234        double obs = observed.getWeight(s);
235        double exp = expected.getWeight(s);
236        if(obs == 0.0){
237          kldist.put(s,new Double(0.0));
238        }else{
239          double entropy = obs * (Math.log(obs/exp))/Math.log(logBase);
240          kldist.put(s,new Double(entropy));
241        }
242      }catch(IllegalSymbolException ise){
243        ise.printStackTrace(System.err);
244      }
245    }
246    return kldist;
247  }
249  /**
250   * A method to calculate the Shannon Entropy for a Distribution.
251   *
252   * @param logBase  - the log base for the entropy calculation. 2 is standard.
253   * @param observed - the observed frequence of <code>Symbols </code>.
254   * @return  - A HashMap mapping Symbol to <code>(Double)</code> entropy.
255   * @since 1.2
256   */
257  public static final HashMap shannonEntropy(Distribution observed, double logBase){
258    Iterator alpha = ((FiniteAlphabet)observed.getAlphabet()).iterator();
259    HashMap entropy = new HashMap(((FiniteAlphabet)observed.getAlphabet()).size());
261    while(alpha.hasNext()){
262      Symbol s = (Symbol)alpha.next();
263      try{
264        double obs = observed.getWeight(s);
265        if(obs == 0.0){
266         // entropy.put(s,new Double(0.0));
267        }else{
268          double e = -(Math.log(obs))/Math.log(logBase);
269          entropy.put(s,new Double(e));
270        }
271      }catch(IllegalSymbolException ise){
272        ise.printStackTrace(System.err);
273      }
274    }
275    return entropy;
276  }
278  /**
279   * Calculates the total Entropy for a Distribution. Entropies for individual
280   * <code>Symbols</code> are weighted by their probability of occurence.
281   * @param observed the observed frequence of <code>Symbols </code>.
282   * @return the total entropy of the <code>Distribution </code>.
283   */
284  public static double totalEntropy(Distribution observed){
285    HashMap ent = shannonEntropy(observed, 2.0);
286    double totalEntropy = 0.0;
287    try{
288    for(Iterator i = ent.keySet().iterator(); i.hasNext();){
289      Symbol sym = (Symbol) i.next();
290      totalEntropy += observed.getWeight(sym)*((Double)ent.get(sym)).doubleValue();
291    }
292    }
293    catch(Exception e){
294      e.printStackTrace(System.err);
295    }
297    return totalEntropy;
298  }
300  /**
301   * Calculates the total bits of information for a distribution.
302   * @param observed - the observed frequence of <code>Symbols </code>.
303   * @return the total information content of the <code>Distribution </code>.
304   * @since 1.2
305   */
306  public static final double bitsOfInformation(Distribution observed){
307    double totalEntropy = totalEntropy(observed);
308    int size = ((FiniteAlphabet)observed.getAlphabet()).size();
310    return Math.log((double)size)/Math.log(2.0) - totalEntropy;
311  }
313  /**
314   * Equivalent to distOverAlignment(a, false, 0.0).
315   *
316   * @param a  the Alignment
317   * @return   an array of Distribution instances representing columns of the
318   *     alignment
319   * @throws IllegalAlphabetException  if the alignment alphabet is not
320   *    compattible
321   */
322  public static Distribution[] distOverAlignment(Alignment a)
323      throws IllegalAlphabetException{
324    return distOverAlignment(a,false,0.0);
325  }
327  /**
328   * Creates a joint distribution.
329   *
330   * @throws IllegalAlphabetException if all sequences don't use the same alphabet
331   * @param a the <code>Alignment </code>to build the <code>Distribution[]</code> over.
332   * @param countGaps if true gaps will be included in the distributions
334   * @param nullWeight the number of pseudo counts to add to each distribution
335   * @param cols a list of positions in the alignment to include in the joint distribution
336   * @return a <code>Distribution</code>
337   * @since 1.2
338   */
339  public static final Distribution jointDistOverAlignment(Alignment a,
340                                                 boolean countGaps,
341                                                 double nullWeight,
342                                                 int[] cols)
343  throws IllegalAlphabetException {
344        List<String> seqs = a.getLabels();
345        FiniteAlphabet alpha =
346          (FiniteAlphabet)((SymbolList)a.symbolListForLabel(seqs.get(0))).getAlphabet();
347        for(int i = 1; i < seqs.size();i++){
348                FiniteAlphabet test = (FiniteAlphabet)((SymbolList)a.symbolListForLabel(seqs.get(i))).getAlphabet();
349                if(test != alpha){
350                        throw new IllegalAlphabetException("Cannot Calculate jointDistOverAlignment() for alignments with"+
351                        "mixed alphabets");
352                }
353        }
354        List<Alphabet> a_list = new ArrayList();
355        for(int i=0; i<cols.length; i++){
356                a_list.add(alpha);
357        }
358        Distribution dist;
359        DistributionTrainerContext dtc = new SimpleDistributionTrainerContext();
360        dist =
361          DistributionFactory.DEFAULT.
362          createDistribution(AlphabetManager.getCrossProductAlphabet(a_list));
363        dtc.setNullModelWeight(nullWeight);
364    try{
366        dtc.registerDistribution(dist);
367        Location loc= new PointLocation(cols[0]);
368        for (int j = 0; j < cols.length; j++)
369            {
370                Location lj = new PointLocation(cols[j]);
371                loc = LocationTools.union(loc, lj);
372            }
373            Alignment subalign = a.subAlignment(new HashSet(seqs), loc);
374            Iterator s_it = subalign.symbolListIterator();
375        while(s_it.hasNext()){
376            SymbolList syml = (SymbolList) s_it.next();
377            Symbol s= SymbolListViews.orderNSymbolList(syml,syml.length()).symbolAt(1);
378            if(countGaps == false && syml.toList().contains(a.getAlphabet().getGapSymbol())){
379                    //do nothing, not counting gaps
380            }else{
381            dtc.addCount(dist,s,1.0);// count the symbol
382            }
383        }
384        dtc.train();
385    }catch(Exception e){
386      e.printStackTrace(System.err);
387    }
388    return dist;
390  /**
391   * Creates an array of distributions, one for each column of the alignment.
392   *
393   * @throws IllegalAlphabetException if all sequences don't use the same alphabet
394   * @param a the <code>Alignment </code>to build the <code>Distribution[]</code> over.
395   * @param countGaps if true gaps will be included in the distributions
396   * @param nullWeight the number of pseudo counts to add to each distribution,
397   * pseudo counts will not affect gaps, no gaps, no gap counts.
398   * @return a <code>Distribution[]</code> where each member of the array is a
399   * <code>Distribution </code>of the <code>Symbols </code>found at that position
400   * of the <code>Alignment </code>.
401   * @since 1.2
402   */
403  public static final Distribution[] distOverAlignment(Alignment a,
404                                                 boolean countGaps,
405                                                 double nullWeight)
406  throws IllegalAlphabetException {
408    List<String> seqs = a.getLabels();
410    FiniteAlphabet alpha = (FiniteAlphabet)((SymbolList)a.symbolListForLabel(seqs.get(0))).getAlphabet();
411    for(int i = 1; i < seqs.size();i++){
412        FiniteAlphabet test = (FiniteAlphabet)((SymbolList)a.symbolListForLabel(seqs.get(i))).getAlphabet();
413        if(test != alpha){
414          throw new IllegalAlphabetException("Cannot Calculate distOverAlignment() for alignments with"+
415          "mixed alphabets");
416        }
417    }
419    Distribution[] pos = new Distribution[a.length()];
420    DistributionTrainerContext dtc = new SimpleDistributionTrainerContext();
421    dtc.setNullModelWeight(nullWeight);
423    double[] adjRatios = null;
424    if(countGaps){
425      adjRatios = new double[a.length()];
426    }
428    try{
429      for(int i = 0; i < a.length(); i++){// For each position
430        double gapCount = 0.0;
431        double totalCount = 0.0;
433        pos[i] = DistributionFactory.DEFAULT.createDistribution(alpha);
434        dtc.registerDistribution(pos[i]);
436        for(Iterator<String> j = seqs.iterator(); j.hasNext();){// of each sequence
437          String seqLabel = j.next();
438          Symbol s = a.symbolAt(seqLabel,i + 1);
440          /*If this is working over a flexible alignment there is a possibility
441          that s could be null if this Sequence is not really preset in this
442          region of the Alignment. In this case it will be skipped*/
443          if(s == null)
444            continue;
446          Symbol gap = alpha.getGapSymbol();
447          if(countGaps &&
448             s.equals(gap)){
449             gapCount++; totalCount++;
450          }else{
451            dtc.addCount(pos[i],s,1.0);// count the symbol
452            totalCount++;
453          }
454        }
456        if(countGaps){
457          adjRatios[i] = 1.0 - (gapCount / totalCount);
458        }
459      }
461      dtc.train();
463      if(countGaps){//need to adjust counts for gaps
464        for (int i = 0; i < adjRatios.length; i++) {
465          Distribution d = pos[i];
466          for (Iterator iter = ((FiniteAlphabet)d.getAlphabet()).iterator();
467                            iter.hasNext(); ) {
468            Symbol sym = (Symbol)iter.next();
469            d.setWeight(sym, (d.getWeight(sym) * adjRatios[i]));
470          }
471        }
472      }
474    }catch(Exception e){
475      e.printStackTrace(System.err);
476    }
477    return pos;
478  }
481  /**
482   * Creates an array of distributions, one for each column of the alignment.
483   * No pseudo counts are used.
484   * @param countGaps if true gaps will be included in the distributions
485   * @param a the <code>Alignment </code>to build the <code>Distribution[]</code> over.
486   * @throws IllegalAlphabetException if the alignment is not composed from sequences all
487   *         with the same alphabet
488   * @return a <code>Distribution[]</code> where each member of the array is a
489   * <code>Distribution </code>of the <code>Symbols </code>found at that position
490   * of the <code>Alignment </code>.
491   * @since 1.2
492   */
493  public static final Distribution[] distOverAlignment(Alignment a,
494                                                 boolean countGaps)
495  throws IllegalAlphabetException {
496    return distOverAlignment(a,countGaps,0.0);
497  }
499  /**
500   * Averages two or more distributions. NOTE the current implementation ignore the null model.
501   * @since 1.2
502   * @param dists the <code>Distributions </code>to average
503   * @return a <code>Distribution </code>were the weight of each <code>Symbol </code>
504   * is the average of the weights of that <code>Symbol </code>in each <code>Distribution </code>.
505   */
506  public static final Distribution average (Distribution [] dists){
508    Alphabet alpha = dists[0].getAlphabet();
509    //check if all alphabets are the same
510    for (int i = 1; i < dists.length; i++) {
511      if(!(dists[i].getAlphabet().equals(alpha))){
512        throw new IllegalArgumentException("All alphabets must be the same");
513      }
514    }
516    try{
517      Distribution average = DistributionFactory.DEFAULT.createDistribution(alpha);
518      DistributionTrainerContext dtc = new SimpleDistributionTrainerContext();
519      dtc.registerDistribution(average);
521      for (int i = 0; i < dists.length; i++) {// for each distribution
522        for(Iterator iter = ((FiniteAlphabet)dists[i].getAlphabet()).iterator(); iter.hasNext(); ){//for each symbol
523          Symbol sym = (Symbol)iter.next();
524          dtc.addCount(average,sym,dists[i].getWeight(sym));
525        }
526      }
529      dtc.train();
530      return average;
531    } catch(IllegalAlphabetException iae){//The following throw unchecked exceptions as they shouldn't happen
532       throw new AssertionFailure("Distribution contains an illegal alphabet", iae);
533    } catch(IllegalSymbolException ise){
534       throw new AssertionFailure("Distribution contains an illegal symbol", ise);
535    } catch(ChangeVetoException cve){
536       throw new AssertionFailure("The Distribution has become locked", cve);
537    }
538  }
540  /**
541   * Produces a sequence by randomly sampling the Distribution.
542   *
543   * @param name the name for the sequence
544   * @param d the distribution to sample. If this distribution is of order N a
545   * seed sequence is generated allowed to 'burn in' for 1000 iterations and used
546   * to produce a sequence over the conditioned alphabet.
547   * @param length the number of symbols in the sequence.
548   * @return a Sequence with name and urn = to name and an Empty Annotation.
549   */
550  public static final Sequence generateSequence(String name, Distribution d, int length){
551    SymbolList sl = generateSymbolList(d, length);
552    SequenceFactory fact = new SimpleSequenceFactory();
553    return fact.createSequence(sl, name, name, Annotation.EMPTY_ANNOTATION);
554    //return new SimpleSequence(sl,name,name,Annotation.EMPTY_ANNOTATION);
555  }
558 * Produces a <code>SymbolList</code> by randomly sampling a Distribution.
559 *
560 * @param d the distribution to sample. If this distribution is of order N a
561 * seed sequence is generated allowed to 'burn in' for 1000 iterations and used
562 * to produce a sequence over the conditioned alphabet.
563 * @param length the number of symbols in the sequence.
564 * @return a SymbolList or length <code>length</code>
565 */
566  public static final SymbolList generateSymbolList(Distribution d, int length){
567    if(d instanceof OrderNDistribution)
568      return generateOrderNSymbolList((OrderNDistribution)d, length);
570    SymbolList sl = null;
572    List l = new ArrayList(length);
573    for (int i = 0; i < length; i++) {
574      l.add(d.sampleSymbol());
575    }
577    try {
578      SymbolListFactory fact;
579      if(length < 10000){
580        fact = new SimpleSymbolListFactory();
581      }else{
582        fact = new PackedSymbolListFactory();
583      }
585      Symbol[] syms = new Symbol[length];
586      l.toArray(syms);
588      sl = fact.makeSymbolList(syms, length, d.getAlphabet());
589      //sl = new SimpleSymbolList(d.getAlphabet(),l);
590    }
591    catch (IllegalAlphabetException ex) {
592      //shouldn't happen but...
593      throw new BioError("Distribution emitting Symbols not from its Alphabet?");
594    }
596    return sl;
597  }
599  private static final SymbolList generateOrderNSymbolList(OrderNDistribution d, int length){
600    SymbolList sl = null;
601    List l = new ArrayList(length);
603    /*
604     * When emitting an orderN sequence a seed sequence is required that is of the
605     * length of the conditioning alphabet. The emissions will also be allowed
606     * to 'burn in' for 1000 emissions so that the 'end effect' of the seed
607     * is negated.
608     */
609     FiniteAlphabet cond = (FiniteAlphabet)d.getConditioningAlphabet();
610     UniformDistribution uni = new UniformDistribution(cond);
611     BasisSymbol seed = (BasisSymbol)uni.sampleSymbol();
612     //using the linked list the seed becomes like a history buffer.
613     LinkedList ll = new LinkedList(seed.getSymbols());
615    try {
617      for(int i = 0; i < 1000+ length; i++){
618         //get a symbol using the seed
619         Symbol sym = d.getDistribution(seed).sampleSymbol();
620         if(i >= 1000){
621           l.add(sym);
622         }
623         //add the symbol to the end of the seed
624         ll.addLast(sym);
625         //remove the first basis symbol of the seed
626         ll.removeFirst();
627         //regenerate the seed
628         seed = (BasisSymbol)cond.getSymbol(ll);
629       }
631       SymbolListFactory fact;
632       if(length < 10000){
633         fact = new SimpleSymbolListFactory();
634       }else{
635         fact = new PackedSymbolListFactory();
636       }
638       Symbol[] syms = new Symbol[l.size()];
639       l.toArray(syms);
640       sl = fact.makeSymbolList(syms, length, d.getConditionedAlphabet());
641       //sl = new SimpleSymbolList(d.getConditionedAlphabet(),l);
642    }
643    catch (IllegalSymbolException ex) {
644      //shouldn't happen but...
645      throw new BioError("Distribution emitting Symbols not from its Alphabet?",ex);
646    }catch(IllegalAlphabetException ex){
647      //shouldn't happen but...
648      throw new BioError("Distribution emitting Symbols not from its Alphabet?",ex);
649    }
651    return sl;
652  }
654  /**
655   * Generate a sequence by sampling a distribution.
656   *
657   * @deprecated use generateSequence() or generateSymbolList() instead.
658   * @param name    the name of the sequence
659   * @param d       the distribution to sample
660   * @param length  the length of the sequence
661   * @return        a new sequence with the required composition
662   */
663  protected static final Sequence generateOrderNSequence(String name, OrderNDistribution d, int length){
665    SymbolList sl = generateOrderNSymbolList(d, length);
666    SequenceFactory fact = new SimpleSequenceFactory();
668    return fact.createSequence(sl, name, name, Annotation.EMPTY_ANNOTATION);
669    //return new SimpleSequence(sl, name, name, Annotation.EMPTY_ANNOTATION);
670  }
672}//End of class