001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021/* 022 * @(#)SVM_Light.java 0.1 00/01/15 023 * 024 * By Thomas Down <td2@sanger.ac.uk> 025 */ 026 027package org.biojava.stats.svm.tools; 028 029import java.io.BufferedReader; 030import java.io.FileReader; 031import java.io.FileWriter; 032import java.io.IOException; 033import java.io.PrintWriter; 034import java.util.Iterator; 035import java.util.StringTokenizer; 036 037import org.biojava.stats.svm.CachingKernel; 038import org.biojava.stats.svm.PolynomialKernel; 039import org.biojava.stats.svm.RadialBaseKernel; 040import org.biojava.stats.svm.SVMClassifierModel; 041import org.biojava.stats.svm.SVMKernel; 042import org.biojava.stats.svm.SimpleSVMClassifierModel; 043import org.biojava.stats.svm.SparseVector; 044 045/** 046 * @author Thomas Down 047 * @author Greg Cox 048 */ 049 050public class SVM_Light { 051 public static class LabelledVector { 052 private SparseVector v; 053 private double label; 054 private String comment = null; 055 056 public LabelledVector(SparseVector v, double label) { 057 this.v = v; 058 this.label = label; 059 } 060 061 public LabelledVector(SparseVector v, double label, String comment) { 062 this.v = v; 063 this.label = label; 064 this.comment = comment; 065 } 066 067 public SparseVector getVector() { 068 return v; 069 } 070 071 public double getLabel() { 072 return label; 073 } 074 075 public String getComment() { 076 return comment; 077 } 078 } 079 080 public static LabelledVector parseExample(String ex) 081 throws NumberFormatException 082 { 083 String comment = null; 084 int hashPos = ex.indexOf('#'); 085 if (hashPos >= 0) { 086 comment = ex.substring(hashPos + 1); 087 ex = ex.substring(0, hashPos); 088 } 089 090 StringTokenizer toke = new StringTokenizer(ex); 091 double label = Double.parseDouble(toke.nextToken()); 092 093 int size = toke.countTokens(); 094 SparseVector v = new SparseVector(size); 095 while (toke.hasMoreTokens()) { 096 String dim = toke.nextToken(); 097 int cut = dim.indexOf(':'); 098 if (cut < 0) { 099 throw new NumberFormatException("Bad dimension "+dim); 100 } 101 int dnum = Integer.parseInt(dim.substring(0, cut)); 102 double dval = Double.parseDouble(dim.substring(cut + 1)); 103 v.put(dnum, dval); 104 } 105 106 return new LabelledVector(v, label, comment); 107 } 108 109 public static String vectorToString(SparseVector v) { 110 StringBuffer sb = new StringBuffer(); 111 boolean first = true; 112 113 for (int i = 0; i < v.size(); ++i) { 114 double x = v.getValueAtIndex(i); 115 116 if (first) { 117 first = false; 118 } else { 119 sb.append(' '); 120 } 121 122 sb.append(v.getDimAtIndex(i)); 123 sb.append(':'); 124 sb.append(x); 125 } 126 return sb.substring(0); 127 } 128 129 public static SVMClassifierModel readModelFile(String fileName) 130 throws IOException 131 { 132 BufferedReader r = new BufferedReader(new FileReader(fileName)); 133 r.readLine(); // format 134 String kType = firstToken(r.readLine()); 135 String dParam = firstToken(r.readLine()); 136 String gParam = firstToken(r.readLine()); 137 r.readLine(); // s 138 r.readLine(); // r 139 r.readLine(); // u 140 r.readLine(); // numSV 141 String threshString = firstToken(r.readLine()); 142 143 SVMKernel kernel = null; 144 try { 145 switch (Integer.parseInt(kType)) { 146 case 0: 147 kernel = SparseVector.kernel; 148 break; 149 case 1: 150 int order = Integer.parseInt(dParam); 151 PolynomialKernel k = new PolynomialKernel(); 152 k.setOrder(order); 153 k.setNestedKernel(SparseVector.kernel); 154 kernel = k; 155 break; 156 case 2: 157 RadialBaseKernel rbk = new RadialBaseKernel(); 158 double width = Double.parseDouble(gParam); 159 rbk.setWidth(width); 160 rbk.setNestedKernel(SparseVector.kernel); 161 kernel = rbk; 162 break; 163 default: 164 throw new IOException("Couldn't create kernel"); 165 } 166 167 SimpleSVMClassifierModel model = new SimpleSVMClassifierModel(kernel); 168 model.setThreshold(Double.parseDouble(threshString)); 169 String line; 170 while ((line = r.readLine()) != null) { 171 LabelledVector ex = parseExample(line); 172 model.addItemAlpha(ex.getVector(), ex.getLabel()); 173 } 174 r.close(); 175 176 return model; 177 } catch (NumberFormatException ex) { 178 throw new IOException("Couldn't parse model file"); 179 } 180 } 181 182 public static void writeModelFile(SVMClassifierModel model, String fileName) 183 throws IOException { 184 SVMKernel k = model.getKernel(); 185 186 int kType = 0; 187 int d = 3; 188 double g = 1; 189 double s = 1; 190 double r = 1; 191 String u = "empty"; 192 193 while(k instanceof CachingKernel) { 194 k = ((CachingKernel) k).getNestedKernel(); 195 } 196 197 if (k == SparseVector.kernel) { 198 kType = 0; 199 } else if (k instanceof PolynomialKernel) { 200 kType = 1; 201 d = (int) ((PolynomialKernel) k).getOrder(); 202 } else if (k instanceof RadialBaseKernel) { 203 kType = 2; 204 g = ((RadialBaseKernel) k).getWidth(); 205 } else { 206 throw new IOException("Can't write SVM_Light file with kernel type " + k.getClass().toString()); 207 } 208 209 210 PrintWriter pw = new PrintWriter(new FileWriter(fileName)); 211 pw.println("SVM-light Version V3.01"); 212 pw.println("" + kType + " # kernel type"); 213 pw.println("" + d + " # kernel parameter -d"); 214 pw.println("" + g + " # kernel parameter -g"); 215 pw.println("" + s + " # kernel parameter -s"); 216 pw.println("" + r + " # kernel parameter -r"); 217 pw.println(u + " # kernel parameter -u"); 218 219 int numSV = 0; 220 for(Iterator i = model.items().iterator(); i.hasNext(); ) { 221 Object item = i.next(); 222 if (model.getAlpha(item) != 0) { 223 numSV++; 224 } 225 } 226 227 pw.println("" + numSV + " # number of support vectors"); 228 pw.println("" + model.getThreshold() + " # threshold b"); 229 230 for(Iterator i = model.items().iterator(); i.hasNext(); ) { 231 Object item = i.next(); 232 if (model.getAlpha(item) == 0) { 233 continue; 234 } 235 pw.print(model.getAlpha(i)); 236 237 SparseVector v = (SparseVector) item; 238 for (int j = 0; j <= v.maxIndex(); ++j) { 239 double x = v.get(j); 240 if (x != 0.0) 241 pw.print(" " + j + ":" + x); 242 } 243 pw.println(""); 244 } 245 246 pw.close(); 247 } 248 249 public static String firstToken(String s) { 250 int ndx = s.indexOf(" "); 251 if (ndx < 0) { 252 return s; 253 } else { 254 return s.substring(0, ndx); 255 } 256 } 257}