9 from scipy.spatial.distance
import cdist
12 x2 = np.sum(p1**2, axis=1)
13 y2 = np.sum(p2**2, axis=1)
14 xy = np.matmul(p1, p2.T)
15 x2 = x2.reshape(-1, 1)
16 return np.sqrt(x2 - 2*xy + y2)
19 """ Defines atoms for custom compounds
21 lDDT requires the reference atoms of a compound which are typically
22 extracted from a :class:`ost.conop.CompoundLib`. This lightweight
23 container allows to handle arbitrary compounds which are not
24 necessarily in the compound library.
26 :param atom_names: Names of atoms of custom compound
27 :type atom_names: :class:`list` of :class:`str`
34 """ Construct custom compound from residue
36 :param res: Residue from which reference atom names are extracted,
37 hydrogen/deuterium atoms are filtered out
38 :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
39 :returns: :class:`CustomCompound`
41 at_names = [a.name
for a
in res.atoms
if a.element
not in [
"H",
"D"]]
42 if len(at_names) != len(set(at_names)):
43 raise RuntimeError(
"Duplicate atoms detected in CustomCompound")
48 """Container for symmetric compounds
50 lDDT considers symmetries and selects the one resulting in the highest
53 A symmetry is defined as a renaming operation on one or more atoms that
54 leads to a chemically equivalent residue. Example would be OD1 and OD2 in
55 ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
58 Use :func:`AddSymmetricCompound` to define a symmetry which can then
59 directly be accessed through the *symmetric_compounds* member.
65 """Adds symmetry for compound with *name*
67 :param name: Name of compound with symmetry
68 :type name: :class:`str`
69 :param symmetric_atoms: Pairs of atom names that define renaming
70 operation, i.e. after applying all switches
71 defined in the tuples, the resulting residue
72 should be chemically equivalent. Atom names
73 must refer to the PDB component dictionary.
74 :type symmetric_atoms: :class:`list` of :class:`tuple`
76 for pair
in symmetric_atoms:
78 raise RuntimeError(
"Expect pairs when defining symmetries")
83 """Constructs and returns :class:`SymmetrySettings` object for natural amino
89 symmetry_settings.AddSymmetricCompound(
"ASP", [(
"OD1",
"OD2")])
92 symmetry_settings.AddSymmetricCompound(
"GLU", [(
"OE1",
"OE2")])
95 symmetry_settings.AddSymmetricCompound(
"LEU", [(
"CD1",
"CD2")])
98 symmetry_settings.AddSymmetricCompound(
"VAL", [(
"CG1",
"CG2")])
101 symmetry_settings.AddSymmetricCompound(
"ARG", [(
"NH1",
"NH2")])
104 symmetry_settings.AddSymmetricCompound(
105 "PHE", [(
"CD1",
"CD2"), (
"CE1",
"CE2")]
109 symmetry_settings.AddSymmetricCompound(
110 "TYR", [(
"CD1",
"CD2"), (
"CE1",
"CE2")]
113 return symmetry_settings
117 """lDDT scorer object for a specific target
119 Sets up everything to score models of that target. lDDT (local distance
120 difference test) is defined as fraction of pairwise distances which exhibit
121 a difference < threshold when considering target and model. In case of
122 multiple thresholds, the average is returned. See
124 V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
125 superposition-free score for comparing protein structures and models using
126 distance difference tests, Bioinformatics, 2013
128 :param target: The target
129 :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
130 :param compound_lib: Compound library from which a compound for each residue
131 is extracted based on its name. Uses
132 :func:`ost.conop.GetDefaultLib` if not given, raises
133 if this returns no valid compound library. Atoms
134 defined in the compound are searched in the residue and
135 build the reference for scoring. If the residue has
136 atoms with names ["A", "B", "C"] but the corresponding
137 compound only has ["A", "B"], "A" and "B" are
138 considered for scoring. If the residue has atoms
139 ["A", "B"] but the compound has ["A", "B", "C"], "C" is
140 considered missing and does not influence scoring, even
141 if present in the model.
142 :param custom_compounds: Custom compounds defining reference atoms. If
143 given, *custom_compounds* take precedent over
145 :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
146 key and :class:`CustomCompound` as value.
147 :type compound_lib: :class:`ost.conop.CompoundLib`
148 :param inclusion_radius: All pairwise distances < *inclusion_radius* are
149 considered for scoring
150 :type inclusion_radius: :class:`float`
151 :param sequence_separation: Only pairwise distances between atoms of
152 residues which are further apart than this
153 threshold are considered. Residue distance is
154 based on resnum. The default (0) considers all
155 pairwise distances except intra-residue
157 :type sequence_separation: :class:`int`
158 :param symmetry_settings: Define residues exhibiting internal symmetry, uses
159 :func:`GetDefaultSymmetrySettings` if not given.
160 :type symmetry_settings: :class:`SymmetrySettings`
161 :param seqres_mapping: Mapping of model residues at the scoring stage
162 happens with residue numbers defining their location
163 in a reference sequence (SEQRES) using one based
164 indexing. If the residue numbers in *target* don't
165 correspond to that SEQRES, you can specify the
166 mapping manually. You can provide a dictionary to
167 specify a reference sequence (SEQRES) for one or more
168 chain(s). Key: chain name, value: alignment
169 (seq1: SEQRES, seq2: sequence of residues in chain).
170 Example: The residues in a chain with name "A" have
171 sequence "YEAH" and residue numbers [42,43,44,45].
172 You can provide an alignment with seq1 "``HELLYEAH``"
173 and seq2 "``----YEAH``". "Y" gets assigned residue
174 number 5, "E" gets assigned 6 and so on no matter
175 what the original residue numbers were.
176 :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
177 :class:`ost.seq.AlignmentHandle`)
178 :param bb_only: Only consider atoms with name "CA" in case of amino acids and
179 "C3'" for Nucleotides. this invalidates *compound_lib*.
180 Raises if any residue in *target* is not
181 `r.chem_class.IsPeptideLinking()` or
182 `r.chem_class.IsNucleotideLinking()`
183 :type bb_only: :class:`bool`
184 :raises: :class:`RuntimeError` if *target* contains compound which is not in
185 *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
186 specifies symmetric atoms that are not present in the according
187 compound in *compound_lib*, :class:`RuntimeError` if
188 *seqres_mapping* is not provided and *target* contains residue
189 numbers with insertion codes or the residue numbers for each chain
190 are not monotonically increasing, :class:`RuntimeError` if
191 *seqres_mapping* is provided but an alignment is invalid
192 (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
193 residues in corresponding chains).
199 custom_compounds=None,
201 sequence_separation=0,
202 symmetry_settings=None,
203 seqres_mapping=dict(),
210 if compound_lib
is None:
211 compound_lib = conop.GetDefaultLib()
212 if compound_lib
is None:
213 raise RuntimeError(
"No compound_lib given and conop.GetDefaultLib "
214 "returns no valid compound library")
217 if symmetry_settings
is None:
318 lDDTScorer._SetupDistances(self.
targettarget, self.
n_atomsn_atoms,
327 lDDTScorer._SetupDistances(self.
targettarget, self.
n_atomsn_atoms,
352 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
362 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
372 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
382 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
392 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
402 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
412 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
422 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
428 def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
429 local_lddt_prop=None, local_contact_prop=None,
430 chain_mapping=None, no_interchain=False,
431 no_intrachain=False, penalize_extra_chains=False,
432 residue_mapping=None, return_dist_test=False,
433 check_resnames=True, add_mdl_contacts=False):
434 """Computes lDDT of *model* - globally and per-residue
436 :param model: Model to be scored - models are preferably scored upon
437 performing stereo-chemistry checks in order to punish for
438 non-sensical irregularities. This must be done separately
439 as a pre-processing step. Target contacts that are not
440 covered by *model* are considered not conserved, thus
441 decreasing lDDT score. This also includes missing model
442 chains or model chains for which no mapping is provided in
444 :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
445 :param thresholds: Thresholds of distance differences to be considered
446 as correct - see docs in constructor for more info.
447 default: [0.5, 1.0, 2.0, 4.0]
448 :type thresholds: :class:`list` of :class:`floats`
449 :param local_lddt_prop: If set, per-residue scores will be assigned as
450 generic float property of that name
451 :type local_lddt_prop: :class:`str`
452 :param local_contact_prop: If set, number of expected contacts as well
453 as number of conserved contacts will be
454 assigned as generic int property.
455 Excected contacts will be set as
456 <local_contact_prop>_exp, conserved contacts
457 as <local_contact_prop>_cons. Values
458 are summed over all thresholds.
459 :type local_contact_prop: :class:`str`
460 :param chain_mapping: Mapping of model chains (key) onto target chains
461 (value). This is required if target or model have
463 :type chain_mapping: :class:`dict` with :class:`str` as keys/values
464 :param no_interchain: Whether to exclude interchain contacts
465 :type no_interchain: :class:`bool`
466 :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
467 consider interface related contacts)
468 :type no_intrachain: :class:`bool`
469 :param penalize_extra_chains: Whether to include a fixed penalty for
470 additional chains in the model that are
471 not mapped to the target. ONLY AFFECTS
472 RETURNED GLOBAL SCORE. In detail: adds the
473 number of intra-chain contacts of each
474 extra chain to the expected contacts, thus
476 :type penalize_extra_chains: :class:`bool`
477 :param residue_mapping: By default, residue mapping is based on residue
478 numbers. That means, a model chain and the
479 respective target chain map to the same
480 underlying reference sequence (SEQRES).
481 Alternatively, you can specify one or
482 several alignment(s) between model and target
483 chains by providing a dictionary. key: Name
484 of chain in model (respective target chain is
485 extracted from *chain_mapping*),
486 value: Alignment with first sequence
487 corresponding to target chain and second
488 sequence to model chain. There is NO reference
489 sequence involved, so the two sequences MUST
490 exactly match the actual residues observed in
491 the respective target/model chains (ATOMSEQ).
492 :type residue_mapping: :class:`dict` with key: :class:`str`,
493 value: :class:`ost.seq.AlignmentHandle`
494 :param return_dist_test: Whether to additionally return the underlying
495 per-residue data for the distance difference
496 test. Adds five objects to the return tuple.
497 First: Number of total contacts summed over all
499 Second: Number of conserved contacts summed
501 Third: list with length of scored residues.
502 Contains indices referring to model.residues.
503 Fourth: numpy array of size
504 len(scored_residues) containing the number of
506 Fifth: numpy matrix of shape
507 (len(scored_residues), len(thresholds))
508 specifying how many for each threshold are
510 :param check_resnames: On by default. Enforces residue name matches
511 between mapped model and target residues.
512 :type check_resnames: :class:`bool`
513 :param add_mdl_contacts: Adds model contacts - Only using contacts that
514 are within a certain distance threshold in the
515 target does not penalize for added model
516 contacts. If set to True, this flag will also
517 consider target contacts that are within the
518 specified distance threshold in the model but
519 not necessarily in the target. No contact will
520 be added if the respective atom pair is not
521 resolved in the target.
522 :type add_mdl_contacts: :class:`bool`
524 :returns: global and per-residue lDDT scores as a tuple -
525 first element is global lDDT score (None if *target* has no
526 contacts) and second element a list of per-residue scores with
527 length len(*model*.residues). None is assigned to residues that
528 are not covered by target. If a residue is covered but has no
529 contacts in *target*, 0.0 is assigned.
531 if chain_mapping
is None:
532 if len(self.
chain_nameschain_names) > 1
or len(model.chains) > 1:
533 raise NotImplementedError(
"Must provide chain mapping if "
534 "target or model have > 1 chains.")
535 chain_mapping = {model.chains[0].GetName(): self.
chain_nameschain_names[0]}
538 for model_chain, target_chain
in chain_mapping.items():
539 if target_chain
not in self.
chain_nameschain_names:
540 raise RuntimeError(f
"Target chain specified in "
541 f
"chain_mapping ({target_chain}) does "
542 f
"not exist. Target has chains: "
543 f
"{self.chain_names}")
544 ch = model.FindChain(model_chain)
546 raise RuntimeError(f
"Model chain specified in "
547 f
"chain_mapping ({model_chain}) does "
548 f
"not exist. Model has chains: "
549 f
"{[c.GetName() for c in model.chains]}")
553 pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
554 res_indices, symmetries = self.
_ProcessModel_ProcessModel(model, chain_mapping,
555 residue_mapping = residue_mapping,
556 thresholds = thresholds,
557 check_resnames = check_resnames)
559 if no_interchain
and no_intrachain:
560 raise RuntimeError(
"no_interchain and no_intrachain flags are "
561 "mutually exclusive")
580 ref_indices, ref_distances, \
581 sym_ref_indices, sym_ref_distances = \
582 self.
_AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
583 ref_indices, ref_distances,
584 no_interchain, no_intrachain)
586 self.
_ResolveSymmetries_ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
589 per_res_exp = np.asarray([self.
_GetNExp_GetNExp(res_ref_atom_indices[idx],
590 ref_indices)
for idx
in range(len(res_indices))], dtype=np.int32)
591 per_res_conserved = self.
_EvalResidues_EvalResidues(pos, thresholds,
593 ref_indices, ref_distances)
595 n_thresh = len(thresholds)
598 per_res_lDDT = [
None] * len(model.residues)
599 for idx
in range(len(res_indices)):
600 n_exp = n_thresh * per_res_exp[idx]
602 score = np.sum(per_res_conserved[idx,:]) / n_exp
603 per_res_lDDT[res_indices[idx]] = score
605 per_res_lDDT[res_indices[idx]] = 0.0
608 n_distances = sum([len(x)
for x
in ref_indices])
609 if penalize_extra_chains:
612 lDDT_tot = int(n_thresh * n_distances)
613 lDDT_cons = int(np.sum(per_res_conserved))
616 lDDT = float(lDDT_cons) / lDDT_tot
620 residues = model.residues
621 for idx
in res_indices:
622 residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
624 if local_contact_prop:
625 residues = model.residues
626 exp_prop = local_contact_prop +
"_exp"
627 conserved_prop = local_contact_prop +
"_cons"
629 for i, r_idx
in enumerate(res_indices):
630 residues[r_idx].SetIntProp(exp_prop,
631 n_thresh * int(per_res_exp[i]))
632 residues[r_idx].SetIntProp(conserved_prop,
633 int(np.sum(per_res_conserved[i,:])))
636 return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
637 per_res_exp, per_res_conserved
639 return lDDT, per_res_lDDT
642 """Returns number of contacts expected for a certain chain in *target*
644 :param target_chain: Chain in *target* for which you want the number
646 :type target_chain: :class:`str`
647 :param no_interchain: Whether to exclude interchain contacts
648 :type no_interchain: :class:`bool`
649 :raises: :class:`RuntimeError` if specified chain doesnt exist
651 if target_chain
not in self.
chain_nameschain_names:
652 raise RuntimeError(f
"Specified chain name ({target_chain}) not in "
654 ch_idx = self.
chain_nameschain_names.index(target_chain)
664 def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
665 thresholds = [0.5, 1.0, 2.0, 4.0],
666 check_resnames = True):
667 """ Helper that generates data structures from model
672 max_pos = model.bounds.GetMax()
673 max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
674 max_coordinate += 42 * max(thresholds)
675 pos = np.ones((self.
n_atomsn_atoms, 3), dtype=np.float32) * max_coordinate
679 res_ref_atom_indices = list()
683 res_atom_indices = list()
687 res_atom_hashes = list()
695 current_model_res_idx = -1
696 for ch
in model.chains:
697 model_ch_name = ch.GetName()
698 if model_ch_name
not in chain_mapping:
699 current_model_res_idx += len(ch.residues)
701 target_ch_name = chain_mapping[model_ch_name]
703 rnums = self.
_GetChainRNums_GetChainRNums(ch, residue_mapping, model_ch_name,
706 for r, rnum
in zip(ch.residues, rnums):
707 current_model_res_idx += 1
708 res_mapper_key = (target_ch_name, rnum)
709 if res_mapper_key
not in self.
res_mapperres_mapper:
711 r_idx = self.
res_mapperres_mapper[res_mapper_key]
712 if check_resnames
and r.name != self.
compound_namescompound_names[r_idx]:
714 f
"Residue name mismatch for {r}, "
715 f
" expect {self.compound_names[r_idx]}"
720 atoms = [r.FindAtom(aname)
for aname
in anames]
721 res_ref_atom_indices.append(
722 list(range(res_start_idx, res_start_idx + len(anames)))
724 res_atom_indices.append(list())
725 res_atom_hashes.append(list())
726 res_indices.append(current_model_res_idx)
727 for a_idx, a
in enumerate(atoms):
730 pos[res_start_idx + a_idx][0] = p[0]
731 pos[res_start_idx + a_idx][1] = p[1]
732 pos[res_start_idx + a_idx][2] = p[2]
733 res_atom_indices[-1].append(res_start_idx + a_idx)
734 res_atom_hashes[-1].append(a.handle.GetHashCode())
738 a_one = atoms[sym_tuple[0]]
739 a_two = atoms[sym_tuple[1]]
740 if a_one.IsValid()
and a_two.IsValid():
743 res_start_idx + sym_tuple[0],
744 res_start_idx + sym_tuple[1],
747 if len(sym_indices) > 0:
748 symmetries.append(sym_indices)
750 return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
751 res_indices, symmetries)
754 def _GetExtraModelChainPenalty(self, model, chain_mapping):
755 """Counts n distances in extra model chains to be added as penalty
758 for chain
in model.chains:
759 ch_name = chain.GetName()
760 if ch_name
not in chain_mapping:
762 mdl_sel = model.Select(f
"cname={mol.QueryQuoteName(ch_name)}")
764 symmetry_settings = sm,
767 penalty += sum([len(x)
for x
in dummy_scorer.ref_indices])
770 def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
772 """Map residues in model chain to target residues
774 There are two options: one is simply using residue numbers,
775 the other is a custom mapping as given in *residue_mapping*
777 if residue_mapping
and model_ch_name
in residue_mapping:
779 ch_idx = self.
chain_nameschain_names.index(target_ch_name)
785 target_rnums = self.
res_resnumsres_resnums[start_idx:end_idx]
787 target_seq = residue_mapping[model_ch_name].GetSequence(0)
788 model_seq = residue_mapping[model_ch_name].GetSequence(1)
789 if len(target_seq.GetGaplessString()) != len(target_rnums):
790 raise RuntimeError(f
"Try to perform residue mapping for "
791 f
"model chain {model_ch_name} which "
792 f
"maps to {target_ch_name} in target. "
793 f
"Target sequence in alignment suggests "
794 f
"{len(target_seq.GetGaplessString())} "
795 f
"residues but {len(target_rnums)} are "
797 if len(model_seq.GetGaplessString()) != len(ch.residues):
798 raise RuntimeError(f
"Try to perform residue mapping for "
799 f
"model chain {model_ch_name} which "
800 f
"maps to {target_ch_name} in target. "
801 f
"Model sequence in alignment suggests "
802 f
"{len(model_seq.GetGaplessString())} "
803 f
"residues but {len(ch.residues)} are "
807 for col
in residue_mapping[model_ch_name]:
811 if col[0] !=
'-' and col[1] !=
'-':
812 rnums.append(target_rnums[target_idx])
814 if col[0] ==
'-' and col[1] !=
'-':
817 rnums = [r.GetNumber()
for r
in ch.residues]
822 def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
823 seqres_mapping, bb_only):
824 """Sets target related lDDTScorer members defined in constructor
826 No distance related members - see _SetupDistances
832 for chain
in self.
targettarget.chains:
833 ch_name = chain.GetName()
837 for r, rnum
in zip(chain.residues, residue_numbers[ch_name]):
841 self.
_SetupCompound_SetupCompound(r, compound_lib, custom_compounds,
842 symmetry_settings, bb_only)
849 atoms = [r.FindAtom(an)
for an
in self.
compound_anamescompound_anames[r.name]]
852 self.
atom_indicesatom_indices[a.handle.GetHashCode()] = current_idx
854 positions.append(np.asarray([p[0], p[1], p[2]],
857 positions.append(np.zeros(3, dtype=np.float32))
862 for a_idx
in sym_tuple:
865 hashcode = a.handle.GetHashCode()
869 self.
positionspositions = np.vstack(positions)
870 self.
n_atomsn_atoms = current_idx
872 def _GetTargetResidueNumbers(self, target, seqres_mapping):
873 """Returns residue numbers for each chain in target as dict
875 They're either directly extracted from the raw residue number
876 from the structure or from user provided alignments
878 residue_numbers = dict()
879 for ch
in target.chains:
880 ch_name = ch.GetName()
882 if ch_name
in seqres_mapping:
883 seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
884 atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
888 "SEQRES in seqres_mapping must not " "contain gaps"
890 atomseq_from_chain = [r.one_letter_code
for r
in ch.residues]
891 if atomseq.replace(
"-",
"") != atomseq_from_chain:
893 "ATOMSEQ in seqres_mapping must match "
894 "raw sequence extracted from chain "
898 for seqres_olc, atomseq_olc
in zip(seqres, atomseq):
899 if seqres_olc !=
"-":
901 if atomseq_olc !=
"-":
902 if seqres_olc != atomseq_olc:
904 f
"Residue with number {rnum} in "
905 f
"chain {ch_name} has SEQRES "
910 rnums = [r.GetNumber()
for r
in ch.residues]
911 assert len(rnums) == len(ch.residues)
912 residue_numbers[ch_name] = rnums
913 return residue_numbers
915 def _SetupCompound(self, r, compound_lib, custom_compounds,
916 symmetry_settings, bb_only):
917 """fill self.compound_anames/self.compound_symmetric_atoms
921 if r.chem_class.IsPeptideLinking():
923 elif r.chem_class.IsNucleotideLinking():
926 raise RuntimeError(f
"Only support amino acids and nucleotides "
927 f
"if bb_only is True, failed with {str(r)}")
931 symmetric_atoms = list()
932 if custom_compounds
is not None and r.GetName()
in custom_compounds:
933 atom_names = list(custom_compounds[r.GetName()].atom_names)
935 compound = compound_lib.FindCompound(r.name)
937 raise RuntimeError(f
"no entry for {r} in compound_lib")
938 for atom_spec
in compound.GetAtomSpecs():
939 if atom_spec.element
not in [
"H",
"D"]:
940 atom_names.append(atom_spec.name)
941 if r.name
in symmetry_settings.symmetric_compounds:
942 for pair
in symmetry_settings.symmetric_compounds[r.name]:
944 a = atom_names.index(pair[0])
945 b = atom_names.index(pair[1])
947 msg = f
"Could not find symmetric atoms "
948 msg += f
"({pair[0]}, {pair[1]}) for {r.name} "
949 msg += f
"as specified in SymmetrySettings in "
950 msg += f
"compound from component dictionary. "
951 msg += f
"Atoms in compound: {atom_names}"
952 raise RuntimeError(msg)
953 symmetric_atoms.append((a, b))
955 if len(symmetric_atoms) > 0:
958 def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
959 ref_indices, ref_distances, no_interchain,
963 in_target = np.zeros(self.
n_atomsn_atoms, dtype=bool)
966 mdl_atom_indices = dict()
967 for at_indices, at_hashes
in zip(res_atom_indices, res_atom_hashes):
968 for i, h
in zip(at_indices, at_hashes):
970 mdl_atom_indices[h] = i
975 mdl_ref_indices, mdl_ref_distances = \
976 lDDTScorer._SetupDistances(model, self.
n_atomsn_atoms, mdl_atom_indices,
979 mdl_ref_indices, mdl_ref_distances = \
980 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
986 mdl_ref_indices, mdl_ref_distances = \
987 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
993 for i
in range(self.
n_atomsn_atoms):
994 mask = np.isin(mdl_ref_indices[i], ref_indices[i],
995 assume_unique=
True, invert=
True)
997 added_mdl_indices = mdl_ref_indices[i][mask]
998 ref_indices[i] = np.append(ref_indices[i],
1002 tmp = self.
positionspositions.take(added_mdl_indices, axis=0)
1003 np.subtract(tmp, self.
positionspositions[i][
None, :], out=tmp)
1004 np.square(tmp, out=tmp)
1005 tmp = tmp.sum(axis=1)
1006 np.sqrt(tmp, out=tmp)
1007 ref_distances[i] = np.append(ref_distances[i], tmp)
1010 sym_ref_indices, sym_ref_distances = \
1012 ref_indices, ref_distances)
1014 return (ref_indices, ref_distances, sym_ref_indices, sym_ref_distances)
1019 def _SetupDistances(structure, n_atoms, atom_index_mapping,
1022 """Compute distance related members of lDDTScorer
1024 Brute force all vs all distance computation kills lDDT for large
1025 complexes. Instead of building some KD tree data structure, we make use
1026 of expected spatial proximity of atoms in the same chain. Distances are
1027 computed as follows:
1029 - process each chain individually
1030 - perform crude collision detection
1031 - process potentially interacting chain pairs
1032 - concatenate distances from all processing steps
1034 ref_indices = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1035 ref_distances = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1037 indices = [list()
for _
in range(n_atoms)]
1038 distances = [list()
for _
in range(n_atoms)]
1039 per_chain_pos = list()
1040 per_chain_indices = list()
1043 for ch
in structure.chains:
1045 atom_indices = list()
1049 for r_idx, r
in enumerate(ch.residues):
1052 hash_code = a.handle.GetHashCode()
1053 if hash_code
in atom_index_mapping:
1055 pos_list.append(np.asarray([p[0], p[1], p[2]]))
1056 atom_indices.append(atom_index_mapping[hash_code])
1058 mask_start.extend([r_start_idx] * n_valid_atoms)
1059 mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1060 r_start_idx += n_valid_atoms
1062 if len(pos_list) == 0:
1066 pos = np.vstack(pos_list)
1067 atom_indices = np.asarray(atom_indices)
1068 dists =
cdist(pos, pos)
1071 far_away = 2 * inclusion_radius
1072 for idx
in range(atom_indices.shape[0]):
1073 dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1076 within_mask = dists < inclusion_radius
1077 for idx
in range(atom_indices.shape[0]):
1078 indices_to_append = atom_indices[within_mask[idx,:]]
1079 if indices_to_append.shape[0] > 0:
1080 full_at_idx = atom_indices[idx]
1081 indices[full_at_idx].append(indices_to_append)
1082 distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1084 per_chain_pos.append(pos)
1085 per_chain_indices.append(atom_indices)
1088 min_pos = [p.min(0)
for p
in per_chain_pos]
1089 max_pos = [p.max(0)
for p
in per_chain_pos]
1090 chain_pairs = list()
1091 for idx_one
in range(len(per_chain_pos)):
1092 for idx_two
in range(idx_one + 1, len(per_chain_pos)):
1093 if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1095 if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1097 chain_pairs.append((idx_one, idx_two))
1100 for pair
in chain_pairs:
1101 dists =
cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1102 within = dists <= inclusion_radius
1105 tmp = within.sum(axis=1)
1106 for idx
in range(tmp.shape[0]):
1111 at_idx = per_chain_indices[pair[0]][idx]
1112 indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1113 distances_to_insert = dists[idx, within[idx, :]]
1114 insertion_idx = len(indices[at_idx])
1115 for i
in range(insertion_idx):
1116 if indices_to_insert[0] > indices[at_idx][i][0]:
1119 indices[at_idx].insert(insertion_idx, indices_to_insert)
1120 distances[at_idx].insert(insertion_idx, distances_to_insert)
1123 tmp = within.sum(axis=0)
1124 for idx
in range(tmp.shape[0]):
1129 at_idx = per_chain_indices[pair[1]][idx]
1130 indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1131 distances_to_insert = dists[within[:, idx], idx]
1132 insertion_idx = len(indices[at_idx])
1133 for i
in range(insertion_idx):
1134 if indices_to_insert[0] > indices[at_idx][i][0]:
1137 indices[at_idx].insert(insertion_idx, indices_to_insert)
1138 distances[at_idx].insert(insertion_idx, distances_to_insert)
1141 for at_idx
in range(n_atoms):
1142 if len(indices[at_idx]) > 0:
1143 ref_indices[at_idx] = np.hstack(indices[at_idx])
1144 ref_distances[at_idx] = np.hstack(distances[at_idx])
1146 return (ref_indices, ref_distances)
1149 def _SetupDistancesSC(n_atoms, chain_start_indices,
1150 ref_indices, ref_distances):
1151 """Select subset of contacts only covering intra-chain contacts
1154 ref_indices_sc = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1155 ref_distances_sc = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1157 n_chains = len(chain_start_indices)
1158 for ch_idx
in range(n_chains):
1159 chain_s = chain_start_indices[ch_idx]
1161 if ch_idx + 1 < n_chains:
1162 chain_e = chain_start_indices[ch_idx+1]
1163 for i
in range(chain_s, chain_e):
1164 if len(ref_indices[i]) > 0:
1165 intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1166 ref_indices[i]<chain_e))[0]
1167 ref_indices_sc[i] = ref_indices[i][intra_idx]
1168 ref_distances_sc[i] = ref_distances[i][intra_idx]
1170 return (ref_indices_sc, ref_distances_sc)
1173 def _SetupDistancesIC(n_atoms, chain_start_indices,
1174 ref_indices, ref_distances):
1175 """Select subset of contacts only covering inter-chain contacts
1178 ref_indices_ic = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1179 ref_distances_ic = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1181 n_chains = len(chain_start_indices)
1182 for ch_idx
in range(n_chains):
1183 chain_s = chain_start_indices[ch_idx]
1185 if ch_idx + 1 < n_chains:
1186 chain_e = chain_start_indices[ch_idx+1]
1187 for i
in range(chain_s, chain_e):
1188 if len(ref_indices[i]) > 0:
1189 inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1190 ref_indices[i]>=chain_e))[0]
1191 ref_indices_ic[i] = ref_indices[i][inter_idx]
1192 ref_distances_ic[i] = ref_distances[i][inter_idx]
1194 return (ref_indices_ic, ref_distances_ic)
1197 def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1198 """Transfer indices/distances of non-symmetric atoms and return
1201 sym_ref_indices = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1202 sym_ref_distances = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1204 for idx
in symmetric_atoms:
1207 for i, d
in zip(ref_indices[idx], ref_distances[idx]):
1208 if i
not in symmetric_atoms:
1211 sym_ref_indices[idx] = indices
1212 sym_ref_distances[idx] = np.asarray(distances)
1214 return (sym_ref_indices, sym_ref_distances)
1216 def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1217 """Computes number of distance differences within given thresholds
1219 returns np.array with len(thresholds) elements
1221 a_p = pos[atom_idx, :]
1222 tmp = pos.take(ref_indices[atom_idx], axis=0)
1223 np.subtract(tmp, a_p[
None, :], out=tmp)
1224 np.square(tmp, out=tmp)
1225 tmp = tmp.sum(axis=1)
1226 np.sqrt(tmp, out=tmp)
1227 np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1228 np.absolute(tmp, out=tmp)
1229 return np.asarray([(tmp <= thresh).sum()
for thresh
in thresholds],
1233 self, pos, atom_indices, thresholds, ref_indices, ref_distances
1235 """Calls _EvalAtom for several atoms and sums up the computed number
1236 of distance differences within given thresholds
1238 returns numpy matrix of shape (n_atoms, len(threshold))
1240 conserved = np.zeros((len(atom_indices), len(thresholds)),
1242 for a_idx, a
in enumerate(atom_indices):
1243 conserved[a_idx, :] = self.
_EvalAtom_EvalAtom(pos, a, thresholds,
1244 ref_indices, ref_distances)
1247 def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1249 """Calls _EvalAtoms for a bunch of residues
1251 residues are defined in *res_atom_indices* as lists of atom indices
1252 returns numpy matrix of shape (n_residues, len(thresholds)).
1254 conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1256 for rai_idx, rai
in enumerate(res_atom_indices):
1257 conserved[rai_idx,:] = np.sum(self.
_EvalAtoms_EvalAtoms(pos, rai, thresholds,
1258 ref_indices, ref_distances), axis=0)
1261 def _ProcessSequenceSeparation(self):
1263 raise NotImplementedError(
"Congratulations! You're the first one "
1264 "requesting a non-default "
1265 "sequence_separation in the new and "
1266 "awesome lDDT implementation. A crate of "
1267 "beer for Gabriel and he'll implement "
1270 def _GetNExp(self, atom_idx, ref_indices):
1271 """Returns number of close atoms around one or several atoms
1273 if isinstance(atom_idx, int):
1274 return len(ref_indices[atom_idx])
1275 elif isinstance(atom_idx, list):
1276 return sum([len(ref_indices[idx])
for idx
in atom_idx])
1278 raise RuntimeError(
"invalid input type")
1280 def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1282 """Swaps symmetric positions in-place in order to maximize lDDT scores
1283 towards non-symmetric atoms.
1285 for sym
in symmetries:
1287 atom_indices = list()
1288 for sym_tuple
in sym:
1289 atom_indices += [sym_tuple[0], sym_tuple[1]]
1290 tot = self.
_GetNExp_GetNExp(atom_indices, sym_ref_indices)
1296 sym_one_conserved = self.
_EvalAtoms_EvalAtoms(
1306 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1308 sym_two_conserved = self.
_EvalAtoms_EvalAtoms(
1316 sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1317 sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1319 if sym_one_score >= sym_two_score:
1324 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def __init__(self, atom_names)
def AddSymmetricCompound(self, name, symmetric_atoms)
def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices, ref_distances)
def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
def _ProcessModel(self, model, chain_mapping, residue_mapping=None, thresholds=[0.5, 1.0, 2.0, 4.0], check_resnames=True)
def sym_ref_indices(self)
def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
def _ProcessSequenceSeparation(self)
def sym_ref_distances(self)
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
def ref_distances_ic(self)
def _GetTargetResidueNumbers(self, target, seqres_mapping)
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
def sym_ref_distances_ic(self)
def lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False)
def sym_ref_distances_sc(self)
def sym_ref_indices_ic(self)
def GetNChainContacts(self, target_chain, no_interchain=False)
def sym_ref_indices_sc(self)
def ref_distances_sc(self)
def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
def _GetNExp(self, atom_idx, ref_indices)
def __init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
def _GetExtraModelChainPenalty(self, model, chain_mapping)
def _EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
def GetDefaultSymmetrySettings()