001/* 002 * BioJava development code 003 * 004 * This code may be freely distributed and modified under the 005 * terms of the GNU Lesser General Public Licence. This should 006 * be distributed with the code. If you do not have a copy, 007 * see: 008 * 009 * http://www.gnu.org/copyleft/lesser.html 010 * 011 * Copyright for this code is held jointly by the individual 012 * authors. These should be listed in @author doc comments. 013 * 014 * For more information on the BioJava project and its aims, 015 * or to join the biojava-l mailing list, visit the home page 016 * at: 017 * 018 * http://www.biojava.org/ 019 * 020 */ 021 022 023/* 024 * @(#)Train.java 0.1 00/01/15 025 * 026 * By Thomas Down <td2@sanger.ac.uk> 027 */ 028 029package org.biojava.stats.svm.tools; 030 031import java.io.BufferedReader; 032import java.io.FileReader; 033import java.util.Iterator; 034 035import org.biojava.stats.svm.CachingKernel; 036import org.biojava.stats.svm.DiagonalCachingKernel; 037import org.biojava.stats.svm.NormalizingKernel; 038import org.biojava.stats.svm.PolynomialKernel; 039import org.biojava.stats.svm.SMOTrainer; 040import org.biojava.stats.svm.SVMClassifierModel; 041import org.biojava.stats.svm.SVMTarget; 042import org.biojava.stats.svm.SimpleSVMTarget; 043import org.biojava.stats.svm.SparseVector; 044import org.biojava.stats.svm.TrainingEvent; 045import org.biojava.stats.svm.TrainingListener; 046 047/** 048 * @author Ewan Birney 049 * @author Matthew Pocock 050 */ 051public class Train { 052 public static void main(String[] args) throws Throwable { 053 if (args.length != 2) { 054 throw new Exception("usage: stats.svm.tools.Classify <train_examples> <model_file>"); 055 } 056 String trainFile = args[0]; 057 String modelFile = args[1]; 058 059 BufferedReader r = new BufferedReader(new FileReader(trainFile)); 060 String line; 061 062 SVMTarget target = new SimpleSVMTarget(); 063 while ((line = r.readLine()) != null) { 064 if (line.length() == 0 || line.startsWith("#")) { 065 continue; 066 } 067 SVM_Light.LabelledVector ex = SVM_Light.parseExample(line); 068 target.addItemTarget(ex.getVector(), ex.getLabel()); 069 } 070 r.close(); 071 072 PolynomialKernel pK = new PolynomialKernel(); 073 pK.setOrder(2.0); 074 pK.setNestedKernel(SparseVector.kernel); 075 DiagonalCachingKernel gcK = new DiagonalCachingKernel(); 076 gcK.setNestedKernel(pK); 077 NormalizingKernel nK = new NormalizingKernel(); 078 nK.setNestedKernel(gcK); 079 CachingKernel cK = new CachingKernel(); 080 cK.setNestedKernel(nK); 081 082 SMOTrainer trainer = new SMOTrainer(); 083 trainer.setEpsilon(1.0e-9); 084 trainer.setC(1000); 085 TrainingListener tl = new TrainingListener() { 086 public void trainingCycleComplete(TrainingEvent e) { 087 System.out.print('.'); 088 } 089 public void trainingComplete(TrainingEvent e) { 090 System.out.println(""); 091 } 092 }; 093 094 System.out.println("Training"); 095 SVMClassifierModel model = trainer.trainModel(target, cK, tl); 096 System.out.println("Done"); 097 098 for(Iterator i = target.items().iterator(); i.hasNext(); ) { 099 Object item = i.next(); 100 System.out.println(target.getTarget(item) + "\t" + 101 model.classify(item) + "\t(" + 102 model.getAlpha(item) + ")"); 103 } 104 105 SVM_Light.writeModelFile(model, modelFile); 106 } 107}