001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.cox;
022
023import java.util.ArrayList;
024
025/**
026 *
027 * @author Scooter Willis 
028 */
029public class CoxScore {
030
031        /**
032         *
033         * @param method
034         * @param survivalInfoList
035         * @param coxInfo
036         * @param useStrata
037         * @return
038         */
039        public static double[][] process(CoxMethod method, ArrayList<SurvivalInfo> survivalInfoList, CoxInfo coxInfo, boolean useStrata) {
040                int i, j, k;
041                double temp;
042                int n = survivalInfoList.size();
043
044                ArrayList<String> variables = new ArrayList<>(coxInfo.getCoefficientsList().keySet());
045                int nvar = variables.size();
046
047                double deaths;
048                int dd;
049                double[] time = new double[n];
050                double[] status = new double[n];
051                double[] strata = new double[n];
052                double[] weights = new double[n];
053                double[] score = new double[n];
054                double[] a = new double[nvar];
055                double[] a2 = new double[nvar];
056                double denom = 0, e_denom;
057                double risk;
058                double[][] covar = new double[nvar][n];
059                double[][] resid = new double[nvar][n];
060                double hazard, meanwt;
061                double downwt, temp2;
062                double mean;
063
064                //  n = *nx;
065                //  nvar  = *nvarx;
066                for (int p = 0; p < n; p++) {
067                        SurvivalInfo si = survivalInfoList.get(p);
068                        time[p] = si.getTime();
069                        status[p] = si.getStatus();
070                        if (useStrata) {
071                                strata[p] = si.getStrata();
072                        } else {
073                                strata[p] = 0;
074                        }
075                        weights[p] = si.getWeight();
076                        score[p] = si.getScore();
077
078                        for(int v = 0; v < variables.size(); v++){
079                                String variable = variables.get(v);
080                                Double value = si.getVariable(variable);
081                                covar[v][p] = value;
082                        }
083
084                }
085
086
087
088                //  a = scratch;
089                //  a2 = a+nvar;
090        /*
091                 **  Set up the ragged array
092                 */
093                //   covar=  dmatrix(covar2, n, nvar);
094                //   resid=  dmatrix(resid2, n, nvar);
095
096                e_denom = 0;
097                deaths = 0;
098                meanwt = 0;
099                for (i = 0; i < nvar; i++) {
100                        a2[i] = 0;
101                }
102                strata[n - 1] = 1;  /*failsafe */
103                for (i = n - 1; i >= 0; i--) {
104                        if (strata[i] == 1) {
105                                denom = 0;
106                                for (j = 0; j < nvar; j++) {
107                                        a[j] = 0;
108                                }
109                        }
110
111                        risk = score[i] * weights[i];
112                        denom += risk;
113                        if (status[i] == 1) {
114                                deaths++;
115                                e_denom += risk;
116                                meanwt += weights[i];
117                                for (j = 0; j < nvar; j++) {
118                                        a2[j] += risk * covar[j][i];
119                                }
120                        }
121                        for (j = 0; j < nvar; j++) {
122                                a[j] += risk * covar[j][i];
123                                resid[j][i] = 0;
124                        }
125
126                        if (deaths > 0 && (i == 0 || strata[i - 1] == 1 || time[i] != time[i - 1])) {
127                                /* last obs of a set of tied death times */
128                                if (deaths < 2 || method == CoxMethod.Breslow) {
129                                        hazard = meanwt / denom;
130                                        for (j = 0; j < nvar; j++) {
131                                                temp = (a[j] / denom);     /* xbar */
132                                                for (k = i; k < n; k++) {
133                                                        temp2 = covar[j][k] - temp;
134                                                        if (time[k] == time[i] && status[k] == 1) {
135                                                                resid[j][k] += temp2;
136                                                        }
137                                                        resid[j][k] -= temp2 * score[k] * hazard;
138                                                        if (strata[k] == 1) {
139                                                                break;
140                                                        }
141                                                }
142                                        }
143                                } else {  /* the harder case */
144                                        meanwt /= deaths;
145                                        for (dd = 0; dd < deaths; dd++) {
146                                                downwt = dd / deaths;
147                                                temp = denom - downwt * e_denom;
148                                                hazard = meanwt / temp;
149                                                for (j = 0; j < nvar; j++) {
150                                                        mean = (a[j] - downwt * a2[j]) / temp;
151                                                        for (k = i; k < n; k++) {
152                                                                temp2 = covar[j][k] - mean;
153                                                                if (time[k] == time[i] && status[k] == 1) {
154                                                                        resid[j][k] += temp2 / deaths;
155                                                                        resid[j][k] -= temp2 * score[k] * hazard
156                                                                                        * (1 - downwt);
157                                                                } else {
158                                                                        resid[j][k] -= temp2 * score[k] * hazard;
159                                                                }
160                                                                if (strata[k] == 1) {
161                                                                        break;
162                                                                }
163                                                        }
164                                                }
165                                        }
166                                }
167                                e_denom = 0;
168                                deaths = 0;
169                                meanwt = 0;
170                                for (j = 0; j < nvar; j++) {
171                                        a2[j] = 0;
172                                }
173                        }
174                }
175
176                for (int p = 0; p < n; p++) {
177                        SurvivalInfo si = survivalInfoList.get(p);
178                        for (int v = 0; v < variables.size(); v++) {
179                                si.setResidualVariable(variables.get(v), resid[v][p]);
180                        }
181
182                }
183
184                //appears to be backward internally
185                double[][] flipresid = new double[n][nvar];
186
187                for(int s = 0; s < resid.length; s++){
188                        for(int t = 0; t  < resid[0].length; t++){
189                                flipresid[t][s] = resid[s][t];
190                        }
191                }
192
193                return flipresid;
194
195        }
196
197        /**
198         * @param args the command line arguments
199         */
200        public static void main(String[] args) {
201                // TODO code application logic here
202        }
203}