OpenStructure
contact_score.py
Go to the documentation of this file.
1 import itertools
2 import numpy as np
3 
4 import time
5 from ost import mol
6 from ost import geom
7 from ost import io
8 
10  """ Helper object for Contact-score computation
11  """
12  def __init__(self, ent, contact_d = 5.0, contact_mode="aa"):
13 
14  if contact_mode not in ["aa", "repr"]:
15  raise RuntimeError("contact_mode must be in [\"aa\", \"repr\"]")
16 
17  if contact_mode == "repr":
18  for r in ent.residues:
19  repr_at = None
20  if r.IsPeptideLinking():
21  cb = r.FindAtom("CB")
22  if cb.IsValid():
23  repr_at = cb
24  elif r.GetName() == "GLY":
25  ca = r.FindAtom("CA")
26  if ca.IsValid():
27  repr_at = ca
28  elif r.IsNucleotideLinking():
29  c3 = r.FindAtom("C3'")
30  if c3.IsValid():
31  repr_at = c3
32  else:
33  raise RuntimeError(f"Only support peptide and nucleotide "
34  f"residues in \"repr\" contact mode. "
35  f"Problematic residue: {r}")
36  if repr_at is None:
37  raise RuntimeError(f"Residue {r} has no required "
38  f"representative atom (CB for peptide "
39  f"residues (CA for GLY) C3' for "
40  f"nucleotide residues.")
41 
42  self._contact_mode_contact_mode = contact_mode
43 
44  if self.contact_modecontact_modecontact_mode == "aa":
45  self._view_view = ent.CreateFullView()
46  elif self.contact_modecontact_modecontact_mode == "repr":
47  pep_query = "(peptide=true and (aname=\"CB\" or (rname=\"GLY\" and aname=\"CA\")))"
48  nuc_query = "(nucleotide=True and aname=\"C3'\")"
49  self._view_view = ent.Select(" or ".join([pep_query, nuc_query]))
50  self._contact_d_contact_d = contact_d
51 
52  # the following attributes will be lazily evaluated
53  self._chain_names_chain_names = None
54  self._interacting_chains_interacting_chains = None
55  self._sequence_sequence = dict()
56  self._contacts_contacts = None
57  self._hr_contacts_hr_contacts = None
58  self._interface_residues_interface_residues = None
59  self._hr_interface_residues_hr_interface_residues = None
60 
61  @property
62  def view(self):
63  """ The structure depending on *contact_mode*
64 
65  Full view in case of "aa", view that only contains representative
66  atoms in case of "repr".
67 
68  :type: :class:`ost.mol.EntityView`
69  """
70  return self._view_view
71 
72  @property
73  def contact_mode(self):
74  """ The contact mode
75 
76  Can either be "aa", meaning that all atoms are considered to identify
77  contacts, or "repr" which only considers distances between
78  representative atoms. For peptides thats CB (CA for GLY), for
79  nucleotides thats C3'.
80 
81  :type: :class:`str`
82  """
83  return self._contact_mode_contact_mode
84 
85  @property
86  def contact_d(self):
87  """ Pairwise distance of residues to be considered as contacts
88 
89  Given at :class:`ContactScorer` construction
90 
91  :type: :class:`float`
92  """
93  return self._contact_d_contact_d
94 
95  @property
96  def chain_names(self):
97  """ Chain names in :attr:`~view`
98 
99  Names are sorted
100 
101  :type: :class:`list` of :class:`str`
102  """
103  if self._chain_names_chain_names is None:
104  self._chain_names_chain_names = sorted([ch.name for ch in self.viewview.chains])
105  return self._chain_names_chain_names
106 
107  @property
109  """ Pairs of chains in :attr:`~view` with at least one contact
110 
111  :type: :class:`list` of :class:`tuples`
112  """
113  if self._interacting_chains_interacting_chains is None:
114  self._interacting_chains_interacting_chains = list(self.contactscontacts.keys())
115  return self._interacting_chains_interacting_chains
116 
117  @property
118  def contacts(self):
119  """ Interchain contacts
120 
121  Organized as :class:`dict` with key (cname1, cname2) and values being
122  a set of tuples with the respective residue indices.
123  cname1 < cname2 evaluates to True.
124  """
125  if self._contacts_contacts is None:
126  self._SetupContacts_SetupContacts()
127  return self._contacts_contacts
128 
129  @property
130  def hr_contacts(self):
131  """ Human readable interchain contacts
132 
133  Human readable version of :attr:`~contacts`. Simple list with tuples
134  containing two strings specifying the residues in contact. Format:
135  <cname>.<rnum>.<ins_code>
136  """
137  if self._hr_contacts_hr_contacts is None:
138  self._SetupContacts_SetupContacts()
139  return self._hr_contacts_hr_contacts
140 
141  @property
143  """ Interface residues
144 
145  Residues in each chain that are in contact with any other chain.
146  Organized as :class:`dict` with key cname and values the respective
147  residue indices in a :class:`set`.
148  """
149  if self._interface_residues_interface_residues is None:
150  self._SetupInterfaceResidues_SetupInterfaceResidues()
151  return self._interface_residues_interface_residues
152 
153  @property
155  """ Human readable interface residues
156 
157  Human readable version of :attr:`interface_residues`. :class:`list` of
158  strings specifying the interface residues in format:
159  <cname>.<rnum>.<ins_code>
160  """
161  if self._interface_residues_interface_residues is None:
162  self._SetupHRInterfaceResidues_SetupHRInterfaceResidues()
163  return self._hr_interface_residues_hr_interface_residues
164 
165  def GetChain(self, chain_name):
166  """ Get chain by name
167 
168  :param chain_name: Chain in :attr:`~view`
169  :type chain_name: :class:`str`
170  """
171  chain = self.viewview.FindChain(chain_name)
172  if not chain.IsValid():
173  raise RuntimeError(f"view has no chain named \"{chain_name}\"")
174  return chain
175 
176  def GetSequence(self, chain_name):
177  """ Get sequence of chain
178 
179  Returns sequence of specified chain as raw :class:`str`
180 
181  :param chain_name: Chain in :attr:`~view`
182  :type chain_name: :class:`str`
183  """
184  if chain_name not in self._sequence_sequence:
185  ch = self.GetChainGetChain(chain_name)
186  s = ''.join([r.one_letter_code for r in ch.residues])
187  self._sequence_sequence[chain_name] = s
188  return self._sequence_sequence[chain_name]
189 
190  def _SetupContacts(self):
191  self._contacts_contacts = dict()
192  self._hr_contacts_hr_contacts = list()
193 
194  # set indices relative to full view
195  for ch in self.viewview.chains:
196  for r_idx, r in enumerate(ch.residues):
197  r.SetIntProp("contact_idx", r_idx)
198 
199  residue_lists = list()
200  min_res_x = list()
201  min_res_y = list()
202  min_res_z = list()
203  max_res_x = list()
204  max_res_y = list()
205  max_res_z = list()
206  per_res_pos = list()
207  min_chain_pos = list()
208  max_chain_pos = list()
209 
210  for cname in self.chain_nameschain_names:
211  ch = self.viewview.FindChain(cname)
212  if ch.GetAtomCount() == 0:
213  raise RuntimeError(f"Chain without atoms observed: \"{cname}\"")
214  residue_lists.append([r for r in ch.residues])
215  res_pos = list()
216  for r in residue_lists[-1]:
217  pos = np.zeros((r.GetAtomCount(), 3))
218  for at_idx, at in enumerate(r.atoms):
219  p = at.GetPos()
220  pos[(at_idx, 0)] = p[0]
221  pos[(at_idx, 1)] = p[1]
222  pos[(at_idx, 2)] = p[2]
223  res_pos.append(pos)
224  min_res_pos = np.vstack([p.min(0) for p in res_pos])
225  max_res_pos = np.vstack([p.max(0) for p in res_pos])
226  min_res_x.append(min_res_pos[:, 0])
227  min_res_y.append(min_res_pos[:, 1])
228  min_res_z.append(min_res_pos[:, 2])
229  max_res_x.append(max_res_pos[:, 0])
230  max_res_y.append(max_res_pos[:, 1])
231  max_res_z.append(max_res_pos[:, 2])
232  min_chain_pos.append(min_res_pos.min(0))
233  max_chain_pos.append(max_res_pos.max(0))
234  per_res_pos.append(res_pos)
235 
236  # operate on squared contact_d (scd) to save some square roots
237  scd = self.contact_dcontact_d * self.contact_dcontact_d
238 
239  for ch1_idx in range(len(self.chain_nameschain_names)):
240  for ch2_idx in range(ch1_idx + 1, len(self.chain_nameschain_names)):
241  # chains which fulfill the following expressions have no contact
242  # within self.contact_d
243  if np.max(min_chain_pos[ch1_idx] - max_chain_pos[ch2_idx]) > self.contact_dcontact_d:
244  continue
245  if np.max(min_chain_pos[ch2_idx] - max_chain_pos[ch1_idx]) > self.contact_dcontact_d:
246  continue
247 
248  # same thing for residue positions but all at once
249  skip_one = np.subtract.outer(min_res_x[ch1_idx], max_res_x[ch2_idx]) > self.contact_dcontact_d
250  skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_y[ch1_idx], max_res_y[ch2_idx]) > self.contact_dcontact_d)
251  skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_z[ch1_idx], max_res_z[ch2_idx]) > self.contact_dcontact_d)
252  skip_two = np.subtract.outer(min_res_x[ch2_idx], max_res_x[ch1_idx]) > self.contact_dcontact_d
253  skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_y[ch2_idx], max_res_y[ch1_idx]) > self.contact_dcontact_d)
254  skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_z[ch2_idx], max_res_z[ch1_idx]) > self.contact_dcontact_d)
255  skip = np.logical_or(skip_one, skip_two.T)
256 
257  # identify residue pairs for which we cannot exclude a contact
258  r1_indices, r2_indices = np.nonzero(np.logical_not(skip))
259  ch1_per_res_pos = per_res_pos[ch1_idx]
260  ch2_per_res_pos = per_res_pos[ch2_idx]
261  for r1_idx, r2_idx in zip(r1_indices, r2_indices):
262  # compute pairwise distances
263  p1 = ch1_per_res_pos[r1_idx]
264  p2 = ch2_per_res_pos[r2_idx]
265  x2 = np.sum(p1**2, axis=1) # (m)
266  y2 = np.sum(p2**2, axis=1) # (n)
267  xy = np.matmul(p1, p2.T) # (m, n)
268  x2 = x2.reshape(-1, 1)
269  squared_distances = x2 - 2*xy + y2 # (m, n)
270  if np.min(squared_distances) <= scd:
271  # its a contact!
272  r1 = residue_lists[ch1_idx][r1_idx]
273  r2 = residue_lists[ch2_idx][r2_idx]
274  cname_key = (self.chain_nameschain_names[ch1_idx], self.chain_nameschain_names[ch2_idx])
275  if cname_key not in self._contacts_contacts:
276  self._contacts_contacts[cname_key] = set()
277  self._contacts_contacts[cname_key].add((r1.GetIntProp("contact_idx"),
278  r2.GetIntProp("contact_idx")))
279  rnum1 = r1.GetNumber()
280  hr1 = f"{self.chain_names[ch1_idx]}.{rnum1.num}.{rnum1.ins_code}"
281  rnum2 = r2.GetNumber()
282  hr2 = f"{self.chain_names[ch2_idx]}.{rnum2.num}.{rnum2.ins_code}"
283  self._hr_contacts_hr_contacts.append((hr1.strip("\u0000"),
284  hr2.strip("\u0000")))
285 
286 
287  def _SetupInterfaceResidues(self):
288  self._interface_residues_interface_residues = {cname: set() for cname in self.chain_nameschain_names}
289  for k,v in self.contactscontacts.items():
290  for item in v:
291  self._interface_residues_interface_residues[k[0]].add(item[0])
292  self._interface_residues_interface_residues[k[1]].add(item[1])
293 
294  def _SetupHRInterfaceResidues(self):
295  interface_residues = set()
296  for item in self.hr_contactshr_contacts:
297  interface_residues.add(item[0])
298  interface_residues.add(item[1])
299  self._hr_interface_residues_hr_interface_residues = list(interface_residues)
300 
301 
303  """
304  Holds data relevant to compute ics
305  """
306  def __init__(self, n_trg_contacts, n_mdl_contacts, n_union, n_intersection):
307  self._n_trg_contacts_n_trg_contacts = n_trg_contacts
308  self._n_mdl_contacts_n_mdl_contacts = n_mdl_contacts
309  self._n_union_n_union = n_union
310  self._n_intersection_n_intersection = n_intersection
311 
312  @property
313  def n_trg_contacts(self):
314  """ Number of contacts in target
315 
316  :type: :class:`int`
317  """
318  return self._n_trg_contacts_n_trg_contacts
319 
320  @property
321  def n_mdl_contacts(self):
322  """ Number of contacts in model
323 
324  :type: :class:`int`
325  """
326  return self._n_mdl_contacts_n_mdl_contacts
327 
328  @property
329  def precision(self):
330  """ Precision of model contacts
331 
332  The fraction of model contacts that are also present in target
333 
334  :type: :class:`int`
335  """
336  if self._n_mdl_contacts_n_mdl_contacts != 0:
337  return self._n_intersection_n_intersection / self._n_mdl_contacts_n_mdl_contacts
338  else:
339  return 0.0
340 
341  @property
342  def recall(self):
343  """ Recall of model contacts
344 
345  The fraction of target contacts that are also present in model
346 
347  :type: :class:`int`
348  """
349  if self._n_trg_contacts_n_trg_contacts != 0:
350  return self._n_intersection_n_intersection / self._n_trg_contacts_n_trg_contacts
351  else:
352  return 0.0
353 
354  @property
355  def ics(self):
356  """ The Interface Contact Similarity score (ICS)
357 
358  Combination of :attr:`precision` and :attr:`recall` using the F1-measure
359 
360  :type: :class:`float`
361  """
362  p = self.precisionprecision
363  r = self.recallrecall
364  nominator = p*r
365  denominator = p + r
366  if denominator != 0.0:
367  return 2*nominator/denominator
368  else:
369  return 0.0
370 
372  """
373  Holds data relevant to compute ips
374  """
375  def __init__(self, n_trg_int_res, n_mdl_int_res, n_union, n_intersection):
376  self._n_trg_int_res_n_trg_int_res = n_trg_int_res
377  self._n_mdl_int_res_n_mdl_int_res = n_mdl_int_res
378  self._n_union_n_union = n_union
379  self._n_intersection_n_intersection = n_intersection
380 
381  @property
382  def n_trg_int_res(self):
383  """ Number of interface residues in target
384 
385  :type: :class:`int`
386  """
387  return self._n_trg_contacts
388 
389  @property
390  def n_mdl_int_res(self):
391  """ Number of interface residues in model
392 
393  :type: :class:`int`
394  """
395  return self._n_mdl_int_res_n_mdl_int_res
396 
397  @property
398  def precision(self):
399  """ Precision of model interface residues
400 
401  The fraction of model interface residues that are also interface
402  residues in target
403 
404  :type: :class:`int`
405  """
406  if self._n_mdl_int_res_n_mdl_int_res != 0:
407  return self._n_intersection_n_intersection / self._n_mdl_int_res_n_mdl_int_res
408  else:
409  return 0.0
410 
411  @property
412  def recall(self):
413  """ Recall of model interface residues
414 
415  The fraction of target interface residues that are also interface
416  residues in model
417 
418  :type: :class:`int`
419  """
420  if self._n_trg_int_res_n_trg_int_res != 0:
421  return self._n_intersection_n_intersection / self._n_trg_int_res_n_trg_int_res
422  else:
423  return 0.0
424 
425  @property
426  def ips(self):
427  """ The Interface Patch Similarity score (IPS)
428 
429  Jaccard coefficient of interface residues in model/target.
430  Technically thats :attr:`intersection`/:attr:`union`
431 
432  :type: :class:`float`
433  """
434  if(self._n_union_n_union > 0):
435  return self._n_intersection_n_intersection/self._n_union_n_union
436  return 0.0
437 
439  """ Helper object to compute Contact scores
440 
441  Tightly integrated into the mechanisms from the chain_mapping module.
442  The prefered way to derive an object of type :class:`ContactScorer` is
443  through the static constructor: :func:`~FromMappingResult`.
444 
445  Usage is the same as for :class:`ost.mol.alg.QSScorer`
446  """
447 
448  def __init__(self, target, chem_groups, model, alns,
449  contact_mode="aa", contact_d=5.0):
450  self._cent1_cent1 = ContactEntity(target, contact_mode = contact_mode,
451  contact_d = contact_d)
452  # ensure that target chain names match the ones in chem_groups
453  chem_group_ch_names = list(itertools.chain.from_iterable(chem_groups))
454  if self._cent1_cent1.chain_names != sorted(chem_group_ch_names):
455  raise RuntimeError(f"Expect exact same chain names in chem_groups "
456  f"and in target (which is processed to only "
457  f"contain peptides/nucleotides). target: "
458  f"{self._cent1.chain_names}, chem_groups: "
459  f"{chem_group_ch_names}")
460 
461  self._chem_groups_chem_groups = chem_groups
462  self._cent2_cent2 = ContactEntity(model, contact_mode = contact_mode,
463  contact_d = contact_d)
464  self._alns_alns = alns
465 
466  # cache for mapped interface scores
467  # key: tuple of tuple ((qsent1_ch1, qsent1_ch2),
468  # ((qsent2_ch1, qsent2_ch2))
469  # value: tuple with four numbers required for computation of
470  # per-interface scores.
471  # The first two are relevant for ICS, the others for per
472  # interface IPS.
473  # 1: n_union_contacts
474  # 2: n_intersection_contacts
475  # 3: n_union_interface_residues
476  # 4: n_intersection_interface_residues
477  self._mapped_cache_interface_mapped_cache_interface = dict()
478 
479  # cache for mapped single chain scores
480  # for interface residues of single chains
481  # key: tuple: (qsent1_ch, qsent2_ch)
482  # value: tuple with two numbers required for computation of IPS
483  # 1: n_union
484  # 2: n_intersection
485  self._mapped_cache_sc_mapped_cache_sc = dict()
486 
487  @staticmethod
488  def FromMappingResult(mapping_result, contact_mode="aa", contact_d = 5.0):
489  """ The preferred way to get a :class:`ContactScorer`
490 
491  Static constructor that derives an object of type :class:`ContactScorer`
492  using a :class:`ost.mol.alg.chain_mapping.MappingResult`
493 
494  :param mapping_result: Data source
495  :type mapping_result: :class:`ost.mol.alg.chain_mapping.MappingResult`
496  """
497  contact_scorer = ContactScorer(mapping_result.target,
498  mapping_result.chem_groups,
499  mapping_result.model,
500  mapping_result.alns,
501  contact_mode = contact_mode,
502  contact_d = contact_d)
503  return contact_scorer
504 
505  @property
506  def cent1(self):
507  """ Represents *target*
508 
509  :type: :class:`ContactEntity`
510  """
511  return self._cent1_cent1
512 
513  @property
514  def chem_groups(self):
515  """ Groups of chemically equivalent chains in *target*
516 
517  Provided at object construction
518 
519  :type: :class:`list` of :class:`list` of :class:`str`
520  """
521  return self._chem_groups_chem_groups
522 
523  @property
524  def cent2(self):
525  """ Represents *model*
526 
527  :type: :class:`ContactEntity`
528  """
529  return self._cent2_cent2
530 
531  @property
532  def alns(self):
533  """ Alignments between chains in :attr:`~cent1` and :attr:`~cent2`
534 
535  Provided at object construction. Each alignment is accessible with
536  ``alns[(t_chain,m_chain)]``. First sequence is the sequence of the
537  respective chain in :attr:`~cent1`, second sequence the one from
538  :attr:`~cent2`.
539 
540  :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
541  :class:`ost.seq.AlignmentHandle`
542  """
543  return self._alns_alns
544 
545  def ScoreICS(self, mapping, check=True):
546  """ Computes ICS given chain mapping
547 
548  Again, the preferred way is to get *mapping* is from an object
549  of type :class:`ost.mol.alg.chain_mapping.MappingResult`.
550 
551  :param mapping: see
552  :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping`
553  :type mapping: :class:`list` of :class:`list` of :class:`str`
554  :param check: Perform input checks, can be disabled for speed purposes
555  if you know what you're doing.
556  :type check: :class:`bool`
557  :returns: Result object of type :class:`ContactScorerResultICS`
558  """
559 
560  if check:
561  # ensure that dimensionality of mapping matches self.chem_groups
562  if len(self.chem_groupschem_groups) != len(mapping):
563  raise RuntimeError("Dimensions of self.chem_groups and mapping "
564  "must match")
565  for a,b in zip(self.chem_groupschem_groups, mapping):
566  if len(a) != len(b):
567  raise RuntimeError("Dimensions of self.chem_groups and "
568  "mapping must match")
569  # ensure that chain names in mapping are all present in cent2
570  for name in itertools.chain.from_iterable(mapping):
571  if name is not None and name not in self.cent2cent2.chain_names:
572  raise RuntimeError(f"Each chain in mapping must be present "
573  f"in self.cent2. No match for "
574  f"\"{name}\"")
575 
576  flat_mapping = dict()
577  for a, b in zip(self.chem_groupschem_groups, mapping):
578  flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
579 
580  return self.ICSFromFlatMappingICSFromFlatMapping(flat_mapping)
581 
582  def ScoreICSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2):
583  """ Computes ICS scores only considering one interface
584 
585  This only works for interfaces that are computed in :func:`Score`, i.e.
586  interfaces for which the alignments are set up correctly.
587 
588  :param trg_ch1: Name of first interface chain in target
589  :type trg_ch1: :class:`str`
590  :param trg_ch2: Name of second interface chain in target
591  :type trg_ch2: :class:`str`
592  :param mdl_ch1: Name of first interface chain in model
593  :type mdl_ch1: :class:`str`
594  :param mdl_ch2: Name of second interface chain in model
595  :type mdl_ch2: :class:`str`
596  :returns: Result object of type :class:`ContactScorerResultICS`
597  :raises: :class:`RuntimeError` if no aln for trg_ch1/mdl_ch1 or
598  trg_ch2/mdl_ch2 is available.
599  """
600  if (trg_ch1, mdl_ch1) not in self.alnsalns:
601  raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
602  f"mdl_ch1 ({mdl_ch1}) available. Did you "
603  f"construct the QSScorer object from a "
604  f"MappingResult and are trg_ch1 and mdl_ch1 "
605  f"mapped to each other?")
606  if (trg_ch2, mdl_ch2) not in self.alnsalns:
607  raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
608  f"mdl_ch1 ({mdl_ch1}) available. Did you "
609  f"construct the QSScorer object from a "
610  f"MappingResult and are trg_ch1 and mdl_ch1 "
611  f"mapped to each other?")
612  trg_int = (trg_ch1, trg_ch2)
613  mdl_int = (mdl_ch1, mdl_ch2)
614  trg_int_r = (trg_ch2, trg_ch1)
615  mdl_int_r = (mdl_ch2, mdl_ch1)
616 
617  if trg_int in self.cent1cent1.contacts:
618  n_trg = len(self.cent1cent1.contacts[trg_int])
619  elif trg_int_r in self.cent1cent1.contacts:
620  n_trg = len(self.cent1cent1.contacts[trg_int_r])
621  else:
622  n_trg = 0
623 
624  if mdl_int in self.cent2cent2.contacts:
625  n_mdl = len(self.cent2cent2.contacts[mdl_int])
626  elif mdl_int_r in self.cent2cent2.contacts:
627  n_mdl = len(self.cent2cent2.contacts[mdl_int_r])
628  else:
629  n_mdl = 0
630 
631  n_union, n_intersection, _, _ = self._MappedInterfaceScores_MappedInterfaceScores(trg_int, mdl_int)
632  return ContactScorerResultICS(n_trg, n_mdl, n_union, n_intersection)
633 
634  def ICSFromFlatMapping(self, flat_mapping):
635  """ Same as :func:`ScoreICS` but with flat mapping
636 
637  :param flat_mapping: Dictionary with target chain names as keys and
638  the mapped model chain names as value
639  :type flat_mapping: :class:`dict` with :class:`str` as key and value
640  :returns: Result object of type :class:`ContactScorerResultICS`
641  """
642  n_trg = sum([len(x) for x in self.cent1cent1.contacts.values()])
643  n_mdl = sum([len(x) for x in self.cent2cent2.contacts.values()])
644  n_union = 0
645  n_intersection = 0
646 
647  processed_cent2_interfaces = set()
648  for int1 in self.cent1cent1.interacting_chains:
649  if int1[0] in flat_mapping and int1[1] in flat_mapping:
650  int2 = (flat_mapping[int1[0]], flat_mapping[int1[1]])
651  a, b, _, _ = self._MappedInterfaceScores_MappedInterfaceScores(int1, int2)
652  n_union += a
653  n_intersection += b
654  processed_cent2_interfaces.add((min(int2), max(int2)))
655 
656  # process interfaces that only exist in qsent2
657  r_flat_mapping = {v:k for k,v in flat_mapping.items()} # reverse mapping
658  for int2 in self.cent2cent2.interacting_chains:
659  if int2 not in processed_cent2_interfaces:
660  if int2[0] in r_flat_mapping and int2[1] in r_flat_mapping:
661  int1 = (r_flat_mapping[int2[0]], r_flat_mapping[int2[1]])
662  a, b, _, _ = self._MappedInterfaceScores_MappedInterfaceScores(int1, int2)
663  n_union += a
664  n_intersection += b
665 
666  return ContactScorerResultICS(n_trg, n_mdl,
667  n_union, n_intersection)
668 
669  def ScoreIPS(self, mapping, check=True):
670  """ Computes IPS given chain mapping
671 
672  Again, the preferred way is to get *mapping* is from an object
673  of type :class:`ost.mol.alg.chain_mapping.MappingResult`.
674 
675  :param mapping: see
676  :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping`
677  :type mapping: :class:`list` of :class:`list` of :class:`str`
678  :param check: Perform input checks, can be disabled for speed purposes
679  if you know what you're doing.
680  :type check: :class:`bool`
681  :returns: Result object of type :class:`ContactScorerResultIPS`
682  """
683 
684  if check:
685  # ensure that dimensionality of mapping matches self.chem_groups
686  if len(self.chem_groupschem_groups) != len(mapping):
687  raise RuntimeError("Dimensions of self.chem_groups and mapping "
688  "must match")
689  for a,b in zip(self.chem_groupschem_groups, mapping):
690  if len(a) != len(b):
691  raise RuntimeError("Dimensions of self.chem_groups and "
692  "mapping must match")
693  # ensure that chain names in mapping are all present in cent2
694  for name in itertools.chain.from_iterable(mapping):
695  if name is not None and name not in self.cent2cent2.chain_names:
696  raise RuntimeError(f"Each chain in mapping must be present "
697  f"in self.cent2. No match for "
698  f"\"{name}\"")
699 
700  flat_mapping = dict()
701  for a, b in zip(self.chem_groupschem_groups, mapping):
702  flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
703 
704  return self.IPSFromFlatMappingIPSFromFlatMapping(flat_mapping)
705 
706  def ScoreIPSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2):
707  """ Computes IPS scores only considering one interface
708 
709  This only works for interfaces that are computed in :func:`Score`, i.e.
710  interfaces for which the alignments are set up correctly.
711 
712  :param trg_ch1: Name of first interface chain in target
713  :type trg_ch1: :class:`str`
714  :param trg_ch2: Name of second interface chain in target
715  :type trg_ch2: :class:`str`
716  :param mdl_ch1: Name of first interface chain in model
717  :type mdl_ch1: :class:`str`
718  :param mdl_ch2: Name of second interface chain in model
719  :type mdl_ch2: :class:`str`
720  :returns: Result object of type :class:`ContactScorerResultIPS`
721  :raises: :class:`RuntimeError` if no aln for trg_ch1/mdl_ch1 or
722  trg_ch2/mdl_ch2 is available.
723  """
724  if (trg_ch1, mdl_ch1) not in self.alnsalns:
725  raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
726  f"mdl_ch1 ({mdl_ch1}) available. Did you "
727  f"construct the QSScorer object from a "
728  f"MappingResult and are trg_ch1 and mdl_ch1 "
729  f"mapped to each other?")
730  if (trg_ch2, mdl_ch2) not in self.alnsalns:
731  raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
732  f"mdl_ch1 ({mdl_ch1}) available. Did you "
733  f"construct the QSScorer object from a "
734  f"MappingResult and are trg_ch1 and mdl_ch1 "
735  f"mapped to each other?")
736  trg_int = (trg_ch1, trg_ch2)
737  mdl_int = (mdl_ch1, mdl_ch2)
738  trg_int_r = (trg_ch2, trg_ch1)
739  mdl_int_r = (mdl_ch2, mdl_ch1)
740 
741  if trg_int in self.cent1cent1.contacts:
742  n_trg = len(self.cent1cent1.contacts[trg_int])
743  elif trg_int_r in self.cent1cent1.contacts:
744  n_trg = len(self.cent1cent1.contacts[trg_int_r])
745  else:
746  n_trg = 0
747 
748  if mdl_int in self.cent2cent2.contacts:
749  n_mdl = len(self.cent2cent2.contacts[mdl_int])
750  elif mdl_int_r in self.cent2cent2.contacts:
751  n_mdl = len(self.cent2cent2.contacts[mdl_int_r])
752  else:
753  n_mdl = 0
754 
755  _, _, n_union, n_intersection = self._MappedInterfaceScores_MappedInterfaceScores(trg_int, mdl_int)
756  return ContactScorerResultIPS(n_trg, n_mdl, n_union, n_intersection)
757 
758 
759  def IPSFromFlatMapping(self, flat_mapping):
760  """ Same as :func:`ScoreIPS` but with flat mapping
761 
762  :param flat_mapping: Dictionary with target chain names as keys and
763  the mapped model chain names as value
764  :type flat_mapping: :class:`dict` with :class:`str` as key and value
765  :returns: Result object of type :class:`ContactScorerResultIPS`
766  """
767  n_trg = sum([len(x) for x in self.cent1cent1.interface_residues.values()])
768  n_mdl = sum([len(x) for x in self.cent2cent2.interface_residues.values()])
769  n_union = 0
770  n_intersection = 0
771 
772  processed_cent2_chains = set()
773  for trg_ch in self.cent1cent1.chain_names:
774  if trg_ch in flat_mapping:
775  a, b = self._MappedSCScores_MappedSCScores(trg_ch, flat_mapping[trg_ch])
776  n_union += a
777  n_intersection += b
778  processed_cent2_chains.add(flat_mapping[trg_ch])
779  else:
780  n_union += len(self.cent1cent1.interface_residues[trg_ch])
781 
782  for mdl_ch in self._cent2_cent2.chain_names:
783  if mdl_ch not in processed_cent2_chains:
784  n_union += len(self.cent2cent2.interface_residues[mdl_ch])
785 
786  return ContactScorerResultIPS(n_trg, n_mdl,
787  n_union, n_intersection)
788 
789 
790  def _MappedInterfaceScores(self, int1, int2):
791  key_one = (int1, int2)
792  if key_one in self._mapped_cache_interface_mapped_cache_interface:
793  return self._mapped_cache_interface_mapped_cache_interface[key_one]
794  key_two = ((int1[1], int1[0]), (int2[1], int2[0]))
795  if key_two in self._mapped_cache_interface_mapped_cache_interface:
796  return self._mapped_cache_interface_mapped_cache_interface[key_two]
797 
798  a, b, c, d = self._InterfaceScores_InterfaceScores(int1, int2)
799  self._mapped_cache_interface_mapped_cache_interface[key_one] = (a, b, c, d)
800  return (a, b, c, d)
801 
802  def _InterfaceScores(self, int1, int2):
803  if int1 in self.cent1cent1.contacts:
804  ref_contacts = self.cent1cent1.contacts[int1]
805  elif (int1[1], int1[0]) in self.cent1cent1.contacts:
806  ref_contacts = self.cent1cent1.contacts[(int1[1], int1[0])]
807  # need to reverse contacts
808  ref_contacts = set([(x[1], x[0]) for x in ref_contacts])
809  else:
810  ref_contacts = set() # no contacts at all
811 
812  if int2 in self.cent2cent2.contacts:
813  mdl_contacts = self.cent2cent2.contacts[int2]
814  elif (int2[1], int2[0]) in self.cent2cent2.contacts:
815  mdl_contacts = self.cent2cent2.contacts[(int2[1], int2[0])]
816  # need to reverse contacts
817  mdl_contacts = set([(x[1], x[0]) for x in mdl_contacts])
818  else:
819  mdl_contacts = set() # no contacts at all
820 
821  # indices in contacts lists are specific to the respective
822  # structures, need manual mapping from alignments
823  ch1_aln = self.alnsalns[(int1[0], int2[0])]
824  ch2_aln = self.alnsalns[(int1[1], int2[1])]
825  mapped_ref_contacts = set()
826  mapped_mdl_contacts = set()
827  for c in ref_contacts:
828  mapped_c = (ch1_aln.GetPos(0, c[0]), ch2_aln.GetPos(0, c[1]))
829  mapped_ref_contacts.add(mapped_c)
830  for c in mdl_contacts:
831  mapped_c = (ch1_aln.GetPos(1, c[0]), ch2_aln.GetPos(1, c[1]))
832  mapped_mdl_contacts.add(mapped_c)
833 
834  contact_union = len(mapped_ref_contacts.union(mapped_mdl_contacts))
835  contact_intersection = len(mapped_ref_contacts.intersection(mapped_mdl_contacts))
836 
837  # above, we computed the union and intersection on actual
838  # contacts. Here, we do the same on interface residues
839 
840  # process interface residues of chain one in interface
841  tmp_ref = set([x[0] for x in mapped_ref_contacts])
842  tmp_mdl = set([x[0] for x in mapped_mdl_contacts])
843  intres_union = len(tmp_ref.union(tmp_mdl))
844  intres_intersection = len(tmp_ref.intersection(tmp_mdl))
845 
846  # process interface residues of chain two in interface
847  tmp_ref = set([x[1] for x in mapped_ref_contacts])
848  tmp_mdl = set([x[1] for x in mapped_mdl_contacts])
849  intres_union += len(tmp_ref.union(tmp_mdl))
850  intres_intersection += len(tmp_ref.intersection(tmp_mdl))
851 
852  return (contact_union, contact_intersection,
853  intres_union, intres_intersection)
854 
855  def _MappedSCScores(self, ref_ch, mdl_ch):
856  if (ref_ch, mdl_ch) in self._mapped_cache_sc_mapped_cache_sc:
857  return self._mapped_cache_sc_mapped_cache_sc[(ref_ch, mdl_ch)]
858  n_union, n_intersection = self._SCScores_SCScores(ref_ch, mdl_ch)
859  self._mapped_cache_sc_mapped_cache_sc[(ref_ch, mdl_ch)] = (n_union, n_intersection)
860  return (n_union, n_intersection)
861 
862  def _SCScores(self, ch1, ch2):
863  ref_int_res = self.cent1cent1.interface_residues[ch1]
864  mdl_int_res = self.cent2cent2.interface_residues[ch2]
865  aln = self.alnsalns[(ch1, ch2)]
866  mapped_ref_int_res = set()
867  mapped_mdl_int_res = set()
868  for r_idx in ref_int_res:
869  mapped_ref_int_res.add(aln.GetPos(0, r_idx))
870  for r_idx in mdl_int_res:
871  mapped_mdl_int_res.add(aln.GetPos(1, r_idx))
872  return(len(mapped_ref_int_res.union(mapped_mdl_int_res)),
873  len(mapped_ref_int_res.intersection(mapped_mdl_int_res)))
874 
875 # specify public interface
876 __all__ = ('ContactEntity', 'ContactScorerResultICS', 'ContactScorerResultIPS', 'ContactScorer')
def __init__(self, ent, contact_d=5.0, contact_mode="aa")
def ScoreICS(self, mapping, check=True)
def __init__(self, target, chem_groups, model, alns, contact_mode="aa", contact_d=5.0)
def ScoreIPS(self, mapping, check=True)
def FromMappingResult(mapping_result, contact_mode="aa", contact_d=5.0)
def ScoreICSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2)
def ScoreIPSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2)
def _MappedSCScores(self, ref_ch, mdl_ch)
def _MappedInterfaceScores(self, int1, int2)
def IPSFromFlatMapping(self, flat_mapping)
def ICSFromFlatMapping(self, flat_mapping)
def __init__(self, n_trg_contacts, n_mdl_contacts, n_union, n_intersection)
def __init__(self, n_trg_int_res, n_mdl_int_res, n_union, n_intersection)