001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022package org.biojava.stats.svm;
023
024/**
025 * Train a regression support vector machine using the Sequential Minimal
026 * Optimization algorithm.  See "A Tutorial on Support Vector Regression"
027 * by Smola and Scholkopf.
028 *
029 * <p>
030 * WARNING: This doesn't work right now -- use and fix (both
031 * at own risk...)
032 *
033 * @author Matthew Pocock
034 * @author Thomas Down
035 */
036
037public class SMORegressionTrainer {
038  private double C = 1000;
039  private double epsilon = 0.000001;
040    
041  // Working variables for the trainer: protected by the
042  // synchronization on trainModel.
043
044  private SVMRegressionModel model;
045  private double[] target;
046  private double[] E;
047
048  public void setC(double C) {
049    this.C = C;
050  }
051
052  public void setEpsilon(double epsilon) {
053    this.epsilon = epsilon;
054  }
055
056  private boolean takeStep(int i1, int i2) {
057    //System.out.print("+");
058
059    if (i1 == i2) {
060            return false;
061    }
062    
063    double y1 = target[i1];
064    double y2 = target[i2];
065    double alpha1 = model.getAlpha(i1);
066    double alpha2 = model.getAlpha(i2);
067    double alpha1star = model.getAlphaStar(i1);
068    double alpha2star = model.getAlphaStar(i2);
069    double phi1 = getError(i1);
070    double phi2 = getError(i2);
071
072    System.out.println("y1=" + y1 + "\ty2=" + y2);
073    System.out.println("alpha1=" + alpha1 + "\talpha2=" + alpha2);
074    System.out.println("alpha1star=" + alpha1star + "\talpha2star=" + alpha2star);
075    System.out.println("phi1=" + phi1 + "\tphi2=" + phi2);
076    
077    double k11 = model.getKernelValue(i1, i1);
078    double k12 = model.getKernelValue(i1, i2);
079    double k22 = model.getKernelValue(i2, i2);
080    double eta = k11 + k22 - 2.0 * k12; // from improvement
081    // double eta = 2.0 * k12 - k11 - k22; // from tutorial, but always gives negative eta
082
083    System.out.println("k11=" + k11 + "\tk12=" + k12 + "\tk22=" + k22);
084    System.out.println("eta=" + eta);
085    
086    boolean case1, case2, case3, case4, finnished, changed;
087    double deltaphi = phi2 - phi1;
088    case1 = case2 = case3 = case4 = finnished = changed = false;
089
090    System.out.println("deltaphi=" + deltaphi);
091    
092    double L, H;
093    double a1 = 0.0, a2 = 0.0;
094    double gamma = alpha1 - alpha1star + alpha2 - alpha2star;
095    System.out.println("gamma=" + gamma);
096    System.out.println("epsilon=" + epsilon);
097    
098    if(eta <= 0) {
099      System.out.println("Negative eta");
100      return false;
101    }
102    while (!finnished) {
103      if(!case1 &&
104      (alpha1 > 0.0 || (alpha1star == 0.0 && deltaphi > 0.0)) &&
105      (alpha2 > 0.0 || (alpha2star == 0.0 && deltaphi < 0.0))
106      ) {
107        L = Math.max(0.0, gamma - C);
108        H = Math.min(C, gamma);
109        System.out.println("L=" + L + "\tH=" + H);
110        if (L < H) {
111          a2 = alpha2 - deltaphi / eta;
112          System.out.println("Ideal a2 = " + a2);
113          a2 = Math.min(a2, H);
114          a2 = Math.max(L, a2);
115          a1 = alpha1 - (a2 - alpha2);
116          System.out.println("a1=" + a1 + ", a2=" + a2);
117          // updatae alpah1, alpha2 if change greater than some eps
118          if(Math.abs(a2 - alpha2) > epsilon * (a2 + alpha2 + 1.0 + epsilon)) {
119            model.setAlpha(i1, a1);
120            model.setAlpha(i2, a2);
121            alpha1 = a1;
122            alpha2 = a2;
123            System.out.println("case1 worked");
124            changed = true;
125          } else {
126            System.out.println("case1: change too small: " + (a2 - alpha2));
127          }
128        } else {
129          System.out.println("case1: L > H");
130          finnished = true;
131        }
132        case1 = true;
133      } else if (!case2 &&
134      (alpha1     > 0.0 || (alpha1star == 0.0 && deltaphi > 2.0 * epsilon)) &&
135      (alpha2star > 0.0 || (alpha2     == 0.0 && deltaphi > 2.0 * epsilon))
136      ) {
137        L = Math.max(0.0, gamma);
138        H = Math.min(C, C + gamma);
139        System.out.println("L=" + L + "\tH=" + H);
140        if(L < H) {
141          a2 = alpha2star + (deltaphi - 2.0 * epsilon) / eta;
142          System.out.println("Ideal a2 = " + a2);
143          a2 = Math.min(a2, H);
144          a2 = Math.max(L, a2);
145          a1 = alpha1 + (a2 - alpha2star);
146          System.out.println("a1=" + a1 + ", a2=" + a2);
147          // updatae alpah1, alpha2star if change greater than some eps
148          if(Math.abs(a2 - alpha2star) > epsilon * (a2 + alpha2star + 1.0 + epsilon)) {
149            model.setAlpha(i1, a1);
150            model.setAlphaStar(i2, a2);
151            alpha1     = a1;
152            alpha2star = a2;
153            System.out.println("case2 worked");
154            changed = true;
155          } else {
156            System.out.println("case2: change too small: " + (a2 - alpha2star));
157          }
158        } else {
159          System.out.println("case2: L > H");
160          finnished = true;
161        }
162        case2 = true;
163      } else if (!case3 &&
164      (alpha1star > 0.0 || (alpha1     == 0.0 && deltaphi < 2.0 * epsilon)) &&
165      (alpha2     > 0.0 || (alpha2star == 0.0 && deltaphi < 2.0 * epsilon))
166      ) {
167        L = Math.max(0.0, -gamma);
168        H = Math.min(C, -gamma + C);
169        System.out.println("L=" + L + "\tH=" + H);
170        if(L < H) {
171          a2 = alpha2 - (deltaphi + 2.0 * epsilon) / eta; // according to improvement
172          //a2 = alpha2 - (deltaphi - 2.0 * epsilon) / eta; // according to tutorial
173          System.out.println("Ideal a2 = " + a2);
174          a2 = Math.min(a2, H);
175          a2 = Math.max(L, a2);
176          a1 = alpha1star + (a2 - alpha2); // according to improvement
177          //a1 = alpha1star - (a2 - alpha2); // according to tutorial
178          System.out.println("a1=" + a1 + ", a2=" + a2);
179          // update alpha1star, alpha2 if change is greater than some eps
180          if(Math.abs(a2 - alpha2) > epsilon * (a2 + alpha2 + 1.0 + epsilon)) {
181            model.setAlphaStar(i1, a1);
182            model.setAlpha(i2, a2);
183            alpha1star = a1;
184            alpha2     = a2;
185            System.out.println("case3 worked");
186            changed = true;
187          } else {
188            System.out.println("case3: change too small: " + (a2 - alpha2));
189          }
190        } else {
191          System.out.println("case3: L > H");
192          finnished = true;
193        }
194        case3 = true;
195      } else if(!case4 &&
196      (alpha1star > 0.0 || (alpha1 == 0 && deltaphi < 0.0)) &&
197      (alpha2star > 0.0 || (alpha2 == 0 && deltaphi > 0.0))
198      ) {
199        L = Math.max(0.0, -gamma - C);
200        H = Math.min(C, -gamma);
201        System.out.println("L=" + L + "\tH=" + H);
202        if(L < H) {
203          a2 = alpha2star + deltaphi/eta;
204          System.out.println("Ideal a2 = " + a2);
205          a2 = Math.min(a2, H);
206          a2 = Math.max(L, a2);
207          a1 = alpha1star - (a2 - alpha2star);
208          System.out.println("a1=" + a1 + ", a2=" + a2);
209          // update alpha1star, alpha2star if change is larger than some eps
210          if(Math.abs(a2 - alpha2star) > epsilon * (a2 + alpha2star + 1.0 + epsilon)) {
211            model.setAlphaStar(i1, a1);
212            model.setAlphaStar(i2, a2);
213            alpha1star = a1;
214            alpha2star = a2;
215            System.out.println("case4 worked");
216            changed = true;
217          } else {
218            System.out.println("case4: change too small: " + (a2 - alpha2star));
219          }
220        } else {
221          System.out.println("case4: L > H");
222          finnished = true;
223        }
224        case4 = true;
225      } else {
226        finnished = true;
227      }
228      System.out.println("!!!Errors: " + getError(i1) + "    " + getError(i2));
229
230      // update deltaphi
231      /*
232
233      deltaphi = phi2 - phi1 + eta * ( (alpha1 - alpha1star)
234                                     + (alpha1old - alpha1starold) );
235
236      */
237      deltaphi = getError(i2) - getError(i1);
238      System.out.println("deltaphi=" + deltaphi);
239    }
240    
241
242    // Calculate new threshold
243        double b;
244    double bOld = model.getThreshold();
245    System.out.println("b was " + bOld);
246
247    /*
248
249    if(!isBound(alpha1)) {
250      b = phi1 + y1*(alpha1 - alpha1old)*k11 + y2*(alpha2 - alpha2old)*k12 + bOld;
251    } else if(!isBound(alpha2)) {
252      b = phi2 + y1*(alpha1 - alpha1old)*k12 + y2*(alpha2 - alpha2old)*k22 + bOld;
253    } else if(!isBound(alpha1star)) {
254      b = phi1 + y1*(alpha1star - alpha1starold)*k11 + y2*(alpha2star - alpha2starold)*k12 + bOld;
255    } else if(!isBound(alpha2star)) {
256      b = phi2 + y1*(alpha1star - alpha1starold)*k12 + y2*(alpha2star - alpha2starold)*k22 + bOld;
257    } else {
258      // no suitable alpha found to infer b - take middle of allowed interval
259      double b1 = phi1 + y1*(alpha1 - alpha1old)*k11 + y2*(alpha2 - alpha2old)*k12 + bOld;
260      double b2 = phi2 + y1*(alpha1 - alpha1old)*k12 + y2*(alpha2 - alpha2old)*k22 + bOld;
261      b = (b1 + b2) / 2.0;
262    }
263    model.setThreshold(b);
264    System.out.println("b is " + b);
265
266    */
267
268    // double bOld = model.getThreshold();
269    if (!isBound(alpha1)) {
270        b = y1 - model.internalClassify(i1) + bOld - epsilon;
271    } else if (!isBound(alpha1star)) {
272        b = y1 - model.internalClassify(i1) + bOld - epsilon;
273    } else {
274        b = y1 - model.internalClassify(i1) + bOld;
275    }
276    model.setThreshold(b);   
277    
278    // Update error cache
279/*    E[i1] = 0;
280    E[i2] = 0;
281        for (int l = 0; l < E.length; ++l) {
282          if (l==i1 || l==i2) {
283                  continue;
284      }
285          if (!(isBound(model.getAlpha(l)))) {
286        E[l] += y1*(a1-alpha1)*model.getKernelValue(i1, l) + y2*(a2-alpha2)*model.getKernelValue(i2, l) + bOld - b;
287          }
288        }*/
289
290    if(changed) {
291      System.out.println("Successfuly changed things");
292      return true;
293    } else {
294      System.out.println("Nothing changed");
295      return false;
296    }
297  }
298  
299  private int examineExample(int i2) {
300    double alpha2 = model.getAlpha(i2);
301    double alpha2star = model.getAlphaStar(i2);
302    double phi2 = getError(i2);
303
304    /*
305
306    if (Math.abs(phi2) < epsilon)
307        return 0;
308
309    */
310
311    if ( (+phi2 < epsilon && alpha2star < C)   ||
312         (+phi2 < epsilon && alpha2star > 0.0) ||
313         (-phi2 > epsilon && alpha2 < C)       ||
314         (-phi2 > epsilon && alpha2 > 0.0))
315     {
316
317         /*
318
319            int secondChoice = -1;
320            double step = 0.0;
321            for (int l = 0; l < model.size(); ++l) {
322        if (!isBound(model.getAlpha(l))) {
323          double thisStep = Math.abs(getError(l) - phi2);
324          if (thisStep > step) {
325            step = thisStep;
326            secondChoice = l;
327          }
328        }
329            }
330
331            if (secondChoice >= 0) {
332        System.out.println("Using secondChoice heuristic for " + secondChoice + ", " + i2);
333        if (takeStep(secondChoice, i2)) {
334          return 1;
335        }
336            }
337
338            int randomStart = (int) Math.floor(Math.random() * model.size());
339            for (int l = 0; l < model.size(); ++l) {
340        int i1 = (l + randomStart) % model.size();
341        if (!isBound(model.getAlpha(i1))) {
342          System.out.println("Using unbound huristic for " + i1 + ", " + i2);
343                    if (takeStep(i1, i2)) {
344            return 1;
345          }
346            }
347      }
348      
349         */
350      
351         int randomStart = (int) Math.floor(Math.random() * model.size());
352
353            // The second pass should look at ALL alphas, but
354            // we've already checked the non-bound ones.
355            for (int l = 0; l < model.size(); ++l) {
356        int i1 = (l + randomStart) % model.size();
357        // if (isBound(model.getAlpha(i1))) {
358          System.out.println("Using bound huristic for " + i1 + ", " + i2);
359          if (takeStep(i1, i2)) {
360            return 1;
361          }
362          // }
363            }
364    }
365          return 0;
366  }
367
368  private boolean isBound(double alpha) {
369    return (alpha <= 0 || alpha >= C);
370  }
371
372  private double getError(int i) {
373//    if (E[i] == Double.NEGATIVE_INFINITY || 
374//          isBound(model.getAlpha(i))) {
375            E[i] = model.internalClassify(i) - target[i];
376      System.out.println("Calculated error: " + E[i]);
377//    }
378    return E[i];
379  }
380
381  public synchronized void trainModel(SVMRegressionModel m, double[] t) {
382    model = m;
383    target = t;
384
385    E = new double[model.size()];
386    for (int i = 0; i < t.length; ++i) {
387            E[i] = Double.NEGATIVE_INFINITY;
388    }
389
390    int numChanged = 0;
391    boolean examineAll = true;
392    //int SigFig = -100;
393
394    while (numChanged > 0 || examineAll /*|| SigFig < 3*/) {
395        System.out.print(".");
396
397        numChanged = 0;
398        if (examineAll) {
399            System.out.println("Running full iteration");
400            for (int i = 0; i < model.size(); ++i) {
401                numChanged += examineExample(i);
402            } 
403        } else {
404            System.out.println("Running non-bounds iteration");
405            for (int i = 0; i < model.size(); ++i) {
406                double alpha = model.getAlpha(i);
407                if (!isBound(alpha)) {
408                    numChanged += examineExample(i);
409                }
410            }
411        }
412        
413        if (examineAll) {
414            examineAll = false;
415        } else {
416            examineAll = (numChanged == 0);
417        }
418    }
419
420    E = null;
421  }
422}