001package org.biojavax.bio.phylo;
002
003import java.io.*;
004import java.lang.*;
005import java.util.*;
006import java.util.ArrayList;
007import java.util.List;
008
009import org.biojavax.bio.phylo.io.nexus.*;
010import org.jgrapht.*;
011import org.jgrapht.graph.*;
012
013
014/*
015  *   Phylogeny reconstruction methods based on distance
016  *
017  *   @author Bohyun Lee
018  */
019
020public class DistanceBasedTreeMethod {
021        
022         /*  
023                *
024                *              generating a tree graph that has been generated by UPGMA model
025                *
026                *              @ param t
027          *                     Nexus TaxaBlock,  contains number of taxa, taxa labels
028          *
029                *              @ param ch
030                *                               Nexus CharactersBlock, contains sequence information
031          *
032          *         @ returns   the generated graph in JGraphT type
033          */
034        
035        public static WeightedGraph<String, DefaultWeightedEdge> Upgma(TaxaBlock t, CharactersBlock ch){
036        
037                String v1, v2, v3;
038                int index_x = 0, index_y = 0, p_index = 0;
039                WeightedGraph<String, DefaultWeightedEdge> jgrapht =  new SimpleWeightedGraph<String, DefaultWeightedEdge>(DefaultWeightedEdge.class);
040        
041                int NTax = t.getDimensionsNTax();
042                List labels = t.getTaxLabels();
043                
044                String [] seq = new String[NTax];
045                double [][] distance = new double[NTax][];
046
047                for(int i = 0; i < NTax; i++){
048                        seq[i] = "";
049                        distance[i] = new double[NTax];
050                }
051
052                for (Iterator i = labels.iterator(); i.hasNext(); ) {
053                        String taxa = (String)i.next();
054                      List matrix = ch.getMatrixData(taxa);
055                        
056                      for (Iterator j = matrix.iterator(); j.hasNext(); ) {                     
057                                Object elem = j.next();
058                                
059                                if (elem instanceof Set) {
060                        
061                                        // This is a curly-braces {} enclosed
062                                                // set of values from the matrix.
063                                                 Set data = (Set)elem;
064                                
065                                } else if (elem instanceof List) {
066                              
067                                        // This is a round-braces () enclosed
068                                                // set of values from the matrix.
069                              
070                                        List data = (List)elem;
071                                } else {
072                                                      // Assume it's a string.
073                             
074                                         String data = elem.toString();
075                                          
076                                        if(data != null && data != " ")
077                                                seq[labels.indexOf(taxa)] += data;
078                                }
079                        }
080                }
081        
082                // build initial distance matrix
083                for( int i = 0; i < NTax; i++){
084                        for(int j = 0; j < NTax; j++){
085                                if(i == j) 
086                                        distance[i][j] = 0.0;
087                                else
088                                        distance[i][j] = MultipleHitCorrection.JukesCantor(seq[i], seq[j]);     
089                        }
090                }
091                
092                do{
093                        //find minimum distance pair
094                        double min_d = distance[0][1];
095                        for( int i = 0; i < NTax; i++){
096                                for(int j = i + 1; j < NTax; j++){
097                                        if( min_d >= distance[i][j]){ 
098                                                min_d = distance[i][j];
099                                                index_x = i;
100                                                index_y = j;
101                                        }
102                                }
103                        }
104
105                        // build a sub-tree by using jgrapht
106                        v1 = (String) labels.get(index_x);
107                        v2 = "p" + p_index;
108                        v3 = (String) labels.get(index_y);
109                        
110                        jgrapht.addVertex(v1);
111                        jgrapht.addVertex(v2);
112                        jgrapht.addVertex(v3);  
113                        jgrapht.addEdge(v1,v2);
114                        jgrapht.addEdge(v2,v3); 
115
116                        p_index++;
117                        
118                        //System.out.println(jgrapht.toString());
119
120                        //collapse a min_distance pair and re-build distance matrix
121                        for(int i = 0; i < NTax; i++){
122                                for(int j = i; j < NTax; j++){
123                                        if(i == j){
124                                                distance[i][j] = 0.0;
125                                        }else if(i == index_x && j == index_y){
126                                                for(int k = j+1; k < NTax; k++){
127                                        
128                                                        distance[i][j] =  (distance[k][i] + distance[k][j])/2;
129                                                        distance[j][i] = distance[i][j];
130                                                        labels.set(i, (Object) v2);
131                                                        labels.set(j, labels.get(k));
132                                                }
133                                        
134                                                labels.set(NTax-1, (Object) null);
135                                        }else if(j == index_x){
136                                                for(int k = j+1; k < NTax; k++){
137                                                        if(k == index_y){
138                                                                distance[i][j] =  (distance[i][j] + distance[i][k])/2;
139                                                                distance[j][i] = distance[i][j];
140                                                                labels.set(index_x, (Object) v2);
141                                                                labels.set(index_y, (Object) null);
142                                                        }
143                                                }
144                                        }
145                                        
146                                }
147                        }
148
149                        NTax--;
150
151                //iterate until tree is completed!
152                }while(NTax > 1);               
153                
154                return jgrapht;
155        }       
156
157        
158        /*  
159                *
160                *              generating a tree graph that has been generated by Neighbor-Joining model
161                *
162                *              @ param t
163          *                     Nexus TaxaBlock,  contains number of taxa, taxa labels
164          *
165                *              @ param ch
166                *                               Nexus CharactersBlock, contains sequence information
167          *
168          *         @ returns   the generated graph in JGraphT type
169          */
170        
171
172        public static WeightedGraph<String, DefaultWeightedEdge> NeighborJoining(TaxaBlock t, CharactersBlock ch){
173
174                String v1, v2, v3;
175                int index_x = 0, index_y = 0, p_index = 0;
176                WeightedGraph<String, DefaultWeightedEdge> jgrapht =  new SimpleWeightedGraph<String, DefaultWeightedEdge>(DefaultWeightedEdge.class);
177        
178                int NTax = t.getDimensionsNTax();
179                List labels = t.getTaxLabels();
180                
181                String [] seq = new String[NTax];
182                double []net_divergence = new double[NTax];
183                double [][] raw_distance = new double[NTax][];
184                double [][] distance = new double[NTax][];
185
186                for(int i = 0; i < NTax; i++){
187                        seq[i] = "";
188                        raw_distance[i] = new double[NTax];
189                        distance[i] = new double[NTax];
190                }
191
192                for (Iterator i = labels.iterator(); i.hasNext(); ) {
193                        String taxa = (String)i.next();
194                      List matrix = ch.getMatrixData(taxa);
195                        
196                      for (Iterator j = matrix.iterator(); j.hasNext(); ) {                     
197                                Object elem = j.next();
198                                
199                                if (elem instanceof Set) {
200                                        // This is a curly-braces {} enclosed
201                                                         // set of values from the matrix.
202                                        
203                                        Set data = (Set)elem;
204                        
205                                } else if (elem instanceof List) {
206                                        
207                                        // This is a round-braces () enclosed
208                                                // set of values from the matrix.
209                                        
210                                        List data = (List)elem;
211                        
212                                } else {
213                                
214                                          // Assume it's a string.
215                                                  String data = elem.toString();
216                                          
217                                         if(data != null && data != " ")
218                                                seq[labels.indexOf(taxa)] += data;
219                                }
220                        }
221                }
222                
223                
224                // build initial distance matrix
225                for( int i = 0; i < NTax; i++){
226                        for(int j = 0; j< NTax; j++){
227                                if(i == j) 
228                                        raw_distance[i][j] = 0.0;
229                                else
230                                        raw_distance[i][j] = MultipleHitCorrection.JukesCantor(seq[i], seq[j]); 
231                        
232                                net_divergence[i] =+ raw_distance[i][j];
233                        }
234                }
235                
236                //iterate until tree is completed!
237                do{
238                        // calculate distance matrix from raw_distances & net divergence
239                        for(int i = 0; i < NTax; i++){
240                                for(int j = 0; j < NTax; j++){
241                                        if(i == j)
242                                                distance[i][j] = 0.0;
243                                        else
244                                                distance[i][j] = raw_distance[i][j] - ((net_divergence[i] + net_divergence[j])/2) ;
245                                }
246                        }
247                        
248                        //find minimum distance pair
249                        double min_d = distance[0][1];
250                        for( int i = 0; i < NTax; i++){
251                                for(int j = i + 1; j < NTax; j++){
252                                        if( min_d >= distance[i][j]){ 
253                                                min_d = distance[i][j];
254                                                index_x = i;
255                                                index_y = j;
256                                        }
257                                }
258                        }
259
260                        // build a sub-tree by using jgrapht
261                        v1 = (String) labels.get(index_x);
262                        v2 = "p" + p_index;
263                        v3 = (String) labels.get(index_y);
264                        
265                        jgrapht.addVertex(v1);
266                        jgrapht.addVertex(v2);
267                        jgrapht.addVertex(v3);  
268                        jgrapht.addEdge(v1,v2);
269                        jgrapht.addEdge(v2,v3);
270                        
271                        //adding weight to the edge
272                        jgrapht.setEdgeWeight(jgrapht.getEdge(v1,v2), ((raw_distance[index_x][index_y]/2) + (net_divergence[index_x] - net_divergence[index_y])/(2*(NTax-2))) );
273                        jgrapht.setEdgeWeight(jgrapht.getEdge(v2,v3), raw_distance[index_x][index_y] - ((raw_distance[index_x][index_y]/2) + (net_divergence[index_x] - net_divergence[index_y])/(2*(NTax-2))) );               
274                        
275                        p_index++;
276                        
277                        //System.out.println(jgrapht.toString());
278
279                        //collapse a min_distance pair and re-build distance matrix
280                        for(int i = 0; i < NTax; i++){
281                                for(int j = i; j < NTax; j++){
282                                        if(i == j){
283                                                distance[i][j] = 0.0;
284                                        }else if(i == index_x && j == index_y){
285                                                for(int k = j+1; k < NTax; k++){
286                                                        raw_distance[i][j] =  (raw_distance[k][i] + raw_distance[k][j] - raw_distance[index_x][index_y])/2;
287                                                        raw_distance[j][i] = raw_distance[i][j];
288                                                        labels.set(i, (Object) v2);
289                                                        labels.set(j, (Object) labels.get(k));
290                                                }
291                                                labels.set(NTax-1, (Object) null);
292                                        }else if(j == index_x){
293                                                for(int k = j+1; k < NTax; k++){
294                                                        if(k == index_y){
295                                                                raw_distance[i][j] =  (raw_distance[i][j] + raw_distance[i][k] -raw_distance[index_x][index_y])/2;
296                                                                raw_distance[j][i] = raw_distance[i][j];
297                                                                labels.set(index_x, (Object) v2);
298                                                                labels.set(index_y, (Object) null);
299                                                        }
300                                                }
301                                        }
302                                        
303                                }
304                        }
305
306                        NTax--;
307
308                //iterate until tree is completed!
309                }while(NTax > 1);
310        return jgrapht;                 
311        }       
312}
313