001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022package org.biojava.bio.dist;
023
024import java.io.IOException;
025import java.io.ObjectInputStream;
026import java.io.ObjectOutputStream;
027import java.io.Serializable;
028
029import org.biojava.bio.BioError;
030import org.biojava.bio.symbol.Alphabet;
031import org.biojava.bio.symbol.AlphabetIndex;
032import org.biojava.bio.symbol.AlphabetManager;
033import org.biojava.bio.symbol.AtomicSymbol;
034import org.biojava.bio.symbol.FiniteAlphabet;
035import org.biojava.bio.symbol.IllegalAlphabetException;
036import org.biojava.bio.symbol.IllegalSymbolException;
037import org.biojava.bio.symbol.Symbol;
038import org.biojava.utils.ChangeAdapter;
039import org.biojava.utils.ChangeEvent;
040import org.biojava.utils.ChangeSupport;
041import org.biojava.utils.ChangeVetoException;
042
043/**
044 * A simple implementation of a distribution, which works with any finite alphabet.
045 *
046 * @author Matthew Pocock
047 * @author Thomas Down
048 * @author Mark Schreiber
049 * @since 1.0
050 * @serial WARNING serialized versions of this class may not be compatible with later versions of BioJava
051 */
052public class SimpleDistribution
053extends AbstractDistribution implements Serializable{
054    static final long serialVersionUID = 7252850540926095728L;
055    
056    
057  private transient AlphabetIndex indexer;
058  private transient double[] weights = null;//because indexer is transient.
059  private Distribution nullModel;
060  private FiniteAlphabet alpha;
061  
062  private static class SymbolWeightMemento implements Serializable {
063      static final long serialVersionUID = 5223128163879670657L;
064      
065      public final Symbol symbol;
066      public final double weight;
067      
068      public SymbolWeightMemento(Symbol s, double weight) {
069          this.symbol = s;
070          this.weight = weight;
071      }
072  }
073  
074  private void writeObject(ObjectOutputStream oos)
075      throws IOException
076  {
077      oos.defaultWriteObject();
078      
079      if (weights != null) {// fix for bug 2360
080          SymbolWeightMemento[] swm = new SymbolWeightMemento[weights.length];
081          for (int w = 0; w < swm.length; ++w) {
082              swm[w] = new SymbolWeightMemento(indexer.symbolForIndex(w), weights[w]);
083          }
084          oos.writeObject(swm);
085      }
086  }
087
088  private void readObject(ObjectInputStream stream)
089    throws IOException, ClassNotFoundException
090  {
091    stream.defaultReadObject();
092    
093    //System.out.println("Alphabet for this dist is: "+alpha.getName());
094    indexer = AlphabetManager.getAlphabetIndex(alpha);
095    indexer.addChangeListener(
096      new ChangeAdapter(){
097        public void preChange(ChangeEvent ce) throws ChangeVetoException{
098          if(hasWeights()){
099            throw new ChangeVetoException(
100              ce,
101              "Can't allow the index to change as we have probabilities."
102            );
103          }
104        }
105      },AlphabetIndex.INDEX
106    );
107    weights = new double[alpha.size()];
108    
109    SymbolWeightMemento[] swm = (SymbolWeightMemento[]) stream.readObject();
110    for (int m = 0; m < swm.length; ++m) {
111        try {
112            weights[indexer.indexForSymbol(swm[m].symbol)] = swm[m].weight;
113        } catch (IllegalSymbolException ex) {
114            throw new IOException("Symbol in serialized stream: "+swm[m].symbol.getName()+" can't be found in the alphabet");
115        }
116    }
117  }
118
119  public Alphabet getAlphabet() {
120    return indexer.getAlphabet();
121  }
122
123  public Distribution getNullModel() {
124    return this.nullModel;
125  }
126
127
128
129  protected void setNullModelImpl(Distribution nullModel)
130
131  throws IllegalAlphabetException, ChangeVetoException {
132    this.nullModel = nullModel;
133  }
134
135
136  /**
137   * Indicate whether the weights array has been allocated yet.
138   *
139   * @return  true if the weights are allocated
140   */
141  protected boolean hasWeights() {
142    return weights != null;
143  }
144
145
146  /**
147   * Get the underlying array that stores the weights.
148   *
149   * <p>
150   * Modifying this will modify the state of the distribution.
151   * </p>
152   *
153   * @return  the weights array
154   */
155  protected double[] getWeights() {
156    if(weights == null) {
157      weights = new double[((FiniteAlphabet)getAlphabet()).size()];
158      for(int i = 0; i < weights.length; i++) {
159        weights[i] = Double.NaN;
160
161      }
162    }
163     return weights;
164  }
165
166
167
168  public double getWeightImpl(AtomicSymbol s)
169
170  throws IllegalSymbolException {
171    if(!hasWeights()) {
172      return Double.NaN;
173    } else {
174      int index = indexer.indexForSymbol(s);
175      return weights[index];
176    }
177  }
178
179
180  protected void setWeightImpl(AtomicSymbol s, double w)
181  throws IllegalSymbolException, ChangeVetoException {
182    double[] weights = getWeights();
183    if(w < 0.0) {
184      throw new IllegalArgumentException(
185        "Can't set weight to negative score: " +
186        s.getName() + " -> " + w
187      );
188    }
189    weights[indexer.indexForSymbol(s)] = w;
190  }
191
192  private void initialise(FiniteAlphabet alphabet) {
193    this.alpha = alphabet;
194    this.indexer = AlphabetManager.getAlphabetIndex(alphabet);
195    indexer.addChangeListener(
196      new ChangeAdapter() {
197        public void preChange(ChangeEvent ce) throws ChangeVetoException {
198          if(hasWeights()) {
199            throw new ChangeVetoException(
200              ce,
201              "Can't allow the index to change as we have probabilities."
202            );
203          }
204        }
205      },
206      AlphabetIndex.INDEX
207    );
208
209    try {
210      setNullModel(new UniformDistribution(alphabet));
211    } catch (Exception e) {
212      throw new BioError("This should never fail. Something is screwed!", e);
213    }
214  }
215
216  /**
217   * make an instance of SimpleDistribution for the specified Alphabet.
218   */
219  public SimpleDistribution(FiniteAlphabet alphabet)
220  {
221    initialise(alphabet);
222  }
223
224  /**
225   * make an instance of SimpleDistribution with weights identical
226   * to the specified Distribution.
227   *
228   * @param dist Distribution to copy the weights from.
229   */
230  public SimpleDistribution(Distribution dist)
231  {
232    try {
233    initialise((FiniteAlphabet) dist.getAlphabet());
234
235    // now copy over weights
236    int alfaSize = ((FiniteAlphabet)getAlphabet()).size();
237
238    for (int i = 0; i < alfaSize; i++) {
239      weights = new double[alfaSize];
240      weights[i] = dist.getWeight(indexer.symbolForIndex(i));
241    }
242    }
243    catch (IllegalSymbolException ise) {
244      System.err.println("an impossible error surely! "); ise.printStackTrace();
245    }
246  }
247
248  /**
249   * Register an SimpleDistribution.Trainer instance as the trainer for this distribution.
250   */
251  public void registerWithTrainer(DistributionTrainerContext dtc) {
252   dtc.registerTrainer(this, new Trainer());
253  }
254
255
256  /**
257   * A simple implementation of a trainer for this class.
258   *
259   * @author Matthew Pocock
260   * @since 1.0
261   */
262  protected class Trainer implements DistributionTrainer {
263    private final Count counts;
264
265    /**
266     * Create a new trainer.
267     */
268    public Trainer() {
269      counts = new IndexedCount(indexer);
270    }
271
272    public void addCount(DistributionTrainerContext dtc, AtomicSymbol sym, double times)
273    throws IllegalSymbolException {
274      try {
275          counts.increaseCount(sym, times);
276      } catch (ChangeVetoException cve) {
277        throw new BioError(
278          "Assertion Failure: Change to Count object vetoed", cve
279        );
280      }
281    }
282
283    public double getCount(DistributionTrainerContext dtc, AtomicSymbol sym)
284    throws IllegalSymbolException {
285      return counts.getCount(sym);
286    }
287
288
289
290    public void clearCounts(DistributionTrainerContext dtc) {
291      try {
292        int size = ((FiniteAlphabet) counts.getAlphabet()).size();
293        for(int i = 0; i < size; i++) {
294          counts.zeroCounts();
295        }
296      } catch (ChangeVetoException cve) {
297        throw new BioError(
298          "Assertion Failure: Change to Count object vetoed",cve
299        );
300      }
301    }
302
303
304
305    public void train(DistributionTrainerContext dtc, double weight)
306    throws ChangeVetoException {
307      if(!hasListeners())  {
308        trainImpl(dtc, weight);
309      } else {
310        ChangeSupport changeSupport = getChangeSupport(Distribution.WEIGHTS);
311        synchronized(changeSupport) {
312          ChangeEvent ce = new ChangeEvent(
313            SimpleDistribution.this,
314            Distribution.WEIGHTS
315          );
316          changeSupport.firePreChangeEvent(ce);
317          trainImpl(dtc, weight);
318          changeSupport.firePostChangeEvent(ce);
319        }
320      }
321    }
322
323
324
325    protected void trainImpl(DistributionTrainerContext dtc, double weight) {
326      //System.out.println("Training");
327      try {
328        Distribution nullModel = getNullModel();
329        double[] weights = getWeights();
330        double[] total = new double[weights.length];
331        double sum = 0.0;
332
333        for(int i = 0; i < total.length; i++) {
334          AtomicSymbol s = (AtomicSymbol) indexer.symbolForIndex(i);
335          sum +=
336            total[i] =
337              getCount(dtc, s) +
338              nullModel.getWeight(s) * weight;
339        }
340        double sum_inv = 1.0 / sum;
341        for(int i = 0; i < total.length; i++) {
342          //System.out.println("\t" + weights[i] + "\t" + total[i] * sum_inv);
343          weights[i] = total[i] * sum_inv;
344        }
345      } catch (IllegalSymbolException ise) {
346        throw new BioError(
347          "Assertion Failure: Should be impossible to mess up the symbols.",ise
348        );
349      }
350    }
351  }
352}
353
354
355