001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.kaplanmeier.figure;
022
023import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
024import org.biojava.nbio.survival.cox.SurvivalInfo;
025import org.biojava.nbio.survival.cox.comparators.SurvivalInfoValueComparator;
026
027import javax.imageio.ImageIO;
028import javax.swing.*;
029import java.awt.*;
030import java.awt.image.BufferedImage;
031import java.io.File;
032import java.text.DecimalFormat;
033import java.util.ArrayList;
034import java.util.Collections;
035
036/**
037 *
038 * @author Scooter Willis 
039 */
040public class ExpressionFigure extends JPanel {
041
042        private static final long serialVersionUID = 1L;
043
044        ArrayList<String> title = new ArrayList<>();
045        /**
046         *
047         */
048        public int top;
049        /**
050         *
051         */
052        public int bottom;
053        /**
054         *
055         */
056        public int left;
057        /**
058         *
059         */
060        public int right;
061        int titleHeight;
062        int xAxisLabelHeight;
063        int labelWidth;
064        Double maxTime = null;
065        Double minX = 0.0;
066        Double maxX = 10.0;
067        Double minY = 0.0;
068        Double maxY = 1.0;
069        Double mean = 0.0;
070        FontMetrics fm;
071        KMFigureInfo kmfi = new KMFigureInfo();
072//    LinkedHashMap<String, ArrayList<CensorStatus>> survivalData = new LinkedHashMap<String, ArrayList<CensorStatus>>();
073        ArrayList<String> lineInfoList = new ArrayList<>();
074        ArrayList<SurvivalInfo> siList = new ArrayList<>();
075        String variable = "";
076        private String fileName = "";
077
078        /**
079         *
080         */
081        public ExpressionFigure() {
082                super();
083                setSize(500, 400);
084                setBackground(Color.WHITE);
085        }
086
087        /**
088         * The data used to draw the graph
089         * @return
090         */
091
092        public ArrayList<SurvivalInfo> getSurvivalInfoList(){
093                return siList;
094        }
095
096        /**
097         *
098         * @param kmfi
099         */
100        public void setKMFigureInfo(KMFigureInfo kmfi) {
101                this.kmfi = kmfi;
102                if (kmfi.width != null && kmfi.height != null) {
103                        this.setSize(kmfi.width, kmfi.height);
104                }
105        }
106
107        /**
108         *
109         * @param lineInfoList
110         */
111        public void setFigureLineInfo(ArrayList<String> lineInfoList) {
112                this.lineInfoList = lineInfoList;
113                this.repaint();
114        }
115
116        /**
117         *
118         * @param title
119         * @param _siList
120         * @param variable
121         */
122        public void setSurvivalInfo(ArrayList<String> title, ArrayList<SurvivalInfo> _siList, String variable) {
123                this.siList = new ArrayList<>();
124                this.title = title;
125                this.variable = variable;
126
127                minX = 0.0;
128                maxX = (double) _siList.size();
129                minY = 0.0;
130                maxY = null;
131                DescriptiveStatistics ds = new DescriptiveStatistics();
132                for (SurvivalInfo si : _siList) {
133                        this.siList.add(si);
134                        String v = si.getOriginalMetaData(variable);
135                        Double value = Double.parseDouble(v);
136                        ds.addValue(value);
137                        if (maxTime == null || maxTime < si.getTime()) {
138                                maxTime = si.getTime();
139                        }
140
141                }
142                SurvivalInfoValueComparator sivc = new SurvivalInfoValueComparator(variable);
143                Collections.sort(this.siList, sivc);
144                mean = ds.getMean();
145                minY = ds.getMin();
146                maxY = ds.getMax();
147                minY = (double) Math.floor(minY);
148                maxY = (double) Math.ceil(maxY);
149
150
151                this.repaint();
152        }
153        DecimalFormat df = new DecimalFormat("#.#");
154
155        private void setRenderingHints(Graphics2D g) {
156                RenderingHints rh = new RenderingHints(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
157                rh.put(RenderingHints.KEY_DITHERING, RenderingHints.VALUE_DITHER_ENABLE);
158                rh.put(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_GASP);
159                rh.put(RenderingHints.KEY_RENDERING, RenderingHints.VALUE_RENDER_QUALITY);
160
161                g.setRenderingHints(rh);
162
163        }
164
165        @Override
166        public void paintComponent(Graphics g) // draw graphics in the panel
167        {
168                //int width = getWidth();             // width of window in pixels
169                //int height = getHeight();           // height of window in pixels
170                setFigureDimensions();
171                super.paintComponent(g);            // call superclass to make panel display correctly
172                Graphics2D g2 = (Graphics2D)g;
173                setRenderingHints(g2);
174                drawExpressionLevels(g);
175                drawFigureLineInfo(g);
176                drawLegend(g);
177                // Drawing code goes here
178        }
179
180        private void drawFigureLineInfo(Graphics g) {
181                g.setColor(Color.BLACK);
182                Font font = g.getFont();
183                Font f = new Font(font.getFontName(), Font.BOLD, font.getSize());
184                g.setFont(f);
185                fm = getFontMetrics(getFont());
186                int yoffset = fm.getHeight() * lineInfoList.size();
187
188                int x = getX(kmfi.figureLineInfoLowerPercentX * maxX);
189                int y = getY(kmfi.figureLineInfoLowerPercentY) - yoffset;
190
191                for (String line : lineInfoList) {
192                        g.drawString(line, x, y);
193                        y = y + fm.getHeight();
194                }
195
196        }
197
198        private void drawExpressionLevels(Graphics g) {
199                Graphics2D g2 = (Graphics2D) g;
200                g2.setStroke(kmfi.kmStroke);
201                g2.setColor(Color.blue);
202                Double py = null;
203                for (int x = 0; x < siList.size(); x++) {
204                        SurvivalInfo si = siList.get(x);
205                        String v = si.getOriginalMetaData(variable);
206                        Double y = Double.parseDouble(v);
207                        if (si.getStatus() == 1) {
208                                g2.setColor(Color.RED);
209                        } else {
210                                g2.setColor(Color.LIGHT_GRAY);
211                        }
212                        g2.drawLine(getX(x), getY((maxY - minY)), getX(x), getY(maxY - y));
213
214                        if (py == null) {
215                                py = y;
216                        }
217                        if (mean >= py && mean <= y) {
218                                g2.setColor(Color.green);
219                                g2.drawLine(getX(x), getY(maxY - minY), getX(x), getY(maxY - y));
220                                g2.drawLine(getX(x - 1), getY(maxY - minY), getX(x - 1), getY(maxY - y));
221                        }
222                        py = y;
223
224                        //    g2.setColor(Color.black);
225                        //    double yt = getYFromPercentage(1.0 - ((double)x)/((double)siList.size())  );
226                        //    g2.drawOval(getX(x) - 2, ((int)yt) - 2, 4, 4);
227                }
228
229        }
230
231//    private int getYFromPercentage(double percentage) {
232//        double d = top + (((bottom - top) * percentage));
233//        return (int) d;
234//    }
235
236        private int getX(double value) {
237                double d = left + (((right - left) * value) / (maxX - minX));
238                return (int) d;
239        }
240
241        private int getY(double value) {
242
243                double d = top + (((bottom - top) * value) / (maxY - minY));
244                return (int) d;
245        }
246
247        /**
248         * @return the fileName
249         */
250        public String getFileName() {
251                return fileName;
252        }
253
254        private void drawLegend(Graphics g) {
255                Graphics2D g2 = (Graphics2D) g;
256                Font font = g2.getFont();
257
258                font = new Font(font.getFontName(), Font.BOLD, font.getSize());
259                g2.setFont(font);
260                fm = getFontMetrics(font);
261                int fontHeight = fm.getHeight();
262                for (int i = 0; i < title.size(); i++) {
263                        if (fm.stringWidth(title.get(i)) > .8 * this.getWidth()) {
264
265                                Font f = new Font(font.getFontName(), Font.BOLD, 10);
266                                g2.setFont(f);
267                                fm = getFontMetrics(f);
268                        }
269                        g2.drawString(title.get(i), (getSize().width - fm.stringWidth(title.get(i))) / 2, ((i + 1) * fontHeight));
270                        g2.setFont(font);
271                }
272                // draw the maxY and minY values
273                g2.drawString(df.format(minY), left - fm.stringWidth(df.format(minY)) - 20, bottom + titleHeight / 6);
274                g2.drawLine(left - 5, bottom, left, bottom);
275                double ySize = maxY - minY;
276                double increment = kmfi.yaxisPercentIncrement * ySize;
277                increment = Math.ceil(increment);
278                //  increment = increment * 10.0;
279                double d = minY + increment;
280                //double graphHeight = top - bottom;
281                String label = "";
282                while (d < maxY) {
283                        int yvalue = getY(maxY - d);
284                        label = df.format(d);
285                        g2.drawString(label, left - (fm.stringWidth(label)) - 20, yvalue + titleHeight / 6); //
286
287                        g2.drawLine(left - 5, yvalue, left, yvalue);
288                        d = d + (increment);
289                }
290
291                label = df.format(maxY);
292                g2.drawString(label, left - (fm.stringWidth(label)) - 20, top + (titleHeight) / 6);
293                g2.drawLine(left - 5, top, left, top);
294
295                double timeDistance = maxX - minX;
296                double timeIncrement = timeDistance * kmfi.xaxisPercentIncrement;
297                double timeInt = (int) timeIncrement;
298                if (timeInt < 1.0) {
299                        timeInt = 1.0;
300                }
301                double adjustedPercentIncrement = timeInt / timeDistance;
302
303                d = adjustedPercentIncrement; //kmfi.xaxisPercentIncrement;
304                while (d <= 1.0) {
305                        label = df.format((minX * kmfi.timeScale) + d * ((maxX - minX) * kmfi.timeScale)); //
306                        if (d + adjustedPercentIncrement > 1.0) { //if this is the last one then adjust
307                                g2.drawString(label, left + (int) (d * (right - left)) - (int) (.5 * fm.stringWidth(label)), bottom + fm.getHeight() + 5);
308                        } else {
309                                g2.drawString(label, left + (int) (d * (right - left)) - (fm.stringWidth(label) / 2), bottom + fm.getHeight() + 5);
310                        }
311                        g2.drawLine(left + (int) (d * (right - left)), bottom, left + (int) (d * (right - left)), bottom + 5);
312                        d = d + adjustedPercentIncrement; //kmfi.xaxisPercentIncrement;
313                }
314
315
316                // draw the vertical and horizontal lines
317                g2.setStroke(kmfi.axisStroke);
318                g2.drawLine(left, top, left, bottom);
319                g2.drawLine(left, bottom, right, bottom);
320        }
321
322        private void setFigureDimensions() {
323                fm = getFontMetrics(getFont());
324                titleHeight = kmfi.titleHeight;//fm.getHeight();
325                xAxisLabelHeight = titleHeight;
326                labelWidth = Math.max(fm.stringWidth(df.format(minY)),
327                                fm.stringWidth(df.format(maxY))) + 5;
328                top = kmfi.padding + titleHeight;
329                bottom = this.getHeight() - kmfi.padding - xAxisLabelHeight;
330                left = kmfi.padding + labelWidth;
331                right = this.getWidth() - kmfi.padding;
332
333        }
334
335        /**
336         *
337         * @param fileName
338         */
339        public void savePNG(String fileName) {
340                if (fileName.startsWith("null")) {
341                        return;
342                }
343                this.fileName = fileName;
344                BufferedImage image = new BufferedImage(this.getWidth(), this.getHeight(), BufferedImage.TYPE_INT_RGB);
345                Graphics2D graphics2D = image.createGraphics();
346
347                this.paint(graphics2D);
348                try {
349                        ImageIO.write(image, "png", new File(fileName));
350                } catch (Exception ex) {
351                        ex.printStackTrace();
352                }
353
354        }
355
356        /**
357         * @param args the command line arguments
358         */
359        public static void main(String[] args) {
360                // TODO code application logic here
361                try {
362
363                        ExpressionFigure expressionFigure = new ExpressionFigure();
364
365
366
367                        JFrame application = new JFrame();
368                        application.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
369                        application.add(expressionFigure);
370                        expressionFigure.setSize(500, 400);
371
372                        application.setSize(500, 400);         // window is 500 pixels wide, 400 high
373                        application.setVisible(true);
374
375                        ArrayList<String> titles = new ArrayList<>();
376                        titles.add("Line 1");
377                        titles.add("line 2");
378
379                        ArrayList<String> figureInfo = new ArrayList<>();
380
381                        ArrayList<SurvivalInfo> survivalInfoList = new ArrayList<>();
382
383                        for (int i = 0; i < 600; i++) {
384                                double r = Math.random();
385                                double v = r * 10000;
386                                double t = Math.random() * 5.0;
387                                r = Math.random();
388                                int e = 0;
389                                if (r < .3) {
390                                        e = 1;
391                                }
392                                SurvivalInfo si = new SurvivalInfo(t, e);
393                                si.addContinuousVariable("META_GENE", v);
394                                survivalInfoList.add(si);
395
396                        }
397
398
399                        expressionFigure.setSurvivalInfo(titles, survivalInfoList, "META_GENE");
400
401                        expressionFigure.setFigureLineInfo(figureInfo);
402
403                        expressionFigure.savePNG("/Users/Scooter/Downloads/test.png");
404
405                } catch (Exception e) {
406                        e.printStackTrace();
407                }
408        }
409}