001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.stats.svm.tools;
024
025import java.awt.BorderLayout;
026import java.awt.Color;
027import java.awt.Cursor;
028import java.awt.FlowLayout;
029import java.awt.Graphics;
030import java.awt.Graphics2D;
031import java.awt.Paint;
032import java.awt.Point;
033import java.awt.Rectangle;
034import java.awt.Shape;
035import java.awt.event.ActionEvent;
036import java.awt.event.ActionListener;
037import java.awt.event.ItemEvent;
038import java.awt.event.ItemListener;
039import java.awt.event.MouseAdapter;
040import java.awt.event.MouseEvent;
041import java.awt.event.WindowAdapter;
042import java.awt.event.WindowEvent;
043import java.awt.geom.AffineTransform;
044import java.awt.geom.Ellipse2D;
045import java.awt.geom.Point2D;
046import java.awt.geom.Rectangle2D;
047import java.util.Collections;
048import java.util.Iterator;
049import java.util.Set;
050
051import javax.swing.ButtonGroup;
052import javax.swing.JButton;
053import javax.swing.JComboBox;
054import javax.swing.JComponent;
055import javax.swing.JFrame;
056import javax.swing.JPanel;
057import javax.swing.JRadioButton;
058
059import org.biojava.stats.svm.PolynomialKernel;
060import org.biojava.stats.svm.RadialBaseKernel;
061import org.biojava.stats.svm.SMOTrainer;
062import org.biojava.stats.svm.SVMClassifierModel;
063import org.biojava.stats.svm.SVMKernel;
064import org.biojava.stats.svm.SVMTarget;
065import org.biojava.stats.svm.SimpleSVMTarget;
066
067/**
068 * A simple toy example that allows you to put points on a canvas, and find a
069 * polynomial hyperplane to seperate them.
070 *
071 * @author Ewan Birney
072 * @author Matthew Pocock
073 * @author Thomas Down
074 * @author Michael L Heurer
075 */
076public class ClassifierExample {
077  /**
078   * Entry point for the application. The arguments are ignored.
079   */
080  public static void main(String args[]) {
081    JFrame f = new JFrame();
082    f.addWindowListener(new WindowAdapter() {
083      public void windowClosing(WindowEvent we) {
084        System.exit(0);
085      }
086    });
087    f.getContentPane().setLayout(new BorderLayout());
088    final PointClassifier pc = new PointClassifier();
089    f.getContentPane().add(BorderLayout.CENTER, pc);
090    JPanel panel = new JPanel();
091    panel.setLayout(new FlowLayout());
092    ButtonGroup bGroup = new ButtonGroup();
093    final JRadioButton rbPos = new JRadioButton("postive");
094    bGroup.add(rbPos);
095    final JRadioButton rbNeg = new JRadioButton("negative");
096    bGroup.add(rbNeg);
097    ActionListener addTypeAction = new ActionListener() {
098      public void actionPerformed(ActionEvent ae) {
099        //JRadioButton rb = (JRadioButton) ae.getSource();
100        pc.setAddPos(rbPos.isSelected());
101      }
102    };
103    rbPos.addActionListener(addTypeAction);
104    panel.add(rbPos);
105    rbNeg.addActionListener(addTypeAction);
106    panel.add(rbNeg);
107    ActionListener classifyAction = new ActionListener() {
108      public void actionPerformed(ActionEvent ae) {
109        pc.classify();
110      }
111    };
112    JButton classifyB = new JButton("classify");
113    classifyB.addActionListener(classifyAction);
114    panel.add(classifyB);
115    ActionListener clearAction = new ActionListener() {
116      public void actionPerformed(ActionEvent ae) {
117        pc.clear();
118      }
119    };
120    JButton clearB = new JButton("clear");
121    clearB.addActionListener(clearAction);
122    panel.add(clearB);
123    rbPos.setSelected(pc.getAddPos());
124    rbNeg.setSelected(!pc.getAddPos());
125    
126    JComboBox kernelBox = new JComboBox();
127    kernelBox.addItem("polynomeal");
128    kernelBox.addItem("rbf");
129    
130    kernelBox.addItemListener(new ItemListener() {
131      public void itemStateChanged(ItemEvent e) {
132        if(e.getStateChange() == ItemEvent.SELECTED) {
133          Object o = e.getItem();
134          if(o.equals("polynomeal")) {
135            pc.setKernel(PointClassifier.polyKernel);
136          } else if(o.equals("rbf")) {
137            pc.setKernel(PointClassifier.rbfKernel);
138          }
139        }
140      }
141    });
142    panel.add(kernelBox);
143    
144    f.getContentPane().add(BorderLayout.NORTH, panel);
145    f.setSize(400, 300);
146    f.setVisible(true);
147  }
148
149  /**
150   * An extention of JComponent that contains the points & encapsulates the
151   * classifier.
152   */
153  public static class PointClassifier extends JComponent {
154    // public kernels
155    public static SVMKernel polyKernel;
156    public static SVMKernel rbfKernel;
157    public static SMOTrainer trainer;
158    
159    static {
160      trainer = new SMOTrainer();
161      trainer.setC(1.0E+7);
162      trainer.setEpsilon(1.0E-9);
163      
164      SVMKernel k = new SVMKernel() {
165        public double evaluate(Object a, Object b) {
166          Point2D pa = (Point2D) a;
167          Point2D pb = (Point2D) b;
168
169          double dot = pa.getX() * pb.getX() + pa.getY() * pb.getY();
170          return dot;
171        }
172      };
173
174      PolynomialKernel pk = new PolynomialKernel();
175      pk.setNestedKernel(k);
176      pk.setOrder(2.0);
177      pk.setConstant(1.0);
178      pk.setMultiplier(0.0000001);
179      
180      RadialBaseKernel rb = new RadialBaseKernel();
181      rb.setNestedKernel(k);
182      rb.setWidth(10000.0);
183      
184      polyKernel = pk;
185      rbfKernel = rb;
186    }
187    
188    // private variables that should only be diddled by internal methods
189    private SVMTarget target;
190    private SVMClassifierModel model;
191
192    {
193      target = new SimpleSVMTarget();
194      model = null;
195    }
196
197    // private variables containing state that may be diddled by beany methods
198    private boolean addPos;
199    private Shape posShape;
200    private Shape negShape;
201    private Paint svPaint;
202    private Paint plainPaint;
203    private Paint posPaint;
204    private Paint negPaint;
205    private SVMKernel kernel;
206
207    /**
208     * Set the kernel used for classification.
209     *
210     * @param kernel  the SVMKernel to use
211     */
212    public void setKernel(SVMKernel kernel) {
213      firePropertyChange("kernel", this.kernel, kernel);
214      this.kernel = kernel;
215    }
216    
217    /**
218     * Retrieve the currently used kernel
219     *
220     * @return the current value of the kernel.
221     */
222    public SVMKernel getKernel() {
223      return this.kernel;
224    }
225
226    /**
227     * Set a flag so that newly added points will be in the positive class or
228     * negative class, depending on wether addPos is true or false respectively.
229     *
230     * @param addPos  boolean to flag which class to add new points to
231     */
232    public void setAddPos(boolean addPos) {
233      firePropertyChange("addPos", this.addPos, addPos);
234      this.addPos = addPos;
235    }
236
237    /**
238     * Retrieve the current value of addPos.
239     *
240     * @return  true if new points will be added to the positive examples and
241     *          false if they will be added to the negative examples.
242     */
243    public boolean getAddPos() {
244      return addPos;
245    }
246
247    /**
248     * Set the Shape to represent the positive points.
249     * <p>
250     * The shape should be positioned so that 0, 0 is the center or focus.
251     *
252     * @param posShape the Shape to use
253     */
254    public void setPosShape(Shape posShape) {
255      firePropertyChange("posShape", this.posShape, posShape);
256      this.posShape = posShape;
257    }
258
259    /**
260     * Retrieve the shape used to represent positive points.
261     *
262     * @return the current positive Shape
263     */
264    public Shape getPosShape() {
265      return posShape;
266    }
267
268    /**
269     * Set the Shape to represent the negative points.
270     * <p>
271     * The shape should be positioned so that 0, 0 is the center or focus.
272     *
273     * @param negShape the Shape to use
274     */
275    public void setNegShape(Shape negShape) {
276      firePropertyChange("negShape", this.negShape, negShape);
277      this.negShape = negShape;
278    }
279
280    /**
281     * Retrieve the shape used to represent negative points.
282     *
283     * @return the current negative Shape
284     */
285    public Shape getNegShape() {
286      return negShape;
287    }
288
289    /**
290     * Remove all points from the canvas, and discard any model.
291     */
292    public void clear() {
293      target.clear();
294      model = null;
295      repaint();
296    }
297
298    /**
299     * Learn a model from the current points.
300     * <p>
301     * This may take some time for complicated models.
302     */
303    public void classify() {
304      new Thread() {
305        public void run() {
306          Cursor c = getCursor();
307          setCursor(new Cursor(Cursor.WAIT_CURSOR));
308          System.out.println("Training");
309          model = trainer.trainModel(target, kernel, null);
310
311          System.out.println("Threshold = " + model.getThreshold());
312          for(Iterator i = model.items().iterator(); i.hasNext(); ) {
313            Object item = i.next();
314            System.out.println(item + "\t" +
315                               target.getTarget(item) + "\t" +
316                               model.getAlpha(item) + "\t" +
317                               model.classify(item)
318            );
319          }
320
321          PointClassifier.this.model = model;
322          setCursor(c);
323          repaint();
324        }
325      }.start();
326    }
327
328    /**
329     * Make a new PointClassifier.
330     * <p>
331     * Hooks up the mouse listener & cursor.
332     * Chooses default colors & Shapes.
333     */
334    public PointClassifier() {
335      setCursor(new Cursor(Cursor.CROSSHAIR_CURSOR));
336      addPos = true;
337      setPosShape(new Rectangle2D.Double(-2.0, -2.0, 5.0, 5.0));
338      setNegShape(new Ellipse2D.Double(-2.0, -2.0, 5.0, 5.0));
339      setKernel(polyKernel);
340      plainPaint = Color.black;
341      svPaint = Color.green;
342      posPaint = Color.red;
343      negPaint = Color.blue;
344
345      addMouseListener(new MouseAdapter() {
346        public void mouseReleased(MouseEvent me) {
347          Point p = me.getPoint();
348          if(getAddPos()) {
349            target.addItemTarget(p, +1.0);
350          } else {
351            target.addItemTarget(p, -1.0);
352          }
353          model = null;
354          repaint();
355        }
356      });
357    }
358  
359    /**
360     * Renders this component to display the points, and if present, the
361     * support vector machine.
362     */
363    public void paintComponent(Graphics g) {
364      Graphics2D g2 = (Graphics2D) g;
365      AffineTransform at = new AffineTransform();
366      int i = 0;
367      Rectangle r = g2.getClipBounds();
368      int step = 3;
369
370      if(model != null) {
371        Rectangle rr = new Rectangle(r.x, r.y, step, step);
372        Point p = new Point(r.x, r.y);
373        for(int x = r.x; x < r.x + r.width; x+=step) {
374          p.x = x;
375          rr.x = x;
376          for(int y = r.y; y < r.y + r.height; y+=step) {
377            p.y = y;
378            rr.y = y;
379            double s = model.classify(p);
380            if(s <= -1.0) {
381              g2.setPaint(negPaint);
382            } else if(s >= +1.0) {
383              g2.setPaint(posPaint);
384            } else {
385              g2.setPaint(Color.white);
386            }
387            g2.fill(rr);
388          }
389        }
390      }
391
392      Set supportVectors = Collections.EMPTY_SET;
393      if(model != null) {
394        supportVectors = model.items();
395      }
396      for(Iterator it = target.items().iterator(); it.hasNext(); i++) {
397        Point2D p = (Point2D) it.next();
398        at.setToTranslation(p.getX(), p.getY());
399        Shape glyph;
400        if(target.getTarget(p) > 0) {
401          glyph = getPosShape();
402        } else {
403          glyph = getNegShape();
404        }
405        Shape s = at.createTransformedShape(glyph);
406        if(supportVectors.contains(p)) {
407          g2.setPaint(svPaint);
408        } else {
409          g2.setPaint(plainPaint);
410        }
411        g2.draw(s);
412      }
413    }
414  }
415}