001package org.biojava.bio.program.ssaha;
002
003import java.io.ByteArrayOutputStream;
004import java.io.File;
005import java.io.IOException;
006import java.io.ObjectOutputStream;
007import java.io.RandomAccessFile;
008import java.nio.IntBuffer;
009import java.nio.MappedByteBuffer;
010import java.nio.channels.FileChannel;
011
012import org.biojava.bio.BioException;
013import org.biojava.bio.BioRuntimeException;
014import org.biojava.bio.seq.db.SequenceDB;
015import org.biojava.bio.seq.io.ParseException;
016import org.biojava.bio.seq.io.SeqIOAdapter;
017import org.biojava.bio.symbol.Alphabet;
018import org.biojava.bio.symbol.IllegalAlphabetException;
019import org.biojava.bio.symbol.IllegalSymbolException;
020import org.biojava.bio.symbol.Packing;
021import org.biojava.bio.symbol.Symbol;
022import org.biojava.utils.AssertionFailure;
023import org.biojava.utils.Constants;
024
025/**
026 * <p>
027 * Builder for a data store that is backed by a java.nio.MappedByteBuffer.
028 * This has a limitation that the total size of the mapped buffer and
029 * therefore the hash table can not exceed 2 gigs.
030 * </p>
031 *
032 * <p>
033 * The data store file has the following structure.
034 * <pre>
035 * file: header, hash table, nameArray, nameTable, hitTable
036 *
037 * header:
038 *   int hashTablePos, // byte offset in file
039 *   int hitTablePos,  // byte offset in file
040 *   int nameArrayPos, // byte offset in file
041 *   int nameTablePos, // byte offset in file
042 *   int wordLength,
043 *   int serializedPackingLength,
044 *   byte[] serializedPacking
045 *
046 *   hash table:
047 *     int hashTableLength,
048 *     int[hashTableLength] hits // index into hitTable
049 *
050 *  nameArray:
051 *    int nameArrayLength,
052 *    int[nameArrayLength] nameRecord // byte offset into nameTable
053 *
054 *  nameRecord:
055 *    int nameTableOffset
056 *    int sequenceStartOffset
057 * 
058 *  nameTable:
059 *    int nameTableSize, // size in bytes
060 *    (short nameLength, char[nameLength] name)[nameTableSize] names
061 *
062 *  hitTable:
063 *    int hitTableSize, // size in bytes
064 *    hitTableRecord[hitTableSize] hits
065 *
066 *  hitTableRecord:
067 *    int hitCount,
068 *    hitRecord[hitCount] hit
069 *
070 *  hit:
071 *    int offset    // offset into the sequence
072 * </pre>
073 * </p>
074 *
075 * @author Matthew Pocock
076 * @author Thomas Down
077 */
078
079public class CompactedDataStoreFactory implements DataStoreFactory {
080  public DataStore getDataStore(File storeFile)
081      throws IOException 
082  {
083      return new CompactedDataStore(storeFile);
084  }
085  
086  public DataStore buildDataStore(
087    File storeFile,
088    SequenceDB seqDB,
089    Packing packing,
090    int wordLength,
091    int threshold
092  ) throws
093    IllegalAlphabetException,
094    IOException,
095    BioException
096  {
097      return this.buildDataStore(storeFile,
098                                 new SequenceStreamer.SequenceDBStreamer(seqDB),
099                                 packing,
100                                 wordLength,
101                                 1,
102                                 threshold);
103  }
104
105  public DataStore buildDataStore(
106    File storeFile,
107    SequenceStreamer streamer,
108    Packing packing,
109    int wordLength,
110    int stepSize,
111    int threshold
112  ) throws
113    IllegalAlphabetException,
114    IOException,
115    BioException
116  { 
117    ByteArrayOutputStream packingStream = new ByteArrayOutputStream();
118    ObjectOutputStream packingSerializer = new ObjectOutputStream(packingStream);
119    packingSerializer.writeObject(packing);
120    packingSerializer.flush();
121    
122    final int structDataSize =
123      6 * Constants.BYTES_IN_INT +
124      packingStream.toByteArray().length;
125    
126    final int hashTablePos;
127    final int hitTablePos;
128    final int nameArrayPos;
129    final int nameTablePos;
130    
131    storeFile.createNewFile();
132    final RandomAccessFile store = new RandomAccessFile(storeFile, "rw");
133    final FileChannel channel = store.getChannel();
134    
135    // allocate array for k-tuple -> hit list
136    //System.out.println("Word length:\t" + wordLength);
137    int words = 1 << (
138      (int) packing.wordSize() *
139      (int) wordLength
140    );
141    //System.out.println("Words:\t" + words);
142    
143    hashTablePos = structDataSize;
144    int hashTableSize =
145      (int) Constants.BYTES_IN_INT + // hash table length
146      words * (int) Constants.BYTES_IN_INT; // hash table entries
147    
148    //System.out.println("Allocated:\t" + hashTableSize);
149    if(hashTableSize < words) {
150      throw new AssertionFailure(
151        "Possible underflow. number of words: " + words +
152        "\tsize of hash table: " + hashTableSize +
153        "\tcompared to Integer.MAX_VALUE " + Integer.MAX_VALUE);
154    }
155
156    final MappedByteBuffer hashTable_MB = channel.map(
157      FileChannel.MapMode.READ_WRITE,
158      hashTablePos,
159      hashTableSize
160    );
161    final IntBuffer hashTable = hashTable_MB.asIntBuffer();
162    hashTable.put(0, hashTableSize); // write length of k-tuple array
163    
164    // initialize counts to zero
165    for(int i = 0; i < words; i++) {
166      hashTable.put(i+1, 0);
167    }
168    hashTable.position(0);
169    
170    // System.err.println("And so it begins...");
171
172    // 1st pass
173    // writes counts as ints for each k-tuple
174    // count up the space required for sequence names
175    //
176
177    FirstPassListener fpl = new FirstPassListener(packing, wordLength, stepSize, hashTable);
178    streamer.reset();
179    while (streamer.hasNext()) {
180        streamer.streamNext(fpl);
181    }
182    
183    // map the space for sequence index->name
184    //
185    nameArrayPos = hashTablePos + hashTableSize;
186    int nameArraySize = ((fpl.seqCount * 2) + 1) * Constants.BYTES_IN_INT;
187    //System.out.println("seqCount:\t" + seqCount);
188    //System.out.println("nameArraySize:\t" + nameArraySize);
189    final MappedByteBuffer nameArray_MB = channel.map(
190      FileChannel.MapMode.READ_WRITE,
191      nameArrayPos,
192      nameArraySize
193    );
194    final IntBuffer nameArray = nameArray_MB.asIntBuffer();
195    nameArray.put(0, nameArraySize);
196    
197    // map the space for sequence names as short length, char* name
198    //
199    nameTablePos = nameArrayPos + nameArraySize;
200    int nameTableSize =
201      Constants.BYTES_IN_INT +
202      fpl.seqCount * Constants.BYTES_IN_INT +
203      fpl.nameChars * Constants.BYTES_IN_CHAR;
204    //System.out.println("nameTableSize:\t" + nameTableSize);
205    final MappedByteBuffer nameTable = channel.map(
206      FileChannel.MapMode.READ_WRITE,
207      nameTablePos,
208      nameTableSize
209    );
210    nameTable.putInt(0, nameTableSize);
211    nameTable.position(Constants.BYTES_IN_INT);
212    
213    // add up the number of k-tuples
214    //
215    int kmersUsed = 0;
216    int hitCount = 0;
217    for(int i = 0; i < words; i++) {
218      int counts = hashTable.get(i + 1);
219      if(counts > 0 && counts < threshold) {
220        hitCount++;
221        kmersUsed += counts;
222      }
223    }
224    
225    // map the space for hits
226    hitTablePos = nameTablePos + nameTableSize;
227    long hitTableSize =
228      (long) Constants.BYTES_IN_INT +                            // size
229      (long) kmersUsed * (Constants.BYTES_IN_INT) +              // list elements
230      (long) hitCount * Constants.BYTES_IN_INT;                  // size of lists
231    //System.out.println("hitTableSize:\t" + hitTableSize);
232    //System.out.println("hitTableSize:\t" + (int) hitTableSize);
233    //System.out.println("hitTablePos:\t" + hitTablePos);
234    final MappedByteBuffer hitTable = channel.map(
235      FileChannel.MapMode.READ_WRITE,
236      hitTablePos,
237      (int) hitTableSize
238    );
239    hitTable.putInt(0, (int) hitTableSize);
240    hitTable.position(Constants.BYTES_IN_INT);
241    
242    // write locations of hit arrays
243    int hitOffset = 0;
244    for(int i = 0; i < words; i++) {
245      int counts = hashTable.get(i+1);
246      if(counts > 0 && counts < threshold) {
247        try {
248        // record location of a block of the form:
249        // n,(seqID,offset)1,(seqID,offset)2,...,(seqID,offset)n
250        if(hitOffset < 0) {
251          throw new IndexOutOfBoundsException("Hit offset negative");
252        }
253        hashTable.put(i + 1, hitOffset); // wire hash table to hit table
254        hitTable.putInt(hitOffset + Constants.BYTES_IN_INT, 0); // initialy we have no hits
255        hitOffset +=
256          Constants.BYTES_IN_INT +
257          counts * (Constants.BYTES_IN_INT);
258        } catch (IndexOutOfBoundsException e) {
259          System.out.println("counts:\t" + counts);
260          System.out.println("word:\t" + i);
261          System.out.println("hitOffset:\t" + hitOffset);
262          throw e;
263        }
264      } else if (counts == 0) {
265        // nothing - set the number of hits to the flag value -1
266        hashTable.put(i + 1, -1);
267      } else {
268        // too many hits - set the number of hits to the flag value -2
269        hashTable.put(i + 1, -2);
270      }
271    }
272    
273    // System.err.println("Second pass...");
274
275    // 2nd parse
276    // write sequence array and names
277    // write hitTable
278    
279    SecondPassListener spl = new SecondPassListener(packing,
280                                                    wordLength,
281                                                    stepSize,
282                                                    hashTable,
283                                                    nameArray,
284                                                    nameTable,
285                                                    hitTable);
286    streamer.reset();
287    while (streamer.hasNext()) {
288        streamer.streamNext(spl);
289    }
290    
291    //validateNames(seqCount, nameArray, nameTable);
292    
293    final MappedByteBuffer rootBuffer = channel.map(
294      FileChannel.MapMode.READ_WRITE,
295      0,
296      structDataSize
297    );
298    
299    rootBuffer.position(0);
300    rootBuffer.putInt(hashTablePos);
301    rootBuffer.putInt(hitTablePos);
302    rootBuffer.putInt(nameArrayPos);
303    rootBuffer.putInt(nameTablePos);
304    rootBuffer.putInt(wordLength);
305    rootBuffer.putInt(packingStream.toByteArray().length);
306    rootBuffer.put(packingStream.toByteArray());
307    
308    rootBuffer.force();
309    hashTable_MB.force();
310    hitTable.force();
311    nameArray_MB.force();
312    nameTable.force();
313    
314    return getDataStore(storeFile);
315  }
316  
317    private abstract class PackingListener extends SeqIOAdapter {
318        private final Packing packing;
319        private final int wordLength;
320        private final int stepSize;
321        private int pos = -1;
322        private int word = 0;
323        private int lengthFromUnknown = 0;
324
325        public PackingListener(Packing packing,
326                               int wordLength,
327                               int stepSize) 
328        {
329            this.packing = packing;
330            this.wordLength = wordLength;
331            this.stepSize = stepSize;
332        }
333
334        public void startSequence() 
335            throws ParseException
336        {
337            pos = 0;
338            word = 0;
339            lengthFromUnknown = 0;
340        }
341
342        public void endSequence()
343            throws ParseException
344        {
345            foundLength(pos);
346            pos = -1;
347        }
348
349        public void foundLength(int length)
350            throws ParseException
351        {
352        }
353
354        public abstract void processWord(int word, int pos)
355            throws ParseException;
356
357        public void addSymbols(Alphabet alpha, Symbol[] syms, int start, int length)
358            throws IllegalAlphabetException
359        {
360            if (alpha != packing.getAlphabet()) {
361                throw new IllegalAlphabetException("Alphabet " + alpha.getName() + " doesn't match packing");
362            }
363
364            int stepCounter = stepSize;
365            for (int i = start; i < (start + length); ++i) {
366                word = word >> (int) packing.wordSize();
367                try {
368                    int p = packing.pack(syms[i]);
369                    if (p < 0) {
370                        lengthFromUnknown = 0;
371                    } else {
372                        lengthFromUnknown++;
373                        word |= (int) p << ((int) (wordLength - 1) * packing.wordSize());
374                    }
375                } catch (IllegalSymbolException ex) {
376                    throw new BioRuntimeException(ex);
377                }
378
379                ++pos;
380                // System.out.println("Pos = " + pos + "        lengthFromUnknown = " + lengthFromUnknown);
381                if (--stepCounter == 0) {
382                    stepCounter = stepSize;
383                    if (lengthFromUnknown >= wordLength) {
384                        try {
385                            processWord(word, pos - wordLength + 1);
386                        } catch (ParseException ex) {
387                            throw new BioRuntimeException(ex);
388                        }
389                    }
390                }
391            }
392        }
393    }
394
395    private class FirstPassListener extends PackingListener {
396        private final IntBuffer hashTable;
397        int seqCount = 0;
398        int nameChars = 0;
399
400        FirstPassListener(Packing packing,
401                          int wordLength,
402                          int stepSize,
403                          IntBuffer hashTable) 
404        {
405            super(packing, wordLength, stepSize);
406            this.hashTable = hashTable;
407        }
408
409        public void startSequence()
410            throws ParseException
411        {
412            super.startSequence();
413            ++seqCount;
414        }
415
416        public void setName(String name) 
417            throws ParseException
418        {
419            //System.err.println(this + " setting name to " + name);
420            nameChars += name.length();
421        }
422
423        public void processWord(int word, int pos)
424            throws ParseException
425        {
426            addCount(hashTable, word);
427        }
428    }
429
430    private class SecondPassListener extends PackingListener {
431        private final IntBuffer hashTable;
432        private final IntBuffer nameArray;
433        private final MappedByteBuffer nameTable;
434        private final MappedByteBuffer hitTable;
435
436        private int seqNumber = 0;
437        private int concatOffset = 0;
438
439        private String name = ""; // fixme: we need to be cleverer
440        private int length = -1;
441
442//      public void startSequence()
443//          throws ParseException
444//      {
445//          super.startSequence();
446//          System.out.println("Starting sequence");
447//      }
448
449        SecondPassListener(Packing packing, 
450                           int wordLength, 
451                           int stepSize,
452                           IntBuffer hashTable,
453                           IntBuffer nameArray,
454                           MappedByteBuffer nameTable,
455                           MappedByteBuffer hitTable) 
456        {
457            super(packing, wordLength, stepSize);
458
459            if( (hashTable == null) ||
460                (nameArray == null) ||
461                (nameTable == null) ||
462                (hitTable  == null) )
463            {
464              throw new NullPointerException(
465                "Buffers must not be null. " +
466                "\thashTable: " + hashTable +
467                "\tnameArray: " + nameArray +
468                "\tnameTable: " + nameTable +
469                "\thitTable: " + hitTable );
470            }
471
472            this.hashTable = hashTable;
473            this.nameArray = nameArray;
474            this.nameTable = nameTable;
475            this.hitTable = hitTable;
476
477            nameTable.position(Constants.BYTES_IN_INT);
478        }
479
480        public void setName(String name) {
481            //System.err.println(this + " setting name to " + name);
482            this.name = name;
483        }
484
485        public void foundLength(int length) {
486            this.length = length;
487        }
488
489        public void endSequence() 
490            throws ParseException
491        {
492            super.endSequence();
493            
494            nameArray.put((seqNumber * 2) + 1, nameTable.position()-Constants.BYTES_IN_INT);
495            nameArray.put((seqNumber * 2) + 2, concatOffset);
496            // write sequence name length and chars into nameTable
497            nameTable.putInt(name.length());
498            for(int j = 0; j < name.length(); j++) {
499                nameTable.putChar((char) name.charAt(j));
500            }
501
502            ++seqNumber;
503            concatOffset += (length + 100);
504        }
505
506        public void processWord(int word, int pos)
507            throws ParseException
508        {
509            if (pos < 1) {
510                throw new ParseException("pos < 1");
511            }
512            writeRecord(hashTable, hitTable, pos + concatOffset, seqNumber, word);
513        }
514    }
515
516  private void addCount(IntBuffer buffer, int word) {
517    int count = buffer.get(word+1);
518    count++;
519    buffer.put(word+1, count);
520  }
521  
522  private void writeRecord(
523    IntBuffer hashTable,
524    MappedByteBuffer hitTable,
525    int offset,
526    int seqNumber,
527    int word
528  ) {
529    int kmerPointer = hashTable.get(word+1);
530    if(kmerPointer >= 0) {
531      kmerPointer += Constants.BYTES_IN_INT;
532
533      int hitCount = hitTable.getInt(kmerPointer);
534      int pos = kmerPointer + hitCount * (Constants.BYTES_IN_INT) + Constants.BYTES_IN_INT;
535      
536      hitTable.position(pos);
537      hitTable.putInt(offset);
538      hitTable.putInt(kmerPointer, hitCount + 1);
539    }
540  }
541  
542}