001/*
002 *                    BioJava development code
003 *
004 * This code may be freely distributed and modified under the
005 * terms of the GNU Lesser General Public Licence.  This should
006 * be distributed with the code.  If you do not have a copy,
007 * see:
008 *
009 *      http://www.gnu.org/copyleft/lesser.html
010 *
011 * Copyright for this code is held jointly by the individual
012 * authors.  These should be listed in @author doc comments.
013 *
014 * For more information on the BioJava project and its aims,
015 * or to join the biojava-l mailing list, visit the home page
016 * at:
017 *
018 *      http://www.biojava.org/
019 *
020 */
021
022package org.biojava.nbio.structure.geometry;
023
024import javax.vecmath.Matrix3d;
025import javax.vecmath.Matrix4d;
026import javax.vecmath.Point3d;
027import javax.vecmath.Vector3d;
028
029import org.slf4j.Logger;
030import org.slf4j.LoggerFactory;
031
032/**
033 * Implementation of the Quaternion-Based Characteristic Polynomial algorithm
034 * for RMSD and Superposition calculations.
035 * <p>
036 * Usage:
037 * <p>
038 * The input consists of 2 Point3d arrays of equal length. The input coordinates
039 * are not changed.
040 *
041 * <pre>
042 *    Point3d[] x = ...
043 *    Point3d[] y = ...
044 *    SuperPositionQCP qcp = new SuperPositionQCP();
045 *    qcp.set(x, y);
046 * </pre>
047 * <p>
048 * or with weighting factors [0 - 1]]
049 *
050 * <pre>
051 *    double[] weights = ...
052 *    qcp.set(x, y, weights);
053 * </pre>
054 * <p>
055 * For maximum efficiency, create a SuperPositionQCP object once and reuse it.
056 * <p>
057 * A. Calculate rmsd only
058 *
059 * <pre>
060 * double rmsd = qcp.getRmsd();
061 * </pre>
062 * <p>
063 * B. Calculate a 4x4 transformation (rotation and translation) matrix
064 *
065 * <pre>
066 * Matrix4d rottrans = qcp.getTransformationMatrix();
067 * </pre>
068 * <p>
069 * C. Get transformated points (y superposed onto the reference x)
070 *
071 * <pre>
072 * Point3d[] ySuperposed = qcp.getTransformedCoordinates();
073 * </pre>
074 * <p>
075 * Citations:
076 * <p>
077 * Liu P, Agrafiotis DK, & Theobald DL (2011) Reply to comment on: "Fast
078 * determination of the optimal rotation matrix for macromolecular
079 * superpositions." Journal of Computational Chemistry 32(1):185-186.
080 * [http://dx.doi.org/10.1002/jcc.21606]
081 * <p>
082 * Liu P, Agrafiotis DK, & Theobald DL (2010) "Fast determination of the optimal
083 * rotation matrix for macromolecular superpositions." Journal of Computational
084 * Chemistry 31(7):1561-1563. [http://dx.doi.org/10.1002/jcc.21439]
085 * <p>
086 * Douglas L Theobald (2005) "Rapid calculation of RMSDs using a
087 * quaternion-based characteristic polynomial." Acta Crystallogr A
088 * 61(4):478-480. [http://dx.doi.org/10.1107/S0108767305015266 ]
089 * <p>
090 * This is an adoption of the original C code QCProt 1.4 (2012, October 10) to
091 * Java. The original C source code is available from
092 * http://theobald.brandeis.edu/qcp/ and was developed by
093 * <p>
094 * Douglas L. Theobald Department of Biochemistry MS 009 Brandeis University 415
095 * South St Waltham, MA 02453 USA
096 * <p>
097 * dtheobald@brandeis.edu
098 * <p>
099 * Pu Liu Johnson & Johnson Pharmaceutical Research and Development, L.L.C. 665
100 * Stockton Drive Exton, PA 19341 USA
101 * <p>
102 * pliu24@its.jnj.com
103 * <p>
104 *
105 * @author Douglas L. Theobald (original C code)
106 * @author Pu Liu (original C code)
107 * @author Peter Rose (adopted to Java)
108 * @author Aleix Lafita (adopted to Java)
109 */
110public final class SuperPositionQCP extends SuperPositionAbstract {
111
112        private static final Logger logger = LoggerFactory.getLogger(SuperPositionQCP.class);
113
114        private double evec_prec = 1E-6;
115        private double eval_prec = 1E-11;
116
117        private Point3d[] x;
118        private Point3d[] y;
119
120        private double[] weight;
121        private double wsum;
122
123        private Point3d[] xref;
124        private Point3d[] yref;
125        private Point3d xtrans;
126        private Point3d ytrans;
127
128        private double e0;
129        private Matrix3d rotmat = new Matrix3d();
130        private Matrix4d transformation = new Matrix4d();
131        private double rmsd = 0;
132        private double Sxy, Sxz, Syx, Syz, Szx, Szy;
133        private double SxxpSyy, Szz, mxEigenV, SyzmSzy, SxzmSzx, SxymSyx;
134        private double SxxmSyy, SxypSyx, SxzpSzx;
135        private double Syy, Sxx, SyzpSzy;
136        private boolean rmsdCalculated = false;
137        private boolean transformationCalculated = false;
138        private boolean centered = false;
139
140        /**
141         * Default constructor for the quaternion based superposition algorithm.
142         *
143         * @param centered
144         *            true if the point arrays are centered at the origin (faster),
145         *            false otherwise
146         */
147        public SuperPositionQCP(boolean centered) {
148                super(centered);
149        }
150
151        /**
152         * Constructor with option to set the precision values.
153         *
154         * @param centered
155         *            true if the point arrays are centered at the origin (faster),
156         *            false otherwise
157         * @param evec_prec
158         *            required eigenvector precision
159         * @param eval_prec
160         *            required eigenvalue precision
161         */
162        public SuperPositionQCP(boolean centered, double evec_prec, double eval_prec) {
163                super(centered);
164                this.evec_prec = evec_prec;
165                this.eval_prec = eval_prec;
166        }
167
168        /**
169         * Sets the two input coordinate arrays. These input arrays must be of equal
170         * length. Input coordinates are not modified.
171         *
172         * @param x
173         *            3d points of reference coordinate set
174         * @param y
175         *            3d points of coordinate set for superposition
176         */
177        private void set(Point3d[] x, Point3d[] y) {
178                this.x = x;
179                this.y = y;
180                rmsdCalculated = false;
181                transformationCalculated = false;
182        }
183
184        /**
185         * Sets the two input coordinate arrays and weight array. All input arrays
186         * must be of equal length. Input coordinates are not modified.
187         *
188         * @param x
189         *            3d points of reference coordinate set
190         * @param y
191         *            3d points of coordinate set for superposition
192         * @param weight
193         *            a weight in the inclusive range [0,1] for each point
194         */
195        private void set(Point3d[] x, Point3d[] y, double[] weight) {
196                this.x = x;
197                this.y = y;
198                this.weight = weight;
199                rmsdCalculated = false;
200                transformationCalculated = false;
201        }
202
203        /**
204         * Return the RMSD of the superposition of input coordinate set y onto x.
205         * Note, this is the fasted way to calculate an RMSD without actually
206         * superposing the two sets. The calculation is performed "lazy", meaning
207         * calculations are only performed if necessary.
208         *
209         * @return root mean square deviation for superposition of y onto x
210         */
211        private double getRmsd() {
212                if (!rmsdCalculated) {
213                        calcRmsd(x, y);
214                        rmsdCalculated = true;
215                }
216                return rmsd;
217        }
218
219        /**
220         * Weighted superposition.
221         *
222         * @param fixed
223         * @param moved
224         * @param weight
225         *            array of weigths for each equivalent point position
226         * @return
227         */
228        public Matrix4d weightedSuperpose(Point3d[] fixed, Point3d[] moved, double[] weight) {
229                set(moved, fixed, weight);
230                getRotationMatrix();
231                if (!centered) {
232                        calcTransformation();
233                } else {
234                        transformation.set(rotmat);
235                }
236                return transformation;
237        }
238
239        private Matrix3d getRotationMatrix() {
240                getRmsd();
241                if (!transformationCalculated) {
242                        calcRotationMatrix();
243                        transformationCalculated = true;
244                }
245                return rotmat;
246        }
247
248        /**
249         * Calculates the RMSD value for superposition of y onto x. This requires
250         * the coordinates to be precentered.
251         *
252         * @param x
253         *            3d points of reference coordinate set
254         * @param y
255         *            3d points of coordinate set for superposition
256         */
257        private void calcRmsd(Point3d[] x, Point3d[] y) {
258                if (centered) {
259                        innerProduct(y, x);
260                } else {
261                        // translate to origin
262                        xref = CalcPoint.clonePoint3dArray(x);
263                        xtrans = CalcPoint.centroid(xref);
264                        logger.debug("x centroid: " + xtrans);
265                        xtrans.negate();
266                        CalcPoint.translate(new Vector3d(xtrans), xref);
267
268                        yref = CalcPoint.clonePoint3dArray(y);
269                        ytrans = CalcPoint.centroid(yref);
270                        logger.debug("y centroid: " + ytrans);
271                        ytrans.negate();
272                        CalcPoint.translate(new Vector3d(ytrans), yref);
273                        innerProduct(yref, xref);
274                }
275                calcRmsd(wsum);
276        }
277
278        /**
279         * Superposition coords2 onto coords1 -- in other words, coords2 is rotated,
280         * coords1 is held fixed
281         */
282        private void calcTransformation() {
283
284                // transformation.set(rotmat,new Vector3d(0,0,0), 1);
285                transformation.set(rotmat);
286                // long t2 = System.nanoTime();
287                // System.out.println("create transformation: " + (t2-t1));
288                // System.out.println("m3d -> m4d");
289                // System.out.println(transformation);
290
291                // combine with x -> origin translation
292                Matrix4d trans = new Matrix4d();
293                trans.setIdentity();
294                trans.setTranslation(new Vector3d(xtrans));
295                transformation.mul(transformation, trans);
296                // System.out.println("setting xtrans");
297                // System.out.println(transformation);
298
299                // combine with origin -> y translation
300                ytrans.negate();
301                Matrix4d transInverse = new Matrix4d();
302                transInverse.setIdentity();
303                transInverse.setTranslation(new Vector3d(ytrans));
304                transformation.mul(transInverse, transformation);
305                // System.out.println("setting ytrans");
306                // System.out.println(transformation);
307        }
308
309        /**
310         * Calculates the inner product between two coordinate sets x and y
311         * (optionally weighted, if weights set through
312         * {@link #set(Point3d[], Point3d[], double[])}). It also calculates an
313         * upper bound of the most positive root of the key matrix.
314         * http://theobald.brandeis.edu/qcp/qcprot.c
315         *
316         * @param coords1
317         * @param coords2
318         * @return
319         */
320        private void innerProduct(Point3d[] coords1, Point3d[] coords2) {
321                double x1, x2, y1, y2, z1, z2;
322                double g1 = 0.0, g2 = 0.0;
323
324                Sxx = 0;
325                Sxy = 0;
326                Sxz = 0;
327                Syx = 0;
328                Syy = 0;
329                Syz = 0;
330                Szx = 0;
331                Szy = 0;
332                Szz = 0;
333
334                if (weight != null) {
335                        wsum = 0;
336                        for (int i = 0; i < coords1.length; i++) {
337
338                                wsum += weight[i];
339
340                                x1 = weight[i] * coords1[i].x;
341                                y1 = weight[i] * coords1[i].y;
342                                z1 = weight[i] * coords1[i].z;
343
344                                g1 += x1 * coords1[i].x + y1 * coords1[i].y + z1 * coords1[i].z;
345
346                                x2 = coords2[i].x;
347                                y2 = coords2[i].y;
348                                z2 = coords2[i].z;
349
350                                g2 += weight[i] * (x2 * x2 + y2 * y2 + z2 * z2);
351
352                                Sxx += (x1 * x2);
353                                Sxy += (x1 * y2);
354                                Sxz += (x1 * z2);
355
356                                Syx += (y1 * x2);
357                                Syy += (y1 * y2);
358                                Syz += (y1 * z2);
359
360                                Szx += (z1 * x2);
361                                Szy += (z1 * y2);
362                                Szz += (z1 * z2);
363                        }
364                } else {
365                        for (int i = 0; i < coords1.length; i++) {
366                                g1 += coords1[i].x * coords1[i].x + coords1[i].y * coords1[i].y + coords1[i].z * coords1[i].z;
367                                g2 += coords2[i].x * coords2[i].x + coords2[i].y * coords2[i].y + coords2[i].z * coords2[i].z;
368
369                                Sxx += coords1[i].x * coords2[i].x;
370                                Sxy += coords1[i].x * coords2[i].y;
371                                Sxz += coords1[i].x * coords2[i].z;
372
373                                Syx += coords1[i].y * coords2[i].x;
374                                Syy += coords1[i].y * coords2[i].y;
375                                Syz += coords1[i].y * coords2[i].z;
376
377                                Szx += coords1[i].z * coords2[i].x;
378                                Szy += coords1[i].z * coords2[i].y;
379                                Szz += coords1[i].z * coords2[i].z;
380                        }
381                        wsum = coords1.length;
382                }
383
384                e0 = (g1 + g2) * 0.5;
385        }
386
387        private int calcRmsd(double len) {
388                double Sxx2 = Sxx * Sxx;
389                double Syy2 = Syy * Syy;
390                double Szz2 = Szz * Szz;
391
392                double Sxy2 = Sxy * Sxy;
393                double Syz2 = Syz * Syz;
394                double Sxz2 = Sxz * Sxz;
395
396                double Syx2 = Syx * Syx;
397                double Szy2 = Szy * Szy;
398                double Szx2 = Szx * Szx;
399
400                double SyzSzymSyySzz2 = 2.0 * (Syz * Szy - Syy * Szz);
401                double Sxx2Syy2Szz2Syz2Szy2 = Syy2 + Szz2 - Sxx2 + Syz2 + Szy2;
402
403                double c2 = -2.0 * (Sxx2 + Syy2 + Szz2 + Sxy2 + Syx2 + Sxz2 + Szx2 + Syz2 + Szy2);
404                double c1 = 8.0 * (Sxx * Syz * Szy + Syy * Szx * Sxz + Szz * Sxy * Syx - Sxx * Syy * Szz - Syz * Szx * Sxy
405                                - Szy * Syx * Sxz);
406
407                SxzpSzx = Sxz + Szx;
408                SyzpSzy = Syz + Szy;
409                SxypSyx = Sxy + Syx;
410                SyzmSzy = Syz - Szy;
411                SxzmSzx = Sxz - Szx;
412                SxymSyx = Sxy - Syx;
413                SxxpSyy = Sxx + Syy;
414                SxxmSyy = Sxx - Syy;
415
416                double Sxy2Sxz2Syx2Szx2 = Sxy2 + Sxz2 - Syx2 - Szx2;
417
418                double c0 = Sxy2Sxz2Syx2Szx2 * Sxy2Sxz2Syx2Szx2
419                                + (Sxx2Syy2Szz2Syz2Szy2 + SyzSzymSyySzz2) * (Sxx2Syy2Szz2Syz2Szy2 - SyzSzymSyySzz2)
420                                + (-(SxzpSzx) * (SyzmSzy) + (SxymSyx) * (SxxmSyy - Szz))
421                                                * (-(SxzmSzx) * (SyzpSzy) + (SxymSyx) * (SxxmSyy + Szz))
422                                + (-(SxzpSzx) * (SyzpSzy) - (SxypSyx) * (SxxpSyy - Szz))
423                                                * (-(SxzmSzx) * (SyzmSzy) - (SxypSyx) * (SxxpSyy + Szz))
424                                + (+(SxypSyx) * (SyzpSzy) + (SxzpSzx) * (SxxmSyy + Szz))
425                                                * (-(SxymSyx) * (SyzmSzy) + (SxzpSzx) * (SxxpSyy + Szz))
426                                + (+(SxypSyx) * (SyzmSzy) + (SxzmSzx) * (SxxmSyy - Szz))
427                                                * (-(SxymSyx) * (SyzpSzy) + (SxzmSzx) * (SxxpSyy - Szz));
428
429                mxEigenV = e0;
430
431                int i;
432                for (i = 1; i < 51; ++i) {
433                        double oldg = mxEigenV;
434                        double x2 = mxEigenV * mxEigenV;
435                        double b = (x2 + c2) * mxEigenV;
436                        double a = b + c1;
437                        double delta = ((a * mxEigenV + c0) / (2.0 * x2 * mxEigenV + b + a));
438                        mxEigenV -= delta;
439
440                        if (Math.abs(mxEigenV - oldg) < Math.abs(eval_prec * mxEigenV))
441                                break;
442                }
443
444                if (i == 50) {
445                        logger.warn(String.format("More than %d iterations needed!", i));
446                } else {
447                        logger.info(String.format("%d iterations needed!", i));
448                }
449
450                /*
451                 * the fabs() is to guard against extremely small, but *negative*
452                 * numbers due to floating point error
453                 */
454                rmsd = Math.sqrt(Math.abs(2.0 * (e0 - mxEigenV) / len));
455
456                return 1;
457        }
458
459        private int calcRotationMatrix() {
460                double a11 = SxxpSyy + Szz - mxEigenV;
461                double a12 = SyzmSzy;
462                double a13 = -SxzmSzx;
463                double a14 = SxymSyx;
464                double a21 = SyzmSzy;
465                double a22 = SxxmSyy - Szz - mxEigenV;
466                double a23 = SxypSyx;
467                double a24 = SxzpSzx;
468                double a31 = a13;
469                double a32 = a23;
470                double a33 = Syy - Sxx - Szz - mxEigenV;
471                double a34 = SyzpSzy;
472                double a41 = a14;
473                double a42 = a24;
474                double a43 = a34;
475                double a44 = Szz - SxxpSyy - mxEigenV;
476                double a3344_4334 = a33 * a44 - a43 * a34;
477                double a3244_4234 = a32 * a44 - a42 * a34;
478                double a3243_4233 = a32 * a43 - a42 * a33;
479                double a3143_4133 = a31 * a43 - a41 * a33;
480                double a3144_4134 = a31 * a44 - a41 * a34;
481                double a3142_4132 = a31 * a42 - a41 * a32;
482                double q1 = a22 * a3344_4334 - a23 * a3244_4234 + a24 * a3243_4233;
483                double q2 = -a21 * a3344_4334 + a23 * a3144_4134 - a24 * a3143_4133;
484                double q3 = a21 * a3244_4234 - a22 * a3144_4134 + a24 * a3142_4132;
485                double q4 = -a21 * a3243_4233 + a22 * a3143_4133 - a23 * a3142_4132;
486
487                double qsqr = q1 * q1 + q2 * q2 + q3 * q3 + q4 * q4;
488
489                /*
490                 * The following code tries to calculate another column in the adjoint
491                 * matrix when the norm of the current column is too small. Usually this
492                 * commented block will never be activated. To be absolutely safe this
493                 * should be uncommented, but it is most likely unnecessary.
494                 */
495                if (qsqr < evec_prec) {
496                        q1 = a12 * a3344_4334 - a13 * a3244_4234 + a14 * a3243_4233;
497                        q2 = -a11 * a3344_4334 + a13 * a3144_4134 - a14 * a3143_4133;
498                        q3 = a11 * a3244_4234 - a12 * a3144_4134 + a14 * a3142_4132;
499                        q4 = -a11 * a3243_4233 + a12 * a3143_4133 - a13 * a3142_4132;
500                        qsqr = q1 * q1 + q2 * q2 + q3 * q3 + q4 * q4;
501
502                        if (qsqr < evec_prec) {
503                                double a1324_1423 = a13 * a24 - a14 * a23, a1224_1422 = a12 * a24 - a14 * a22;
504                                double a1223_1322 = a12 * a23 - a13 * a22, a1124_1421 = a11 * a24 - a14 * a21;
505                                double a1123_1321 = a11 * a23 - a13 * a21, a1122_1221 = a11 * a22 - a12 * a21;
506
507                                q1 = a42 * a1324_1423 - a43 * a1224_1422 + a44 * a1223_1322;
508                                q2 = -a41 * a1324_1423 + a43 * a1124_1421 - a44 * a1123_1321;
509                                q3 = a41 * a1224_1422 - a42 * a1124_1421 + a44 * a1122_1221;
510                                q4 = -a41 * a1223_1322 + a42 * a1123_1321 - a43 * a1122_1221;
511                                qsqr = q1 * q1 + q2 * q2 + q3 * q3 + q4 * q4;
512
513                                if (qsqr < evec_prec) {
514                                        q1 = a32 * a1324_1423 - a33 * a1224_1422 + a34 * a1223_1322;
515                                        q2 = -a31 * a1324_1423 + a33 * a1124_1421 - a34 * a1123_1321;
516                                        q3 = a31 * a1224_1422 - a32 * a1124_1421 + a34 * a1122_1221;
517                                        q4 = -a31 * a1223_1322 + a32 * a1123_1321 - a33 * a1122_1221;
518                                        qsqr = q1 * q1 + q2 * q2 + q3 * q3 + q4 * q4;
519
520                                        if (qsqr < evec_prec) {
521                                                /*
522                                                 * if qsqr is still too small, return the identity
523                                                 * matrix.
524                                                 */
525                                                rotmat.setIdentity();
526
527                                                return 0;
528                                        }
529                                }
530                        }
531                }
532
533                double normq = Math.sqrt(qsqr);
534                q1 /= normq;
535                q2 /= normq;
536                q3 /= normq;
537                q4 /= normq;
538
539                logger.debug("q: " + q1 + " " + q2 + " " + q3 + " " + q4);
540
541                double a2 = q1 * q1;
542                double x2 = q2 * q2;
543                double y2 = q3 * q3;
544                double z2 = q4 * q4;
545
546                double xy = q2 * q3;
547                double az = q1 * q4;
548                double zx = q4 * q2;
549                double ay = q1 * q3;
550                double yz = q3 * q4;
551                double ax = q1 * q2;
552
553                rotmat.m00 = a2 + x2 - y2 - z2;
554                rotmat.m01 = 2 * (xy + az);
555                rotmat.m02 = 2 * (zx - ay);
556
557                rotmat.m10 = 2 * (xy - az);
558                rotmat.m11 = a2 - x2 + y2 - z2;
559                rotmat.m12 = 2 * (yz + ax);
560
561                rotmat.m20 = 2 * (zx + ay);
562                rotmat.m21 = 2 * (yz - ax);
563                rotmat.m22 = a2 - x2 - y2 + z2;
564
565                return 1;
566        }
567
568        @Override
569        public double getRmsd(Point3d[] fixed, Point3d[] moved) {
570                set(moved, fixed);
571                return getRmsd();
572        }
573
574        @Override
575        public Matrix4d superpose(Point3d[] fixed, Point3d[] moved) {
576                set(moved, fixed);
577                getRotationMatrix();
578                if (!centered) {
579                        calcTransformation();
580                } else {
581                        transformation.set(rotmat);
582                }
583                return transformation;
584        }
585
586        /**
587         * @param fixed
588         * @param moved
589         * @param weight
590         *            array of weigths for each equivalent point position
591         * @return weighted RMSD.
592         */
593        public double getWeightedRmsd(Point3d[] fixed, Point3d[] moved, double[] weight) {
594                set(moved, fixed, weight);
595                return getRmsd();
596        }
597
598        /**
599         * The QCP method can be used as a two-step calculation: first compute the
600         * RMSD (fast) and then compute the superposition.
601         *
602         * This method assumes that the RMSD of two arrays of points has been
603         * already calculated using {@link #getRmsd(Point3d[], Point3d[])} method
604         * and calculates the transformation of the same two point arrays.
605         *
606         * @param fixed
607         * @param moved
608         * @return transformation matrix as a Matrix4d to superpose moved onto fixed
609         *         point arrays
610         */
611        public Matrix4d superposeAfterRmsd() {
612
613                if (!rmsdCalculated) {
614                        throw new IllegalStateException("The RMSD was not yet calculated. Use the superpose() method instead.");
615                }
616
617                getRotationMatrix();
618                if (!centered) {
619                        calcTransformation();
620                } else {
621                        transformation.set(rotmat);
622                }
623                return transformation;
624        }
625
626}