001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021/*
022 * @(#)SVM_Light.java      0.1 00/01/15
023 *
024 * By Thomas Down <td2@sanger.ac.uk>
025 */
026
027package org.biojava.stats.svm.tools;
028
029import java.io.BufferedReader;
030import java.io.FileReader;
031import java.io.FileWriter;
032import java.io.IOException;
033import java.io.PrintWriter;
034import java.util.Iterator;
035import java.util.StringTokenizer;
036
037import org.biojava.stats.svm.CachingKernel;
038import org.biojava.stats.svm.PolynomialKernel;
039import org.biojava.stats.svm.RadialBaseKernel;
040import org.biojava.stats.svm.SVMClassifierModel;
041import org.biojava.stats.svm.SVMKernel;
042import org.biojava.stats.svm.SimpleSVMClassifierModel;
043import org.biojava.stats.svm.SparseVector;
044
045/**
046 * @author Thomas Down
047 * @author Greg Cox
048 */
049
050public class SVM_Light {
051    public static class LabelledVector {
052        private SparseVector v;
053        private double label;
054        private String comment = null;
055
056        public LabelledVector(SparseVector v, double label) {
057            this.v = v;
058            this.label = label;
059        }
060
061        public LabelledVector(SparseVector v, double label, String comment) {
062            this.v = v;
063            this.label = label;
064            this.comment = comment;
065        }
066
067        public SparseVector getVector() {
068            return v;
069        }
070
071        public double getLabel() {
072            return label;
073        }
074
075        public String getComment() {
076            return comment;
077        }
078    }
079
080    public static LabelledVector parseExample(String ex)
081        throws NumberFormatException
082    {
083        String comment = null;
084        int hashPos = ex.indexOf('#');
085        if (hashPos >= 0) {
086            comment = ex.substring(hashPos + 1);
087            ex = ex.substring(0, hashPos);
088        }
089
090        StringTokenizer toke = new StringTokenizer(ex);
091        double label = Double.parseDouble(toke.nextToken());
092
093        int size = toke.countTokens();
094        SparseVector v = new SparseVector(size);
095        while (toke.hasMoreTokens()) {
096            String dim = toke.nextToken();
097            int cut = dim.indexOf(':');
098            if (cut < 0) {
099                throw new NumberFormatException("Bad dimension "+dim);
100            }
101            int dnum = Integer.parseInt(dim.substring(0, cut));
102            double dval = Double.parseDouble(dim.substring(cut + 1));
103            v.put(dnum, dval);
104        }
105
106        return new LabelledVector(v, label, comment);
107    }
108
109    public static String vectorToString(SparseVector v) {
110        StringBuffer sb = new StringBuffer();
111        boolean first = true;
112
113        for (int i = 0; i < v.size(); ++i) {
114            double x = v.getValueAtIndex(i);
115
116            if (first) {
117                first = false;
118            } else {
119                sb.append(' ');
120            }
121
122            sb.append(v.getDimAtIndex(i));
123            sb.append(':');
124            sb.append(x);
125        }
126        return sb.substring(0);
127    }
128
129    public static SVMClassifierModel readModelFile(String fileName)
130        throws IOException
131    {
132        BufferedReader r = new BufferedReader(new FileReader(fileName));
133        r.readLine(); // format
134        String kType = firstToken(r.readLine());
135        String dParam = firstToken(r.readLine());
136        String gParam = firstToken(r.readLine());
137        r.readLine(); // s
138        r.readLine(); // r
139        r.readLine(); // u
140        r.readLine(); // numSV
141        String threshString = firstToken(r.readLine());
142
143        SVMKernel kernel = null;
144        try {
145            switch (Integer.parseInt(kType)) {
146            case 0:
147                kernel = SparseVector.kernel;
148                break;
149            case 1:
150                int order = Integer.parseInt(dParam);
151                PolynomialKernel k = new PolynomialKernel();
152                k.setOrder(order);
153                k.setNestedKernel(SparseVector.kernel);
154                kernel = k;
155                break;
156            case 2:
157                RadialBaseKernel rbk = new RadialBaseKernel();
158                double width = Double.parseDouble(gParam);
159                rbk.setWidth(width);
160                rbk.setNestedKernel(SparseVector.kernel);
161                kernel = rbk;
162                break;
163            default:
164                throw new IOException("Couldn't create kernel");
165            }
166
167            SimpleSVMClassifierModel model = new SimpleSVMClassifierModel(kernel);
168            model.setThreshold(Double.parseDouble(threshString));
169            String line;
170            while ((line = r.readLine()) != null) {
171                LabelledVector ex = parseExample(line);
172                model.addItemAlpha(ex.getVector(), ex.getLabel());
173            }
174            r.close();
175
176            return model;
177        } catch (NumberFormatException ex) {
178            throw new IOException("Couldn't parse model file");
179        }
180    }
181
182  public static void writeModelFile(SVMClassifierModel model, String fileName)
183  throws IOException {
184    SVMKernel k = model.getKernel();
185
186    int kType = 0;
187    int d = 3;
188    double g = 1;
189    double s = 1;
190    double r = 1;
191    String u = "empty";
192
193    while(k instanceof CachingKernel) {
194      k = ((CachingKernel) k).getNestedKernel();
195    }
196    
197    if (k == SparseVector.kernel) {
198      kType = 0;
199    } else if (k instanceof PolynomialKernel) {
200      kType = 1;
201      d = (int) ((PolynomialKernel) k).getOrder();
202    } else if (k instanceof RadialBaseKernel) {
203      kType = 2;
204      g = ((RadialBaseKernel) k).getWidth();
205    } else {
206      throw new IOException("Can't write SVM_Light file with kernel type " + k.getClass().toString());
207    }
208
209
210    PrintWriter pw = new PrintWriter(new FileWriter(fileName));
211    pw.println("SVM-light Version V3.01");
212    pw.println("" + kType + " # kernel type");
213    pw.println("" + d + " # kernel parameter -d");
214    pw.println("" + g + " # kernel parameter -g");
215    pw.println("" + s + " # kernel parameter -s");
216    pw.println("" + r + " # kernel parameter -r");
217    pw.println(u + " # kernel parameter -u");
218
219    int numSV = 0;
220    for(Iterator i = model.items().iterator(); i.hasNext(); ) {
221      Object item = i.next();
222      if (model.getAlpha(item) != 0) {
223        numSV++;
224      }
225    }
226
227    pw.println("" + numSV + " # number of support vectors");
228    pw.println("" + model.getThreshold() + " # threshold b");
229
230    for(Iterator i = model.items().iterator(); i.hasNext(); ) {
231      Object item = i.next();
232      if (model.getAlpha(item) == 0) {
233        continue;
234      }
235      pw.print(model.getAlpha(i));
236
237      SparseVector v = (SparseVector) item;
238      for (int j = 0; j <= v.maxIndex(); ++j) {
239        double x = v.get(j);
240        if (x != 0.0)
241        pw.print(" " + j + ":" + x);
242      }
243      pw.println("");
244    }
245
246    pw.close();
247  }
248
249  public static String firstToken(String s) {
250    int ndx = s.indexOf(" ");
251    if (ndx < 0) {
252            return s;
253    } else {
254      return s.substring(0, ndx);
255    }
256  }
257}