001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022
023/*
024 * @(#)Classify.java      0.1 00/01/15
025 *
026 * By Thomas Down <td2@sanger.ac.uk>
027 */
028
029package org.biojava.stats.svm.tools;
030
031import java.io.BufferedReader;
032import java.io.FileReader;
033import java.io.FileWriter;
034import java.io.PrintWriter;
035
036import org.biojava.stats.svm.SVMClassifierModel;
037
038/**
039 * @author Ewan Birney;
040 * @author Matthew Pocock
041 */
042public class Classify {
043    public static void main(String[] args) throws Throwable {
044        if (args.length < 3) {
045            throw new Exception("usage: stats.svm.tools.Classify <model> <test_examples> <results_log>");
046        }
047
048        String modelName = args[0];
049        String examplesName = args[1];
050        String resultsName = args[2];
051
052        SVMClassifierModel model = SVM_Light.readModelFile(modelName);
053
054        BufferedReader r = new BufferedReader(new FileReader(examplesName));
055        PrintWriter w = new PrintWriter(new FileWriter(resultsName));
056        
057        String line;
058        int right = 0, wrong = 0;
059
060        while ((line = r.readLine()) != null) {
061            if (line.length() == 0 || line.startsWith("#"))
062                continue;
063            SVM_Light.LabelledVector ex = SVM_Light.parseExample(line);
064            double result = model.classify(ex.getVector());
065            w.println(result);
066
067            if (sign(result) == sign(ex.getLabel()))
068                right++;
069            else
070                wrong++;
071        }
072
073        System.out.println("" + ((double) right / (right + wrong))*100 + "% correct");
074
075        r.close();
076        w.close();
077    }
078
079    public static int sign(double d) {
080        if (d < 0)
081            return -1;
082        else if (d == 0)
083            return 0;
084        return 1;
085    }
086}