001/**
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 * Created on 5 Mar 2013
021 * Created by Andreas Prlic
022 *
023 * @since 3.0.6
024 */
025package org.biojava.nbio.structure.math;
026
027import java.io.Serializable;
028
029
030/**
031 *
032 *  A sparse vector, implemented using a symbol table.
033 *
034 *  Derived from http://introcs.cs.princeton.edu/java/44st/SparseVector.java.html
035 *
036 *  For additional documentation, see <a href="http://introcs.cs.princeton.edu/44st">Section 4.4</a> of
037 *  <i>Introduction to Programming in Java: An Interdisciplinary Approach</i> by Robert Sedgewick and Kevin Wayne.
038 */
039
040public class SparseVector implements Serializable{
041        /**
042         *
043         */
044        private static final long serialVersionUID = 1174668523213431927L;
045
046        private final int N;             // length
047
048        private SymbolTable<Integer, Double> symbolTable;  // the vector, represented by index-value pairs
049
050
051        /** Constructor. initialize the all 0s vector of length N
052         *
053         * @param N
054         */
055        public SparseVector(int N) {
056                this.N  = N;
057                this.symbolTable = new SymbolTable<Integer, Double>();
058        }
059
060        /** Setter method (should it be renamed to set?)
061        *
062        * @param i set symbolTable[i]
063        * @param value
064        */
065        public void put(int i, double value) {
066                if (i < 0 || i >= N) throw new IllegalArgumentException("Illegal index " + i + " should be > 0 and < " + N);
067                if (value == 0.0) symbolTable.delete(i);
068                else              symbolTable.put(i, value);
069        }
070
071        /** get a value
072         *
073         * @param i
074         * @return  return symbolTable[i]
075         */
076        public double get(int i) {
077                if (i < 0 || i >= N) throw new IllegalArgumentException("Illegal index " + i + " should be > 0 and < " + N);
078                if (symbolTable.contains(i)) return symbolTable.get(i);
079                else                return 0.0;
080        }
081
082        // return the number of nonzero entries
083        public int nnz() {
084                return symbolTable.size();
085        }
086
087        // return the size of the vector
088        public int size() {
089                return N;
090        }
091
092        /** Calculates the dot product of this vector a with b
093         *
094         * @param b
095         * @return
096         */
097        public double dot(SparseVector b) {
098                SparseVector a = this;
099                if (a.N != b.N) throw new IllegalArgumentException("Vector lengths disagree. " + a.N + " != " + b.N);
100                double sum = 0.0;
101
102                // iterate over the vector with the fewest nonzeros
103                if (a.symbolTable.size() <= b.symbolTable.size()) {
104                        for (int i : a.symbolTable)
105                                if (b.symbolTable.contains(i)) sum += a.get(i) * b.get(i);
106                }
107                else  {
108                        for (int i : b.symbolTable)
109                                if (a.symbolTable.contains(i)) sum += a.get(i) * b.get(i);
110                }
111                return sum;
112        }
113
114        /** Calculates the 2-norm
115         *
116         * @return
117         */
118        public double norm() {
119                SparseVector a = this;
120                return Math.sqrt(a.dot(a));
121        }
122
123        /** Calculates  alpha * a
124         *
125         * @param alpha
126         * @return
127         */
128        public SparseVector scale(double alpha) {
129                SparseVector a = this;
130                SparseVector c = new SparseVector(N);
131                for (int i : a.symbolTable) c.put(i, alpha * a.get(i));
132                return c;
133        }
134
135        /** Calcualtes return a + b
136         *
137         * @param b
138         * @return
139         */
140        public SparseVector plus(SparseVector b) {
141                SparseVector a = this;
142                if (a.N != b.N) throw new IllegalArgumentException("Vector lengths disagree : " + a.N + " != " + b.N);
143                SparseVector c = new SparseVector(N);
144                for (int i : a.symbolTable) c.put(i, a.get(i));                // c = a
145                for (int i : b.symbolTable) c.put(i, b.get(i) + c.get(i));     // c = c + b
146                return c;
147        }
148
149        @Override
150        public String toString() {
151                StringBuilder s = new StringBuilder();
152                for (int i : symbolTable) {
153                        s.append("(");
154                        s.append(i);
155                        s.append(", ");
156                        s.append(symbolTable.get(i));
157                        s.append(") ");
158                }
159                return s.toString();
160        }
161
162
163}
164