001package org.biojava.bio.program.ssaha;
002
003import java.io.ByteArrayOutputStream;
004import java.io.File;
005import java.io.IOException;
006import java.io.ObjectOutputStream;
007import java.io.RandomAccessFile;
008import java.nio.BufferOverflowException;
009import java.nio.LongBuffer;
010import java.nio.MappedByteBuffer;
011import java.nio.channels.FileChannel;
012
013import org.biojava.bio.BioException;
014import org.biojava.bio.seq.Sequence;
015import org.biojava.bio.seq.SequenceIterator;
016import org.biojava.bio.seq.db.SequenceDB;
017import org.biojava.bio.symbol.IllegalAlphabetException;
018import org.biojava.bio.symbol.Packing;
019import org.biojava.bio.symbol.PackingFactory;
020import org.biojava.utils.Constants;
021import org.biojava.utils.io.LargeBuffer;
022
023/**
024 * <p>
025 * Builder for a datastore that has no practical file size limit.
026 * </p>
027 *
028 * <p>
029 * This implementation of the data store factory uses longs as indecies
030 * internaly, so can be used with files exceeding 2 gigs in size.
031 * </p>
032 *
033 * <p>
034 * The data store file has the following structure.
035 * <pre>
036 * file: header, hash table, nameTable, hitTable
037 *
038 * header:
039 *   long hashTablePos, // byte offset in file
040 *   long hitTablePos,  // byte offset in file
041 *   long nameTablePos, // byte offset in file
042 *   int wordLength,
043 *   int serializedPackingLength,
044 *   byte[] serializedPacking
045 *
046 * hashTable:
047 *   long hashTableLength,
048 *   long[hashTableLength] hits // byte offset into hitTable
049 *
050 * nameTable:
051 *   int nameTableSize, // size in bytes
052 *   (short nameLength, char[nameLength] name)[nameTableSize] names
053 *
054 * hitTable:
055 *   long hitTableSize, // size in bytes
056 *   hitTableRecord[hitTableSize] hits
057 *
058 * hitTableRecord:
059 *   int hitCount,
060 *   hitRecord[hitCount] hit
061 *
062 * hit:
063 *   long seqOffset, // byte offset into sequence names table
064 *   int pos         // biological position in sequence
065 * </pre>
066 * </p>
067 * @author Matthew Pocock
068 */
069public class NIODataStoreFactory
070implements DataStoreFactory {
071  public DataStore getDataStore(File storeFile)
072  throws IOException {
073    return new NIODataStore(storeFile);
074  }
075  
076  public DataStore buildDataStore(
077    File storeFile,
078    SequenceDB seqDB,
079    Packing packing,
080    int wordLength,
081    int threshold
082  ) throws
083    IllegalAlphabetException,
084    IOException,
085    BioException
086  {
087    ByteArrayOutputStream packingStream = new ByteArrayOutputStream();
088    ObjectOutputStream packingSerializer = new ObjectOutputStream(packingStream);
089    packingSerializer.writeObject(packing);
090    packingSerializer.flush();
091    
092    final int structDataSize =
093      3 * Constants.BYTES_IN_LONG + // positions
094      2 * Constants.BYTES_IN_INT +  // word length & packing length
095      packingStream.toByteArray().length;
096    
097    final long hashTablePos;
098    final long hitTablePos;
099    final long nameTablePos;
100    
101    storeFile.createNewFile();
102    final RandomAccessFile store = new RandomAccessFile(storeFile, "rw");
103    final FileChannel channel = store.getChannel();
104    
105    // allocate array for k-tuple -> hit list
106    System.out.println("Word length:\t" + wordLength);
107    int words = 1 << (
108      (int) packing.wordSize() *
109      (int) wordLength
110    );
111    System.out.println("Words:\t" + words);
112    
113    hashTablePos = structDataSize;
114    long hashTableSize =
115      Constants.BYTES_IN_LONG +        // hash table length
116      words * Constants.BYTES_IN_LONG; // hash table entries
117    
118    System.out.println("Allocated:\t" + hashTableSize);
119    final MappedByteBuffer hashTable_MB = channel.map(
120      FileChannel.MapMode.READ_WRITE,
121      hashTablePos,
122      hashTableSize
123    );
124    final LongBuffer hashTable = hashTable_MB.asLongBuffer();
125    hashTable.put(0, hashTableSize); // write length of k-tuple array
126    
127    // initialize counts to zero
128    for(int i = 0; i < words; i++) {
129      hashTable.put(i+1, 0);
130    }
131    hashTable.position(0);
132    System.out.println("Initialized hash table");
133    
134    // 1st pass
135    // writes counts as longs for each k-tuple
136    // count up the space required for sequence names
137    //
138    System.out.println("First parse");
139    int seqCount = 0;
140    int nameChars = 0;
141    for(SequenceIterator i = seqDB.sequenceIterator(); i.hasNext(); ) {
142      Sequence seq = i.nextSequence();
143      System.out.println(seq.getName());
144      if(seq.length() > wordLength) {
145        seqCount++;
146        nameChars += seq.getName().length();
147        
148        int word = PackingFactory.primeWord(seq, wordLength, packing);
149        //PackingFactory.binary(word);
150        addCount(hashTable, word);
151        for(int j = wordLength+2; j <= seq.length(); j++) {
152          word = PackingFactory.nextWord(seq, word, j, wordLength, packing);
153          //PackingFactory.binary(word);
154          addCount(hashTable, word);
155        }
156      }
157    }
158    System.out.println();
159    System.out.println("Done");
160     
161    // map the space for sequence names as short length, char* name
162    //
163    nameTablePos = hashTablePos + hashTableSize;
164    int nameTableSize =
165      Constants.BYTES_IN_INT +              // bytes in table
166      seqCount * Constants.BYTES_IN_SHORT + // string lengths
167      nameChars * Constants.BYTES_IN_CHAR;  // characters
168    System.out.println("nameTableSize:\t" + nameTableSize);
169    final MappedByteBuffer nameTable = channel.map(
170      FileChannel.MapMode.READ_WRITE,
171      nameTablePos,
172      nameTableSize
173    );
174    nameTable.putInt(0, nameTableSize);
175    nameTable.position(Constants.BYTES_IN_INT);
176    
177    // add up the number of k-tuples
178    //
179    long kmersUsed = 0; // number of kmers with valid hits
180    long hitCount = 0;  // total number of individual hits
181    for(int i = 0; i < words; i++) {
182      long counts = hashTable.get(i + 1);
183      if(counts > 0 && counts < threshold) {
184        hitCount += counts;
185        kmersUsed++;
186      }
187    }
188    System.out.println("hitCount: " + hitCount);
189    System.out.println("kmersUsed: " + kmersUsed);
190    
191    hitTablePos = nameTablePos + nameTableSize;
192    long hitTableSize =
193      (long) Constants.BYTES_IN_LONG +                            // size
194      (long) hitCount * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT) +  // list elements
195      (long) kmersUsed * Constants.BYTES_IN_INT;                  // size of lists
196    System.out.println("hitTableSize:\t" + hitTableSize);
197    System.out.println("hitTablePos:\t" + hitTablePos);
198    
199    final LargeBuffer hitTable = new LargeBuffer(
200      channel,
201      FileChannel.MapMode.READ_WRITE,
202      hitTablePos,
203      hitTableSize
204    );
205    hitTable.position(0);
206    hitTable.putLong(hitTableSize);
207    
208    // write locations of hit arrays
209    System.out.println("Writing pointers from hash to hits");
210    long hitOffset = 0;
211    for(int i = 0; i < words; i++) {
212      long counts = hashTable.get(i+1);
213      if(counts > 0 && counts < threshold) {
214        try {
215          // record location of a block of the form:
216          // n,(seqID,offset)1,(seqID,offset)2,...,(seqID,offset)n
217          if(hitOffset < 0) {
218            throw new IndexOutOfBoundsException("Hit offset negative");
219          }
220          hashTable.put(i + 1, hitOffset); // wire hash table to hit table
221          hitTable.putInt(hitOffset + Constants.BYTES_IN_LONG, 0); // initialy we have no hits
222          hitOffset +=
223            Constants.BYTES_IN_INT +
224            counts * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT);
225        } catch (IndexOutOfBoundsException e) {
226          System.out.println("counts:\t" + counts);
227          System.out.println("word:\t" + i);
228          System.out.println("hitOffset:\t" + hitOffset);
229          throw e;
230        }
231      } else {
232        // too many hits - set the number of hits to the flag value -1
233        hashTable.put(i + 1, -1);
234      }
235    }
236    
237    // 2nd parse
238    // write sequence array and names
239    // write hitTable
240    System.out.println("2nd parse");
241    int seqNumber = 0;
242    nameTable.position(Constants.BYTES_IN_INT);
243    for(SequenceIterator i = seqDB.sequenceIterator(); i.hasNext(); ) {
244      Sequence seq = i.nextSequence();
245      System.out.println(seq.getName());
246      seqNumber++;
247      if(seq.length() > wordLength) {
248        try {
249          // write sequence name length and chars into nameTable
250          int namePos = nameTable.position() - Constants.BYTES_IN_INT;
251          String name = seq.getName();
252          nameTable.putShort((short) name.length());
253          for(int j = 0; j < name.length(); j++) {
254            nameTable.putChar((char) name.charAt(j));
255          }
256          
257          // write k-mer seq,offset
258          int word = PackingFactory.primeWord(seq, wordLength, packing);
259          writeRecord(hashTable, hitTable, 1, namePos, word);
260          for(int j = wordLength+2; j <= seq.length(); j++) {
261            word = PackingFactory.nextWord(seq, word, j, wordLength, packing);
262            writeRecord(hashTable, hitTable, j - wordLength, namePos, word);
263          }
264        } catch (BufferOverflowException e) {
265          System.out.println("name:\t" + seq.getName());
266          System.out.println("seqNumber:\t" + seqNumber);
267          System.out.println("nt pos:\t" + nameTable.position());
268          throw e;
269        }
270      }
271    }
272    System.out.println();
273    System.out.println("Done");
274
275    final MappedByteBuffer rootBuffer = channel.map(
276      FileChannel.MapMode.READ_WRITE,
277      0,
278      structDataSize
279    );
280    
281    rootBuffer.position(0);
282    rootBuffer.putLong(hashTablePos);
283    rootBuffer.putLong(hitTablePos);
284    rootBuffer.putLong(nameTablePos);
285    rootBuffer.putInt(wordLength);
286    rootBuffer.putInt(packingStream.toByteArray().length);
287    rootBuffer.put(packingStream.toByteArray());
288    
289    rootBuffer.force();
290    hashTable_MB.force();
291    hitTable.force();
292    nameTable.force();
293    
294    System.out.println("Committed");
295    
296    return getDataStore(storeFile);
297  }
298  
299  private void addCount(LongBuffer buffer, int word) {
300    long count = buffer.get(word+1);
301    count++;
302    buffer.put(word+1, count);
303  }
304  
305  private void writeRecord(
306    LongBuffer hashTable,
307    LargeBuffer hitTable,
308    int offset,
309    int seqNumber,
310    int word
311  ) throws IOException {
312    long kmerPointer = hashTable.get(word+1);
313    if(kmerPointer != -1) {
314      kmerPointer += Constants.BYTES_IN_LONG;
315
316      int hitCount = hitTable.getInt(kmerPointer);
317      long pos = kmerPointer + hitCount * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT) + Constants.BYTES_IN_INT;
318      
319      hitTable.position(pos);
320      hitTable.putInt(seqNumber);
321      hitTable.putInt(offset);
322      hitTable.putInt(kmerPointer, hitCount + 1);
323    }
324  }
325  
326}