001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.kaplanmeier.figure;
022
023import org.biojava.nbio.survival.cox.*;
024
025import javax.imageio.ImageIO;
026import javax.swing.*;
027import java.awt.*;
028import java.awt.geom.AffineTransform;
029import java.awt.image.BufferedImage;
030import java.io.File;
031import java.io.FileWriter;
032import java.text.DecimalFormat;
033import java.util.ArrayList;
034import java.util.Collections;
035import java.util.LinkedHashMap;
036
037/**
038 *
039 * @author Scooter Willis <willishf at gmail dot com>
040 */
041public class KaplanMeierFigure extends JPanel {
042
043        private static final long serialVersionUID = 1L;
044
045        ArrayList<String> title = new ArrayList<String>();
046        /**
047         *
048         */
049        private int top;
050        /**
051         *
052         */
053        private int bottom;
054        /**
055         *
056         */
057        private int left;
058        private int yaxisLabel = 20;
059        /**
060         *
061         */
062        private int right;
063        int titleHeight;
064        int xAxisLabelHeight;
065        int labelWidth;
066        double minTime = 0.0;
067        double maxTime = 10.0;
068        double minPercentage = 0.0;
069        double maxPercentage = 1.0;
070        FontMetrics fm;
071        KMFigureInfo kmfi = new KMFigureInfo();
072        LinkedHashMap<String, ArrayList<CensorStatus>> survivalData = new LinkedHashMap<String, ArrayList<CensorStatus>>();
073        ArrayList<String> lineInfoList = new ArrayList<String>();
074        SurvFitInfo sfi = new SurvFitInfo();
075        private String fileName = "";
076        private ArrayList<Double> xAxisTimeValues = new ArrayList<Double>();
077        private ArrayList<Integer> xAxisTimeCoordinates = new ArrayList<Integer>();
078
079        /**
080         *
081         */
082        public KaplanMeierFigure() {
083                super();
084                setSize(500, 400);
085                setBackground(Color.WHITE);
086        }
087
088        /**
089         * Get the name of the groups that are being plotted in the figure
090         *
091         * @return
092         */
093        public ArrayList<String> getGroups() {
094                return new ArrayList<String>(survivalData.keySet());
095        }
096
097        /**
098         * To get the median percentile for a particular group pass the value of
099         * .50.
100         *
101         * @param group
102         * @param percentile
103         * @return
104         */
105        public Double getSurvivalTimePercentile(String group, double percentile) {
106
107                StrataInfo si = sfi.getStrataInfoHashMap().get(group);
108                ArrayList<Double> percentage = si.getSurv();
109                Integer percentileIndex = null;
110                for (int i = 0; i < percentage.size(); i++) {
111                        if (percentage.get(i) == percentile) {
112                                if (i + 1 < percentage.size()) {
113                                        percentileIndex = i + 1;
114                                }
115                                break;
116                        } else if (percentage.get(i) < percentile) {
117                                percentileIndex = i;
118                                break;
119                        }
120                }
121                if (percentileIndex != null) {
122                        return si.getTime().get(percentileIndex);
123                } else {
124                        return null;
125                }
126        }
127
128        /**
129         *
130         * @param kmfi
131         */
132        public void setKMFigureInfo(KMFigureInfo kmfi) {
133                this.kmfi = kmfi;
134                if (kmfi.width != null && kmfi.height != null) {
135                        this.setSize(kmfi.width, kmfi.height);
136                }
137        }
138
139        public KMFigureInfo getKMFigureInfo() {
140                return kmfi;
141        }
142
143        /**
144         *
145         * @param lineInfoList
146         */
147        public void setFigureLineInfo(ArrayList<String> lineInfoList) {
148                this.lineInfoList = lineInfoList;
149                this.repaint();
150        }
151
152        /**
153         *
154         * @param title Title of figures
155         * @param ci
156         * @param strataVariable The column that based on value will do a figure
157         * line
158         * @param legendMap Map the value in the column to something readable
159         * @param useWeighted
160         * @throws Exception
161         */
162        public void setCoxInfo(ArrayList<String> title, CoxInfo ci, String strataVariable, LinkedHashMap<String, String> legendMap, Boolean useWeighted) throws Exception {
163                LinkedHashMap<String, ArrayList<CensorStatus>> survivalData = new LinkedHashMap<String, ArrayList<CensorStatus>>();
164                ArrayList<SurvivalInfo> siList = ci.getSurvivalInfoList();
165                int n = 0;
166                int event = 0;
167                for (SurvivalInfo si : siList) {
168                        String strata = si.getOriginalMetaData(strataVariable);
169                        String legend = legendMap.get(strata);
170                        if (legend == null) {
171
172                                legend = strata;
173                        }
174                        ArrayList<CensorStatus> censorStatusList = survivalData.get(legend);
175                        if (censorStatusList == null) {
176                                censorStatusList = new ArrayList<CensorStatus>();
177                                survivalData.put(legend, censorStatusList);
178                        }
179                        CensorStatus cs = new CensorStatus(strata, si.getTime(), si.getStatus() + "");
180                        cs.weight = si.getWeight();
181                        censorStatusList.add(cs);
182                        n++;
183                        if (si.getStatus() == 1) {
184                                event++;
185                        }
186                }
187
188                setSurvivalData(title, survivalData, useWeighted);
189                CoxCoefficient cc = ci.getCoefficient(strataVariable);
190                //DecimalFormat df = new DecimalFormat("#.##");
191                String line1 = "HR=" + fmt(cc.getHazardRatio(), 2, 0) + " (CI:" + fmt(cc.getHazardRatioLoCI(), 2, 0) + "-" + fmt(cc.getHazardRatioHiCI(), 2, 0) + ")";
192                String line2 = "p=" + fmt(cc.getPvalue(), 3, 0);
193           // String line2 = "logrank P=" + fmt(ci.getScoreLogrankTestpvalue(), 3, 0);
194                String line3 = "n=" + n + " events=" + event;
195//        System.out.println("setCoxInfo=" + cc.pvalue + " " + title);
196
197
198                ArrayList<String> lines = new ArrayList<String>();
199                lines.add(line1);
200                lines.add(line2);
201                lines.add(line3);
202                setFigureLineInfo(lines);
203        }
204
205        /**
206         *
207         * @param d
208         * @param precision
209         * @param pad
210         * @return
211         */
212        public static String fmt(Double d, int precision, int pad) {
213                String value = "";
214                DecimalFormat dfe = new DecimalFormat("0.00E0");
215                String dpad = "0.";
216                double p = 1.0;
217                for (int i = 0; i < (precision); i++) {
218                        dpad = dpad + "0";
219                        p = p / 10.0;
220                }
221                DecimalFormat df = new DecimalFormat(dpad);
222                if (Math.abs(d) >= p) {
223                        value = df.format(d);
224                } else {
225                        value = dfe.format(d);
226                }
227                int length = value.length();
228                int extra = pad - length;
229                if (extra > 0) {
230                        for (int i = 0; i < extra; i++) {
231                                value = " " + value;
232                        }
233                }
234                return value;
235        }
236
237        /**
238         *
239         * @return
240         */
241        public SurvFitInfo getSurvivalFitInfo() {
242                return sfi;
243        }
244
245        /**
246         * Allow setting of points in the figure where weighted correction has been
247         * done and percentage has already been calculated.
248         *
249         * @param title
250         * @param sfi
251         * @param userSetMaxTime
252         */
253        public void setSurvivalData(ArrayList<String> title, SurvFitInfo sfi, Double userSetMaxTime) {
254                this.title = title;
255                LinkedHashMap<String, StrataInfo> strataInfoHashMap = sfi.getStrataInfoHashMap();
256                Double mTime = null;
257                for (StrataInfo si : strataInfoHashMap.values()) {
258                        for (double t : si.getTime()) {
259                                if (mTime == null || t > mTime) {
260                                        mTime = t;
261                                }
262                        }
263                }
264
265                int evenCheck = Math.round(mTime.floatValue());
266                if (evenCheck % 2 == 1) {
267                        evenCheck = evenCheck + 1;
268                }
269                this.maxTime = evenCheck;
270
271                if (userSetMaxTime != null && userSetMaxTime > maxTime) {
272                        this.maxTime = userSetMaxTime;
273                }
274                this.sfi = sfi;
275                if (sfi.getStrataInfoHashMap().size() == 1) {
276                        return;
277                }
278                this.repaint();
279        }
280
281        /**
282         * The data will set the max time which will result in off time points for
283         * tick marks
284         *
285         * @param title
286         * @param survivalData
287         * @param useWeighted
288         * @throws Exception
289         */
290        public void setSurvivalData(ArrayList<String> title, LinkedHashMap<String, ArrayList<CensorStatus>> survivalData, Boolean useWeighted) throws Exception {
291                this.setSurvivalData(title, survivalData, null, useWeighted);
292        }
293
294        /**
295         *
296         * @param title
297         * @param survivalData
298         * @param userSetMaxTime
299         * @param useWeighted
300         * @throws Exception
301         */
302        public void setSurvivalData(ArrayList<String> title, LinkedHashMap<String, ArrayList<CensorStatus>> survivalData, Double userSetMaxTime, Boolean useWeighted) throws Exception {
303                this.title = title;
304                this.survivalData = survivalData;
305                Double mTime = null;
306                ArrayList<String> labels = new ArrayList<String>(survivalData.keySet());
307                Collections.sort(labels);
308                for (String legend : labels) {
309                        ArrayList<CensorStatus> censorStatusList = survivalData.get(legend);
310                        for (CensorStatus cs : censorStatusList) {
311
312                                if (mTime == null || cs.time > mTime) {
313                                        mTime = cs.time;
314                                }
315                        }
316                }
317
318                int evenCheck = Math.round(mTime.floatValue());
319                if (evenCheck % 2 == 1) {
320                        evenCheck = evenCheck + 1;
321                }
322                this.maxTime = evenCheck;
323
324                if (userSetMaxTime != null && userSetMaxTime > maxTime) {
325                        this.maxTime = userSetMaxTime;
326                }
327
328                //calculate percentages
329                SurvFitKM survFitKM = new SurvFitKM();
330                sfi = survFitKM.process(survivalData, useWeighted);
331                this.repaint();
332        }
333
334        /**
335         * Save data from survival curve to text file
336         *
337         * @param fileName
338         * @throws Exception
339         */
340        public void saveSurvivalData(String fileName) throws Exception {
341                FileWriter fw = new FileWriter(fileName);
342                fw.write("index\tTIME\tSTATUS\tGROUP\r\n");
343                int index = 0;
344                for (String group : survivalData.keySet()) {
345                        ArrayList<CensorStatus> sd = survivalData.get(group);
346                        for (CensorStatus cs : sd) {
347                                String line = index + "\t" + cs.time + "\t" + cs.censored + "\t" + cs.group + "\r\n";
348                                index++;
349                                fw.write(line);
350                        }
351                }
352                fw.close();
353        }
354        DecimalFormat df = new DecimalFormat("#.#");
355
356        @Override
357        public void paintComponent(Graphics g) // draw graphics in the panel
358        {
359                int width = getWidth();             // width of window in pixels
360                int height = getHeight();           // height of window in pixels
361                setFigureDimensions();
362                g.setColor(Color.white);
363                g.clearRect(0, 0, width, height);
364
365                super.paintComponent(g);            // call superclass to make panel display correctly
366
367                drawLegend(g);
368                drawSurvivalCurves(g);
369                drawFigureLineInfo(g);
370                // Drawing code goes here
371        }
372
373        private void drawFigureLineInfo(Graphics g) {
374                Graphics2D g2 = (Graphics2D) g;
375                setRenderingHints(g2);
376                g2.setColor(Color.BLACK);
377                fm = getFontMetrics(getFont());
378                int yoffset = fm.getHeight() * lineInfoList.size();
379
380                int x = getTimeX(kmfi.figureLineInfoLowerPercentX * maxTime);
381                int y = getPercentageY(kmfi.figureLineInfoLowerPercentY) - yoffset;
382
383                for (String line : lineInfoList) {
384                        g2.drawString(line, x, y);
385                        y = y + fm.getHeight();
386                }
387
388        }
389
390        private void drawSurvivalCurves(Graphics g) {
391                Graphics2D g2 = (Graphics2D) g;
392                setRenderingHints(g2);
393                g2.setStroke(kmfi.kmStroke);
394
395
396                int colorIndex = 0;
397                ArrayList<String> labels = new ArrayList<String>(sfi.getStrataInfoHashMap().keySet());
398                Collections.sort(labels);
399
400                LinkedHashMap<String, StrataInfo> strataInfoHashMap = sfi.getStrataInfoHashMap();
401
402                for (String legend : labels) {
403                        StrataInfo si = strataInfoHashMap.get(legend);
404                        g2.setColor(kmfi.legendColor[colorIndex]);
405                        colorIndex++;
406
407                        for (int i = 0; i < si.getSurv().size() - 1; i++) {
408                                double p0time = si.getTime().get(i);
409                                double p1time = si.getTime().get(i + 1);
410                                double p0percentage = si.getSurv().get(i);
411                                double p1percentage = si.getSurv().get(i + 1);
412                                if (i == 0) {
413                                        g2.drawLine(getTimeX(0), getPercentageY(1), getTimeX(p0time), getPercentageY(1));
414                                        g2.drawLine(getTimeX(p0time), getPercentageY(1), getTimeX(p0time), getPercentageY(p0percentage));
415                                }
416                                g2.drawLine(getTimeX(p0time), getPercentageY(p0percentage), getTimeX(p1time), getPercentageY(p0percentage));
417
418                                g2.drawLine(getTimeX(p1time), getPercentageY(p0percentage), getTimeX(p1time), getPercentageY(p1percentage));
419                                // if (si.getStatus().get(i) == 0) {
420                                if (i > 0 && si.getNcens().get(i) > 0) {
421                                        g2.drawLine(getTimeX(p0time), getPercentageY(p0percentage) - 4, getTimeX(p0time), getPercentageY(p0percentage) + 4);
422                                        g2.drawLine(getTimeX(p0time) - 4, getPercentageY(p0percentage), getTimeX(p0time) + 4, getPercentageY(p0percentage));
423                                }
424                        }
425
426
427                }
428
429                String maxString = "";
430                for (String legend : labels) {
431                        if (legend.length() > maxString.length()) {
432                                maxString = legend;
433                        }
434                }
435
436                int offset = fm.stringWidth(maxString);
437                int x = getTimeX(kmfi.legendUpperPercentX * maxTime) - offset;
438                int y = getPercentageY(kmfi.legendUpperPercentY);
439
440                colorIndex = 0;
441                for (String legend : labels) {
442                        g2.setColor(kmfi.legendColor[colorIndex]);
443                        colorIndex++;
444                        g2.drawLine(x - 20, y - (fm.getHeight() / 3), x - 5, y - (fm.getHeight() / 3));
445                        g2.drawString(legend, x, y);
446                        y = y + fm.getHeight();
447                }
448
449
450        }
451
452        /**
453         * Get the X coordinate based on a time value
454         *
455         * @param value
456         * @return
457         */
458        private int getTimeX(double value) {
459                double d = left + (((right - left) * value) / (maxTime - minTime));
460                return (int) d;
461        }
462
463        /**
464         * Get the Y coordinate based on percent value 0.0-1.0
465         *
466         * @param value
467         * @return
468         */
469        private int getPercentageY(double value) {
470                value = 1.0 - value;
471                double d = top + (((bottom - top) * value) / (maxPercentage - minPercentage));
472                return (int) d;
473        }
474
475        /**
476         * @return the fileName
477         */
478        public String getFileName() {
479                return fileName;
480        }
481
482        /**
483         * @return the top
484         */
485        public int getTop() {
486                return top;
487        }
488
489        /**
490         * @return the bottom
491         */
492        public int getBottom() {
493                return bottom;
494        }
495
496        /**
497         * @return the left
498         */
499        public int getLeft() {
500                return left;
501        }
502
503        /**
504         * @return the right
505         */
506        public int getRight() {
507                return right;
508        }
509
510        /**
511         * @return the xAxisTimeValues
512         */
513        public ArrayList<Double> getxAxisTimeValues() {
514                return xAxisTimeValues;
515        }
516
517        /**
518         * @return the xAxisTimeValues
519         */
520        public ArrayList<Integer> getxAxisTimeCoordinates() {
521                return xAxisTimeCoordinates;
522        }
523
524        class PlotInfo {
525
526                double time;
527                double atRisk;
528                double censored;
529                double events;
530                double percentage;
531
532                @Override
533                public String toString() {
534                        return time + "\t" + atRisk + "\t" + censored + "\t" + events + "\t" + (atRisk - events) + "\t" + percentage;
535                }
536        }
537
538        /**
539         * Do higher quality rendering options
540         *
541         * @param g
542         */
543        private void setRenderingHints(Graphics2D g) {
544                RenderingHints rh = new RenderingHints(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
545                rh.put(RenderingHints.KEY_DITHERING, RenderingHints.VALUE_DITHER_ENABLE);
546                rh.put(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_GASP);
547                rh.put(RenderingHints.KEY_RENDERING, RenderingHints.VALUE_RENDER_QUALITY);
548
549                g.setRenderingHints(rh);
550
551        }
552
553        /**
554         * Setup the axis, labels etc
555         *
556         * @param g
557         */
558        private void drawLegend(Graphics g) {
559                Graphics2D g2 = (Graphics2D) g;
560                setRenderingHints(g2);
561                g2.setColor(Color.BLACK);
562                Font font = g2.getFont();
563                Font f = new Font(font.getFontName(), Font.BOLD, font.getSize());
564                g2.setFont(f);
565                fm = getFontMetrics(f);
566                int fontHeight = fm.getHeight();
567                for (int i = 0; i < title.size(); i++) {
568                        if (fm.stringWidth(title.get(i)) > .8 * this.getWidth()) {
569                                f = new Font(font.getFontName(), Font.BOLD, 10);
570                                g2.setFont(f);
571                                fm = getFontMetrics(f);
572                        }
573                        g2.drawString(title.get(i), (getSize().width - fm.stringWidth(title.get(i))) / 2, ((i + 1) * fontHeight));
574                        // g2.setFont(font);
575                }
576                // draw the maxPercentage and minPercentage values
577                String label = df.format(minPercentage);
578                g2.drawString(label, left - 5 - (fm.stringWidth(label)), bottom + titleHeight / 6);
579                g2.drawLine(left - 5, bottom, left, bottom);
580                double d = minPercentage + kmfi.yaxisPercentIncrement;
581                //double graphHeight = top - bottom;
582
583                while (d < maxPercentage) {
584                        int yvalue = bottom - (int) (d * (bottom - top));
585                        label = df.format(d * 100);
586                        g2.drawString(label, left - 5 - (fm.stringWidth(label)), yvalue + titleHeight / 6); //
587
588                        g2.drawLine(left - 5, yvalue, left, yvalue);
589                        d = d + kmfi.yaxisPercentIncrement;
590                }
591
592                label = df.format(maxPercentage * 100);
593                g2.drawString(label, left - 5 - (fm.stringWidth(label)), top + (titleHeight) / 6);
594                g2.drawLine(left - 5, top, left, top);
595
596                // Create a rotation transformation for the font.
597                AffineTransform fontAT = new AffineTransform();
598
599
600                // Derive a new font using a rotatation transform
601                fontAT.rotate(270 * java.lang.Math.PI / 180);
602                Font theDerivedFont = f.deriveFont(fontAT);
603
604                // set the derived font in the Graphics2D context
605                g2.setFont(theDerivedFont);
606
607                // Render a string using the derived font
608                int yaxisHeight = fm.stringWidth(kmfi.yAxisLegend);
609                g2.drawString(kmfi.yAxisLegend, yaxisLabel, (bottom - (int) (.5 * (bottom - top))) + yaxisHeight / 2);
610
611                // put the original font back
612                g2.setFont(f);
613
614
615
616                double timeDistance = maxTime - minTime;
617                double timeIncrement = timeDistance * kmfi.xaxisPercentIncrement;
618                double timeInt = (int) Math.floor(timeIncrement);
619                if (timeInt < 1.0) {
620                        timeInt = 1.0;
621                }
622                adjustedPercentIncrement = timeInt / timeDistance;
623
624                d = adjustedPercentIncrement; //kmfi.xaxisPercentIncrement;
625                xAxisTimeValues.clear();
626                xAxisTimeCoordinates.clear();
627
628                //if we don't have time values then use percentage to set time. Not perfect but allows different tics
629                if (kmfi.xAxisLabels.isEmpty()) {
630                        xAxisTimeValues.add(minTime);
631                        xAxisTimeCoordinates.add(left);
632                        while (d <= 1.0) {
633                                double xaxisTime = ((minTime * kmfi.timeScale) + d * ((maxTime - minTime) * kmfi.timeScale)); //
634                                xAxisTimeValues.add(xaxisTime);
635
636                                Integer coordinate = left + (int) (d * (right - left));
637                                xAxisTimeCoordinates.add(coordinate);
638                                //       System.out.println(d + " " + left + " " + right + " " + coordinate + " " + minTime + " " + maxTime);
639                                d = d + adjustedPercentIncrement; //kmfi.xaxisPercentIncrement;
640                        }
641                } else {
642                        minTime = kmfi.xAxisLabels.get(0);
643                        maxTime = kmfi.xAxisLabels.get(kmfi.xAxisLabels.size() - 1);
644                        for (Double xaxisTime : kmfi.xAxisLabels) {
645                                xAxisTimeValues.add(xaxisTime);
646                                d = (xaxisTime - minTime) / (maxTime - minTime);
647                                Integer coordinate = left + (int) (d * (right - left));
648                                xAxisTimeCoordinates.add(coordinate);
649                        }
650                }
651
652                for (int i = 0; i < xAxisTimeValues.size(); i++) {
653                        Double xaxisTime = xAxisTimeValues.get(i);
654                        Integer xCoordinate = xAxisTimeCoordinates.get(i);
655                        label = df.format(xaxisTime);
656                        if (i == xAxisTimeValues.size() - 1) {
657                                g2.drawString(label, xCoordinate - (fm.stringWidth(label)), bottom + fm.getHeight() + 5);
658                        } else {
659                                g2.drawString(label, xCoordinate - (fm.stringWidth(label) / 2), bottom + fm.getHeight() + 5);
660                        }
661                        g2.drawLine(xCoordinate, bottom, xCoordinate, bottom + 5);
662                }
663
664                // draw the vertical and horizontal lines
665                g2.setStroke(kmfi.axisStroke);
666                g2.drawLine(left, top, left, bottom);
667                g2.drawLine(left, bottom, right, bottom);
668
669                // draw xAxis legend
670                g2.drawString(kmfi.xAxisLegend, getSize().width / 2 - (fm.stringWidth(kmfi.xAxisLegend) / 2), bottom + 2 * fm.getHeight() + 10);
671        }
672        Double adjustedPercentIncrement = 0.0;
673
674        /**
675         * Get the percentage increment for the time axis
676         *
677         * @return
678         */
679        public Double getTimeAxisIncrementPercentage() {
680                return adjustedPercentIncrement;
681        }
682
683        /**
684         * Reset the various bounds used to draw graph
685         */
686        private void setFigureDimensions() {
687                fm = getFontMetrics(getFont());
688                titleHeight = kmfi.titleHeight;//fm.getHeight();
689                xAxisLabelHeight = titleHeight;
690                labelWidth = Math.max(fm.stringWidth(df.format(minPercentage)),
691                                fm.stringWidth(df.format(maxPercentage))) + 5;
692                top = kmfi.padding + titleHeight;
693                bottom = this.getHeight() - kmfi.padding - xAxisLabelHeight;
694                left = kmfi.padding + labelWidth + yaxisLabel;
695                right = this.getWidth() - kmfi.padding;
696
697        }
698
699        /**
700         * Combine the KM and Num risk into one image
701         *
702         * @param fileName
703         */
704        public void savePNGKMNumRisk(String fileName) {
705                if (fileName.startsWith("null") || fileName.startsWith("Null") || fileName.startsWith("NULL")) {
706                        return;
707                }
708                this.fileName = fileName;
709
710                NumbersAtRiskPanel numbersAtRiskPanel = new NumbersAtRiskPanel();
711                numbersAtRiskPanel.setKaplanMeierFigure(this);
712                numbersAtRiskPanel.setSize(this.getWidth(), numbersAtRiskPanel.getHeight());
713                BufferedImage imageKM = new BufferedImage(this.getWidth(), this.getHeight(), BufferedImage.TYPE_INT_RGB);
714                Graphics2D graphics2D = imageKM.createGraphics();
715
716                this.paint(graphics2D);
717
718                BufferedImage imageNumRisk = new BufferedImage(numbersAtRiskPanel.getWidth(), numbersAtRiskPanel.getHeight(), BufferedImage.TYPE_INT_RGB);
719                Graphics2D graphics2DNumRisk = imageNumRisk.createGraphics();
720                numbersAtRiskPanel.paint(graphics2DNumRisk);
721
722
723                BufferedImage image = new BufferedImage(numbersAtRiskPanel.getWidth(), numbersAtRiskPanel.getHeight() + this.getHeight(), BufferedImage.TYPE_INT_RGB);
724                Graphics2D g = image.createGraphics();
725
726                g.drawImage(imageKM, 0, 0, null);
727                g.drawImage(imageNumRisk, 0, this.getHeight(), null);
728
729                try {
730                        ImageIO.write(image, "png", new File(fileName));
731                } catch (Exception ex) {
732                        ex.printStackTrace();
733                }
734
735        }
736
737        /**
738         *
739         * @param fileName
740         */
741        public void savePNG(String fileName) {
742                if (fileName.startsWith("null") || fileName.startsWith("Null") || fileName.startsWith("NULL")) {
743                        return;
744                }
745                this.fileName = fileName;
746                BufferedImage image = new BufferedImage(this.getWidth(), this.getHeight(), BufferedImage.TYPE_INT_RGB);
747                Graphics2D graphics2D = image.createGraphics();
748
749                this.paint(graphics2D);
750                try {
751                        ImageIO.write(image, "png", new File(fileName));
752                } catch (Exception ex) {
753                        ex.printStackTrace();
754                }
755
756        }
757
758        /**
759         * @param args the command line arguments
760         */
761        public static void main(String[] args) {
762                // TODO code application logic here
763                try {
764
765                        KaplanMeierFigure kaplanMeierFigure = new KaplanMeierFigure();
766                        LinkedHashMap<String, ArrayList<CensorStatus>> survivalDataHashMap = new LinkedHashMap<String, ArrayList<CensorStatus>>();
767
768//            if (false) { //http://sph.bu.edu/otlt/MPH-Modules/BS/BS704_Survival/
769//                ArrayList<CensorStatus> graph1 = new ArrayList<CensorStatus>();
770//                graph1.add(new CensorStatus("A", 24.0, "0"));
771//                graph1.add(new CensorStatus("A", 3.0, "1"));
772//                graph1.add(new CensorStatus("A", 11.0, "0"));
773//                graph1.add(new CensorStatus("A", 19.0, "0"));
774//                graph1.add(new CensorStatus("A", 24.0, "0"));
775//                graph1.add(new CensorStatus("A", 13.0, "0"));
776//
777//                graph1.add(new CensorStatus("A", 14.0, "1"));
778//                graph1.add(new CensorStatus("A", 2.0, "0"));
779//                graph1.add(new CensorStatus("A", 18.0, "0"));
780//                graph1.add(new CensorStatus("A", 17.0, "0"));
781//                graph1.add(new CensorStatus("A", 24.0, "0"));
782//                graph1.add(new CensorStatus("A", 21.0, "0"));
783//                graph1.add(new CensorStatus("A", 12.0, "0"));
784//
785//                graph1.add(new CensorStatus("A", 1.0, "1"));
786//                graph1.add(new CensorStatus("A", 10.0, "0"));
787//                graph1.add(new CensorStatus("A", 23.0, "1"));
788//                graph1.add(new CensorStatus("A", 6.0, "0"));
789//                graph1.add(new CensorStatus("A", 5.0, "1"));
790//                graph1.add(new CensorStatus("A", 9.0, "0"));
791//                graph1.add(new CensorStatus("A", 17.0, "1"));
792//
793//                survivalDataHashMap.put("Label 1", graph1);
794//
795//
796//
797//            }
798
799
800                        if (true) {
801
802
803
804                                ArrayList<CensorStatus> graph1 = new ArrayList<CensorStatus>();
805                                graph1.add(new CensorStatus("A", 1.0, "1"));
806                                graph1.add(new CensorStatus("A", 1.0, "1"));
807                                graph1.add(new CensorStatus("A", 1.0, "1"));
808                                graph1.add(new CensorStatus("A", 2.0, "1"));
809                                graph1.add(new CensorStatus("A", 2.0, "1"));
810                                graph1.add(new CensorStatus("A", 3.0, "1"));
811
812                                graph1.add(new CensorStatus("A", 4.0, "1"));
813                                graph1.add(new CensorStatus("A", 4.0, "1"));
814                                graph1.add(new CensorStatus("A", 4.0, "1"));
815                                graph1.add(new CensorStatus("A", 4.0, "1"));
816                                graph1.add(new CensorStatus("A", 4.0, "1"));
817                                graph1.add(new CensorStatus("A", 4.0, "1"));
818                                graph1.add(new CensorStatus("A", 4.0, "0"));
819
820                                graph1.add(new CensorStatus("A", 5.0, "1"));
821                                graph1.add(new CensorStatus("A", 5.0, "1"));
822
823                                graph1.add(new CensorStatus("A", 8.0, "0"));
824                                graph1.add(new CensorStatus("A", 8.0, "0"));
825                                graph1.add(new CensorStatus("A", 8.0, "0"));
826                                graph1.add(new CensorStatus("A", 8.0, "0"));
827                                graph1.add(new CensorStatus("A", 8.0, "0"));
828                                graph1.add(new CensorStatus("A", 8.0, "0"));
829                                graph1.add(new CensorStatus("A", 8.0, "1"));
830
831                                graph1.add(new CensorStatus("A", 9.0, "1"));
832                                graph1.add(new CensorStatus("A", 9.0, "1"));
833                                graph1.add(new CensorStatus("A", 9.0, "1"));
834                                graph1.add(new CensorStatus("A", 9.0, "1"));
835                                graph1.add(new CensorStatus("A", 9.0, "1"));
836
837
838                                graph1.add(new CensorStatus("A", 13.0, "0"));
839                                graph1.add(new CensorStatus("A", 13.0, "0"));
840                                graph1.add(new CensorStatus("A", 13.0, "1"));
841
842                                survivalDataHashMap.put("Label 1", graph1);
843
844                                ArrayList<CensorStatus> graph2 = new ArrayList<CensorStatus>();
845                                graph2.add(new CensorStatus("A", 1.0, "1"));
846                                graph2.add(new CensorStatus("A", 1.0, "1"));
847                                graph2.add(new CensorStatus("A", 1.0, "0"));
848                                graph2.add(new CensorStatus("A", 3.0, "0"));
849                                graph2.add(new CensorStatus("A", 3.0, "1"));
850                                graph2.add(new CensorStatus("A", 4.0, "1"));
851
852                                graph2.add(new CensorStatus("A", 4.0, "1"));
853                                graph2.add(new CensorStatus("A", 4.0, "1"));
854                                graph2.add(new CensorStatus("A", 4.0, "1"));
855                                graph2.add(new CensorStatus("A", 5.0, "1"));
856                                graph2.add(new CensorStatus("A", 5.0, "0"));
857                                graph2.add(new CensorStatus("A", 5.0, "0"));
858                                graph2.add(new CensorStatus("A", 5.0, "0"));
859
860                                graph2.add(new CensorStatus("A", 6.0, "1"));
861                                graph2.add(new CensorStatus("A", 6.0, "0"));
862
863                                graph2.add(new CensorStatus("A", 7.0, "0"));
864                                graph2.add(new CensorStatus("A", 7.0, "0"));
865                                graph2.add(new CensorStatus("A", 7.0, "0"));
866                                graph2.add(new CensorStatus("A", 7.0, "0"));
867                                graph2.add(new CensorStatus("A", 8.0, "1"));
868                                graph2.add(new CensorStatus("A", 8.0, "1"));
869                                graph2.add(new CensorStatus("A", 8.0, "1"));
870
871                                graph2.add(new CensorStatus("A", 8.0, "1"));
872                                graph2.add(new CensorStatus("A", 8.0, "1"));
873                                graph2.add(new CensorStatus("A", 8.0, "0"));
874                                graph2.add(new CensorStatus("A", 9.0, "0"));
875                                graph2.add(new CensorStatus("A", 9.0, "1"));
876
877
878                                graph2.add(new CensorStatus("A", 10.0, "0"));
879                                graph2.add(new CensorStatus("A", 10.0, "0"));
880                                graph2.add(new CensorStatus("A", 10.0, "0"));
881
882                                survivalDataHashMap.put("Label 2", graph2);
883                        }
884
885                        ArrayList<String> figureInfo = new ArrayList<String>();
886                        //DecimalFormat dfe = new DecimalFormat("0.00E0");
887                        //DecimalFormat df = new DecimalFormat("0.00");
888
889
890
891                        JFrame application = new JFrame();
892                        application.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
893                        application.add(kaplanMeierFigure);
894                        kaplanMeierFigure.setSize(500, 400);
895
896                        application.setSize(500, 400);         // window is 500 pixels wide, 400 high
897                        application.setVisible(true);
898
899                        ArrayList<String> titles = new ArrayList<String>();
900                        titles.add("Line 1");
901                        titles.add("line 2");
902                        kaplanMeierFigure.setSurvivalData(titles, survivalDataHashMap, true);
903
904                        //   figureInfo.add("HR=2.1 95% CI(1.8-2.5)");
905                        //   figureInfo.add("p-value=.001");
906                        kaplanMeierFigure.setFigureLineInfo(figureInfo);
907
908                        kaplanMeierFigure.savePNGKMNumRisk("/Users/Scooter/Downloads/test.png");
909
910                } catch (Exception e) {
911                        e.printStackTrace();
912                }
913        }
914}