001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022package org.biojava.stats.svm.tools;
023
024import java.io.Serializable;
025import java.util.BitSet;
026
027import org.biojava.bio.symbol.SuffixTree;
028import org.biojava.stats.svm.SVMKernel;
029
030/**
031 * Computes the dot-product of two suffix-trees as the sum of the products
032 * of the counts of all nodes they have in common.
033 * <p>
034 * This implementation allows you to scale the sub-space for each word length
035 * independently.
036 *
037 * @author Matthew Pocock
038 */
039public class SuffixTreeKernel implements SVMKernel, Serializable {
040  /**
041   * The <span class="type">DepthScaler</span> that will scale each sub-space.
042   * This defaults to <span class="type">UniformScaler</type>.
043   */
044  private DepthScaler depthScaler = new UniformScaler();
045  
046  /**
047   * Retrieve the current <span class="type">DepthScaler</span>.
048   *
049   * @return the current <span class="type">DepthScaler</span>
050   */
051  public DepthScaler getDepthScaler() {
052    return depthScaler;
053  }
054  
055  /**
056   * Change the current <span class="type">DepthScaler</span> to
057   * <span class="arg">depthScaler</span>.
058   *
059   * @param depthScaler  the new <span class="type">DepthScaler</span> to use
060   */
061  public void setDepthScaler(DepthScaler depthScaler) {
062    this.depthScaler = depthScaler;
063  }
064  
065  /**
066   * Calculate the dot product between the
067   * <span class="type">SuffixTree</span>s <span class="arg">a</span> and
068   * <span class="arg">b</span>.
069   * <p>
070   * This is the sum of the dot products of each subspace for a given word
071   * length. Each subspace is scaled using the
072   * <span class="type">DepthScaler</span> returned by
073   * <span class="method">getDepthScaler</span>.
074   *
075   * @param a  the first <span class="type">Object</span>
076   * @param b  the second <span class="type">Object</span>
077   * @return <span class="arg">a</span>.<span class="arg">b</span>
078   * @throws <span class="type">ClassCastException</span> if either
079   *         <span class="arg">a</span> or <span class="arg">b</span> are not
080   *         castable to <span class="type">SuffixTree</span>
081   */
082  public double evaluate(Object a, Object b) {
083    SuffixTree st1 = (SuffixTree) a;
084    SuffixTree st2 = (SuffixTree) b;
085    SuffixTree.SuffixNode n1 = st1.getRoot();
086    SuffixTree.SuffixNode n2 = st2.getRoot();
087      
088    return dot(st1, n1, st2, n2, st1.getAlphabet().size(), 0);
089  }
090  
091  /**
092   * Recursive method to compute the dot product of the
093   * <span class="type">SuffixTree.SuffixNode</span>s
094   * <span class="arg">n1</span> and <span class="arg">n2</span>.
095   * <p>
096   * This scales <span class="arg">n1</span>.
097   * <span class="method">getNumber</span><code>()</code> *
098   * <span class="arg">n2</span>.
099   * <span class="method">getNumber</span><code>()</code>
100   * by <span class="const">this</span>.<span class="method">getDepthScaler</span>
101   * (<span class="arg">depth</span>), and then returns the sum of this and the
102   * dot products for all children of the suffix nodes.
103   */
104  private double dot(SuffixTree st1,
105                     SuffixTree.SuffixNode n1,
106                     SuffixTree st2,
107                     SuffixTree.SuffixNode n2,
108                     int size,
109                     int depth)
110  {
111    double scale = getDepthScaler().getScale(depth);
112    double dot = n1.getNumber() * n2.getNumber() * scale * scale;
113    for(int i = 0; i < size; i++) {
114      if(n1.hasChild(i) && n2.hasChild(i)) {
115        dot += dot(st1, st1.getChild(n1, i), st2, st2.getChild(n2, i), size, depth+1);
116      }
117    }
118    return dot;
119  }
120    
121  public String toString() {
122    return new String("Suffix tree kernel");
123  }
124  
125  /**
126   * Encapsulates the scale factor to apply at a given depth.
127   *
128   * @author Matthew Pocock
129   */
130  public interface DepthScaler {
131    /**
132     * Retrieve the scaling factor at a given depth
133     *
134     * @param depth  word length
135     * @return the scaling factor for the subspace at that length
136     */
137    double getScale(int depth);
138  }
139  
140  /**
141   * Scales by 4^depth - equivalent to dividing by a probablistic flatt prior
142   * null model
143   *
144   * @author Matthew Pocock
145   */
146  public static class NullModelScaler implements DepthScaler, Serializable {
147    public double getScale(int depth) {
148      return Math.pow(4.0, (double) depth);
149    }
150  }
151  
152  /**
153   * Scale all depths by 1.0
154   *
155   * @author Matthew Pocock
156   */
157  public static class UniformScaler implements DepthScaler, Serializable {
158    public double getScale(int depth) {
159      return 1.0;
160    }
161  }
162  
163  /**
164   * Scale using a <span class="type">BitSet</span> to allow/disallow depths.
165   *
166   * @author Matthew Pocock
167   */
168  public static class SelectionScalar implements DepthScaler, Serializable {
169    private BitSet bSet;
170    
171    /**
172     * Make a new <span class="type">SelectionScalar</span> that masks in different
173     * depths.
174     *
175     * @param bSet  the mask for which depths to allow
176     */
177    public SelectionScalar(BitSet bSet) {
178      this.bSet = new BitSet();
179      this.bSet.or(bSet);
180    }
181    
182    /**
183     * @return 1.0 or 0.0 depending on whether the bit at
184     *         <span class="arg">depth</span> is set or not
185     */
186    public double getScale(int depth) {
187      if(bSet.get(depth)) {
188        return 1.0;
189      } else {
190        return 0.0;
191      }
192    }
193  }
194  
195  /**
196   * Scale using a multiple of two <span class="type">DepthScaler</span>s.
197   *
198   * @author Matthew Pocock
199   */
200  public static class MultipleScalar implements DepthScaler, Serializable {
201    private DepthScaler a;
202    private DepthScaler b;
203    
204    public MultipleScalar(DepthScaler a, DepthScaler b) {
205      this.a = a;
206      this.b = b;
207    }
208    
209    public double getScale(int depth) {
210      return a.getScale(depth) * b.getScale(depth);
211    }
212  }
213}