001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021package org.biojava.nbio.structure.symmetry.core;
022
023import org.biojava.nbio.structure.geometry.CalcPoint;
024import org.biojava.nbio.structure.geometry.SuperPositions;
025import org.slf4j.Logger;
026import org.slf4j.LoggerFactory;
027
028import javax.vecmath.AxisAngle4d;
029import javax.vecmath.Matrix4d;
030import javax.vecmath.Point3d;
031import javax.vecmath.Vector3d;
032
033import java.util.*;
034import java.util.Map.Entry;
035
036/**
037 * 
038 * 
039 * @author Peter Rose
040 *
041 */
042public class HelixSolver {
043
044        private static final Logger logger = LoggerFactory
045                        .getLogger(HelixSolver.class);
046
047        private QuatSymmetrySubunits subunits = null;
048        private int fold = 1;
049        private HelixLayers helixLayers = new HelixLayers();
050        private QuatSymmetryParameters parameters = null;
051        boolean modified = true;
052
053        public HelixSolver(QuatSymmetrySubunits subunits, int fold,
054                        QuatSymmetryParameters parameters) {
055                this.subunits = subunits;
056                this.fold = fold;
057                this.parameters = parameters;
058        }
059
060        public HelixLayers getSymmetryOperations() {
061                if (modified) {
062                        solve();
063                        modified = false;
064                }
065                return helixLayers;
066        }
067
068        private void solve() {
069                if (!preCheck()) {
070                        return;
071                }
072
073                HelicalRepeatUnit unit = new HelicalRepeatUnit(subunits);
074                List<Point3d> repeatUnitCenters = unit.getRepeatUnitCenters();
075                List<Point3d[]> repeatUnits = unit.getRepeatUnits();
076                Set<List<Integer>> permutations = new HashSet<List<Integer>>();
077
078                double minRise = parameters.getMinimumHelixRise() * fold; // for n-start
079                                                                                                                                        // helix,
080                                                                                                                                        // the rise
081                                                                                                                                        // must be
082                                                                                                                                        // steeper
083                Map<Integer[], Integer> interactionMap = unit
084                                .getInteractingRepeatUnits();
085
086                int maxLayerLineLength = 0;
087
088                for (Entry<Integer[], Integer> entry : interactionMap.entrySet()) {
089                        Integer[] pair = entry.getKey();
090                        logger.debug("HelixSolver: pair: " + Arrays.toString(pair));
091                        
092                        int contacts = entry.getValue();
093                        Point3d[] h1 = CalcPoint.clonePoint3dArray(repeatUnits.get(pair[0]));
094                        Point3d[] h2 = CalcPoint.clonePoint3dArray(repeatUnits.get(pair[1]));
095
096                        // trial superposition of repeat unit pairs to get a seed
097                        // permutation
098                        Matrix4d transformation = SuperPositions.superposeAndTransform(h2, h1);
099
100                        double rmsd = CalcPoint.rmsd(h1, h2);
101                        double rise = getRise(transformation,
102                                        repeatUnitCenters.get(pair[0]),
103                                        repeatUnitCenters.get(pair[1]));
104                        double angle = getAngle(transformation);
105
106                        logger.debug(
107                                        "Original rmsd: {}, Original rise {}, Original angle: {}",
108                                        rmsd, rise, Math.toDegrees(angle));
109
110                        if (rmsd > parameters.getRmsdThreshold()) {
111                                continue;
112                        }
113
114                        if (Math.abs(rise) < minRise) {
115                                continue;
116                        }
117
118                        // determine which subunits are permuted by the transformation
119                        List<Integer> permutation = getPermutation(transformation);
120
121                        // check permutations for validity
122
123                        // don't save redundant permutations
124                        if (permutations.contains(permutation)) {
125                                continue;
126                        }
127                        permutations.add(permutation);
128                        logger.debug("Permutation: " + permutation);
129                        
130
131                        // keep track of which subunits are permuted
132                        Set<Integer> permSet = new HashSet<Integer>();
133                        int count = 0;
134                        boolean valid = true;
135                        for (int i = 0; i < permutation.size(); i++) {
136                                if (permutation.get(i) == i) {
137                                        valid = false;
138                                        break;
139                                }
140                                if (permutation.get(i) != -1) {
141                                        permSet.add(permutation.get(i));
142                                        permSet.add(i);
143                                        count++;
144                                }
145
146                        }
147
148                        // a helix a repeat unit cannot map onto itself
149                        if (!valid) {
150                                logger.debug("Invalid mapping");
151                                continue;
152                        }
153
154                        // all subunits must be involved in a permutation
155                        if (permSet.size() != subunits.getSubunitCount()) {
156                                logger.debug("Not all subunits involved in permutation");
157                                continue;
158                        }
159
160                        // if all subunit permutation values are set, then it can't be
161                        // helical symmetry (must be cyclic symmetry)
162                        if (count == permutation.size()) {
163                                continue;
164                        }
165
166                        // superpose all permuted subunits
167                        List<Point3d> point1 = new ArrayList<Point3d>();
168                        List<Point3d> point2 = new ArrayList<Point3d>();
169                        List<Point3d> centers = subunits.getOriginalCenters();
170                        for (int j = 0; j < permutation.size(); j++) {
171                                if (permutation.get(j) != -1) {
172                                        point1.add(new Point3d(centers.get(j)));
173                                        point2.add(new Point3d(centers.get(permutation.get(j))));
174                                }
175                        }
176
177                        h1 = new Point3d[point1.size()];
178                        h2 = new Point3d[point2.size()];
179                        point1.toArray(h1);
180                        point2.toArray(h2);
181
182                        // calculate subunit rmsd if at least 3 subunits are available
183                        double subunitRmsd = 0;
184                        if (point1.size() > 2) {
185                                transformation = SuperPositions.superposeAndTransform(h2, h1);
186
187                                subunitRmsd = CalcPoint.rmsd(h1, h2);
188                                rise = getRise(transformation, repeatUnitCenters.get(pair[0]),
189                                                repeatUnitCenters.get(pair[1]));
190                                angle = getAngle(transformation);
191
192                                logger.debug("Subunit rmsd: {}, Subunit rise: {}, Subunit angle: {}", subunitRmsd, rise, Math.toDegrees(angle));
193
194                                if (subunitRmsd > parameters.getRmsdThreshold()) {
195                                        continue;
196                                }
197
198                                if (Math.abs(rise) < minRise) {
199                                        continue;
200                                }
201
202                                if (subunitRmsd > parameters.getHelixRmsdToRiseRatio()
203                                                * Math.abs(rise)) {
204                                        continue;
205                                }
206                        }
207
208                        // superpose all C alpha traces
209                        point1.clear();
210                        point2.clear();
211                        List<Point3d[]> traces = subunits.getTraces();
212                        for (int j = 0; j < permutation.size(); j++) {
213                                if (permutation.get(j) == -1) {
214                                        continue;
215                                }
216                                for (Point3d p : traces.get(j)) {
217                                        point1.add(new Point3d(p));
218                                }
219                                for (Point3d p : traces.get(permutation.get(j))) {
220                                        point2.add(new Point3d(p));
221                                }
222                        }
223
224                        h1 = new Point3d[point1.size()];
225                        h2 = new Point3d[point2.size()];
226                        point1.toArray(h1);
227                        point2.toArray(h2);
228                        Point3d[] h3 = CalcPoint.clonePoint3dArray(h1);
229                        transformation = SuperPositions.superposeAndTransform(h2, h1);
230
231                        Point3d xtrans = CalcPoint.centroid(h3);
232
233                        xtrans.negate();
234
235                        double traceRmsd = CalcPoint.rmsd(h1, h2);
236
237                        rise = getRise(transformation, repeatUnitCenters.get(pair[0]),
238                                        repeatUnitCenters.get(pair[1]));
239                        angle = getAngle(transformation);
240
241                        logger.debug("Trace rmsd: " + traceRmsd);
242                        logger.debug("Trace rise: " + rise);
243                        logger.debug("Trace angle: " + Math.toDegrees(angle));
244                        logger.debug("Permutation: " + permutation);
245
246                        if (traceRmsd > parameters.getRmsdThreshold()) {
247                                continue;
248                        }
249
250                        if (Math.abs(rise) < minRise) {
251                                continue;
252                        }
253
254                        // This prevents translational repeats to be counted as helices
255                        if (angle < Math.toRadians(parameters.getMinimumHelixAngle())) {
256                                continue;
257                        }
258
259                        if (traceRmsd > parameters.getHelixRmsdToRiseRatio()
260                                        * Math.abs(rise)) {
261                                continue;
262                        }
263
264                        AxisAngle4d a1 = new AxisAngle4d();
265                        a1.set(transformation);
266
267                        // save this helix rot-translation
268                        Helix helix = new Helix();
269                        helix.setTransformation(transformation);
270                        helix.setPermutation(permutation);
271                        helix.setRise(rise);
272                        // Old version of Vecmath on LINUX doesn't set element m33 to 1.
273                        // Here we make sure it's 1.
274                        transformation.setElement(3, 3, 1.0);
275                        transformation.invert();
276                        QuatSymmetryScores scores = QuatSuperpositionScorer.calcScores(
277                                        subunits, transformation, permutation);
278                        scores.setRmsdCenters(subunitRmsd);
279                        helix.setScores(scores);
280                        helix.setFold(fold);
281                        helix.setContacts(contacts);
282                        helix.setRepeatUnits(unit.getRepeatUnitIndices());
283                        logger.debug("Layerlines: " + helix.getLayerLines());
284                        
285                        for (List<Integer> line : helix.getLayerLines()) {
286                                maxLayerLineLength = Math.max(maxLayerLineLength, line.size());
287                        }
288
289                        // TODO
290                        // checkSelfLimitingHelix(helix);
291
292                        helixLayers.addHelix(helix);
293
294                }
295                if (maxLayerLineLength < 3) {
296                        // System.out.println("maxLayerLineLength: " + maxLayerLineLength);
297                        helixLayers.clear();
298                }
299
300                return;
301        }
302
303        @SuppressWarnings("unused")
304        private void checkSelfLimitingHelix(Helix helix) {
305                HelixExtender he = new HelixExtender(subunits, helix);
306                Point3d[] extendedHelix = he.extendHelix(1);
307
308                int overlap1 = 0;
309                for (Point3d[] trace : subunits.getTraces()) {
310                        for (Point3d pt : trace) {
311                                for (Point3d pe : extendedHelix) {
312                                        if (pt.distance(pe) < 5.0) {
313                                                overlap1++;
314                                        }
315                                }
316                        }
317                }
318
319                extendedHelix = he.extendHelix(-1);
320
321                int overlap2 = 0;
322                for (Point3d[] trace : subunits.getTraces()) {
323                        for (Point3d pt : trace) {
324                                for (Point3d pe : extendedHelix) {
325                                        if (pt.distance(pe) < 3.0) {
326                                                overlap2++;
327                                        }
328                                }
329                        }
330                }
331                System.out.println("SelfLimiting helix: " + overlap1 + ", " + overlap2);
332        }
333
334        private boolean preCheck() {
335                if (subunits.getSubunitCount() < 3) {
336                        return false;
337                }
338                List<Integer> folds = this.subunits.getFolds();
339                int maxFold = folds.get(folds.size() - 1);
340                return maxFold > 1;
341        }
342
343        /**
344         * Returns a permutation of subunit indices for the given helix
345         * transformation. An index of -1 is used to indicate subunits that do not
346         * superpose onto any other subunit.
347         * 
348         * @param transformation
349         * @return
350         */
351        private List<Integer> getPermutation(Matrix4d transformation) {
352                double rmsdThresholdSq = Math
353                                .pow(this.parameters.getRmsdThreshold(), 2);
354
355                List<Point3d> centers = subunits.getOriginalCenters();
356                List<Integer> seqClusterId = subunits.getClusterIds();
357
358                List<Integer> permutations = new ArrayList<Integer>(centers.size());
359                double[] dSqs = new double[centers.size()];
360                boolean[] used = new boolean[centers.size()];
361                Arrays.fill(used, false);
362
363                for (int i = 0; i < centers.size(); i++) {
364                        Point3d tCenter = new Point3d(centers.get(i));
365                        transformation.transform(tCenter);
366                        int permutation = -1;
367                        double minDistSq = Double.MAX_VALUE;
368                        for (int j = 0; j < centers.size(); j++) {
369                                if (seqClusterId.get(i) == seqClusterId.get(j)) {
370                                        if (!used[j]) {
371                                                double dSq = tCenter.distanceSquared(centers.get(j));
372                                                if (dSq < minDistSq && dSq <= rmsdThresholdSq) {
373                                                        minDistSq = dSq;
374                                                        permutation = j;
375                                                        dSqs[j] = dSq;
376                                                }
377                                        }
378                                }
379                        }
380                        // can't map to itself
381                        if (permutations.size() == permutation) {
382                                permutation = -1;
383                        }
384
385                        if (permutation != -1) {
386                                used[permutation] = true;
387                        }
388
389                        permutations.add(permutation);
390                }
391
392                return permutations;
393        }
394
395        /**
396         * Returns the rise of a helix given the subunit centers of two adjacent
397         * subunits and the helix transformation
398         * 
399         * @param transformation
400         *            helix transformation
401         * @param p1
402         *            center of one subunit
403         * @param p2
404         *            center of an adjacent subunit
405         * @return
406         */
407        private static double getRise(Matrix4d transformation, Point3d p1,
408                        Point3d p2) {
409                AxisAngle4d axis = getAxisAngle(transformation);
410                Vector3d h = new Vector3d(axis.x, axis.y, axis.z);
411                Vector3d p = new Vector3d();
412                p.sub(p1, p2);
413                return p.dot(h);
414        }
415
416        /**
417         * Returns the pitch angle of the helix
418         * 
419         * @param transformation
420         *            helix transformation
421         * @return
422         */
423        private static double getAngle(Matrix4d transformation) {
424                return getAxisAngle(transformation).angle;
425        }
426
427        /**
428         * Returns the AxisAngle of the helix transformation
429         * 
430         * @param transformation
431         *            helix transformation
432         * @return
433         */
434        private static AxisAngle4d getAxisAngle(Matrix4d transformation) {
435                AxisAngle4d axis = new AxisAngle4d();
436                axis.set(transformation);
437                return axis;
438        }
439}