001package org.biojava.bio.program.ssaha; 002 003import java.io.ByteArrayOutputStream; 004import java.io.File; 005import java.io.IOException; 006import java.io.ObjectOutputStream; 007import java.io.RandomAccessFile; 008import java.nio.BufferOverflowException; 009import java.nio.IntBuffer; 010import java.nio.MappedByteBuffer; 011import java.nio.channels.FileChannel; 012 013import org.biojava.bio.BioException; 014import org.biojava.bio.seq.Sequence; 015import org.biojava.bio.seq.SequenceIterator; 016import org.biojava.bio.seq.db.SequenceDB; 017import org.biojava.bio.symbol.IllegalAlphabetException; 018import org.biojava.bio.symbol.Packing; 019import org.biojava.bio.symbol.PackingFactory; 020import org.biojava.utils.Constants; 021 022/** 023 * <p> 024 * Builder for a data store that is backed by a java.nio.MappedByteBuffer. 025 * This has a limitation that the total size of the mapped buffer and 026 * therefore the hash table can not exceed 2 gigs. 027 * </p> 028 * 029 * <p> 030 * The data store file has the following structure. 031 * <pre> 032 * file: header, hash table, nameArray, nameTable, hitTable 033 * 034 * header: 035 * int hashTablePos, // byte offset in file 036 * int hitTablePos, // byte offset in file 037 * int nameArrayPos, // byte offset in file 038 * int nameTablePos, // byte offset in file 039 * int wordLength, 040 * int serializedPackingLength, 041 * byte[] serializedPacking 042 * 043 * hash table: 044 * int hashTableLength, 045 * int[hashTableLength] hits // index into hitTable 046 * 047 * nameArray: 048 * int nameArrayLength, 049 * int[nameArrayLength] nameArray // byte offset into nameTable 050 * 051 * nameTable: 052 * int nameTableSize, // size in bytes 053 * (short nameLength, char[nameLength] name)[nameTableSize] names 054 * 055 * hitTable: 056 * int hitTableSize, // size in bytes 057 * hitTableRecord[hitTableSize] hits 058 * 059 * hitTableRecord: 060 * int hitCount, 061 * hitRecord[hitCount] hit 062 * 063 * hit: 064 * int seqIndex, // index into nameArray 065 * int offset // offset into the sequence 066 * </pre> 067 * </p> 068 * 069 * @author Matthew Pocock 070 */ 071public class MappedDataStoreFactory 072implements DataStoreFactory { 073 public DataStore getDataStore(File storeFile) 074 throws IOException { 075 return new MappedDataStore(storeFile); 076 } 077 078 public DataStore buildDataStore( 079 File storeFile, 080 SequenceDB seqDB, 081 Packing packing, 082 int wordLength, 083 int threshold 084 ) throws 085 IllegalAlphabetException, 086 IOException, 087 BioException 088 { 089 ByteArrayOutputStream packingStream = new ByteArrayOutputStream(); 090 ObjectOutputStream packingSerializer = new ObjectOutputStream(packingStream); 091 packingSerializer.writeObject(packing); 092 packingSerializer.flush(); 093 094 final int structDataSize = 095 6 * Constants.BYTES_IN_INT + 096 packingStream.toByteArray().length; 097 098 final int hashTablePos; 099 final int hitTablePos; 100 final int nameArrayPos; 101 final int nameTablePos; 102 103 storeFile.createNewFile(); 104 final RandomAccessFile store = new RandomAccessFile(storeFile, "rw"); 105 final FileChannel channel = store.getChannel(); 106 107 // allocate array for k-tuple -> hit list 108 //System.out.println("Word length:\t" + wordLength); 109 int words = 1 << ( 110 (int) packing.wordSize() * 111 (int) wordLength 112 ); 113 //System.out.println("Words:\t" + words); 114 115 hashTablePos = structDataSize; 116 int hashTableSize = 117 (int) Constants.BYTES_IN_INT + // hash table length 118 words * (int) Constants.BYTES_IN_INT; // hash table entries 119 120 //System.out.println("Allocated:\t" + hashTableSize); 121 final MappedByteBuffer hashTable_MB = channel.map( 122 FileChannel.MapMode.READ_WRITE, 123 hashTablePos, 124 hashTableSize 125 ); 126 final IntBuffer hashTable = hashTable_MB.asIntBuffer(); 127 hashTable.put(0, hashTableSize); // write length of k-tuple array 128 129 // initialize counts to zero 130 for(int i = 0; i < words; i++) { 131 hashTable.put(i+1, 0); 132 } 133 hashTable.position(0); 134 135 // 1st pass 136 // writes counts as ints for each k-tuple 137 // count up the space required for sequence names 138 // 139 int seqCount = 0; 140 int nameChars = 0; 141 for(SequenceIterator i = seqDB.sequenceIterator(); i.hasNext(); ) { 142 Sequence seq = i.nextSequence(); 143 if(seq.length() > wordLength) { 144 seqCount++; 145 nameChars += seq.getName().length(); 146 147 int word = PackingFactory.primeWord(seq, wordLength, packing); 148 //PackingFactory.binary(word); 149 addCount(hashTable, word); 150 for(int j = wordLength + 1; j <= seq.length(); j++) { 151 word = PackingFactory.nextWord(seq, word, j, wordLength, packing); 152 //PackingFactory.binary(word); 153 addCount(hashTable, word); 154 } 155 } 156 } 157 158 // map the space for sequence index->name 159 // 160 nameArrayPos = hashTablePos + hashTableSize; 161 int nameArraySize = (seqCount + 1) * Constants.BYTES_IN_INT; 162 //System.out.println("seqCount:\t" + seqCount); 163 //System.out.println("nameArraySize:\t" + nameArraySize); 164 final MappedByteBuffer nameArray_MB = channel.map( 165 FileChannel.MapMode.READ_WRITE, 166 nameArrayPos, 167 nameArraySize 168 ); 169 final IntBuffer nameArray = nameArray_MB.asIntBuffer(); 170 nameArray.put(0, nameArraySize); 171 172 // map the space for sequence names as short length, char* name 173 // 174 nameTablePos = nameArrayPos + nameArraySize; 175 int nameTableSize = 176 Constants.BYTES_IN_INT + 177 seqCount * Constants.BYTES_IN_INT + 178 nameChars * Constants.BYTES_IN_CHAR; 179 //System.out.println("nameTableSize:\t" + nameTableSize); 180 final MappedByteBuffer nameTable = channel.map( 181 FileChannel.MapMode.READ_WRITE, 182 nameTablePos, 183 nameTableSize 184 ); 185 nameTable.putInt(0, nameTableSize); 186 nameTable.position(Constants.BYTES_IN_INT); 187 188 // add up the number of k-tuples 189 // 190 int kmersUsed = 0; 191 int hitCount = 0; 192 for(int i = 0; i < words; i++) { 193 int counts = hashTable.get(i + 1); 194 if(counts > 0 && counts < threshold) { 195 hitCount++; 196 kmersUsed += counts; 197 } 198 } 199 200 // map the space for hits 201 hitTablePos = nameTablePos + nameTableSize; 202 long hitTableSize = 203 (long) Constants.BYTES_IN_INT + // size 204 (long) kmersUsed * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT) + // list elements 205 (long) hitCount * Constants.BYTES_IN_INT; // size of lists 206 //System.out.println("hitTableSize:\t" + hitTableSize); 207 //System.out.println("hitTableSize:\t" + (int) hitTableSize); 208 //System.out.println("hitTablePos:\t" + hitTablePos); 209 final MappedByteBuffer hitTable = channel.map( 210 FileChannel.MapMode.READ_WRITE, 211 hitTablePos, 212 (int) hitTableSize 213 ); 214 hitTable.putInt(0, (int) hitTableSize); 215 hitTable.position(Constants.BYTES_IN_INT); 216 217 // write locations of hit arrays 218 int hitOffset = 0; 219 for(int i = 0; i < words; i++) { 220 int counts = hashTable.get(i+1); 221 if(counts > 0 && counts < threshold) { 222 try { 223 // record location of a block of the form: 224 // n,(seqID,offset)1,(seqID,offset)2,...,(seqID,offset)n 225 if(hitOffset < 0) { 226 throw new IndexOutOfBoundsException("Hit offset negative"); 227 } 228 hashTable.put(i + 1, hitOffset); // wire hash table to hit table 229 hitTable.putInt(hitOffset + Constants.BYTES_IN_INT, 0); // initialy we have no hits 230 hitOffset += 231 Constants.BYTES_IN_INT + 232 counts * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT); 233 } catch (IndexOutOfBoundsException e) { 234 System.out.println("counts:\t" + counts); 235 System.out.println("word:\t" + i); 236 System.out.println("hitOffset:\t" + hitOffset); 237 throw e; 238 } 239 } else { 240 // too many hits - set the number of hits to the flag value -1 241 hashTable.put(i + 1, -1); 242 } 243 } 244 245 // 2nd parse 246 // write sequence array and names 247 // write hitTable 248 int seqNumber = 0; 249 nameTable.position(Constants.BYTES_IN_INT); 250 for(SequenceIterator i = seqDB.sequenceIterator(); i.hasNext(); ) { 251 Sequence seq = i.nextSequence(); 252 253 if(seq.length() > wordLength) { 254 try { 255 256 // write sequence name reference into nameArray 257 nameArray.put(seqNumber + 1, nameTable.position()-Constants.BYTES_IN_INT); 258 259 // write sequence name length and chars into nameTable 260 String name = seq.getName(); 261 nameTable.putInt(name.length()); 262 for(int j = 0; j < name.length(); j++) { 263 nameTable.putChar((char) name.charAt(j)); 264 } 265 266 // write k-mer seq,offset 267 int word = PackingFactory.primeWord(seq, wordLength, packing); 268 writeRecord(hashTable, hitTable, 1, seqNumber, word); 269 for(int j = wordLength+1; j <= seq.length(); j++) { 270 word = PackingFactory.nextWord(seq, word, j, wordLength, packing); 271 writeRecord(hashTable, hitTable, j - wordLength + 1, seqNumber, word); 272 } 273 } catch (BufferOverflowException e) { 274 System.out.println("name:\t" + seq.getName()); 275 System.out.println("seqNumber:\t" + seqNumber); 276 System.out.println("na pos:\t" + nameArray.position()); 277 System.out.println("nt pos:\t" + nameTable.position()); 278 throw e; 279 } 280 seqNumber++; 281 } 282 } 283 284 //validateNames(seqCount, nameArray, nameTable); 285 286 final MappedByteBuffer rootBuffer = channel.map( 287 FileChannel.MapMode.READ_WRITE, 288 0, 289 structDataSize 290 ); 291 292 rootBuffer.position(0); 293 rootBuffer.putInt(hashTablePos); 294 rootBuffer.putInt(hitTablePos); 295 rootBuffer.putInt(nameArrayPos); 296 rootBuffer.putInt(nameTablePos); 297 rootBuffer.putInt(wordLength); 298 rootBuffer.putInt(packingStream.toByteArray().length); 299 rootBuffer.put(packingStream.toByteArray()); 300 301 rootBuffer.force(); 302 hashTable_MB.force(); 303 hitTable.force(); 304 nameArray_MB.force(); 305 nameTable.force(); 306 307 return getDataStore(storeFile); 308 } 309 310 private void addCount(IntBuffer buffer, int word) { 311 int count = buffer.get(word+1); 312 count++; 313 buffer.put(word+1, count); 314 } 315 316 private void writeRecord( 317 IntBuffer hashTable, 318 MappedByteBuffer hitTable, 319 int offset, 320 int seqNumber, 321 int word 322 ) { 323 int kmerPointer = hashTable.get(word+1); 324 if(kmerPointer != -1) { 325 kmerPointer += Constants.BYTES_IN_INT; 326 327 int hitCount = hitTable.getInt(kmerPointer); 328 int pos = kmerPointer + hitCount * (Constants.BYTES_IN_INT + Constants.BYTES_IN_INT) + Constants.BYTES_IN_INT; 329 330 hitTable.position(pos); 331 hitTable.putInt(seqNumber); 332 hitTable.putInt(offset); 333 hitTable.putInt(kmerPointer, hitCount + 1); 334 } 335 } 336 337}