001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.survival.cox;
022
023import java.io.PrintStream;
024import java.util.ArrayList;
025import java.util.Collections;
026import java.util.LinkedHashMap;
027
028/**
029 * Used to work with SurvivalInfo
030 * @author Scooter Willis 
031 */
032public class SurvivalInfoHelper {
033
034        /**
035         * For each analysis this allows outputing of the data used in the calculations to a printstream/file. This then
036         * allows the file to be loaded into R and calculations can be verified.
037         * @param DataT
038         * @param ps
039         * @param delimiter
040         */
041        public static void dump(ArrayList<SurvivalInfo> DataT, PrintStream ps, String delimiter) {
042                ArrayList<String> variables = DataT.get(0).getDataVariables();
043                ps.print("Seq" + delimiter);
044                for (String variable : variables) {
045                        ps.print(variable + delimiter);
046                }
047                ps.print("TIME" + delimiter + "STATUS" + delimiter + "WEIGHT" + delimiter + "STRATA");
048
049                ps.println();
050                for (SurvivalInfo si : DataT) {
051                        ps.print(si.getOrder() + delimiter);
052                        for (String variable : variables) {
053                                Double value = si.getVariable(variable);
054                                ps.print(value + delimiter);
055                        }
056
057                        ps.print(si.getTime() + delimiter + si.getStatus() + delimiter + si.getWeight() + delimiter + si.getStrata());
058
059                        ps.println();
060                }
061
062
063        }
064
065        /**
066         * If any not numeric value then categorical
067         * @param values
068         * @return
069         */
070        private static boolean isCategorical(LinkedHashMap<String, Double> values) {
071                try {
072                        for (String value : values.keySet()) {
073                                Double.parseDouble(value);
074                        }
075                        return false;
076                } catch (Exception e) {
077                        return true;
078                }
079
080        }
081
082        /**
083         * Take a collection of categorical data and convert it to numeric to be used in cox calculations
084         * @param DataT
085         */
086        public static void categorizeData(ArrayList<SurvivalInfo> DataT) {
087
088                //Go through and get all variable value pairs
089                LinkedHashMap<String, LinkedHashMap<String, Double>> valueMap = new LinkedHashMap<>();
090                for (SurvivalInfo si : DataT) {
091
092                        for (String key : si.unknownDataType.keySet()) {
093                                LinkedHashMap<String, Double> map = valueMap.get(key);
094                                if (map == null) {
095                                        map = new LinkedHashMap<>();
096                                        valueMap.put(key, map);
097                                }
098                                map.put(si.unknownDataType.get(key), null);
099                        }
100                }
101
102                for (String variable : valueMap.keySet()) {
103                        LinkedHashMap<String, Double> values = valueMap.get(variable);
104                        if (isCategorical(values)) {
105                                ArrayList<String> categories = new ArrayList<>(values.keySet());
106                                Collections.sort(categories); //go ahead and put in alphabetical order
107                                if (categories.size() == 2) {
108                                        for (String value : values.keySet()) {
109                                                int index = categories.indexOf(value);
110                                                values.put(value, index + 0.0);
111                                        }
112                                } else {
113                                        for (String value : values.keySet()) {
114                                                int index = categories.indexOf(value);
115                                                values.put(value, index + 1.0);
116                                        }
117                                }
118
119                        } else {
120                                for (String value : values.keySet()) {
121                                        Double d = Double.parseDouble(value);
122                                        values.put(value, d);
123                                }
124                        }
125                }
126
127                for (SurvivalInfo si : DataT) {
128                        for (String key : si.unknownDataType.keySet()) {
129                                LinkedHashMap<String, Double> map = valueMap.get(key);
130                                String value = si.unknownDataType.get(key);
131                                Double d = map.get(value);
132                                si.data.put(key, d);
133                        }
134                }
135
136                for (SurvivalInfo si : DataT) {
137                        si.unknownDataType.clear();
138                }
139
140        }
141
142        /**
143         * To test for interactions use two variables and create a third variable where the two are multiplied together.
144         * @param variable1
145         * @param variable2
146         * @param survivalInfoList
147         * @return
148         */
149        public static ArrayList<String> addInteraction(String variable1, String variable2, ArrayList<SurvivalInfo> survivalInfoList) {
150                ArrayList<String> variables = new ArrayList<>();
151                variables.add(variable1);
152                variables.add(variable2);
153                variables.add(variable1 + ":" + variable2);
154                for (SurvivalInfo si : survivalInfoList) {
155                        Double value1 = si.getVariable(variable1);
156                        Double value2 = si.getVariable(variable2);
157                        Double value3 = value1 * value2;
158                        si.addContinuousVariable(variable1 + ":" + variable2, value3);
159                }
160                return variables;
161        }
162
163        /**
164         * Need to allow a range of values similar to cut in R and a continuous c
165         *
166         * @param range
167         * @param variable
168         * @param groupName
169         * @param survivalInfoList
170         * @throws Exception
171         */
172        public static void groupByRange(double[] range, String variable, String groupName, ArrayList<SurvivalInfo> survivalInfoList) throws Exception {
173                ArrayList<String> labels = new ArrayList<>();
174                for (int i = 0; i < range.length; i++) {
175                        String label = "";
176                        if (i == 0) {
177                                label = "[<=" + range[i] + "]";
178                        } else if (i == range.length - 1) {
179                                label = "[" + (range[i - 1] + 1) + "-" + range[i] + "]";
180                                labels.add(label);
181                                label = "[>" + range[i] + "]";
182                        } else {
183                                label = "[" + (range[i - 1] + 1) + "-" + range[i] + "]";
184                        }
185                        labels.add(label);
186                }
187                ArrayList<String> validLabels = new ArrayList<>();
188
189                //need to find the categories so we can set 1 and 0 and not include ranges with no values
190                for (SurvivalInfo si : survivalInfoList) {
191                        Double value = si.getContinuousVariable(variable);
192                        if (value == null) {
193                                throw new Exception("Variable " + variable + " not found in " + si.toString());
194                        }
195                        int rangeIndex = getRangeIndex(range, value);
196                        String label = labels.get(rangeIndex);
197                        if (!validLabels.contains(groupName + "_" + label)) {
198                                validLabels.add(groupName + "_" + label);
199                        }
200                }
201                Collections.sort(validLabels);
202                System.out.println("Valid Lables:" + validLabels);
203                for (SurvivalInfo si : survivalInfoList) {
204                        Double value = si.getContinuousVariable(variable);
205                        if (value == null) {
206                                throw new Exception("Variable " + variable + " not found in " + si.toString());
207                        }
208                        int rangeIndex = getRangeIndex(range, value);
209                        String label = labels.get(rangeIndex);
210                        String inLable = groupName + "_" + label;
211                        for (String gl : validLabels) {
212                                if (gl.equals(inLable)) {
213                                        si.addContinuousVariable(gl, 1.0);
214                                } else {
215                                        si.addContinuousVariable(gl, 0.0);
216                                }
217                        }
218                }
219
220        }
221
222        /**
223         *
224         * @param groupName
225         * @param survivalInfoList
226         * @return
227         */
228        public static ArrayList<String> getGroupCategories(String groupName, ArrayList<SurvivalInfo> survivalInfoList) {
229                return survivalInfoList.get(0).getGroupCategories(groupName);
230        }
231
232        private static int getRangeIndex(double[] range, double value) throws Exception {
233                for (int i = 0; i < range.length; i++) {
234                        if (i == 0 && value <= range[i]) {
235                                return i;
236                        }
237                        if (value <= range[i]) {
238                                return i;
239                        }
240
241                }
242
243                if (value > range[range.length - 1]) {
244                        return range.length;
245                }
246                throw new Exception("Value " + value + " not found in range ");
247        }
248}