001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023package org.biojava.stats.svm;
024
025import java.util.Iterator;
026import java.util.Set;
027
028/**
029 * Train a support vector machine using the Sequential Minimal
030 * Optimization algorithm.  See Kernel Methods book.
031 *
032 * @author Thomas Down
033 * @author Matthew Pocock
034 */
035public class SMOTrainer {
036  private double _C = 1000;
037  private double _epsilon = 0.000001;
038
039  public void setC(double C) {
040    this._C = C;
041  }
042    
043  public double getC() {
044    return _C;
045  }
046
047  public void setEpsilon(double epsilon) {
048    this._epsilon = epsilon;
049  }
050    
051  public double getEpsilon() {
052    return _epsilon;
053  }
054
055  private boolean takeStep(SMOTrainingContext trainingContext, int i1, int i2) {
056    // //System.out.print("+");
057
058    if (i1 == i2) {
059            return false;
060    }
061    
062    double y1 = trainingContext.getTarget(i1);
063    double y2 = trainingContext.getTarget(i2);
064    double alpha1 = trainingContext.getAlpha(i1);
065    double alpha2 = trainingContext.getAlpha(i2);
066    double E1 = trainingContext.getError(i1);
067    double E2 = trainingContext.getError(i2);
068    double s = y1 * y2;
069    double C = trainingContext.getC();
070    double epsilon = trainingContext.getEpsilon();
071        
072        double L, H;
073    if (y2 != y1) /* preferred (s<0) */ {
074            // targets in opposite directions
075            L = Math.max(0, alpha2 - alpha1);
076            H = Math.min(C, C + alpha2 - alpha1);
077    } else {
078            // Equal targets.
079            L = Math.max(0, alpha1 + alpha2 - C);
080            H = Math.min(C, alpha1 + alpha2);
081    }
082    if (L == H) {
083      ////System.out.print("h");
084      return false;
085    }
086
087    double k11 = trainingContext.getKernelValue(i1, i1);
088    double k12 = trainingContext.getKernelValue(i1, i2);
089    double k22 = trainingContext.getKernelValue(i2, i2);
090    double eta = 2 * k12 - k11 - k22;
091        
092        double a1 = 0, a2 = 0;
093    if (eta > 0 && eta < epsilon) {
094      eta = 0.0;
095    }
096  
097        if (eta < 0) {
098            a2 = alpha2 - y2 * (E1 - E2) / eta;
099            if (a2 < L) {
100        a2 = L;
101      } else if (a2 > H) {
102        a2 = H;
103      }
104    } else {
105      //System.out.println("Positive eta!");
106
107      /*
108
109      double gamma = alpha1 + s*alpha2;
110      double v1 = model.classify(model.getVector(i1)) + model.getThreshold() - y1*alpha1*k11 - y2*alpha2*k12;
111      double v2 = model.classify(model.getVector(i2)) + model.getThreshold() - y1*alpha1*k12 - y2*alpha2*k22;
112
113      double Lobj = gamma - s * L + L - 0.5*k11*Math.pow(gamma - s*L,2) - 0.5*k22*Math.pow(L,2) - s*k12*(gamma-s*L)*L-y1*(gamma-s*L) - y1*(gamma - s*L)*v1 - y2*L*v2; 
114      double Hobj = gamma - s * H + H - 0.5*k11*Math.pow(gamma - s*H,2) - 0.5*k22*Math.pow(H,2) - s*k12*(gamma-s*H)*H-y1*(gamma-s*H) - y1*(gamma - s*H)*v1 - y2*H*v2;
115      if (Lobj > Hobj+epsilon) {
116        a2 = L;
117      } else if (Lobj < Hobj-epsilon) {
118        a2 = H;
119      } else {
120        a2 = alpha2;
121      }
122      */
123      ////System.out.print("+");
124      return false;
125    }
126        
127        a1 = alpha1 + s*(alpha2 - a2);
128    if (Math.abs(a1 - alpha1) < epsilon * (a1 + alpha1+1 +epsilon)) {
129      //    //System.out.print("s");
130      return false;
131    }
132
133    // Calculate new threshold
134        
135        double b;
136    double bOLD = trainingContext.getThreshold();
137
138        if (0 < a1 && a1 < C) {
139      // use "b1 formula"
140            // //System.out.println("b1");
141          b = E1 + y1*(a1 - alpha1)*k11 + y2*(a2 - alpha2)*k12 + bOLD;
142        } else if (0 < a2 && a2 < C) {
143          // use "b2 formula"
144          b = E2 + y1*(a1 - alpha1)*k12 + y2*(a2 - alpha2)*k22 + bOLD;
145            // //System.out.println("b2");
146        } else {
147            // Both are at bounds -- use `half way' method.
148            double b1, b2;
149            b1 = E1 + y1*(a1 - alpha1)*k11 + y2*(a2 - alpha2)*k12 + bOLD;
150            b2 = E2 + y1*(a1 - alpha1)*k12 + y2*(a2 - alpha2)*k22 + bOLD;
151            // //System.out.println("hybrid");
152            b = (b1 + b2) / 2.0;
153    }
154    trainingContext.setThreshold(b);
155    trainingContext.setAlpha(i1, a1);
156    trainingContext.setAlpha(i2, a2);
157
158    // Update error cache
159
160    trainingContext.resetError(i1);
161    trainingContext.resetError(i2);
162
163    for (int l = 0; l < trainingContext.size(); ++l) {
164      if (l==i1 || l==i2) {
165        continue;
166      }
167      if (!trainingContext.isBound(
168        trainingContext.getAlpha(l)
169      )) {
170        trainingContext.updateError(
171          l,
172          y1*(a1-alpha1)*trainingContext.getKernelValue(i1, l) +
173          y2*(a2-alpha2)*trainingContext.getKernelValue(i2, l) +
174          bOLD - b
175        );
176      }
177    }
178
179    return true;
180  }
181
182  private int examineExample(SMOTrainingContext trainingContext, int i2) {
183    double y2 = trainingContext.getTarget(i2);
184    double alpha2 = trainingContext.getAlpha(i2);
185    double E2 = trainingContext.getError(i2);
186    double r2 = E2 * y2;
187    double epsilon = trainingContext.getEpsilon();
188    double C = trainingContext.getC();
189
190    //System.out.println("r2 = " + r2);
191    //System.out.println("alpha2 = " + alpha2);
192    //System.out.println("epsilon = " + epsilon);
193    //System.out.println("C = " + C);
194    if ((r2 < -epsilon && alpha2 < C) || (r2 > epsilon && alpha2 > 0)) {
195            int secondChoice = -1;
196            double step = 0.0;
197      //System.out.println("First choice heuristic");
198            for (int l = 0; l < trainingContext.size(); ++l) {
199        if (!trainingContext.isBound(
200          trainingContext.getAlpha(l)
201        )) {
202          double thisStep = Math.abs(trainingContext.getError(l) - E2);
203          if (thisStep > step) {
204            step = thisStep;
205            secondChoice = l;
206          }
207        }
208            }
209
210            if (secondChoice >= 0) {
211        if (takeStep(trainingContext, secondChoice, i2)) {
212          return 1;
213        }
214            }
215
216      //System.out.println("Unbound");
217          int randomStart = (int) Math.floor(Math.random() * trainingContext.size());
218          for (int l = 0; l < trainingContext.size(); ++l) {
219        int i1 = (l + randomStart) % trainingContext.size();
220        if (!trainingContext.isBound(
221          trainingContext.getAlpha(i1)
222        )) {
223                    if (takeStep(trainingContext, i1, i2)) {
224            return 1;
225          }
226        }
227          }
228            // The second pass should look at ALL alphas, but
229            // we've already checked the non-bound ones.
230      //System.out.println("Bound");
231            for (int l = 0; l < trainingContext.size(); l++) {
232        int i1 = (l + randomStart) % trainingContext.size();
233        if (trainingContext.isBound(
234          trainingContext.getAlpha(i1)
235        )) {
236          if (takeStep(trainingContext, i1, i2)) {
237            return 1;
238          }
239        }
240            }
241    } else {
242      //System.out.print("Nothing to optimize");
243    }
244    return 0;
245  }
246
247  public SVMClassifierModel
248  trainModel(SVMTarget target, SVMKernel kernel, TrainingListener l) {
249    SMOTrainingContext trainingContext = new SMOTrainingContext(target, kernel, l);
250
251    int numChanged = 0;
252    boolean examineAll = true;
253    
254    while (numChanged > 0 || examineAll) {
255      numChanged = 0;
256            if (examineAll) {
257        //System.out.println("Running full iteration");
258        for(int i = 0; i < trainingContext.size(); i++) {
259          //System.out.println("Item " + i);
260          numChanged += examineExample(trainingContext, i);
261        } 
262            } else {
263        //System.out.println("Running non-bounds iteration");
264        for(int i = 0; i < trainingContext.size(); i++) {
265          double alpha = trainingContext.getAlpha(i);
266          if (!trainingContext.isBound(alpha)) {
267            numChanged += examineExample(trainingContext, i);
268          }
269        }
270      }
271      if (examineAll) {
272        examineAll = false;
273      } else {
274        examineAll = (numChanged == 0);
275      }
276            
277          trainingContext.trainingCycleCompleted();
278    }
279    trainingContext.trainingCompleted();
280    
281    return trainingContext.getModel();
282  }
283
284
285  final class SMOTrainingContext implements TrainingContext {
286    private double C;
287    private double epsilon;
288    private TrainingListener listener;
289    private int cycle = 0;
290    private TrainingEvent ourEvent;
291    private SVMTarget target;
292    private SVMClassifierModel model;
293
294    private Object [] items;
295    private double [] alphas;
296    private double [] targets;
297    private double [] E;
298
299    private boolean isBound(double alpha) {
300      return (alpha <= 0 || alpha >= getC());
301    }
302    
303    public int size() {
304      return items.length;
305    }
306    
307    public int getCurrentCycle() {
308      return cycle;
309    }
310    
311    public void trainingCycleCompleted() {
312      cycle++;
313      if(listener != null) {
314        listener.trainingCycleComplete(ourEvent);
315      }
316    }
317
318    public void trainingCompleted() {
319      for(int i = 0; i < size(); i++) {
320        if(getAlpha(i) == 0) {
321          model.removeItem(getItem(i));
322        }
323      }
324      if (listener != null) {
325        listener.trainingComplete(ourEvent);
326      }
327    }
328
329    public Object getItem(int i) {
330      return items[i];
331    }
332    
333    public double getAlpha(int i) {
334      return alphas[i];
335    }
336    
337    public void setAlpha(int i, double a) {
338      alphas[i] = a;
339      model.setAlpha(getItem(i), getAlpha(i) * getTarget(i));
340    }
341    
342    public double getTarget(int i) {
343      return targets[i];
344    }
345    
346    public double getC() {
347      return C;
348    }
349    
350    public double getEpsilon() {
351      return epsilon;
352    }
353    
354    public SVMTarget getTarget() {
355      return target;
356    }
357    
358    public SVMClassifierModel getModel() {
359      return model;
360    }
361    
362    public void setThreshold(double t) {
363      model.setThreshold(t);
364    }
365    
366    public double getThreshold() {
367      return model.getThreshold();
368    }
369    
370    public double getError(int i) {
371      double alpha = getAlpha(i);
372      if (isBound(alpha)) {
373        return E[i] = getModel().classify(getItem(i)) - getTarget(i);
374      }
375      return E[i];
376    }
377   
378    public void updateError(int i, double delta) {
379      E[i] += delta;
380    }
381
382    public void resetError(int i) {
383      E[i] = getModel().classify(getItem(i)) - getTarget(i);
384    }
385
386    public double getKernelValue(int i1, int i2) {
387      return getModel().getKernel().evaluate(getItem(i1), getItem(i2));
388    }
389    
390    public SMOTrainingContext(SVMTarget target, SVMKernel kernel, TrainingListener l) {
391      C = SMOTrainer.this.getC();
392      epsilon = SMOTrainer.this.getEpsilon();
393      model = new SimpleSVMClassifierModel(kernel, target);
394      model.setThreshold(0.0);
395      listener = l;
396      ourEvent = new TrainingEvent(this);
397      cycle = 0;
398      Set itemSet = target.items();
399      int size = itemSet.size();
400      items = new Object[size];
401      alphas = new double[size];
402      targets = new double[size];
403      E = new double[size];
404      Iterator itemI = itemSet.iterator();
405      for (int i = 0; itemI.hasNext(); ++i) {
406        Object item = itemI.next();
407        items[i] = item;
408        targets[i] = target.getTarget(item);
409        alphas[i] = model.getAlpha(item) / targets[i];
410        E[i] = - targets[i];
411      }
412    }
413  }
414}