001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021package org.biojava.stats.svm; 022 023import java.util.HashSet; 024import java.util.Set; 025 026/** 027 * Adds a class specific constant to k(x, x). 028 * 029 * @author Matthew Pocock 030 */ 031public class DiagonalAddKernel extends NestedKernel { 032 private Set posClass; 033 private Set negClass; 034 035 { 036 posClass = new HashSet(); 037 negClass = new HashSet(); 038 } 039 040 public void addPos(Object o) { 041 posClass.add(o); 042 } 043 044 public void addNeg(Object o) { 045 negClass.add(o); 046 } 047 048 /** 049 * The scale vactor. 050 */ 051 private double lambda = 1.0; 052 053 /** 054 * Set the scale factor. 055 * 056 * @param l the new scale factor 057 */ 058 public void setLambda(double l) { 059 this.lambda = l; 060 } 061 062 /** 063 * Retrieve the scale factor. 064 * 065 * @return the current scale factor 066 */ 067 public double getLambda() { 068 return lambda; 069 } 070 071 /** 072 * Return the dot product of a, b. 073 * <p> 074 * This is equal to 075 * <code>k(a, b) + d(a, b) * ||class(a)|| / (||class||)</code> 076 * where d(a, b) is zero if a != b, and 1 if a == b. class(a) is the set of all 077 * items in the same class as a. class is all items with a classification. 078 */ 079 public double evaluate(Object a, Object b) { 080 double dot = getNestedKernel().evaluate(a, b); 081 if(a == b) { 082 int size = 0; 083 if(posClass.contains(a)) { 084 size = posClass.size(); 085 } else if(negClass.contains(a)) { 086 size = negClass.size(); 087 } 088 dot += getLambda() * size / (posClass.size() + negClass.size()); 089 } 090 return dot; 091 } 092 093 public String toString() { 094 return 095 "DiagonalAdd K(a, b | l, s+, s-, k) = k(a, b) + d[a, b]; d[a, b] = " + 096 "{ a != b, 0; a == b, l * {class(a == +), s+; class(a == -), s-} }; k = " + 097 getNestedKernel().toString(); 098 } 099}