001package org.biojava.bio.program.ssaha;
002
003import java.io.ByteArrayInputStream;
004import java.io.File;
005import java.io.FileInputStream;
006import java.io.IOException;
007import java.io.ObjectInputStream;
008import java.nio.IntBuffer;
009import java.nio.MappedByteBuffer;
010import java.nio.channels.FileChannel;
011
012import org.biojava.bio.BioError;
013import org.biojava.bio.symbol.FiniteAlphabet;
014import org.biojava.bio.symbol.IllegalSymbolException;
015import org.biojava.bio.symbol.Packing;
016import org.biojava.bio.symbol.SymbolList;
017
018/**
019 * An implementation of DataStore that will map onto a file using the NIO
020 * constructs. You should obtain one of these by using the methods in
021 * MappedDataStoreFactory.
022 *
023 * @author Matthew Pocock
024 * @author Thomas Down
025 */
026
027public class CompactedDataStore implements DataStore {
028  private final Packing packing;
029  private final int wordLength;
030  private final IntBuffer hashTable;
031  private final MappedByteBuffer hitTable;
032  private final IntBuffer nameArray;
033  private final MappedByteBuffer nameTable;
034  private final int numSequences;
035  
036  CompactedDataStore(File dataStoreFile)
037      throws IOException 
038  {
039    FileChannel channel = new FileInputStream(dataStoreFile).getChannel();
040    
041    MappedByteBuffer rootBuffer = channel.map(
042      FileChannel.MapMode.READ_ONLY,
043      0,
044      4 * 6
045    );
046    rootBuffer.position(0);
047    
048    final int hashTablePos = rootBuffer.getInt();
049    final int hitTablePos = rootBuffer.getInt();
050    final int nameArrayPos = rootBuffer.getInt();
051    final int nameTablePos = rootBuffer.getInt();
052    wordLength = rootBuffer.getInt();
053    
054    // extend root map to include the serialized packing
055    int packingStreamLength = rootBuffer.getInt();
056    //System.out.println("hashTablePos:\t" + hashTablePos);
057    //System.out.println("hitTablePos:\t" + hitTablePos);
058    //System.out.println("nameArrayPos:\t" + nameArrayPos);
059    //System.out.println("nameTablePos:\t" + nameTablePos);
060    //System.out.println("packingStreamLength:\t" + packingStreamLength);
061    rootBuffer = channel.map(
062      FileChannel.MapMode.READ_ONLY,
063      0,
064      4 * 6 + packingStreamLength
065    );
066    rootBuffer.position(4 * 6);
067    byte[] packingBuffer = new byte[packingStreamLength];
068    rootBuffer.get(packingBuffer);
069    ByteArrayInputStream packingStream = new ByteArrayInputStream(packingBuffer);
070    ObjectInputStream packingSerializer = new ObjectInputStream(packingStream);
071    
072    try {
073      this.packing = (Packing) packingSerializer.readObject();
074    } catch (ClassNotFoundException cnfe) {
075      throw new Error("Can't restore packing", cnfe);
076    }
077    
078    // map in buffer for the hash table
079    MappedByteBuffer hashTable_MB = channel.map(
080      FileChannel.MapMode.READ_ONLY,
081      hashTablePos,
082      4
083    );
084    hashTable_MB.position(0);
085    int hashTableSize = hashTable_MB.getInt();
086    hashTable = channel.map(
087      FileChannel.MapMode.READ_ONLY,
088      hashTablePos + 4,
089      hashTableSize - 4
090    ).asIntBuffer();
091    
092    // map in buffer for hit table
093    MappedByteBuffer hitTable_MB = channel.map(
094      FileChannel.MapMode.READ_ONLY,
095      hitTablePos,
096      4
097    );
098    hitTable_MB.position(0);
099    int hitTableSize = hitTable_MB.getInt();
100    hitTable = channel.map(
101      FileChannel.MapMode.READ_ONLY,
102      hitTablePos + 4,
103      hitTableSize - 4
104    );
105    
106    // map in buffer for names array
107    MappedByteBuffer nameArray_MB = channel.map(
108      FileChannel.MapMode.READ_ONLY,
109      nameArrayPos,
110      4
111    );
112    nameArray_MB.position(0);
113    int nameArraySize = nameArray_MB.getInt();
114    numSequences = nameArraySize / 8;
115    // System.err.println("numSequences: " + numSequences);
116    nameArray = channel.map(
117      FileChannel.MapMode.READ_ONLY,
118      nameArrayPos + 4,
119      nameArraySize - 4
120    ).asIntBuffer();
121    
122    // map in buffer for names table
123    MappedByteBuffer nameTable_MB = channel.map(
124      FileChannel.MapMode.READ_ONLY,
125      nameTablePos,
126      4
127    );
128    nameTable_MB.position(0);
129    int nameTableSize = nameTable_MB.getInt();
130    nameTable = channel.map(
131      FileChannel.MapMode.READ_ONLY,
132      nameTablePos + 4,
133      nameTableSize - 4
134    );
135  }
136  
137  public FiniteAlphabet getAlphabet() {
138    return packing.getAlphabet();
139  }
140  
141  public void search(
142    String seqID,
143    SymbolList symList,
144    SearchListener listener
145  ) {
146    try {
147        int word = 0;
148        int lengthFromUnknown = 0;
149        listener.startSearch(seqID);
150        for(int pos = 1; pos <= symList.length(); pos++) {
151            word = word >> (int) packing.wordSize();
152            int p = packing.pack(symList.symbolAt(pos));
153            if (p < 0) {
154                lengthFromUnknown = 0;
155            } else {
156                lengthFromUnknown++;
157                word |= (int) p << ((int) (wordLength - 1) * packing.wordSize());
158            }
159            
160            if (lengthFromUnknown >= wordLength) {
161                fireHits(word, pos - wordLength + 1, listener);
162            }
163        }
164        listener.endSearch(seqID);
165    } catch (IllegalSymbolException ise) {
166      throw new BioError("Assertion Failure: Symbol dissapeared");
167    }
168  }
169  
170  public String seqNameForID(int id) {
171    int offset = nameArray.get(id);
172    nameTable.position(offset);
173    int length = nameTable.getInt();
174    StringBuffer sbuff = new StringBuffer(length);
175    for(int i = 0; i < length; i++) {
176      sbuff.append(nameTable.getChar());
177    }
178    return sbuff.toString();
179  }
180  
181    private int seqIDForPos(int pos) {
182        if (numSequences == 1) {
183            return 0;
184        } else {
185            int maxBound = numSequences - 1;
186            int minBound = 0;
187
188            while (true) {
189                int mid = (minBound + maxBound) / 2;
190                int offset = nameArray.get((mid * 2) + 1);
191                int endOffset = Integer.MAX_VALUE;
192                if (mid < (numSequences - 1)) {
193                    endOffset = nameArray.get((mid * 2) + 3);
194                }
195                if (pos > offset && pos < endOffset) {
196                    return mid * 2;
197                } else if (pos < offset) {
198                    maxBound = mid - 1;
199                } else if (pos > endOffset) {
200                    minBound = mid + 1;
201                } else {
202                    throw new Error("Ooops: could not locate seq name for " +
203                                    "\tpos: " + pos +
204                                    "\tmid: " + mid +
205                                    "\toffset: " + offset +
206                                    "\tendOffset: " + endOffset +
207                                    "\tminBound: " + minBound +
208                                    "\tmaxBound: " + maxBound);
209                }
210            }
211        }
212    }
213
214    private int offsetForID(int id) {
215        if (numSequences == 1) {
216            return 0;
217        } else {
218            return nameArray.get(id + 1);
219        }
220    }
221
222  public void fireHits(
223    int word,
224    int offset,
225    SearchListener listener
226  ) {
227    int hitOffset = hashTable.get(word);
228    if(hitOffset >= 0) {
229      try {
230        hitTable.position(hitOffset);
231      } catch (IllegalArgumentException e) {
232        System.out.println("word:\t" + word);
233        System.out.println("offset:\t" + offset);
234        System.out.println("hitOffset\t" + hitOffset);
235        throw e;
236        
237      }
238      int hits = hitTable.getInt();
239      
240      for(int i = 0; i < hits; i++) {
241          int pos = hitTable.getInt();
242          int id = seqIDForPos(pos);
243
244          listener.hit(
245                       id,
246                       offset,
247                       pos - offsetForID(id),
248                       wordLength
249                      );
250      }
251    } else if (hitOffset == -2) {
252        System.err.println("Hit an elided word!");
253    }
254  }
255}