001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023/*
024 * @(#)Train.java      0.1 00/01/15
025 *
026 * By Thomas Down <td2@sanger.ac.uk>
027 */
028
029package org.biojava.stats.svm.tools;
030
031import java.io.BufferedReader;
032import java.io.FileReader;
033import java.util.Iterator;
034
035import org.biojava.stats.svm.CachingKernel;
036import org.biojava.stats.svm.DiagonalCachingKernel;
037import org.biojava.stats.svm.NormalizingKernel;
038import org.biojava.stats.svm.PolynomialKernel;
039import org.biojava.stats.svm.SMOTrainer;
040import org.biojava.stats.svm.SVMClassifierModel;
041import org.biojava.stats.svm.SVMTarget;
042import org.biojava.stats.svm.SimpleSVMTarget;
043import org.biojava.stats.svm.SparseVector;
044import org.biojava.stats.svm.TrainingEvent;
045import org.biojava.stats.svm.TrainingListener;
046
047/**
048 * @author Ewan Birney
049 * @author Matthew Pocock
050 */
051public class Train {
052  public static void main(String[] args) throws Throwable {
053    if (args.length != 2) {
054            throw new Exception("usage: stats.svm.tools.Classify <train_examples> <model_file>");
055    }
056    String trainFile = args[0];
057    String modelFile = args[1];
058
059    BufferedReader r = new BufferedReader(new FileReader(trainFile));
060    String line;
061
062    SVMTarget target = new SimpleSVMTarget();
063    while ((line = r.readLine()) != null) {
064      if (line.length() == 0 || line.startsWith("#")) {
065        continue;
066      }
067            SVM_Light.LabelledVector ex = SVM_Light.parseExample(line);
068      target.addItemTarget(ex.getVector(), ex.getLabel());
069    }
070    r.close();
071
072    PolynomialKernel pK = new PolynomialKernel();
073    pK.setOrder(2.0);
074    pK.setNestedKernel(SparseVector.kernel);
075    DiagonalCachingKernel gcK = new DiagonalCachingKernel();
076    gcK.setNestedKernel(pK);
077    NormalizingKernel nK = new NormalizingKernel();
078    nK.setNestedKernel(gcK);
079    CachingKernel cK = new CachingKernel();
080    cK.setNestedKernel(nK);
081
082    SMOTrainer trainer = new SMOTrainer();
083    trainer.setEpsilon(1.0e-9);
084    trainer.setC(1000);
085    TrainingListener tl = new TrainingListener() {
086            public void trainingCycleComplete(TrainingEvent e) {
087        System.out.print('.');
088            }
089            public void trainingComplete(TrainingEvent e) {
090        System.out.println("");
091            }
092    };
093    
094    System.out.println("Training");
095          SVMClassifierModel model = trainer.trainModel(target, cK, tl);
096    System.out.println("Done");
097
098    for(Iterator i = target.items().iterator(); i.hasNext(); ) {
099      Object item = i.next();
100            System.out.println(target.getTarget(item) + "\t" +
101                         model.classify(item) + "\t(" +
102                         model.getAlpha(item) + ")");
103    }
104
105    SVM_Light.writeModelFile(model, modelFile);
106  }
107}