001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023package org.biojava.stats.svm; 024 025import java.util.Iterator; 026import java.util.Set; 027 028/** 029 * Train a support vector machine using the Sequential Minimal 030 * Optimization algorithm. See Kernel Methods book. 031 * 032 * @author Thomas Down 033 * @author Matthew Pocock 034 */ 035public class SMOTrainer { 036 private double _C = 1000; 037 private double _epsilon = 0.000001; 038 039 public void setC(double C) { 040 this._C = C; 041 } 042 043 public double getC() { 044 return _C; 045 } 046 047 public void setEpsilon(double epsilon) { 048 this._epsilon = epsilon; 049 } 050 051 public double getEpsilon() { 052 return _epsilon; 053 } 054 055 private boolean takeStep(SMOTrainingContext trainingContext, int i1, int i2) { 056 // //System.out.print("+"); 057 058 if (i1 == i2) { 059 return false; 060 } 061 062 double y1 = trainingContext.getTarget(i1); 063 double y2 = trainingContext.getTarget(i2); 064 double alpha1 = trainingContext.getAlpha(i1); 065 double alpha2 = trainingContext.getAlpha(i2); 066 double E1 = trainingContext.getError(i1); 067 double E2 = trainingContext.getError(i2); 068 double s = y1 * y2; 069 double C = trainingContext.getC(); 070 double epsilon = trainingContext.getEpsilon(); 071 072 double L, H; 073 if (y2 != y1) /* preferred (s<0) */ { 074 // targets in opposite directions 075 L = Math.max(0, alpha2 - alpha1); 076 H = Math.min(C, C + alpha2 - alpha1); 077 } else { 078 // Equal targets. 079 L = Math.max(0, alpha1 + alpha2 - C); 080 H = Math.min(C, alpha1 + alpha2); 081 } 082 if (L == H) { 083 ////System.out.print("h"); 084 return false; 085 } 086 087 double k11 = trainingContext.getKernelValue(i1, i1); 088 double k12 = trainingContext.getKernelValue(i1, i2); 089 double k22 = trainingContext.getKernelValue(i2, i2); 090 double eta = 2 * k12 - k11 - k22; 091 092 double a1 = 0, a2 = 0; 093 if (eta > 0 && eta < epsilon) { 094 eta = 0.0; 095 } 096 097 if (eta < 0) { 098 a2 = alpha2 - y2 * (E1 - E2) / eta; 099 if (a2 < L) { 100 a2 = L; 101 } else if (a2 > H) { 102 a2 = H; 103 } 104 } else { 105 //System.out.println("Positive eta!"); 106 107 /* 108 109 double gamma = alpha1 + s*alpha2; 110 double v1 = model.classify(model.getVector(i1)) + model.getThreshold() - y1*alpha1*k11 - y2*alpha2*k12; 111 double v2 = model.classify(model.getVector(i2)) + model.getThreshold() - y1*alpha1*k12 - y2*alpha2*k22; 112 113 double Lobj = gamma - s * L + L - 0.5*k11*Math.pow(gamma - s*L,2) - 0.5*k22*Math.pow(L,2) - s*k12*(gamma-s*L)*L-y1*(gamma-s*L) - y1*(gamma - s*L)*v1 - y2*L*v2; 114 double Hobj = gamma - s * H + H - 0.5*k11*Math.pow(gamma - s*H,2) - 0.5*k22*Math.pow(H,2) - s*k12*(gamma-s*H)*H-y1*(gamma-s*H) - y1*(gamma - s*H)*v1 - y2*H*v2; 115 if (Lobj > Hobj+epsilon) { 116 a2 = L; 117 } else if (Lobj < Hobj-epsilon) { 118 a2 = H; 119 } else { 120 a2 = alpha2; 121 } 122 */ 123 ////System.out.print("+"); 124 return false; 125 } 126 127 a1 = alpha1 + s*(alpha2 - a2); 128 if (Math.abs(a1 - alpha1) < epsilon * (a1 + alpha1+1 +epsilon)) { 129 // //System.out.print("s"); 130 return false; 131 } 132 133 // Calculate new threshold 134 135 double b; 136 double bOLD = trainingContext.getThreshold(); 137 138 if (0 < a1 && a1 < C) { 139 // use "b1 formula" 140 // //System.out.println("b1"); 141 b = E1 + y1*(a1 - alpha1)*k11 + y2*(a2 - alpha2)*k12 + bOLD; 142 } else if (0 < a2 && a2 < C) { 143 // use "b2 formula" 144 b = E2 + y1*(a1 - alpha1)*k12 + y2*(a2 - alpha2)*k22 + bOLD; 145 // //System.out.println("b2"); 146 } else { 147 // Both are at bounds -- use `half way' method. 148 double b1, b2; 149 b1 = E1 + y1*(a1 - alpha1)*k11 + y2*(a2 - alpha2)*k12 + bOLD; 150 b2 = E2 + y1*(a1 - alpha1)*k12 + y2*(a2 - alpha2)*k22 + bOLD; 151 // //System.out.println("hybrid"); 152 b = (b1 + b2) / 2.0; 153 } 154 trainingContext.setThreshold(b); 155 trainingContext.setAlpha(i1, a1); 156 trainingContext.setAlpha(i2, a2); 157 158 // Update error cache 159 160 trainingContext.resetError(i1); 161 trainingContext.resetError(i2); 162 163 for (int l = 0; l < trainingContext.size(); ++l) { 164 if (l==i1 || l==i2) { 165 continue; 166 } 167 if (!trainingContext.isBound( 168 trainingContext.getAlpha(l) 169 )) { 170 trainingContext.updateError( 171 l, 172 y1*(a1-alpha1)*trainingContext.getKernelValue(i1, l) + 173 y2*(a2-alpha2)*trainingContext.getKernelValue(i2, l) + 174 bOLD - b 175 ); 176 } 177 } 178 179 return true; 180 } 181 182 private int examineExample(SMOTrainingContext trainingContext, int i2) { 183 double y2 = trainingContext.getTarget(i2); 184 double alpha2 = trainingContext.getAlpha(i2); 185 double E2 = trainingContext.getError(i2); 186 double r2 = E2 * y2; 187 double epsilon = trainingContext.getEpsilon(); 188 double C = trainingContext.getC(); 189 190 //System.out.println("r2 = " + r2); 191 //System.out.println("alpha2 = " + alpha2); 192 //System.out.println("epsilon = " + epsilon); 193 //System.out.println("C = " + C); 194 if ((r2 < -epsilon && alpha2 < C) || (r2 > epsilon && alpha2 > 0)) { 195 int secondChoice = -1; 196 double step = 0.0; 197 //System.out.println("First choice heuristic"); 198 for (int l = 0; l < trainingContext.size(); ++l) { 199 if (!trainingContext.isBound( 200 trainingContext.getAlpha(l) 201 )) { 202 double thisStep = Math.abs(trainingContext.getError(l) - E2); 203 if (thisStep > step) { 204 step = thisStep; 205 secondChoice = l; 206 } 207 } 208 } 209 210 if (secondChoice >= 0) { 211 if (takeStep(trainingContext, secondChoice, i2)) { 212 return 1; 213 } 214 } 215 216 //System.out.println("Unbound"); 217 int randomStart = (int) Math.floor(Math.random() * trainingContext.size()); 218 for (int l = 0; l < trainingContext.size(); ++l) { 219 int i1 = (l + randomStart) % trainingContext.size(); 220 if (!trainingContext.isBound( 221 trainingContext.getAlpha(i1) 222 )) { 223 if (takeStep(trainingContext, i1, i2)) { 224 return 1; 225 } 226 } 227 } 228 // The second pass should look at ALL alphas, but 229 // we've already checked the non-bound ones. 230 //System.out.println("Bound"); 231 for (int l = 0; l < trainingContext.size(); l++) { 232 int i1 = (l + randomStart) % trainingContext.size(); 233 if (trainingContext.isBound( 234 trainingContext.getAlpha(i1) 235 )) { 236 if (takeStep(trainingContext, i1, i2)) { 237 return 1; 238 } 239 } 240 } 241 } else { 242 //System.out.print("Nothing to optimize"); 243 } 244 return 0; 245 } 246 247 public SVMClassifierModel 248 trainModel(SVMTarget target, SVMKernel kernel, TrainingListener l) { 249 SMOTrainingContext trainingContext = new SMOTrainingContext(target, kernel, l); 250 251 int numChanged = 0; 252 boolean examineAll = true; 253 254 while (numChanged > 0 || examineAll) { 255 numChanged = 0; 256 if (examineAll) { 257 //System.out.println("Running full iteration"); 258 for(int i = 0; i < trainingContext.size(); i++) { 259 //System.out.println("Item " + i); 260 numChanged += examineExample(trainingContext, i); 261 } 262 } else { 263 //System.out.println("Running non-bounds iteration"); 264 for(int i = 0; i < trainingContext.size(); i++) { 265 double alpha = trainingContext.getAlpha(i); 266 if (!trainingContext.isBound(alpha)) { 267 numChanged += examineExample(trainingContext, i); 268 } 269 } 270 } 271 if (examineAll) { 272 examineAll = false; 273 } else { 274 examineAll = (numChanged == 0); 275 } 276 277 trainingContext.trainingCycleCompleted(); 278 } 279 trainingContext.trainingCompleted(); 280 281 return trainingContext.getModel(); 282 } 283 284 285 final class SMOTrainingContext implements TrainingContext { 286 private double C; 287 private double epsilon; 288 private TrainingListener listener; 289 private int cycle = 0; 290 private TrainingEvent ourEvent; 291 private SVMTarget target; 292 private SVMClassifierModel model; 293 294 private Object [] items; 295 private double [] alphas; 296 private double [] targets; 297 private double [] E; 298 299 private boolean isBound(double alpha) { 300 return (alpha <= 0 || alpha >= getC()); 301 } 302 303 public int size() { 304 return items.length; 305 } 306 307 public int getCurrentCycle() { 308 return cycle; 309 } 310 311 public void trainingCycleCompleted() { 312 cycle++; 313 if(listener != null) { 314 listener.trainingCycleComplete(ourEvent); 315 } 316 } 317 318 public void trainingCompleted() { 319 for(int i = 0; i < size(); i++) { 320 if(getAlpha(i) == 0) { 321 model.removeItem(getItem(i)); 322 } 323 } 324 if (listener != null) { 325 listener.trainingComplete(ourEvent); 326 } 327 } 328 329 public Object getItem(int i) { 330 return items[i]; 331 } 332 333 public double getAlpha(int i) { 334 return alphas[i]; 335 } 336 337 public void setAlpha(int i, double a) { 338 alphas[i] = a; 339 model.setAlpha(getItem(i), getAlpha(i) * getTarget(i)); 340 } 341 342 public double getTarget(int i) { 343 return targets[i]; 344 } 345 346 public double getC() { 347 return C; 348 } 349 350 public double getEpsilon() { 351 return epsilon; 352 } 353 354 public SVMTarget getTarget() { 355 return target; 356 } 357 358 public SVMClassifierModel getModel() { 359 return model; 360 } 361 362 public void setThreshold(double t) { 363 model.setThreshold(t); 364 } 365 366 public double getThreshold() { 367 return model.getThreshold(); 368 } 369 370 public double getError(int i) { 371 double alpha = getAlpha(i); 372 if (isBound(alpha)) { 373 return E[i] = getModel().classify(getItem(i)) - getTarget(i); 374 } 375 return E[i]; 376 } 377 378 public void updateError(int i, double delta) { 379 E[i] += delta; 380 } 381 382 public void resetError(int i) { 383 E[i] = getModel().classify(getItem(i)) - getTarget(i); 384 } 385 386 public double getKernelValue(int i1, int i2) { 387 return getModel().getKernel().evaluate(getItem(i1), getItem(i2)); 388 } 389 390 public SMOTrainingContext(SVMTarget target, SVMKernel kernel, TrainingListener l) { 391 C = SMOTrainer.this.getC(); 392 epsilon = SMOTrainer.this.getEpsilon(); 393 model = new SimpleSVMClassifierModel(kernel, target); 394 model.setThreshold(0.0); 395 listener = l; 396 ourEvent = new TrainingEvent(this); 397 cycle = 0; 398 Set itemSet = target.items(); 399 int size = itemSet.size(); 400 items = new Object[size]; 401 alphas = new double[size]; 402 targets = new double[size]; 403 E = new double[size]; 404 Iterator itemI = itemSet.iterator(); 405 for (int i = 0; itemI.hasNext(); ++i) { 406 Object item = itemI.next(); 407 items[i] = item; 408 targets[i] = target.getTarget(item); 409 alphas[i] = model.getAlpha(item) / targets[i]; 410 E[i] = - targets[i]; 411 } 412 } 413 } 414}