001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023package org.biojava.stats.svm.tools; 024 025import java.awt.BorderLayout; 026import java.awt.Color; 027import java.awt.Cursor; 028import java.awt.FlowLayout; 029import java.awt.Graphics; 030import java.awt.Graphics2D; 031import java.awt.Paint; 032import java.awt.Point; 033import java.awt.Rectangle; 034import java.awt.Shape; 035import java.awt.event.ActionEvent; 036import java.awt.event.ActionListener; 037import java.awt.event.ItemEvent; 038import java.awt.event.ItemListener; 039import java.awt.event.MouseAdapter; 040import java.awt.event.MouseEvent; 041import java.awt.event.WindowAdapter; 042import java.awt.event.WindowEvent; 043import java.awt.geom.AffineTransform; 044import java.awt.geom.Ellipse2D; 045import java.awt.geom.Point2D; 046import java.awt.geom.Rectangle2D; 047import java.util.Collections; 048import java.util.Iterator; 049import java.util.Set; 050 051import javax.swing.ButtonGroup; 052import javax.swing.JButton; 053import javax.swing.JComboBox; 054import javax.swing.JComponent; 055import javax.swing.JFrame; 056import javax.swing.JPanel; 057import javax.swing.JRadioButton; 058 059import org.biojava.stats.svm.PolynomialKernel; 060import org.biojava.stats.svm.RadialBaseKernel; 061import org.biojava.stats.svm.SMOTrainer; 062import org.biojava.stats.svm.SVMClassifierModel; 063import org.biojava.stats.svm.SVMKernel; 064import org.biojava.stats.svm.SVMTarget; 065import org.biojava.stats.svm.SimpleSVMTarget; 066 067/** 068 * A simple toy example that allows you to put points on a canvas, and find a 069 * polynomial hyperplane to seperate them. 070 * 071 * @author Ewan Birney 072 * @author Matthew Pocock 073 * @author Thomas Down 074 * @author Michael L Heurer 075 */ 076public class ClassifierExample { 077 /** 078 * Entry point for the application. The arguments are ignored. 079 */ 080 public static void main(String args[]) { 081 JFrame f = new JFrame(); 082 f.addWindowListener(new WindowAdapter() { 083 public void windowClosing(WindowEvent we) { 084 System.exit(0); 085 } 086 }); 087 f.getContentPane().setLayout(new BorderLayout()); 088 final PointClassifier pc = new PointClassifier(); 089 f.getContentPane().add(BorderLayout.CENTER, pc); 090 JPanel panel = new JPanel(); 091 panel.setLayout(new FlowLayout()); 092 ButtonGroup bGroup = new ButtonGroup(); 093 final JRadioButton rbPos = new JRadioButton("postive"); 094 bGroup.add(rbPos); 095 final JRadioButton rbNeg = new JRadioButton("negative"); 096 bGroup.add(rbNeg); 097 ActionListener addTypeAction = new ActionListener() { 098 public void actionPerformed(ActionEvent ae) { 099 //JRadioButton rb = (JRadioButton) ae.getSource(); 100 pc.setAddPos(rbPos.isSelected()); 101 } 102 }; 103 rbPos.addActionListener(addTypeAction); 104 panel.add(rbPos); 105 rbNeg.addActionListener(addTypeAction); 106 panel.add(rbNeg); 107 ActionListener classifyAction = new ActionListener() { 108 public void actionPerformed(ActionEvent ae) { 109 pc.classify(); 110 } 111 }; 112 JButton classifyB = new JButton("classify"); 113 classifyB.addActionListener(classifyAction); 114 panel.add(classifyB); 115 ActionListener clearAction = new ActionListener() { 116 public void actionPerformed(ActionEvent ae) { 117 pc.clear(); 118 } 119 }; 120 JButton clearB = new JButton("clear"); 121 clearB.addActionListener(clearAction); 122 panel.add(clearB); 123 rbPos.setSelected(pc.getAddPos()); 124 rbNeg.setSelected(!pc.getAddPos()); 125 126 JComboBox kernelBox = new JComboBox(); 127 kernelBox.addItem("polynomeal"); 128 kernelBox.addItem("rbf"); 129 130 kernelBox.addItemListener(new ItemListener() { 131 public void itemStateChanged(ItemEvent e) { 132 if(e.getStateChange() == ItemEvent.SELECTED) { 133 Object o = e.getItem(); 134 if(o.equals("polynomeal")) { 135 pc.setKernel(PointClassifier.polyKernel); 136 } else if(o.equals("rbf")) { 137 pc.setKernel(PointClassifier.rbfKernel); 138 } 139 } 140 } 141 }); 142 panel.add(kernelBox); 143 144 f.getContentPane().add(BorderLayout.NORTH, panel); 145 f.setSize(400, 300); 146 f.setVisible(true); 147 } 148 149 /** 150 * An extention of JComponent that contains the points & encapsulates the 151 * classifier. 152 */ 153 public static class PointClassifier extends JComponent { 154 // public kernels 155 public static SVMKernel polyKernel; 156 public static SVMKernel rbfKernel; 157 public static SMOTrainer trainer; 158 159 static { 160 trainer = new SMOTrainer(); 161 trainer.setC(1.0E+7); 162 trainer.setEpsilon(1.0E-9); 163 164 SVMKernel k = new SVMKernel() { 165 public double evaluate(Object a, Object b) { 166 Point2D pa = (Point2D) a; 167 Point2D pb = (Point2D) b; 168 169 double dot = pa.getX() * pb.getX() + pa.getY() * pb.getY(); 170 return dot; 171 } 172 }; 173 174 PolynomialKernel pk = new PolynomialKernel(); 175 pk.setNestedKernel(k); 176 pk.setOrder(2.0); 177 pk.setConstant(1.0); 178 pk.setMultiplier(0.0000001); 179 180 RadialBaseKernel rb = new RadialBaseKernel(); 181 rb.setNestedKernel(k); 182 rb.setWidth(10000.0); 183 184 polyKernel = pk; 185 rbfKernel = rb; 186 } 187 188 // private variables that should only be diddled by internal methods 189 private SVMTarget target; 190 private SVMClassifierModel model; 191 192 { 193 target = new SimpleSVMTarget(); 194 model = null; 195 } 196 197 // private variables containing state that may be diddled by beany methods 198 private boolean addPos; 199 private Shape posShape; 200 private Shape negShape; 201 private Paint svPaint; 202 private Paint plainPaint; 203 private Paint posPaint; 204 private Paint negPaint; 205 private SVMKernel kernel; 206 207 /** 208 * Set the kernel used for classification. 209 * 210 * @param kernel the SVMKernel to use 211 */ 212 public void setKernel(SVMKernel kernel) { 213 firePropertyChange("kernel", this.kernel, kernel); 214 this.kernel = kernel; 215 } 216 217 /** 218 * Retrieve the currently used kernel 219 * 220 * @return the current value of the kernel. 221 */ 222 public SVMKernel getKernel() { 223 return this.kernel; 224 } 225 226 /** 227 * Set a flag so that newly added points will be in the positive class or 228 * negative class, depending on wether addPos is true or false respectively. 229 * 230 * @param addPos boolean to flag which class to add new points to 231 */ 232 public void setAddPos(boolean addPos) { 233 firePropertyChange("addPos", this.addPos, addPos); 234 this.addPos = addPos; 235 } 236 237 /** 238 * Retrieve the current value of addPos. 239 * 240 * @return true if new points will be added to the positive examples and 241 * false if they will be added to the negative examples. 242 */ 243 public boolean getAddPos() { 244 return addPos; 245 } 246 247 /** 248 * Set the Shape to represent the positive points. 249 * <p> 250 * The shape should be positioned so that 0, 0 is the center or focus. 251 * 252 * @param posShape the Shape to use 253 */ 254 public void setPosShape(Shape posShape) { 255 firePropertyChange("posShape", this.posShape, posShape); 256 this.posShape = posShape; 257 } 258 259 /** 260 * Retrieve the shape used to represent positive points. 261 * 262 * @return the current positive Shape 263 */ 264 public Shape getPosShape() { 265 return posShape; 266 } 267 268 /** 269 * Set the Shape to represent the negative points. 270 * <p> 271 * The shape should be positioned so that 0, 0 is the center or focus. 272 * 273 * @param negShape the Shape to use 274 */ 275 public void setNegShape(Shape negShape) { 276 firePropertyChange("negShape", this.negShape, negShape); 277 this.negShape = negShape; 278 } 279 280 /** 281 * Retrieve the shape used to represent negative points. 282 * 283 * @return the current negative Shape 284 */ 285 public Shape getNegShape() { 286 return negShape; 287 } 288 289 /** 290 * Remove all points from the canvas, and discard any model. 291 */ 292 public void clear() { 293 target.clear(); 294 model = null; 295 repaint(); 296 } 297 298 /** 299 * Learn a model from the current points. 300 * <p> 301 * This may take some time for complicated models. 302 */ 303 public void classify() { 304 new Thread() { 305 public void run() { 306 Cursor c = getCursor(); 307 setCursor(new Cursor(Cursor.WAIT_CURSOR)); 308 System.out.println("Training"); 309 model = trainer.trainModel(target, kernel, null); 310 311 System.out.println("Threshold = " + model.getThreshold()); 312 for(Iterator i = model.items().iterator(); i.hasNext(); ) { 313 Object item = i.next(); 314 System.out.println(item + "\t" + 315 target.getTarget(item) + "\t" + 316 model.getAlpha(item) + "\t" + 317 model.classify(item) 318 ); 319 } 320 321 PointClassifier.this.model = model; 322 setCursor(c); 323 repaint(); 324 } 325 }.start(); 326 } 327 328 /** 329 * Make a new PointClassifier. 330 * <p> 331 * Hooks up the mouse listener & cursor. 332 * Chooses default colors & Shapes. 333 */ 334 public PointClassifier() { 335 setCursor(new Cursor(Cursor.CROSSHAIR_CURSOR)); 336 addPos = true; 337 setPosShape(new Rectangle2D.Double(-2.0, -2.0, 5.0, 5.0)); 338 setNegShape(new Ellipse2D.Double(-2.0, -2.0, 5.0, 5.0)); 339 setKernel(polyKernel); 340 plainPaint = Color.black; 341 svPaint = Color.green; 342 posPaint = Color.red; 343 negPaint = Color.blue; 344 345 addMouseListener(new MouseAdapter() { 346 public void mouseReleased(MouseEvent me) { 347 Point p = me.getPoint(); 348 if(getAddPos()) { 349 target.addItemTarget(p, +1.0); 350 } else { 351 target.addItemTarget(p, -1.0); 352 } 353 model = null; 354 repaint(); 355 } 356 }); 357 } 358 359 /** 360 * Renders this component to display the points, and if present, the 361 * support vector machine. 362 */ 363 public void paintComponent(Graphics g) { 364 Graphics2D g2 = (Graphics2D) g; 365 AffineTransform at = new AffineTransform(); 366 int i = 0; 367 Rectangle r = g2.getClipBounds(); 368 int step = 3; 369 370 if(model != null) { 371 Rectangle rr = new Rectangle(r.x, r.y, step, step); 372 Point p = new Point(r.x, r.y); 373 for(int x = r.x; x < r.x + r.width; x+=step) { 374 p.x = x; 375 rr.x = x; 376 for(int y = r.y; y < r.y + r.height; y+=step) { 377 p.y = y; 378 rr.y = y; 379 double s = model.classify(p); 380 if(s <= -1.0) { 381 g2.setPaint(negPaint); 382 } else if(s >= +1.0) { 383 g2.setPaint(posPaint); 384 } else { 385 g2.setPaint(Color.white); 386 } 387 g2.fill(rr); 388 } 389 } 390 } 391 392 Set supportVectors = Collections.EMPTY_SET; 393 if(model != null) { 394 supportVectors = model.items(); 395 } 396 for(Iterator it = target.items().iterator(); it.hasNext(); i++) { 397 Point2D p = (Point2D) it.next(); 398 at.setToTranslation(p.getX(), p.getY()); 399 Shape glyph; 400 if(target.getTarget(p) > 0) { 401 glyph = getPosShape(); 402 } else { 403 glyph = getNegShape(); 404 } 405 Shape s = at.createTransformedShape(glyph); 406 if(supportVectors.contains(p)) { 407 g2.setPaint(svPaint); 408 } else { 409 g2.setPaint(plainPaint); 410 } 411 g2.draw(s); 412 } 413 } 414 } 415}