001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022package org.biojava.stats.svm; 023 024/** 025 * Train a regression support vector machine using the Sequential Minimal 026 * Optimization algorithm. See "A Tutorial on Support Vector Regression" 027 * by Smola and Scholkopf. 028 * 029 * <p> 030 * WARNING: This doesn't work right now -- use and fix (both 031 * at own risk...) 032 * 033 * @author Matthew Pocock 034 * @author Thomas Down 035 */ 036 037public class SMORegressionTrainer { 038 private double C = 1000; 039 private double epsilon = 0.000001; 040 041 // Working variables for the trainer: protected by the 042 // synchronization on trainModel. 043 044 private SVMRegressionModel model; 045 private double[] target; 046 private double[] E; 047 048 public void setC(double C) { 049 this.C = C; 050 } 051 052 public void setEpsilon(double epsilon) { 053 this.epsilon = epsilon; 054 } 055 056 private boolean takeStep(int i1, int i2) { 057 //System.out.print("+"); 058 059 if (i1 == i2) { 060 return false; 061 } 062 063 double y1 = target[i1]; 064 double y2 = target[i2]; 065 double alpha1 = model.getAlpha(i1); 066 double alpha2 = model.getAlpha(i2); 067 double alpha1star = model.getAlphaStar(i1); 068 double alpha2star = model.getAlphaStar(i2); 069 double phi1 = getError(i1); 070 double phi2 = getError(i2); 071 072 System.out.println("y1=" + y1 + "\ty2=" + y2); 073 System.out.println("alpha1=" + alpha1 + "\talpha2=" + alpha2); 074 System.out.println("alpha1star=" + alpha1star + "\talpha2star=" + alpha2star); 075 System.out.println("phi1=" + phi1 + "\tphi2=" + phi2); 076 077 double k11 = model.getKernelValue(i1, i1); 078 double k12 = model.getKernelValue(i1, i2); 079 double k22 = model.getKernelValue(i2, i2); 080 double eta = k11 + k22 - 2.0 * k12; // from improvement 081 // double eta = 2.0 * k12 - k11 - k22; // from tutorial, but always gives negative eta 082 083 System.out.println("k11=" + k11 + "\tk12=" + k12 + "\tk22=" + k22); 084 System.out.println("eta=" + eta); 085 086 boolean case1, case2, case3, case4, finnished, changed; 087 double deltaphi = phi2 - phi1; 088 case1 = case2 = case3 = case4 = finnished = changed = false; 089 090 System.out.println("deltaphi=" + deltaphi); 091 092 double L, H; 093 double a1 = 0.0, a2 = 0.0; 094 double gamma = alpha1 - alpha1star + alpha2 - alpha2star; 095 System.out.println("gamma=" + gamma); 096 System.out.println("epsilon=" + epsilon); 097 098 if(eta <= 0) { 099 System.out.println("Negative eta"); 100 return false; 101 } 102 while (!finnished) { 103 if(!case1 && 104 (alpha1 > 0.0 || (alpha1star == 0.0 && deltaphi > 0.0)) && 105 (alpha2 > 0.0 || (alpha2star == 0.0 && deltaphi < 0.0)) 106 ) { 107 L = Math.max(0.0, gamma - C); 108 H = Math.min(C, gamma); 109 System.out.println("L=" + L + "\tH=" + H); 110 if (L < H) { 111 a2 = alpha2 - deltaphi / eta; 112 System.out.println("Ideal a2 = " + a2); 113 a2 = Math.min(a2, H); 114 a2 = Math.max(L, a2); 115 a1 = alpha1 - (a2 - alpha2); 116 System.out.println("a1=" + a1 + ", a2=" + a2); 117 // updatae alpah1, alpha2 if change greater than some eps 118 if(Math.abs(a2 - alpha2) > epsilon * (a2 + alpha2 + 1.0 + epsilon)) { 119 model.setAlpha(i1, a1); 120 model.setAlpha(i2, a2); 121 alpha1 = a1; 122 alpha2 = a2; 123 System.out.println("case1 worked"); 124 changed = true; 125 } else { 126 System.out.println("case1: change too small: " + (a2 - alpha2)); 127 } 128 } else { 129 System.out.println("case1: L > H"); 130 finnished = true; 131 } 132 case1 = true; 133 } else if (!case2 && 134 (alpha1 > 0.0 || (alpha1star == 0.0 && deltaphi > 2.0 * epsilon)) && 135 (alpha2star > 0.0 || (alpha2 == 0.0 && deltaphi > 2.0 * epsilon)) 136 ) { 137 L = Math.max(0.0, gamma); 138 H = Math.min(C, C + gamma); 139 System.out.println("L=" + L + "\tH=" + H); 140 if(L < H) { 141 a2 = alpha2star + (deltaphi - 2.0 * epsilon) / eta; 142 System.out.println("Ideal a2 = " + a2); 143 a2 = Math.min(a2, H); 144 a2 = Math.max(L, a2); 145 a1 = alpha1 + (a2 - alpha2star); 146 System.out.println("a1=" + a1 + ", a2=" + a2); 147 // updatae alpah1, alpha2star if change greater than some eps 148 if(Math.abs(a2 - alpha2star) > epsilon * (a2 + alpha2star + 1.0 + epsilon)) { 149 model.setAlpha(i1, a1); 150 model.setAlphaStar(i2, a2); 151 alpha1 = a1; 152 alpha2star = a2; 153 System.out.println("case2 worked"); 154 changed = true; 155 } else { 156 System.out.println("case2: change too small: " + (a2 - alpha2star)); 157 } 158 } else { 159 System.out.println("case2: L > H"); 160 finnished = true; 161 } 162 case2 = true; 163 } else if (!case3 && 164 (alpha1star > 0.0 || (alpha1 == 0.0 && deltaphi < 2.0 * epsilon)) && 165 (alpha2 > 0.0 || (alpha2star == 0.0 && deltaphi < 2.0 * epsilon)) 166 ) { 167 L = Math.max(0.0, -gamma); 168 H = Math.min(C, -gamma + C); 169 System.out.println("L=" + L + "\tH=" + H); 170 if(L < H) { 171 a2 = alpha2 - (deltaphi + 2.0 * epsilon) / eta; // according to improvement 172 //a2 = alpha2 - (deltaphi - 2.0 * epsilon) / eta; // according to tutorial 173 System.out.println("Ideal a2 = " + a2); 174 a2 = Math.min(a2, H); 175 a2 = Math.max(L, a2); 176 a1 = alpha1star + (a2 - alpha2); // according to improvement 177 //a1 = alpha1star - (a2 - alpha2); // according to tutorial 178 System.out.println("a1=" + a1 + ", a2=" + a2); 179 // update alpha1star, alpha2 if change is greater than some eps 180 if(Math.abs(a2 - alpha2) > epsilon * (a2 + alpha2 + 1.0 + epsilon)) { 181 model.setAlphaStar(i1, a1); 182 model.setAlpha(i2, a2); 183 alpha1star = a1; 184 alpha2 = a2; 185 System.out.println("case3 worked"); 186 changed = true; 187 } else { 188 System.out.println("case3: change too small: " + (a2 - alpha2)); 189 } 190 } else { 191 System.out.println("case3: L > H"); 192 finnished = true; 193 } 194 case3 = true; 195 } else if(!case4 && 196 (alpha1star > 0.0 || (alpha1 == 0 && deltaphi < 0.0)) && 197 (alpha2star > 0.0 || (alpha2 == 0 && deltaphi > 0.0)) 198 ) { 199 L = Math.max(0.0, -gamma - C); 200 H = Math.min(C, -gamma); 201 System.out.println("L=" + L + "\tH=" + H); 202 if(L < H) { 203 a2 = alpha2star + deltaphi/eta; 204 System.out.println("Ideal a2 = " + a2); 205 a2 = Math.min(a2, H); 206 a2 = Math.max(L, a2); 207 a1 = alpha1star - (a2 - alpha2star); 208 System.out.println("a1=" + a1 + ", a2=" + a2); 209 // update alpha1star, alpha2star if change is larger than some eps 210 if(Math.abs(a2 - alpha2star) > epsilon * (a2 + alpha2star + 1.0 + epsilon)) { 211 model.setAlphaStar(i1, a1); 212 model.setAlphaStar(i2, a2); 213 alpha1star = a1; 214 alpha2star = a2; 215 System.out.println("case4 worked"); 216 changed = true; 217 } else { 218 System.out.println("case4: change too small: " + (a2 - alpha2star)); 219 } 220 } else { 221 System.out.println("case4: L > H"); 222 finnished = true; 223 } 224 case4 = true; 225 } else { 226 finnished = true; 227 } 228 System.out.println("!!!Errors: " + getError(i1) + " " + getError(i2)); 229 230 // update deltaphi 231 /* 232 233 deltaphi = phi2 - phi1 + eta * ( (alpha1 - alpha1star) 234 + (alpha1old - alpha1starold) ); 235 236 */ 237 deltaphi = getError(i2) - getError(i1); 238 System.out.println("deltaphi=" + deltaphi); 239 } 240 241 242 // Calculate new threshold 243 double b; 244 double bOld = model.getThreshold(); 245 System.out.println("b was " + bOld); 246 247 /* 248 249 if(!isBound(alpha1)) { 250 b = phi1 + y1*(alpha1 - alpha1old)*k11 + y2*(alpha2 - alpha2old)*k12 + bOld; 251 } else if(!isBound(alpha2)) { 252 b = phi2 + y1*(alpha1 - alpha1old)*k12 + y2*(alpha2 - alpha2old)*k22 + bOld; 253 } else if(!isBound(alpha1star)) { 254 b = phi1 + y1*(alpha1star - alpha1starold)*k11 + y2*(alpha2star - alpha2starold)*k12 + bOld; 255 } else if(!isBound(alpha2star)) { 256 b = phi2 + y1*(alpha1star - alpha1starold)*k12 + y2*(alpha2star - alpha2starold)*k22 + bOld; 257 } else { 258 // no suitable alpha found to infer b - take middle of allowed interval 259 double b1 = phi1 + y1*(alpha1 - alpha1old)*k11 + y2*(alpha2 - alpha2old)*k12 + bOld; 260 double b2 = phi2 + y1*(alpha1 - alpha1old)*k12 + y2*(alpha2 - alpha2old)*k22 + bOld; 261 b = (b1 + b2) / 2.0; 262 } 263 model.setThreshold(b); 264 System.out.println("b is " + b); 265 266 */ 267 268 // double bOld = model.getThreshold(); 269 if (!isBound(alpha1)) { 270 b = y1 - model.internalClassify(i1) + bOld - epsilon; 271 } else if (!isBound(alpha1star)) { 272 b = y1 - model.internalClassify(i1) + bOld - epsilon; 273 } else { 274 b = y1 - model.internalClassify(i1) + bOld; 275 } 276 model.setThreshold(b); 277 278 // Update error cache 279/* E[i1] = 0; 280 E[i2] = 0; 281 for (int l = 0; l < E.length; ++l) { 282 if (l==i1 || l==i2) { 283 continue; 284 } 285 if (!(isBound(model.getAlpha(l)))) { 286 E[l] += y1*(a1-alpha1)*model.getKernelValue(i1, l) + y2*(a2-alpha2)*model.getKernelValue(i2, l) + bOld - b; 287 } 288 }*/ 289 290 if(changed) { 291 System.out.println("Successfuly changed things"); 292 return true; 293 } else { 294 System.out.println("Nothing changed"); 295 return false; 296 } 297 } 298 299 private int examineExample(int i2) { 300 double alpha2 = model.getAlpha(i2); 301 double alpha2star = model.getAlphaStar(i2); 302 double phi2 = getError(i2); 303 304 /* 305 306 if (Math.abs(phi2) < epsilon) 307 return 0; 308 309 */ 310 311 if ( (+phi2 < epsilon && alpha2star < C) || 312 (+phi2 < epsilon && alpha2star > 0.0) || 313 (-phi2 > epsilon && alpha2 < C) || 314 (-phi2 > epsilon && alpha2 > 0.0)) 315 { 316 317 /* 318 319 int secondChoice = -1; 320 double step = 0.0; 321 for (int l = 0; l < model.size(); ++l) { 322 if (!isBound(model.getAlpha(l))) { 323 double thisStep = Math.abs(getError(l) - phi2); 324 if (thisStep > step) { 325 step = thisStep; 326 secondChoice = l; 327 } 328 } 329 } 330 331 if (secondChoice >= 0) { 332 System.out.println("Using secondChoice heuristic for " + secondChoice + ", " + i2); 333 if (takeStep(secondChoice, i2)) { 334 return 1; 335 } 336 } 337 338 int randomStart = (int) Math.floor(Math.random() * model.size()); 339 for (int l = 0; l < model.size(); ++l) { 340 int i1 = (l + randomStart) % model.size(); 341 if (!isBound(model.getAlpha(i1))) { 342 System.out.println("Using unbound huristic for " + i1 + ", " + i2); 343 if (takeStep(i1, i2)) { 344 return 1; 345 } 346 } 347 } 348 349 */ 350 351 int randomStart = (int) Math.floor(Math.random() * model.size()); 352 353 // The second pass should look at ALL alphas, but 354 // we've already checked the non-bound ones. 355 for (int l = 0; l < model.size(); ++l) { 356 int i1 = (l + randomStart) % model.size(); 357 // if (isBound(model.getAlpha(i1))) { 358 System.out.println("Using bound huristic for " + i1 + ", " + i2); 359 if (takeStep(i1, i2)) { 360 return 1; 361 } 362 // } 363 } 364 } 365 return 0; 366 } 367 368 private boolean isBound(double alpha) { 369 return (alpha <= 0 || alpha >= C); 370 } 371 372 private double getError(int i) { 373// if (E[i] == Double.NEGATIVE_INFINITY || 374// isBound(model.getAlpha(i))) { 375 E[i] = model.internalClassify(i) - target[i]; 376 System.out.println("Calculated error: " + E[i]); 377// } 378 return E[i]; 379 } 380 381 public synchronized void trainModel(SVMRegressionModel m, double[] t) { 382 model = m; 383 target = t; 384 385 E = new double[model.size()]; 386 for (int i = 0; i < t.length; ++i) { 387 E[i] = Double.NEGATIVE_INFINITY; 388 } 389 390 int numChanged = 0; 391 boolean examineAll = true; 392 //int SigFig = -100; 393 394 while (numChanged > 0 || examineAll /*|| SigFig < 3*/) { 395 System.out.print("."); 396 397 numChanged = 0; 398 if (examineAll) { 399 System.out.println("Running full iteration"); 400 for (int i = 0; i < model.size(); ++i) { 401 numChanged += examineExample(i); 402 } 403 } else { 404 System.out.println("Running non-bounds iteration"); 405 for (int i = 0; i < model.size(); ++i) { 406 double alpha = model.getAlpha(i); 407 if (!isBound(alpha)) { 408 numChanged += examineExample(i); 409 } 410 } 411 } 412 413 if (examineAll) { 414 examineAll = false; 415 } else { 416 examineAll = (numChanged == 0); 417 } 418 } 419 420 E = null; 421 } 422}