OpenStructure
lddt.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import mol
4 from ost import conop
5 
6 # use cdist of scipy, fallback to (slower) numpy implementation if scipy is not
7 # available
8 try:
9  from scipy.spatial.distance import cdist
10 except:
11  def cdist(p1, p2):
12  x2 = np.sum(p1**2, axis=1) # (m)
13  y2 = np.sum(p2**2, axis=1) # (n)
14  xy = np.matmul(p1, p2.T) # (m, n)
15  x2 = x2.reshape(-1, 1)
16  return np.sqrt(x2 - 2*xy + y2) # (m, n)
17 
19  """ Defines atoms for custom compounds
20 
21  lDDT requires the reference atoms of a compound which are typically
22  extracted from a :class:`ost.conop.CompoundLib`. This lightweight
23  container allows to handle arbitrary compounds which are not
24  necessarily in the compound library.
25 
26  :param atom_names: Names of atoms of custom compound
27  :type atom_names: :class:`list` of :class:`str`
28  """
29  def __init__(self, atom_names):
30  self.atom_namesatom_names = atom_names
31 
32  @staticmethod
33  def FromResidue(res):
34  """ Construct custom compound from residue
35 
36  :param res: Residue from which reference atom names are extracted,
37  hydrogen/deuterium atoms are filtered out
38  :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
39  :returns: :class:`CustomCompound`
40  """
41  at_names = [a.name for a in res.atoms if a.element not in ["H", "D"]]
42  if len(at_names) != len(set(at_names)):
43  raise RuntimeError("Duplicate atoms detected in CustomCompound")
44  compound = CustomCompound(at_names)
45  return compound
46 
48  """Container for symmetric compounds
49 
50  lDDT considers symmetries and selects the one resulting in the highest
51  possible score.
52 
53  A symmetry is defined as a renaming operation on one or more atoms that
54  leads to a chemically equivalent residue. Example would be OD1 and OD2 in
55  ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
56  residue.
57 
58  Use :func:`AddSymmetricCompound` to define a symmetry which can then
59  directly be accessed through the *symmetric_compounds* member.
60  """
61  def __init__(self):
62  self.symmetric_compoundssymmetric_compounds = dict()
63 
64  def AddSymmetricCompound(self, name, symmetric_atoms):
65  """Adds symmetry for compound with *name*
66 
67  :param name: Name of compound with symmetry
68  :type name: :class:`str`
69  :param symmetric_atoms: Pairs of atom names that define renaming
70  operation, i.e. after applying all switches
71  defined in the tuples, the resulting residue
72  should be chemically equivalent. Atom names
73  must refer to the PDB component dictionary.
74  :type symmetric_atoms: :class:`list` of :class:`tuple`
75  """
76  for pair in symmetric_atoms:
77  if len(pair) != 2:
78  raise RuntimeError("Expect pairs when defining symmetries")
79  self.symmetric_compoundssymmetric_compounds[name] = symmetric_atoms
80 
81 
83  """Constructs and returns :class:`SymmetrySettings` object for natural amino
84  acids
85  """
86  symmetry_settings = SymmetrySettings()
87 
88  # ASP
89  symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")])
90 
91  # GLU
92  symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")])
93 
94  # LEU
95  symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")])
96 
97  # VAL
98  symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")])
99 
100  # ARG
101  symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")])
102 
103  # PHE
104  symmetry_settings.AddSymmetricCompound(
105  "PHE", [("CD1", "CD2"), ("CE1", "CE2")]
106  )
107 
108  # TYR
109  symmetry_settings.AddSymmetricCompound(
110  "TYR", [("CD1", "CD2"), ("CE1", "CE2")]
111  )
112 
113  return symmetry_settings
114 
115 
117  """lDDT scorer object for a specific target
118 
119  Sets up everything to score models of that target. lDDT (local distance
120  difference test) is defined as fraction of pairwise distances which exhibit
121  a difference < threshold when considering target and model. In case of
122  multiple thresholds, the average is returned. See
123 
124  V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
125  superposition-free score for comparing protein structures and models using
126  distance difference tests, Bioinformatics, 2013
127 
128  :param target: The target
129  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
130  :param compound_lib: Compound library from which a compound for each residue
131  is extracted based on its name. Uses
132  :func:`ost.conop.GetDefaultLib` if not given, raises
133  if this returns no valid compound library. Atoms
134  defined in the compound are searched in the residue and
135  build the reference for scoring. If the residue has
136  atoms with names ["A", "B", "C"] but the corresponding
137  compound only has ["A", "B"], "A" and "B" are
138  considered for scoring. If the residue has atoms
139  ["A", "B"] but the compound has ["A", "B", "C"], "C" is
140  considered missing and does not influence scoring, even
141  if present in the model.
142  :param custom_compounds: Custom compounds defining reference atoms. If
143  given, *custom_compounds* take precedent over
144  *compound_lib*.
145  :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
146  key and :class:`CustomCompound` as value.
147  :type compound_lib: :class:`ost.conop.CompoundLib`
148  :param inclusion_radius: All pairwise distances < *inclusion_radius* are
149  considered for scoring
150  :type inclusion_radius: :class:`float`
151  :param sequence_separation: Only pairwise distances between atoms of
152  residues which are further apart than this
153  threshold are considered. Residue distance is
154  based on resnum. The default (0) considers all
155  pairwise distances except intra-residue
156  distances.
157  :type sequence_separation: :class:`int`
158  :param symmetry_settings: Define residues exhibiting internal symmetry, uses
159  :func:`GetDefaultSymmetrySettings` if not given.
160  :type symmetry_settings: :class:`SymmetrySettings`
161  :param seqres_mapping: Mapping of model residues at the scoring stage
162  happens with residue numbers defining their location
163  in a reference sequence (SEQRES) using one based
164  indexing. If the residue numbers in *target* don't
165  correspond to that SEQRES, you can specify the
166  mapping manually. You can provide a dictionary to
167  specify a reference sequence (SEQRES) for one or more
168  chain(s). Key: chain name, value: alignment
169  (seq1: SEQRES, seq2: sequence of residues in chain).
170  Example: The residues in a chain with name "A" have
171  sequence "YEAH" and residue numbers [42,43,44,45].
172  You can provide an alignment with seq1 "``HELLYEAH``"
173  and seq2 "``----YEAH``". "Y" gets assigned residue
174  number 5, "E" gets assigned 6 and so on no matter
175  what the original residue numbers were.
176  :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
177  :class:`ost.seq.AlignmentHandle`)
178  :param bb_only: Only consider atoms with name "CA" in case of amino acids and
179  "C3'" for Nucleotides. this invalidates *compound_lib*.
180  Raises if any residue in *target* is not
181  `r.chem_class.IsPeptideLinking()` or
182  `r.chem_class.IsNucleotideLinking()`
183  :type bb_only: :class:`bool`
184  :raises: :class:`RuntimeError` if *target* contains compound which is not in
185  *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
186  specifies symmetric atoms that are not present in the according
187  compound in *compound_lib*, :class:`RuntimeError` if
188  *seqres_mapping* is not provided and *target* contains residue
189  numbers with insertion codes or the residue numbers for each chain
190  are not monotonically increasing, :class:`RuntimeError` if
191  *seqres_mapping* is provided but an alignment is invalid
192  (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
193  residues in corresponding chains).
194  """
195  def __init__(
196  self,
197  target,
198  compound_lib=None,
199  custom_compounds=None,
200  inclusion_radius=15,
201  sequence_separation=0,
202  symmetry_settings=None,
203  seqres_mapping=dict(),
204  bb_only=False
205  ):
206 
207  self.targettarget = target
208  self.inclusion_radiusinclusion_radius = inclusion_radius
209  self.sequence_separationsequence_separation = sequence_separation
210  if compound_lib is None:
211  compound_lib = conop.GetDefaultLib()
212  if compound_lib is None:
213  raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
214  "returns no valid compound library")
215  self.compound_libcompound_lib = compound_lib
216  self.custom_compoundscustom_compounds = custom_compounds
217  if symmetry_settings is None:
219  else:
220  self.symmetry_settingssymmetry_settings = symmetry_settings
221 
222  # whether to only consider atoms with name "CA" (amino acids) or C3'
223  # (nucleotides), invalidates *compound_lib*
224  self.bb_onlybb_only=bb_only
225 
226  # names of heavy atoms of each unique compound present in *target* as
227  # extracted from *compound_lib*, e.g.
228  # self.compound_anames["GLY"] = ["N", "CA", "C", "O"]
229  self.compound_anamescompound_anames = dict()
230 
231  # stores symmetry information for those compounds as defined in
232  # *symmetry_settings*
233  self.compound_symmetric_atomscompound_symmetric_atoms = dict()
234 
235  # list of len(target.chains) containing all chain names in *target*
236  self.chain_nameschain_names = list()
237 
238  # list of len(target.residues) containing all compound names in *target*
239  self.compound_namescompound_names = list()
240 
241  # list of len(target.residues) defining start pos in internal reference
242  # positions for each residue
243  self.res_start_indicesres_start_indices = list()
244 
245  # list of len(target.residues) defining residue numbers in target
246  self.res_resnumsres_resnums = list()
247 
248  # list of len(target.chains) defining start pos in internal reference
249  # positions for each chain
250  self.chain_start_indiceschain_start_indices = list()
251 
252  # list of len(target.chains) defining start pos in self.compound_names
253  # for each chain
254  self.chain_res_start_indiceschain_res_start_indices = list()
255 
256  # maps residues in *target* to indices in
257  # self.compound_names/self.res_start_indices. A residue gets identified
258  # by a tuple (first element: chain name, second element: residue number,
259  # residue number is either the actual residue number in *target* or
260  # given by *seqres_mapping*)
261  self.res_mapperres_mapper = dict()
262 
263  # number of atoms as specified in compounds. not all are necessarily
264  # covered by structure
265  self.n_atomsn_atoms = None
266 
267  # stores an index for each AtomHandle in *target*
268  # (atom hashcode => index)
269  self.atom_indicesatom_indices = dict()
270 
271  # store indices of all atoms that have symmetry properties
272  self.symmetric_atomssymmetric_atoms = set()
273 
274  # the actual target positions in a numpy array of shape (self.n_atoms,3)
275  self.positionspositions = None
276 
277  # setup members defined above
278  self._SetupEnv_SetupEnv(self.compound_libcompound_lib, self.custom_compoundscustom_compounds,
279  self.symmetry_settingssymmetry_settings, seqres_mapping, self.bb_onlybb_only)
280 
281  # distance related members are lazily computed as they're affected
282  # by different flavours of lDDT (e.g. lDDT including inter-chain
283  # contacts or not etc.)
284 
285  # stores for each atom the other atoms within inclusion_radius
286  self._ref_indices_ref_indices = None
287  # the corresponding distances
288  self._ref_distances_ref_distances = None
289 
290  # The following lists will be sparsely populated. We keep for each
291  # symmetry related atom the distances towards all atoms which are NOT
292  # affected by symmetry. So we can evaluate two symmetric versions
293  # against the fixed stuff later on and select the better scoring one.
294  self._sym_ref_indices_sym_ref_indices = None
295  self._sym_ref_distances_sym_ref_distances = None
296 
297  # exactly the same as above but without interchain contacts
298  # => single-chain (sc)
299  self._ref_indices_sc_ref_indices_sc = None
300  self._ref_distances_sc_ref_distances_sc = None
301  self._sym_ref_indices_sc_sym_ref_indices_sc = None
302  self._sym_ref_distances_sc_sym_ref_distances_sc = None
303 
304  # exactly the same as above but without intrachain contacts
305  # => inter-chain (ic)
306  self._ref_indices_ic_ref_indices_ic = None
307  self._ref_distances_ic_ref_distances_ic = None
308  self._sym_ref_indices_ic_sym_ref_indices_ic = None
309  self._sym_ref_distances_ic_sym_ref_distances_ic = None
310 
311  # input parameter checking
312  self._ProcessSequenceSeparation_ProcessSequenceSeparation()
313 
314  @property
315  def ref_indices(self):
316  if self._ref_indices_ref_indices is None:
317  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
318  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
319  self.atom_indicesatom_indices,
320  self.inclusion_radiusinclusion_radius)
321  return self._ref_indices_ref_indices
322 
323  @property
324  def ref_distances(self):
325  if self._ref_distances_ref_distances is None:
326  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
327  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
328  self.atom_indicesatom_indices,
329  self.inclusion_radiusinclusion_radius)
330  return self._ref_distances_ref_distances
331 
332  @property
333  def sym_ref_indices(self):
334  if self._sym_ref_indices_sym_ref_indices is None:
335  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
336  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
337  self.ref_indicesref_indices, self.ref_distancesref_distances)
338  return self._sym_ref_indices_sym_ref_indices
339 
340  @property
341  def sym_ref_distances(self):
342  if self._sym_ref_distances_sym_ref_distances is None:
343  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
344  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
345  self.ref_indicesref_indices, self.ref_distancesref_distances)
346  return self._sym_ref_distances_sym_ref_distances
347 
348  @property
349  def ref_indices_sc(self):
350  if self._ref_indices_sc_ref_indices_sc is None:
351  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
352  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
353  self.chain_start_indiceschain_start_indices,
354  self.ref_indicesref_indices,
355  self.ref_distancesref_distances)
356  return self._ref_indices_sc_ref_indices_sc
357 
358  @property
359  def ref_distances_sc(self):
360  if self._ref_distances_sc_ref_distances_sc is None:
361  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
362  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
363  self.chain_start_indiceschain_start_indices,
364  self.ref_indicesref_indices,
365  self.ref_distancesref_distances)
366  return self._ref_distances_sc_ref_distances_sc
367 
368  @property
370  if self._sym_ref_indices_sc_sym_ref_indices_sc is None:
371  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
372  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
373  self.symmetric_atomssymmetric_atoms,
374  self.ref_indices_scref_indices_sc,
375  self.ref_distances_scref_distances_sc)
376  return self._sym_ref_indices_sc_sym_ref_indices_sc
377 
378  @property
380  if self._sym_ref_distances_sc_sym_ref_distances_sc is None:
381  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
382  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
383  self.symmetric_atomssymmetric_atoms,
384  self.ref_indices_scref_indices_sc,
385  self.ref_distances_scref_distances_sc)
386  return self._sym_ref_distances_sc_sym_ref_distances_sc
387 
388  @property
389  def ref_indices_ic(self):
390  if self._ref_indices_ic_ref_indices_ic is None:
391  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
392  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
393  self.chain_start_indiceschain_start_indices,
394  self.ref_indicesref_indices,
395  self.ref_distancesref_distances)
396  return self._ref_indices_ic_ref_indices_ic
397 
398  @property
399  def ref_distances_ic(self):
400  if self._ref_distances_ic_ref_distances_ic is None:
401  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
402  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
403  self.chain_start_indiceschain_start_indices,
404  self.ref_indicesref_indices,
405  self.ref_distancesref_distances)
406  return self._ref_distances_ic_ref_distances_ic
407 
408  @property
410  if self._sym_ref_indices_ic_sym_ref_indices_ic is None:
411  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
412  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
413  self.symmetric_atomssymmetric_atoms,
414  self.ref_indices_icref_indices_ic,
415  self.ref_distances_icref_distances_ic)
416  return self._sym_ref_indices_ic_sym_ref_indices_ic
417 
418  @property
420  if self._sym_ref_distances_ic_sym_ref_distances_ic is None:
421  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
422  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
423  self.symmetric_atomssymmetric_atoms,
424  self.ref_indices_icref_indices_ic,
425  self.ref_distances_icref_distances_ic)
426  return self._sym_ref_distances_ic_sym_ref_distances_ic
427 
428  def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
429  local_lddt_prop=None, local_contact_prop=None,
430  chain_mapping=None, no_interchain=False,
431  no_intrachain=False, penalize_extra_chains=False,
432  residue_mapping=None, return_dist_test=False,
433  check_resnames=True, add_mdl_contacts=False):
434  """Computes lDDT of *model* - globally and per-residue
435 
436  :param model: Model to be scored - models are preferably scored upon
437  performing stereo-chemistry checks in order to punish for
438  non-sensical irregularities. This must be done separately
439  as a pre-processing step. Target contacts that are not
440  covered by *model* are considered not conserved, thus
441  decreasing lDDT score. This also includes missing model
442  chains or model chains for which no mapping is provided in
443  *chain_mapping*.
444  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
445  :param thresholds: Thresholds of distance differences to be considered
446  as correct - see docs in constructor for more info.
447  default: [0.5, 1.0, 2.0, 4.0]
448  :type thresholds: :class:`list` of :class:`floats`
449  :param local_lddt_prop: If set, per-residue scores will be assigned as
450  generic float property of that name
451  :type local_lddt_prop: :class:`str`
452  :param local_contact_prop: If set, number of expected contacts as well
453  as number of conserved contacts will be
454  assigned as generic int property.
455  Excected contacts will be set as
456  <local_contact_prop>_exp, conserved contacts
457  as <local_contact_prop>_cons. Values
458  are summed over all thresholds.
459  :type local_contact_prop: :class:`str`
460  :param chain_mapping: Mapping of model chains (key) onto target chains
461  (value). This is required if target or model have
462  more than one chain.
463  :type chain_mapping: :class:`dict` with :class:`str` as keys/values
464  :param no_interchain: Whether to exclude interchain contacts
465  :type no_interchain: :class:`bool`
466  :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
467  consider interface related contacts)
468  :type no_intrachain: :class:`bool`
469  :param penalize_extra_chains: Whether to include a fixed penalty for
470  additional chains in the model that are
471  not mapped to the target. ONLY AFFECTS
472  RETURNED GLOBAL SCORE. In detail: adds the
473  number of intra-chain contacts of each
474  extra chain to the expected contacts, thus
475  adding a penalty.
476  :type penalize_extra_chains: :class:`bool`
477  :param residue_mapping: By default, residue mapping is based on residue
478  numbers. That means, a model chain and the
479  respective target chain map to the same
480  underlying reference sequence (SEQRES).
481  Alternatively, you can specify one or
482  several alignment(s) between model and target
483  chains by providing a dictionary. key: Name
484  of chain in model (respective target chain is
485  extracted from *chain_mapping*),
486  value: Alignment with first sequence
487  corresponding to target chain and second
488  sequence to model chain. There is NO reference
489  sequence involved, so the two sequences MUST
490  exactly match the actual residues observed in
491  the respective target/model chains (ATOMSEQ).
492  :type residue_mapping: :class:`dict` with key: :class:`str`,
493  value: :class:`ost.seq.AlignmentHandle`
494  :param return_dist_test: Whether to additionally return the underlying
495  per-residue data for the distance difference
496  test. Adds five objects to the return tuple.
497  First: Number of total contacts summed over all
498  thresholds
499  Second: Number of conserved contacts summed
500  over all thresholds
501  Third: list with length of scored residues.
502  Contains indices referring to model.residues.
503  Fourth: numpy array of size
504  len(scored_residues) containing the number of
505  total contacts,
506  Fifth: numpy matrix of shape
507  (len(scored_residues), len(thresholds))
508  specifying how many for each threshold are
509  conserved.
510  :param check_resnames: On by default. Enforces residue name matches
511  between mapped model and target residues.
512  :type check_resnames: :class:`bool`
513  :param add_mdl_contacts: Adds model contacts - Only using contacts that
514  are within a certain distance threshold in the
515  target does not penalize for added model
516  contacts. If set to True, this flag will also
517  consider target contacts that are within the
518  specified distance threshold in the model but
519  not necessarily in the target. No contact will
520  be added if the respective atom pair is not
521  resolved in the target.
522  :type add_mdl_contacts: :class:`bool`
523 
524  :returns: global and per-residue lDDT scores as a tuple -
525  first element is global lDDT score (None if *target* has no
526  contacts) and second element a list of per-residue scores with
527  length len(*model*.residues). None is assigned to residues that
528  are not covered by target. If a residue is covered but has no
529  contacts in *target*, 0.0 is assigned.
530  """
531  if chain_mapping is None:
532  if len(self.chain_nameschain_names) > 1 or len(model.chains) > 1:
533  raise NotImplementedError("Must provide chain mapping if "
534  "target or model have > 1 chains.")
535  chain_mapping = {model.chains[0].GetName(): self.chain_nameschain_names[0]}
536  else:
537  # check whether chains specified in mapping exist
538  for model_chain, target_chain in chain_mapping.items():
539  if target_chain not in self.chain_nameschain_names:
540  raise RuntimeError(f"Target chain specified in "
541  f"chain_mapping ({target_chain}) does "
542  f"not exist. Target has chains: "
543  f"{self.chain_names}")
544  ch = model.FindChain(model_chain)
545  if not ch.IsValid():
546  raise RuntimeError(f"Model chain specified in "
547  f"chain_mapping ({model_chain}) does "
548  f"not exist. Model has chains: "
549  f"{[c.GetName() for c in model.chains]}")
550 
551  # data objects defining model data - see _ProcessModel for rough
552  # description
553  pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
554  res_indices, symmetries = self._ProcessModel_ProcessModel(model, chain_mapping,
555  residue_mapping = residue_mapping,
556  thresholds = thresholds,
557  check_resnames = check_resnames)
558 
559  if no_interchain and no_intrachain:
560  raise RuntimeError("no_interchain and no_intrachain flags are "
561  "mutually exclusive")
562 
563  if no_interchain:
564  sym_ref_indices = self.sym_ref_indices_scsym_ref_indices_sc
565  sym_ref_distances = self.sym_ref_distances_scsym_ref_distances_sc
566  ref_indices = self.ref_indices_scref_indices_sc
567  ref_distances = self.ref_distances_scref_distances_sc
568  elif no_intrachain:
569  sym_ref_indices = self.sym_ref_indices_icsym_ref_indices_ic
570  sym_ref_distances = self.sym_ref_distances_icsym_ref_distances_ic
571  ref_indices = self.ref_indices_icref_indices_ic
572  ref_distances = self.ref_distances_icref_distances_ic
573  else:
574  sym_ref_indices = self.sym_ref_indicessym_ref_indices
575  sym_ref_distances = self.sym_ref_distancessym_ref_distances
576  ref_indices = self.ref_indicesref_indices
577  ref_distances = self.ref_distancesref_distances
578 
579  if add_mdl_contacts:
580  ref_indices, ref_distances, \
581  sym_ref_indices, sym_ref_distances = \
582  self._AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
583  ref_indices, ref_distances,
584  no_interchain, no_intrachain)
585 
586  self._ResolveSymmetries_ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
587  sym_ref_distances)
588 
589  per_res_exp = np.asarray([self._GetNExp_GetNExp(res_ref_atom_indices[idx],
590  ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
591  per_res_conserved = self._EvalResidues_EvalResidues(pos, thresholds,
592  res_atom_indices,
593  ref_indices, ref_distances)
594 
595  n_thresh = len(thresholds)
596 
597  # do per-residue scores
598  per_res_lDDT = [None] * len(model.residues)
599  for idx in range(len(res_indices)):
600  n_exp = n_thresh * per_res_exp[idx]
601  if n_exp > 0:
602  score = np.sum(per_res_conserved[idx,:]) / n_exp
603  per_res_lDDT[res_indices[idx]] = score
604  else:
605  per_res_lDDT[res_indices[idx]] = 0.0
606 
607  # do full model score
608  n_distances = sum([len(x) for x in ref_indices])
609  if penalize_extra_chains:
610  n_distances += self._GetExtraModelChainPenalty_GetExtraModelChainPenalty(model, chain_mapping)
611 
612  lDDT_tot = int(n_thresh * n_distances)
613  lDDT_cons = int(np.sum(per_res_conserved))
614  lDDT = None
615  if lDDT_tot > 0:
616  lDDT = float(lDDT_cons) / lDDT_tot
617 
618  # set properties if necessary
619  if local_lddt_prop:
620  residues = model.residues
621  for idx in res_indices:
622  residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
623 
624  if local_contact_prop:
625  residues = model.residues
626  exp_prop = local_contact_prop + "_exp"
627  conserved_prop = local_contact_prop + "_cons"
628 
629  for i, r_idx in enumerate(res_indices):
630  residues[r_idx].SetIntProp(exp_prop,
631  n_thresh * int(per_res_exp[i]))
632  residues[r_idx].SetIntProp(conserved_prop,
633  int(np.sum(per_res_conserved[i,:])))
634 
635  if return_dist_test:
636  return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
637  per_res_exp, per_res_conserved
638  else:
639  return lDDT, per_res_lDDT
640 
641  def GetNChainContacts(self, target_chain, no_interchain=False):
642  """Returns number of contacts expected for a certain chain in *target*
643 
644  :param target_chain: Chain in *target* for which you want the number
645  of expected contacts
646  :type target_chain: :class:`str`
647  :param no_interchain: Whether to exclude interchain contacts
648  :type no_interchain: :class:`bool`
649  :raises: :class:`RuntimeError` if specified chain doesnt exist
650  """
651  if target_chain not in self.chain_nameschain_names:
652  raise RuntimeError(f"Specified chain name ({target_chain}) not in "
653  f"target")
654  ch_idx = self.chain_nameschain_names.index(target_chain)
655  s = self.chain_start_indiceschain_start_indices[ch_idx]
656  e = self.n_atomsn_atoms
657  if ch_idx + 1 < len(self.chain_nameschain_names):
658  e = self.chain_start_indiceschain_start_indices[ch_idx+1]
659  if no_interchain:
660  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indices_scref_indices_sc)
661  else:
662  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indicesref_indices)
663 
664  def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
665  thresholds = [0.5, 1.0, 2.0, 4.0],
666  check_resnames = True):
667  """ Helper that generates data structures from model
668  """
669 
670  # initialize positions with values far in nirvana. If a position is not
671  # set, it should be far away from any position in model.
672  max_pos = model.bounds.GetMax()
673  max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
674  max_coordinate += 42 * max(thresholds)
675  pos = np.ones((self.n_atomsn_atoms, 3), dtype=np.float32) * max_coordinate
676 
677  # for each scored residue in model a list of indices describing the
678  # atoms from the reference that should be there
679  res_ref_atom_indices = list()
680 
681  # for each scored residue in model a list of indices of atoms that are
682  # actually there
683  res_atom_indices = list()
684 
685  # and the respective hash codes
686  # this is required if add_mdl_contacts is set to True
687  res_atom_hashes = list()
688 
689  # indices of the scored residues
690  res_indices = list()
691 
692  # Will contain one element per symmetry group
693  symmetries = list()
694 
695  current_model_res_idx = -1
696  for ch in model.chains:
697  model_ch_name = ch.GetName()
698  if model_ch_name not in chain_mapping:
699  current_model_res_idx += len(ch.residues)
700  continue # additional model chain which is not mapped
701  target_ch_name = chain_mapping[model_ch_name]
702 
703  rnums = self._GetChainRNums_GetChainRNums(ch, residue_mapping, model_ch_name,
704  target_ch_name)
705 
706  for r, rnum in zip(ch.residues, rnums):
707  current_model_res_idx += 1
708  res_mapper_key = (target_ch_name, rnum)
709  if res_mapper_key not in self.res_mapperres_mapper:
710  continue
711  r_idx = self.res_mapperres_mapper[res_mapper_key]
712  if check_resnames and r.name != self.compound_namescompound_names[r_idx]:
713  raise RuntimeError(
714  f"Residue name mismatch for {r}, "
715  f" expect {self.compound_names[r_idx]}"
716  )
717  res_start_idx = self.res_start_indicesres_start_indices[r_idx]
718  rname = self.compound_namescompound_names[r_idx]
719  anames = self.compound_anamescompound_anames[rname]
720  atoms = [r.FindAtom(aname) for aname in anames]
721  res_ref_atom_indices.append(
722  list(range(res_start_idx, res_start_idx + len(anames)))
723  )
724  res_atom_indices.append(list())
725  res_atom_hashes.append(list())
726  res_indices.append(current_model_res_idx)
727  for a_idx, a in enumerate(atoms):
728  if a.IsValid():
729  p = a.GetPos()
730  pos[res_start_idx + a_idx][0] = p[0]
731  pos[res_start_idx + a_idx][1] = p[1]
732  pos[res_start_idx + a_idx][2] = p[2]
733  res_atom_indices[-1].append(res_start_idx + a_idx)
734  res_atom_hashes[-1].append(a.handle.GetHashCode())
735  if rname in self.compound_symmetric_atomscompound_symmetric_atoms:
736  sym_indices = list()
737  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[rname]:
738  a_one = atoms[sym_tuple[0]]
739  a_two = atoms[sym_tuple[1]]
740  if a_one.IsValid() and a_two.IsValid():
741  sym_indices.append(
742  (
743  res_start_idx + sym_tuple[0],
744  res_start_idx + sym_tuple[1],
745  )
746  )
747  if len(sym_indices) > 0:
748  symmetries.append(sym_indices)
749 
750  return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
751  res_indices, symmetries)
752 
753 
754  def _GetExtraModelChainPenalty(self, model, chain_mapping):
755  """Counts n distances in extra model chains to be added as penalty
756  """
757  penalty = 0
758  for chain in model.chains:
759  ch_name = chain.GetName()
760  if ch_name not in chain_mapping:
761  sm = self.symmetry_settingssymmetry_settings
762  mdl_sel = model.Select(f"cname={mol.QueryQuoteName(ch_name)}")
763  dummy_scorer = lDDTScorer(mdl_sel, self.compound_libcompound_lib,
764  symmetry_settings = sm,
765  inclusion_radius = self.inclusion_radiusinclusion_radius,
766  bb_only = self.bb_onlybb_only)
767  penalty += sum([len(x) for x in dummy_scorer.ref_indices])
768  return penalty
769 
770  def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
771  target_ch_name):
772  """Map residues in model chain to target residues
773 
774  There are two options: one is simply using residue numbers,
775  the other is a custom mapping as given in *residue_mapping*
776  """
777  if residue_mapping and model_ch_name in residue_mapping:
778  # extract residue numbers from target chain
779  ch_idx = self.chain_nameschain_names.index(target_ch_name)
780  start_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx]
781  if ch_idx < len(self.chain_nameschain_names) - 1:
782  end_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx+1]
783  else:
784  end_idx = len(self.compound_namescompound_names)
785  target_rnums = self.res_resnumsres_resnums[start_idx:end_idx]
786  # get sequences from alignment and do consistency checks
787  target_seq = residue_mapping[model_ch_name].GetSequence(0)
788  model_seq = residue_mapping[model_ch_name].GetSequence(1)
789  if len(target_seq.GetGaplessString()) != len(target_rnums):
790  raise RuntimeError(f"Try to perform residue mapping for "
791  f"model chain {model_ch_name} which "
792  f"maps to {target_ch_name} in target. "
793  f"Target sequence in alignment suggests "
794  f"{len(target_seq.GetGaplessString())} "
795  f"residues but {len(target_rnums)} are "
796  f"expected.")
797  if len(model_seq.GetGaplessString()) != len(ch.residues):
798  raise RuntimeError(f"Try to perform residue mapping for "
799  f"model chain {model_ch_name} which "
800  f"maps to {target_ch_name} in target. "
801  f"Model sequence in alignment suggests "
802  f"{len(model_seq.GetGaplessString())} "
803  f"residues but {len(ch.residues)} are "
804  f"expected.")
805  rnums = list()
806  target_idx = -1
807  for col in residue_mapping[model_ch_name]:
808  if col[0] != '-':
809  target_idx += 1
810  # handle match
811  if col[0] != '-' and col[1] != '-':
812  rnums.append(target_rnums[target_idx])
813  # insertion in model adds None to rnum
814  if col[0] == '-' and col[1] != '-':
815  rnums.append(None)
816  else:
817  rnums = [r.GetNumber() for r in ch.residues]
818 
819  return rnums
820 
821 
822  def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
823  seqres_mapping, bb_only):
824  """Sets target related lDDTScorer members defined in constructor
825 
826  No distance related members - see _SetupDistances
827  """
828  residue_numbers = self._GetTargetResidueNumbers_GetTargetResidueNumbers(self.targettarget,
829  seqres_mapping)
830  current_idx = 0
831  positions = list()
832  for chain in self.targettarget.chains:
833  ch_name = chain.GetName()
834  self.chain_nameschain_names.append(ch_name)
835  self.chain_start_indiceschain_start_indices.append(current_idx)
836  self.chain_res_start_indiceschain_res_start_indices.append(len(self.compound_namescompound_names))
837  for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
838  if r.name not in self.compound_anamescompound_anames:
839  # sets compound info in self.compound_anames and
840  # self.compound_symmetric_atoms
841  self._SetupCompound_SetupCompound(r, compound_lib, custom_compounds,
842  symmetry_settings, bb_only)
843 
844  self.res_start_indicesres_start_indices.append(current_idx)
845  self.res_mapperres_mapper[(ch_name, rnum)] = len(self.compound_namescompound_names)
846  self.compound_namescompound_names.append(r.name)
847  self.res_resnumsres_resnums.append(rnum)
848 
849  atoms = [r.FindAtom(an) for an in self.compound_anamescompound_anames[r.name]]
850  for a in atoms:
851  if a.IsValid():
852  self.atom_indicesatom_indices[a.handle.GetHashCode()] = current_idx
853  p = a.GetPos()
854  positions.append(np.asarray([p[0], p[1], p[2]],
855  dtype=np.float32))
856  else:
857  positions.append(np.zeros(3, dtype=np.float32))
858  current_idx += 1
859 
860  if r.name in self.compound_symmetric_atomscompound_symmetric_atoms:
861  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[r.name]:
862  for a_idx in sym_tuple:
863  a = atoms[a_idx]
864  if a.IsValid():
865  hashcode = a.handle.GetHashCode()
866  self.symmetric_atomssymmetric_atoms.add(
867  self.atom_indicesatom_indices[hashcode]
868  )
869  self.positionspositions = np.vstack(positions)
870  self.n_atomsn_atoms = current_idx
871 
872  def _GetTargetResidueNumbers(self, target, seqres_mapping):
873  """Returns residue numbers for each chain in target as dict
874 
875  They're either directly extracted from the raw residue number
876  from the structure or from user provided alignments
877  """
878  residue_numbers = dict()
879  for ch in target.chains:
880  ch_name = ch.GetName()
881  rnums = list()
882  if ch_name in seqres_mapping:
883  seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
884  atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
885  # SEQRES must not contain gaps
886  if "-" in seqres:
887  raise RuntimeError(
888  "SEQRES in seqres_mapping must not " "contain gaps"
889  )
890  atomseq_from_chain = [r.one_letter_code for r in ch.residues]
891  if atomseq.replace("-", "") != atomseq_from_chain:
892  raise RuntimeError(
893  "ATOMSEQ in seqres_mapping must match "
894  "raw sequence extracted from chain "
895  "residues"
896  )
897  rnum = 0
898  for seqres_olc, atomseq_olc in zip(seqres, atomseq):
899  if seqres_olc != "-":
900  rnum += 1
901  if atomseq_olc != "-":
902  if seqres_olc != atomseq_olc:
903  raise RuntimeError(
904  f"Residue with number {rnum} in "
905  f"chain {ch_name} has SEQRES "
906  f"ATOMSEQ mismatch"
907  )
908  rnums.append(mol.ResNum(rnum))
909  else:
910  rnums = [r.GetNumber() for r in ch.residues]
911  assert len(rnums) == len(ch.residues)
912  residue_numbers[ch_name] = rnums
913  return residue_numbers
914 
915  def _SetupCompound(self, r, compound_lib, custom_compounds,
916  symmetry_settings, bb_only):
917  """fill self.compound_anames/self.compound_symmetric_atoms
918  """
919  if bb_only:
920  # throw away compound_lib info
921  if r.chem_class.IsPeptideLinking():
922  self.compound_anamescompound_anames[r.name] = ["CA"]
923  elif r.chem_class.IsNucleotideLinking():
924  self.compound_anamescompound_anames[r.name] = ["C3'"]
925  else:
926  raise RuntimeError(f"Only support amino acids and nucleotides "
927  f"if bb_only is True, failed with {str(r)}")
928  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = list()
929  else:
930  atom_names = list()
931  symmetric_atoms = list()
932  if custom_compounds is not None and r.GetName() in custom_compounds:
933  atom_names = list(custom_compounds[r.GetName()].atom_names)
934  else:
935  compound = compound_lib.FindCompound(r.name)
936  if compound is None:
937  raise RuntimeError(f"no entry for {r} in compound_lib")
938  for atom_spec in compound.GetAtomSpecs():
939  if atom_spec.element not in ["H", "D"]:
940  atom_names.append(atom_spec.name)
941  if r.name in symmetry_settings.symmetric_compounds:
942  for pair in symmetry_settings.symmetric_compounds[r.name]:
943  try:
944  a = atom_names.index(pair[0])
945  b = atom_names.index(pair[1])
946  except:
947  msg = f"Could not find symmetric atoms "
948  msg += f"({pair[0]}, {pair[1]}) for {r.name} "
949  msg += f"as specified in SymmetrySettings in "
950  msg += f"compound from component dictionary. "
951  msg += f"Atoms in compound: {atom_names}"
952  raise RuntimeError(msg)
953  symmetric_atoms.append((a, b))
954  self.compound_anamescompound_anames[r.name] = atom_names
955  if len(symmetric_atoms) > 0:
956  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = symmetric_atoms
957 
958  def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
959  ref_indices, ref_distances, no_interchain,
960  no_intrachain):
961 
962  # buildup an index map for mdl atoms that are also present in target
963  in_target = np.zeros(self.n_atomsn_atoms, dtype=bool)
964  for i in self.atom_indicesatom_indices.values():
965  in_target[i] = True
966  mdl_atom_indices = dict()
967  for at_indices, at_hashes in zip(res_atom_indices, res_atom_hashes):
968  for i, h in zip(at_indices, at_hashes):
969  if in_target[i]:
970  mdl_atom_indices[h] = i
971 
972  # get contacts for mdl - the contacts are only from atom pairs that
973  # are also present in target, as we only provide the respective
974  # hashes in mdl_atom_indices
975  mdl_ref_indices, mdl_ref_distances = \
976  lDDTScorer._SetupDistances(model, self.n_atomsn_atoms, mdl_atom_indices,
977  self.inclusion_radiusinclusion_radius)
978  if no_interchain:
979  mdl_ref_indices, mdl_ref_distances = \
980  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
981  self.chain_start_indiceschain_start_indices,
982  mdl_ref_indices,
983  mdl_ref_distances)
984 
985  if no_intrachain:
986  mdl_ref_indices, mdl_ref_distances = \
987  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
988  self.chain_start_indiceschain_start_indices,
989  mdl_ref_indices,
990  mdl_ref_distances)
991 
992  # update ref_indices/ref_distances => add mdl contacts
993  for i in range(self.n_atomsn_atoms):
994  mask = np.isin(mdl_ref_indices[i], ref_indices[i],
995  assume_unique=True, invert=True)
996  if np.sum(mask) > 0:
997  added_mdl_indices = mdl_ref_indices[i][mask]
998  ref_indices[i] = np.append(ref_indices[i],
999  added_mdl_indices)
1000 
1001  # distances need to be recomputed from ref positions
1002  tmp = self.positionspositions.take(added_mdl_indices, axis=0)
1003  np.subtract(tmp, self.positionspositions[i][None, :], out=tmp)
1004  np.square(tmp, out=tmp)
1005  tmp = tmp.sum(axis=1)
1006  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1007  ref_distances[i] = np.append(ref_distances[i], tmp)
1008 
1009  # recompute symmetry related indices/distances
1010  sym_ref_indices, sym_ref_distances = \
1011  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
1012  ref_indices, ref_distances)
1013 
1014  return (ref_indices, ref_distances, sym_ref_indices, sym_ref_distances)
1015 
1016 
1017 
1018  @staticmethod
1019  def _SetupDistances(structure, n_atoms, atom_index_mapping,
1020  inclusion_radius):
1021 
1022  """Compute distance related members of lDDTScorer
1023 
1024  Brute force all vs all distance computation kills lDDT for large
1025  complexes. Instead of building some KD tree data structure, we make use
1026  of expected spatial proximity of atoms in the same chain. Distances are
1027  computed as follows:
1028 
1029  - process each chain individually
1030  - perform crude collision detection
1031  - process potentially interacting chain pairs
1032  - concatenate distances from all processing steps
1033  """
1034  ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1035  ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1036 
1037  indices = [list() for _ in range(n_atoms)]
1038  distances = [list() for _ in range(n_atoms)]
1039  per_chain_pos = list()
1040  per_chain_indices = list()
1041 
1042  # Process individual chains
1043  for ch in structure.chains:
1044  pos_list = list()
1045  atom_indices = list()
1046  mask_start = list()
1047  mask_end = list()
1048  r_start_idx = 0
1049  for r_idx, r in enumerate(ch.residues):
1050  n_valid_atoms = 0
1051  for a in r.atoms:
1052  hash_code = a.handle.GetHashCode()
1053  if hash_code in atom_index_mapping:
1054  p = a.GetPos()
1055  pos_list.append(np.asarray([p[0], p[1], p[2]]))
1056  atom_indices.append(atom_index_mapping[hash_code])
1057  n_valid_atoms += 1
1058  mask_start.extend([r_start_idx] * n_valid_atoms)
1059  mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1060  r_start_idx += n_valid_atoms
1061 
1062  if len(pos_list) == 0:
1063  # nothing to do...
1064  continue
1065 
1066  pos = np.vstack(pos_list)
1067  atom_indices = np.asarray(atom_indices)
1068  dists = cdist(pos, pos)
1069 
1070  # apply masks
1071  far_away = 2 * inclusion_radius
1072  for idx in range(atom_indices.shape[0]):
1073  dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1074 
1075  # fish out and store close atoms within inclusion radius
1076  within_mask = dists < inclusion_radius
1077  for idx in range(atom_indices.shape[0]):
1078  indices_to_append = atom_indices[within_mask[idx,:]]
1079  if indices_to_append.shape[0] > 0:
1080  full_at_idx = atom_indices[idx]
1081  indices[full_at_idx].append(indices_to_append)
1082  distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1083 
1084  per_chain_pos.append(pos)
1085  per_chain_indices.append(atom_indices)
1086 
1087  # perform crude collision detection
1088  min_pos = [p.min(0) for p in per_chain_pos]
1089  max_pos = [p.max(0) for p in per_chain_pos]
1090  chain_pairs = list()
1091  for idx_one in range(len(per_chain_pos)):
1092  for idx_two in range(idx_one + 1, len(per_chain_pos)):
1093  if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1094  continue
1095  if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1096  continue
1097  chain_pairs.append((idx_one, idx_two))
1098 
1099  # process potentially interacting chains
1100  for pair in chain_pairs:
1101  dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1102  within = dists <= inclusion_radius
1103 
1104  # process pair[0]
1105  tmp = within.sum(axis=1)
1106  for idx in range(tmp.shape[0]):
1107  if tmp[idx] > 0:
1108  # even though not being a strict requirement, we perform an
1109  # insertion here such that the indices for each atom will be
1110  # sorted after the hstack operation
1111  at_idx = per_chain_indices[pair[0]][idx]
1112  indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1113  distances_to_insert = dists[idx, within[idx, :]]
1114  insertion_idx = len(indices[at_idx])
1115  for i in range(insertion_idx):
1116  if indices_to_insert[0] > indices[at_idx][i][0]:
1117  insertion_idx = i
1118  break
1119  indices[at_idx].insert(insertion_idx, indices_to_insert)
1120  distances[at_idx].insert(insertion_idx, distances_to_insert)
1121 
1122  # process pair[1]
1123  tmp = within.sum(axis=0)
1124  for idx in range(tmp.shape[0]):
1125  if tmp[idx] > 0:
1126  # even though not being a strict requirement, we perform an
1127  # insertion here such that the indices for each atom will be
1128  # sorted after the hstack operation
1129  at_idx = per_chain_indices[pair[1]][idx]
1130  indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1131  distances_to_insert = dists[within[:, idx], idx]
1132  insertion_idx = len(indices[at_idx])
1133  for i in range(insertion_idx):
1134  if indices_to_insert[0] > indices[at_idx][i][0]:
1135  insertion_idx = i
1136  break
1137  indices[at_idx].insert(insertion_idx, indices_to_insert)
1138  distances[at_idx].insert(insertion_idx, distances_to_insert)
1139 
1140  # concatenate distances from all processing steps
1141  for at_idx in range(n_atoms):
1142  if len(indices[at_idx]) > 0:
1143  ref_indices[at_idx] = np.hstack(indices[at_idx])
1144  ref_distances[at_idx] = np.hstack(distances[at_idx])
1145 
1146  return (ref_indices, ref_distances)
1147 
1148  @staticmethod
1149  def _SetupDistancesSC(n_atoms, chain_start_indices,
1150  ref_indices, ref_distances):
1151  """Select subset of contacts only covering intra-chain contacts
1152  """
1153  # init
1154  ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1155  ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1156 
1157  n_chains = len(chain_start_indices)
1158  for ch_idx in range(n_chains):
1159  chain_s = chain_start_indices[ch_idx]
1160  chain_e = n_atoms
1161  if ch_idx + 1 < n_chains:
1162  chain_e = chain_start_indices[ch_idx+1]
1163  for i in range(chain_s, chain_e):
1164  if len(ref_indices[i]) > 0:
1165  intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1166  ref_indices[i]<chain_e))[0]
1167  ref_indices_sc[i] = ref_indices[i][intra_idx]
1168  ref_distances_sc[i] = ref_distances[i][intra_idx]
1169 
1170  return (ref_indices_sc, ref_distances_sc)
1171 
1172  @staticmethod
1173  def _SetupDistancesIC(n_atoms, chain_start_indices,
1174  ref_indices, ref_distances):
1175  """Select subset of contacts only covering inter-chain contacts
1176  """
1177  # init
1178  ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1179  ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1180 
1181  n_chains = len(chain_start_indices)
1182  for ch_idx in range(n_chains):
1183  chain_s = chain_start_indices[ch_idx]
1184  chain_e = n_atoms
1185  if ch_idx + 1 < n_chains:
1186  chain_e = chain_start_indices[ch_idx+1]
1187  for i in range(chain_s, chain_e):
1188  if len(ref_indices[i]) > 0:
1189  inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1190  ref_indices[i]>=chain_e))[0]
1191  ref_indices_ic[i] = ref_indices[i][inter_idx]
1192  ref_distances_ic[i] = ref_distances[i][inter_idx]
1193 
1194  return (ref_indices_ic, ref_distances_ic)
1195 
1196  @staticmethod
1197  def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1198  """Transfer indices/distances of non-symmetric atoms and return
1199  """
1200 
1201  sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1202  sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1203 
1204  for idx in symmetric_atoms:
1205  indices = list()
1206  distances = list()
1207  for i, d in zip(ref_indices[idx], ref_distances[idx]):
1208  if i not in symmetric_atoms:
1209  indices.append(i)
1210  distances.append(d)
1211  sym_ref_indices[idx] = indices
1212  sym_ref_distances[idx] = np.asarray(distances)
1213 
1214  return (sym_ref_indices, sym_ref_distances)
1215 
1216  def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1217  """Computes number of distance differences within given thresholds
1218 
1219  returns np.array with len(thresholds) elements
1220  """
1221  a_p = pos[atom_idx, :]
1222  tmp = pos.take(ref_indices[atom_idx], axis=0)
1223  np.subtract(tmp, a_p[None, :], out=tmp)
1224  np.square(tmp, out=tmp)
1225  tmp = tmp.sum(axis=1)
1226  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1227  np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1228  np.absolute(tmp, out=tmp) # absolute dist diffs
1229  return np.asarray([(tmp <= thresh).sum() for thresh in thresholds],
1230  dtype=np.int32)
1231 
1232  def _EvalAtoms(
1233  self, pos, atom_indices, thresholds, ref_indices, ref_distances
1234  ):
1235  """Calls _EvalAtom for several atoms and sums up the computed number
1236  of distance differences within given thresholds
1237 
1238  returns numpy matrix of shape (n_atoms, len(threshold))
1239  """
1240  conserved = np.zeros((len(atom_indices), len(thresholds)),
1241  dtype=np.int32)
1242  for a_idx, a in enumerate(atom_indices):
1243  conserved[a_idx, :] = self._EvalAtom_EvalAtom(pos, a, thresholds,
1244  ref_indices, ref_distances)
1245  return conserved
1246 
1247  def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1248  ref_distances):
1249  """Calls _EvalAtoms for a bunch of residues
1250 
1251  residues are defined in *res_atom_indices* as lists of atom indices
1252  returns numpy matrix of shape (n_residues, len(thresholds)).
1253  """
1254  conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1255  dtype=np.int32)
1256  for rai_idx, rai in enumerate(res_atom_indices):
1257  conserved[rai_idx,:] = np.sum(self._EvalAtoms_EvalAtoms(pos, rai, thresholds,
1258  ref_indices, ref_distances), axis=0)
1259  return conserved
1260 
1261  def _ProcessSequenceSeparation(self):
1262  if self.sequence_separationsequence_separation != 0:
1263  raise NotImplementedError("Congratulations! You're the first one "
1264  "requesting a non-default "
1265  "sequence_separation in the new and "
1266  "awesome lDDT implementation. A crate of "
1267  "beer for Gabriel and he'll implement "
1268  "it.")
1269 
1270  def _GetNExp(self, atom_idx, ref_indices):
1271  """Returns number of close atoms around one or several atoms
1272  """
1273  if isinstance(atom_idx, int):
1274  return len(ref_indices[atom_idx])
1275  elif isinstance(atom_idx, list):
1276  return sum([len(ref_indices[idx]) for idx in atom_idx])
1277  else:
1278  raise RuntimeError("invalid input type")
1279 
1280  def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1281  sym_ref_distances):
1282  """Swaps symmetric positions in-place in order to maximize lDDT scores
1283  towards non-symmetric atoms.
1284  """
1285  for sym in symmetries:
1286 
1287  atom_indices = list()
1288  for sym_tuple in sym:
1289  atom_indices += [sym_tuple[0], sym_tuple[1]]
1290  tot = self._GetNExp_GetNExp(atom_indices, sym_ref_indices)
1291 
1292  if tot == 0:
1293  continue # nothing to do
1294 
1295  # score as is
1296  sym_one_conserved = self._EvalAtoms_EvalAtoms(
1297  pos,
1298  atom_indices,
1299  thresholds,
1300  sym_ref_indices,
1301  sym_ref_distances,
1302  )
1303 
1304  # switch positions and score again
1305  for pair in sym:
1306  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1307 
1308  sym_two_conserved = self._EvalAtoms_EvalAtoms(
1309  pos,
1310  atom_indices,
1311  thresholds,
1312  sym_ref_indices,
1313  sym_ref_distances,
1314  )
1315 
1316  sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1317  sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1318 
1319  if sym_one_score >= sym_two_score:
1320  # switch back, initial positions were better or equal
1321  # for the equal case: we still switch back to reproduce the old
1322  # lDDT behaviour
1323  for pair in sym:
1324  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def __init__(self, atom_names)
Definition: lddt.py:29
def AddSymmetricCompound(self, name, symmetric_atoms)
Definition: lddt.py:64
def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices, ref_distances)
Definition: lddt.py:1248
def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
Definition: lddt.py:916
def _ProcessModel(self, model, chain_mapping, residue_mapping=None, thresholds=[0.5, 1.0, 2.0, 4.0], check_resnames=True)
Definition: lddt.py:666
def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
Definition: lddt.py:771
def _ProcessSequenceSeparation(self)
Definition: lddt.py:1261
def sym_ref_distances(self)
Definition: lddt.py:341
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
Definition: lddt.py:1281
def ref_distances_ic(self)
Definition: lddt.py:399
def _GetTargetResidueNumbers(self, target, seqres_mapping)
Definition: lddt.py:872
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1216
def sym_ref_distances_ic(self)
Definition: lddt.py:419
def lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False)
Definition: lddt.py:433
def sym_ref_distances_sc(self)
Definition: lddt.py:379
def sym_ref_indices_ic(self)
Definition: lddt.py:409
def GetNChainContacts(self, target_chain, no_interchain=False)
Definition: lddt.py:641
def sym_ref_indices_sc(self)
Definition: lddt.py:369
def ref_distances_sc(self)
Definition: lddt.py:359
def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
Definition: lddt.py:960
def _GetNExp(self, atom_idx, ref_indices)
Definition: lddt.py:1270
def __init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
Definition: lddt.py:205
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
Definition: lddt.py:823
def _GetExtraModelChainPenalty(self, model, chain_mapping)
Definition: lddt.py:754
def _EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1234
def GetDefaultSymmetrySettings()
Definition: lddt.py:82
def cdist(p1, p2)
Definition: lddt.py:11