001package org.biojava.bio.program.ssaha;
002
003import java.io.ByteArrayOutputStream;
004import java.io.File;
005import java.io.IOException;
006import java.io.ObjectOutputStream;
007import java.io.RandomAccessFile;
008import java.nio.BufferOverflowException;
009import java.nio.IntBuffer;
010import java.nio.MappedByteBuffer;
011import java.nio.channels.FileChannel;
012
013import org.biojava.bio.BioException;
014import org.biojava.bio.seq.Sequence;
015import org.biojava.bio.seq.SequenceIterator;
016import org.biojava.bio.seq.db.SequenceDB;
017import org.biojava.bio.symbol.IllegalAlphabetException;
018import org.biojava.bio.symbol.Packing;
019import org.biojava.bio.symbol.PackingFactory;
020import org.biojava.utils.Constants;
021
022/**
023 * <p>
024 * Builder for a data store that is backed by a java.nio.MappedByteBuffer.
025 * This has a limitation that the total size of the mapped buffer and
026 * therefore the hash table can not exceed 2 gigs.
027 * </p>
028 *
029 * <p>
030 * The data store file has the following structure.
031 * <pre>
032 * file: header, hash table, nameArray, nameTable, hitTable
033 *
034 * header:
035 *   int hashTablePos, // byte offset in file
036 *   int hitTablePos,  // byte offset in file
037 *   int nameArrayPos, // byte offset in file
038 *   int nameTablePos, // byte offset in file
039 *   int wordLength,
040 *   int serializedPackingLength,
041 *   byte[] serializedPacking
042 *
043 *   hash table:
044 *     int hashTableLength,
045 *     int[hashTableLength] hits // index into hitTable
046 *
047 *  nameArray:
048 *    int nameArrayLength,
049 *    int[nameArrayLength] nameArray // byte offset into nameTable
050 *
051 *  nameTable:
052 *    int nameTableSize, // size in bytes
053 *    (short nameLength, char[nameLength] name)[nameTableSize] names
054 *
055 *  hitTable:
056 *    int hitTableSize, // size in bytes
057 *    hitTableRecord[hitTableSize] hits
058 *
059 *  hitTableRecord:
060 *    int hitCount,
061 *    hitRecord[hitCount] hit
062 *
063 *  hit:
064 *    int seqIndex, // index into nameArray
065 *    int offset    // offset into the sequence
066 * </pre>
067 * </p>
068 *
069 * @author Matthew Pocock
070 */
071public class MappedDataStoreFactory
072implements DataStoreFactory {
073  public DataStore getDataStore(File storeFile)
074  throws IOException {
075    return new MappedDataStore(storeFile);
076  }
077  
078  public DataStore buildDataStore(
079    File storeFile,
080    SequenceDB seqDB,
081    Packing packing,
082    int wordLength,
083    int threshold
084  ) throws
085    IllegalAlphabetException,
086    IOException,
087    BioException
088  {
089    ByteArrayOutputStream packingStream = new ByteArrayOutputStream();
090    ObjectOutputStream packingSerializer = new ObjectOutputStream(packingStream);
091    packingSerializer.writeObject(packing);
092    packingSerializer.flush();
093    
094    final int structDataSize =
095      6 * Constants.BYTES_IN_INT +
096      packingStream.toByteArray().length;
097    
098    final int hashTablePos;
099    final int hitTablePos;
100    final int nameArrayPos;
101    final int nameTablePos;
102    
103    storeFile.createNewFile();
104    final RandomAccessFile store = new RandomAccessFile(storeFile, "rw");
105    final FileChannel channel = store.getChannel();
106    
107    // allocate array for k-tuple -> hit list
108    //System.out.println("Word length:\t" + wordLength);
109    int words = 1 << (
110      (int) packing.wordSize() *
111      (int) wordLength
112    );
113    //System.out.println("Words:\t" + words);
114    
115    hashTablePos = structDataSize;
116    int hashTableSize =
117      (int) Constants.BYTES_IN_INT + // hash table length
118      words * (int) Constants.BYTES_IN_INT; // hash table entries
119    
120    //System.out.println("Allocated:\t" + hashTableSize);
121    final MappedByteBuffer hashTable_MB = channel.map(
122      FileChannel.MapMode.READ_WRITE,
123      hashTablePos,
124      hashTableSize
125    );
126    final IntBuffer hashTable = hashTable_MB.asIntBuffer();
127    hashTable.put(0, hashTableSize); // write length of k-tuple array
128    
129    // initialize counts to zero
130    for(int i = 0; i < words; i++) {
131      hashTable.put(i+1, 0);
132    }
133    hashTable.position(0);
134    
135    // 1st pass
136    // writes counts as ints for each k-tuple
137    // count up the space required for sequence names
138    //
139    int seqCount = 0;
140    int nameChars = 0;
141    for(SequenceIterator i = seqDB.sequenceIterator(); i.hasNext(); ) {
142      Sequence seq = i.nextSequence();
143      if(seq.length() > wordLength) {
144        seqCount++;
145        nameChars += seq.getName().length();
146        
147        int word = PackingFactory.primeWord(seq, wordLength, packing);
148        //PackingFactory.binary(word);
149        addCount(hashTable, word);
150        for(int j = wordLength + 1; j <= seq.length(); j++) {
151          word = PackingFactory.nextWord(seq, word, j, wordLength, packing);
152          //PackingFactory.binary(word);
153          addCount(hashTable, word);
154        }
155      }
156    }
157    
158    // map the space for sequence index->name
159    //
160    nameArrayPos = hashTablePos + hashTableSize;
161    int nameArraySize = (seqCount + 1) * Constants.BYTES_IN_INT;
162    //System.out.println("seqCount:\t" + seqCount);
163    //System.out.println("nameArraySize:\t" + nameArraySize);
164    final MappedByteBuffer nameArray_MB = channel.map(
165      FileChannel.MapMode.READ_WRITE,
166      nameArrayPos,
167      nameArraySize
168    );
169    final IntBuffer nameArray = nameArray_MB.asIntBuffer();
170    nameArray.put(0, nameArraySize);
171    
172    // map the space for sequence names as short length, char* name
173    //
174    nameTablePos = nameArrayPos + nameArraySize;
175    int nameTableSize =
176      Constants.BYTES_IN_INT +
177      seqCount * Constants.BYTES_IN_INT +
178      nameChars * Constants.BYTES_IN_CHAR;
179    //System.out.println("nameTableSize:\t" + nameTableSize);
180    final MappedByteBuffer nameTable = channel.map(
181      FileChannel.MapMode.READ_WRITE,
182      nameTablePos,
183      nameTableSize
184    );
185    nameTable.putInt(0, nameTableSize);
186    nameTable.position(Constants.BYTES_IN_INT);
187    
188    // add up the number of k-tuples
189    //
190    int kmersUsed = 0;
191    int hitCount = 0;
192    for(int i = 0; i < words; i++) {
193      int counts = hashTable.get(i + 1);
194      if(counts > 0 && counts < threshold) {
195        hitCount++;
196        kmersUsed += counts;
197      }
198    }
199    
200    // map the space for hits
201    hitTablePos = nameTablePos + nameTableSize;
202    long hitTableSize =
203      (long) Constants.BYTES_IN_INT +                            // size
204      (long) kmersUsed * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT) +  // list elements
205      (long) hitCount * Constants.BYTES_IN_INT;                  // size of lists
206    //System.out.println("hitTableSize:\t" + hitTableSize);
207    //System.out.println("hitTableSize:\t" + (int) hitTableSize);
208    //System.out.println("hitTablePos:\t" + hitTablePos);
209    final MappedByteBuffer hitTable = channel.map(
210      FileChannel.MapMode.READ_WRITE,
211      hitTablePos,
212      (int) hitTableSize
213    );
214    hitTable.putInt(0, (int) hitTableSize);
215    hitTable.position(Constants.BYTES_IN_INT);
216    
217    // write locations of hit arrays
218    int hitOffset = 0;
219    for(int i = 0; i < words; i++) {
220      int counts = hashTable.get(i+1);
221      if(counts > 0 && counts < threshold) {
222        try {
223        // record location of a block of the form:
224        // n,(seqID,offset)1,(seqID,offset)2,...,(seqID,offset)n
225        if(hitOffset < 0) {
226          throw new IndexOutOfBoundsException("Hit offset negative");
227        }
228        hashTable.put(i + 1, hitOffset); // wire hash table to hit table
229        hitTable.putInt(hitOffset + Constants.BYTES_IN_INT, 0); // initialy we have no hits
230        hitOffset +=
231          Constants.BYTES_IN_INT +
232          counts * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT);
233        } catch (IndexOutOfBoundsException e) {
234          System.out.println("counts:\t" + counts);
235          System.out.println("word:\t" + i);
236          System.out.println("hitOffset:\t" + hitOffset);
237          throw e;
238        }
239      } else {
240        // too many hits - set the number of hits to the flag value -1
241        hashTable.put(i + 1, -1);
242      }
243    }
244    
245    // 2nd parse
246    // write sequence array and names
247    // write hitTable
248    int seqNumber = 0;
249    nameTable.position(Constants.BYTES_IN_INT);
250    for(SequenceIterator i = seqDB.sequenceIterator(); i.hasNext(); ) {
251      Sequence seq = i.nextSequence();
252      
253      if(seq.length() > wordLength) {
254        try {
255          
256          // write sequence name reference into nameArray
257          nameArray.put(seqNumber + 1, nameTable.position()-Constants.BYTES_IN_INT);
258          
259          // write sequence name length and chars into nameTable
260          String name = seq.getName();
261          nameTable.putInt(name.length());
262          for(int j = 0; j < name.length(); j++) {
263            nameTable.putChar((char) name.charAt(j));
264          }
265          
266          // write k-mer seq,offset
267          int word = PackingFactory.primeWord(seq, wordLength, packing);
268          writeRecord(hashTable, hitTable, 1, seqNumber, word);
269          for(int j = wordLength+1; j <= seq.length(); j++) {
270            word = PackingFactory.nextWord(seq, word, j, wordLength, packing);
271            writeRecord(hashTable, hitTable, j - wordLength + 1, seqNumber, word);
272          }
273        } catch (BufferOverflowException e) {
274          System.out.println("name:\t" + seq.getName());
275          System.out.println("seqNumber:\t" + seqNumber);
276          System.out.println("na pos:\t" + nameArray.position());
277          System.out.println("nt pos:\t" + nameTable.position());
278          throw e;
279        }
280        seqNumber++;
281      }
282    }
283    
284    //validateNames(seqCount, nameArray, nameTable);
285    
286    final MappedByteBuffer rootBuffer = channel.map(
287      FileChannel.MapMode.READ_WRITE,
288      0,
289      structDataSize
290    );
291    
292    rootBuffer.position(0);
293    rootBuffer.putInt(hashTablePos);
294    rootBuffer.putInt(hitTablePos);
295    rootBuffer.putInt(nameArrayPos);
296    rootBuffer.putInt(nameTablePos);
297    rootBuffer.putInt(wordLength);
298    rootBuffer.putInt(packingStream.toByteArray().length);
299    rootBuffer.put(packingStream.toByteArray());
300    
301    rootBuffer.force();
302    hashTable_MB.force();
303    hitTable.force();
304    nameArray_MB.force();
305    nameTable.force();
306    
307    return getDataStore(storeFile);
308  }
309  
310  private void addCount(IntBuffer buffer, int word) {
311    int count = buffer.get(word+1);
312    count++;
313    buffer.put(word+1, count);
314  }
315  
316  private void writeRecord(
317    IntBuffer hashTable,
318    MappedByteBuffer hitTable,
319    int offset,
320    int seqNumber,
321    int word
322  ) {
323    int kmerPointer = hashTable.get(word+1);
324    if(kmerPointer != -1) {
325      kmerPointer += Constants.BYTES_IN_INT;
326
327      int hitCount = hitTable.getInt(kmerPointer);
328      int pos = kmerPointer + hitCount * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT) + Constants.BYTES_IN_INT;
329      
330      hitTable.position(pos);
331      hitTable.putInt(seqNumber);
332      hitTable.putInt(offset);
333      hitTable.putInt(kmerPointer, hitCount + 1);
334    }
335  }
336  
337}