001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.cox;
022
023import org.biojava.nbio.survival.cox.matrix.Matrix;
024import org.biojava.nbio.survival.cox.stats.ChiSq;
025import org.biojava.nbio.survival.cox.stats.Cholesky2;
026import org.biojava.nbio.survival.data.WorkSheet;
027
028import java.io.InputStream;
029import java.util.ArrayList;
030import java.util.Collections;
031
032/**
033 *   This is a port of the R survival code used for doing Cox Regression. The algorithm was a fairly easy port from C code to Java where the challenge was
034 *   making the code a little more object friendly. In the R code everything is passed around as an array and a large portion of the code is spent extracting
035 *   data from the array for use in different calculations. By organizing the data in a class for each data point was able to simplify much of the code.
036 *   Not all variants of different methods that you can select for doing various statistical calculations are implemented. Wouldn't be difficult to go back in
037 *   add them in if they are important.
038 *
039 *<p>In R you can pass in different paramaters to override defaults which requires parsing of the paramaters. In the Java code tried to be a little more exact
040 *   in the code related to paramaters where using strata, weighting, robust and cluster are advance options. Additionaly code is implemented from Bob Gray
041 *   to do variance correction when using weighted paramaters in a data set.
042 *   /Users/Scooter/NetBeansProjects/biojava3-survival/docs/wtexamples.docx
043 *
044 *<p>The CoxHelper class is meant to hide some of the implementation details.
045 *
046 *<p>Issues
047 *<ul>
048 *<li>sign in CoxMart?
049 *<li>double toler_chol = 1.818989e-12; Different value for some reason
050 *<li>In robust linear_predictor set to 0 which implies score = 1 but previous score value doesn't get reset
051 *</ul>
052 *
053 *  Cox regression fit, replacement for coxfit2 in order
054 *    to be more frugal about memory: specificly that we
055 *    don't make copies of the input data.
056 *
057 *
058 *
059 *
060 *
061 * <p>
062 * the input parameters are
063 *
064 * <pre>
065 *      maxiter      :number of iterations
066 *      time(n)      :time of status or censoring for person i
067 *      status(n)    :status for the ith person    1=dead , 0=censored
068 *      covar(nv,n)  :covariates for person i.
069 *                       Note that S sends this in column major order.
070 *      strata(n)    :marks the strata.  Will be 1 if this person is the
071 *                      last one in a strata.  If there are no strata, the
072 *                      vector can be identically zero, since the nth person's
073 *                      value is always assumed to be = to 1.
074 *      offset(n)    :offset for the linear predictor
075 *      weights(n)   :case weights
076 *      init         :initial estimate for the coefficients
077 *      eps          :tolerance for convergence.  Iteration continues until
078 *                      the percent change in loglikelihood is &lt;= eps.
079 *      chol_tol     : tolerance for the Cholesky decompostion
080 *      method       : 0=Breslow, 1=Efron
081 *      doscale      : 0=don't scale the X matrix, 1=scale the X matrix
082 * </pre>
083 * returned parameters
084 * <pre>
085 *      means(nv)    : vector of column means of X
086 *      beta(nv)     :the vector of answers (at start contains initial est)
087 *      u(nv)        :score vector
088 *      imat(nv,nv)  :the variance matrix at beta=final
089 *                     (returned as a vector)
090 *      loglik(2)    :loglik at beta=initial values, at beta=final
091 *      sctest       :the score test at beta=initial
092 *      flag         :success flag  1000  did not converge
093 *                                  1 to nvar: rank of the solution
094 *      iterations         :actual number of iterations used
095 * </pre>
096 * work arrays
097 * <pre>
098 *      mark(n)
099 *      wtave(n)
100 *      a(nvar), a2(nvar)
101 *      cmat(nvar,nvar)       ragged array
102 *      cmat2(nvar,nvar)
103 *      newbeta(nvar)         always contains the "next iteration"
104 * </pre>
105 * calls functions:  cholesky2, chsolve2, chinv2
106 * <p>
107 * the data must be sorted by ascending time within strata
108 *
109 * @author Scooter Willis 
110 */
111public class CoxR {
112
113        /**
114         *
115         * @param variables
116         * @param DataT
117         * @param useStrata
118         * @param useWeighted
119         * @param robust
120         * @param cluster
121         * @return
122         * @throws Exception
123         */
124        public CoxInfo process(ArrayList<String> variables, ArrayList<SurvivalInfo> DataT, boolean useStrata, boolean useWeighted, boolean robust, boolean cluster) throws Exception {
125                //from coxph.control.S
126                int maxiter2 = 20;
127                double eps2 = 1e-9;
128                double toler2 = Math.pow(eps2, .75);
129                int doscale2 = 1;
130                //int method2 = 0;
131                double[] beta = new double[variables.size()];
132                return process(variables, DataT, maxiter2, CoxMethod.Efron, eps2, toler2, beta, doscale2, useStrata, useWeighted, robust, cluster);
133
134        }
135
136        /**
137         *
138         * @param variables
139         * @param data
140         * @param maxiter
141         * @param method
142         * @param eps
143         * @param toler
144         * @param beta
145         * @param doscale
146         * @param useStrata
147         * @param useWeighted
148         * @param robust
149         * @param cluster
150         * @return
151         * @throws Exception
152         */
153        public CoxInfo process(ArrayList<String> variables, ArrayList<SurvivalInfo> data, int maxiter, CoxMethod method, double eps, double toler, double[] beta, int doscale, boolean useStrata, boolean useWeighted, boolean robust, boolean cluster) throws Exception {
154                //make sure data is converted to numbers if labels are used
155                SurvivalInfoHelper.categorizeData(data);
156                //create variables if testing for interaction
157                for (String variable : variables) {
158                        if (variable.indexOf(":") != -1) {
159                                String[] d = variable.split(":");
160                                SurvivalInfoHelper.addInteraction(d[0], d[1], data);
161                        }
162                }
163
164                Collections.sort(data);
165                // Collections.reverse(data);
166                CoxInfo coxInfo = new CoxInfo();
167                coxInfo.setSurvivalInfoList(data);
168
169
170
171                int i, j, k, person;
172                boolean gotofinish = false;
173                double[][] cmat, imat;  /*ragged arrays covar[][], */
174                double wtave;
175                double[] a, newbeta;
176                double[] a2;
177                double[][] cmat2;
178                double[] scale;
179                double denom = 0, zbeta, risk;
180                double temp, temp2;
181                int ndead;  /* actually, the sum of their weights */
182                double newlk = 0;
183                double dtime, d2;
184                double deadwt;  /*sum of case weights for the deaths*/
185                double efronwt; /* sum of weighted risk scores for the deaths*/
186                int halving;    /*are we doing step halving at the moment? */
187                @SuppressWarnings("unused")
188                int nrisk = 0;   /* number of subjects in the current risk set */
189
190                /* copies of scalar input arguments */
191                int nused, nvar;
192
193
194
195
196                /* vector inputs */
197                //  double *time, *weights, *offset;
198                //  int *status, *strata;
199
200                /* returned objects */
201                // double imat2[][];
202                double[] u, loglik, means;
203
204
205                double sctest;
206                int flag = 0;
207                int iter = 0;
208                //SEXP rlist, rlistnames;
209                //  int nprotect;  /* number of protect calls I have issued */
210
211                /* get local copies of some input args */
212                nused = data.size(); // LENGTH(offset2);
213                nvar = variables.size(); // ncols(covar2);
214
215
216                //       imat2 = new double[nvar][nvar];
217//        nprotect++;
218                imat = new double[nvar][nvar]; //dmatrix(REAL(imat2),  nvar, nvar);
219                a = new double[nvar]; //(double *) R_alloc(2*nvar*nvar + 4*nvar, sizeof(double));
220                newbeta = new double[nvar]; //a + nvar;
221                a2 = new double[nvar]; //newbeta + nvar;
222                scale = new double[nvar]; //a2 + nvar;
223                cmat = new double[nvar][nvar]; //dmatrix(scale + nvar,   nvar, nvar);
224                cmat2 = new double[nvar][nvar]; //dmatrix(scale + nvar +nvar*nvar, nvar, nvar);
225
226                /*
227                 ** create output variables
228                 */
229//    PROTECT(beta2 = duplicate(ibeta));
230//    beta = REAL(beta2);
231                //  beta = new double[nvar];
232                // beta = beta2;
233                //  PROTECT(means2 = allocVector(REALSXP, nvar));
234                //  means = REAL(means2);
235                means = new double[nvar];
236                double[] sd = new double[nvar];
237                //double[] se = new double[nvar];
238
239                //   means = means2;
240                //   PROTECT(u2 = allocVector(REALSXP, nvar));
241                //   u = REAL(u2);
242                u = new double[nvar];
243                //   u = u2;
244//    PROTECT(loglik2 = allocVector(REALSXP, 2));
245//    loglik = REAL(loglik2);
246                loglik = new double[2];
247                //   loglik = loglik2;
248//    PROTECT(sctest2 = allocVector(REALSXP, 1));
249//    sctest = REAL(sctest2);
250//        sctest = new double[1];
251                //   sctest = sctest2;
252//    PROTECT(flag2 = allocVector(INTSXP, 1));
253//    flag = INTEGER(flag2);
254//        flag = new int[1];
255                //     flag = flag2;
256//    PROTECT(iter2 = allocVector(INTSXP, 1));
257//    iterations = INTEGER(iter2);
258//        iterations = new int[1];
259//        iterations = iter2;
260                //       nprotect += 7;
261
262                /*
263                 ** Subtract the mean from each covar, as this makes the regression
264                 **  much more stable.
265                 */
266                double[] time = new double[nused];
267                int[] status = new int[nused];
268                double[] offset = new double[nused];
269                double[] weights = new double[nused];
270                int[] strata = new int[nused];
271
272                double[][] covar = new double[nvar][nused];
273                ArrayList<String> clusterList = null;
274
275                if(cluster){
276                        clusterList = new ArrayList<>();
277                }
278                //copy data over to local arrays to minimuze changing code
279                for (person = 0; person < nused; person++) {
280                        SurvivalInfo si = data.get(person);
281                        time[person] = si.getTime();
282                        status[person] = si.getStatus();
283                        offset[person] = si.getOffset();
284                        if(cluster){
285                                if(si.getClusterValue() == null && si.getClusterValue().length() == 0){
286                                        throw new Exception("Cluster value is not valid for " + si.toString());
287                                }
288                                clusterList.add(si.getClusterValue());
289                        }
290                        if (useWeighted) {
291                                weights[person] = si.getWeight();
292                        } else {
293                                weights[person] = 1.0;
294                        }
295                        if (useStrata) {
296                                strata[person] = si.getStrata();
297                        } else {
298                                strata[person] = 0;
299                        }
300                        for (i = 0; i < variables.size(); i++) {
301                                String variable = variables.get(i);
302                                covar[i][person] = si.getVariable(variable);
303                        }
304                }
305
306                double tempsd = 0;
307                i = 0;
308                for (i = 0; i < nvar; i++) {
309
310                        temp = 0;
311                        tempsd = 0;
312                        //calculate the mean sd
313
314                        for (person = 0; person < nused; person++) {
315
316                                temp += covar[i][person]; // * weights[person];
317                                tempsd += (covar[i][person]) * (covar[i][person]); //*weights[person] * weights[person]
318                        }
319                        temp /= nused;
320                        //   temp /= weightCount;
321                        means[i] = temp;
322                        tempsd /= nused;
323                        //  tempsd /= weightCount;
324                        tempsd = Math.sqrt(tempsd - temp * temp);
325                        sd[i] = tempsd; //standard deviation
326                        //subtract the mean
327                        for (person = 0; person < nused; person++) {
328                                covar[i][person] -= temp;
329                        }
330                        if (doscale == 1) {  /* and also scale it */
331                                temp = 0;
332                                for (person = 0; person < nused; person++) {
333                                        temp += Math.abs(covar[i][person]); //fabs
334                                }
335                                if (temp > 0) {
336                                        temp = nused / temp;   /* scaling */
337                                } else {
338                                        temp = 1.0; /* rare case of a constant covariate */
339                                }
340                                scale[i] = temp;
341                                for (person = 0; person < nused; person++) {
342                                        covar[i][person] *= temp;
343                                }
344                        }
345                }
346                if (doscale == 1) {
347                        for (i = 0; i < nvar; i++) {
348                                beta[i] /= scale[i]; /*rescale initial betas */
349                        }
350                } else {
351                        for (i = 0; i < nvar; i++) {
352                                scale[i] = 1.0;
353                        }
354                }
355
356                /*
357                 ** do the initial iteration step
358                 */
359                strata[nused - 1] = 1;
360                loglik[1] = 0;
361                for (i = 0; i < nvar; i++) {
362                        u[i] = 0;  //u = s1
363                        a2[i] = 0; //a2 = a
364                        for (j = 0; j < nvar; j++) {
365                                imat[i][j] = 0;  //s2
366                                cmat2[i][j] = 0; //a
367                        }
368                }
369
370                for (person = nused - 1; person >= 0;) {
371                        if (strata[person] == 1) {
372                                nrisk = 0;
373                                denom = 0;
374                                for (i = 0; i < nvar; i++) {
375                                        a[i] = 0;
376                                        for (j = 0; j < nvar; j++) {
377                                                cmat[i][j] = 0;
378                                        }
379                                }
380                        }
381
382                        dtime = time[person];
383                        ndead = 0; /*number of deaths at this time point */
384                        deadwt = 0;  /* sum of weights for the deaths */
385                        efronwt = 0;  /* sum of weighted risks for the deaths */
386                        while (person >= 0 && time[person] == dtime) {
387                                /* walk through the this set of tied times */
388                                nrisk++;
389                                zbeta = offset[person];    /* form the term beta*z (vector mult) */
390                                for (i = 0; i < nvar; i++) {
391                                        zbeta += beta[i] * covar[i][person]; //x
392                                }
393                                zbeta = coxsafe(zbeta);
394
395                                risk = Math.exp(zbeta) * weights[person]; //risk = v
396                                denom += risk;
397
398                                /* a is the vector of weighted sums of x, cmat sums of squares */
399                                for (i = 0; i < nvar; i++) {
400                                        a[i] += risk * covar[i][person]; //a = s1
401                                        for (j = 0; j <= i; j++) {
402                                                cmat[i][j] += risk * covar[i][person] * covar[j][person]; //cmat = s2;
403                                        }
404                                }
405
406                                if (status[person] == 1) {
407                                        ndead++;
408                                        deadwt += weights[person];
409                                        efronwt += risk;
410                                        loglik[1] += weights[person] * zbeta;
411
412                                        for (i = 0; i < nvar; i++) {
413                                                u[i] += weights[person] * covar[i][person];
414                                        }
415                                        if (method == CoxMethod.Efron) { /* Efron */
416                                                for (i = 0; i < nvar; i++) {
417                                                        a2[i] += risk * covar[i][person];
418                                                        for (j = 0; j <= i; j++) {
419                                                                cmat2[i][j] += risk * covar[i][person] * covar[j][person];
420                                                        }
421                                                }
422                                        }
423                                }
424
425                                person--;
426                                if (person >= 0 && strata[person] == 1) { //added catch of person = 0 and person-- = -1
427                                        break;  /*ties don't cross strata */
428                                }
429                        }
430
431
432                        if (ndead > 0) {  /* we need to add to the main terms */
433                                if (method == CoxMethod.Breslow) { /* Breslow */
434                                        loglik[1] -= deadwt * Math.log(denom);
435
436                                        for (i = 0; i < nvar; i++) {
437                                                temp2 = a[i] / denom;  /* mean */
438                                                u[i] -= deadwt * temp2;
439                                                for (j = 0; j <= i; j++) {
440                                                        imat[j][i] += deadwt * (cmat[i][j] - temp2 * a[j]) / denom;
441                                                }
442                                        }
443                                } else { /* Efron */
444                                        /*
445                                         ** If there are 3 deaths we have 3 terms: in the first the
446                                         **  three deaths are all in, in the second they are 2/3
447                                         **  in the sums, and in the last 1/3 in the sum.  Let k go
448                                         **  from 0 to (ndead -1), then we will sequentially use
449                                         **     denom - (k/ndead)*efronwt as the denominator
450                                         **     a - (k/ndead)*a2 as the "a" term
451                                         **     cmat - (k/ndead)*cmat2 as the "cmat" term
452                                         **  and reprise the equations just above.
453                                         */
454                                        for (k = 0; k < ndead; k++) {
455                                                temp = (double) k / ndead;
456                                                wtave = deadwt / ndead;
457                                                d2 = denom - temp * efronwt;
458                                                loglik[1] -= wtave * Math.log(d2);
459                                                for (i = 0; i < nvar; i++) {
460                                                        temp2 = (a[i] - temp * a2[i]) / d2;
461                                                        u[i] -= wtave * temp2;
462                                                        for (j = 0; j <= i; j++) {
463                                                                imat[j][i] += (wtave / d2)
464                                                                                * ((cmat[i][j] - temp * cmat2[i][j])
465                                                                                - temp2 * (a[j] - temp * a2[j]));
466                                                        }
467                                                }
468                                        }
469
470                                        for (i = 0; i < nvar; i++) {
471                                                a2[i] = 0;
472                                                for (j = 0; j < nvar; j++) {
473                                                        cmat2[i][j] = 0;
474                                                }
475                                        }
476                                }
477                        }
478                }   /* end  of accumulation loop */
479                loglik[0] = loglik[1]; /* save the loglik for iterations 0 */
480
481                /* am I done?
482                 **   update the betas and test for convergence
483                 */
484                for (i = 0; i < nvar; i++) /*use 'a' as a temp to save u0, for the score test*/ {
485                        a[i] = u[i];
486                }
487
488                flag = Cholesky2.process(imat, nvar, toler);
489                chsolve2(imat, nvar, a);        /* a replaced by  a *inverse(i) */
490
491                temp = 0;
492                for (i = 0; i < nvar; i++) {
493                        temp += u[i] * a[i];
494                }
495                sctest = temp;  /* score test */
496
497                /*
498                 **  Never, never complain about convergence on the first step.  That way,
499                 **  if someone HAS to they can force one iterations at a time.
500                 */
501                for (i = 0; i < nvar; i++) {
502                        newbeta[i] = beta[i] + a[i];
503                }
504                if (maxiter == 0) {
505                        chinv2(imat, nvar);
506                        for (i = 0; i < nvar; i++) {
507                                beta[i] *= scale[i];  /*return to original scale */
508                                u[i] /= scale[i];
509                                imat[i][i] *= scale[i] * scale[i];
510                                for (j = 0; j < i; j++) {
511                                        imat[j][i] *= scale[i] * scale[j];
512                                        imat[i][j] = imat[j][i];
513                                }
514                        }
515                        // goto finish;
516                        gotofinish = true;
517
518                }
519
520                /*
521                 ** here is the main loop
522                 */
523                if (!gotofinish) {
524                        halving = 0;             /* =1 when in the midst of "step halving" */
525                        for (iter = 1; iter <= maxiter; iter++) {
526                                newlk = 0;
527                                for (i = 0; i < nvar; i++) {
528                                        u[i] = 0;
529                                        for (j = 0; j < nvar; j++) {
530                                                imat[i][j] = 0;
531                                        }
532                                }
533
534                                /*
535                                 ** The data is sorted from smallest time to largest
536                                 ** Start at the largest time, accumulating the risk set 1 by 1
537                                 */
538                                for (person = nused - 1; person >= 0;) {
539                                        if (strata[person] == 1) { /* rezero temps for each strata */
540                                                denom = 0;
541                                                nrisk = 0;
542                                                for (i = 0; i < nvar; i++) {
543                                                        a[i] = 0;
544                                                        for (j = 0; j < nvar; j++) {
545                                                                cmat[i][j] = 0;
546                                                        }
547                                                }
548                                        }
549
550                                        dtime = time[person];
551                                        deadwt = 0;
552                                        ndead = 0;
553                                        efronwt = 0;
554                                        while (person >= 0 && time[person] == dtime) {
555                                                nrisk++;
556                                                zbeta = offset[person];
557                                                for (i = 0; i < nvar; i++) {
558                                                        zbeta += newbeta[i] * covar[i][person];
559                                                }
560                                                zbeta = coxsafe(zbeta);
561
562
563                                                risk = Math.exp(zbeta) * weights[person];
564                                                denom += risk;
565
566                                                for (i = 0; i < nvar; i++) {
567                                                        a[i] += risk * covar[i][person];
568                                                        for (j = 0; j <= i; j++) {
569                                                                cmat[i][j] += risk * covar[i][person] * covar[j][person];
570                                                        }
571                                                }
572
573                                                if (status[person] == 1) {
574                                                        ndead++;
575                                                        deadwt += weights[person];
576                                                        newlk += weights[person] * zbeta;
577                                                        for (i = 0; i < nvar; i++) {
578                                                                u[i] += weights[person] * covar[i][person];
579                                                        }
580                                                        if (method == CoxMethod.Efron) { /* Efron */
581                                                                efronwt += risk;
582                                                                for (i = 0; i < nvar; i++) {
583                                                                        a2[i] += risk * covar[i][person];
584                                                                        for (j = 0; j <= i; j++) {
585                                                                                cmat2[i][j] += risk * covar[i][person] * covar[j][person];
586                                                                        }
587                                                                }
588                                                        }
589                                                }
590
591                                                person--;
592                                                if (person >= 0 && strata[person] == 1) { //added catch of person = 0 and person-- = -1
593                                                        break;  /*ties don't cross strata */
594                                                }
595                                        }
596
597                                        if (ndead > 0) {  /* add up terms*/
598                                                if (method == CoxMethod.Breslow) { /* Breslow */
599                                                        newlk -= deadwt * Math.log(denom);
600                                                        for (i = 0; i < nvar; i++) {
601                                                                temp2 = a[i] / denom;  /* mean */
602                                                                u[i] -= deadwt * temp2;
603                                                                for (j = 0; j <= i; j++) {
604                                                                        imat[j][i] += (deadwt / denom)
605                                                                                        * (cmat[i][j] - temp2 * a[j]);
606                                                                }
607                                                        }
608                                                } else { /* Efron */
609                                                        for (k = 0; k < ndead; k++) {
610                                                                temp = (double) k / ndead;
611                                                                wtave = deadwt / ndead;
612                                                                d2 = denom - temp * efronwt;
613                                                                newlk -= wtave * Math.log(d2);
614                                                                for (i = 0; i < nvar; i++) {
615                                                                        temp2 = (a[i] - temp * a2[i]) / d2;
616                                                                        u[i] -= wtave * temp2;
617                                                                        for (j = 0; j <= i; j++) {
618                                                                                imat[j][i] += (wtave / d2)
619                                                                                                * ((cmat[i][j] - temp * cmat2[i][j])
620                                                                                                - temp2 * (a[j] - temp * a2[j]));
621                                                                        }
622                                                                }
623                                                        }
624
625                                                        for (i = 0; i < nvar; i++) { /*in anticipation */
626                                                                a2[i] = 0;
627                                                                for (j = 0; j < nvar; j++) {
628                                                                        cmat2[i][j] = 0;
629                                                                }
630                                                        }
631                                                }
632                                        }
633                                }   /* end  of accumulation loop  */
634
635                                /* am I done?
636                                 **   update the betas and test for convergence
637                                 */
638                                flag = Cholesky2.process(imat, nvar, toler);
639
640                                if (Math.abs(1 - (loglik[1] / newlk)) <= eps && halving == 0) { /* all done */
641                                        loglik[1] = newlk;
642                                        chinv2(imat, nvar);     /* invert the information matrix */
643                                        for (i = 0; i < nvar; i++) {
644                                                beta[i] = newbeta[i] * scale[i];
645                                                u[i] /= scale[i];
646                                                imat[i][i] *= scale[i] * scale[i];
647                                                for (j = 0; j < i; j++) {
648                                                        imat[j][i] *= scale[i] * scale[j];
649                                                        imat[i][j] = imat[j][i];
650                                                }
651                                        }
652                                        //  goto finish;
653                                        gotofinish = true;
654                                        break;
655                                }
656
657                                if (iter == maxiter) {
658                                        break;  /*skip the step halving calc*/
659                                }
660
661                                if (newlk < loglik[1]) {    /*it is not converging ! */
662                                        halving = 1;
663                                        for (i = 0; i < nvar; i++) {
664                                                newbeta[i] = (newbeta[i] + beta[i]) / 2; /*half of old increment */
665                                        }
666                                } else {
667                                        halving = 0;
668                                        loglik[1] = newlk;
669                                        chsolve2(imat, nvar, u);
670                                        j = 0;
671                                        for (i = 0; i < nvar; i++) {
672                                                beta[i] = newbeta[i];
673                                                newbeta[i] = newbeta[i] + u[i];
674                                        }
675                                }
676                        }   /* return for another iteration */
677                }
678
679                if (!gotofinish) {
680                        /*
681                         ** We end up here only if we ran out of iterations
682                         */
683                        loglik[1] = newlk;
684                        chinv2(imat, nvar);
685                        for (i = 0; i < nvar; i++) {
686                                beta[i] = newbeta[i] * scale[i];
687                                u[i] /= scale[i];
688                                imat[i][i] *= scale[i] * scale[i];
689                                for (j = 0; j < i; j++) {
690                                        imat[j][i] *= scale[i] * scale[j];
691                                        imat[i][j] = imat[j][i];
692                                }
693                        }
694                        flag = 1000;
695                }
696
697//finish:
698                /*
699                 for (j = 0; j < numCovariates; j++) {
700                 b[j] = b[j] / SD[j];
701                 * ix = j * (numCovariates + 1) + j
702                 SE[j] = Math.sqrt(a[ix(j, j, numCovariates + 1)]) / SD[j];
703                 //            o = o + ("   " + variables.get(j) + "    " + Fmt(b[j]) + Fmt(SE[j]) + Fmt(Math.exp(b[j])) + Fmt(Norm(Math.abs(b[j] / SE[j]))) + Fmt(Math.exp(b[j] - 1.95 * SE[j])) + Fmt(Math.exp(b[j] + 1.95 * SE[j])) + NL);
704                 CoxCoefficient coe = coxInfo.getCoefficient(variables.get(j));
705                 coe.coeff = b[j];
706                 coe.stdError = SE[j];
707                 coe.hazardRatio = Math.exp(b[j]);
708                 coe.pvalue = Norm(Math.abs(b[j] / SE[j]));
709                 coe.hazardRatioLoCI = Math.exp(b[j] - 1.95 * SE[j]);
710                 coe.hazardRatioHiCI = Math.exp(b[j] + 1.95 * SE[j]);
711                 }
712
713                 */
714
715                coxInfo.setScoreLogrankTest(sctest);
716                coxInfo.setDegreeFreedom(beta.length);
717                coxInfo.setScoreLogrankTestpvalue(ChiSq.chiSq(coxInfo.getScoreLogrankTest(), beta.length));
718                coxInfo.setVariance(imat);
719                coxInfo.u = u;
720
721                //     for (int n = 0; n < beta.length; n++) {
722                //         se[n] = Math.sqrt(imat[n][n]); // / sd[n];
723                //     }
724
725
726                //       System.out.println("coef,se, means,u");
727                for (int n = 0; n < beta.length; n++) {
728                        CoxCoefficient coe = new CoxCoefficient();
729                        coe.name = variables.get(n);
730                        coe.mean = means[n];
731                        coe.standardDeviation = sd[n];
732                        coe.coeff = beta[n];
733                        coe.stdError = Math.sqrt(imat[n][n]);
734                        coe.hazardRatio = Math.exp(coe.getCoeff());
735                        coe.z = coe.getCoeff() / coe.getStdError();
736                        coe.pvalue = ChiSq.norm(Math.abs(coe.getCoeff() / coe.getStdError()));
737                        double z = 1.959964;
738                        coe.hazardRatioLoCI = Math.exp(coe.getCoeff() - z * coe.getStdError());
739                        coe.hazardRatioHiCI = Math.exp(coe.getCoeff() + z * coe.getStdError());
740
741                        coxInfo.setCoefficient(coe.getName(), coe);
742                        // System.out.println(beta[n] + "," + se[n] + "," + means[n] + "," + sd[n] + "," + u[n]); //+ "," + imat[n] "," + loglik[n] + "," + sctest[n] + "," + iterations[n] + "," + flag[n]
743
744                }
745
746                coxInfo.maxIterations = maxiter;
747                coxInfo.eps = eps;
748                coxInfo.toler = toler;
749
750                coxInfo.iterations = iter;
751                coxInfo.flag = flag;
752                coxInfo.loglikInit = loglik[0];
753                coxInfo.loglikFinal = loglik[1];
754                coxInfo.method = method;
755
756                //    System.out.println("loglik[0]=" + loglik[0]);
757                //    System.out.println("loglik[1]=" + loglik[1]);
758
759                //    System.out.println("chisq? sctest[0]=" + sctest[0]);
760                //    System.out.println("?overall model p-value=" + chiSq(sctest[0], beta.length));
761
762
763                //      System.out.println();
764                //       for (int n = 0; n < covar[0].length; n++) {
765                //           System.out.print(n);
766                //           for (int variable = 0; variable < covar.length; variable++) {
767                //               System.out.print("\t" + covar[variable][n]);
768
769                //           }
770                //           System.out.println();
771                //       }
772                //      for (SurvivalInfo si : data) {
773                //          System.out.println(si.order + " " + si.getScore());
774                //      }
775//        coxInfo.dump();
776
777
778                coxphfitSCleanup(coxInfo, useWeighted, robust,clusterList);
779                return coxInfo;
780        }
781
782        /**
783         *
784         * @param ci
785         * @param useWeighted
786         * @param robust
787         * @param cluster
788         * @throws Exception
789         */
790        public void coxphfitSCleanup(CoxInfo ci, boolean useWeighted,boolean robust, ArrayList<String> cluster) throws Exception {
791                //Do cleanup found after coxfit6 is called in coxph.fit.S
792                //infs <- abs(coxfit$u %*% var)
793                //[ a1 b1] * [a1 b1]
794                //           [a2 b2]
795                double[][] du = new double[1][ci.u.length];
796                du[0] = ci.u;
797                double[] infs = Matrix.abs(Matrix.multiply(ci.u, ci.getVariance()));
798//        StdArrayIO.print(infs);
799
800                ArrayList<CoxCoefficient> coxCoefficients = new ArrayList<>(ci.getCoefficientsList().values());
801
802                for (int i = 0; i < infs.length; i++) {
803                        double inf = infs[i];
804                        double coe = coxCoefficients.get(i).getCoeff();
805                        if (inf > ci.eps && inf > (ci.toler * Math.abs(coe))) {
806                                ci.message = "Loglik converged before variable ";
807                        }
808                }
809
810                //sum(coef*coxfit$means)
811                double sumcoefmeans = 0;
812                for (CoxCoefficient cc : coxCoefficients) {
813                        sumcoefmeans = sumcoefmeans + cc.getCoeff() * cc.getMean();
814                }
815
816                // coxph.fit.S line 107
817                //lp <- c(x %*% coef) + offset - sum(coef*coxfit$means)
818                for (SurvivalInfo si : ci.survivalInfoList) {
819                        double offset = si.getOffset();
820                        double lp = 0;
821                        for (CoxCoefficient cc : coxCoefficients) {
822                                String name = cc.getName();
823                                double coef = cc.getCoeff();
824                                double value = si.getVariable(name);
825                                lp = lp + value * coef;
826                        }
827                        lp = lp + offset - sumcoefmeans;
828                        si.setLinearPredictor(lp);
829                        si.setScore(Math.exp(lp));
830
831//           System.out.println("lp score " + si.order + " " + si.time + " " + si.getWeight() + " " + si.getClusterValue() + " " + lp + " " + Math.exp(lp));
832                }
833//       ci.dump();
834                //begin code after call to coxfit6 in coxph.fit.S
835                //Compute the martingale residual for a Cox model
836                // appears to be C syntax error for = - vs -=
837                //(if (nullmodel) in coxph.fit
838                double[] res = CoxMart.process(ci.method, ci.survivalInfoList, false);
839
840                for(int i = 0; i < ci.survivalInfoList.size(); i++){
841                        SurvivalInfo si = ci.survivalInfoList.get(i);
842                        si.setResidual(res[i]);
843                }
844
845                //this represents the end of coxph.fit.S code and we pickup
846                //after call to fit <- fitter(X, Y, strats ....) in coxph.R
847
848                if (robust) {
849                        ci.setNaiveVariance(ci.getVariance());
850                        double[][] temp;
851                        double[][] temp0;
852
853                        if (cluster != null) {
854
855                                temp = ResidualsCoxph.process(ci, ResidualsCoxph.Type.dfbeta, useWeighted, cluster);
856                                //# get score for null model
857                                //    if (is.null(init))
858                                //          fit2$linear.predictors <- 0*fit$linear.predictors
859                                //    else
860                                //          fit2$linear.predictors <- c(X %*% init)
861                                //Set score to 1
862
863                                double[] templp = new double[ci.survivalInfoList.size()];
864                                double[] tempscore = new double[ci.survivalInfoList.size()];
865                                int i = 0;
866                                for (SurvivalInfo si : ci.survivalInfoList) {
867                                        templp[i] = si.getLinearPredictor();
868                                        tempscore[i] = si.getScore();
869                                        si.setLinearPredictor(0);
870                                        si.setScore(1.0); //this erases stored value which isn't how the R code does it
871                                        i++;
872                                }
873
874                                temp0 = ResidualsCoxph.process(ci, ResidualsCoxph.Type.score, useWeighted, cluster);
875
876                                i = 0;
877                                for (SurvivalInfo si : ci.survivalInfoList) {
878                                        si.setLinearPredictor(templp[i]);
879                                        si.setScore(tempscore[i]); //this erases stored value which isn't how the R code does it
880                                        i++;
881                                }
882
883
884                        } else {
885                                temp = ResidualsCoxph.process(ci, ResidualsCoxph.Type.dfbeta, useWeighted, null);
886                                //     fit2$linear.predictors <- 0*fit$linear.predictors
887                                double[] templp = new double[ci.survivalInfoList.size()];
888                                double[] tempscore = new double[ci.survivalInfoList.size()];
889                                int i = 0;
890                                for (SurvivalInfo si : ci.survivalInfoList) {
891                                        templp[i] = si.getLinearPredictor();
892                                        tempscore[i] = si.getScore();
893                                        si.setLinearPredictor(0);
894                                        si.setScore(1.0);
895                                }
896                                temp0 = ResidualsCoxph.process(ci, ResidualsCoxph.Type.score, useWeighted, null);
897
898                                i = 0;
899                                for (SurvivalInfo si : ci.survivalInfoList) {
900                                        si.setLinearPredictor(templp[i]);
901                                        si.setScore(tempscore[i]); //this erases stored value which isn't how the R code does it
902                                        i++;
903                                }
904                        }
905                        //fit$var<- t(temp) % * % temp
906                        double[][] ttemp = Matrix.transpose(temp);
907                        double[][] var = Matrix.multiply(ttemp, temp);
908                        ci.setVariance(var);
909                        //u<- apply(as.matrix(temp0), 2, sum)
910                        double[] u = new double[temp0[0].length];
911                        for (int i = 0; i < temp0[0].length; i++) {
912                                for (int j = 0; j < temp0.length; j++) {
913                                        u[i] = u[i] + temp0[j][i];
914                                }
915                        }
916                        //fit$rscore <- coxph.wtest(t(temp0)%*%temp0, u, control$toler.chol)$test
917                        double[][] wtemp = Matrix.multiply(Matrix.transpose(temp0),temp0);
918                        double toler_chol = 1.818989e-12;
919                  //  toler_chol = ci.toler;
920                        WaldTestInfo wti = WaldTest.process(wtemp,u,toler_chol);
921                        //not giving the correct value
922                        ci.setRscore(wti.getTest());
923                }
924                calculateWaldTestInfo(ci);
925
926
927
928
929        }
930
931        static public void calculateWaldTestInfo(CoxInfo ci){
932                if(ci.getNumberCoefficients() > 0){
933                        double toler_chol = 1.818989e-12;
934                  //  toler_chol = ci.toler;
935                        double[][] b = new double[1][ci.getNumberCoefficients()];
936                        int i = 0;
937                        for(CoxCoefficient coe : ci.getCoefficientsList().values()){
938                                b[0][i] = coe.getCoeff();
939                                i++;
940                        }
941                        ci.setWaldTestInfo(WaldTest.process(ci.getVariance(), b, toler_chol));
942                }
943        }
944
945        /**
946         * @param args the command line arguments
947         */
948        public static void main(String[] args) {
949                // TODO code application logic here
950                CoxR coxr = new CoxR();
951
952
953                if (true) {
954                        try {
955                           InputStream is = coxr.getClass().getClassLoader().getResourceAsStream("uis-complete.txt");
956
957
958
959
960                                WorkSheet worksheet = WorkSheet.readCSV(is, '\t');
961                                ArrayList<SurvivalInfo> survivalInfoList = new ArrayList<>();
962                                int i = 0;
963                                for (String row : worksheet.getRows()) {
964                                        double time = worksheet.getCellDouble(row, "TIME");
965                                        double age = worksheet.getCellDouble(row, "AGE");
966                                        double treat = worksheet.getCellDouble(row, "TREAT");
967                                        double c = worksheet.getCellDouble(row, "CENSOR");
968                                        int censor = (int) c;
969
970                                        SurvivalInfo si = new SurvivalInfo(time, censor);
971                                        si.setOrder(i);
972                                        si.addContinuousVariable("AGE", age);
973                                        si.addContinuousVariable("TREAT", treat);
974
975                                        survivalInfoList.add(si);
976                                        i++;
977                                }
978
979                                CoxR cox = new CoxR();
980                                ArrayList<String> variables = new ArrayList<>();
981                                //               variables.add("AGE");
982
983                                variables.add("AGE");
984                                variables.add("TREAT");
985
986                                //       variables.add("TREAT:AGE");
987                          //  ArrayList<Integer> cluster = new ArrayList<Integer>();
988                                CoxInfo ci = cox.process(variables, survivalInfoList, false, true,false, false);
989                                System.out.println(ci);
990                        } catch (Exception e) {
991                                e.printStackTrace();
992                        }
993
994
995                }
996
997//        if (false) {
998//
999//            try {
1000//
1001//
1002//                WorkSheet worksheet = WorkSheet.readCSV("/Users/Scooter/NetBeansProjects/AssayWorkbench/src/edu/scripps/assayworkbench/cox/uis-complete.txt", '\t');
1003//                ArrayList<String> rows = worksheet.getRows();
1004//                ArrayList<String> variables = new ArrayList<String>();
1005//                variables.add("AGE");
1006//                variables.add("TREAT");
1007//                double[] time2 = new double[rows.size()];
1008//                int[] status2 = new int[rows.size()];
1009//                double[][] covar2 = new double[variables.size()][rows.size()];
1010//                double[] offset2 = new double[rows.size()];
1011//                double[] weights2 = new double[rows.size()];
1012//                int[] strata2 = new int[rows.size()];
1013//
1014//
1015//                for (int i = 0; i < rows.size(); i++) {
1016//                    String row = rows.get(i);
1017//                    double time = worksheet.getCellDouble(row, "TIME");
1018//                    //      double age = worksheet.getCellDouble(row, "AGE");
1019//                    //      double treat = worksheet.getCellDouble(row, "TREAT");
1020//                    double c = worksheet.getCellDouble(row, "CENSOR");
1021//                    int censor = (int) c;
1022//
1023//                    time2[i] = time;
1024//                    status2[i] = censor;
1025//                    offset2[i] = 0;
1026//                    weights2[i] = 1;
1027//                    strata2[i] = 0;
1028//
1029//                    for (int j = 0; j < variables.size(); j++) {
1030//                        String variable = variables.get(j);
1031//                        double v = worksheet.getCellDouble(row, variable);
1032//                        covar2[j][i] = v;
1033//                    }
1034//
1035//
1036//                }
1037//                //from coxph.control.S
1038//                int maxiter2 = 20;
1039//                double eps2 = 1e-9;
1040//                double toler2 = Math.pow(eps2, .75);
1041//                int doscale2 = 1;
1042//                int method2 = 0;
1043//                //toler.chol = eps ^ .75
1044//                //toler.inf=sqrt(eps)
1045//                //outer.max=10
1046//
1047//                CoxR cox = new CoxR();
1048//                //        cox.coxfit6(maxiter2, time2, status2, covar2, offset2, weights2, strata2, method2, eps2, toler2, time2, doscale2);
1049//
1050//
1051//
1052//
1053//
1054//            } catch (Exception e) {
1055//                e.printStackTrace();
1056//            }
1057//        }
1058
1059        }
1060
1061        /* $Id: chinv2.c 11357 2009-09-04 15:22:46Z therneau $
1062         **
1063         ** matrix inversion, given the FDF' cholesky decomposition
1064         **
1065         ** input  **matrix, which contains the chol decomp of an n by n
1066         **   matrix in its lower triangle.
1067         **
1068         ** returned: the upper triangle + diagonal contain (FDF')^{-1}
1069         **            below the diagonal will be F inverse
1070         **
1071         **  Terry Therneau
1072         */
1073        void chinv2(double[][] matrix, int n) {
1074                double temp;
1075                int i, j, k;
1076
1077                /*
1078                 ** invert the cholesky in the lower triangle
1079                 **   take full advantage of the cholesky's diagonal of 1's
1080                 */
1081                for (i = 0; i < n; i++) {
1082                        if (matrix[i][i] > 0) {
1083                                matrix[i][i] = 1 / matrix[i][i];   /*this line inverts D */
1084                                for (j = (i + 1); j < n; j++) {
1085                                        matrix[j][i] = -matrix[j][i];
1086                                        for (k = 0; k < i; k++) /*sweep operator */ {
1087                                                matrix[j][k] += matrix[j][i] * matrix[i][k];
1088                                        }
1089                                }
1090                        }
1091                }
1092
1093                /*
1094                 ** lower triangle now contains inverse of cholesky
1095                 ** calculate F'DF (inverse of cholesky decomp process) to get inverse
1096                 **   of original matrix
1097                 */
1098                for (i = 0; i < n; i++) {
1099                        if (matrix[i][i] == 0) {  /* singular row */
1100                                for (j = 0; j < i; j++) {
1101                                        matrix[j][i] = 0;
1102                                }
1103                                for (j = i; j < n; j++) {
1104                                        matrix[i][j] = 0;
1105                                }
1106                        } else {
1107                                for (j = (i + 1); j < n; j++) {
1108                                        temp = matrix[j][i] * matrix[j][j];
1109                                        if (j != i) {
1110                                                matrix[i][j] = temp;
1111                                        }
1112                                        for (k = i; k < j; k++) {
1113                                                matrix[i][k] += temp * matrix[j][k];
1114                                        }
1115                                }
1116                        }
1117                }
1118        }
1119
1120        /*  $Id: chsolve2.c 11376 2009-12-14 22:53:57Z therneau $
1121         **
1122         ** Solve the equation Ab = y, where the cholesky decomposition of A and y
1123         **   are the inputs.
1124         **
1125         ** Input  **matrix, which contains the chol decomp of an n by n
1126         **   matrix in its lower triangle.
1127         **        y[n] contains the right hand side
1128         **
1129         **  y is overwriten with b
1130         **
1131         **  Terry Therneau
1132         */
1133        void chsolve2(double[][] matrix, int n, double[] y) {
1134                int i, j;
1135                double temp;
1136
1137                /*
1138                 ** solve Fb =y
1139                 */
1140                for (i = 0; i < n; i++) {
1141                        temp = y[i];
1142                        for (j = 0; j < i; j++) {
1143                                temp -= y[j] * matrix[i][j];
1144                        }
1145                        y[i] = temp;
1146                }
1147                /*
1148                 ** solve DF'z =b
1149                 */
1150                for (i = (n - 1); i >= 0; i--) {
1151                        if (matrix[i][i] == 0) {
1152                                y[i] = 0;
1153                        } else {
1154                                temp = y[i] / matrix[i][i];
1155                                for (j = i + 1; j < n; j++) {
1156                                        temp -= y[j] * matrix[j][i];
1157                                }
1158                                y[i] = temp;
1159                        }
1160                }
1161        }
1162
1163
1164
1165        /**
1166         *
1167         * @param x
1168         * @return
1169         */
1170        public double coxsafe(double x) {
1171                if (x < -200) {
1172                        return -200;
1173                }
1174                if (x > 22) {
1175                        return 22;
1176                }
1177                return x;
1178        }
1179}