001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.cox;
022
023import org.apache.commons.math.stat.correlation.Covariance;
024import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
025import org.biojava.nbio.survival.cox.matrix.Matrix;
026
027import java.util.ArrayList;
028import java.util.Collections;
029import java.util.LinkedHashMap;
030
031/**
032 *
033 * @author Scooter Willis 
034 */
035public class CoxCC {
036
037        /**
038         *
039         * @param ci
040         * @throws Exception
041         */
042        static public void process(CoxInfo ci) throws Exception {
043                ArrayList<SurvivalInfo> survivalInfoList = ci.survivalInfoList;
044                //r
045                ArrayList<String> variables = new ArrayList<>(ci.getCoefficientsList().keySet());
046
047                ArrayList<Integer> strataClass = new ArrayList<>(survivalInfoList.size());
048                double[] wt = new double[survivalInfoList.size()];
049                for (int i = 0; i < survivalInfoList.size(); i++) {
050                        SurvivalInfo si = survivalInfoList.get(i);
051                        strataClass.add(si.getStrata());
052                        wt[i] = si.getWeight();
053                }
054
055
056                double[][] r = ResidualsCoxph.process(ci, ResidualsCoxph.Type.score, false, null); // dn not use weighted
057
058                // ArrayList<String> variables = ci.survivalInfoList.get(0).getDataVariables();
059//        if (false) {
060//            for (int i = 0; i < survivalInfoList.size(); i++) {
061//                SurvivalInfo si = survivalInfoList.get(i);
062//                System.out.print("Cox cc " + si.getOrder());
063//                for (int j = 0; j < variables.size(); j++) {
064//                    System.out.print(" " + r[i][j]);
065//                }
066//                System.out.println();
067//            }
068//        }
069
070                double[][] rvar = null;
071
072                if (ci.getNaiveVariance() != null) {
073                        rvar = ci.getNaiveVariance();
074                } else {
075                        rvar = ci.getVariance();
076                }
077                //nj
078                LinkedHashMap<Integer, Double> nj = new LinkedHashMap<>();
079                Collections.sort(strataClass);
080                for (Integer value : strataClass) {
081                        Double count = nj.get(value);
082                        if (count == null) {
083                                count = 0.0;
084                        }
085                        count++;
086                        nj.put(value, count);
087                }
088                //Nj
089                LinkedHashMap<Integer, Double> Nj = new LinkedHashMap<>();
090                //N = N + Nj[key];
091                double N = 0;
092                for (int i = 0; i < survivalInfoList.size(); i++) {
093                        SurvivalInfo si = survivalInfoList.get(i);
094                        Integer strata = si.getStrata();
095                        Double weight = si.getWeight();
096                        Double sum = Nj.get(strata);
097                        if (sum == null) {
098                                sum = 0.0;
099                        }
100                        sum = sum + weight;
101                        Nj.put(strata, sum);
102
103                }
104
105                for(Double value : Nj.values()){
106                        N = N + value;
107                }
108
109                LinkedHashMap<Integer, Double> k1j = new LinkedHashMap<>();
110                for (Integer key : nj.keySet()) {
111                        double _nj = (nj.get(key)); //trying to copy what R is doing on precision
112                        double _Nj = (Nj.get(key));
113                        //         System.out.println("nj=" + _nj + " Nj=" + _Nj);
114                        k1j.put(key, _Nj * ((_Nj / _nj) - 1));
115                }
116
117                double[][] V = new double[variables.size()][variables.size()];
118
119                for (Integer i : k1j.keySet()) {
120                        //          System.out.println("Strata=" + i + " " + k1j.get(i) + " " + Nj.get(i) + " " + nj.get(i));
121                        if (nj.get(i) > 1) {
122                                LinkedHashMap<String, DescriptiveStatistics> variableStatsMap = new LinkedHashMap<>();
123
124                                for (int p = 0; p < survivalInfoList.size(); p++) {
125                                        SurvivalInfo si = survivalInfoList.get(p);
126                                        if (si.getStrata() != i) {
127                                                continue;
128                                        }
129                                        //              System.out.print(si.order + " ");
130                                        for (int col = 0; col < variables.size(); col++) {
131                                                String v = variables.get(col);
132                                                DescriptiveStatistics ds = variableStatsMap.get(v);
133                                                if (ds == null) {
134                                                        ds = new DescriptiveStatistics();
135                                                        variableStatsMap.put(v, ds);
136                                                }
137                                                ds.addValue(r[p][col]);
138                                                //                  System.out.print(si.getResidualVariable(v) + "  ");
139                                        }
140                                        //              System.out.println();
141                                }
142                                //calculate variance covariance matrix var(r[class==levels(class)[i],],use='comp')
143                                double[][] var_covar = new double[variables.size()][variables.size()];
144                                for (int m = 0; m < variables.size(); m++) {
145                                        String var_m = variables.get(m);
146                                        for (int n = 0; n < variables.size(); n++) {
147                                                String var_n = variables.get(n);
148                                                if (m == n) {
149                                                        DescriptiveStatistics ds = variableStatsMap.get(var_m);
150                                                        var_covar[m][n] = ds.getVariance();
151                                                } else {
152                                                        DescriptiveStatistics ds_m = variableStatsMap.get(var_m);
153                                                        DescriptiveStatistics ds_n = variableStatsMap.get(var_n);
154                                                        Covariance cv = new Covariance();
155                                                        double covar = cv.covariance(ds_m.getValues(), ds_n.getValues(), true);
156                                                        var_covar[m][n] = covar;
157                                                }
158                                        }
159                                }
160                 //              System.out.println();
161                 //              System.out.println("sstrat=" + i);
162                 //              StdArrayIO.print(var_covar);
163
164                                           V = Matrix.add(V, Matrix.scale(var_covar, k1j.get(i))  );
165
166                 //       for (int m = 0; m < V.length; m++) {
167                 //           for (int n = 0; n < V.length; n++) {
168                 //               V[m][n] = V[m][n] + (k1j.get(i) * var_covar[m][n]);
169                  //
170                 //           }
171                  //      }
172                        }
173                }
174                //     System.out.println("V");
175                //     StdArrayIO.print(V);
176                //     System.out.println();
177                //z$var <- rvar + rvar %*% V %*% rvar # replace variance in z
178                double[][] imat1 = Matrix.multiply(rvar, V);
179                imat1 = Matrix.multiply(imat1, rvar);
180                imat1 = Matrix.add(rvar, imat1);
181                //  System.out.println("New var");
182                //  StdArrayIO.print(imat1);
183                ci.setVariance(imat1);
184
185                //need to update walsh stats for overall model
186                CoxR.calculateWaldTestInfo(ci);
187                //per Bob/Kathryn email on 4/23/2014 in a weighted model LogRank p-value is no longer valid so should erase it
188                ci.setScoreLogrankTest(Double.NaN);
189                ci.setScoreLogrankTestpvalue(Double.NaN);
190        }
191
192        /**
193         * @param args the command line arguments
194         */
195        public static void main(String[] args) {
196                // TODO code application logic here
197        }
198}