001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.cox;
022
023import org.biojava.nbio.survival.cox.matrix.Matrix;
024import org.biojava.nbio.survival.cox.stats.ChiSq;
025import org.biojava.nbio.survival.cox.stats.Cholesky2;
026import org.biojava.nbio.survival.data.WorkSheet;
027
028import java.io.InputStream;
029import java.util.ArrayList;
030import java.util.Collections;
031
032/**
033 *   This is a port of the R survival code used for doing Cox Regression. The algorithm was a fairly easy port from C code to Java where the challenge was
034 *   making the code a little more object friendly. In the R code everything is passed around as an array and a large portion of the code is spent extracting
035 *   data from the array for use in different calculations. By organizing the data in a class for each data point was able to simplify much of the code.
036 *   Not all variants of different methods that you can select for doing various statistical calculations are implemented. Wouldn't be difficult to go back in
037 *   add them in if they are important.
038 *
039 *<p>In R you can pass in different paramaters to override defaults which requires parsing of the paramaters. In the Java code tried to be a little more exact
040 *   in the code related to paramaters where using strata, weighting, robust and cluster are advance options. Additionaly code is implemented from Bob Gray
041 *   to do variance correction when using weighted paramaters in a data set.
042 *   /Users/Scooter/NetBeansProjects/biojava3-survival/docs/wtexamples.docx
043 *
044 *<p>The CoxHelper class is meant to hide some of the implementation details.
045 *
046 *<p>Issues
047 *<ul>
048 *<li>sign in CoxMart?
049 *<li>double toler_chol = 1.818989e-12; Different value for some reason
050 *<li>In robust linear_predictor set to 0 which implies score = 1 but previous score value doesn't get reset
051 *</ul>
052 *
053 *  Cox regression fit, replacement for coxfit2 in order
054 *    to be more frugal about memory: specificly that we
055 *    don't make copies of the input data.
056 *
057 *
058 *
059 *
060 *
061 * <p>
062 * the input parameters are
063 *
064 * <pre>
065 *      maxiter      :number of iterations
066 *      time(n)      :time of status or censoring for person i
067 *      status(n)    :status for the ith person    1=dead , 0=censored
068 *      covar(nv,n)  :covariates for person i.
069 *                       Note that S sends this in column major order.
070 *      strata(n)    :marks the strata.  Will be 1 if this person is the
071 *                      last one in a strata.  If there are no strata, the
072 *                      vector can be identically zero, since the nth person's
073 *                      value is always assumed to be = to 1.
074 *      offset(n)    :offset for the linear predictor
075 *      weights(n)   :case weights
076 *      init         :initial estimate for the coefficients
077 *      eps          :tolerance for convergence.  Iteration continues until
078 *                      the percent change in loglikelihood is <= eps.
079 *      chol_tol     : tolerance for the Cholesky decompostion
080 *      method       : 0=Breslow, 1=Efron
081 *      doscale      : 0=don't scale the X matrix, 1=scale the X matrix
082 * </pre>
083 * returned parameters
084 * <pre>
085 *      means(nv)    : vector of column means of X
086 *      beta(nv)     :the vector of answers (at start contains initial est)
087 *      u(nv)        :score vector
088 *      imat(nv,nv)  :the variance matrix at beta=final
089 *                     (returned as a vector)
090 *      loglik(2)    :loglik at beta=initial values, at beta=final
091 *      sctest       :the score test at beta=initial
092 *      flag         :success flag  1000  did not converge
093 *                                  1 to nvar: rank of the solution
094 *      iterations         :actual number of iterations used
095 * </pre>
096 * work arrays
097 * <pre>
098 *      mark(n)
099 *      wtave(n)
100 *      a(nvar), a2(nvar)
101 *      cmat(nvar,nvar)       ragged array
102 *      cmat2(nvar,nvar)
103 *      newbeta(nvar)         always contains the "next iteration"
104 * </pre>
105 * calls functions:  cholesky2, chsolve2, chinv2
106 * <p>
107 * the data must be sorted by ascending time within strata
108 *
109 * @author Scooter Willis <willishf at gmail dot com>
110 */
111public class CoxR {
112
113        /**
114         *
115         * @param variables
116         * @param DataT
117         * @param useStrata
118         * @param useWeighted
119         * @param robust
120         * @param cluster
121         * @return
122         * @throws Exception
123         */
124        public CoxInfo process(ArrayList<String> variables, ArrayList<SurvivalInfo> DataT, boolean useStrata, boolean useWeighted, boolean robust, boolean cluster) throws Exception {
125                //from coxph.control.S
126                int maxiter2 = 20;
127                double eps2 = 1e-9;
128                double toler2 = Math.pow(eps2, .75);
129                int doscale2 = 1;
130                //int method2 = 0;
131                double[] beta = new double[variables.size()];
132                return process(variables, DataT, maxiter2, CoxMethod.Efron, eps2, toler2, beta, doscale2, useStrata, useWeighted, robust, cluster);
133
134        }
135
136        /**
137         *
138         * @param variables
139         * @param data
140         * @param maxiter
141         * @param method
142         * @param eps
143         * @param toler
144         * @param beta
145         * @param doscale
146         * @param useStrata
147         * @param useWeighted
148         * @param robust
149         * @param cluster
150         * @return
151         * @throws Exception
152         */
153        public CoxInfo process(ArrayList<String> variables, ArrayList<SurvivalInfo> data, int maxiter, CoxMethod method, double eps, double toler, double[] beta, int doscale, boolean useStrata, boolean useWeighted, boolean robust, boolean cluster) throws Exception {
154                //make sure data is converted to numbers if labels are used
155                SurvivalInfoHelper.categorizeData(data);
156                //create variables if testing for interaction
157                for (String variable : variables) {
158                        if (variable.indexOf(":") != -1) {
159                                String[] d = variable.split(":");
160                                SurvivalInfoHelper.addInteraction(d[0], d[1], data);
161                        }
162                }
163
164                Collections.sort(data);
165                // Collections.reverse(data);
166                CoxInfo coxInfo = new CoxInfo();
167                coxInfo.setSurvivalInfoList(data);
168
169
170
171                int i, j, k, person;
172                boolean gotofinish = false;
173                double cmat[][], imat[][];  /*ragged arrays covar[][], */
174                double wtave;
175                double a[], newbeta[];
176                double a2[], cmat2[][];
177                double scale[];
178                double denom = 0, zbeta, risk;
179                double temp, temp2;
180                int ndead;  /* actually, the sum of their weights */
181                double newlk = 0;
182                double dtime, d2;
183                double deadwt;  /*sum of case weights for the deaths*/
184                double efronwt; /* sum of weighted risk scores for the deaths*/
185                int halving;    /*are we doing step halving at the moment? */
186                @SuppressWarnings("unused")
187                int nrisk = 0;   /* number of subjects in the current risk set */
188
189                /* copies of scalar input arguments */
190                int nused, nvar;
191
192
193
194
195                /* vector inputs */
196                //  double *time, *weights, *offset;
197                //  int *status, *strata;
198
199                /* returned objects */
200                // double imat2[][];
201                double u[], loglik[], means[];
202
203
204                double sctest;
205                int flag = 0;
206                int iter = 0;
207                //SEXP rlist, rlistnames;
208                //  int nprotect;  /* number of protect calls I have issued */
209
210                /* get local copies of some input args */
211                nused = data.size(); // LENGTH(offset2);
212                nvar = variables.size(); // ncols(covar2);
213
214
215                //       imat2 = new double[nvar][nvar];
216//        nprotect++;
217                imat = new double[nvar][nvar]; //dmatrix(REAL(imat2),  nvar, nvar);
218                a = new double[nvar]; //(double *) R_alloc(2*nvar*nvar + 4*nvar, sizeof(double));
219                newbeta = new double[nvar]; //a + nvar;
220                a2 = new double[nvar]; //newbeta + nvar;
221                scale = new double[nvar]; //a2 + nvar;
222                cmat = new double[nvar][nvar]; //dmatrix(scale + nvar,   nvar, nvar);
223                cmat2 = new double[nvar][nvar]; //dmatrix(scale + nvar +nvar*nvar, nvar, nvar);
224
225                /*
226                 ** create output variables
227                 */
228//    PROTECT(beta2 = duplicate(ibeta));
229//    beta = REAL(beta2);
230                //  beta = new double[nvar];
231                // beta = beta2;
232                //  PROTECT(means2 = allocVector(REALSXP, nvar));
233                //  means = REAL(means2);
234                means = new double[nvar];
235                double[] sd = new double[nvar];
236                //double[] se = new double[nvar];
237
238                //   means = means2;
239                //   PROTECT(u2 = allocVector(REALSXP, nvar));
240                //   u = REAL(u2);
241                u = new double[nvar];
242                //   u = u2;
243//    PROTECT(loglik2 = allocVector(REALSXP, 2));
244//    loglik = REAL(loglik2);
245                loglik = new double[2];
246                //   loglik = loglik2;
247//    PROTECT(sctest2 = allocVector(REALSXP, 1));
248//    sctest = REAL(sctest2);
249//        sctest = new double[1];
250                //   sctest = sctest2;
251//    PROTECT(flag2 = allocVector(INTSXP, 1));
252//    flag = INTEGER(flag2);
253//        flag = new int[1];
254                //     flag = flag2;
255//    PROTECT(iter2 = allocVector(INTSXP, 1));
256//    iterations = INTEGER(iter2);
257//        iterations = new int[1];
258//        iterations = iter2;
259                //       nprotect += 7;
260
261                /*
262                 ** Subtract the mean from each covar, as this makes the regression
263                 **  much more stable.
264                 */
265                double[] time = new double[nused];
266                int[] status = new int[nused];
267                double[] offset = new double[nused];
268                double[] weights = new double[nused];
269                int[] strata = new int[nused];
270
271                double[][] covar = new double[nvar][nused];
272                ArrayList<String> clusterList = null;
273
274                if(cluster){
275                        clusterList = new ArrayList<String>();
276                }
277                //copy data over to local arrays to minimuze changing code
278                for (person = 0; person < nused; person++) {
279                        SurvivalInfo si = data.get(person);
280                        time[person] = si.getTime();
281                        status[person] = si.getStatus();
282                        offset[person] = si.getOffset();
283                        if(cluster){
284                                if(si.getClusterValue() == null && si.getClusterValue().length() == 0){
285                                        throw new Exception("Cluster value is not valid for " + si.toString());
286                                }
287                                clusterList.add(si.getClusterValue());
288                        }
289                        if (useWeighted) {
290                                weights[person] = si.getWeight();
291                        } else {
292                                weights[person] = 1.0;
293                        }
294                        if (useStrata) {
295                                strata[person] = si.getStrata();
296                        } else {
297                                strata[person] = 0;
298                        }
299                        for (i = 0; i < variables.size(); i++) {
300                                String variable = variables.get(i);
301                                covar[i][person] = si.getVariable(variable);
302                        }
303                }
304
305                double tempsd = 0;
306                i = 0;
307                for (i = 0; i < nvar; i++) {
308
309                        temp = 0;
310                        tempsd = 0;
311                        //calculate the mean sd
312
313                        for (person = 0; person < nused; person++) {
314
315                                temp += covar[i][person]; // * weights[person];
316                                tempsd += (covar[i][person]) * (covar[i][person]); //*weights[person] * weights[person]
317                        }
318                        temp /= nused;
319                        //   temp /= weightCount;
320                        means[i] = temp;
321                        tempsd /= nused;
322                        //  tempsd /= weightCount;
323                        tempsd = Math.sqrt(tempsd - temp * temp);
324                        sd[i] = tempsd; //standard deviation
325                        //subtract the mean
326                        for (person = 0; person < nused; person++) {
327                                covar[i][person] -= temp;
328                        }
329                        if (doscale == 1) {  /* and also scale it */
330                                temp = 0;
331                                for (person = 0; person < nused; person++) {
332                                        temp += Math.abs(covar[i][person]); //fabs
333                                }
334                                if (temp > 0) {
335                                        temp = nused / temp;   /* scaling */
336                                } else {
337                                        temp = 1.0; /* rare case of a constant covariate */
338                                }
339                                scale[i] = temp;
340                                for (person = 0; person < nused; person++) {
341                                        covar[i][person] *= temp;
342                                }
343                        }
344                }
345                if (doscale == 1) {
346                        for (i = 0; i < nvar; i++) {
347                                beta[i] /= scale[i]; /*rescale initial betas */
348                        }
349                } else {
350                        for (i = 0; i < nvar; i++) {
351                                scale[i] = 1.0;
352                        }
353                }
354
355                /*
356                 ** do the initial iteration step
357                 */
358                strata[nused - 1] = 1;
359                loglik[1] = 0;
360                for (i = 0; i < nvar; i++) {
361                        u[i] = 0;  //u = s1
362                        a2[i] = 0; //a2 = a
363                        for (j = 0; j < nvar; j++) {
364                                imat[i][j] = 0;  //s2
365                                cmat2[i][j] = 0; //a
366                        }
367                }
368
369                for (person = nused - 1; person >= 0;) {
370                        if (strata[person] == 1) {
371                                nrisk = 0;
372                                denom = 0;
373                                for (i = 0; i < nvar; i++) {
374                                        a[i] = 0;
375                                        for (j = 0; j < nvar; j++) {
376                                                cmat[i][j] = 0;
377                                        }
378                                }
379                        }
380
381                        dtime = time[person];
382                        ndead = 0; /*number of deaths at this time point */
383                        deadwt = 0;  /* sum of weights for the deaths */
384                        efronwt = 0;  /* sum of weighted risks for the deaths */
385                        while (person >= 0 && time[person] == dtime) {
386                                /* walk through the this set of tied times */
387                                nrisk++;
388                                zbeta = offset[person];    /* form the term beta*z (vector mult) */
389                                for (i = 0; i < nvar; i++) {
390                                        zbeta += beta[i] * covar[i][person]; //x
391                                }
392                                zbeta = coxsafe(zbeta);
393
394                                risk = Math.exp(zbeta) * weights[person]; //risk = v
395                                denom += risk;
396
397                                /* a is the vector of weighted sums of x, cmat sums of squares */
398                                for (i = 0; i < nvar; i++) {
399                                        a[i] += risk * covar[i][person]; //a = s1
400                                        for (j = 0; j <= i; j++) {
401                                                cmat[i][j] += risk * covar[i][person] * covar[j][person]; //cmat = s2;
402                                        }
403                                }
404
405                                if (status[person] == 1) {
406                                        ndead++;
407                                        deadwt += weights[person];
408                                        efronwt += risk;
409                                        loglik[1] += weights[person] * zbeta;
410
411                                        for (i = 0; i < nvar; i++) {
412                                                u[i] += weights[person] * covar[i][person];
413                                        }
414                                        if (method == CoxMethod.Efron) { /* Efron */
415                                                for (i = 0; i < nvar; i++) {
416                                                        a2[i] += risk * covar[i][person];
417                                                        for (j = 0; j <= i; j++) {
418                                                                cmat2[i][j] += risk * covar[i][person] * covar[j][person];
419                                                        }
420                                                }
421                                        }
422                                }
423
424                                person--;
425                                if (person >= 0 && strata[person] == 1) { //added catch of person = 0 and person-- = -1
426                                        break;  /*ties don't cross strata */
427                                }
428                        }
429
430
431                        if (ndead > 0) {  /* we need to add to the main terms */
432                                if (method == CoxMethod.Breslow) { /* Breslow */
433                                        loglik[1] -= deadwt * Math.log(denom);
434
435                                        for (i = 0; i < nvar; i++) {
436                                                temp2 = a[i] / denom;  /* mean */
437                                                u[i] -= deadwt * temp2;
438                                                for (j = 0; j <= i; j++) {
439                                                        imat[j][i] += deadwt * (cmat[i][j] - temp2 * a[j]) / denom;
440                                                }
441                                        }
442                                } else { /* Efron */
443                                        /*
444                                         ** If there are 3 deaths we have 3 terms: in the first the
445                                         **  three deaths are all in, in the second they are 2/3
446                                         **  in the sums, and in the last 1/3 in the sum.  Let k go
447                                         **  from 0 to (ndead -1), then we will sequentially use
448                                         **     denom - (k/ndead)*efronwt as the denominator
449                                         **     a - (k/ndead)*a2 as the "a" term
450                                         **     cmat - (k/ndead)*cmat2 as the "cmat" term
451                                         **  and reprise the equations just above.
452                                         */
453                                        for (k = 0; k < ndead; k++) {
454                                                temp = (double) k / ndead;
455                                                wtave = deadwt / ndead;
456                                                d2 = denom - temp * efronwt;
457                                                loglik[1] -= wtave * Math.log(d2);
458                                                for (i = 0; i < nvar; i++) {
459                                                        temp2 = (a[i] - temp * a2[i]) / d2;
460                                                        u[i] -= wtave * temp2;
461                                                        for (j = 0; j <= i; j++) {
462                                                                imat[j][i] += (wtave / d2)
463                                                                                * ((cmat[i][j] - temp * cmat2[i][j])
464                                                                                - temp2 * (a[j] - temp * a2[j]));
465                                                        }
466                                                }
467                                        }
468
469                                        for (i = 0; i < nvar; i++) {
470                                                a2[i] = 0;
471                                                for (j = 0; j < nvar; j++) {
472                                                        cmat2[i][j] = 0;
473                                                }
474                                        }
475                                }
476                        }
477                }   /* end  of accumulation loop */
478                loglik[0] = loglik[1]; /* save the loglik for iterations 0 */
479
480                /* am I done?
481                 **   update the betas and test for convergence
482                 */
483                for (i = 0; i < nvar; i++) /*use 'a' as a temp to save u0, for the score test*/ {
484                        a[i] = u[i];
485                }
486
487                flag = Cholesky2.process(imat, nvar, toler);
488                chsolve2(imat, nvar, a);        /* a replaced by  a *inverse(i) */
489
490                temp = 0;
491                for (i = 0; i < nvar; i++) {
492                        temp += u[i] * a[i];
493                }
494                sctest = temp;  /* score test */
495
496                /*
497                 **  Never, never complain about convergence on the first step.  That way,
498                 **  if someone HAS to they can force one iterations at a time.
499                 */
500                for (i = 0; i < nvar; i++) {
501                        newbeta[i] = beta[i] + a[i];
502                }
503                if (maxiter == 0) {
504                        chinv2(imat, nvar);
505                        for (i = 0; i < nvar; i++) {
506                                beta[i] *= scale[i];  /*return to original scale */
507                                u[i] /= scale[i];
508                                imat[i][i] *= scale[i] * scale[i];
509                                for (j = 0; j < i; j++) {
510                                        imat[j][i] *= scale[i] * scale[j];
511                                        imat[i][j] = imat[j][i];
512                                }
513                        }
514                        // goto finish;
515                        gotofinish = true;
516
517                }
518
519                /*
520                 ** here is the main loop
521                 */
522                if (gotofinish == false) {
523                        halving = 0;             /* =1 when in the midst of "step halving" */
524                        for (iter = 1; iter <= maxiter; iter++) {
525                                newlk = 0;
526                                for (i = 0; i < nvar; i++) {
527                                        u[i] = 0;
528                                        for (j = 0; j < nvar; j++) {
529                                                imat[i][j] = 0;
530                                        }
531                                }
532
533                                /*
534                                 ** The data is sorted from smallest time to largest
535                                 ** Start at the largest time, accumulating the risk set 1 by 1
536                                 */
537                                for (person = nused - 1; person >= 0;) {
538                                        if (strata[person] == 1) { /* rezero temps for each strata */
539                                                denom = 0;
540                                                nrisk = 0;
541                                                for (i = 0; i < nvar; i++) {
542                                                        a[i] = 0;
543                                                        for (j = 0; j < nvar; j++) {
544                                                                cmat[i][j] = 0;
545                                                        }
546                                                }
547                                        }
548
549                                        dtime = time[person];
550                                        deadwt = 0;
551                                        ndead = 0;
552                                        efronwt = 0;
553                                        while (person >= 0 && time[person] == dtime) {
554                                                nrisk++;
555                                                zbeta = offset[person];
556                                                for (i = 0; i < nvar; i++) {
557                                                        zbeta += newbeta[i] * covar[i][person];
558                                                }
559                                                zbeta = coxsafe(zbeta);
560
561
562                                                risk = Math.exp(zbeta) * weights[person];
563                                                denom += risk;
564
565                                                for (i = 0; i < nvar; i++) {
566                                                        a[i] += risk * covar[i][person];
567                                                        for (j = 0; j <= i; j++) {
568                                                                cmat[i][j] += risk * covar[i][person] * covar[j][person];
569                                                        }
570                                                }
571
572                                                if (status[person] == 1) {
573                                                        ndead++;
574                                                        deadwt += weights[person];
575                                                        newlk += weights[person] * zbeta;
576                                                        for (i = 0; i < nvar; i++) {
577                                                                u[i] += weights[person] * covar[i][person];
578                                                        }
579                                                        if (method == CoxMethod.Efron) { /* Efron */
580                                                                efronwt += risk;
581                                                                for (i = 0; i < nvar; i++) {
582                                                                        a2[i] += risk * covar[i][person];
583                                                                        for (j = 0; j <= i; j++) {
584                                                                                cmat2[i][j] += risk * covar[i][person] * covar[j][person];
585                                                                        }
586                                                                }
587                                                        }
588                                                }
589
590                                                person--;
591                                                if (person >= 0 && strata[person] == 1) { //added catch of person = 0 and person-- = -1
592                                                        break;  /*ties don't cross strata */
593                                                }
594                                        }
595
596                                        if (ndead > 0) {  /* add up terms*/
597                                                if (method == CoxMethod.Breslow) { /* Breslow */
598                                                        newlk -= deadwt * Math.log(denom);
599                                                        for (i = 0; i < nvar; i++) {
600                                                                temp2 = a[i] / denom;  /* mean */
601                                                                u[i] -= deadwt * temp2;
602                                                                for (j = 0; j <= i; j++) {
603                                                                        imat[j][i] += (deadwt / denom)
604                                                                                        * (cmat[i][j] - temp2 * a[j]);
605                                                                }
606                                                        }
607                                                } else { /* Efron */
608                                                        for (k = 0; k < ndead; k++) {
609                                                                temp = (double) k / ndead;
610                                                                wtave = deadwt / ndead;
611                                                                d2 = denom - temp * efronwt;
612                                                                newlk -= wtave * Math.log(d2);
613                                                                for (i = 0; i < nvar; i++) {
614                                                                        temp2 = (a[i] - temp * a2[i]) / d2;
615                                                                        u[i] -= wtave * temp2;
616                                                                        for (j = 0; j <= i; j++) {
617                                                                                imat[j][i] += (wtave / d2)
618                                                                                                * ((cmat[i][j] - temp * cmat2[i][j])
619                                                                                                - temp2 * (a[j] - temp * a2[j]));
620                                                                        }
621                                                                }
622                                                        }
623
624                                                        for (i = 0; i < nvar; i++) { /*in anticipation */
625                                                                a2[i] = 0;
626                                                                for (j = 0; j < nvar; j++) {
627                                                                        cmat2[i][j] = 0;
628                                                                }
629                                                        }
630                                                }
631                                        }
632                                }   /* end  of accumulation loop  */
633
634                                /* am I done?
635                                 **   update the betas and test for convergence
636                                 */
637                                flag = Cholesky2.process(imat, nvar, toler);
638
639                                if (Math.abs(1 - (loglik[1] / newlk)) <= eps && halving == 0) { /* all done */
640                                        loglik[1] = newlk;
641                                        chinv2(imat, nvar);     /* invert the information matrix */
642                                        for (i = 0; i < nvar; i++) {
643                                                beta[i] = newbeta[i] * scale[i];
644                                                u[i] /= scale[i];
645                                                imat[i][i] *= scale[i] * scale[i];
646                                                for (j = 0; j < i; j++) {
647                                                        imat[j][i] *= scale[i] * scale[j];
648                                                        imat[i][j] = imat[j][i];
649                                                }
650                                        }
651                                        //  goto finish;
652                                        gotofinish = true;
653                                        break;
654                                }
655
656                                if (iter == maxiter) {
657                                        break;  /*skip the step halving calc*/
658                                }
659
660                                if (newlk < loglik[1]) {    /*it is not converging ! */
661                                        halving = 1;
662                                        for (i = 0; i < nvar; i++) {
663                                                newbeta[i] = (newbeta[i] + beta[i]) / 2; /*half of old increment */
664                                        }
665                                } else {
666                                        halving = 0;
667                                        loglik[1] = newlk;
668                                        chsolve2(imat, nvar, u);
669                                        j = 0;
670                                        for (i = 0; i < nvar; i++) {
671                                                beta[i] = newbeta[i];
672                                                newbeta[i] = newbeta[i] + u[i];
673                                        }
674                                }
675                        }   /* return for another iteration */
676                }
677
678                if (gotofinish == false) {
679                        /*
680                         ** We end up here only if we ran out of iterations
681                         */
682                        loglik[1] = newlk;
683                        chinv2(imat, nvar);
684                        for (i = 0; i < nvar; i++) {
685                                beta[i] = newbeta[i] * scale[i];
686                                u[i] /= scale[i];
687                                imat[i][i] *= scale[i] * scale[i];
688                                for (j = 0; j < i; j++) {
689                                        imat[j][i] *= scale[i] * scale[j];
690                                        imat[i][j] = imat[j][i];
691                                }
692                        }
693                        flag = 1000;
694                }
695
696//finish:
697                /*
698                 for (j = 0; j < numCovariates; j++) {
699                 b[j] = b[j] / SD[j];
700                 * ix = j * (numCovariates + 1) + j
701                 SE[j] = Math.sqrt(a[ix(j, j, numCovariates + 1)]) / SD[j];
702                 //            o = o + ("   " + variables.get(j) + "    " + Fmt(b[j]) + Fmt(SE[j]) + Fmt(Math.exp(b[j])) + Fmt(Norm(Math.abs(b[j] / SE[j]))) + Fmt(Math.exp(b[j] - 1.95 * SE[j])) + Fmt(Math.exp(b[j] + 1.95 * SE[j])) + NL);
703                 CoxCoefficient coe = coxInfo.getCoefficient(variables.get(j));
704                 coe.coeff = b[j];
705                 coe.stdError = SE[j];
706                 coe.hazardRatio = Math.exp(b[j]);
707                 coe.pvalue = Norm(Math.abs(b[j] / SE[j]));
708                 coe.hazardRatioLoCI = Math.exp(b[j] - 1.95 * SE[j]);
709                 coe.hazardRatioHiCI = Math.exp(b[j] + 1.95 * SE[j]);
710                 }
711
712                 */
713
714                coxInfo.setScoreLogrankTest(sctest);
715                coxInfo.setDegreeFreedom(beta.length);
716                coxInfo.setScoreLogrankTestpvalue(ChiSq.chiSq(coxInfo.getScoreLogrankTest(), beta.length));
717                coxInfo.setVariance(imat);
718                coxInfo.u = u;
719
720                //     for (int n = 0; n < beta.length; n++) {
721                //         se[n] = Math.sqrt(imat[n][n]); // / sd[n];
722                //     }
723
724
725                //       System.out.println("coef,se, means,u");
726                for (int n = 0; n < beta.length; n++) {
727                        CoxCoefficient coe = new CoxCoefficient();
728                        coe.name = variables.get(n);
729                        coe.mean = means[n];
730                        coe.standardDeviation = sd[n];
731                        coe.coeff = beta[n];
732                        coe.stdError = Math.sqrt(imat[n][n]);
733                        coe.hazardRatio = Math.exp(coe.getCoeff());
734                        coe.z = coe.getCoeff() / coe.getStdError();
735                        coe.pvalue = ChiSq.norm(Math.abs(coe.getCoeff() / coe.getStdError()));
736                        double z = 1.959964;
737                        coe.hazardRatioLoCI = Math.exp(coe.getCoeff() - z * coe.getStdError());
738                        coe.hazardRatioHiCI = Math.exp(coe.getCoeff() + z * coe.getStdError());
739
740                        coxInfo.setCoefficient(coe.getName(), coe);
741                        // System.out.println(beta[n] + "," + se[n] + "," + means[n] + "," + sd[n] + "," + u[n]); //+ "," + imat[n] "," + loglik[n] + "," + sctest[n] + "," + iterations[n] + "," + flag[n]
742
743                }
744
745                coxInfo.maxIterations = maxiter;
746                coxInfo.eps = eps;
747                coxInfo.toler = toler;
748
749                coxInfo.iterations = iter;
750                coxInfo.flag = flag;
751                coxInfo.loglikInit = loglik[0];
752                coxInfo.loglikFinal = loglik[1];
753                coxInfo.method = method;
754
755                //    System.out.println("loglik[0]=" + loglik[0]);
756                //    System.out.println("loglik[1]=" + loglik[1]);
757
758                //    System.out.println("chisq? sctest[0]=" + sctest[0]);
759                //    System.out.println("?overall model p-value=" + chiSq(sctest[0], beta.length));
760
761
762                //      System.out.println();
763                //       for (int n = 0; n < covar[0].length; n++) {
764                //           System.out.print(n);
765                //           for (int variable = 0; variable < covar.length; variable++) {
766                //               System.out.print("\t" + covar[variable][n]);
767
768                //           }
769                //           System.out.println();
770                //       }
771                //      for (SurvivalInfo si : data) {
772                //          System.out.println(si.order + " " + si.getScore());
773                //      }
774//        coxInfo.dump();
775
776
777                coxphfitSCleanup(coxInfo, useWeighted, robust,clusterList);
778                return coxInfo;
779        }
780
781        /**
782         *
783         * @param ci
784         * @param useWeighted
785         * @param robust
786         * @param cluster
787         * @throws Exception
788         */
789        public void coxphfitSCleanup(CoxInfo ci, boolean useWeighted,boolean robust, ArrayList<String> cluster) throws Exception {
790                //Do cleanup found after coxfit6 is called in coxph.fit.S
791                //infs <- abs(coxfit$u %*% var)
792                //[ a1 b1] * [a1 b1]
793                //           [a2 b2]
794                double[][] du = new double[1][ci.u.length];
795                du[0] = ci.u;
796                double[] infs = Matrix.abs(Matrix.multiply(ci.u, ci.getVariance()));
797//        StdArrayIO.print(infs);
798
799                ArrayList<CoxCoefficient> coxCoefficients = new ArrayList<CoxCoefficient>(ci.getCoefficientsList().values());
800
801                for (int i = 0; i < infs.length; i++) {
802                        double inf = infs[i];
803                        double coe = coxCoefficients.get(i).getCoeff();
804                        if (inf > ci.eps && inf > (ci.toler * Math.abs(coe))) {
805                                ci.message = "Loglik converged before variable ";
806                        }
807                }
808
809                //sum(coef*coxfit$means)
810                double sumcoefmeans = 0;
811                for (CoxCoefficient cc : coxCoefficients) {
812                        sumcoefmeans = sumcoefmeans + cc.getCoeff() * cc.getMean();
813                }
814
815                // coxph.fit.S line 107
816                //lp <- c(x %*% coef) + offset - sum(coef*coxfit$means)
817                for (SurvivalInfo si : ci.survivalInfoList) {
818                        double offset = si.getOffset();
819                        double lp = 0;
820                        for (CoxCoefficient cc : coxCoefficients) {
821                                String name = cc.getName();
822                                double coef = cc.getCoeff();
823                                double value = si.getVariable(name);
824                                lp = lp + value * coef;
825                        }
826                        lp = lp + offset - sumcoefmeans;
827                        si.setLinearPredictor(lp);
828                        si.setScore(Math.exp(lp));
829
830//           System.out.println("lp score " + si.order + " " + si.time + " " + si.getWeight() + " " + si.getClusterValue() + " " + lp + " " + Math.exp(lp));
831                }
832//       ci.dump();
833                //begin code after call to coxfit6 in coxph.fit.S
834                //Compute the martingale residual for a Cox model
835                // appears to be C syntax error for = - vs -=
836                //(if (nullmodel) in coxph.fit
837                double[] res = CoxMart.process(ci.method, ci.survivalInfoList, false);
838
839                for(int i = 0; i < ci.survivalInfoList.size(); i++){
840                        SurvivalInfo si = ci.survivalInfoList.get(i);
841                        si.setResidual(res[i]);
842                }
843
844                //this represents the end of coxph.fit.S code and we pickup
845                //after call to fit <- fitter(X, Y, strats ....) in coxph.R
846
847                if (robust) {
848                        ci.setNaiveVariance(ci.getVariance());
849                        double[][] temp;
850                        double[][] temp0;
851
852                        if (cluster != null) {
853
854                                temp = ResidualsCoxph.process(ci, ResidualsCoxph.Type.dfbeta, useWeighted, cluster);
855                                //# get score for null model
856                                //    if (is.null(init))
857                                //          fit2$linear.predictors <- 0*fit$linear.predictors
858                                //    else
859                                //          fit2$linear.predictors <- c(X %*% init)
860                                //Set score to 1
861
862                                double[] templp = new double[ci.survivalInfoList.size()];
863                                double[] tempscore = new double[ci.survivalInfoList.size()];
864                                int i = 0;
865                                for (SurvivalInfo si : ci.survivalInfoList) {
866                                        templp[i] = si.getLinearPredictor();
867                                        tempscore[i] = si.getScore();
868                                        si.setLinearPredictor(0);
869                                        si.setScore(1.0); //this erases stored value which isn't how the R code does it
870                                        i++;
871                                }
872
873                                temp0 = ResidualsCoxph.process(ci, ResidualsCoxph.Type.score, useWeighted, cluster);
874
875                                i = 0;
876                                for (SurvivalInfo si : ci.survivalInfoList) {
877                                        si.setLinearPredictor(templp[i]);
878                                        si.setScore(tempscore[i]); //this erases stored value which isn't how the R code does it
879                                        i++;
880                                }
881
882
883                        } else {
884                                temp = ResidualsCoxph.process(ci, ResidualsCoxph.Type.dfbeta, useWeighted, null);
885                                //     fit2$linear.predictors <- 0*fit$linear.predictors
886                                double[] templp = new double[ci.survivalInfoList.size()];
887                                double[] tempscore = new double[ci.survivalInfoList.size()];
888                                int i = 0;
889                                for (SurvivalInfo si : ci.survivalInfoList) {
890                                        templp[i] = si.getLinearPredictor();
891                                        tempscore[i] = si.getScore();
892                                        si.setLinearPredictor(0);
893                                        si.setScore(1.0);
894                                }
895                                temp0 = ResidualsCoxph.process(ci, ResidualsCoxph.Type.score, useWeighted, null);
896
897                                i = 0;
898                                for (SurvivalInfo si : ci.survivalInfoList) {
899                                        si.setLinearPredictor(templp[i]);
900                                        si.setScore(tempscore[i]); //this erases stored value which isn't how the R code does it
901                                        i++;
902                                }
903                        }
904                        //fit$var<- t(temp) % * % temp
905                        double[][] ttemp = Matrix.transpose(temp);
906                        double[][] var = Matrix.multiply(ttemp, temp);
907                        ci.setVariance(var);
908                        //u<- apply(as.matrix(temp0), 2, sum)
909                        double[] u = new double[temp0[0].length];
910                        for (int i = 0; i < temp0[0].length; i++) {
911                                for (int j = 0; j < temp0.length; j++) {
912                                        u[i] = u[i] + temp0[j][i];
913                                }
914                        }
915                        //fit$rscore <- coxph.wtest(t(temp0)%*%temp0, u, control$toler.chol)$test
916                        double[][] wtemp = Matrix.multiply(Matrix.transpose(temp0),temp0);
917                        double toler_chol = 1.818989e-12;
918                  //  toler_chol = ci.toler;
919                        WaldTestInfo wti = WaldTest.process(wtemp,u,toler_chol);
920                        //not giving the correct value
921                        ci.setRscore(wti.getTest());
922                }
923                calculateWaldTestInfo(ci);
924
925
926
927
928        }
929
930        static public void calculateWaldTestInfo(CoxInfo ci){
931                if(ci.getNumberCoefficients() > 0){
932                        double toler_chol = 1.818989e-12;
933                  //  toler_chol = ci.toler;
934                        double[][] b = new double[1][ci.getNumberCoefficients()];
935                        int i = 0;
936                        for(CoxCoefficient coe : ci.getCoefficientsList().values()){
937                                b[0][i] = coe.getCoeff();
938                                i++;
939                        }
940                        ci.setWaldTestInfo(WaldTest.process(ci.getVariance(), b, toler_chol));
941                }
942        }
943
944        /**
945         * @param args the command line arguments
946         */
947        public static void main(String[] args) {
948                // TODO code application logic here
949                CoxR coxr = new CoxR();
950
951
952                if (true) {
953                        try {
954                           InputStream is = coxr.getClass().getClassLoader().getResourceAsStream("uis-complete.txt");
955
956
957
958
959                                WorkSheet worksheet = WorkSheet.readCSV(is, '\t');
960                                ArrayList<SurvivalInfo> survivalInfoList = new ArrayList<SurvivalInfo>();
961                                int i = 0;
962                                for (String row : worksheet.getRows()) {
963                                        double time = worksheet.getCellDouble(row, "TIME");
964                                        double age = worksheet.getCellDouble(row, "AGE");
965                                        double treat = worksheet.getCellDouble(row, "TREAT");
966                                        double c = worksheet.getCellDouble(row, "CENSOR");
967                                        int censor = (int) c;
968
969                                        SurvivalInfo si = new SurvivalInfo(time, censor);
970                                        si.setOrder(i);
971                                        si.addContinuousVariable("AGE", age);
972                                        si.addContinuousVariable("TREAT", treat);
973
974                                        survivalInfoList.add(si);
975                                        i++;
976                                }
977
978                                CoxR cox = new CoxR();
979                                ArrayList<String> variables = new ArrayList<String>();
980                                //               variables.add("AGE");
981
982                                variables.add("AGE");
983                                variables.add("TREAT");
984
985                                //       variables.add("TREAT:AGE");
986                          //  ArrayList<Integer> cluster = new ArrayList<Integer>();
987                                CoxInfo ci = cox.process(variables, survivalInfoList, false, true,false, false);
988                                System.out.println(ci);
989                        } catch (Exception e) {
990                                e.printStackTrace();
991                        }
992
993
994                }
995
996//        if (false) {
997//
998//            try {
999//
1000//
1001//                WorkSheet worksheet = WorkSheet.readCSV("/Users/Scooter/NetBeansProjects/AssayWorkbench/src/edu/scripps/assayworkbench/cox/uis-complete.txt", '\t');
1002//                ArrayList<String> rows = worksheet.getRows();
1003//                ArrayList<String> variables = new ArrayList<String>();
1004//                variables.add("AGE");
1005//                variables.add("TREAT");
1006//                double[] time2 = new double[rows.size()];
1007//                int[] status2 = new int[rows.size()];
1008//                double[][] covar2 = new double[variables.size()][rows.size()];
1009//                double[] offset2 = new double[rows.size()];
1010//                double[] weights2 = new double[rows.size()];
1011//                int[] strata2 = new int[rows.size()];
1012//
1013//
1014//                for (int i = 0; i < rows.size(); i++) {
1015//                    String row = rows.get(i);
1016//                    double time = worksheet.getCellDouble(row, "TIME");
1017//                    //      double age = worksheet.getCellDouble(row, "AGE");
1018//                    //      double treat = worksheet.getCellDouble(row, "TREAT");
1019//                    double c = worksheet.getCellDouble(row, "CENSOR");
1020//                    int censor = (int) c;
1021//
1022//                    time2[i] = time;
1023//                    status2[i] = censor;
1024//                    offset2[i] = 0;
1025//                    weights2[i] = 1;
1026//                    strata2[i] = 0;
1027//
1028//                    for (int j = 0; j < variables.size(); j++) {
1029//                        String variable = variables.get(j);
1030//                        double v = worksheet.getCellDouble(row, variable);
1031//                        covar2[j][i] = v;
1032//                    }
1033//
1034//
1035//                }
1036//                //from coxph.control.S
1037//                int maxiter2 = 20;
1038//                double eps2 = 1e-9;
1039//                double toler2 = Math.pow(eps2, .75);
1040//                int doscale2 = 1;
1041//                int method2 = 0;
1042//                //toler.chol = eps ^ .75
1043//                //toler.inf=sqrt(eps)
1044//                //outer.max=10
1045//
1046//                CoxR cox = new CoxR();
1047//                //        cox.coxfit6(maxiter2, time2, status2, covar2, offset2, weights2, strata2, method2, eps2, toler2, time2, doscale2);
1048//
1049//
1050//
1051//
1052//
1053//            } catch (Exception e) {
1054//                e.printStackTrace();
1055//            }
1056//        }
1057
1058        }
1059
1060        /* $Id: chinv2.c 11357 2009-09-04 15:22:46Z therneau $
1061         **
1062         ** matrix inversion, given the FDF' cholesky decomposition
1063         **
1064         ** input  **matrix, which contains the chol decomp of an n by n
1065         **   matrix in its lower triangle.
1066         **
1067         ** returned: the upper triangle + diagonal contain (FDF')^{-1}
1068         **            below the diagonal will be F inverse
1069         **
1070         **  Terry Therneau
1071         */
1072        void chinv2(double[][] matrix, int n) {
1073                double temp;
1074                int i, j, k;
1075
1076                /*
1077                 ** invert the cholesky in the lower triangle
1078                 **   take full advantage of the cholesky's diagonal of 1's
1079                 */
1080                for (i = 0; i < n; i++) {
1081                        if (matrix[i][i] > 0) {
1082                                matrix[i][i] = 1 / matrix[i][i];   /*this line inverts D */
1083                                for (j = (i + 1); j < n; j++) {
1084                                        matrix[j][i] = -matrix[j][i];
1085                                        for (k = 0; k < i; k++) /*sweep operator */ {
1086                                                matrix[j][k] += matrix[j][i] * matrix[i][k];
1087                                        }
1088                                }
1089                        }
1090                }
1091
1092                /*
1093                 ** lower triangle now contains inverse of cholesky
1094                 ** calculate F'DF (inverse of cholesky decomp process) to get inverse
1095                 **   of original matrix
1096                 */
1097                for (i = 0; i < n; i++) {
1098                        if (matrix[i][i] == 0) {  /* singular row */
1099                                for (j = 0; j < i; j++) {
1100                                        matrix[j][i] = 0;
1101                                }
1102                                for (j = i; j < n; j++) {
1103                                        matrix[i][j] = 0;
1104                                }
1105                        } else {
1106                                for (j = (i + 1); j < n; j++) {
1107                                        temp = matrix[j][i] * matrix[j][j];
1108                                        if (j != i) {
1109                                                matrix[i][j] = temp;
1110                                        }
1111                                        for (k = i; k < j; k++) {
1112                                                matrix[i][k] += temp * matrix[j][k];
1113                                        }
1114                                }
1115                        }
1116                }
1117        }
1118
1119        /*  $Id: chsolve2.c 11376 2009-12-14 22:53:57Z therneau $
1120         **
1121         ** Solve the equation Ab = y, where the cholesky decomposition of A and y
1122         **   are the inputs.
1123         **
1124         ** Input  **matrix, which contains the chol decomp of an n by n
1125         **   matrix in its lower triangle.
1126         **        y[n] contains the right hand side
1127         **
1128         **  y is overwriten with b
1129         **
1130         **  Terry Therneau
1131         */
1132        void chsolve2(double[][] matrix, int n, double[] y) {
1133                int i, j;
1134                double temp;
1135
1136                /*
1137                 ** solve Fb =y
1138                 */
1139                for (i = 0; i < n; i++) {
1140                        temp = y[i];
1141                        for (j = 0; j < i; j++) {
1142                                temp -= y[j] * matrix[i][j];
1143                        }
1144                        y[i] = temp;
1145                }
1146                /*
1147                 ** solve DF'z =b
1148                 */
1149                for (i = (n - 1); i >= 0; i--) {
1150                        if (matrix[i][i] == 0) {
1151                                y[i] = 0;
1152                        } else {
1153                                temp = y[i] / matrix[i][i];
1154                                for (j = i + 1; j < n; j++) {
1155                                        temp -= y[j] * matrix[j][i];
1156                                }
1157                                y[i] = temp;
1158                        }
1159                }
1160        }
1161
1162
1163
1164        /**
1165         *
1166         * @param x
1167         * @return
1168         */
1169        public double coxsafe(double x) {
1170                if (x < -200) {
1171                        return -200;
1172                }
1173                if (x > 22) {
1174                        return 22;
1175                }
1176                return x;
1177        }
1178}