001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.cox.stats;
022
023import org.biojava.nbio.survival.cox.CoxInfo;
024import org.biojava.nbio.survival.cox.CoxMethod;
025import org.biojava.nbio.survival.cox.SurvivalInfo;
026
027import java.util.ArrayList;
028
029/**
030 *
031 * @author Scooter Willis 
032 */
033public class AgScore {
034
035        /**
036         *
037         * @param method
038         * @param survivalInfoList
039         * @param coxInfo
040         * @param useStrata
041         * @return
042         */
043        public static double[][] process(CoxMethod method, ArrayList<SurvivalInfo> survivalInfoList, CoxInfo coxInfo, boolean useStrata) {
044                int i, k;
045                //double temp;
046                int n = survivalInfoList.size();
047
048                ArrayList<String> variables = new ArrayList<>(coxInfo.getCoefficientsList().keySet());
049                int nvar = variables.size();
050
051
052                int dd;
053
054                double[] event = new double[n];
055                double[] start = new double[n];
056                double[] stop = new double[n];
057
058                double[] strata = new double[n];
059                double[] weights = new double[n];
060                double[] score = new double[n];
061
062                double[] a = new double[nvar];
063                double[] a2 = new double[nvar];
064                double[] mean = new double[nvar];
065                double[] mh1 = new double[nvar];
066                double[] mh2 = new double[nvar];
067                double[] mh3 = new double[nvar];
068
069                double denom = 0;
070                double time = 0;
071                double e_denom = 0;
072                double meanwt = 0;
073                double deaths = 0;
074                double risk;
075                double[][] covar = new double[nvar][n];
076                double[][] resid = new double[nvar][n];
077                double hazard;
078                double downwt, temp1, temp2, d2;
079
080
081                int person = 0;
082
083                //  n = *nx;
084                //  nvar  = *nvarx;
085                for (int p = 0; p < n; p++) {
086                        SurvivalInfo si = survivalInfoList.get(p);
087                        stop[p] = si.getTime();
088                        event[p] = si.getStatus();
089                        if (useStrata) {
090                                strata[p] = si.getStrata();
091                        } else {
092                                strata[p] = 0;
093                        }
094                        weights[p] = si.getWeight();
095                        score[p] = si.getScore();
096
097                        for (int v = 0; v < variables.size(); v++) {
098                                String variable = variables.get(v);
099                                Double value = si.getVariable(variable);
100                                covar[v][p] = value;
101                        }
102
103                }
104
105                for (person = 0; person < n;) {
106                        if (event[person] == 0) {
107                                person++;
108                        } else {
109                                /*
110                                 ** compute the mean over the risk set, also hazard at this time
111                                 */
112                                denom = 0;
113                                e_denom = 0;
114                                meanwt = 0;
115                                deaths = 0;
116                                for (i = 0; i < nvar; i++) {
117                                        a[i] = 0;
118                                        a2[i] = 0;
119                                }
120                                time = stop[person];
121                                for (k = person; k < n; k++) {
122                                        if (start[k] < time) {
123                                                risk = score[k] * weights[k];
124                                                denom += risk;
125                                                for (i = 0; i < nvar; i++) {
126                                                        a[i] = a[i] + risk * covar[i][k];
127                                                }
128                                                if (stop[k] == time && event[k] == 1) {
129                                                        deaths++;
130                                                        e_denom += risk;
131                                                        meanwt += weights[k];
132                                                        for (i = 0; i < nvar; i++) {
133                                                                a2[i] = a2[i] + risk * covar[i][k];
134                                                        }
135                                                }
136                                        }
137                                        if (strata[k] == 1) {
138                                                break;
139                                        }
140                                }
141
142                                /* add things in for everyone in the risk set*/
143                                if (deaths < 2 || method == CoxMethod.Breslow) {
144                                        /* easier case */
145                                        hazard = meanwt / denom;
146                                        for (i = 0; i < nvar; i++) {
147                                                mean[i] = a[i] / denom;
148                                        }
149                                        for (k = person; k < n; k++) {
150                                                if (start[k] < time) {
151                                                        risk = score[k];
152                                                        for (i = 0; i < nvar; i++) {
153                                                                resid[i][k] -= (covar[i][k] - mean[i]) * risk * hazard;
154                                                        }
155                                                        if (stop[k] == time) {
156                                                                person++;
157                                                                if (event[k] == 1) {
158                                                                        for (i = 0; i < nvar; i++) {
159                                                                                resid[i][k] += (covar[i][k] - mean[i]);
160                                                                        }
161                                                                }
162                                                        }
163                                                }
164                                                if (strata[k] == 1) {
165                                                        break;
166                                                }
167                                        }
168                                } else {
169                                        /*
170                                         ** If there are 3 deaths, let m1, m2, m3 be the three
171                                         **   weighted means,  h1, h2, h3 be the three hazard jumps.
172                                         ** Then temp1 = h1 + h2 + h3
173                                         **      temp2 = h1 + (2/3)h2 + (1/3)h3
174                                         **      mh1   = m1*h1 + m2*h2 + m3*h3
175                                         **      mh2   = m1*h1 + (2/3)m2*h2 + (1/3)m3*h3
176                                         **      mh3   = (1/3)*(m1+m2+m3)
177                                         */
178                                        temp1 = 0;
179                                        temp2 = 0;
180                                        for (i = 0; i < nvar; i++) {
181                                                mh1[i] = 0;
182                                                mh2[i] = 0;
183                                                mh3[i] = 0;
184                                        }
185                                        meanwt /= deaths;
186                                        for (dd = 0; dd < deaths; dd++) {
187                                                downwt = dd / deaths;
188                                                d2 = denom - downwt * e_denom;
189                                                hazard = meanwt / d2;
190                                                temp1 += hazard;
191                                                temp2 += (1 - downwt) * hazard;
192                                                for (i = 0; i < nvar; i++) {
193                                                        mean[i] = (a[i] - downwt * a2[i]) / d2;
194                                                        mh1[i] += mean[i] * hazard;
195                                                        mh2[i] += mean[i] * (1 - downwt) * hazard;
196                                                        mh3[i] += mean[i] / deaths;
197                                                }
198                                        }
199                                        for (k = person; k < n; k++) {
200                                                if (start[k] < time) {
201                                                        risk = score[k];
202                                                        if (stop[k] == time && event[k] == 1) {
203                                                                for (i = 0; i < nvar; i++) {
204                                                                        resid[i][k] += covar[i][k] - mh3[i];
205                                                                        resid[i][k] -= risk * covar[i][k] * temp2;
206                                                                        resid[i][k] += risk * mh2[i];
207                                                                }
208                                                        } else {
209                                                                for (i = 0; i < nvar; i++) {
210                                                                        resid[i][k] -= risk * (covar[i][k] * temp1 - mh1[i]);
211                                                                }
212                                                        }
213                                                }
214                                                if (strata[k] == 1) {
215                                                        break;
216                                                }
217                                        }
218                                        for (; stop[person] == time; person++) {
219                                                if (strata[person] == 1) {
220                                                        break;
221                                                }
222                                        }
223                                }
224                        }
225                }
226
227
228                //appears to be backward internally
229                double[][] flipresid = new double[n][nvar];
230
231                for (int s = 0; s < resid.length; s++) {
232                        for (int t = 0; t < resid[0].length; t++) {
233                                flipresid[t][s] = resid[s][t];
234                        }
235                }
236
237                return flipresid;
238
239        }
240
241        /**
242         * @param args the command line arguments
243         */
244        public static void main(String[] args) {
245                // TODO code application logic here
246        }
247}