001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.stats.svm;
022
023import java.util.HashSet;
024import java.util.Set;
025
026/**
027 * Adds a class specific constant to k(x, x).
028 *
029 * @author Matthew Pocock
030 */
031public class DiagonalAddKernel extends NestedKernel {
032  private Set posClass;
033  private Set negClass;
034
035  {
036    posClass = new HashSet();
037    negClass = new HashSet();
038  }
039  
040  public void addPos(Object o) {
041    posClass.add(o);
042  }
043  
044  public void addNeg(Object o) {
045    negClass.add(o);
046  }
047  
048  /**
049   * The scale vactor.
050   */
051  private double lambda = 1.0;
052  
053  /**
054   * Set the scale factor.
055   *
056   * @param l  the new scale factor
057   */
058  public void setLambda(double l) {
059    this.lambda = l;
060  }
061  
062  /**
063   * Retrieve the scale factor.
064   *
065   * @return the current scale factor
066   */
067  public double getLambda() {
068    return lambda;
069  }
070  
071  /**
072   * Return the dot product of a, b.
073   * <p>
074   * This is equal to
075   * <code>k(a, b) + d(a, b) * ||class(a)|| / (||class||)</code>
076   * where d(a, b) is zero if a != b, and 1 if a == b. class(a) is the set of all
077   * items in the same class as a. class is all items with a classification.
078   */
079  public double evaluate(Object a, Object b) {
080    double dot = getNestedKernel().evaluate(a, b);
081    if(a == b) {
082      int size = 0;
083      if(posClass.contains(a)) {
084        size = posClass.size();
085      } else if(negClass.contains(a)) {
086        size = negClass.size();
087      }
088      dot += getLambda() * size / (posClass.size() + negClass.size());
089    }
090    return dot;
091  }
092  
093  public String toString() {
094    return
095     "DiagonalAdd K(a, b | l, s+, s-, k) = k(a, b) + d[a, b]; d[a, b] = " +
096     "{ a != b, 0; a == b, l * {class(a == +), s+; class(a == -), s-} }; k = " +
097     getNestedKernel().toString();
098  }
099}