001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023/* 024 * @(#)TrainRegression.java 0.1 00/01/15 025 * 026 * By Matthew Pocock <mrp@sanger.ac.uk> 027 */ 028 029package org.biojava.stats.svm.tools; 030 031import java.io.BufferedReader; 032import java.io.FileReader; 033import java.util.ArrayList; 034import java.util.List; 035 036import org.biojava.stats.svm.PolynomialKernel; 037import org.biojava.stats.svm.SMORegressionTrainer; 038import org.biojava.stats.svm.SVMRegressionModel; 039import org.biojava.stats.svm.SparseVector; 040 041/** 042 * @author Ewan Birney 043 * @author Matthew Pocock 044 */ 045public class TrainRegression { 046 public static void main(String[] args) throws Throwable { 047 if (args.length < 2) { 048 throw new Exception("usage: stats.svm.tools.TrainRegression <train_examples> <model_file>"); 049 } 050 String trainFile = args[0]; 051 052 List examples = new ArrayList(); 053 BufferedReader r = new BufferedReader(new FileReader(trainFile)); 054 String line; 055 056 while ((line = r.readLine()) != null) { 057 if (line.length() == 0 || line.startsWith("#")) { 058 continue; 059 } 060 examples.add(SVM_Light.parseExample(line)); 061 } 062 r.close(); 063 064 SVMRegressionModel model = new SVMRegressionModel(examples.size()); 065 double[] target = new double[examples.size()]; 066 for (int i = 0; i < examples.size(); ++i) { 067 SVM_Light.LabelledVector ex = (SVM_Light.LabelledVector) examples.get(i); 068 model.addVector(ex.getVector()); 069 target[i] = ex.getLabel(); 070 } 071 072 PolynomialKernel k = new PolynomialKernel(); 073 k.setNestedKernel(SparseVector.kernel); 074 k.setOrder(2); 075 model.setKernel(k); 076 System.out.println("Calculating kernel " + k); 077 model.calcKernel(); 078 SMORegressionTrainer trainer = new SMORegressionTrainer(); 079 trainer.setEpsilon(0.00000000001); 080 trainer.setC(1000); 081 System.out.println("\nTraining"); 082 trainer.trainModel(model, target); 083 System.out.println("\nDone"); 084 085 for (int i=0; i < model.size(); ++i) { 086 System.err.println("y=" + target[i] + "\tf(x)=" + model.internalClassify(i) 087 + " (" + model.getAlpha(i) + ",\t" 088 + model.getAlphaStar(i) + ")" + "\t" + (model.internalClassify(i) - model.getThreshold())); 089 } 090 System.err.println("b=" + model.getThreshold()); 091 092// SVM_Light.writeModelFile(model, modelFile); 093 } 094}