001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023/*
024 * @(#)TrainRegression.java      0.1 00/01/15
025 *
026 * By Matthew Pocock <mrp@sanger.ac.uk>
027 */
028
029package org.biojava.stats.svm.tools;
030
031import java.io.BufferedReader;
032import java.io.FileReader;
033import java.util.ArrayList;
034import java.util.List;
035
036import org.biojava.stats.svm.PolynomialKernel;
037import org.biojava.stats.svm.SMORegressionTrainer;
038import org.biojava.stats.svm.SVMRegressionModel;
039import org.biojava.stats.svm.SparseVector;
040
041/**
042 * @author Ewan Birney
043 * @author Matthew Pocock
044 */
045public class TrainRegression {
046  public static void main(String[] args) throws Throwable {
047    if (args.length < 2) {
048            throw new Exception("usage: stats.svm.tools.TrainRegression <train_examples> <model_file>");
049    }
050    String trainFile = args[0];
051
052    List examples = new ArrayList();
053    BufferedReader r = new BufferedReader(new FileReader(trainFile));
054    String line;
055
056    while ((line = r.readLine()) != null) {
057            if (line.length() == 0 || line.startsWith("#")) {
058        continue;
059      }
060            examples.add(SVM_Light.parseExample(line));
061    }
062    r.close();
063  
064        SVMRegressionModel model = new SVMRegressionModel(examples.size());
065    double[] target = new double[examples.size()];
066    for (int i = 0; i < examples.size(); ++i) {
067            SVM_Light.LabelledVector ex = (SVM_Light.LabelledVector) examples.get(i);
068            model.addVector(ex.getVector());
069            target[i] = ex.getLabel();
070    }
071
072    PolynomialKernel k = new PolynomialKernel();
073    k.setNestedKernel(SparseVector.kernel);
074    k.setOrder(2);
075    model.setKernel(k);
076    System.out.println("Calculating kernel " + k);
077    model.calcKernel();
078    SMORegressionTrainer trainer = new SMORegressionTrainer();
079    trainer.setEpsilon(0.00000000001);
080    trainer.setC(1000);
081    System.out.println("\nTraining");
082    trainer.trainModel(model, target);
083    System.out.println("\nDone");
084
085    for (int i=0; i < model.size(); ++i) {
086      System.err.println("y=" + target[i] + "\tf(x)=" + model.internalClassify(i)
087                         + "    (" + model.getAlpha(i) + ",\t"
088                         + model.getAlphaStar(i) + ")" + "\t" + (model.internalClassify(i) - model.getThreshold()));
089    }
090    System.err.println("b=" + model.getThreshold());
091
092//    SVM_Light.writeModelFile(model, modelFile);
093  }
094}