Module proteinflow.data

Classes for downloading and manipulating protein data.

  • ProteinEntry: a class for manipulating proteinflow pickle files,
  • PDBEntry: a class for manipulating raw PDB files,
  • SAbDabEntry: a class for manipulating SAbDab files with specific methods for antibody data.

A ProteinEntry object can be created from a proteinflow pickle file, a PDB file or a SAbDab file directly and can be used to process the data and extract additional features. The processed data can be saved as a proteinflow pickle file or a PDB file.

Expand source code
"""
Classes for downloading and manipulating protein data.

- `ProteinEntry`: a class for manipulating proteinflow pickle files,
- `PDBEntry`: a class for manipulating raw PDB files,
- `SAbDabEntry`: a class for manipulating SAbDab files with specific methods for antibody data.

A `ProteinEntry` object can be created from a proteinflow pickle file, a PDB file or a SAbDab file directly
and can be used to process the data and extract additional features. The processed data can be saved as a
proteinflow pickle file or a PDB file.

"""

import itertools
import os
import pickle
import string
import tempfile
import warnings
from collections import defaultdict

import Bio.PDB
import numpy as np
import pandas as pd
from Bio import pairwise2
from biopandas.pdb import PandasPdb
from editdistance import eval as edit_distance
from torch import Tensor, from_numpy

try:
    import MDAnalysis as mda
except ImportError:
    pass
try:
    from methodtools import lru_cache
except ImportError:

    def lru_cache():
        """Make a dummy decorator."""

        def wrapper(func):
            return func

        return wrapper


from proteinflow.constants import (
    _PMAP,
    ALPHABET,
    ALPHABET_REVERSE,
    ATOM_MASKS,
    BACKBONE_ORDER,
    CDR_ALPHABET,
    CDR_REVERSE,
    CDR_VALUES,
    COLORS,
    D3TO1,
    MAIN_ATOM_DICT,
    SIDECHAIN_ORDER,
)
from proteinflow.data.utils import (
    CustomMmcif,
    PDBBuilder,
    PDBError,
    _annotate_sse,
    _Atom,
    _dihedral_angle,
    _retrieve_chain_names,
)
from proteinflow.download import download_fasta, download_pdb
from proteinflow.extra import _get_view, requires_extra
from proteinflow.ligand import _get_ligands
from proteinflow.metrics import (
    ablang_pll,
    blosum62_score,
    ca_rmsd,
    confidence_from_file,
    esm_pll,
    esmfold_generate,
    igfold_generate,
    immunebuilder_generate,
    long_repeat_num,
    tm_score,
)


def interpolate_coords(crd, mask, fill_ends=True):
    """Fill in missing values in a coordinates array with linear interpolation.

    Parameters
    ----------
    crd : np.ndarray
        Coordinates array of shape `(L, 4, 3)`
    mask : np.ndarray
        Mask array of shape `(L,)` where 1 indicates residues with known coordinates and 0
        indicates missing values
    fill_ends : bool, default True
        If `True`, fill in missing values at the ends of the protein sequence with the edge values;
        otherwise fill them in with zeros

    Returns
    -------
    crd : np.ndarray
        Interpolated coordinates array of shape `(L, 4, 3)`
    mask : np.ndarray
        Interpolated mask array of shape `(L,)` where 1 indicates residues with known or interpolated
        coordinates and 0 indicates missing values

    """
    crd[(1 - mask).astype(bool)] = np.nan
    df = pd.DataFrame(crd.reshape((crd.shape[0], -1)))
    crd = df.interpolate(limit_area="inside" if not fill_ends else None).values.reshape(
        crd.shape
    )
    if not fill_ends:
        nan_mask = np.isnan(crd)  # in the middle the nans have been interpolated
        interpolated_mask = np.zeros_like(mask)
        interpolated_mask[~np.isnan(crd[:, 0, 0])] = 1
        crd[nan_mask] = 0
    else:
        interpolated_mask = np.ones_like(crd[:, :, 0])
    return crd, mask


class ProteinEntry:
    """A class to interact with proteinflow data files."""

    ATOM_ORDER = {k: BACKBONE_ORDER + v for k, v in SIDECHAIN_ORDER.items()}
    """A dictionary mapping 3-letter residue names to the order of atoms in the coordinates array."""

    def __init__(
        self,
        seqs,
        crds,
        masks,
        chain_ids,
        predict_masks=None,
        cdrs=None,
        protein_id=None,
    ):
        """Initialize a `ProteinEntry` object.

        Parameters
        ----------
        seqs : list of str
            Amino acid sequences of the protein (one-letter code)
        crds : list of np.ndarray
            Coordinates of the protein, `numpy` arrays of shape `(L, 14, 3)`,
            in the order of `N, C, CA, O`
        masks : list of np.ndarray
            Mask arrays where 1 indicates residues with known coordinates and 0
            indicates missing values
        cdrs : list of np.ndarray
            `'numpy'` arrays of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
            and non-CDR residues are marked with `'-'`
        chain_ids : list of str
            Chain IDs of the protein
        predict_masks : list of np.ndarray, optional
            Mask arrays where 1 indicates residues that were generated by a model and 0
            indicates residues with known coordinates
        cdrs : list of np.ndarray, optional
            `'numpy'` arrays of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
        protein_id : str, optional
            ID of the protein

        """
        if crds[0].shape[1] != 14:
            raise ValueError(
                "Coordinates array must have 14 atoms in the order of N, C, CA, O, sidechain atoms"
            )
        self.seq = {x: seq for x, seq in zip(chain_ids, seqs)}
        self.crd = {x: crd for x, crd in zip(chain_ids, crds)}
        self.mask = {x: mask for x, mask in zip(chain_ids, masks)}
        self.mask_original = {x: mask for x, mask in zip(chain_ids, masks)}
        if cdrs is None:
            cdrs = [None for _ in chain_ids]
        self.cdr = {x: cdr for x, cdr in zip(chain_ids, cdrs)}
        if predict_masks is None:
            predict_masks = [None for _ in chain_ids]
        self.predict_mask = {x: mask for x, mask in zip(chain_ids, predict_masks)}
        self.id = protein_id

    def get_id(self):
        """Return the ID of the protein."""
        return self.id

    def interpolate_coords(self, fill_ends=True):
        """Fill in missing values in the coordinates arrays with linear interpolation.

        Parameters
        ----------
        fill_ends : bool, default True
            If `True`, fill in missing values at the ends of the protein sequence with the edge values;
            otherwise fill them in with zeros

        """
        for chain in self.get_chains():
            self.crd[chain], self.mask[chain] = interpolate_coords(
                self.crd[chain], self.mask[chain], fill_ends=fill_ends
            )

    def cut_missing_edges(self):
        """Cut off the ends of the protein sequence that have missing coordinates."""
        for chain in self.get_chains():
            mask = self.mask[chain]
            known_ind = np.where(mask == 1)[0]
            start, end = known_ind[0], known_ind[-1] + 1
            self.seq[chain] = self.seq[chain][start:end]
            self.crd[chain] = self.crd[chain][start:end]
            self.mask[chain] = self.mask[chain][start:end]
            if self.cdr[chain] is not None:
                self.cdr[chain] = self.cdr[chain][start:end]

    def get_chains(self):
        """Get the chain IDs of the protein.

        Returns
        -------
        chains : list of str
            Chain IDs of the protein

        """
        return sorted(self.seq.keys())

    def _get_chains_list(self, chains):
        """Get a list of chains to iterate over."""
        if chains is None:
            chains = self.get_chains()
        return chains

    def get_chain_type_dict(self, chains=None):
        """Get the chain types of the protein.

        If the CDRs are not annotated, this function will return `None`.
        If there is no light or heavy chain, the corresponding key will be missing.
        If there is no antigen chain, the `'antigen'` key will map to an empty list.

        Parameters
        ----------
        chains : list of str, default None
            Chain IDs to consider

        Returns
        -------
        chain_type_dict : dict
            A dictionary with keys `'heavy'`, `'light'` and `'antigen'` and values
            the corresponding chain IDs

        """
        if not self.has_cdr():
            return None
        chain_type_dict = {"antigen": []}
        chains = self._get_chains_list(chains)
        for chain, cdr in self.cdr.items():
            if chain not in chains:
                continue
            u = np.unique(cdr)
            if "H1" in u:
                chain_type_dict["heavy"] = chain
            elif "L1" in u:
                chain_type_dict["light"] = chain
            else:
                chain_type_dict["antigen"].append(chain)
        return chain_type_dict

    def get_length(self, chains=None):
        """Get the total length of a set of chains.

        Parameters
        ----------
        chain : str, optional
            Chain ID; if `None`, the length of the whole protein is returned

        Returns
        -------
        length : int
            Length of the chain

        """
        chains = self._get_chains_list(chains)
        return sum([len(self.seq[x]) for x in chains])

    def get_cdr_length(self, chains):
        """Get the length of the CDR regions of a set of chains.

        Parameters
        ----------
        chain : str
            Chain ID

        Returns
        -------
        length : int
            Length of the CDR regions of the chain

        """
        if not self.has_cdr():
            return {x: None for x in ["H1", "H2", "H3", "L1", "L2", "L3"]}
        return {
            x: len(self.get_sequence(chains=chains, cdr=x))
            for x in ["H1", "H2", "H3", "L1", "L2", "L3"]
        }

    def has_cdr(self):
        """Check if the protein is from the SAbDab database.

        Returns
        -------
        is_sabdab : bool
            True if the protein is from the SAbDab database

        """
        return list(self.cdr.values())[0] is not None

    def has_predict_mask(self):
        """Check if the protein has a predicted mask.

        Returns
        -------
        has_predict_mask : bool
            True if the protein has a predicted mask

        """
        return list(self.predict_mask.values())[0] is not None

    def __len__(self):
        """Get the total length of the protein chains."""
        return self.get_length(self.get_chains())

    def get_sequence(self, chains=None, encode=False, cdr=None, only_known=False):
        """Get the amino acid sequence of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the sequences of the specified chains is returned (in the same order);
            otherwise, all sequences are concatenated in alphabetical order of the chain IDs
        encode : bool, default False
            If `True`, the sequence is encoded as a `'numpy'` array of integers
            where each integer corresponds to the index of the amino acid in
            `proteinflow.constants.ALPHABET`
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned
        only_known : bool, default False
            If `True`, only the residues with known coordinates are returned

        Returns
        -------
        seq : str or np.ndarray
            Amino acid sequence of the protein (one-letter code) or an encoded
            sequence as a `'numpy'` array of integers

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        seq = "".join([self.seq[c] for c in chains]).replace("B", "")
        if encode:
            seq = np.array([ALPHABET_REVERSE[aa] for aa in seq])
        elif cdr is not None or only_known:
            seq = np.array(list(seq))
        if cdr is not None:
            cdr_arr = self.get_cdr(chains=chains)
            seq = seq[cdr_arr == cdr]
        if only_known:
            seq = seq[self.get_mask(chains=chains, cdr=cdr).astype(bool)]
        if not encode and not isinstance(seq, str):
            seq = "".join(seq)
        return seq

    def get_coordinates(self, chains=None, bb_only=False, cdr=None, only_known=False):
        """Get the coordinates of the protein.

        Backbone atoms are in the order of `N, C, CA, O`; for the full-atom
        order see `ProteinEntry.ATOM_ORDER` (sidechain atoms come after the
        backbone atoms).

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the coordinates of the specified chains are returned (in the same order);
            otherwise, all coordinates are concatenated in alphabetical order of the chain IDs
        bb_only : bool, default False
            If `True`, only the backbone atoms are returned
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned
        only_known : bool, default False
            If `True`, only return the coordinates of residues with known coordinates

        Returns
        -------
        crd : np.ndarray
            Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)`
            or `(L, 4, 3)` if `bb_only=True`

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        crd = np.concatenate([self.crd[c] for c in chains], axis=0)
        if cdr is not None:
            crd = crd[self.cdr == cdr]
        if bb_only:
            crd = crd[:, :4, :]
        if only_known:
            crd = crd[self.get_mask(chains=chains, cdr=cdr).astype(bool)]
        return crd

    def get_mask(self, chains=None, cdr=None, original=False):
        """Get the mask of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the masks of the specified chains are returned (in the same order);
            otherwise, all masks are concatenated in alphabetical order of the chain IDs
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned
        original : bool, default False
            If `True`, return the original mask (before interpolation)

        Returns
        -------
        mask : np.ndarray
            Mask array where 1 indicates residues with known coordinates and 0
            indicates missing values

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        mask = np.concatenate(
            [self.mask_original[c] if original else self.mask[c] for c in chains],
            axis=0,
        )
        if cdr is not None:
            mask = mask[self.cdr == cdr]
        return mask

    def get_cdr(self, chains=None, encode=False):
        """Get the CDR information of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the CDR information of the specified chains is
            returned (in the same order); otherwise, all CDR information is concatenated in
            alphabetical order of the chain IDs
        encode : bool, default False
            If `True`, the CDR information is encoded as a `'numpy'` array of
            integers where each integer corresponds to the index of the CDR
            type in `proteinflow.constants.CDR_ALPHABET`

        Returns
        -------
        cdr : np.ndarray or None
            A `'numpy'` array of shape `(L,)` where CDR residues are marked
            with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR
            residues are marked with `'-'` or an encoded array of integers
            ir `encode=True`; `None` if CDR information is not available
        chains : list of str, optional
            If specified, only the CDR information of the specified chains is
            returned (in the same order); otherwise, all CDR information is concatenated in
            alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        if self.cdr is None:
            return None
        cdr = np.concatenate([self.cdr[c] for c in chains], axis=0)
        if encode:
            cdr = np.array([CDR_REVERSE[aa] for aa in cdr])
        return cdr

    def get_atom_mask(self, chains=None, cdr=None):
        """Get the atom mask of the protein.

        Parameters
        ----------
        chains : str, optional
            If specified, only the atom masks of the specified chains are returned (in the same order);
            otherwise, all atom masks are concatenated in alphabetical order of the chain IDs
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned

        Returns
        -------
        atom_mask : np.ndarray
            Atom mask array where 1 indicates atoms with known coordinates and 0
            indicates missing or non-existing values, shaped `(L, 14, 3)`

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        seq = "".join([self.seq[c] for c in chains])
        atom_mask = np.concatenate([ATOM_MASKS[aa] for aa in seq])
        atom_mask[self.mask == 0] = 0
        if cdr is not None:
            atom_mask = atom_mask[self.cdr == cdr]
        return atom_mask

    @staticmethod
    def decode_cdr(cdr):
        """Decode the CDR information.

        Parameters
        ----------
        cdr : np.ndarray
            A `'numpy'` array of shape `(L,)` encoded as integers where each
            integer corresponds to the index of the CDR type in
            `proteinflow.constants.CDR_ALPHABET`

        Returns
        -------
        cdr : np.ndarray
            A `'numpy'` array of shape `(L,)` where CDR residues are marked
            with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR
            residues are marked with `'-'`

        """
        cdr = ProteinEntry._to_numpy(cdr)
        return np.array([CDR_ALPHABET[x] for x in cdr.astype(int)])

    @staticmethod
    def _to_numpy(arr):
        if isinstance(arr, Tensor):
            arr = arr.detach().cpu().numpy()
        if isinstance(arr, list):
            arr = np.array(arr)
        return arr

    @staticmethod
    def decode_sequence(seq):
        """Decode the amino acid sequence.

        Parameters
        ----------
        seq : np.ndarray
            A `'numpy'` array of integers where each integer corresponds to the
            index of the amino acid in `proteinflow.constants.ALPHABET`

        Returns
        -------
        seq : str
            Amino acid sequence of the protein (one-letter code)

        """
        seq = ProteinEntry._to_numpy(seq)
        return "".join([ALPHABET[x] for x in seq.astype(int)])

    def _rename_chains(self, chain_dict):
        """Rename the chains of the protein (with no safeguards)."""
        for old_chain, new_chain in chain_dict.items():
            self.seq[new_chain] = self.seq.pop(old_chain)
            self.crd[new_chain] = self.crd.pop(old_chain)
            self.mask[new_chain] = self.mask.pop(old_chain)
            self.mask_original[new_chain] = self.mask_original.pop(old_chain)
            self.cdr[new_chain] = self.cdr.pop(old_chain)
            self.predict_mask[new_chain] = self.predict_mask.pop(old_chain)

    def rename_chains(self, chain_dict):
        """Rename the chains of the protein.

        Parameters
        ----------
        chain_dict : dict
            A dictionary mapping old chain IDs to new chain IDs

        """
        for chain in self.get_chains():
            if chain not in chain_dict:
                chain_dict[chain] = chain
        self._rename_chains({k: k * 5 for k in self.get_chains()})
        self._rename_chains({k * 5: v for k, v in chain_dict.items()})

    def get_predicted_entry(self):
        """Return a `ProteinEntry` object that only contains predicted residues.

        Returns
        -------
        entry : ProteinEntry
            The truncated `ProteinEntry` object

        """
        if self.predict_mask is None:
            raise ValueError("Predicted mask not available")
        entry_dict = self.to_dict()
        for chain in self.get_chains():
            mask_ = self.predict_mask[chain].astype(bool)
            if mask_.sum() == 0:
                entry_dict.pop(chain)
                continue
            if mask_.sum() == len(mask_):
                continue
            seq_arr = np.array(list(entry_dict[chain]["seq"]))
            entry_dict[chain]["seq"] = "".join(seq_arr[mask_])
            entry_dict[chain]["crd_bb"] = entry_dict[chain]["crd_bb"][mask_]
            entry_dict[chain]["crd_sc"] = entry_dict[chain]["crd_sc"][mask_]
            entry_dict[chain]["msk"] = entry_dict[chain]["msk"][mask_]
            entry_dict[chain]["predict_msk"] = entry_dict[chain]["predict_msk"][mask_]
            if "cdr" in entry_dict[chain]:
                entry_dict[chain]["cdr"] = entry_dict[chain]["cdr"][mask_]
        return ProteinEntry.from_dict(entry_dict)

    def get_predicted_chains(self):
        """Return a list of chain IDs that contain predicted residues.

        Returns
        -------
        chains : list of str
            Chain IDs

        """
        if not self.has_predict_mask():
            raise ValueError("Predicted mask not available")
        return [k for k, v in self.predict_mask.items() if v.sum() != 0]

    def merge(self, entry):
        """Merge another `ProteinEntry` object into this one.

        Parameters
        ----------
        entry : ProteinEntry
            The merged `ProteinEntry` object

        """
        for chain in entry.get_chains():
            if chain.split("_")[0] in {x.split("_")[0] for x in self.get_chains()}:
                raise ValueError("Chain IDs must be unique")
            self.seq[chain] = entry.seq[chain]
            self.crd[chain] = entry.crd[chain]
            self.mask[chain] = entry.mask[chain]
            self.mask_original[chain] = entry.mask_original[chain]
            self.cdr[chain] = entry.cdr[chain]
            self.predict_mask[chain] = entry.predict_mask[chain]
        if not all([x is None for x in self.predict_mask.values()]):
            for k, v in self.predict_mask.items():
                if v is None:
                    self.predict_mask[k] = np.zeros(len(self.get_sequence(k)))

    @staticmethod
    def from_arrays(
        seqs,
        crds,
        masks,
        chain_id_dict,
        chain_id_array,
        predict_masks=None,
        cdrs=None,
        protein_id=None,
    ):
        """Load a protein entry from arrays.

        Parameters
        ----------
        seqs : np.ndarray
            Amino acid sequences of the protein (encoded as integers, see `proteinflow.constants.ALPHABET`), `'numpy'` array of shape `(L,)`
        crds : np.ndarray
            Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)` or `(L, 4, 3)`
        masks : np.ndarray
            Mask array where 1 indicates residues with known coordinates and 0
            indicates missing values, `'numpy'` array of shape `(L,)`
        chain_id_dict : dict
            A dictionary mapping chain IDs to indices in `chain_id_array`
        chain_id_array : np.ndarray
            A `'numpy'` array of chain IDs encoded as integers
        predict_masks : np.ndarray, optional
            Mask array where 1 indicates residues that were generated by a model and 0
            indicates residues with known coordinates, `'numpy'` array of shape `(L,)`
        cdrs : np.ndarray, optional
            A `'numpy'` array of shape `(L,)` where residues are marked
            with the corresponding CDR type (encoded as integers, see `proteinflow.constants.CDR_ALPHABET`)
        protein_id : str, optional
            Protein ID

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        seqs_list = []
        crds_list = []
        masks_list = []
        chain_ids_list = []
        predict_masks_list = None if predict_masks is None else []
        cdrs_list = None if cdrs is None else []
        for chain_id, ind in chain_id_dict.items():
            chain_ids_list.append(chain_id)
            chain_mask = chain_id_array == ind
            seqs_list.append(ProteinEntry.decode_sequence(seqs[chain_mask]))
            if crds.shape[1] != 14:
                crds_ = np.zeros((crds[chain_mask].shape[0], 14, 3))
                crds_[:, :4, :] = ProteinEntry._to_numpy(crds[chain_mask])
            else:
                crds_ = ProteinEntry._to_numpy(crds[chain_mask])
            crds_list.append(crds_)
            masks_list.append(ProteinEntry._to_numpy(masks[chain_mask]))
            if predict_masks is not None:
                predict_masks_list.append(
                    ProteinEntry._to_numpy(predict_masks[chain_mask])
                )
            if cdrs is not None:
                cdrs_list.append(ProteinEntry.decode_cdr(cdrs[chain_mask]))
        return ProteinEntry(
            seqs_list,
            crds_list,
            masks_list,
            chain_ids_list,
            predict_masks_list,
            cdrs_list,
            protein_id,
        )

    @staticmethod
    def from_dict(dictionary):
        """Load a protein entry from a dictionary.

        Parameters
        ----------
        dictionary : dict
            A nested dictionary where first-level keys are chain IDs and
            second-level keys are the following:
            - `'seq'` : amino acid sequence (one-letter code)
            - `'crd_bb'` : backbone coordinates, shaped `(L, 4, 3)`
            - `'crd_sc'` : sidechain coordinates, shaped `(L, 10, 3)`
            - `'msk'` : mask array where 1 indicates residues with known coordinates and 0
                indicates missing values, shaped `(L,)`
            - `'cdr'` (optional): CDR information, shaped `(L,)` where CDR residues are marked
                with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR residues are marked with `'-'`
            - `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
                indicates residues with known coordinates, shaped `(L,)`
            It can also contain a `'protein_id'` first-level key.

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        chains = sorted([x for x in dictionary.keys() if x != "protein_id"])
        seq = [dictionary[k]["seq"] for k in chains]
        crd = [
            np.concatenate([dictionary[k]["crd_bb"], dictionary[k]["crd_sc"]], axis=1)
            for k in chains
        ]
        mask = [dictionary[k]["msk"] for k in chains]
        cdr = [dictionary[k].get("cdr", None) for k in chains]
        predict_mask = [dictionary[k].get("predict_msk", None) for k in chains]
        return ProteinEntry(
            seqs=seq,
            crds=crd,
            masks=mask,
            cdrs=cdr,
            chain_ids=chains,
            predict_masks=predict_mask,
            protein_id=dictionary.get("protein_id"),
        )

    @staticmethod
    def from_pdb_entry(pdb_entry):
        """Load a protein entry from a `PDBEntry` object.

        Parameters
        ----------
        pdb_entry : PDBEntry
            A `PDBEntry` object

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        pdb_dict = {}
        fasta_dict = pdb_entry.get_fasta()
        for (chain,) in pdb_entry.get_chains():
            pdb_dict[chain] = {}
            fasta_seq = fasta_dict[chain]

            # align fasta and pdb and check criteria)
            mask = pdb_entry.get_mask([chain])[chain]
            if isinstance(pdb_entry, SAbDabEntry):
                pdb_dict[chain]["cdr"] = pdb_entry.get_cdr([chain])[chain]
            pdb_dict[chain]["seq"] = fasta_seq
            pdb_dict[chain]["msk"] = mask

            # go over rows of coordinates
            crd_arr = pdb_entry.get_coordinates_array(chain)

            pdb_dict[chain]["crd_bb"] = crd_arr[:, :4, :]
            pdb_dict[chain]["crd_sc"] = crd_arr[:, 4:, :]
            pdb_dict[chain]["msk"][
                (pdb_dict[chain]["crd_bb"] == 0).sum(-1).sum(-1) == 4
            ] = 0
        pdb_dict["protein_id"] = pdb_entry.pdb_id
        return ProteinEntry.from_dict(pdb_dict)

    @staticmethod
    def from_pdb(
        pdb_path,
        fasta_path=None,
        heavy_chain=None,
        light_chain=None,
        antigen_chains=None,
    ):
        """Load a protein entry from a PDB file.

        Parameters
        ----------
        pdb_path : str
            Path to the PDB file
        fasta_path : str, optional
            Path to the FASTA file; if not specified, the sequence is extracted
            from the PDB file
        heavy_chain : str, optional
            Chain ID of the heavy chain (to load a SAbDab entry)
        light_chain : str, optional
            Chain ID of the light chain (to load a SAbDab entry)
        antigen_chains : list of str, optional
            Chain IDs of the antigen chains (to load a SAbDab entry)

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        if heavy_chain is not None or light_chain is not None:
            pdb_entry = SAbDabEntry(
                pdb_path=pdb_path,
                fasta_path=fasta_path,
                heavy_chain=heavy_chain,
                light_chain=light_chain,
                antigen_chains=antigen_chains,
            )
        else:
            pdb_entry = PDBEntry(pdb_path=pdb_path, fasta_path=fasta_path)
        return ProteinEntry.from_pdb_entry(pdb_entry)

    @staticmethod
    def from_id(
        pdb_id,
        local_folder=".",
        heavy_chain=None,
        light_chain=None,
        antigen_chains=None,
    ):
        """Load a protein entry from a PDB file.

        Parameters
        ----------
        pdb_id : str
            PDB ID of the protein
        local_folder : str, default "."
            Path to the local folder where the PDB file is saved
        heavy_chain : str, optional
            Chain ID of the heavy chain (to load a SAbDab entry)
        light_chain : str, optional
            Chain ID of the light chain (to load a SAbDab entry)
        antigen_chains : list of str, optional
            Chain IDs of the antigen chains (to load a SAbDab entry)

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        if heavy_chain is not None or light_chain is not None:
            pdb_entry = SAbDabEntry.from_id(
                pdb_id=pdb_id,
                local_folder=local_folder,
                heavy_chain=heavy_chain,
                light_chain=light_chain,
                antigen_chains=antigen_chains,
            )
        else:
            pdb_entry = PDBEntry.from_id(pdb_id=pdb_id)
        return ProteinEntry.from_pdb_entry(pdb_entry)

    @staticmethod
    def from_pickle(path):
        """Load a protein entry from a pickle file.

        Parameters
        ----------
        path : str
            Path to the pickle file

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        with open(path, "rb") as f:
            data = pickle.load(f)
        return ProteinEntry.from_dict(data)

    @staticmethod
    def retrieve_ligands_from_pickle(path):
        """Retrieve ligands from a pickle file.

        Parameters
        ----------
        path : str
            Path to the pickle file

        Returns
        -------
        chain2ligand : dict
            A dictionary where keys are chain IDs and values are ligand names

        """
        with open(path, "rb") as f:
            data = pickle.load(f)
        chain2ligand = {}
        for chain in data:
            if "ligand" not in data[chain]:
                continue
            chain2ligand[chain] = data[chain]["ligand"]
        return chain2ligand

    def to_dict(self):
        """Convert a protein entry into a dictionary.

        Returns
        -------
        dictionary : dict
            A nested dictionary where first-level keys are chain IDs and
            second-level keys are the following:
            - `'seq'` : amino acid sequence (one-letter code)
            - `'crd_bb'` : backbone coordinates, shaped `(L, 4, 3)`
            - `'crd_sc'` : sidechain coordinates, shaped `(L, 10, 3)`
            - `'msk'` : mask array where 1 indicates residues with known coordinates and 0
                indicates missing values, shaped `(L,)`
            - `'cdr'` (optional): CDR information, shaped `(L,)` encoded as integers where each
                integer corresponds to the index of the CDR type in
                `proteinflow.constants.CDR_ALPHABET`
            - `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
                indicates residues with known coordinates, shaped `(L,)`
            It can optionally also contain `protein_id` as a first-level key.

        """
        data = {}
        for chain in self.get_chains():
            data[chain] = {
                "seq": self.seq[chain],
                "crd_bb": self.crd[chain][:, :4],
                "crd_sc": self.crd[chain][:, 4:],
                "msk": self.mask[chain],
            }
            if self.cdr[chain] is not None:
                data[chain]["cdr"] = self.cdr[chain]
            if self.predict_mask[chain] is not None:
                data[chain]["predict_msk"] = self.predict_mask[chain]
        if self.id is not None:
            data["protein_id"] = self.id
        return data

    def to_pdb(
        self,
        path,
        only_ca=False,
        skip_oxygens=False,
        only_backbone=False,
        title=None,
    ):
        """Save the protein entry to a PDB file.

        Parameters
        ----------
        path : str
            Path to the output PDB file
        only_ca : bool, default False
            If `True`, only backbone atoms are saved
        skip_oxygens : bool, default False
            If `True`, oxygen atoms are not saved
        only_backbone : bool, default False
            If `True`, only backbone atoms are saved
        title : str, optional
            Title of the PDB file (by default either the protein id or "Untitled")

        """
        if any([x[0].upper() != x for x in self.get_chains()]):
            raise ValueError(
                "Chain IDs must be single uppercase letters, please rename with `rename_chains` before saving."
            )
        pdb_builder = PDBBuilder(
            self,
            only_ca=only_ca,
            skip_oxygens=skip_oxygens,
            only_backbone=only_backbone,
        )
        if title is None:
            if self.id is not None:
                title = self.id
            else:
                title = "Untitled"
        pdb_builder.save_pdb(path, title=title)

    def to_pickle(self, path):
        """Save a protein entry to a pickle file.

        The output files are pickled nested dictionaries where first-level keys are chain Ids and second-level keys are the following:
        - `'crd_bb'`: a `numpy` array of shape `(L, 4, 3)` with backbone atom coordinates (N, C, CA, O),
        - `'crd_sc'`: a `numpy` array of shape `(L, 10, 3)` with sidechain atom coordinates (check `proteinflow.sidechain_order()` for the order of atoms),
        - `'msk'`: a `numpy` array of shape `(L,)` where ones correspond to residues with known coordinates and
            zeros to missing values,
        - `'seq'`: a string of length `L` with residue types.

        In a SAbDab datasets, an additional key is added to the dictionary:
        - `'cdr'`: a `'numpy'` array of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
            and non-CDR residues are marked with `'-'`.

        If a prediction mask is available, another additional key is added to the dictionary:
        - `'predict_msk'`: a `numpy` array of shape `(L,)` where ones correspond to residues that were generated by a model and
            zeros to residues with known coordinates.

        Parameters
        ----------
        path : str
            Path to the pickle file

        """
        data = self.to_dict()
        with open(path, "wb") as f:
            pickle.dump(data, f)

    def dihedral_angles(self, chains=None):
        """Calculate the backbone dihedral angles (phi, psi) of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the dihedral angles of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        angles : np.ndarray
            A `'numpy'` array of shape `(L, 2)` with backbone dihedral angles
            (phi, psi) in degrees; missing values are marked with zeros
        chains : list of str, optional
            If specified, only the dihedral angles of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        angles = []
        chains = self._get_chains_list(chains)
        # N, C, Ca, O
        # psi
        for chain in chains:
            chain_angles = []
            crd = self.get_coordinates([chain])
            mask = self.get_mask([chain])
            p = crd[:-1, [0, 2, 1], :]
            p = np.concatenate([p, crd[1:, [0], :]], 1)
            p = np.pad(p, ((0, 1), (0, 0), (0, 0)))
            chain_angles.append(_dihedral_angle(p, mask))
            # phi
            p = crd[:-1, [1], :]
            p = np.concatenate([p, crd[1:, [0, 2, 1]]], 1)
            p = np.pad(p, ((1, 0), (0, 0), (0, 0)))
            chain_angles.append(_dihedral_angle(p, mask))
            angles.append(np.stack(chain_angles, -1))
        angles = np.concatenate(angles, 0)
        return angles

    def secondary_structure(self, chains=None):
        """Calculate the secondary structure of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the secondary structure of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        sse : np.ndarray
            A `'numpy'` array of shape `(L, 3)` with secondary structure
            elements encoded as one-hot vectors (alpha-helix, beta-sheet, loop);
            missing values are marked with zeros
        chains : list of str, optional
            If specified, only the secondary structure of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        out = []
        for chain in chains:
            crd = self.get_coordinates([chain])
            sse_map = {"c": [0, 0, 1], "b": [0, 1, 0], "a": [1, 0, 0], "": [0, 0, 0]}
            sse = _annotate_sse(crd[:, :4])
            out += [sse_map[x] for x in sse]
        sse = np.array(out)
        return sse

    def sidechain_coordinates(self, chains=None):
        """Get the sidechain coordinates of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the sidechain coordinates of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        crd : np.ndarray
            A `'numpy'` array of shape `(L, 10, 3)` with sidechain atom
            coordinates (check `proteinflow.sidechain_order()` for the order of
            atoms); missing values are marked with zeros
        chains : list of str, optional
            If specified, only the sidechain coordinates of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        return self.get_coordinates(chains)[:, 4:, :]

    def chemical_features(self, chains=None):
        """Calculate chemical features of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the chemical features of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        features : np.ndarray
            A `'numpy'` array of shape `(L, 4)` with chemical features of the
            protein (hydropathy, volume, charge, polarity, acceptor/donor); missing
            values are marked with zeros
        chains : list of str, optional
            If specified, only the chemical features of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        seq = "".join([self.seq[chain] for chain in chains])
        features = np.array([_PMAP(x) for x in seq])
        return features

    def sidechain_orientation(self, chains=None):
        """Calculate the (global) sidechain orientation of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the sidechain orientation of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        orientation : np.ndarray
            A `'numpy'` array of shape `(L, 3)` with sidechain orientation
            vectors; missing values are marked with zeros
        chains : list of str, optional
            If specified, only the sidechain orientation of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        crd = self.get_coordinates(chains=chains)
        crd_bb, crd_sc = crd[:, :4, :], crd[:, 4:, :]
        seq = self.get_sequence(chains=chains, encode=True)
        orientation = np.zeros((crd_sc.shape[0], 3))
        for i in range(1, 21):
            if MAIN_ATOM_DICT[i] is not None:
                orientation[seq == i] = (
                    crd_sc[seq == i, MAIN_ATOM_DICT[i], :] - crd_bb[seq == i, 2, :]
                )
            else:
                S_mask = self.seq == i
                orientation[S_mask] = np.random.rand(*orientation[S_mask].shape)
        orientation /= np.expand_dims(np.linalg.norm(orientation, axis=-1), -1) + 1e-7
        return orientation

    @lru_cache()
    def is_valid_pair(self, chain1, chain2, cutoff=10):
        """Check if two chains are a valid pair based on the distance between them.

        We consider two chains to be a valid pair if the distance between them is
        smaller than `cutoff` Angstroms. The distance is calculated as the minimum
        distance between any two atoms of the two chains.

        Parameters
        ----------
        chain1 : str
            Chain ID of the first chain
        chain2 : str
            Chain ID of the second chain
        cutoff : int, optional
            Minimum distance between the two chains (in Angstroms)

        Returns
        -------
        valid : bool
            `True` if the two chains are a valid pair, `False` otherwise

        """
        margin = cutoff * 3
        assert chain1 in self.get_chains(), f"Chain {chain1} not found"
        assert chain2 in self.get_chains(), f"Chain {chain2} not found"
        X1 = self.get_coordinates(chains=[chain1], only_known=True)
        X2 = self.get_coordinates(chains=[chain2], only_known=True)
        intersect_dim_X1 = []
        intersect_dim_X2 = []
        intersect_X1 = np.zeros(len(X1))
        intersect_X2 = np.zeros(len(X2))
        for dim in range(3):
            min_dim_1 = X1[:, 2, dim].min()
            max_dim_1 = X1[:, 2, dim].max()
            min_dim_2 = X2[:, 2, dim].min()
            max_dim_2 = X2[:, 2, dim].max()
            intersect_dim_X1.append(
                np.where(
                    np.logical_and(
                        X1[:, 2, dim] >= min_dim_2 - margin,
                        X1[:, 2, dim] <= max_dim_2 + margin,
                    )
                )[0]
            )
            intersect_dim_X2.append(
                np.where(
                    np.logical_and(
                        X2[:, 2, dim] >= min_dim_1 - margin,
                        X2[:, 2, dim] <= max_dim_1 + margin,
                    )
                )[0]
            )

        intersect_X1 = np.intersect1d(
            np.intersect1d(intersect_dim_X1[0], intersect_dim_X1[1]),
            intersect_dim_X1[2],
        )
        intersect_X2 = np.intersect1d(
            np.intersect1d(intersect_dim_X2[0], intersect_dim_X2[1]),
            intersect_dim_X2[2],
        )

        not_end_mask1 = np.where((X1[:, 2, :] == 0).sum(-1) != 3)[0]
        not_end_mask2 = np.where((X2[:, 2, :] == 0).sum(-1) != 3)[0]

        intersect_X1 = np.intersect1d(intersect_X1, not_end_mask1)
        intersect_X2 = np.intersect1d(intersect_X2, not_end_mask2)

        diff = X1[intersect_X1, 2, np.newaxis, :] - X2[intersect_X2, 2, :]
        distances = np.sqrt(np.sum(diff**2, axis=2))

        if np.sum(distances < cutoff) < 3:
            return False
        else:
            return True

    def get_index_array(self, chains=None, index_bump=100):
        """Get the index array of the protein.

        The index array is a `'numpy'` array of shape `(L,)` with the index of each residue along the chain.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the index array of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs
        index_bump : int, default 0
            If specified, the index is bumped by this number between chains

        Returns
        -------
        index_array : np.ndarray
            A `'numpy'` array of shape `(L,)` with the index of each residue along the chain; if multiple chains
            are specified, the index is bumped by `index_bump` at the beginning of each chain

        """
        chains = self._get_chains_list(chains)
        start_value = 0
        start_index = 0
        index_array = np.zeros(self.get_length(chains))
        for chain in chains:
            chain_length = self.get_length([chain])
            index_array[start_index : start_index + chain_length] = np.arange(
                start_value, start_value + chain_length
            )
            start_value += chain_length + index_bump
            start_index += chain_length
        return index_array.astype(int)

    def get_chain_id_dict(self, chains=None):
        """Get the dictionary mapping from chain indices to chain IDs.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the chain IDs of the specified chains are returned

        Returns
        -------
        chain_id_dict : dict
            A dictionary mapping from chain indices to chain IDs

        """
        chains = self._get_chains_list(chains)
        chain_id_dict = {x: i for i, x in enumerate(self.get_chains()) if x in chains}
        return chain_id_dict

    def get_chain_id_array(self, chains=None, encode=True):
        """Get the chain ID array of the protein.

        The chain ID array is a `'numpy'` array of shape `(L,)` with the chain ID of each residue.
        The chain ID is the index of the chain in the alphabetical order of the chain IDs. To get a
        mapping from the index to the chain ID, use `get_chain_id_dict()`.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the chain ID array of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs
        encode : bool, default True
            If True, the chain ID is encoded as an integer; otherwise, the chain ID is the chain ID string

        Returns
        -------
        chain_id_array : np.ndarray
            A `'numpy'` array of shape `(L,)` with the chain ID of each residue

        """
        id_dict = self.get_chain_id_dict()
        if encode:
            index_array = np.zeros(self.get_length(chains))
        else:
            index_array = np.empty(self.get_length(chains), dtype=object)
        start_index = 0
        for chain in self._get_chains_list(chains):
            chain_length = self.get_length([chain])
            index_array[start_index : start_index + chain_length] = (
                id_dict[chain] if encode else chain
            )
            start_index += chain_length
        return index_array

    def get_ligand_features(self, ligands, chains=None):
        """Get ligand coordinates, smiles, and chain mapping.

        Parameters
        ----------
        ligands : dict
            A dictionary mapping from chain IDs to a list of ligands, where each ligand is a dictionary
        chains : list of str, optional
            If specified, only the ligands of the specified chains are returned (in the same order);
            otherwise, all ligands are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        X_ligands : torch.Tensor
            A `'torch'` tensor of shape `(N, 3)` with the ligand coordinates
        ligand_smiles : str
            A string with the ligand smiles separated by a dot
        ligand_chains : torch.Tensor
            A `'torch'` tensor of shape `(N, 1)` with the chain index of each atom
        """
        chains = self._get_chains_list(chains)
        X_ligands = []
        ligand_smiles = []
        ligand_chains = []
        for chain_i, chain in enumerate(chains):
            all_smiles = ".".join([x["smiles"] for x in ligands[chain]])
            ligand_smiles.append(all_smiles)
            x_lig = np.concatenate([x["X"] for x in ligands[chain]])
            X_ligands.append(x_lig)
            ligand_chains += [[chain_i]] * len(x_lig)
        ligand_smiles = ".".join(ligand_smiles)
        X_ligands = from_numpy(np.concatenate(X_ligands, 0))
        ligand_chains = Tensor(ligand_chains)
        return (
            X_ligands,
            ligand_smiles,
            ligand_chains,
        )

    def _get_highlight_mask_dict(self, highlight_mask=None):
        """Turn mask array into a dictionary."""
        chain_arr = self.get_chain_id_array(encode=False)
        mask_arr = self.get_mask().astype(bool)
        highlight_mask_dict = {}
        if highlight_mask is not None:
            chains = self.get_chains()
            for chain in chains:
                chain_mask = chain_arr == chain
                pdb_highlight = highlight_mask[mask_arr & chain_mask]
                highlight_mask_dict[chain] = pdb_highlight
        return highlight_mask_dict

    def _get_atom_dicts(
        self,
        highlight_mask=None,
        style="cartoon",
        opacity=1,
        colors=None,
        accent_color="#D96181",
    ):
        """Get the atom dictionaries of the protein."""
        highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
        pdb_entry = PDBEntry(self._temp_pdb_file())
        return pdb_entry._get_atom_dicts(
            highlight_mask_dict=highlight_mask_dict,
            style=style,
            opacity=opacity,
            colors=colors,
            accent_color=accent_color,
        )

    def get_predict_mask(self, chains=None, only_known=False):
        """Get the prediction mask of the protein.

        The prediction mask is a `'numpy'` array of shape `(L,)` with ones
        corresponding to residues that were generated by a model and zeros to
        residues with known coordinates. If the prediction mask is not available,
        `None` is returned.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the prediction mask of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs
        only_known : bool, default False
            If `True`, only residues with known coordinates are returned

        Returns
        -------
        predict_mask : np.ndarray
            A `'numpy'` array of shape `(L,)` with ones corresponding to residues that were generated by a model and
            zeros to residues with known coordinates

        """
        if list(self.predict_mask.values())[0] is None:
            return None
        chains = self._get_chains_list(chains)
        predict_mask = np.concatenate([self.predict_mask[chain] for chain in chains])
        if only_known:
            mask = self.get_mask(chains=chains)
            predict_mask = predict_mask[mask.astype(bool)]
        return predict_mask

    def visualize(
        self,
        highlight_mask=None,
        style="cartoon",
        highlight_style=None,
        opacity=1,
        canvas_size=(400, 300),
    ):
        """Visualize the protein in a notebook.

        Parameters
        ----------
        highlight_mask : np.ndarray, optional
            A `'numpy'` array of shape `(L,)` with the residues to highlight
            marked with 1 and the rest marked with 0; if not given and
            `self.predict_mask` is not `None`, the predicted residues are highlighted
        style : str, default 'cartoon'
            The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
        highlight_style : str, optional
            The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
            (defaults to the same as `style`)
        opacity : float or dict, default 1
            Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
        canvas_size : tuple, default (400, 300)
            Shape of the canvas

        """
        if highlight_mask is not None:
            highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
        elif list(self.predict_mask.values())[0] is not None:
            highlight_mask_dict = {
                chain: self.predict_mask[chain][self.get_mask([chain]).astype(bool)]
                for chain in self.get_chains()
            }
        else:
            highlight_mask_dict = None
        with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
            self.to_pdb(tmp.name)
            pdb_entry = PDBEntry(tmp.name)
        pdb_entry.visualize(
            highlight_mask_dict=highlight_mask_dict,
            style=style,
            highlight_style=highlight_style,
            opacity=opacity,
            canvas_size=canvas_size,
        )

    def blosum62_score(self, seq_before, average=True, only_predicted=True):
        """Calculate the BLOSUM62 score of the protein.

        Parameters
        ----------
        seq_before : str
            A string with the sequence before the mutation
        average : bool, default True
            If `True`, the score is averaged over the residues; otherwise, the score is summed
        only_predicted : bool, default True
            If `True` and prediction masks are available, only predicted residues are considered

        Returns
        -------
        score : float
            The BLOSUM62 score of the protein

        """
        seq_after = self.get_sequence(encode=False)
        if self.predict_mask is not None and only_predicted:
            predict_mask = self.get_predict_mask()
            seq_before = np.array(list(seq_before))[predict_mask.astype(bool)]
            seq_after = np.array(list(seq_after))[predict_mask.astype(bool)]
        score = blosum62_score(seq_before, seq_after)
        if average:
            score /= len(seq_before)
        return score

    def long_repeat_num(self, thr=5):
        """Calculate the number of long repeats in the protein.

        Parameters
        ----------
        thr : int, default 5
            The threshold for the minimum length of the repeat

        Returns
        -------
        num : int
            The number of long repeats in the protein

        """
        seq = self.get_sequence(encode=False)
        if self.predict_mask is not None:
            predict_mask = self.get_predict_mask()
            seq = np.array(list(seq))[predict_mask.astype(bool)]
        return long_repeat_num(seq, thr=thr)

    def esm_pll(
        self,
        esm_model_name="esm2_t30_150M_UR50D",
        esm_model_objects=None,
        average=False,
    ):
        """Calculate the ESM PLL score of the protein.

        Parameters
        ----------
        esm_model_name : str, default "esm2_t30_150M_UR50D"
            Name of the ESM-2 model to use
        esm_model_objects : tuple, optional
            Tuple of ESM-2 model, batch converter and tok_to_idx dictionary (if not None, `esm_model_name` will be ignored)
        average : bool, default False
            If `True`, the score is averaged over the residues; otherwise, the score is summed

        Returns
        -------
        score : float
            The ESM PLL score of the protein

        """
        chains = self.get_chains()
        chain_sequences = [self.get_sequence(chains=[chain]) for chain in chains]
        if self.predict_mask is not None:
            predict_masks = [
                (self.get_predict_mask(chains=[chain])).astype(float)
                for chain in chains
            ]
        else:
            predict_masks = [np.ones(len(x)) for x in chain_sequences]
        return esm_pll(
            chain_sequences,
            predict_masks,
            esm_model_name=esm_model_name,
            esm_model_objects=esm_model_objects,
            average=average,
        )

    def ablang_pll(self, ablang_model_name="heavy", average=False):
        """Calculate the AbLang PLL score of the protein.

        Parameters
        ----------
        ablang_model_name : str, default "heavy"
            Name of the AbLang model to use
        average : bool, default False
            If `True`, the score is averaged over the residues; otherwise, the score is summed

        Returns
        -------
        score : float
            The AbLang PLL score of the protein

        """
        chains = self.get_predicted_chains()
        chain_sequences = [self.get_sequence(chains=[chain]) for chain in chains]
        if self.predict_mask is not None:
            predict_masks = [
                (self.get_predict_mask(chains=[chain])).astype(float)
                for chain in chains
            ]
        else:
            predict_masks = [np.ones(len(x)) for x in chain_sequences]
        out = sum(
            [
                ablang_pll(
                    sequence,
                    predict_mask,
                    ablang_model_name=ablang_model_name,
                    average=False,
                )
                for sequence, predict_mask in zip(chain_sequences, predict_masks)
            ]
        )
        if average:
            out /= self.get_predict_mask(chains=chains).sum()
        return out

    def accuracy(self, seq_before):
        """Calculate the accuracy of the protein.

        Parameters
        ----------
        seq_before : str
            A string with the sequence before the mutation

        Returns
        -------
        score : float
            The accuracy of the protein

        """
        seq_after = self.get_sequence(encode=False)
        seq_before = np.array(list(seq_before))
        seq_after = np.array(list(seq_after))
        if self.predict_mask is not None:
            predict_mask = self.get_predict_mask()
            seq_before = seq_before[predict_mask.astype(bool)]
            seq_after = seq_after[predict_mask.astype(bool)]
        return np.mean(seq_before == seq_after)

    def ca_rmsd(self, entry, only_predicted=True):
        """Calculate CA RMSD between two proteins.

        Parameters
        ----------
        entry : ProteinEntry
            A `ProteinEntry` object
        only_predicted : bool, default True
            If `True` and prediction masks are available, only predicted residues are considered

        Returns
        -------
        rmsd : float
            The CA RMSD between the two proteins

        """
        if only_predicted and not self.has_predict_mask():
            only_predicted = False
        chains = [x for x in self.get_chains() if x in entry.get_chains()]
        structure1 = self.get_coordinates(only_known=True, chains=chains)[:, 2]
        structure2 = entry.get_coordinates(only_known=True, chains=chains)[:, 2]
        if only_predicted:
            mask = self.get_predict_mask(only_known=True, chains=chains).astype(bool)
            structure1 = structure1[mask]
            structure2 = structure2[mask]
        return ca_rmsd(structure1, structure2)

    def tm_score(self, entry, chains=None):
        """Calculate TM score between two proteins.

        Parameters
        ----------
        entry : ProteinEntry
            A `ProteinEntry` object
        chains : list of str, optional
            A list of chain IDs to consider

        Returns
        -------
        tm_score : float
            The TM score between the two proteins

        """
        structure1 = self.get_coordinates(only_known=True, chains=chains)[:, 2]
        structure2 = entry.get_coordinates(only_known=True, chains=chains)[:, 2]
        sequence1 = self.get_sequence(only_known=True, chains=chains)
        sequence2 = entry.get_sequence(only_known=True, chains=chains)
        return tm_score(structure1, structure2, sequence1, sequence2)

    def _temp_pdb_file(self):
        """Save a protein entry to a temporary PDB file."""
        with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as tmp:
            self.to_pdb(tmp.name)
        return tmp.name

    @staticmethod
    def esmfold_metrics(entries, only_antibody=False):
        """Calculate ESMFold metrics for a list of entries.

        Parameters
        ----------
        entries : list of ProteinEntry
            A list of `ProteinEntry` objects
        only_antibody : bool, default False
            If `True`, only antibody chains are considered

        Returns
        -------
        plddts_full : list of float
            A list of PLDDT scores averaged over all residues
        plddts_predicted : list of float
            A list of PLDDT scores averaged over predicted residues
        rmsd : list of float
            A list of RMSD values of aligned structures (predicted residues only)
        tm_score : list of float, optional
            A list of TM scores of aligned structures

        """
        sequences = []
        chains_list = [
            [
                x
                for x in entry.get_chains()
                if not entry.has_cdr()
                or not only_antibody
                or x not in entry.get_chain_type_dict()["antigen"]
            ]
            for entry in entries
        ]
        for chains, entry in zip(chains_list, entries):
            sequences.append(
                ":".join(
                    [
                        entry.get_sequence(chains=[chain], only_known=True)
                        for chain in chains
                    ]
                )
            )
        esmfold_generate(sequences)
        esmfold_paths = [
            os.path.join("esmfold_output", f"seq_{i}.pdb")
            for i in range(len(sequences))
        ]
        plddts_predicted = [
            confidence_from_file(
                path, entry.get_predict_mask(only_known=True, chains=chains)
            )
            for path, entry, chains in zip(esmfold_paths, entries, chains_list)
        ]
        plddts_full = [confidence_from_file(path) for path in esmfold_paths]
        rmsds = []
        tm_scores = []
        for entry, path in zip(entries, esmfold_paths):
            chains = [
                x
                for x in entry.get_chains()
                if not entry.has_cdr()
                or not only_antibody
                or x not in entry.get_chain_type_dict()["antigen"]
            ]
            esm_entry = ProteinEntry.from_pdb(path)
            chain_rename_dict = {k: v for k, v in zip(string.ascii_uppercase, chains)}
            esm_entry.rename_chains(chain_rename_dict)
            temp_file = entry._temp_pdb_file()
            esm_entry.align_structure(
                reference_pdb_path=temp_file,
                save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
                chain_ids=(
                    entry.get_predicted_chains() if entry.has_predict_mask() else chains
                ),
            )
            rmsds.append(
                entry.ca_rmsd(
                    ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
                )
            )
            tm_scores.append(
                entry.tm_score(
                    esm_entry,
                    chains=chains,
                )
            )
        return plddts_full, plddts_predicted, rmsds, tm_scores

    @staticmethod
    def igfold_metrics(entries, use_openmm=False):
        """Calculate IgFold metrics for a list of entries.

        Parameters
        ----------
        entries : list of ProteinEntry
            A list of `ProteinEntry` objects
        use_openmm : bool, default False
            Whether to use refinement with OpenMM

        Returns
        -------
        plddts_full : list of float
            A list of PLDDT scores averaged over all residues
        plddts_predicted : list of float
            A list of PLDDT scores averaged over predicted residues
        rmsds : list of float
            A list of RMSD values of aligned structures (predicted residues only)
        tm_scores : list of float
            A list of TM scores of individual chains (self-consistency)

        """
        chains_list = [
            [
                x
                for x in entry.get_chains()
                if x not in entry.get_chain_type_dict()["antigen"]
            ]
            for entry in entries
        ]
        sequences = [
            {
                chain: entry.get_sequence(chains=[chain], only_known=True)
                for chain in chains
            }
            for entry, chains in zip(entries, chains_list)
        ]
        igfold_generate(sequences, use_openmm=use_openmm)
        folder = "igfold_refine_output" if use_openmm else "igfold_output"
        igfold_paths = [
            os.path.join(folder, f"seq_{i}.pdb") for i in range(len(sequences))
        ]
        prmsds_predicted = [
            confidence_from_file(
                path, entry.get_predict_mask(only_known=True, chains=chains)
            )
            for path, entry, chains in zip(igfold_paths, entries, chains_list)
        ]
        prmsds_full = [confidence_from_file(path) for path in igfold_paths]
        rmsds = []
        tm_scores = []
        for entry, path in zip(entries, igfold_paths):
            igfold_entry = ProteinEntry.from_pdb(path)
            temp_file = entry._temp_pdb_file()
            igfold_entry.align_structure(
                reference_pdb_path=temp_file,
                save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
                chain_ids=entry.get_predicted_chains(),
            )
            rmsds.append(
                entry.ca_rmsd(
                    ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
                )
            )
            tm_scores.append(
                entry.tm_score(
                    igfold_entry,
                )
            )
        return prmsds_full, prmsds_predicted, rmsds, tm_scores

    @staticmethod
    def immunebuilder_metrics(entries, protein_type="antibody"):
        """Calculate ImmuneBuilder metrics for a list of entries.

        Parameters
        ----------
        entries : list of ProteinEntry
            A list of `ProteinEntry` objects
        protein_type : {"antibody", "nanobody", "tcr"}, default "antibody"
            The type of the protein

        Returns
        -------
        prmsds_full : list of float
            A list of PRMSD scores averaged over all residues
        prmsds_predicted : list of float
            A list of PRMSD scores averaged over predicted residues
        rmsds : list of float
            A list of RMSD values of aligned structures (predicted residues only)
        tm_scores : list of float
            A list of TM scores of aligned structures

        """
        sequences = []
        chains_list = [
            [
                x
                for x in entry.get_chains()
                if x not in entry.get_chain_type_dict()["antigen"]
            ]
            for entry in entries
        ]
        for chains, entry in zip(chains_list, entries):
            chain_type_dict = entry.get_chain_type_dict()
            sequences.append(
                {
                    key[0].upper(): entry.get_sequence(
                        chains=[chain_type_dict[key]], only_known=True
                    )
                    for key in ["heavy", "light"]
                    if key in chain_type_dict
                }
            )
        immunebuilder_generate(sequences, protein_type=protein_type)
        generated_paths = [
            os.path.join("immunebuilder_output", f"seq_{i}.pdb")
            for i in range(len(sequences))
        ]
        prmsds_predicted = [
            confidence_from_file(
                path, entry.get_predict_mask(only_known=True, chains=chains)
            )
            for path, entry, chains in zip(generated_paths, entries, chains_list)
        ]
        prmsds_full = [confidence_from_file(path) for path in generated_paths]
        rmsds = []
        tm_scores = []
        for entry, path, chains in zip(entries, generated_paths, chains_list):
            generated_entry = ProteinEntry.from_pdb(path)
            chain_type_dict = entry.get_chain_type_dict()
            chain_rename_dict = {}
            if "light" in chain_type_dict:
                chain_rename_dict["L"] = chain_type_dict["light"]
            if "heavy" in chain_type_dict:
                chain_rename_dict["H"] = chain_type_dict["heavy"]
            generated_entry.rename_chains(chain_rename_dict)
            temp_file = entry._temp_pdb_file()
            generated_entry.align_structure(
                reference_pdb_path=temp_file,
                save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
                chain_ids=entry.get_predicted_chains(),
            )
            rmsds.append(
                entry.ca_rmsd(
                    ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
                )
            )
            tm_scores.append(
                entry.tm_score(
                    generated_entry,
                    chains=chains,
                )
            )
        return prmsds_full, prmsds_predicted, rmsds, tm_scores

    def align_structure(self, reference_pdb_path, save_pdb_path, chain_ids=None):
        """Aligns the structure to a reference structure using the CA atoms.

        Parameters
        ----------
        reference_pdb_path : str
            Path to the reference structure (in .pdb format)
        save_pdb_path : str
            Path where the aligned structure should be saved (in .pdb format)
        chain_ids : list of str, optional
            If specified, only the chains with the specified IDs are aligned

        """
        pdb_parser = Bio.PDB.PDBParser(QUIET=True)

        temp_file = self._temp_pdb_file()
        ref_structure = pdb_parser.get_structure("reference", reference_pdb_path)
        sample_structure = pdb_parser.get_structure("sample", temp_file)

        ref_model = ref_structure[0]
        sample_model = sample_structure[0]

        ref_atoms = []
        sample_atoms = []

        for ref_chain in ref_model:
            if chain_ids is not None and ref_chain.id not in chain_ids:
                continue
            for ref_res in ref_chain:
                if "CA" in ref_res:
                    ref_atoms.append(ref_res["CA"])
                elif "C" in ref_res:
                    ref_atoms.append(ref_res["C"])
                    warnings.warn(
                        "Using a C atom instead of CA for alignment in the reference structure"
                    )

        for sample_chain in sample_model:
            if chain_ids is not None and sample_chain.id not in chain_ids:
                continue
            for sample_res in sample_chain:
                if "CA" in sample_res:
                    sample_atoms.append(sample_res["CA"])
                elif "C" in sample_res:
                    sample_atoms.append(sample_res["C"])
                    warnings.warn(
                        "Using a C atom instead of CA for alignment in the sample structure"
                    )

        super_imposer = Bio.PDB.Superimposer()
        super_imposer.set_atoms(ref_atoms, sample_atoms)
        super_imposer.apply(sample_model.get_atoms())

        io = Bio.PDB.PDBIO()
        io.set_structure(sample_structure)
        io.save(save_pdb_path)

    @staticmethod
    @requires_extra("MDAnalysis")
    def combine_multiple_frames(files, output_path="combined.pdb"):
        """Combine multiple PDB files into a single multiframe PDB file.

        Parameters
        ----------
        files : list of str
            A list of PDB or proteinflow pickle files
        output_path : str, default 'combined.pdb'
            Path to the .pdb output file

        """
        with mda.Writer(output_path, multiframe=True) as writer:
            for file in files:
                if file.endswith(".pickle"):
                    file_ = ProteinEntry.from_pickle(file)._temp_pdb_file()
                else:
                    file_ = file
                u = mda.Universe(file_)
                writer.write(u)

    def set_predict_mask(self, mask_dict):
        """Set the predicted mask.

        Parameters
        ----------
        mask_dict : dict
            A dictionary mapping from chain IDs to a `np.ndarray` mask of 0s and 1s of the same length as the chain sequence

        """
        for chain in mask_dict:
            if chain not in self.get_chains():
                raise PDBError("Chain not found")
            if len(mask_dict[chain]) != self.get_length([chain]):
                raise PDBError("Mask length does not match sequence length")
        self.predict_mask = mask_dict

    def apply_mask(self, mask):
        """Apply a mask to the protein.

        Parameters
        ----------
        mask : np.ndarray
            A boolean mask of shape `(L,)` where `L` is the length of the protein (the chains are concatenated in alphabetical order)

        Returns
        -------
        entry : ProteinEntry
            A new `ProteinEntry` object

        """
        start = 0
        out_dict = {}
        for chain in self.get_chains():
            out_dict[chain] = {}
            chain_mask = mask[start : start + self.get_length([chain])]
            start += self.get_length([chain])
            out_dict[chain]["seq"] = self.decode_sequence(
                self.get_sequence(chains=[chain], encode=True)[chain_mask]
            )
            out_dict[chain]["crd_bb"] = self.get_coordinates(
                chains=[chain], bb_only=True
            )[chain_mask]
            out_dict[chain]["crd_sc"] = self.get_coordinates(chains=[chain])[:, 4:][
                chain_mask
            ]
            out_dict[chain]["msk"] = self.get_mask(chains=[chain])[chain_mask]
            if self.has_cdr():
                out_dict[chain]["cdr"] = self.decode_cdr(
                    self.get_cdr([chain], encode=True)[chain_mask]
                )
            if self.has_predict_mask():
                out_dict[chain]["predict_msk"] = self.predict_mask[chain][chain_mask]
        if self.id is not None:
            out_dict["protein_id"] = self.id
        return ProteinEntry.from_dict(out_dict)

    def get_protein_class(self):
        """Get the protein class.

        Returns
        -------
        protein_class : str
            The protein class ("single_chain", "heteromer", "homomer")

        """
        if len(self.get_chains()) == 1:
            return "single_chain"
        else:
            for chain1, chain2 in itertools.combinations(self.get_chains(), 2):
                if len(chain1) > 0.9 * len(chain2) or len(chain2) > 0.9 * len(chain1):
                    return "heteromer"
                if edit_distance(chain1, chain2) / max(len(chain1), len(chain2)) > 0.1:
                    return "heteromer"
            return "homomer"


class PDBEntry:
    """A class for parsing PDB entries."""

    def __init__(self, pdb_path, fasta_path=None, load_ligand=False):
        """Initialize a PDBEntry object.

        If no FASTA path is provided, the sequences will be fully inferred
        from the PDB file.

        Parameters
        ----------
        pdb_path : str
            Path to the PDB file
        fasta_path : str, optional
            Path to the FASTA file

        """
        self.pdb_path = pdb_path
        self.fasta_path = fasta_path
        self.pdb_id = os.path.basename(pdb_path).split(".")[0].split("-")[0]
        self.load_ligand = load_ligand
        if load_ligand:
            self.crd_df, self.seq_df, self.ligands = self._parse_structure()
        else:
            self.crd_df, self.seq_df = self._parse_structure()
        try:
            self.fasta_dict = self._parse_fasta()
        except FileNotFoundError:
            raise PDBError("FASTA file not found")

    @staticmethod
    def from_id(pdb_id, local_folder="."):
        """Initialize a `PDBEntry` object from a PDB Id.

        Downloads the PDB and FASTA files to the local folder.

        Parameters
        ----------
        pdb_id : str
            PDB Id of the protein
        local_folder : str, default '.'
            Folder where the downloaded files will be stored

        Returns
        -------
        entry : PDBEntry
            A `PDBEntry` object

        """
        pdb_path = download_pdb(pdb_id, local_folder)
        fasta_path = download_fasta(pdb_id, local_folder)
        return PDBEntry(pdb_path=pdb_path, fasta_path=fasta_path)

    def rename_chains(self, chain_dict):
        """Rename chains in the PDB entry.

        Parameters
        ----------
        chain_dict : dict
            A dictionary mapping from old chain IDs to new chain IDs

        Returns
        -------
        entry : PDBEntry
            A `PDBEntry` object

        """
        _chain_dict = {chain: chain * 5 for chain in self.get_chains()}
        self.crd_df["chain_id"] = self.crd_df["chain_id"].replace(_chain_dict)
        self.seq_df["chain_id"] = self.seq_df["chain_id"].replace(_chain_dict)
        self.fasta_dict = {_chain_dict[k]: v for k, v in self.fasta_dict.items()}
        chain_dict = {k * 5: v for k, v in chain_dict.items()}
        self.crd_df["chain_id"] = self.crd_df["chain_id"].replace(chain_dict)
        self.seq_df["chain_id"] = self.seq_df["chain_id"].replace(chain_dict)
        self.fasta_dict = {chain_dict[k]: v for k, v in self.fasta_dict.items()}
        return self

    def merge(self, entry):
        """Merge two PDB entries.

        Parameters
        ----------
        entry : PDBEntry
            A `PDBEntry` object

        Returns
        -------
        entry : PDBEntry
            A `PDBEntry` object

        """
        if entry.pdb_id != self.pdb_id:
            self.pdb_id = f"{self.pdb_id}+{entry.pdb_id}"
        for chain in entry.get_chains():
            if chain.split("_")[0] in {x.split("_")[0] for x in self.get_chains()}:
                raise ValueError("Chain IDs must be unique")
        self.crd_df = pd.concat([self.crd_df, entry.crd_df], ignore_index=True)
        self.seq_df = pd.concat([self.seq_df, entry.seq_df], ignore_index=True)
        self.crd_df.loc[:, "atom_number"] = np.arange(len(self.crd_df))
        self.fasta_dict.update(entry.fasta_dict)
        return self

    def _get_relevant_chains(self):
        """Get the chains that are included in the entry."""
        return list(self.seq_df["chain_id"].unique())

    @staticmethod
    def parse_fasta(fasta_path):
        """Read a fasta file.

        Parameters
        ----------
        fasta_path : str
            Path to the fasta file

        Returns
        -------
        out_dict : dict
            A dictionary containing all the (author) chains in a fasta file (keys)
            and their corresponding sequence (values)

        """
        with open(fasta_path) as f:
            lines = np.array(f.readlines())

        indexes = np.array([k for k, l in enumerate(lines) if l[0] == ">"])
        starts = indexes + 1
        ends = list(indexes[1:]) + [len(lines)]
        names = lines[indexes]
        seqs = ["".join(lines[s:e]).replace("\n", "") for s, e in zip(starts, ends)]

        out_dict = {}
        for name, seq in zip(names, seqs):
            for chain in _retrieve_chain_names(name):
                out_dict[chain] = seq

        return out_dict

    def _parse_fasta(self):
        """Parse the fasta file."""
        # download fasta and check if it contains only proteins
        chains = self._get_relevant_chains()
        if self.fasta_path is None:
            seqs_dict = {k: self._pdb_sequence(k, suppress_check=True) for k in chains}
        else:
            seqs_dict = self.parse_fasta(self.fasta_path)
        # retrieve sequences that are relevant for this PDB from the fasta file
        seqs_dict = {k.upper(): v for k, v in seqs_dict.items()}
        if all([len(x) == 3 and len(set(list(x))) == 1 for x in seqs_dict.keys()]):
            seqs_dict = {k[0]: v for k, v in seqs_dict.items()}

        if not {x.split("-")[0].upper() for x in chains}.issubset(
            set(list(seqs_dict.keys()))
        ):
            raise PDBError("Some chains in the PDB do not appear in the fasta file")

        fasta_dict = {k: seqs_dict[k.split("-")[0].upper()] for k in chains}
        return fasta_dict

    def _parse_structure(self):
        """Parse the structure of the protein."""
        cif = self.pdb_path.endswith("cif.gz")
        # load coordinates in a nice format
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                if cif:
                    p = CustomMmcif().read_mmcif(self.pdb_path).get_model(1)
                else:
                    p = PandasPdb().read_pdb(self.pdb_path).get_model(1)
        except FileNotFoundError:
            raise PDBError("PDB / mmCIF file downloaded but not found")
        crd_df = p.df["ATOM"]
        crd_df = crd_df[crd_df["record_name"] == "ATOM"].reset_index()
        if "insertion" in crd_df.columns:
            crd_df["unique_residue_number"] = crd_df.apply(
                lambda row: f"{row['residue_number']}_{row['insertion']}", axis=1
            )
        seq_df = p.amino3to1()

        if self.load_ligand:
            chain2ligands = None
            try:
                chain2ligands = _get_ligands(
                    self.pdb_id,
                    p,
                    self.pdb_path,
                )
            except Exception:
                raise PDBError("Failed to retrieve ligands")
            return crd_df, seq_df, chain2ligands

        return crd_df, seq_df

    def _get_chain(self, chain):
        """Check the chain ID."""
        if chain is None:
            return chain
        if chain not in self.get_chains():
            raise PDBError("Chain not found")
        return chain

    def get_pdb_df(self, chain=None):
        """Return the PDB dataframe.

        If `chain` is provided, only information for this chain is returned.

        Parameters
        ----------
        chain : str, optional
            Chain identifier

        Returns
        -------
        df : pd.DataFrame
            A `BioPandas` style dataframe containing the PDB information

        """
        chain = self._get_chain(chain)
        if chain is None:
            return self.crd_df
        else:
            return self.crd_df[self.crd_df["chain_id"] == chain]

    def get_sequence_df(self, chain=None, suppress_check=False):
        """Return the sequence dataframe.

        If `chain` is provided, only information for this chain is returned.

        Parameters
        ----------
        chain : str, optional
            Chain identifier
        suppress_check : bool, default False
            If True, do not check if the chain is in the PDB

        Returns
        -------
        df : pd.DataFrame
            A dataframe containing the sequence and chain information
            (analogous to the `BioPandas.pdb.PandasPdb.amino3to1` method output)

        """
        if not suppress_check:
            chain = self._get_chain(chain)
        if chain is None:
            return self.seq_df
        else:
            return self.seq_df[self.seq_df["chain_id"] == chain]

    def get_fasta(self):
        """Return the fasta dictionary.

        Returns
        -------
        fasta_dict : dict
            A dictionary containing all the (author) chains in a fasta file (keys)
            and their corresponding sequence (values)

        """
        return self.fasta_dict

    def get_ligands(self):
        """Return the ligands dictionary.

        Returns
        -------
        ligands : dict
            A dictionary containing all the chains in a pdb file (keys)
            and their corresponding processed ligands (values)

        """
        return self.ligands

    def get_chains(self):
        """Return the chains in the PDB.

        Returns
        -------
        chains : list
            A list of chain identifiers

        """
        return list(self.fasta_dict.keys())

    @lru_cache()
    def _pdb_sequence(self, chain, suppress_check=False):
        """Return the PDB sequence for a given chain ID."""
        return "".join(
            self.get_sequence_df(chain, suppress_check=suppress_check)["residue_name"]
        )

    @lru_cache()
    def _align_chain(self, chain):
        """Align the PDB sequence to the FASTA sequence for a given chain ID."""
        chain = self._get_chain(chain)
        pdb_seq = self._pdb_sequence(chain)
        # aligner = PairwiseAligner()
        # aligner.match_score = 2
        # aligner.mismatch_score = -10
        # aligner.open_gap_score = -0.5
        # aligner.extend_gap_score = -0.1
        # aligned_seq, fasta_seq = aligner.align(pdb_seq, fasta[chain])[0]
        aligned_seq, fasta_seq, *_ = pairwise2.align.globalms(
            pdb_seq, self.fasta_dict[chain], 2, -10, -0.5, -0.1
        )[0]
        if "-" in fasta_seq or "".join([x for x in aligned_seq if x != "-"]) != pdb_seq:
            raise PDBError("Incorrect alignment")
        return aligned_seq, fasta_seq

    def get_alignment(self, chains=None):
        """Return the alignment between the PDB and the FASTA sequence.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are aligned)

        Returns
        -------
        alignment : dict
            A dictionary containing the aligned sequences for each chain

        """
        if chains is None:
            chains = self.chains()
        return {chain: self._align_chain(chain)[0] for chain in chains}

    def get_mask(self, chains=None):
        """Return the mask of the alignment between the PDB and the FASTA sequence.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are aligned)

        Returns
        -------
        mask : dict
            A dictionary containing the `np.ndarray` mask for each chain (0 where the
            aligned sequence has gaps and 1 where it does not)

        """
        alignment = self.get_alignment(chains)
        return {
            chain: (np.array(list(seq)) != "-").astype(int)
            for chain, seq in alignment.items()
        }

    def has_unnatural_amino_acids(self, chains=None):
        """Check if the PDB contains unnatural amino acids.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are checked)

        Returns
        -------
        bool
            True if the PDB contains unnatural amino acids, False otherwise

        """
        if chains is None:
            chains = [None]
        for chain in chains:
            crd = self.get_pdb_df(chain)
            if not crd["residue_name"].isin(D3TO1.keys()).all():
                return True
        return False

    def get_coordinates_array(self, chain):
        """Return the coordinates of the PDB as a numpy array.

        The atom order is the same as in the `ProteinEntry.ATOM_ORDER` dictionary.
        The array has zeros where the mask has zeros and that is where the sequence
        alignment to the FASTA has gaps (unknown coordinates).

        Parameters
        ----------
        chain : str
            Chain identifier

        Returns
        -------
        crd_arr : np.ndarray
            A numpy array of shape (n_residues, 14, 3) containing the coordinates
            of the PDB (zeros where the coordinates are unknown)

        """
        chain_crd = self.get_pdb_df(chain)

        # align fasta and pdb and check criteria)
        mask = self.get_mask([chain])[chain]

        # go over rows of coordinates
        crd_arr = np.zeros((len(mask), 14, 3))

        def arr_index(row):
            atom = row["atom_name"]
            if atom.startswith("H") or atom == "OXT":
                return -1  # ignore hydrogens and OXT
            order = ProteinEntry.ATOM_ORDER[row["residue_name"]]
            try:
                return order.index(atom)
            except ValueError:
                raise PDBError(f"Unexpected atoms ({atom})")

        indices = chain_crd.apply(arr_index, axis=1)
        indices = indices.astype(int)
        informative_mask = indices != -1
        res_indices = np.where(mask == 1)[0]
        unique_numbers = self.get_unique_residue_numbers(chain)
        pdb_seq = self._pdb_sequence(chain)
        if len(unique_numbers) != len(pdb_seq):
            raise PDBError("Inconsistencies in the biopandas dataframe")
        replace_dict = {x: y for x, y in zip(unique_numbers, res_indices)}
        chain_crd.loc[:, "unique_residue_number"] = chain_crd[
            "unique_residue_number"
        ].replace(replace_dict)
        crd_arr[
            chain_crd[informative_mask]["unique_residue_number"].astype(int),
            indices[informative_mask],
        ] = chain_crd[informative_mask][["x_coord", "y_coord", "z_coord"]]
        return crd_arr

    def get_unique_residue_numbers(self, chain):
        """Return the unique residue numbers (residue number + insertion code).

        Parameters
        ----------
        chain : str
            Chain identifier

        Returns
        -------
        unique_numbers : list
            A list of unique residue numbers

        """
        return self.get_pdb_df(chain)["unique_residue_number"].unique().tolist()

    def _get_atom_dicts(
        self,
        highlight_mask_dict=None,
        style="cartoon",
        highlight_style=None,
        opacity=1,
        colors=None,
        accent_color="#D96181",
    ):
        """Get the atom dictionaries for visualization."""
        assert style in ["cartoon", "sphere", "stick", "line", "cross"]
        if highlight_style is None:
            highlight_style = style
        assert highlight_style in ["cartoon", "sphere", "stick", "line", "cross"]
        outstr = []
        df_ = self.crd_df.sort_values(["chain_id", "residue_number"], inplace=False)
        for _, row in df_.iterrows():
            outstr.append(_Atom(row))
        chains = self.get_chains()
        if colors is None:
            colors = COLORS
        colors = {ch: colors[i % len(colors)] for i, ch in enumerate(chains)}
        chain_counters = defaultdict(int)
        chain_last_res = defaultdict(lambda: None)
        if highlight_mask_dict is not None:
            for chain, mask in highlight_mask_dict.items():
                if chain in self.get_chains():
                    assert len(mask) == len(
                        self._pdb_sequence(chain)
                    ), "Mask length does not match sequence length"
        for at in outstr:
            if isinstance(opacity, dict):
                op_ = opacity[at["chain"]]
            else:
                op_ = opacity
            if at["resid"] != chain_last_res[at["chain"]]:
                chain_last_res[at["chain"]] = at["resid"]
                chain_counters[at["chain"]] += 1
            at["pymol"] = {style: {"color": colors[at["chain"]], "opacity": op_}}
            if highlight_mask_dict is not None and at["chain"] in highlight_mask_dict:
                num = chain_counters[at["chain"]]
                if (
                    highlight_mask_dict[at["chain"]][num - 1] == 1
                    and accent_color is not None
                ):
                    at["pymol"] = {
                        highlight_style: {"color": accent_color, "opacity": op_}
                    }
        return outstr

    def visualize(
        self,
        highlight_mask_dict=None,
        style="cartoon",
        highlight_style=None,
        opacity=1,
        colors=None,
        accent_color="#D96181",
        canvas_size=(400, 300),
    ):
        """Visualize the protein in a notebook.

        Parameters
        ----------
        highlight_mask_dict : dict, optional
            A dictionary mapping from chain IDs to a mask of 0s and 1s of the same length as the chain sequence;
            the atoms corresponding to 1s will be highlighted in red
        style : str, default 'cartoon'
            The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
        highlight_style : str, optional
            The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
            (defaults to the same as `style`)
        opacity : float or dict, default 1
            Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
        colors : list, optional
            A list of colors to use for different chains
        accent_color : str, optional
            The color of the highlighted atoms (use `None` to disable highlighting)
        canvas_size : tuple, default (400, 300)
            The shape of the canvas

        """
        outstr = self._get_atom_dicts(
            highlight_mask_dict,
            style=style,
            highlight_style=highlight_style,
            opacity=opacity,
            colors=colors,
            accent_color=accent_color,
        )
        vis_string = "".join([str(x) for x in outstr])
        view = _get_view(canvas_size)
        view.addModelsAsFrames(vis_string)
        for i, at in enumerate(outstr):
            view.setStyle(
                {"model": -1, "serial": i + 1},
                at["pymol"],
            )
        view.zoomTo()
        view.show()


class SAbDabEntry(PDBEntry):
    """A class for parsing SAbDab entries."""

    def __init__(
        self,
        pdb_path,
        fasta_path,
        heavy_chain=None,
        light_chain=None,
        antigen_chains=None,
    ):
        """Initialize the SAbDabEntry.

        Parameters
        ----------
        pdb_path : str
            Path to the PDB file
        fasta_path : str
            Path to the FASTA file
        heavy_chain : str, optional
            Heavy chain identifier (author chain name)
        light_chain : str, optional
            Light chain identifier (author chain name)
        antigen_chains : list, optional
            List of antigen chain identifiers (author chain names)

        """
        if heavy_chain is None and light_chain is None:
            raise PDBError("At least one chain must be provided")
        self.chain_dict = {
            "heavy": heavy_chain,
            "light": light_chain,
        }
        if antigen_chains is None:
            antigen_chains = []
        self.chain_dict["antigen"] = antigen_chains
        self.reverse_chain_dict = {
            heavy_chain: "heavy",
            light_chain: "light",
        }
        for antigen_chain in antigen_chains:
            self.reverse_chain_dict[antigen_chain] = "antigen"
        super().__init__(pdb_path, fasta_path)

    def _get_relevant_chains(self):
        """Get the chains that are included in the entry."""
        chains = []
        if self.chain_dict["heavy"] is not None:
            chains.append(self.chain_dict["heavy"])
        if self.chain_dict["light"] is not None:
            chains.append(self.chain_dict["light"])
        chains.extend(self.chain_dict["antigen"])
        return chains

    @staticmethod
    def from_id(
        pdb_id,
        local_folder=".",
        light_chain=None,
        heavy_chain=None,
        antigen_chains=None,
    ):
        """Create a SAbDabEntry from a PDB ID.

        Either the light or the heavy chain must be provided.

        Parameters
        ----------
        pdb_id : str
            PDB ID
        local_folder : str, optional
            Local folder to download the PDB and FASTA files
        light_chain : str, optional
            Light chain identifier (author chain name)
        heavy_chain : str, optional
            Heavy chain identifier (author chain name)
        antigen_chains : list, optional
            List of antigen chain identifiers (author chain names)

        Returns
        -------
        entry : SAbDabEntry
            A SAbDabEntry object

        """
        pdb_path = download_pdb(pdb_id, local_folder, sabdab=True)
        fasta_path = download_fasta(pdb_id, local_folder)
        return SAbDabEntry(
            pdb_path=pdb_path,
            fasta_path=fasta_path,
            light_chain=light_chain,
            heavy_chain=heavy_chain,
            antigen_chains=antigen_chains,
        )

    def _get_chain(self, chain):
        """Return the chain identifier."""
        if chain in ["heavy", "light"]:
            chain = self.chain_dict[chain]
        return super()._get_chain(chain)

    def heavy_chain(self):
        """Return the heavy chain identifier.

        Returns
        -------
        chain : str
            The heavy chain identifier

        """
        return self.chain_dict["heavy"]

    def light_chain(self):
        """Return the light chain identifier.

        Returns
        -------
        chain : str
            The light chain identifier

        """
        return self.chain_dict["light"]

    def antigen_chains(self):
        """Return the antigen chain identifiers.

        Returns
        -------
        chains : list
            The antigen chain identifiers

        """
        return self.chain_dict["antigen"]

    def chains(self):
        """Return the chains in the PDB.

        Returns
        -------
        chains : list
            A list of chain identifiers

        """
        return [self.heavy_chain(), self.light_chain()] + self.antigen_chains()

    def chain_type(self, chain):
        """Return the type of a chain.

        Parameters
        ----------
        chain : str
            Chain identifier

        Returns
        -------
        chain_type : str
            The type of the chain (heavy, light or antigen)

        """
        if chain in self.reverse_chain_dict:
            return self.reverse_chain_dict[chain]
        raise PDBError("Chain not found")

    @lru_cache()
    def _get_chain_cdr(self, chain, align_to_fasta=True):
        """Return the CDRs for a given chain ID."""
        chain = self._get_chain(chain)
        chain_crd = self.get_pdb_df(chain)
        chain_type = self.chain_type(chain)[0].upper()
        pdb_seq = self._pdb_sequence(chain)
        unique_numbers = chain_crd["unique_residue_number"].unique()
        if len(unique_numbers) != len(pdb_seq):
            raise PDBError("Inconsistencies in the biopandas dataframe")
        if chain_type in ["H", "L"]:
            cdr_arr = [
                CDR_VALUES[chain_type][int(x.split("_")[0])] for x in unique_numbers
            ]
            cdr_arr = np.array(cdr_arr)
        else:
            cdr_arr = np.array(["-"] * len(unique_numbers), dtype=object)
        if align_to_fasta:
            aligned_seq, _ = self._align_chain(chain)
            aligned_seq_arr = np.array(list(aligned_seq))
            cdr_arr_aligned = np.array(["-"] * len(aligned_seq), dtype=object)
            cdr_arr_aligned[aligned_seq_arr != "-"] = cdr_arr
            cdr_arr = cdr_arr_aligned
        return cdr_arr

    def get_cdr(self, chains=None):
        """Return CDR arrays.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are processed)

        Returns
        -------
        cdrs : dict
            A dictionary containing the CDR arrays for each of the chains

        """
        if chains is None:
            chains = self.chains()
        return {chain: self._get_chain_cdr(chain) for chain in chains}

Sub-modules

proteinflow.data.torch

Subclasses of torch.utils.data.Dataset and torch.utils.data.DataLoader that are tuned for loading proteinflow data.

Functions

def interpolate_coords(crd, mask, fill_ends=True)

Fill in missing values in a coordinates array with linear interpolation.

Parameters

crd : np.ndarray
Coordinates array of shape (L, 4, 3)
mask : np.ndarray
Mask array of shape (L,) where 1 indicates residues with known coordinates and 0 indicates missing values
fill_ends : bool, default True
If True, fill in missing values at the ends of the protein sequence with the edge values; otherwise fill them in with zeros

Returns

crd : np.ndarray
Interpolated coordinates array of shape (L, 4, 3)
mask : np.ndarray
Interpolated mask array of shape (L,) where 1 indicates residues with known or interpolated coordinates and 0 indicates missing values
Expand source code
def interpolate_coords(crd, mask, fill_ends=True):
    """Fill in missing values in a coordinates array with linear interpolation.

    Parameters
    ----------
    crd : np.ndarray
        Coordinates array of shape `(L, 4, 3)`
    mask : np.ndarray
        Mask array of shape `(L,)` where 1 indicates residues with known coordinates and 0
        indicates missing values
    fill_ends : bool, default True
        If `True`, fill in missing values at the ends of the protein sequence with the edge values;
        otherwise fill them in with zeros

    Returns
    -------
    crd : np.ndarray
        Interpolated coordinates array of shape `(L, 4, 3)`
    mask : np.ndarray
        Interpolated mask array of shape `(L,)` where 1 indicates residues with known or interpolated
        coordinates and 0 indicates missing values

    """
    crd[(1 - mask).astype(bool)] = np.nan
    df = pd.DataFrame(crd.reshape((crd.shape[0], -1)))
    crd = df.interpolate(limit_area="inside" if not fill_ends else None).values.reshape(
        crd.shape
    )
    if not fill_ends:
        nan_mask = np.isnan(crd)  # in the middle the nans have been interpolated
        interpolated_mask = np.zeros_like(mask)
        interpolated_mask[~np.isnan(crd[:, 0, 0])] = 1
        crd[nan_mask] = 0
    else:
        interpolated_mask = np.ones_like(crd[:, :, 0])
    return crd, mask
def lru_cache()

Make a dummy decorator.

Expand source code
def lru_cache():
    """Make a dummy decorator."""

    def wrapper(func):
        return func

    return wrapper

Classes

class PDBEntry (pdb_path, fasta_path=None, load_ligand=False)

A class for parsing PDB entries.

Initialize a PDBEntry object.

If no FASTA path is provided, the sequences will be fully inferred from the PDB file.

Parameters

pdb_path : str
Path to the PDB file
fasta_path : str, optional
Path to the FASTA file
Expand source code
class PDBEntry:
    """A class for parsing PDB entries."""

    def __init__(self, pdb_path, fasta_path=None, load_ligand=False):
        """Initialize a PDBEntry object.

        If no FASTA path is provided, the sequences will be fully inferred
        from the PDB file.

        Parameters
        ----------
        pdb_path : str
            Path to the PDB file
        fasta_path : str, optional
            Path to the FASTA file

        """
        self.pdb_path = pdb_path
        self.fasta_path = fasta_path
        self.pdb_id = os.path.basename(pdb_path).split(".")[0].split("-")[0]
        self.load_ligand = load_ligand
        if load_ligand:
            self.crd_df, self.seq_df, self.ligands = self._parse_structure()
        else:
            self.crd_df, self.seq_df = self._parse_structure()
        try:
            self.fasta_dict = self._parse_fasta()
        except FileNotFoundError:
            raise PDBError("FASTA file not found")

    @staticmethod
    def from_id(pdb_id, local_folder="."):
        """Initialize a `PDBEntry` object from a PDB Id.

        Downloads the PDB and FASTA files to the local folder.

        Parameters
        ----------
        pdb_id : str
            PDB Id of the protein
        local_folder : str, default '.'
            Folder where the downloaded files will be stored

        Returns
        -------
        entry : PDBEntry
            A `PDBEntry` object

        """
        pdb_path = download_pdb(pdb_id, local_folder)
        fasta_path = download_fasta(pdb_id, local_folder)
        return PDBEntry(pdb_path=pdb_path, fasta_path=fasta_path)

    def rename_chains(self, chain_dict):
        """Rename chains in the PDB entry.

        Parameters
        ----------
        chain_dict : dict
            A dictionary mapping from old chain IDs to new chain IDs

        Returns
        -------
        entry : PDBEntry
            A `PDBEntry` object

        """
        _chain_dict = {chain: chain * 5 for chain in self.get_chains()}
        self.crd_df["chain_id"] = self.crd_df["chain_id"].replace(_chain_dict)
        self.seq_df["chain_id"] = self.seq_df["chain_id"].replace(_chain_dict)
        self.fasta_dict = {_chain_dict[k]: v for k, v in self.fasta_dict.items()}
        chain_dict = {k * 5: v for k, v in chain_dict.items()}
        self.crd_df["chain_id"] = self.crd_df["chain_id"].replace(chain_dict)
        self.seq_df["chain_id"] = self.seq_df["chain_id"].replace(chain_dict)
        self.fasta_dict = {chain_dict[k]: v for k, v in self.fasta_dict.items()}
        return self

    def merge(self, entry):
        """Merge two PDB entries.

        Parameters
        ----------
        entry : PDBEntry
            A `PDBEntry` object

        Returns
        -------
        entry : PDBEntry
            A `PDBEntry` object

        """
        if entry.pdb_id != self.pdb_id:
            self.pdb_id = f"{self.pdb_id}+{entry.pdb_id}"
        for chain in entry.get_chains():
            if chain.split("_")[0] in {x.split("_")[0] for x in self.get_chains()}:
                raise ValueError("Chain IDs must be unique")
        self.crd_df = pd.concat([self.crd_df, entry.crd_df], ignore_index=True)
        self.seq_df = pd.concat([self.seq_df, entry.seq_df], ignore_index=True)
        self.crd_df.loc[:, "atom_number"] = np.arange(len(self.crd_df))
        self.fasta_dict.update(entry.fasta_dict)
        return self

    def _get_relevant_chains(self):
        """Get the chains that are included in the entry."""
        return list(self.seq_df["chain_id"].unique())

    @staticmethod
    def parse_fasta(fasta_path):
        """Read a fasta file.

        Parameters
        ----------
        fasta_path : str
            Path to the fasta file

        Returns
        -------
        out_dict : dict
            A dictionary containing all the (author) chains in a fasta file (keys)
            and their corresponding sequence (values)

        """
        with open(fasta_path) as f:
            lines = np.array(f.readlines())

        indexes = np.array([k for k, l in enumerate(lines) if l[0] == ">"])
        starts = indexes + 1
        ends = list(indexes[1:]) + [len(lines)]
        names = lines[indexes]
        seqs = ["".join(lines[s:e]).replace("\n", "") for s, e in zip(starts, ends)]

        out_dict = {}
        for name, seq in zip(names, seqs):
            for chain in _retrieve_chain_names(name):
                out_dict[chain] = seq

        return out_dict

    def _parse_fasta(self):
        """Parse the fasta file."""
        # download fasta and check if it contains only proteins
        chains = self._get_relevant_chains()
        if self.fasta_path is None:
            seqs_dict = {k: self._pdb_sequence(k, suppress_check=True) for k in chains}
        else:
            seqs_dict = self.parse_fasta(self.fasta_path)
        # retrieve sequences that are relevant for this PDB from the fasta file
        seqs_dict = {k.upper(): v for k, v in seqs_dict.items()}
        if all([len(x) == 3 and len(set(list(x))) == 1 for x in seqs_dict.keys()]):
            seqs_dict = {k[0]: v for k, v in seqs_dict.items()}

        if not {x.split("-")[0].upper() for x in chains}.issubset(
            set(list(seqs_dict.keys()))
        ):
            raise PDBError("Some chains in the PDB do not appear in the fasta file")

        fasta_dict = {k: seqs_dict[k.split("-")[0].upper()] for k in chains}
        return fasta_dict

    def _parse_structure(self):
        """Parse the structure of the protein."""
        cif = self.pdb_path.endswith("cif.gz")
        # load coordinates in a nice format
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                if cif:
                    p = CustomMmcif().read_mmcif(self.pdb_path).get_model(1)
                else:
                    p = PandasPdb().read_pdb(self.pdb_path).get_model(1)
        except FileNotFoundError:
            raise PDBError("PDB / mmCIF file downloaded but not found")
        crd_df = p.df["ATOM"]
        crd_df = crd_df[crd_df["record_name"] == "ATOM"].reset_index()
        if "insertion" in crd_df.columns:
            crd_df["unique_residue_number"] = crd_df.apply(
                lambda row: f"{row['residue_number']}_{row['insertion']}", axis=1
            )
        seq_df = p.amino3to1()

        if self.load_ligand:
            chain2ligands = None
            try:
                chain2ligands = _get_ligands(
                    self.pdb_id,
                    p,
                    self.pdb_path,
                )
            except Exception:
                raise PDBError("Failed to retrieve ligands")
            return crd_df, seq_df, chain2ligands

        return crd_df, seq_df

    def _get_chain(self, chain):
        """Check the chain ID."""
        if chain is None:
            return chain
        if chain not in self.get_chains():
            raise PDBError("Chain not found")
        return chain

    def get_pdb_df(self, chain=None):
        """Return the PDB dataframe.

        If `chain` is provided, only information for this chain is returned.

        Parameters
        ----------
        chain : str, optional
            Chain identifier

        Returns
        -------
        df : pd.DataFrame
            A `BioPandas` style dataframe containing the PDB information

        """
        chain = self._get_chain(chain)
        if chain is None:
            return self.crd_df
        else:
            return self.crd_df[self.crd_df["chain_id"] == chain]

    def get_sequence_df(self, chain=None, suppress_check=False):
        """Return the sequence dataframe.

        If `chain` is provided, only information for this chain is returned.

        Parameters
        ----------
        chain : str, optional
            Chain identifier
        suppress_check : bool, default False
            If True, do not check if the chain is in the PDB

        Returns
        -------
        df : pd.DataFrame
            A dataframe containing the sequence and chain information
            (analogous to the `BioPandas.pdb.PandasPdb.amino3to1` method output)

        """
        if not suppress_check:
            chain = self._get_chain(chain)
        if chain is None:
            return self.seq_df
        else:
            return self.seq_df[self.seq_df["chain_id"] == chain]

    def get_fasta(self):
        """Return the fasta dictionary.

        Returns
        -------
        fasta_dict : dict
            A dictionary containing all the (author) chains in a fasta file (keys)
            and their corresponding sequence (values)

        """
        return self.fasta_dict

    def get_ligands(self):
        """Return the ligands dictionary.

        Returns
        -------
        ligands : dict
            A dictionary containing all the chains in a pdb file (keys)
            and their corresponding processed ligands (values)

        """
        return self.ligands

    def get_chains(self):
        """Return the chains in the PDB.

        Returns
        -------
        chains : list
            A list of chain identifiers

        """
        return list(self.fasta_dict.keys())

    @lru_cache()
    def _pdb_sequence(self, chain, suppress_check=False):
        """Return the PDB sequence for a given chain ID."""
        return "".join(
            self.get_sequence_df(chain, suppress_check=suppress_check)["residue_name"]
        )

    @lru_cache()
    def _align_chain(self, chain):
        """Align the PDB sequence to the FASTA sequence for a given chain ID."""
        chain = self._get_chain(chain)
        pdb_seq = self._pdb_sequence(chain)
        # aligner = PairwiseAligner()
        # aligner.match_score = 2
        # aligner.mismatch_score = -10
        # aligner.open_gap_score = -0.5
        # aligner.extend_gap_score = -0.1
        # aligned_seq, fasta_seq = aligner.align(pdb_seq, fasta[chain])[0]
        aligned_seq, fasta_seq, *_ = pairwise2.align.globalms(
            pdb_seq, self.fasta_dict[chain], 2, -10, -0.5, -0.1
        )[0]
        if "-" in fasta_seq or "".join([x for x in aligned_seq if x != "-"]) != pdb_seq:
            raise PDBError("Incorrect alignment")
        return aligned_seq, fasta_seq

    def get_alignment(self, chains=None):
        """Return the alignment between the PDB and the FASTA sequence.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are aligned)

        Returns
        -------
        alignment : dict
            A dictionary containing the aligned sequences for each chain

        """
        if chains is None:
            chains = self.chains()
        return {chain: self._align_chain(chain)[0] for chain in chains}

    def get_mask(self, chains=None):
        """Return the mask of the alignment between the PDB and the FASTA sequence.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are aligned)

        Returns
        -------
        mask : dict
            A dictionary containing the `np.ndarray` mask for each chain (0 where the
            aligned sequence has gaps and 1 where it does not)

        """
        alignment = self.get_alignment(chains)
        return {
            chain: (np.array(list(seq)) != "-").astype(int)
            for chain, seq in alignment.items()
        }

    def has_unnatural_amino_acids(self, chains=None):
        """Check if the PDB contains unnatural amino acids.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are checked)

        Returns
        -------
        bool
            True if the PDB contains unnatural amino acids, False otherwise

        """
        if chains is None:
            chains = [None]
        for chain in chains:
            crd = self.get_pdb_df(chain)
            if not crd["residue_name"].isin(D3TO1.keys()).all():
                return True
        return False

    def get_coordinates_array(self, chain):
        """Return the coordinates of the PDB as a numpy array.

        The atom order is the same as in the `ProteinEntry.ATOM_ORDER` dictionary.
        The array has zeros where the mask has zeros and that is where the sequence
        alignment to the FASTA has gaps (unknown coordinates).

        Parameters
        ----------
        chain : str
            Chain identifier

        Returns
        -------
        crd_arr : np.ndarray
            A numpy array of shape (n_residues, 14, 3) containing the coordinates
            of the PDB (zeros where the coordinates are unknown)

        """
        chain_crd = self.get_pdb_df(chain)

        # align fasta and pdb and check criteria)
        mask = self.get_mask([chain])[chain]

        # go over rows of coordinates
        crd_arr = np.zeros((len(mask), 14, 3))

        def arr_index(row):
            atom = row["atom_name"]
            if atom.startswith("H") or atom == "OXT":
                return -1  # ignore hydrogens and OXT
            order = ProteinEntry.ATOM_ORDER[row["residue_name"]]
            try:
                return order.index(atom)
            except ValueError:
                raise PDBError(f"Unexpected atoms ({atom})")

        indices = chain_crd.apply(arr_index, axis=1)
        indices = indices.astype(int)
        informative_mask = indices != -1
        res_indices = np.where(mask == 1)[0]
        unique_numbers = self.get_unique_residue_numbers(chain)
        pdb_seq = self._pdb_sequence(chain)
        if len(unique_numbers) != len(pdb_seq):
            raise PDBError("Inconsistencies in the biopandas dataframe")
        replace_dict = {x: y for x, y in zip(unique_numbers, res_indices)}
        chain_crd.loc[:, "unique_residue_number"] = chain_crd[
            "unique_residue_number"
        ].replace(replace_dict)
        crd_arr[
            chain_crd[informative_mask]["unique_residue_number"].astype(int),
            indices[informative_mask],
        ] = chain_crd[informative_mask][["x_coord", "y_coord", "z_coord"]]
        return crd_arr

    def get_unique_residue_numbers(self, chain):
        """Return the unique residue numbers (residue number + insertion code).

        Parameters
        ----------
        chain : str
            Chain identifier

        Returns
        -------
        unique_numbers : list
            A list of unique residue numbers

        """
        return self.get_pdb_df(chain)["unique_residue_number"].unique().tolist()

    def _get_atom_dicts(
        self,
        highlight_mask_dict=None,
        style="cartoon",
        highlight_style=None,
        opacity=1,
        colors=None,
        accent_color="#D96181",
    ):
        """Get the atom dictionaries for visualization."""
        assert style in ["cartoon", "sphere", "stick", "line", "cross"]
        if highlight_style is None:
            highlight_style = style
        assert highlight_style in ["cartoon", "sphere", "stick", "line", "cross"]
        outstr = []
        df_ = self.crd_df.sort_values(["chain_id", "residue_number"], inplace=False)
        for _, row in df_.iterrows():
            outstr.append(_Atom(row))
        chains = self.get_chains()
        if colors is None:
            colors = COLORS
        colors = {ch: colors[i % len(colors)] for i, ch in enumerate(chains)}
        chain_counters = defaultdict(int)
        chain_last_res = defaultdict(lambda: None)
        if highlight_mask_dict is not None:
            for chain, mask in highlight_mask_dict.items():
                if chain in self.get_chains():
                    assert len(mask) == len(
                        self._pdb_sequence(chain)
                    ), "Mask length does not match sequence length"
        for at in outstr:
            if isinstance(opacity, dict):
                op_ = opacity[at["chain"]]
            else:
                op_ = opacity
            if at["resid"] != chain_last_res[at["chain"]]:
                chain_last_res[at["chain"]] = at["resid"]
                chain_counters[at["chain"]] += 1
            at["pymol"] = {style: {"color": colors[at["chain"]], "opacity": op_}}
            if highlight_mask_dict is not None and at["chain"] in highlight_mask_dict:
                num = chain_counters[at["chain"]]
                if (
                    highlight_mask_dict[at["chain"]][num - 1] == 1
                    and accent_color is not None
                ):
                    at["pymol"] = {
                        highlight_style: {"color": accent_color, "opacity": op_}
                    }
        return outstr

    def visualize(
        self,
        highlight_mask_dict=None,
        style="cartoon",
        highlight_style=None,
        opacity=1,
        colors=None,
        accent_color="#D96181",
        canvas_size=(400, 300),
    ):
        """Visualize the protein in a notebook.

        Parameters
        ----------
        highlight_mask_dict : dict, optional
            A dictionary mapping from chain IDs to a mask of 0s and 1s of the same length as the chain sequence;
            the atoms corresponding to 1s will be highlighted in red
        style : str, default 'cartoon'
            The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
        highlight_style : str, optional
            The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
            (defaults to the same as `style`)
        opacity : float or dict, default 1
            Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
        colors : list, optional
            A list of colors to use for different chains
        accent_color : str, optional
            The color of the highlighted atoms (use `None` to disable highlighting)
        canvas_size : tuple, default (400, 300)
            The shape of the canvas

        """
        outstr = self._get_atom_dicts(
            highlight_mask_dict,
            style=style,
            highlight_style=highlight_style,
            opacity=opacity,
            colors=colors,
            accent_color=accent_color,
        )
        vis_string = "".join([str(x) for x in outstr])
        view = _get_view(canvas_size)
        view.addModelsAsFrames(vis_string)
        for i, at in enumerate(outstr):
            view.setStyle(
                {"model": -1, "serial": i + 1},
                at["pymol"],
            )
        view.zoomTo()
        view.show()

Subclasses

Static methods

def from_id(pdb_id, local_folder='.')

Initialize a PDBEntry object from a PDB Id.

Downloads the PDB and FASTA files to the local folder.

Parameters

pdb_id : str
PDB Id of the protein
local_folder : str, default '.'
Folder where the downloaded files will be stored

Returns

entry : PDBEntry
A PDBEntry object
Expand source code
@staticmethod
def from_id(pdb_id, local_folder="."):
    """Initialize a `PDBEntry` object from a PDB Id.

    Downloads the PDB and FASTA files to the local folder.

    Parameters
    ----------
    pdb_id : str
        PDB Id of the protein
    local_folder : str, default '.'
        Folder where the downloaded files will be stored

    Returns
    -------
    entry : PDBEntry
        A `PDBEntry` object

    """
    pdb_path = download_pdb(pdb_id, local_folder)
    fasta_path = download_fasta(pdb_id, local_folder)
    return PDBEntry(pdb_path=pdb_path, fasta_path=fasta_path)
def parse_fasta(fasta_path)

Read a fasta file.

Parameters

fasta_path : str
Path to the fasta file

Returns

out_dict : dict
A dictionary containing all the (author) chains in a fasta file (keys) and their corresponding sequence (values)
Expand source code
@staticmethod
def parse_fasta(fasta_path):
    """Read a fasta file.

    Parameters
    ----------
    fasta_path : str
        Path to the fasta file

    Returns
    -------
    out_dict : dict
        A dictionary containing all the (author) chains in a fasta file (keys)
        and their corresponding sequence (values)

    """
    with open(fasta_path) as f:
        lines = np.array(f.readlines())

    indexes = np.array([k for k, l in enumerate(lines) if l[0] == ">"])
    starts = indexes + 1
    ends = list(indexes[1:]) + [len(lines)]
    names = lines[indexes]
    seqs = ["".join(lines[s:e]).replace("\n", "") for s, e in zip(starts, ends)]

    out_dict = {}
    for name, seq in zip(names, seqs):
        for chain in _retrieve_chain_names(name):
            out_dict[chain] = seq

    return out_dict

Methods

def get_alignment(self, chains=None)

Return the alignment between the PDB and the FASTA sequence.

Parameters

chains : list, optional
A list of chain identifiers (if not provided, all chains are aligned)

Returns

alignment : dict
A dictionary containing the aligned sequences for each chain
Expand source code
def get_alignment(self, chains=None):
    """Return the alignment between the PDB and the FASTA sequence.

    Parameters
    ----------
    chains : list, optional
        A list of chain identifiers (if not provided, all chains are aligned)

    Returns
    -------
    alignment : dict
        A dictionary containing the aligned sequences for each chain

    """
    if chains is None:
        chains = self.chains()
    return {chain: self._align_chain(chain)[0] for chain in chains}
def get_chains(self)

Return the chains in the PDB.

Returns

chains : list
A list of chain identifiers
Expand source code
def get_chains(self):
    """Return the chains in the PDB.

    Returns
    -------
    chains : list
        A list of chain identifiers

    """
    return list(self.fasta_dict.keys())
def get_coordinates_array(self, chain)

Return the coordinates of the PDB as a numpy array.

The atom order is the same as in the ProteinEntry.ATOM_ORDER dictionary. The array has zeros where the mask has zeros and that is where the sequence alignment to the FASTA has gaps (unknown coordinates).

Parameters

chain : str
Chain identifier

Returns

crd_arr : np.ndarray
A numpy array of shape (n_residues, 14, 3) containing the coordinates of the PDB (zeros where the coordinates are unknown)
Expand source code
def get_coordinates_array(self, chain):
    """Return the coordinates of the PDB as a numpy array.

    The atom order is the same as in the `ProteinEntry.ATOM_ORDER` dictionary.
    The array has zeros where the mask has zeros and that is where the sequence
    alignment to the FASTA has gaps (unknown coordinates).

    Parameters
    ----------
    chain : str
        Chain identifier

    Returns
    -------
    crd_arr : np.ndarray
        A numpy array of shape (n_residues, 14, 3) containing the coordinates
        of the PDB (zeros where the coordinates are unknown)

    """
    chain_crd = self.get_pdb_df(chain)

    # align fasta and pdb and check criteria)
    mask = self.get_mask([chain])[chain]

    # go over rows of coordinates
    crd_arr = np.zeros((len(mask), 14, 3))

    def arr_index(row):
        atom = row["atom_name"]
        if atom.startswith("H") or atom == "OXT":
            return -1  # ignore hydrogens and OXT
        order = ProteinEntry.ATOM_ORDER[row["residue_name"]]
        try:
            return order.index(atom)
        except ValueError:
            raise PDBError(f"Unexpected atoms ({atom})")

    indices = chain_crd.apply(arr_index, axis=1)
    indices = indices.astype(int)
    informative_mask = indices != -1
    res_indices = np.where(mask == 1)[0]
    unique_numbers = self.get_unique_residue_numbers(chain)
    pdb_seq = self._pdb_sequence(chain)
    if len(unique_numbers) != len(pdb_seq):
        raise PDBError("Inconsistencies in the biopandas dataframe")
    replace_dict = {x: y for x, y in zip(unique_numbers, res_indices)}
    chain_crd.loc[:, "unique_residue_number"] = chain_crd[
        "unique_residue_number"
    ].replace(replace_dict)
    crd_arr[
        chain_crd[informative_mask]["unique_residue_number"].astype(int),
        indices[informative_mask],
    ] = chain_crd[informative_mask][["x_coord", "y_coord", "z_coord"]]
    return crd_arr
def get_fasta(self)

Return the fasta dictionary.

Returns

fasta_dict : dict
A dictionary containing all the (author) chains in a fasta file (keys) and their corresponding sequence (values)
Expand source code
def get_fasta(self):
    """Return the fasta dictionary.

    Returns
    -------
    fasta_dict : dict
        A dictionary containing all the (author) chains in a fasta file (keys)
        and their corresponding sequence (values)

    """
    return self.fasta_dict
def get_ligands(self)

Return the ligands dictionary.

Returns

ligands : dict
A dictionary containing all the chains in a pdb file (keys) and their corresponding processed ligands (values)
Expand source code
def get_ligands(self):
    """Return the ligands dictionary.

    Returns
    -------
    ligands : dict
        A dictionary containing all the chains in a pdb file (keys)
        and their corresponding processed ligands (values)

    """
    return self.ligands
def get_mask(self, chains=None)

Return the mask of the alignment between the PDB and the FASTA sequence.

Parameters

chains : list, optional
A list of chain identifiers (if not provided, all chains are aligned)

Returns

mask : dict
A dictionary containing the np.ndarray mask for each chain (0 where the aligned sequence has gaps and 1 where it does not)
Expand source code
def get_mask(self, chains=None):
    """Return the mask of the alignment between the PDB and the FASTA sequence.

    Parameters
    ----------
    chains : list, optional
        A list of chain identifiers (if not provided, all chains are aligned)

    Returns
    -------
    mask : dict
        A dictionary containing the `np.ndarray` mask for each chain (0 where the
        aligned sequence has gaps and 1 where it does not)

    """
    alignment = self.get_alignment(chains)
    return {
        chain: (np.array(list(seq)) != "-").astype(int)
        for chain, seq in alignment.items()
    }
def get_pdb_df(self, chain=None)

Return the PDB dataframe.

If chain is provided, only information for this chain is returned.

Parameters

chain : str, optional
Chain identifier

Returns

df : pd.DataFrame
A BioPandas style dataframe containing the PDB information
Expand source code
def get_pdb_df(self, chain=None):
    """Return the PDB dataframe.

    If `chain` is provided, only information for this chain is returned.

    Parameters
    ----------
    chain : str, optional
        Chain identifier

    Returns
    -------
    df : pd.DataFrame
        A `BioPandas` style dataframe containing the PDB information

    """
    chain = self._get_chain(chain)
    if chain is None:
        return self.crd_df
    else:
        return self.crd_df[self.crd_df["chain_id"] == chain]
def get_sequence_df(self, chain=None, suppress_check=False)

Return the sequence dataframe.

If chain is provided, only information for this chain is returned.

Parameters

chain : str, optional
Chain identifier
suppress_check : bool, default False
If True, do not check if the chain is in the PDB

Returns

df : pd.DataFrame
A dataframe containing the sequence and chain information (analogous to the BioPandas.pdb.PandasPdb.amino3to1 method output)
Expand source code
def get_sequence_df(self, chain=None, suppress_check=False):
    """Return the sequence dataframe.

    If `chain` is provided, only information for this chain is returned.

    Parameters
    ----------
    chain : str, optional
        Chain identifier
    suppress_check : bool, default False
        If True, do not check if the chain is in the PDB

    Returns
    -------
    df : pd.DataFrame
        A dataframe containing the sequence and chain information
        (analogous to the `BioPandas.pdb.PandasPdb.amino3to1` method output)

    """
    if not suppress_check:
        chain = self._get_chain(chain)
    if chain is None:
        return self.seq_df
    else:
        return self.seq_df[self.seq_df["chain_id"] == chain]
def get_unique_residue_numbers(self, chain)

Return the unique residue numbers (residue number + insertion code).

Parameters

chain : str
Chain identifier

Returns

unique_numbers : list
A list of unique residue numbers
Expand source code
def get_unique_residue_numbers(self, chain):
    """Return the unique residue numbers (residue number + insertion code).

    Parameters
    ----------
    chain : str
        Chain identifier

    Returns
    -------
    unique_numbers : list
        A list of unique residue numbers

    """
    return self.get_pdb_df(chain)["unique_residue_number"].unique().tolist()
def has_unnatural_amino_acids(self, chains=None)

Check if the PDB contains unnatural amino acids.

Parameters

chains : list, optional
A list of chain identifiers (if not provided, all chains are checked)

Returns

bool
True if the PDB contains unnatural amino acids, False otherwise
Expand source code
def has_unnatural_amino_acids(self, chains=None):
    """Check if the PDB contains unnatural amino acids.

    Parameters
    ----------
    chains : list, optional
        A list of chain identifiers (if not provided, all chains are checked)

    Returns
    -------
    bool
        True if the PDB contains unnatural amino acids, False otherwise

    """
    if chains is None:
        chains = [None]
    for chain in chains:
        crd = self.get_pdb_df(chain)
        if not crd["residue_name"].isin(D3TO1.keys()).all():
            return True
    return False
def merge(self, entry)

Merge two PDB entries.

Parameters

entry : PDBEntry
A PDBEntry object

Returns

entry : PDBEntry
A PDBEntry object
Expand source code
def merge(self, entry):
    """Merge two PDB entries.

    Parameters
    ----------
    entry : PDBEntry
        A `PDBEntry` object

    Returns
    -------
    entry : PDBEntry
        A `PDBEntry` object

    """
    if entry.pdb_id != self.pdb_id:
        self.pdb_id = f"{self.pdb_id}+{entry.pdb_id}"
    for chain in entry.get_chains():
        if chain.split("_")[0] in {x.split("_")[0] for x in self.get_chains()}:
            raise ValueError("Chain IDs must be unique")
    self.crd_df = pd.concat([self.crd_df, entry.crd_df], ignore_index=True)
    self.seq_df = pd.concat([self.seq_df, entry.seq_df], ignore_index=True)
    self.crd_df.loc[:, "atom_number"] = np.arange(len(self.crd_df))
    self.fasta_dict.update(entry.fasta_dict)
    return self
def rename_chains(self, chain_dict)

Rename chains in the PDB entry.

Parameters

chain_dict : dict
A dictionary mapping from old chain IDs to new chain IDs

Returns

entry : PDBEntry
A PDBEntry object
Expand source code
def rename_chains(self, chain_dict):
    """Rename chains in the PDB entry.

    Parameters
    ----------
    chain_dict : dict
        A dictionary mapping from old chain IDs to new chain IDs

    Returns
    -------
    entry : PDBEntry
        A `PDBEntry` object

    """
    _chain_dict = {chain: chain * 5 for chain in self.get_chains()}
    self.crd_df["chain_id"] = self.crd_df["chain_id"].replace(_chain_dict)
    self.seq_df["chain_id"] = self.seq_df["chain_id"].replace(_chain_dict)
    self.fasta_dict = {_chain_dict[k]: v for k, v in self.fasta_dict.items()}
    chain_dict = {k * 5: v for k, v in chain_dict.items()}
    self.crd_df["chain_id"] = self.crd_df["chain_id"].replace(chain_dict)
    self.seq_df["chain_id"] = self.seq_df["chain_id"].replace(chain_dict)
    self.fasta_dict = {chain_dict[k]: v for k, v in self.fasta_dict.items()}
    return self
def visualize(self, highlight_mask_dict=None, style='cartoon', highlight_style=None, opacity=1, colors=None, accent_color='#D96181', canvas_size=(400, 300))

Visualize the protein in a notebook.

Parameters

highlight_mask_dict : dict, optional
A dictionary mapping from chain IDs to a mask of 0s and 1s of the same length as the chain sequence; the atoms corresponding to 1s will be highlighted in red
style : str, default 'cartoon'
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
highlight_style : str, optional
The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross' (defaults to the same as style)
opacity : float or dict, default 1
Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
colors : list, optional
A list of colors to use for different chains
accent_color : str, optional
The color of the highlighted atoms (use None to disable highlighting)
canvas_size : tuple, default (400, 300)
The shape of the canvas
Expand source code
def visualize(
    self,
    highlight_mask_dict=None,
    style="cartoon",
    highlight_style=None,
    opacity=1,
    colors=None,
    accent_color="#D96181",
    canvas_size=(400, 300),
):
    """Visualize the protein in a notebook.

    Parameters
    ----------
    highlight_mask_dict : dict, optional
        A dictionary mapping from chain IDs to a mask of 0s and 1s of the same length as the chain sequence;
        the atoms corresponding to 1s will be highlighted in red
    style : str, default 'cartoon'
        The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
    highlight_style : str, optional
        The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
        (defaults to the same as `style`)
    opacity : float or dict, default 1
        Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
    colors : list, optional
        A list of colors to use for different chains
    accent_color : str, optional
        The color of the highlighted atoms (use `None` to disable highlighting)
    canvas_size : tuple, default (400, 300)
        The shape of the canvas

    """
    outstr = self._get_atom_dicts(
        highlight_mask_dict,
        style=style,
        highlight_style=highlight_style,
        opacity=opacity,
        colors=colors,
        accent_color=accent_color,
    )
    vis_string = "".join([str(x) for x in outstr])
    view = _get_view(canvas_size)
    view.addModelsAsFrames(vis_string)
    for i, at in enumerate(outstr):
        view.setStyle(
            {"model": -1, "serial": i + 1},
            at["pymol"],
        )
    view.zoomTo()
    view.show()
class ProteinEntry (seqs, crds, masks, chain_ids, predict_masks=None, cdrs=None, protein_id=None)

A class to interact with proteinflow data files.

Initialize a ProteinEntry object.

Parameters

seqs : list of str
Amino acid sequences of the protein (one-letter code)
crds : list of np.ndarray
Coordinates of the protein, numpy arrays of shape (L, 14, 3), in the order of N, C, CA, O
masks : list of np.ndarray
Mask arrays where 1 indicates residues with known coordinates and 0 indicates missing values
cdrs : list of np.ndarray
'numpy' arrays of shape (L,) where CDR residues are marked with the corresponding type ('H1', 'L1', …) and non-CDR residues are marked with '-'
chain_ids : list of str
Chain IDs of the protein
predict_masks : list of np.ndarray, optional
Mask arrays where 1 indicates residues that were generated by a model and 0 indicates residues with known coordinates
cdrs : list of np.ndarray, optional
'numpy' arrays of shape (L,) where CDR residues are marked with the corresponding type ('H1', 'L1', …)
protein_id : str, optional
ID of the protein
Expand source code
class ProteinEntry:
    """A class to interact with proteinflow data files."""

    ATOM_ORDER = {k: BACKBONE_ORDER + v for k, v in SIDECHAIN_ORDER.items()}
    """A dictionary mapping 3-letter residue names to the order of atoms in the coordinates array."""

    def __init__(
        self,
        seqs,
        crds,
        masks,
        chain_ids,
        predict_masks=None,
        cdrs=None,
        protein_id=None,
    ):
        """Initialize a `ProteinEntry` object.

        Parameters
        ----------
        seqs : list of str
            Amino acid sequences of the protein (one-letter code)
        crds : list of np.ndarray
            Coordinates of the protein, `numpy` arrays of shape `(L, 14, 3)`,
            in the order of `N, C, CA, O`
        masks : list of np.ndarray
            Mask arrays where 1 indicates residues with known coordinates and 0
            indicates missing values
        cdrs : list of np.ndarray
            `'numpy'` arrays of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
            and non-CDR residues are marked with `'-'`
        chain_ids : list of str
            Chain IDs of the protein
        predict_masks : list of np.ndarray, optional
            Mask arrays where 1 indicates residues that were generated by a model and 0
            indicates residues with known coordinates
        cdrs : list of np.ndarray, optional
            `'numpy'` arrays of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
        protein_id : str, optional
            ID of the protein

        """
        if crds[0].shape[1] != 14:
            raise ValueError(
                "Coordinates array must have 14 atoms in the order of N, C, CA, O, sidechain atoms"
            )
        self.seq = {x: seq for x, seq in zip(chain_ids, seqs)}
        self.crd = {x: crd for x, crd in zip(chain_ids, crds)}
        self.mask = {x: mask for x, mask in zip(chain_ids, masks)}
        self.mask_original = {x: mask for x, mask in zip(chain_ids, masks)}
        if cdrs is None:
            cdrs = [None for _ in chain_ids]
        self.cdr = {x: cdr for x, cdr in zip(chain_ids, cdrs)}
        if predict_masks is None:
            predict_masks = [None for _ in chain_ids]
        self.predict_mask = {x: mask for x, mask in zip(chain_ids, predict_masks)}
        self.id = protein_id

    def get_id(self):
        """Return the ID of the protein."""
        return self.id

    def interpolate_coords(self, fill_ends=True):
        """Fill in missing values in the coordinates arrays with linear interpolation.

        Parameters
        ----------
        fill_ends : bool, default True
            If `True`, fill in missing values at the ends of the protein sequence with the edge values;
            otherwise fill them in with zeros

        """
        for chain in self.get_chains():
            self.crd[chain], self.mask[chain] = interpolate_coords(
                self.crd[chain], self.mask[chain], fill_ends=fill_ends
            )

    def cut_missing_edges(self):
        """Cut off the ends of the protein sequence that have missing coordinates."""
        for chain in self.get_chains():
            mask = self.mask[chain]
            known_ind = np.where(mask == 1)[0]
            start, end = known_ind[0], known_ind[-1] + 1
            self.seq[chain] = self.seq[chain][start:end]
            self.crd[chain] = self.crd[chain][start:end]
            self.mask[chain] = self.mask[chain][start:end]
            if self.cdr[chain] is not None:
                self.cdr[chain] = self.cdr[chain][start:end]

    def get_chains(self):
        """Get the chain IDs of the protein.

        Returns
        -------
        chains : list of str
            Chain IDs of the protein

        """
        return sorted(self.seq.keys())

    def _get_chains_list(self, chains):
        """Get a list of chains to iterate over."""
        if chains is None:
            chains = self.get_chains()
        return chains

    def get_chain_type_dict(self, chains=None):
        """Get the chain types of the protein.

        If the CDRs are not annotated, this function will return `None`.
        If there is no light or heavy chain, the corresponding key will be missing.
        If there is no antigen chain, the `'antigen'` key will map to an empty list.

        Parameters
        ----------
        chains : list of str, default None
            Chain IDs to consider

        Returns
        -------
        chain_type_dict : dict
            A dictionary with keys `'heavy'`, `'light'` and `'antigen'` and values
            the corresponding chain IDs

        """
        if not self.has_cdr():
            return None
        chain_type_dict = {"antigen": []}
        chains = self._get_chains_list(chains)
        for chain, cdr in self.cdr.items():
            if chain not in chains:
                continue
            u = np.unique(cdr)
            if "H1" in u:
                chain_type_dict["heavy"] = chain
            elif "L1" in u:
                chain_type_dict["light"] = chain
            else:
                chain_type_dict["antigen"].append(chain)
        return chain_type_dict

    def get_length(self, chains=None):
        """Get the total length of a set of chains.

        Parameters
        ----------
        chain : str, optional
            Chain ID; if `None`, the length of the whole protein is returned

        Returns
        -------
        length : int
            Length of the chain

        """
        chains = self._get_chains_list(chains)
        return sum([len(self.seq[x]) for x in chains])

    def get_cdr_length(self, chains):
        """Get the length of the CDR regions of a set of chains.

        Parameters
        ----------
        chain : str
            Chain ID

        Returns
        -------
        length : int
            Length of the CDR regions of the chain

        """
        if not self.has_cdr():
            return {x: None for x in ["H1", "H2", "H3", "L1", "L2", "L3"]}
        return {
            x: len(self.get_sequence(chains=chains, cdr=x))
            for x in ["H1", "H2", "H3", "L1", "L2", "L3"]
        }

    def has_cdr(self):
        """Check if the protein is from the SAbDab database.

        Returns
        -------
        is_sabdab : bool
            True if the protein is from the SAbDab database

        """
        return list(self.cdr.values())[0] is not None

    def has_predict_mask(self):
        """Check if the protein has a predicted mask.

        Returns
        -------
        has_predict_mask : bool
            True if the protein has a predicted mask

        """
        return list(self.predict_mask.values())[0] is not None

    def __len__(self):
        """Get the total length of the protein chains."""
        return self.get_length(self.get_chains())

    def get_sequence(self, chains=None, encode=False, cdr=None, only_known=False):
        """Get the amino acid sequence of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the sequences of the specified chains is returned (in the same order);
            otherwise, all sequences are concatenated in alphabetical order of the chain IDs
        encode : bool, default False
            If `True`, the sequence is encoded as a `'numpy'` array of integers
            where each integer corresponds to the index of the amino acid in
            `proteinflow.constants.ALPHABET`
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned
        only_known : bool, default False
            If `True`, only the residues with known coordinates are returned

        Returns
        -------
        seq : str or np.ndarray
            Amino acid sequence of the protein (one-letter code) or an encoded
            sequence as a `'numpy'` array of integers

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        seq = "".join([self.seq[c] for c in chains]).replace("B", "")
        if encode:
            seq = np.array([ALPHABET_REVERSE[aa] for aa in seq])
        elif cdr is not None or only_known:
            seq = np.array(list(seq))
        if cdr is not None:
            cdr_arr = self.get_cdr(chains=chains)
            seq = seq[cdr_arr == cdr]
        if only_known:
            seq = seq[self.get_mask(chains=chains, cdr=cdr).astype(bool)]
        if not encode and not isinstance(seq, str):
            seq = "".join(seq)
        return seq

    def get_coordinates(self, chains=None, bb_only=False, cdr=None, only_known=False):
        """Get the coordinates of the protein.

        Backbone atoms are in the order of `N, C, CA, O`; for the full-atom
        order see `ProteinEntry.ATOM_ORDER` (sidechain atoms come after the
        backbone atoms).

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the coordinates of the specified chains are returned (in the same order);
            otherwise, all coordinates are concatenated in alphabetical order of the chain IDs
        bb_only : bool, default False
            If `True`, only the backbone atoms are returned
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned
        only_known : bool, default False
            If `True`, only return the coordinates of residues with known coordinates

        Returns
        -------
        crd : np.ndarray
            Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)`
            or `(L, 4, 3)` if `bb_only=True`

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        crd = np.concatenate([self.crd[c] for c in chains], axis=0)
        if cdr is not None:
            crd = crd[self.cdr == cdr]
        if bb_only:
            crd = crd[:, :4, :]
        if only_known:
            crd = crd[self.get_mask(chains=chains, cdr=cdr).astype(bool)]
        return crd

    def get_mask(self, chains=None, cdr=None, original=False):
        """Get the mask of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the masks of the specified chains are returned (in the same order);
            otherwise, all masks are concatenated in alphabetical order of the chain IDs
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned
        original : bool, default False
            If `True`, return the original mask (before interpolation)

        Returns
        -------
        mask : np.ndarray
            Mask array where 1 indicates residues with known coordinates and 0
            indicates missing values

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        mask = np.concatenate(
            [self.mask_original[c] if original else self.mask[c] for c in chains],
            axis=0,
        )
        if cdr is not None:
            mask = mask[self.cdr == cdr]
        return mask

    def get_cdr(self, chains=None, encode=False):
        """Get the CDR information of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the CDR information of the specified chains is
            returned (in the same order); otherwise, all CDR information is concatenated in
            alphabetical order of the chain IDs
        encode : bool, default False
            If `True`, the CDR information is encoded as a `'numpy'` array of
            integers where each integer corresponds to the index of the CDR
            type in `proteinflow.constants.CDR_ALPHABET`

        Returns
        -------
        cdr : np.ndarray or None
            A `'numpy'` array of shape `(L,)` where CDR residues are marked
            with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR
            residues are marked with `'-'` or an encoded array of integers
            ir `encode=True`; `None` if CDR information is not available
        chains : list of str, optional
            If specified, only the CDR information of the specified chains is
            returned (in the same order); otherwise, all CDR information is concatenated in
            alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        if self.cdr is None:
            return None
        cdr = np.concatenate([self.cdr[c] for c in chains], axis=0)
        if encode:
            cdr = np.array([CDR_REVERSE[aa] for aa in cdr])
        return cdr

    def get_atom_mask(self, chains=None, cdr=None):
        """Get the atom mask of the protein.

        Parameters
        ----------
        chains : str, optional
            If specified, only the atom masks of the specified chains are returned (in the same order);
            otherwise, all atom masks are concatenated in alphabetical order of the chain IDs
        cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
            If specified, only the CDR region of the specified type is returned

        Returns
        -------
        atom_mask : np.ndarray
            Atom mask array where 1 indicates atoms with known coordinates and 0
            indicates missing or non-existing values, shaped `(L, 14, 3)`

        """
        if cdr is not None and self.cdr is None:
            raise ValueError("CDR information not available")
        if cdr is not None:
            assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
        chains = self._get_chains_list(chains)
        seq = "".join([self.seq[c] for c in chains])
        atom_mask = np.concatenate([ATOM_MASKS[aa] for aa in seq])
        atom_mask[self.mask == 0] = 0
        if cdr is not None:
            atom_mask = atom_mask[self.cdr == cdr]
        return atom_mask

    @staticmethod
    def decode_cdr(cdr):
        """Decode the CDR information.

        Parameters
        ----------
        cdr : np.ndarray
            A `'numpy'` array of shape `(L,)` encoded as integers where each
            integer corresponds to the index of the CDR type in
            `proteinflow.constants.CDR_ALPHABET`

        Returns
        -------
        cdr : np.ndarray
            A `'numpy'` array of shape `(L,)` where CDR residues are marked
            with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR
            residues are marked with `'-'`

        """
        cdr = ProteinEntry._to_numpy(cdr)
        return np.array([CDR_ALPHABET[x] for x in cdr.astype(int)])

    @staticmethod
    def _to_numpy(arr):
        if isinstance(arr, Tensor):
            arr = arr.detach().cpu().numpy()
        if isinstance(arr, list):
            arr = np.array(arr)
        return arr

    @staticmethod
    def decode_sequence(seq):
        """Decode the amino acid sequence.

        Parameters
        ----------
        seq : np.ndarray
            A `'numpy'` array of integers where each integer corresponds to the
            index of the amino acid in `proteinflow.constants.ALPHABET`

        Returns
        -------
        seq : str
            Amino acid sequence of the protein (one-letter code)

        """
        seq = ProteinEntry._to_numpy(seq)
        return "".join([ALPHABET[x] for x in seq.astype(int)])

    def _rename_chains(self, chain_dict):
        """Rename the chains of the protein (with no safeguards)."""
        for old_chain, new_chain in chain_dict.items():
            self.seq[new_chain] = self.seq.pop(old_chain)
            self.crd[new_chain] = self.crd.pop(old_chain)
            self.mask[new_chain] = self.mask.pop(old_chain)
            self.mask_original[new_chain] = self.mask_original.pop(old_chain)
            self.cdr[new_chain] = self.cdr.pop(old_chain)
            self.predict_mask[new_chain] = self.predict_mask.pop(old_chain)

    def rename_chains(self, chain_dict):
        """Rename the chains of the protein.

        Parameters
        ----------
        chain_dict : dict
            A dictionary mapping old chain IDs to new chain IDs

        """
        for chain in self.get_chains():
            if chain not in chain_dict:
                chain_dict[chain] = chain
        self._rename_chains({k: k * 5 for k in self.get_chains()})
        self._rename_chains({k * 5: v for k, v in chain_dict.items()})

    def get_predicted_entry(self):
        """Return a `ProteinEntry` object that only contains predicted residues.

        Returns
        -------
        entry : ProteinEntry
            The truncated `ProteinEntry` object

        """
        if self.predict_mask is None:
            raise ValueError("Predicted mask not available")
        entry_dict = self.to_dict()
        for chain in self.get_chains():
            mask_ = self.predict_mask[chain].astype(bool)
            if mask_.sum() == 0:
                entry_dict.pop(chain)
                continue
            if mask_.sum() == len(mask_):
                continue
            seq_arr = np.array(list(entry_dict[chain]["seq"]))
            entry_dict[chain]["seq"] = "".join(seq_arr[mask_])
            entry_dict[chain]["crd_bb"] = entry_dict[chain]["crd_bb"][mask_]
            entry_dict[chain]["crd_sc"] = entry_dict[chain]["crd_sc"][mask_]
            entry_dict[chain]["msk"] = entry_dict[chain]["msk"][mask_]
            entry_dict[chain]["predict_msk"] = entry_dict[chain]["predict_msk"][mask_]
            if "cdr" in entry_dict[chain]:
                entry_dict[chain]["cdr"] = entry_dict[chain]["cdr"][mask_]
        return ProteinEntry.from_dict(entry_dict)

    def get_predicted_chains(self):
        """Return a list of chain IDs that contain predicted residues.

        Returns
        -------
        chains : list of str
            Chain IDs

        """
        if not self.has_predict_mask():
            raise ValueError("Predicted mask not available")
        return [k for k, v in self.predict_mask.items() if v.sum() != 0]

    def merge(self, entry):
        """Merge another `ProteinEntry` object into this one.

        Parameters
        ----------
        entry : ProteinEntry
            The merged `ProteinEntry` object

        """
        for chain in entry.get_chains():
            if chain.split("_")[0] in {x.split("_")[0] for x in self.get_chains()}:
                raise ValueError("Chain IDs must be unique")
            self.seq[chain] = entry.seq[chain]
            self.crd[chain] = entry.crd[chain]
            self.mask[chain] = entry.mask[chain]
            self.mask_original[chain] = entry.mask_original[chain]
            self.cdr[chain] = entry.cdr[chain]
            self.predict_mask[chain] = entry.predict_mask[chain]
        if not all([x is None for x in self.predict_mask.values()]):
            for k, v in self.predict_mask.items():
                if v is None:
                    self.predict_mask[k] = np.zeros(len(self.get_sequence(k)))

    @staticmethod
    def from_arrays(
        seqs,
        crds,
        masks,
        chain_id_dict,
        chain_id_array,
        predict_masks=None,
        cdrs=None,
        protein_id=None,
    ):
        """Load a protein entry from arrays.

        Parameters
        ----------
        seqs : np.ndarray
            Amino acid sequences of the protein (encoded as integers, see `proteinflow.constants.ALPHABET`), `'numpy'` array of shape `(L,)`
        crds : np.ndarray
            Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)` or `(L, 4, 3)`
        masks : np.ndarray
            Mask array where 1 indicates residues with known coordinates and 0
            indicates missing values, `'numpy'` array of shape `(L,)`
        chain_id_dict : dict
            A dictionary mapping chain IDs to indices in `chain_id_array`
        chain_id_array : np.ndarray
            A `'numpy'` array of chain IDs encoded as integers
        predict_masks : np.ndarray, optional
            Mask array where 1 indicates residues that were generated by a model and 0
            indicates residues with known coordinates, `'numpy'` array of shape `(L,)`
        cdrs : np.ndarray, optional
            A `'numpy'` array of shape `(L,)` where residues are marked
            with the corresponding CDR type (encoded as integers, see `proteinflow.constants.CDR_ALPHABET`)
        protein_id : str, optional
            Protein ID

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        seqs_list = []
        crds_list = []
        masks_list = []
        chain_ids_list = []
        predict_masks_list = None if predict_masks is None else []
        cdrs_list = None if cdrs is None else []
        for chain_id, ind in chain_id_dict.items():
            chain_ids_list.append(chain_id)
            chain_mask = chain_id_array == ind
            seqs_list.append(ProteinEntry.decode_sequence(seqs[chain_mask]))
            if crds.shape[1] != 14:
                crds_ = np.zeros((crds[chain_mask].shape[0], 14, 3))
                crds_[:, :4, :] = ProteinEntry._to_numpy(crds[chain_mask])
            else:
                crds_ = ProteinEntry._to_numpy(crds[chain_mask])
            crds_list.append(crds_)
            masks_list.append(ProteinEntry._to_numpy(masks[chain_mask]))
            if predict_masks is not None:
                predict_masks_list.append(
                    ProteinEntry._to_numpy(predict_masks[chain_mask])
                )
            if cdrs is not None:
                cdrs_list.append(ProteinEntry.decode_cdr(cdrs[chain_mask]))
        return ProteinEntry(
            seqs_list,
            crds_list,
            masks_list,
            chain_ids_list,
            predict_masks_list,
            cdrs_list,
            protein_id,
        )

    @staticmethod
    def from_dict(dictionary):
        """Load a protein entry from a dictionary.

        Parameters
        ----------
        dictionary : dict
            A nested dictionary where first-level keys are chain IDs and
            second-level keys are the following:
            - `'seq'` : amino acid sequence (one-letter code)
            - `'crd_bb'` : backbone coordinates, shaped `(L, 4, 3)`
            - `'crd_sc'` : sidechain coordinates, shaped `(L, 10, 3)`
            - `'msk'` : mask array where 1 indicates residues with known coordinates and 0
                indicates missing values, shaped `(L,)`
            - `'cdr'` (optional): CDR information, shaped `(L,)` where CDR residues are marked
                with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR residues are marked with `'-'`
            - `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
                indicates residues with known coordinates, shaped `(L,)`
            It can also contain a `'protein_id'` first-level key.

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        chains = sorted([x for x in dictionary.keys() if x != "protein_id"])
        seq = [dictionary[k]["seq"] for k in chains]
        crd = [
            np.concatenate([dictionary[k]["crd_bb"], dictionary[k]["crd_sc"]], axis=1)
            for k in chains
        ]
        mask = [dictionary[k]["msk"] for k in chains]
        cdr = [dictionary[k].get("cdr", None) for k in chains]
        predict_mask = [dictionary[k].get("predict_msk", None) for k in chains]
        return ProteinEntry(
            seqs=seq,
            crds=crd,
            masks=mask,
            cdrs=cdr,
            chain_ids=chains,
            predict_masks=predict_mask,
            protein_id=dictionary.get("protein_id"),
        )

    @staticmethod
    def from_pdb_entry(pdb_entry):
        """Load a protein entry from a `PDBEntry` object.

        Parameters
        ----------
        pdb_entry : PDBEntry
            A `PDBEntry` object

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        pdb_dict = {}
        fasta_dict = pdb_entry.get_fasta()
        for (chain,) in pdb_entry.get_chains():
            pdb_dict[chain] = {}
            fasta_seq = fasta_dict[chain]

            # align fasta and pdb and check criteria)
            mask = pdb_entry.get_mask([chain])[chain]
            if isinstance(pdb_entry, SAbDabEntry):
                pdb_dict[chain]["cdr"] = pdb_entry.get_cdr([chain])[chain]
            pdb_dict[chain]["seq"] = fasta_seq
            pdb_dict[chain]["msk"] = mask

            # go over rows of coordinates
            crd_arr = pdb_entry.get_coordinates_array(chain)

            pdb_dict[chain]["crd_bb"] = crd_arr[:, :4, :]
            pdb_dict[chain]["crd_sc"] = crd_arr[:, 4:, :]
            pdb_dict[chain]["msk"][
                (pdb_dict[chain]["crd_bb"] == 0).sum(-1).sum(-1) == 4
            ] = 0
        pdb_dict["protein_id"] = pdb_entry.pdb_id
        return ProteinEntry.from_dict(pdb_dict)

    @staticmethod
    def from_pdb(
        pdb_path,
        fasta_path=None,
        heavy_chain=None,
        light_chain=None,
        antigen_chains=None,
    ):
        """Load a protein entry from a PDB file.

        Parameters
        ----------
        pdb_path : str
            Path to the PDB file
        fasta_path : str, optional
            Path to the FASTA file; if not specified, the sequence is extracted
            from the PDB file
        heavy_chain : str, optional
            Chain ID of the heavy chain (to load a SAbDab entry)
        light_chain : str, optional
            Chain ID of the light chain (to load a SAbDab entry)
        antigen_chains : list of str, optional
            Chain IDs of the antigen chains (to load a SAbDab entry)

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        if heavy_chain is not None or light_chain is not None:
            pdb_entry = SAbDabEntry(
                pdb_path=pdb_path,
                fasta_path=fasta_path,
                heavy_chain=heavy_chain,
                light_chain=light_chain,
                antigen_chains=antigen_chains,
            )
        else:
            pdb_entry = PDBEntry(pdb_path=pdb_path, fasta_path=fasta_path)
        return ProteinEntry.from_pdb_entry(pdb_entry)

    @staticmethod
    def from_id(
        pdb_id,
        local_folder=".",
        heavy_chain=None,
        light_chain=None,
        antigen_chains=None,
    ):
        """Load a protein entry from a PDB file.

        Parameters
        ----------
        pdb_id : str
            PDB ID of the protein
        local_folder : str, default "."
            Path to the local folder where the PDB file is saved
        heavy_chain : str, optional
            Chain ID of the heavy chain (to load a SAbDab entry)
        light_chain : str, optional
            Chain ID of the light chain (to load a SAbDab entry)
        antigen_chains : list of str, optional
            Chain IDs of the antigen chains (to load a SAbDab entry)

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        if heavy_chain is not None or light_chain is not None:
            pdb_entry = SAbDabEntry.from_id(
                pdb_id=pdb_id,
                local_folder=local_folder,
                heavy_chain=heavy_chain,
                light_chain=light_chain,
                antigen_chains=antigen_chains,
            )
        else:
            pdb_entry = PDBEntry.from_id(pdb_id=pdb_id)
        return ProteinEntry.from_pdb_entry(pdb_entry)

    @staticmethod
    def from_pickle(path):
        """Load a protein entry from a pickle file.

        Parameters
        ----------
        path : str
            Path to the pickle file

        Returns
        -------
        entry : ProteinEntry
            A `ProteinEntry` object

        """
        with open(path, "rb") as f:
            data = pickle.load(f)
        return ProteinEntry.from_dict(data)

    @staticmethod
    def retrieve_ligands_from_pickle(path):
        """Retrieve ligands from a pickle file.

        Parameters
        ----------
        path : str
            Path to the pickle file

        Returns
        -------
        chain2ligand : dict
            A dictionary where keys are chain IDs and values are ligand names

        """
        with open(path, "rb") as f:
            data = pickle.load(f)
        chain2ligand = {}
        for chain in data:
            if "ligand" not in data[chain]:
                continue
            chain2ligand[chain] = data[chain]["ligand"]
        return chain2ligand

    def to_dict(self):
        """Convert a protein entry into a dictionary.

        Returns
        -------
        dictionary : dict
            A nested dictionary where first-level keys are chain IDs and
            second-level keys are the following:
            - `'seq'` : amino acid sequence (one-letter code)
            - `'crd_bb'` : backbone coordinates, shaped `(L, 4, 3)`
            - `'crd_sc'` : sidechain coordinates, shaped `(L, 10, 3)`
            - `'msk'` : mask array where 1 indicates residues with known coordinates and 0
                indicates missing values, shaped `(L,)`
            - `'cdr'` (optional): CDR information, shaped `(L,)` encoded as integers where each
                integer corresponds to the index of the CDR type in
                `proteinflow.constants.CDR_ALPHABET`
            - `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
                indicates residues with known coordinates, shaped `(L,)`
            It can optionally also contain `protein_id` as a first-level key.

        """
        data = {}
        for chain in self.get_chains():
            data[chain] = {
                "seq": self.seq[chain],
                "crd_bb": self.crd[chain][:, :4],
                "crd_sc": self.crd[chain][:, 4:],
                "msk": self.mask[chain],
            }
            if self.cdr[chain] is not None:
                data[chain]["cdr"] = self.cdr[chain]
            if self.predict_mask[chain] is not None:
                data[chain]["predict_msk"] = self.predict_mask[chain]
        if self.id is not None:
            data["protein_id"] = self.id
        return data

    def to_pdb(
        self,
        path,
        only_ca=False,
        skip_oxygens=False,
        only_backbone=False,
        title=None,
    ):
        """Save the protein entry to a PDB file.

        Parameters
        ----------
        path : str
            Path to the output PDB file
        only_ca : bool, default False
            If `True`, only backbone atoms are saved
        skip_oxygens : bool, default False
            If `True`, oxygen atoms are not saved
        only_backbone : bool, default False
            If `True`, only backbone atoms are saved
        title : str, optional
            Title of the PDB file (by default either the protein id or "Untitled")

        """
        if any([x[0].upper() != x for x in self.get_chains()]):
            raise ValueError(
                "Chain IDs must be single uppercase letters, please rename with `rename_chains` before saving."
            )
        pdb_builder = PDBBuilder(
            self,
            only_ca=only_ca,
            skip_oxygens=skip_oxygens,
            only_backbone=only_backbone,
        )
        if title is None:
            if self.id is not None:
                title = self.id
            else:
                title = "Untitled"
        pdb_builder.save_pdb(path, title=title)

    def to_pickle(self, path):
        """Save a protein entry to a pickle file.

        The output files are pickled nested dictionaries where first-level keys are chain Ids and second-level keys are the following:
        - `'crd_bb'`: a `numpy` array of shape `(L, 4, 3)` with backbone atom coordinates (N, C, CA, O),
        - `'crd_sc'`: a `numpy` array of shape `(L, 10, 3)` with sidechain atom coordinates (check `proteinflow.sidechain_order()` for the order of atoms),
        - `'msk'`: a `numpy` array of shape `(L,)` where ones correspond to residues with known coordinates and
            zeros to missing values,
        - `'seq'`: a string of length `L` with residue types.

        In a SAbDab datasets, an additional key is added to the dictionary:
        - `'cdr'`: a `'numpy'` array of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
            and non-CDR residues are marked with `'-'`.

        If a prediction mask is available, another additional key is added to the dictionary:
        - `'predict_msk'`: a `numpy` array of shape `(L,)` where ones correspond to residues that were generated by a model and
            zeros to residues with known coordinates.

        Parameters
        ----------
        path : str
            Path to the pickle file

        """
        data = self.to_dict()
        with open(path, "wb") as f:
            pickle.dump(data, f)

    def dihedral_angles(self, chains=None):
        """Calculate the backbone dihedral angles (phi, psi) of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the dihedral angles of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        angles : np.ndarray
            A `'numpy'` array of shape `(L, 2)` with backbone dihedral angles
            (phi, psi) in degrees; missing values are marked with zeros
        chains : list of str, optional
            If specified, only the dihedral angles of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        angles = []
        chains = self._get_chains_list(chains)
        # N, C, Ca, O
        # psi
        for chain in chains:
            chain_angles = []
            crd = self.get_coordinates([chain])
            mask = self.get_mask([chain])
            p = crd[:-1, [0, 2, 1], :]
            p = np.concatenate([p, crd[1:, [0], :]], 1)
            p = np.pad(p, ((0, 1), (0, 0), (0, 0)))
            chain_angles.append(_dihedral_angle(p, mask))
            # phi
            p = crd[:-1, [1], :]
            p = np.concatenate([p, crd[1:, [0, 2, 1]]], 1)
            p = np.pad(p, ((1, 0), (0, 0), (0, 0)))
            chain_angles.append(_dihedral_angle(p, mask))
            angles.append(np.stack(chain_angles, -1))
        angles = np.concatenate(angles, 0)
        return angles

    def secondary_structure(self, chains=None):
        """Calculate the secondary structure of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the secondary structure of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        sse : np.ndarray
            A `'numpy'` array of shape `(L, 3)` with secondary structure
            elements encoded as one-hot vectors (alpha-helix, beta-sheet, loop);
            missing values are marked with zeros
        chains : list of str, optional
            If specified, only the secondary structure of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        out = []
        for chain in chains:
            crd = self.get_coordinates([chain])
            sse_map = {"c": [0, 0, 1], "b": [0, 1, 0], "a": [1, 0, 0], "": [0, 0, 0]}
            sse = _annotate_sse(crd[:, :4])
            out += [sse_map[x] for x in sse]
        sse = np.array(out)
        return sse

    def sidechain_coordinates(self, chains=None):
        """Get the sidechain coordinates of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the sidechain coordinates of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        crd : np.ndarray
            A `'numpy'` array of shape `(L, 10, 3)` with sidechain atom
            coordinates (check `proteinflow.sidechain_order()` for the order of
            atoms); missing values are marked with zeros
        chains : list of str, optional
            If specified, only the sidechain coordinates of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        return self.get_coordinates(chains)[:, 4:, :]

    def chemical_features(self, chains=None):
        """Calculate chemical features of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the chemical features of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        features : np.ndarray
            A `'numpy'` array of shape `(L, 4)` with chemical features of the
            protein (hydropathy, volume, charge, polarity, acceptor/donor); missing
            values are marked with zeros
        chains : list of str, optional
            If specified, only the chemical features of the specified chains are returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        seq = "".join([self.seq[chain] for chain in chains])
        features = np.array([_PMAP(x) for x in seq])
        return features

    def sidechain_orientation(self, chains=None):
        """Calculate the (global) sidechain orientation of the protein.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the sidechain orientation of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        orientation : np.ndarray
            A `'numpy'` array of shape `(L, 3)` with sidechain orientation
            vectors; missing values are marked with zeros
        chains : list of str, optional
            If specified, only the sidechain orientation of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs

        """
        chains = self._get_chains_list(chains)
        crd = self.get_coordinates(chains=chains)
        crd_bb, crd_sc = crd[:, :4, :], crd[:, 4:, :]
        seq = self.get_sequence(chains=chains, encode=True)
        orientation = np.zeros((crd_sc.shape[0], 3))
        for i in range(1, 21):
            if MAIN_ATOM_DICT[i] is not None:
                orientation[seq == i] = (
                    crd_sc[seq == i, MAIN_ATOM_DICT[i], :] - crd_bb[seq == i, 2, :]
                )
            else:
                S_mask = self.seq == i
                orientation[S_mask] = np.random.rand(*orientation[S_mask].shape)
        orientation /= np.expand_dims(np.linalg.norm(orientation, axis=-1), -1) + 1e-7
        return orientation

    @lru_cache()
    def is_valid_pair(self, chain1, chain2, cutoff=10):
        """Check if two chains are a valid pair based on the distance between them.

        We consider two chains to be a valid pair if the distance between them is
        smaller than `cutoff` Angstroms. The distance is calculated as the minimum
        distance between any two atoms of the two chains.

        Parameters
        ----------
        chain1 : str
            Chain ID of the first chain
        chain2 : str
            Chain ID of the second chain
        cutoff : int, optional
            Minimum distance between the two chains (in Angstroms)

        Returns
        -------
        valid : bool
            `True` if the two chains are a valid pair, `False` otherwise

        """
        margin = cutoff * 3
        assert chain1 in self.get_chains(), f"Chain {chain1} not found"
        assert chain2 in self.get_chains(), f"Chain {chain2} not found"
        X1 = self.get_coordinates(chains=[chain1], only_known=True)
        X2 = self.get_coordinates(chains=[chain2], only_known=True)
        intersect_dim_X1 = []
        intersect_dim_X2 = []
        intersect_X1 = np.zeros(len(X1))
        intersect_X2 = np.zeros(len(X2))
        for dim in range(3):
            min_dim_1 = X1[:, 2, dim].min()
            max_dim_1 = X1[:, 2, dim].max()
            min_dim_2 = X2[:, 2, dim].min()
            max_dim_2 = X2[:, 2, dim].max()
            intersect_dim_X1.append(
                np.where(
                    np.logical_and(
                        X1[:, 2, dim] >= min_dim_2 - margin,
                        X1[:, 2, dim] <= max_dim_2 + margin,
                    )
                )[0]
            )
            intersect_dim_X2.append(
                np.where(
                    np.logical_and(
                        X2[:, 2, dim] >= min_dim_1 - margin,
                        X2[:, 2, dim] <= max_dim_1 + margin,
                    )
                )[0]
            )

        intersect_X1 = np.intersect1d(
            np.intersect1d(intersect_dim_X1[0], intersect_dim_X1[1]),
            intersect_dim_X1[2],
        )
        intersect_X2 = np.intersect1d(
            np.intersect1d(intersect_dim_X2[0], intersect_dim_X2[1]),
            intersect_dim_X2[2],
        )

        not_end_mask1 = np.where((X1[:, 2, :] == 0).sum(-1) != 3)[0]
        not_end_mask2 = np.where((X2[:, 2, :] == 0).sum(-1) != 3)[0]

        intersect_X1 = np.intersect1d(intersect_X1, not_end_mask1)
        intersect_X2 = np.intersect1d(intersect_X2, not_end_mask2)

        diff = X1[intersect_X1, 2, np.newaxis, :] - X2[intersect_X2, 2, :]
        distances = np.sqrt(np.sum(diff**2, axis=2))

        if np.sum(distances < cutoff) < 3:
            return False
        else:
            return True

    def get_index_array(self, chains=None, index_bump=100):
        """Get the index array of the protein.

        The index array is a `'numpy'` array of shape `(L,)` with the index of each residue along the chain.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the index array of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs
        index_bump : int, default 0
            If specified, the index is bumped by this number between chains

        Returns
        -------
        index_array : np.ndarray
            A `'numpy'` array of shape `(L,)` with the index of each residue along the chain; if multiple chains
            are specified, the index is bumped by `index_bump` at the beginning of each chain

        """
        chains = self._get_chains_list(chains)
        start_value = 0
        start_index = 0
        index_array = np.zeros(self.get_length(chains))
        for chain in chains:
            chain_length = self.get_length([chain])
            index_array[start_index : start_index + chain_length] = np.arange(
                start_value, start_value + chain_length
            )
            start_value += chain_length + index_bump
            start_index += chain_length
        return index_array.astype(int)

    def get_chain_id_dict(self, chains=None):
        """Get the dictionary mapping from chain indices to chain IDs.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the chain IDs of the specified chains are returned

        Returns
        -------
        chain_id_dict : dict
            A dictionary mapping from chain indices to chain IDs

        """
        chains = self._get_chains_list(chains)
        chain_id_dict = {x: i for i, x in enumerate(self.get_chains()) if x in chains}
        return chain_id_dict

    def get_chain_id_array(self, chains=None, encode=True):
        """Get the chain ID array of the protein.

        The chain ID array is a `'numpy'` array of shape `(L,)` with the chain ID of each residue.
        The chain ID is the index of the chain in the alphabetical order of the chain IDs. To get a
        mapping from the index to the chain ID, use `get_chain_id_dict()`.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the chain ID array of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs
        encode : bool, default True
            If True, the chain ID is encoded as an integer; otherwise, the chain ID is the chain ID string

        Returns
        -------
        chain_id_array : np.ndarray
            A `'numpy'` array of shape `(L,)` with the chain ID of each residue

        """
        id_dict = self.get_chain_id_dict()
        if encode:
            index_array = np.zeros(self.get_length(chains))
        else:
            index_array = np.empty(self.get_length(chains), dtype=object)
        start_index = 0
        for chain in self._get_chains_list(chains):
            chain_length = self.get_length([chain])
            index_array[start_index : start_index + chain_length] = (
                id_dict[chain] if encode else chain
            )
            start_index += chain_length
        return index_array

    def get_ligand_features(self, ligands, chains=None):
        """Get ligand coordinates, smiles, and chain mapping.

        Parameters
        ----------
        ligands : dict
            A dictionary mapping from chain IDs to a list of ligands, where each ligand is a dictionary
        chains : list of str, optional
            If specified, only the ligands of the specified chains are returned (in the same order);
            otherwise, all ligands are concatenated in alphabetical order of the chain IDs

        Returns
        -------
        X_ligands : torch.Tensor
            A `'torch'` tensor of shape `(N, 3)` with the ligand coordinates
        ligand_smiles : str
            A string with the ligand smiles separated by a dot
        ligand_chains : torch.Tensor
            A `'torch'` tensor of shape `(N, 1)` with the chain index of each atom
        """
        chains = self._get_chains_list(chains)
        X_ligands = []
        ligand_smiles = []
        ligand_chains = []
        for chain_i, chain in enumerate(chains):
            all_smiles = ".".join([x["smiles"] for x in ligands[chain]])
            ligand_smiles.append(all_smiles)
            x_lig = np.concatenate([x["X"] for x in ligands[chain]])
            X_ligands.append(x_lig)
            ligand_chains += [[chain_i]] * len(x_lig)
        ligand_smiles = ".".join(ligand_smiles)
        X_ligands = from_numpy(np.concatenate(X_ligands, 0))
        ligand_chains = Tensor(ligand_chains)
        return (
            X_ligands,
            ligand_smiles,
            ligand_chains,
        )

    def _get_highlight_mask_dict(self, highlight_mask=None):
        """Turn mask array into a dictionary."""
        chain_arr = self.get_chain_id_array(encode=False)
        mask_arr = self.get_mask().astype(bool)
        highlight_mask_dict = {}
        if highlight_mask is not None:
            chains = self.get_chains()
            for chain in chains:
                chain_mask = chain_arr == chain
                pdb_highlight = highlight_mask[mask_arr & chain_mask]
                highlight_mask_dict[chain] = pdb_highlight
        return highlight_mask_dict

    def _get_atom_dicts(
        self,
        highlight_mask=None,
        style="cartoon",
        opacity=1,
        colors=None,
        accent_color="#D96181",
    ):
        """Get the atom dictionaries of the protein."""
        highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
        pdb_entry = PDBEntry(self._temp_pdb_file())
        return pdb_entry._get_atom_dicts(
            highlight_mask_dict=highlight_mask_dict,
            style=style,
            opacity=opacity,
            colors=colors,
            accent_color=accent_color,
        )

    def get_predict_mask(self, chains=None, only_known=False):
        """Get the prediction mask of the protein.

        The prediction mask is a `'numpy'` array of shape `(L,)` with ones
        corresponding to residues that were generated by a model and zeros to
        residues with known coordinates. If the prediction mask is not available,
        `None` is returned.

        Parameters
        ----------
        chains : list of str, optional
            If specified, only the prediction mask of the specified chains is returned (in the same order);
            otherwise, all features are concatenated in alphabetical order of the chain IDs
        only_known : bool, default False
            If `True`, only residues with known coordinates are returned

        Returns
        -------
        predict_mask : np.ndarray
            A `'numpy'` array of shape `(L,)` with ones corresponding to residues that were generated by a model and
            zeros to residues with known coordinates

        """
        if list(self.predict_mask.values())[0] is None:
            return None
        chains = self._get_chains_list(chains)
        predict_mask = np.concatenate([self.predict_mask[chain] for chain in chains])
        if only_known:
            mask = self.get_mask(chains=chains)
            predict_mask = predict_mask[mask.astype(bool)]
        return predict_mask

    def visualize(
        self,
        highlight_mask=None,
        style="cartoon",
        highlight_style=None,
        opacity=1,
        canvas_size=(400, 300),
    ):
        """Visualize the protein in a notebook.

        Parameters
        ----------
        highlight_mask : np.ndarray, optional
            A `'numpy'` array of shape `(L,)` with the residues to highlight
            marked with 1 and the rest marked with 0; if not given and
            `self.predict_mask` is not `None`, the predicted residues are highlighted
        style : str, default 'cartoon'
            The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
        highlight_style : str, optional
            The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
            (defaults to the same as `style`)
        opacity : float or dict, default 1
            Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
        canvas_size : tuple, default (400, 300)
            Shape of the canvas

        """
        if highlight_mask is not None:
            highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
        elif list(self.predict_mask.values())[0] is not None:
            highlight_mask_dict = {
                chain: self.predict_mask[chain][self.get_mask([chain]).astype(bool)]
                for chain in self.get_chains()
            }
        else:
            highlight_mask_dict = None
        with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
            self.to_pdb(tmp.name)
            pdb_entry = PDBEntry(tmp.name)
        pdb_entry.visualize(
            highlight_mask_dict=highlight_mask_dict,
            style=style,
            highlight_style=highlight_style,
            opacity=opacity,
            canvas_size=canvas_size,
        )

    def blosum62_score(self, seq_before, average=True, only_predicted=True):
        """Calculate the BLOSUM62 score of the protein.

        Parameters
        ----------
        seq_before : str
            A string with the sequence before the mutation
        average : bool, default True
            If `True`, the score is averaged over the residues; otherwise, the score is summed
        only_predicted : bool, default True
            If `True` and prediction masks are available, only predicted residues are considered

        Returns
        -------
        score : float
            The BLOSUM62 score of the protein

        """
        seq_after = self.get_sequence(encode=False)
        if self.predict_mask is not None and only_predicted:
            predict_mask = self.get_predict_mask()
            seq_before = np.array(list(seq_before))[predict_mask.astype(bool)]
            seq_after = np.array(list(seq_after))[predict_mask.astype(bool)]
        score = blosum62_score(seq_before, seq_after)
        if average:
            score /= len(seq_before)
        return score

    def long_repeat_num(self, thr=5):
        """Calculate the number of long repeats in the protein.

        Parameters
        ----------
        thr : int, default 5
            The threshold for the minimum length of the repeat

        Returns
        -------
        num : int
            The number of long repeats in the protein

        """
        seq = self.get_sequence(encode=False)
        if self.predict_mask is not None:
            predict_mask = self.get_predict_mask()
            seq = np.array(list(seq))[predict_mask.astype(bool)]
        return long_repeat_num(seq, thr=thr)

    def esm_pll(
        self,
        esm_model_name="esm2_t30_150M_UR50D",
        esm_model_objects=None,
        average=False,
    ):
        """Calculate the ESM PLL score of the protein.

        Parameters
        ----------
        esm_model_name : str, default "esm2_t30_150M_UR50D"
            Name of the ESM-2 model to use
        esm_model_objects : tuple, optional
            Tuple of ESM-2 model, batch converter and tok_to_idx dictionary (if not None, `esm_model_name` will be ignored)
        average : bool, default False
            If `True`, the score is averaged over the residues; otherwise, the score is summed

        Returns
        -------
        score : float
            The ESM PLL score of the protein

        """
        chains = self.get_chains()
        chain_sequences = [self.get_sequence(chains=[chain]) for chain in chains]
        if self.predict_mask is not None:
            predict_masks = [
                (self.get_predict_mask(chains=[chain])).astype(float)
                for chain in chains
            ]
        else:
            predict_masks = [np.ones(len(x)) for x in chain_sequences]
        return esm_pll(
            chain_sequences,
            predict_masks,
            esm_model_name=esm_model_name,
            esm_model_objects=esm_model_objects,
            average=average,
        )

    def ablang_pll(self, ablang_model_name="heavy", average=False):
        """Calculate the AbLang PLL score of the protein.

        Parameters
        ----------
        ablang_model_name : str, default "heavy"
            Name of the AbLang model to use
        average : bool, default False
            If `True`, the score is averaged over the residues; otherwise, the score is summed

        Returns
        -------
        score : float
            The AbLang PLL score of the protein

        """
        chains = self.get_predicted_chains()
        chain_sequences = [self.get_sequence(chains=[chain]) for chain in chains]
        if self.predict_mask is not None:
            predict_masks = [
                (self.get_predict_mask(chains=[chain])).astype(float)
                for chain in chains
            ]
        else:
            predict_masks = [np.ones(len(x)) for x in chain_sequences]
        out = sum(
            [
                ablang_pll(
                    sequence,
                    predict_mask,
                    ablang_model_name=ablang_model_name,
                    average=False,
                )
                for sequence, predict_mask in zip(chain_sequences, predict_masks)
            ]
        )
        if average:
            out /= self.get_predict_mask(chains=chains).sum()
        return out

    def accuracy(self, seq_before):
        """Calculate the accuracy of the protein.

        Parameters
        ----------
        seq_before : str
            A string with the sequence before the mutation

        Returns
        -------
        score : float
            The accuracy of the protein

        """
        seq_after = self.get_sequence(encode=False)
        seq_before = np.array(list(seq_before))
        seq_after = np.array(list(seq_after))
        if self.predict_mask is not None:
            predict_mask = self.get_predict_mask()
            seq_before = seq_before[predict_mask.astype(bool)]
            seq_after = seq_after[predict_mask.astype(bool)]
        return np.mean(seq_before == seq_after)

    def ca_rmsd(self, entry, only_predicted=True):
        """Calculate CA RMSD between two proteins.

        Parameters
        ----------
        entry : ProteinEntry
            A `ProteinEntry` object
        only_predicted : bool, default True
            If `True` and prediction masks are available, only predicted residues are considered

        Returns
        -------
        rmsd : float
            The CA RMSD between the two proteins

        """
        if only_predicted and not self.has_predict_mask():
            only_predicted = False
        chains = [x for x in self.get_chains() if x in entry.get_chains()]
        structure1 = self.get_coordinates(only_known=True, chains=chains)[:, 2]
        structure2 = entry.get_coordinates(only_known=True, chains=chains)[:, 2]
        if only_predicted:
            mask = self.get_predict_mask(only_known=True, chains=chains).astype(bool)
            structure1 = structure1[mask]
            structure2 = structure2[mask]
        return ca_rmsd(structure1, structure2)

    def tm_score(self, entry, chains=None):
        """Calculate TM score between two proteins.

        Parameters
        ----------
        entry : ProteinEntry
            A `ProteinEntry` object
        chains : list of str, optional
            A list of chain IDs to consider

        Returns
        -------
        tm_score : float
            The TM score between the two proteins

        """
        structure1 = self.get_coordinates(only_known=True, chains=chains)[:, 2]
        structure2 = entry.get_coordinates(only_known=True, chains=chains)[:, 2]
        sequence1 = self.get_sequence(only_known=True, chains=chains)
        sequence2 = entry.get_sequence(only_known=True, chains=chains)
        return tm_score(structure1, structure2, sequence1, sequence2)

    def _temp_pdb_file(self):
        """Save a protein entry to a temporary PDB file."""
        with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as tmp:
            self.to_pdb(tmp.name)
        return tmp.name

    @staticmethod
    def esmfold_metrics(entries, only_antibody=False):
        """Calculate ESMFold metrics for a list of entries.

        Parameters
        ----------
        entries : list of ProteinEntry
            A list of `ProteinEntry` objects
        only_antibody : bool, default False
            If `True`, only antibody chains are considered

        Returns
        -------
        plddts_full : list of float
            A list of PLDDT scores averaged over all residues
        plddts_predicted : list of float
            A list of PLDDT scores averaged over predicted residues
        rmsd : list of float
            A list of RMSD values of aligned structures (predicted residues only)
        tm_score : list of float, optional
            A list of TM scores of aligned structures

        """
        sequences = []
        chains_list = [
            [
                x
                for x in entry.get_chains()
                if not entry.has_cdr()
                or not only_antibody
                or x not in entry.get_chain_type_dict()["antigen"]
            ]
            for entry in entries
        ]
        for chains, entry in zip(chains_list, entries):
            sequences.append(
                ":".join(
                    [
                        entry.get_sequence(chains=[chain], only_known=True)
                        for chain in chains
                    ]
                )
            )
        esmfold_generate(sequences)
        esmfold_paths = [
            os.path.join("esmfold_output", f"seq_{i}.pdb")
            for i in range(len(sequences))
        ]
        plddts_predicted = [
            confidence_from_file(
                path, entry.get_predict_mask(only_known=True, chains=chains)
            )
            for path, entry, chains in zip(esmfold_paths, entries, chains_list)
        ]
        plddts_full = [confidence_from_file(path) for path in esmfold_paths]
        rmsds = []
        tm_scores = []
        for entry, path in zip(entries, esmfold_paths):
            chains = [
                x
                for x in entry.get_chains()
                if not entry.has_cdr()
                or not only_antibody
                or x not in entry.get_chain_type_dict()["antigen"]
            ]
            esm_entry = ProteinEntry.from_pdb(path)
            chain_rename_dict = {k: v for k, v in zip(string.ascii_uppercase, chains)}
            esm_entry.rename_chains(chain_rename_dict)
            temp_file = entry._temp_pdb_file()
            esm_entry.align_structure(
                reference_pdb_path=temp_file,
                save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
                chain_ids=(
                    entry.get_predicted_chains() if entry.has_predict_mask() else chains
                ),
            )
            rmsds.append(
                entry.ca_rmsd(
                    ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
                )
            )
            tm_scores.append(
                entry.tm_score(
                    esm_entry,
                    chains=chains,
                )
            )
        return plddts_full, plddts_predicted, rmsds, tm_scores

    @staticmethod
    def igfold_metrics(entries, use_openmm=False):
        """Calculate IgFold metrics for a list of entries.

        Parameters
        ----------
        entries : list of ProteinEntry
            A list of `ProteinEntry` objects
        use_openmm : bool, default False
            Whether to use refinement with OpenMM

        Returns
        -------
        plddts_full : list of float
            A list of PLDDT scores averaged over all residues
        plddts_predicted : list of float
            A list of PLDDT scores averaged over predicted residues
        rmsds : list of float
            A list of RMSD values of aligned structures (predicted residues only)
        tm_scores : list of float
            A list of TM scores of individual chains (self-consistency)

        """
        chains_list = [
            [
                x
                for x in entry.get_chains()
                if x not in entry.get_chain_type_dict()["antigen"]
            ]
            for entry in entries
        ]
        sequences = [
            {
                chain: entry.get_sequence(chains=[chain], only_known=True)
                for chain in chains
            }
            for entry, chains in zip(entries, chains_list)
        ]
        igfold_generate(sequences, use_openmm=use_openmm)
        folder = "igfold_refine_output" if use_openmm else "igfold_output"
        igfold_paths = [
            os.path.join(folder, f"seq_{i}.pdb") for i in range(len(sequences))
        ]
        prmsds_predicted = [
            confidence_from_file(
                path, entry.get_predict_mask(only_known=True, chains=chains)
            )
            for path, entry, chains in zip(igfold_paths, entries, chains_list)
        ]
        prmsds_full = [confidence_from_file(path) for path in igfold_paths]
        rmsds = []
        tm_scores = []
        for entry, path in zip(entries, igfold_paths):
            igfold_entry = ProteinEntry.from_pdb(path)
            temp_file = entry._temp_pdb_file()
            igfold_entry.align_structure(
                reference_pdb_path=temp_file,
                save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
                chain_ids=entry.get_predicted_chains(),
            )
            rmsds.append(
                entry.ca_rmsd(
                    ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
                )
            )
            tm_scores.append(
                entry.tm_score(
                    igfold_entry,
                )
            )
        return prmsds_full, prmsds_predicted, rmsds, tm_scores

    @staticmethod
    def immunebuilder_metrics(entries, protein_type="antibody"):
        """Calculate ImmuneBuilder metrics for a list of entries.

        Parameters
        ----------
        entries : list of ProteinEntry
            A list of `ProteinEntry` objects
        protein_type : {"antibody", "nanobody", "tcr"}, default "antibody"
            The type of the protein

        Returns
        -------
        prmsds_full : list of float
            A list of PRMSD scores averaged over all residues
        prmsds_predicted : list of float
            A list of PRMSD scores averaged over predicted residues
        rmsds : list of float
            A list of RMSD values of aligned structures (predicted residues only)
        tm_scores : list of float
            A list of TM scores of aligned structures

        """
        sequences = []
        chains_list = [
            [
                x
                for x in entry.get_chains()
                if x not in entry.get_chain_type_dict()["antigen"]
            ]
            for entry in entries
        ]
        for chains, entry in zip(chains_list, entries):
            chain_type_dict = entry.get_chain_type_dict()
            sequences.append(
                {
                    key[0].upper(): entry.get_sequence(
                        chains=[chain_type_dict[key]], only_known=True
                    )
                    for key in ["heavy", "light"]
                    if key in chain_type_dict
                }
            )
        immunebuilder_generate(sequences, protein_type=protein_type)
        generated_paths = [
            os.path.join("immunebuilder_output", f"seq_{i}.pdb")
            for i in range(len(sequences))
        ]
        prmsds_predicted = [
            confidence_from_file(
                path, entry.get_predict_mask(only_known=True, chains=chains)
            )
            for path, entry, chains in zip(generated_paths, entries, chains_list)
        ]
        prmsds_full = [confidence_from_file(path) for path in generated_paths]
        rmsds = []
        tm_scores = []
        for entry, path, chains in zip(entries, generated_paths, chains_list):
            generated_entry = ProteinEntry.from_pdb(path)
            chain_type_dict = entry.get_chain_type_dict()
            chain_rename_dict = {}
            if "light" in chain_type_dict:
                chain_rename_dict["L"] = chain_type_dict["light"]
            if "heavy" in chain_type_dict:
                chain_rename_dict["H"] = chain_type_dict["heavy"]
            generated_entry.rename_chains(chain_rename_dict)
            temp_file = entry._temp_pdb_file()
            generated_entry.align_structure(
                reference_pdb_path=temp_file,
                save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
                chain_ids=entry.get_predicted_chains(),
            )
            rmsds.append(
                entry.ca_rmsd(
                    ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
                )
            )
            tm_scores.append(
                entry.tm_score(
                    generated_entry,
                    chains=chains,
                )
            )
        return prmsds_full, prmsds_predicted, rmsds, tm_scores

    def align_structure(self, reference_pdb_path, save_pdb_path, chain_ids=None):
        """Aligns the structure to a reference structure using the CA atoms.

        Parameters
        ----------
        reference_pdb_path : str
            Path to the reference structure (in .pdb format)
        save_pdb_path : str
            Path where the aligned structure should be saved (in .pdb format)
        chain_ids : list of str, optional
            If specified, only the chains with the specified IDs are aligned

        """
        pdb_parser = Bio.PDB.PDBParser(QUIET=True)

        temp_file = self._temp_pdb_file()
        ref_structure = pdb_parser.get_structure("reference", reference_pdb_path)
        sample_structure = pdb_parser.get_structure("sample", temp_file)

        ref_model = ref_structure[0]
        sample_model = sample_structure[0]

        ref_atoms = []
        sample_atoms = []

        for ref_chain in ref_model:
            if chain_ids is not None and ref_chain.id not in chain_ids:
                continue
            for ref_res in ref_chain:
                if "CA" in ref_res:
                    ref_atoms.append(ref_res["CA"])
                elif "C" in ref_res:
                    ref_atoms.append(ref_res["C"])
                    warnings.warn(
                        "Using a C atom instead of CA for alignment in the reference structure"
                    )

        for sample_chain in sample_model:
            if chain_ids is not None and sample_chain.id not in chain_ids:
                continue
            for sample_res in sample_chain:
                if "CA" in sample_res:
                    sample_atoms.append(sample_res["CA"])
                elif "C" in sample_res:
                    sample_atoms.append(sample_res["C"])
                    warnings.warn(
                        "Using a C atom instead of CA for alignment in the sample structure"
                    )

        super_imposer = Bio.PDB.Superimposer()
        super_imposer.set_atoms(ref_atoms, sample_atoms)
        super_imposer.apply(sample_model.get_atoms())

        io = Bio.PDB.PDBIO()
        io.set_structure(sample_structure)
        io.save(save_pdb_path)

    @staticmethod
    @requires_extra("MDAnalysis")
    def combine_multiple_frames(files, output_path="combined.pdb"):
        """Combine multiple PDB files into a single multiframe PDB file.

        Parameters
        ----------
        files : list of str
            A list of PDB or proteinflow pickle files
        output_path : str, default 'combined.pdb'
            Path to the .pdb output file

        """
        with mda.Writer(output_path, multiframe=True) as writer:
            for file in files:
                if file.endswith(".pickle"):
                    file_ = ProteinEntry.from_pickle(file)._temp_pdb_file()
                else:
                    file_ = file
                u = mda.Universe(file_)
                writer.write(u)

    def set_predict_mask(self, mask_dict):
        """Set the predicted mask.

        Parameters
        ----------
        mask_dict : dict
            A dictionary mapping from chain IDs to a `np.ndarray` mask of 0s and 1s of the same length as the chain sequence

        """
        for chain in mask_dict:
            if chain not in self.get_chains():
                raise PDBError("Chain not found")
            if len(mask_dict[chain]) != self.get_length([chain]):
                raise PDBError("Mask length does not match sequence length")
        self.predict_mask = mask_dict

    def apply_mask(self, mask):
        """Apply a mask to the protein.

        Parameters
        ----------
        mask : np.ndarray
            A boolean mask of shape `(L,)` where `L` is the length of the protein (the chains are concatenated in alphabetical order)

        Returns
        -------
        entry : ProteinEntry
            A new `ProteinEntry` object

        """
        start = 0
        out_dict = {}
        for chain in self.get_chains():
            out_dict[chain] = {}
            chain_mask = mask[start : start + self.get_length([chain])]
            start += self.get_length([chain])
            out_dict[chain]["seq"] = self.decode_sequence(
                self.get_sequence(chains=[chain], encode=True)[chain_mask]
            )
            out_dict[chain]["crd_bb"] = self.get_coordinates(
                chains=[chain], bb_only=True
            )[chain_mask]
            out_dict[chain]["crd_sc"] = self.get_coordinates(chains=[chain])[:, 4:][
                chain_mask
            ]
            out_dict[chain]["msk"] = self.get_mask(chains=[chain])[chain_mask]
            if self.has_cdr():
                out_dict[chain]["cdr"] = self.decode_cdr(
                    self.get_cdr([chain], encode=True)[chain_mask]
                )
            if self.has_predict_mask():
                out_dict[chain]["predict_msk"] = self.predict_mask[chain][chain_mask]
        if self.id is not None:
            out_dict["protein_id"] = self.id
        return ProteinEntry.from_dict(out_dict)

    def get_protein_class(self):
        """Get the protein class.

        Returns
        -------
        protein_class : str
            The protein class ("single_chain", "heteromer", "homomer")

        """
        if len(self.get_chains()) == 1:
            return "single_chain"
        else:
            for chain1, chain2 in itertools.combinations(self.get_chains(), 2):
                if len(chain1) > 0.9 * len(chain2) or len(chain2) > 0.9 * len(chain1):
                    return "heteromer"
                if edit_distance(chain1, chain2) / max(len(chain1), len(chain2)) > 0.1:
                    return "heteromer"
            return "homomer"

Class variables

var ATOM_ORDER

A dictionary mapping 3-letter residue names to the order of atoms in the coordinates array.

Static methods

def combine_multiple_frames(files, output_path='combined.pdb')

Combine multiple PDB files into a single multiframe PDB file.

Parameters

files : list of str
A list of PDB or proteinflow pickle files
output_path : str, default 'combined.pdb'
Path to the .pdb output file
Expand source code
@staticmethod
@requires_extra("MDAnalysis")
def combine_multiple_frames(files, output_path="combined.pdb"):
    """Combine multiple PDB files into a single multiframe PDB file.

    Parameters
    ----------
    files : list of str
        A list of PDB or proteinflow pickle files
    output_path : str, default 'combined.pdb'
        Path to the .pdb output file

    """
    with mda.Writer(output_path, multiframe=True) as writer:
        for file in files:
            if file.endswith(".pickle"):
                file_ = ProteinEntry.from_pickle(file)._temp_pdb_file()
            else:
                file_ = file
            u = mda.Universe(file_)
            writer.write(u)
def decode_cdr(cdr)

Decode the CDR information.

Parameters

cdr : np.ndarray
A 'numpy' array of shape (L,) encoded as integers where each integer corresponds to the index of the CDR type in proteinflow.constants.CDR_ALPHABET

Returns

cdr : np.ndarray
A 'numpy' array of shape (L,) where CDR residues are marked with the corresponding type ('H1', 'L1', …) and non-CDR residues are marked with '-'
Expand source code
@staticmethod
def decode_cdr(cdr):
    """Decode the CDR information.

    Parameters
    ----------
    cdr : np.ndarray
        A `'numpy'` array of shape `(L,)` encoded as integers where each
        integer corresponds to the index of the CDR type in
        `proteinflow.constants.CDR_ALPHABET`

    Returns
    -------
    cdr : np.ndarray
        A `'numpy'` array of shape `(L,)` where CDR residues are marked
        with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR
        residues are marked with `'-'`

    """
    cdr = ProteinEntry._to_numpy(cdr)
    return np.array([CDR_ALPHABET[x] for x in cdr.astype(int)])
def decode_sequence(seq)

Decode the amino acid sequence.

Parameters

seq : np.ndarray
A 'numpy' array of integers where each integer corresponds to the index of the amino acid in proteinflow.constants.ALPHABET

Returns

seq : str
Amino acid sequence of the protein (one-letter code)
Expand source code
@staticmethod
def decode_sequence(seq):
    """Decode the amino acid sequence.

    Parameters
    ----------
    seq : np.ndarray
        A `'numpy'` array of integers where each integer corresponds to the
        index of the amino acid in `proteinflow.constants.ALPHABET`

    Returns
    -------
    seq : str
        Amino acid sequence of the protein (one-letter code)

    """
    seq = ProteinEntry._to_numpy(seq)
    return "".join([ALPHABET[x] for x in seq.astype(int)])
def esmfold_metrics(entries, only_antibody=False)

Calculate ESMFold metrics for a list of entries.

Parameters

entries : list of ProteinEntry
A list of ProteinEntry objects
only_antibody : bool, default False
If True, only antibody chains are considered

Returns

plddts_full : list of float
A list of PLDDT scores averaged over all residues
plddts_predicted : list of float
A list of PLDDT scores averaged over predicted residues
rmsd : list of float
A list of RMSD values of aligned structures (predicted residues only)
tm_score : list of float, optional
A list of TM scores of aligned structures
Expand source code
@staticmethod
def esmfold_metrics(entries, only_antibody=False):
    """Calculate ESMFold metrics for a list of entries.

    Parameters
    ----------
    entries : list of ProteinEntry
        A list of `ProteinEntry` objects
    only_antibody : bool, default False
        If `True`, only antibody chains are considered

    Returns
    -------
    plddts_full : list of float
        A list of PLDDT scores averaged over all residues
    plddts_predicted : list of float
        A list of PLDDT scores averaged over predicted residues
    rmsd : list of float
        A list of RMSD values of aligned structures (predicted residues only)
    tm_score : list of float, optional
        A list of TM scores of aligned structures

    """
    sequences = []
    chains_list = [
        [
            x
            for x in entry.get_chains()
            if not entry.has_cdr()
            or not only_antibody
            or x not in entry.get_chain_type_dict()["antigen"]
        ]
        for entry in entries
    ]
    for chains, entry in zip(chains_list, entries):
        sequences.append(
            ":".join(
                [
                    entry.get_sequence(chains=[chain], only_known=True)
                    for chain in chains
                ]
            )
        )
    esmfold_generate(sequences)
    esmfold_paths = [
        os.path.join("esmfold_output", f"seq_{i}.pdb")
        for i in range(len(sequences))
    ]
    plddts_predicted = [
        confidence_from_file(
            path, entry.get_predict_mask(only_known=True, chains=chains)
        )
        for path, entry, chains in zip(esmfold_paths, entries, chains_list)
    ]
    plddts_full = [confidence_from_file(path) for path in esmfold_paths]
    rmsds = []
    tm_scores = []
    for entry, path in zip(entries, esmfold_paths):
        chains = [
            x
            for x in entry.get_chains()
            if not entry.has_cdr()
            or not only_antibody
            or x not in entry.get_chain_type_dict()["antigen"]
        ]
        esm_entry = ProteinEntry.from_pdb(path)
        chain_rename_dict = {k: v for k, v in zip(string.ascii_uppercase, chains)}
        esm_entry.rename_chains(chain_rename_dict)
        temp_file = entry._temp_pdb_file()
        esm_entry.align_structure(
            reference_pdb_path=temp_file,
            save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
            chain_ids=(
                entry.get_predicted_chains() if entry.has_predict_mask() else chains
            ),
        )
        rmsds.append(
            entry.ca_rmsd(
                ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
            )
        )
        tm_scores.append(
            entry.tm_score(
                esm_entry,
                chains=chains,
            )
        )
    return plddts_full, plddts_predicted, rmsds, tm_scores
def from_arrays(seqs, crds, masks, chain_id_dict, chain_id_array, predict_masks=None, cdrs=None, protein_id=None)

Load a protein entry from arrays.

Parameters

seqs : np.ndarray
Amino acid sequences of the protein (encoded as integers, see proteinflow.constants.ALPHABET), 'numpy' array of shape (L,)
crds : np.ndarray
Coordinates of the protein, 'numpy' array of shape (L, 14, 3) or (L, 4, 3)
masks : np.ndarray
Mask array where 1 indicates residues with known coordinates and 0 indicates missing values, 'numpy' array of shape (L,)
chain_id_dict : dict
A dictionary mapping chain IDs to indices in chain_id_array
chain_id_array : np.ndarray
A 'numpy' array of chain IDs encoded as integers
predict_masks : np.ndarray, optional
Mask array where 1 indicates residues that were generated by a model and 0 indicates residues with known coordinates, 'numpy' array of shape (L,)
cdrs : np.ndarray, optional
A 'numpy' array of shape (L,) where residues are marked with the corresponding CDR type (encoded as integers, see proteinflow.constants.CDR_ALPHABET)
protein_id : str, optional
Protein ID

Returns

entry : ProteinEntry
A ProteinEntry object
Expand source code
@staticmethod
def from_arrays(
    seqs,
    crds,
    masks,
    chain_id_dict,
    chain_id_array,
    predict_masks=None,
    cdrs=None,
    protein_id=None,
):
    """Load a protein entry from arrays.

    Parameters
    ----------
    seqs : np.ndarray
        Amino acid sequences of the protein (encoded as integers, see `proteinflow.constants.ALPHABET`), `'numpy'` array of shape `(L,)`
    crds : np.ndarray
        Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)` or `(L, 4, 3)`
    masks : np.ndarray
        Mask array where 1 indicates residues with known coordinates and 0
        indicates missing values, `'numpy'` array of shape `(L,)`
    chain_id_dict : dict
        A dictionary mapping chain IDs to indices in `chain_id_array`
    chain_id_array : np.ndarray
        A `'numpy'` array of chain IDs encoded as integers
    predict_masks : np.ndarray, optional
        Mask array where 1 indicates residues that were generated by a model and 0
        indicates residues with known coordinates, `'numpy'` array of shape `(L,)`
    cdrs : np.ndarray, optional
        A `'numpy'` array of shape `(L,)` where residues are marked
        with the corresponding CDR type (encoded as integers, see `proteinflow.constants.CDR_ALPHABET`)
    protein_id : str, optional
        Protein ID

    Returns
    -------
    entry : ProteinEntry
        A `ProteinEntry` object

    """
    seqs_list = []
    crds_list = []
    masks_list = []
    chain_ids_list = []
    predict_masks_list = None if predict_masks is None else []
    cdrs_list = None if cdrs is None else []
    for chain_id, ind in chain_id_dict.items():
        chain_ids_list.append(chain_id)
        chain_mask = chain_id_array == ind
        seqs_list.append(ProteinEntry.decode_sequence(seqs[chain_mask]))
        if crds.shape[1] != 14:
            crds_ = np.zeros((crds[chain_mask].shape[0], 14, 3))
            crds_[:, :4, :] = ProteinEntry._to_numpy(crds[chain_mask])
        else:
            crds_ = ProteinEntry._to_numpy(crds[chain_mask])
        crds_list.append(crds_)
        masks_list.append(ProteinEntry._to_numpy(masks[chain_mask]))
        if predict_masks is not None:
            predict_masks_list.append(
                ProteinEntry._to_numpy(predict_masks[chain_mask])
            )
        if cdrs is not None:
            cdrs_list.append(ProteinEntry.decode_cdr(cdrs[chain_mask]))
    return ProteinEntry(
        seqs_list,
        crds_list,
        masks_list,
        chain_ids_list,
        predict_masks_list,
        cdrs_list,
        protein_id,
    )
def from_dict(dictionary)

Load a protein entry from a dictionary.

Parameters

dictionary : dict
A nested dictionary where first-level keys are chain IDs and second-level keys are the following: - 'seq' : amino acid sequence (one-letter code) - 'crd_bb' : backbone coordinates, shaped (L, 4, 3) - 'crd_sc' : sidechain coordinates, shaped (L, 10, 3) - 'msk' : mask array where 1 indicates residues with known coordinates and 0 indicates missing values, shaped (L,) - 'cdr' (optional): CDR information, shaped (L,) where CDR residues are marked with the corresponding type ('H1', 'L1', …) and non-CDR residues are marked with '-' - 'predict_msk' (optional): mask array where 1 indicates residues that were generated by a model and 0 indicates residues with known coordinates, shaped (L,) It can also contain a 'protein_id' first-level key.

Returns

entry : ProteinEntry
A ProteinEntry object
Expand source code
@staticmethod
def from_dict(dictionary):
    """Load a protein entry from a dictionary.

    Parameters
    ----------
    dictionary : dict
        A nested dictionary where first-level keys are chain IDs and
        second-level keys are the following:
        - `'seq'` : amino acid sequence (one-letter code)
        - `'crd_bb'` : backbone coordinates, shaped `(L, 4, 3)`
        - `'crd_sc'` : sidechain coordinates, shaped `(L, 10, 3)`
        - `'msk'` : mask array where 1 indicates residues with known coordinates and 0
            indicates missing values, shaped `(L,)`
        - `'cdr'` (optional): CDR information, shaped `(L,)` where CDR residues are marked
            with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR residues are marked with `'-'`
        - `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
            indicates residues with known coordinates, shaped `(L,)`
        It can also contain a `'protein_id'` first-level key.

    Returns
    -------
    entry : ProteinEntry
        A `ProteinEntry` object

    """
    chains = sorted([x for x in dictionary.keys() if x != "protein_id"])
    seq = [dictionary[k]["seq"] for k in chains]
    crd = [
        np.concatenate([dictionary[k]["crd_bb"], dictionary[k]["crd_sc"]], axis=1)
        for k in chains
    ]
    mask = [dictionary[k]["msk"] for k in chains]
    cdr = [dictionary[k].get("cdr", None) for k in chains]
    predict_mask = [dictionary[k].get("predict_msk", None) for k in chains]
    return ProteinEntry(
        seqs=seq,
        crds=crd,
        masks=mask,
        cdrs=cdr,
        chain_ids=chains,
        predict_masks=predict_mask,
        protein_id=dictionary.get("protein_id"),
    )
def from_id(pdb_id, local_folder='.', heavy_chain=None, light_chain=None, antigen_chains=None)

Load a protein entry from a PDB file.

Parameters

pdb_id : str
PDB ID of the protein
local_folder : str, default "."
Path to the local folder where the PDB file is saved
heavy_chain : str, optional
Chain ID of the heavy chain (to load a SAbDab entry)
light_chain : str, optional
Chain ID of the light chain (to load a SAbDab entry)
antigen_chains : list of str, optional
Chain IDs of the antigen chains (to load a SAbDab entry)

Returns

entry : ProteinEntry
A ProteinEntry object
Expand source code
@staticmethod
def from_id(
    pdb_id,
    local_folder=".",
    heavy_chain=None,
    light_chain=None,
    antigen_chains=None,
):
    """Load a protein entry from a PDB file.

    Parameters
    ----------
    pdb_id : str
        PDB ID of the protein
    local_folder : str, default "."
        Path to the local folder where the PDB file is saved
    heavy_chain : str, optional
        Chain ID of the heavy chain (to load a SAbDab entry)
    light_chain : str, optional
        Chain ID of the light chain (to load a SAbDab entry)
    antigen_chains : list of str, optional
        Chain IDs of the antigen chains (to load a SAbDab entry)

    Returns
    -------
    entry : ProteinEntry
        A `ProteinEntry` object

    """
    if heavy_chain is not None or light_chain is not None:
        pdb_entry = SAbDabEntry.from_id(
            pdb_id=pdb_id,
            local_folder=local_folder,
            heavy_chain=heavy_chain,
            light_chain=light_chain,
            antigen_chains=antigen_chains,
        )
    else:
        pdb_entry = PDBEntry.from_id(pdb_id=pdb_id)
    return ProteinEntry.from_pdb_entry(pdb_entry)
def from_pdb(pdb_path, fasta_path=None, heavy_chain=None, light_chain=None, antigen_chains=None)

Load a protein entry from a PDB file.

Parameters

pdb_path : str
Path to the PDB file
fasta_path : str, optional
Path to the FASTA file; if not specified, the sequence is extracted from the PDB file
heavy_chain : str, optional
Chain ID of the heavy chain (to load a SAbDab entry)
light_chain : str, optional
Chain ID of the light chain (to load a SAbDab entry)
antigen_chains : list of str, optional
Chain IDs of the antigen chains (to load a SAbDab entry)

Returns

entry : ProteinEntry
A ProteinEntry object
Expand source code
@staticmethod
def from_pdb(
    pdb_path,
    fasta_path=None,
    heavy_chain=None,
    light_chain=None,
    antigen_chains=None,
):
    """Load a protein entry from a PDB file.

    Parameters
    ----------
    pdb_path : str
        Path to the PDB file
    fasta_path : str, optional
        Path to the FASTA file; if not specified, the sequence is extracted
        from the PDB file
    heavy_chain : str, optional
        Chain ID of the heavy chain (to load a SAbDab entry)
    light_chain : str, optional
        Chain ID of the light chain (to load a SAbDab entry)
    antigen_chains : list of str, optional
        Chain IDs of the antigen chains (to load a SAbDab entry)

    Returns
    -------
    entry : ProteinEntry
        A `ProteinEntry` object

    """
    if heavy_chain is not None or light_chain is not None:
        pdb_entry = SAbDabEntry(
            pdb_path=pdb_path,
            fasta_path=fasta_path,
            heavy_chain=heavy_chain,
            light_chain=light_chain,
            antigen_chains=antigen_chains,
        )
    else:
        pdb_entry = PDBEntry(pdb_path=pdb_path, fasta_path=fasta_path)
    return ProteinEntry.from_pdb_entry(pdb_entry)
def from_pdb_entry(pdb_entry)

Load a protein entry from a PDBEntry object.

Parameters

pdb_entry : PDBEntry
A PDBEntry object

Returns

entry : ProteinEntry
A ProteinEntry object
Expand source code
@staticmethod
def from_pdb_entry(pdb_entry):
    """Load a protein entry from a `PDBEntry` object.

    Parameters
    ----------
    pdb_entry : PDBEntry
        A `PDBEntry` object

    Returns
    -------
    entry : ProteinEntry
        A `ProteinEntry` object

    """
    pdb_dict = {}
    fasta_dict = pdb_entry.get_fasta()
    for (chain,) in pdb_entry.get_chains():
        pdb_dict[chain] = {}
        fasta_seq = fasta_dict[chain]

        # align fasta and pdb and check criteria)
        mask = pdb_entry.get_mask([chain])[chain]
        if isinstance(pdb_entry, SAbDabEntry):
            pdb_dict[chain]["cdr"] = pdb_entry.get_cdr([chain])[chain]
        pdb_dict[chain]["seq"] = fasta_seq
        pdb_dict[chain]["msk"] = mask

        # go over rows of coordinates
        crd_arr = pdb_entry.get_coordinates_array(chain)

        pdb_dict[chain]["crd_bb"] = crd_arr[:, :4, :]
        pdb_dict[chain]["crd_sc"] = crd_arr[:, 4:, :]
        pdb_dict[chain]["msk"][
            (pdb_dict[chain]["crd_bb"] == 0).sum(-1).sum(-1) == 4
        ] = 0
    pdb_dict["protein_id"] = pdb_entry.pdb_id
    return ProteinEntry.from_dict(pdb_dict)
def from_pickle(path)

Load a protein entry from a pickle file.

Parameters

path : str
Path to the pickle file

Returns

entry : ProteinEntry
A ProteinEntry object
Expand source code
@staticmethod
def from_pickle(path):
    """Load a protein entry from a pickle file.

    Parameters
    ----------
    path : str
        Path to the pickle file

    Returns
    -------
    entry : ProteinEntry
        A `ProteinEntry` object

    """
    with open(path, "rb") as f:
        data = pickle.load(f)
    return ProteinEntry.from_dict(data)
def igfold_metrics(entries, use_openmm=False)

Calculate IgFold metrics for a list of entries.

Parameters

entries : list of ProteinEntry
A list of ProteinEntry objects
use_openmm : bool, default False
Whether to use refinement with OpenMM

Returns

plddts_full : list of float
A list of PLDDT scores averaged over all residues
plddts_predicted : list of float
A list of PLDDT scores averaged over predicted residues
rmsds : list of float
A list of RMSD values of aligned structures (predicted residues only)
tm_scores : list of float
A list of TM scores of individual chains (self-consistency)
Expand source code
@staticmethod
def igfold_metrics(entries, use_openmm=False):
    """Calculate IgFold metrics for a list of entries.

    Parameters
    ----------
    entries : list of ProteinEntry
        A list of `ProteinEntry` objects
    use_openmm : bool, default False
        Whether to use refinement with OpenMM

    Returns
    -------
    plddts_full : list of float
        A list of PLDDT scores averaged over all residues
    plddts_predicted : list of float
        A list of PLDDT scores averaged over predicted residues
    rmsds : list of float
        A list of RMSD values of aligned structures (predicted residues only)
    tm_scores : list of float
        A list of TM scores of individual chains (self-consistency)

    """
    chains_list = [
        [
            x
            for x in entry.get_chains()
            if x not in entry.get_chain_type_dict()["antigen"]
        ]
        for entry in entries
    ]
    sequences = [
        {
            chain: entry.get_sequence(chains=[chain], only_known=True)
            for chain in chains
        }
        for entry, chains in zip(entries, chains_list)
    ]
    igfold_generate(sequences, use_openmm=use_openmm)
    folder = "igfold_refine_output" if use_openmm else "igfold_output"
    igfold_paths = [
        os.path.join(folder, f"seq_{i}.pdb") for i in range(len(sequences))
    ]
    prmsds_predicted = [
        confidence_from_file(
            path, entry.get_predict_mask(only_known=True, chains=chains)
        )
        for path, entry, chains in zip(igfold_paths, entries, chains_list)
    ]
    prmsds_full = [confidence_from_file(path) for path in igfold_paths]
    rmsds = []
    tm_scores = []
    for entry, path in zip(entries, igfold_paths):
        igfold_entry = ProteinEntry.from_pdb(path)
        temp_file = entry._temp_pdb_file()
        igfold_entry.align_structure(
            reference_pdb_path=temp_file,
            save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
            chain_ids=entry.get_predicted_chains(),
        )
        rmsds.append(
            entry.ca_rmsd(
                ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
            )
        )
        tm_scores.append(
            entry.tm_score(
                igfold_entry,
            )
        )
    return prmsds_full, prmsds_predicted, rmsds, tm_scores
def immunebuilder_metrics(entries, protein_type='antibody')

Calculate ImmuneBuilder metrics for a list of entries.

Parameters

entries : list of ProteinEntry
A list of ProteinEntry objects
protein_type : {"antibody", "nanobody", "tcr"}, default "antibody"
The type of the protein

Returns

prmsds_full : list of float
A list of PRMSD scores averaged over all residues
prmsds_predicted : list of float
A list of PRMSD scores averaged over predicted residues
rmsds : list of float
A list of RMSD values of aligned structures (predicted residues only)
tm_scores : list of float
A list of TM scores of aligned structures
Expand source code
@staticmethod
def immunebuilder_metrics(entries, protein_type="antibody"):
    """Calculate ImmuneBuilder metrics for a list of entries.

    Parameters
    ----------
    entries : list of ProteinEntry
        A list of `ProteinEntry` objects
    protein_type : {"antibody", "nanobody", "tcr"}, default "antibody"
        The type of the protein

    Returns
    -------
    prmsds_full : list of float
        A list of PRMSD scores averaged over all residues
    prmsds_predicted : list of float
        A list of PRMSD scores averaged over predicted residues
    rmsds : list of float
        A list of RMSD values of aligned structures (predicted residues only)
    tm_scores : list of float
        A list of TM scores of aligned structures

    """
    sequences = []
    chains_list = [
        [
            x
            for x in entry.get_chains()
            if x not in entry.get_chain_type_dict()["antigen"]
        ]
        for entry in entries
    ]
    for chains, entry in zip(chains_list, entries):
        chain_type_dict = entry.get_chain_type_dict()
        sequences.append(
            {
                key[0].upper(): entry.get_sequence(
                    chains=[chain_type_dict[key]], only_known=True
                )
                for key in ["heavy", "light"]
                if key in chain_type_dict
            }
        )
    immunebuilder_generate(sequences, protein_type=protein_type)
    generated_paths = [
        os.path.join("immunebuilder_output", f"seq_{i}.pdb")
        for i in range(len(sequences))
    ]
    prmsds_predicted = [
        confidence_from_file(
            path, entry.get_predict_mask(only_known=True, chains=chains)
        )
        for path, entry, chains in zip(generated_paths, entries, chains_list)
    ]
    prmsds_full = [confidence_from_file(path) for path in generated_paths]
    rmsds = []
    tm_scores = []
    for entry, path, chains in zip(entries, generated_paths, chains_list):
        generated_entry = ProteinEntry.from_pdb(path)
        chain_type_dict = entry.get_chain_type_dict()
        chain_rename_dict = {}
        if "light" in chain_type_dict:
            chain_rename_dict["L"] = chain_type_dict["light"]
        if "heavy" in chain_type_dict:
            chain_rename_dict["H"] = chain_type_dict["heavy"]
        generated_entry.rename_chains(chain_rename_dict)
        temp_file = entry._temp_pdb_file()
        generated_entry.align_structure(
            reference_pdb_path=temp_file,
            save_pdb_path=path.rsplit(".", 1)[0] + "_aligned.pdb",
            chain_ids=entry.get_predicted_chains(),
        )
        rmsds.append(
            entry.ca_rmsd(
                ProteinEntry.from_pdb(path.rsplit(".", 1)[0] + "_aligned.pdb")
            )
        )
        tm_scores.append(
            entry.tm_score(
                generated_entry,
                chains=chains,
            )
        )
    return prmsds_full, prmsds_predicted, rmsds, tm_scores
def retrieve_ligands_from_pickle(path)

Retrieve ligands from a pickle file.

Parameters

path : str
Path to the pickle file

Returns

chain2ligand : dict
A dictionary where keys are chain IDs and values are ligand names
Expand source code
@staticmethod
def retrieve_ligands_from_pickle(path):
    """Retrieve ligands from a pickle file.

    Parameters
    ----------
    path : str
        Path to the pickle file

    Returns
    -------
    chain2ligand : dict
        A dictionary where keys are chain IDs and values are ligand names

    """
    with open(path, "rb") as f:
        data = pickle.load(f)
    chain2ligand = {}
    for chain in data:
        if "ligand" not in data[chain]:
            continue
        chain2ligand[chain] = data[chain]["ligand"]
    return chain2ligand

Methods

def ablang_pll(self, ablang_model_name='heavy', average=False)

Calculate the AbLang PLL score of the protein.

Parameters

ablang_model_name : str, default "heavy"
Name of the AbLang model to use
average : bool, default False
If True, the score is averaged over the residues; otherwise, the score is summed

Returns

score : float
The AbLang PLL score of the protein
Expand source code
def ablang_pll(self, ablang_model_name="heavy", average=False):
    """Calculate the AbLang PLL score of the protein.

    Parameters
    ----------
    ablang_model_name : str, default "heavy"
        Name of the AbLang model to use
    average : bool, default False
        If `True`, the score is averaged over the residues; otherwise, the score is summed

    Returns
    -------
    score : float
        The AbLang PLL score of the protein

    """
    chains = self.get_predicted_chains()
    chain_sequences = [self.get_sequence(chains=[chain]) for chain in chains]
    if self.predict_mask is not None:
        predict_masks = [
            (self.get_predict_mask(chains=[chain])).astype(float)
            for chain in chains
        ]
    else:
        predict_masks = [np.ones(len(x)) for x in chain_sequences]
    out = sum(
        [
            ablang_pll(
                sequence,
                predict_mask,
                ablang_model_name=ablang_model_name,
                average=False,
            )
            for sequence, predict_mask in zip(chain_sequences, predict_masks)
        ]
    )
    if average:
        out /= self.get_predict_mask(chains=chains).sum()
    return out
def accuracy(self, seq_before)

Calculate the accuracy of the protein.

Parameters

seq_before : str
A string with the sequence before the mutation

Returns

score : float
The accuracy of the protein
Expand source code
def accuracy(self, seq_before):
    """Calculate the accuracy of the protein.

    Parameters
    ----------
    seq_before : str
        A string with the sequence before the mutation

    Returns
    -------
    score : float
        The accuracy of the protein

    """
    seq_after = self.get_sequence(encode=False)
    seq_before = np.array(list(seq_before))
    seq_after = np.array(list(seq_after))
    if self.predict_mask is not None:
        predict_mask = self.get_predict_mask()
        seq_before = seq_before[predict_mask.astype(bool)]
        seq_after = seq_after[predict_mask.astype(bool)]
    return np.mean(seq_before == seq_after)
def align_structure(self, reference_pdb_path, save_pdb_path, chain_ids=None)

Aligns the structure to a reference structure using the CA atoms.

Parameters

reference_pdb_path : str
Path to the reference structure (in .pdb format)
save_pdb_path : str
Path where the aligned structure should be saved (in .pdb format)
chain_ids : list of str, optional
If specified, only the chains with the specified IDs are aligned
Expand source code
def align_structure(self, reference_pdb_path, save_pdb_path, chain_ids=None):
    """Aligns the structure to a reference structure using the CA atoms.

    Parameters
    ----------
    reference_pdb_path : str
        Path to the reference structure (in .pdb format)
    save_pdb_path : str
        Path where the aligned structure should be saved (in .pdb format)
    chain_ids : list of str, optional
        If specified, only the chains with the specified IDs are aligned

    """
    pdb_parser = Bio.PDB.PDBParser(QUIET=True)

    temp_file = self._temp_pdb_file()
    ref_structure = pdb_parser.get_structure("reference", reference_pdb_path)
    sample_structure = pdb_parser.get_structure("sample", temp_file)

    ref_model = ref_structure[0]
    sample_model = sample_structure[0]

    ref_atoms = []
    sample_atoms = []

    for ref_chain in ref_model:
        if chain_ids is not None and ref_chain.id not in chain_ids:
            continue
        for ref_res in ref_chain:
            if "CA" in ref_res:
                ref_atoms.append(ref_res["CA"])
            elif "C" in ref_res:
                ref_atoms.append(ref_res["C"])
                warnings.warn(
                    "Using a C atom instead of CA for alignment in the reference structure"
                )

    for sample_chain in sample_model:
        if chain_ids is not None and sample_chain.id not in chain_ids:
            continue
        for sample_res in sample_chain:
            if "CA" in sample_res:
                sample_atoms.append(sample_res["CA"])
            elif "C" in sample_res:
                sample_atoms.append(sample_res["C"])
                warnings.warn(
                    "Using a C atom instead of CA for alignment in the sample structure"
                )

    super_imposer = Bio.PDB.Superimposer()
    super_imposer.set_atoms(ref_atoms, sample_atoms)
    super_imposer.apply(sample_model.get_atoms())

    io = Bio.PDB.PDBIO()
    io.set_structure(sample_structure)
    io.save(save_pdb_path)
def apply_mask(self, mask)

Apply a mask to the protein.

Parameters

mask : np.ndarray
A boolean mask of shape (L,) where L is the length of the protein (the chains are concatenated in alphabetical order)

Returns

entry : ProteinEntry
A new ProteinEntry object
Expand source code
def apply_mask(self, mask):
    """Apply a mask to the protein.

    Parameters
    ----------
    mask : np.ndarray
        A boolean mask of shape `(L,)` where `L` is the length of the protein (the chains are concatenated in alphabetical order)

    Returns
    -------
    entry : ProteinEntry
        A new `ProteinEntry` object

    """
    start = 0
    out_dict = {}
    for chain in self.get_chains():
        out_dict[chain] = {}
        chain_mask = mask[start : start + self.get_length([chain])]
        start += self.get_length([chain])
        out_dict[chain]["seq"] = self.decode_sequence(
            self.get_sequence(chains=[chain], encode=True)[chain_mask]
        )
        out_dict[chain]["crd_bb"] = self.get_coordinates(
            chains=[chain], bb_only=True
        )[chain_mask]
        out_dict[chain]["crd_sc"] = self.get_coordinates(chains=[chain])[:, 4:][
            chain_mask
        ]
        out_dict[chain]["msk"] = self.get_mask(chains=[chain])[chain_mask]
        if self.has_cdr():
            out_dict[chain]["cdr"] = self.decode_cdr(
                self.get_cdr([chain], encode=True)[chain_mask]
            )
        if self.has_predict_mask():
            out_dict[chain]["predict_msk"] = self.predict_mask[chain][chain_mask]
    if self.id is not None:
        out_dict["protein_id"] = self.id
    return ProteinEntry.from_dict(out_dict)
def blosum62_score(self, seq_before, average=True, only_predicted=True)

Calculate the BLOSUM62 score of the protein.

Parameters

seq_before : str
A string with the sequence before the mutation
average : bool, default True
If True, the score is averaged over the residues; otherwise, the score is summed
only_predicted : bool, default True
If True and prediction masks are available, only predicted residues are considered

Returns

score : float
The BLOSUM62 score of the protein
Expand source code
def blosum62_score(self, seq_before, average=True, only_predicted=True):
    """Calculate the BLOSUM62 score of the protein.

    Parameters
    ----------
    seq_before : str
        A string with the sequence before the mutation
    average : bool, default True
        If `True`, the score is averaged over the residues; otherwise, the score is summed
    only_predicted : bool, default True
        If `True` and prediction masks are available, only predicted residues are considered

    Returns
    -------
    score : float
        The BLOSUM62 score of the protein

    """
    seq_after = self.get_sequence(encode=False)
    if self.predict_mask is not None and only_predicted:
        predict_mask = self.get_predict_mask()
        seq_before = np.array(list(seq_before))[predict_mask.astype(bool)]
        seq_after = np.array(list(seq_after))[predict_mask.astype(bool)]
    score = blosum62_score(seq_before, seq_after)
    if average:
        score /= len(seq_before)
    return score
def ca_rmsd(self, entry, only_predicted=True)

Calculate CA RMSD between two proteins.

Parameters

entry : ProteinEntry
A ProteinEntry object
only_predicted : bool, default True
If True and prediction masks are available, only predicted residues are considered

Returns

rmsd : float
The CA RMSD between the two proteins
Expand source code
def ca_rmsd(self, entry, only_predicted=True):
    """Calculate CA RMSD between two proteins.

    Parameters
    ----------
    entry : ProteinEntry
        A `ProteinEntry` object
    only_predicted : bool, default True
        If `True` and prediction masks are available, only predicted residues are considered

    Returns
    -------
    rmsd : float
        The CA RMSD between the two proteins

    """
    if only_predicted and not self.has_predict_mask():
        only_predicted = False
    chains = [x for x in self.get_chains() if x in entry.get_chains()]
    structure1 = self.get_coordinates(only_known=True, chains=chains)[:, 2]
    structure2 = entry.get_coordinates(only_known=True, chains=chains)[:, 2]
    if only_predicted:
        mask = self.get_predict_mask(only_known=True, chains=chains).astype(bool)
        structure1 = structure1[mask]
        structure2 = structure2[mask]
    return ca_rmsd(structure1, structure2)
def chemical_features(self, chains=None)

Calculate chemical features of the protein.

Parameters

chains : list of str, optional
If specified, only the chemical features of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs

Returns

features : np.ndarray
A 'numpy' array of shape (L, 4) with chemical features of the protein (hydropathy, volume, charge, polarity, acceptor/donor); missing values are marked with zeros
chains : list of str, optional
If specified, only the chemical features of the specified chains are returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
Expand source code
def chemical_features(self, chains=None):
    """Calculate chemical features of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the chemical features of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    Returns
    -------
    features : np.ndarray
        A `'numpy'` array of shape `(L, 4)` with chemical features of the
        protein (hydropathy, volume, charge, polarity, acceptor/donor); missing
        values are marked with zeros
    chains : list of str, optional
        If specified, only the chemical features of the specified chains are returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    """
    chains = self._get_chains_list(chains)
    seq = "".join([self.seq[chain] for chain in chains])
    features = np.array([_PMAP(x) for x in seq])
    return features
def cut_missing_edges(self)

Cut off the ends of the protein sequence that have missing coordinates.

Expand source code
def cut_missing_edges(self):
    """Cut off the ends of the protein sequence that have missing coordinates."""
    for chain in self.get_chains():
        mask = self.mask[chain]
        known_ind = np.where(mask == 1)[0]
        start, end = known_ind[0], known_ind[-1] + 1
        self.seq[chain] = self.seq[chain][start:end]
        self.crd[chain] = self.crd[chain][start:end]
        self.mask[chain] = self.mask[chain][start:end]
        if self.cdr[chain] is not None:
            self.cdr[chain] = self.cdr[chain][start:end]
def dihedral_angles(self, chains=None)

Calculate the backbone dihedral angles (phi, psi) of the protein.

Parameters

chains : list of str, optional
If specified, only the dihedral angles of the specified chains are returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs

Returns

angles : np.ndarray
A 'numpy' array of shape (L, 2) with backbone dihedral angles (phi, psi) in degrees; missing values are marked with zeros
chains : list of str, optional
If specified, only the dihedral angles of the specified chains are returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
Expand source code
def dihedral_angles(self, chains=None):
    """Calculate the backbone dihedral angles (phi, psi) of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the dihedral angles of the specified chains are returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    Returns
    -------
    angles : np.ndarray
        A `'numpy'` array of shape `(L, 2)` with backbone dihedral angles
        (phi, psi) in degrees; missing values are marked with zeros
    chains : list of str, optional
        If specified, only the dihedral angles of the specified chains are returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    """
    angles = []
    chains = self._get_chains_list(chains)
    # N, C, Ca, O
    # psi
    for chain in chains:
        chain_angles = []
        crd = self.get_coordinates([chain])
        mask = self.get_mask([chain])
        p = crd[:-1, [0, 2, 1], :]
        p = np.concatenate([p, crd[1:, [0], :]], 1)
        p = np.pad(p, ((0, 1), (0, 0), (0, 0)))
        chain_angles.append(_dihedral_angle(p, mask))
        # phi
        p = crd[:-1, [1], :]
        p = np.concatenate([p, crd[1:, [0, 2, 1]]], 1)
        p = np.pad(p, ((1, 0), (0, 0), (0, 0)))
        chain_angles.append(_dihedral_angle(p, mask))
        angles.append(np.stack(chain_angles, -1))
    angles = np.concatenate(angles, 0)
    return angles
def esm_pll(self, esm_model_name='esm2_t30_150M_UR50D', esm_model_objects=None, average=False)

Calculate the ESM PLL score of the protein.

Parameters

esm_model_name : str, default "esm2_t30_150M_UR50D"
Name of the ESM-2 model to use
esm_model_objects : tuple, optional
Tuple of ESM-2 model, batch converter and tok_to_idx dictionary (if not None, esm_model_name will be ignored)
average : bool, default False
If True, the score is averaged over the residues; otherwise, the score is summed

Returns

score : float
The ESM PLL score of the protein
Expand source code
def esm_pll(
    self,
    esm_model_name="esm2_t30_150M_UR50D",
    esm_model_objects=None,
    average=False,
):
    """Calculate the ESM PLL score of the protein.

    Parameters
    ----------
    esm_model_name : str, default "esm2_t30_150M_UR50D"
        Name of the ESM-2 model to use
    esm_model_objects : tuple, optional
        Tuple of ESM-2 model, batch converter and tok_to_idx dictionary (if not None, `esm_model_name` will be ignored)
    average : bool, default False
        If `True`, the score is averaged over the residues; otherwise, the score is summed

    Returns
    -------
    score : float
        The ESM PLL score of the protein

    """
    chains = self.get_chains()
    chain_sequences = [self.get_sequence(chains=[chain]) for chain in chains]
    if self.predict_mask is not None:
        predict_masks = [
            (self.get_predict_mask(chains=[chain])).astype(float)
            for chain in chains
        ]
    else:
        predict_masks = [np.ones(len(x)) for x in chain_sequences]
    return esm_pll(
        chain_sequences,
        predict_masks,
        esm_model_name=esm_model_name,
        esm_model_objects=esm_model_objects,
        average=average,
    )
def get_atom_mask(self, chains=None, cdr=None)

Get the atom mask of the protein.

Parameters

chains : str, optional
If specified, only the atom masks of the specified chains are returned (in the same order); otherwise, all atom masks are concatenated in alphabetical order of the chain IDs
cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
If specified, only the CDR region of the specified type is returned

Returns

atom_mask : np.ndarray
Atom mask array where 1 indicates atoms with known coordinates and 0 indicates missing or non-existing values, shaped (L, 14, 3)
Expand source code
def get_atom_mask(self, chains=None, cdr=None):
    """Get the atom mask of the protein.

    Parameters
    ----------
    chains : str, optional
        If specified, only the atom masks of the specified chains are returned (in the same order);
        otherwise, all atom masks are concatenated in alphabetical order of the chain IDs
    cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
        If specified, only the CDR region of the specified type is returned

    Returns
    -------
    atom_mask : np.ndarray
        Atom mask array where 1 indicates atoms with known coordinates and 0
        indicates missing or non-existing values, shaped `(L, 14, 3)`

    """
    if cdr is not None and self.cdr is None:
        raise ValueError("CDR information not available")
    if cdr is not None:
        assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
    chains = self._get_chains_list(chains)
    seq = "".join([self.seq[c] for c in chains])
    atom_mask = np.concatenate([ATOM_MASKS[aa] for aa in seq])
    atom_mask[self.mask == 0] = 0
    if cdr is not None:
        atom_mask = atom_mask[self.cdr == cdr]
    return atom_mask
def get_cdr(self, chains=None, encode=False)

Get the CDR information of the protein.

Parameters

chains : list of str, optional
If specified, only the CDR information of the specified chains is returned (in the same order); otherwise, all CDR information is concatenated in alphabetical order of the chain IDs
encode : bool, default False
If True, the CDR information is encoded as a 'numpy' array of integers where each integer corresponds to the index of the CDR type in proteinflow.constants.CDR_ALPHABET

Returns

cdr : np.ndarray or None
A 'numpy' array of shape (L,) where CDR residues are marked with the corresponding type ('H1', 'L1', …) and non-CDR residues are marked with '-' or an encoded array of integers ir encode=True; None if CDR information is not available
chains : list of str, optional
If specified, only the CDR information of the specified chains is returned (in the same order); otherwise, all CDR information is concatenated in alphabetical order of the chain IDs
Expand source code
def get_cdr(self, chains=None, encode=False):
    """Get the CDR information of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the CDR information of the specified chains is
        returned (in the same order); otherwise, all CDR information is concatenated in
        alphabetical order of the chain IDs
    encode : bool, default False
        If `True`, the CDR information is encoded as a `'numpy'` array of
        integers where each integer corresponds to the index of the CDR
        type in `proteinflow.constants.CDR_ALPHABET`

    Returns
    -------
    cdr : np.ndarray or None
        A `'numpy'` array of shape `(L,)` where CDR residues are marked
        with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR
        residues are marked with `'-'` or an encoded array of integers
        ir `encode=True`; `None` if CDR information is not available
    chains : list of str, optional
        If specified, only the CDR information of the specified chains is
        returned (in the same order); otherwise, all CDR information is concatenated in
        alphabetical order of the chain IDs

    """
    chains = self._get_chains_list(chains)
    if self.cdr is None:
        return None
    cdr = np.concatenate([self.cdr[c] for c in chains], axis=0)
    if encode:
        cdr = np.array([CDR_REVERSE[aa] for aa in cdr])
    return cdr
def get_cdr_length(self, chains)

Get the length of the CDR regions of a set of chains.

Parameters

chain : str
Chain ID

Returns

length : int
Length of the CDR regions of the chain
Expand source code
def get_cdr_length(self, chains):
    """Get the length of the CDR regions of a set of chains.

    Parameters
    ----------
    chain : str
        Chain ID

    Returns
    -------
    length : int
        Length of the CDR regions of the chain

    """
    if not self.has_cdr():
        return {x: None for x in ["H1", "H2", "H3", "L1", "L2", "L3"]}
    return {
        x: len(self.get_sequence(chains=chains, cdr=x))
        for x in ["H1", "H2", "H3", "L1", "L2", "L3"]
    }
def get_chain_id_array(self, chains=None, encode=True)

Get the chain ID array of the protein.

The chain ID array is a 'numpy' array of shape (L,) with the chain ID of each residue. The chain ID is the index of the chain in the alphabetical order of the chain IDs. To get a mapping from the index to the chain ID, use get_chain_id_dict().

Parameters

chains : list of str, optional
If specified, only the chain ID array of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
encode : bool, default True
If True, the chain ID is encoded as an integer; otherwise, the chain ID is the chain ID string

Returns

chain_id_array : np.ndarray
A 'numpy' array of shape (L,) with the chain ID of each residue
Expand source code
def get_chain_id_array(self, chains=None, encode=True):
    """Get the chain ID array of the protein.

    The chain ID array is a `'numpy'` array of shape `(L,)` with the chain ID of each residue.
    The chain ID is the index of the chain in the alphabetical order of the chain IDs. To get a
    mapping from the index to the chain ID, use `get_chain_id_dict()`.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the chain ID array of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs
    encode : bool, default True
        If True, the chain ID is encoded as an integer; otherwise, the chain ID is the chain ID string

    Returns
    -------
    chain_id_array : np.ndarray
        A `'numpy'` array of shape `(L,)` with the chain ID of each residue

    """
    id_dict = self.get_chain_id_dict()
    if encode:
        index_array = np.zeros(self.get_length(chains))
    else:
        index_array = np.empty(self.get_length(chains), dtype=object)
    start_index = 0
    for chain in self._get_chains_list(chains):
        chain_length = self.get_length([chain])
        index_array[start_index : start_index + chain_length] = (
            id_dict[chain] if encode else chain
        )
        start_index += chain_length
    return index_array
def get_chain_id_dict(self, chains=None)

Get the dictionary mapping from chain indices to chain IDs.

Parameters

chains : list of str, optional
If specified, only the chain IDs of the specified chains are returned

Returns

chain_id_dict : dict
A dictionary mapping from chain indices to chain IDs
Expand source code
def get_chain_id_dict(self, chains=None):
    """Get the dictionary mapping from chain indices to chain IDs.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the chain IDs of the specified chains are returned

    Returns
    -------
    chain_id_dict : dict
        A dictionary mapping from chain indices to chain IDs

    """
    chains = self._get_chains_list(chains)
    chain_id_dict = {x: i for i, x in enumerate(self.get_chains()) if x in chains}
    return chain_id_dict
def get_chain_type_dict(self, chains=None)

Get the chain types of the protein.

If the CDRs are not annotated, this function will return None. If there is no light or heavy chain, the corresponding key will be missing. If there is no antigen chain, the 'antigen' key will map to an empty list.

Parameters

chains : list of str, default None
Chain IDs to consider

Returns

chain_type_dict : dict
A dictionary with keys 'heavy', 'light' and 'antigen' and values the corresponding chain IDs
Expand source code
def get_chain_type_dict(self, chains=None):
    """Get the chain types of the protein.

    If the CDRs are not annotated, this function will return `None`.
    If there is no light or heavy chain, the corresponding key will be missing.
    If there is no antigen chain, the `'antigen'` key will map to an empty list.

    Parameters
    ----------
    chains : list of str, default None
        Chain IDs to consider

    Returns
    -------
    chain_type_dict : dict
        A dictionary with keys `'heavy'`, `'light'` and `'antigen'` and values
        the corresponding chain IDs

    """
    if not self.has_cdr():
        return None
    chain_type_dict = {"antigen": []}
    chains = self._get_chains_list(chains)
    for chain, cdr in self.cdr.items():
        if chain not in chains:
            continue
        u = np.unique(cdr)
        if "H1" in u:
            chain_type_dict["heavy"] = chain
        elif "L1" in u:
            chain_type_dict["light"] = chain
        else:
            chain_type_dict["antigen"].append(chain)
    return chain_type_dict
def get_chains(self)

Get the chain IDs of the protein.

Returns

chains : list of str
Chain IDs of the protein
Expand source code
def get_chains(self):
    """Get the chain IDs of the protein.

    Returns
    -------
    chains : list of str
        Chain IDs of the protein

    """
    return sorted(self.seq.keys())
def get_coordinates(self, chains=None, bb_only=False, cdr=None, only_known=False)

Get the coordinates of the protein.

Backbone atoms are in the order of N, C, CA, O; for the full-atom order see ProteinEntry.ATOM_ORDER (sidechain atoms come after the backbone atoms).

Parameters

chains : list of str, optional
If specified, only the coordinates of the specified chains are returned (in the same order); otherwise, all coordinates are concatenated in alphabetical order of the chain IDs
bb_only : bool, default False
If True, only the backbone atoms are returned
cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
If specified, only the CDR region of the specified type is returned
only_known : bool, default False
If True, only return the coordinates of residues with known coordinates

Returns

crd : np.ndarray
Coordinates of the protein, 'numpy' array of shape (L, 14, 3) or (L, 4, 3) if bb_only=True
Expand source code
def get_coordinates(self, chains=None, bb_only=False, cdr=None, only_known=False):
    """Get the coordinates of the protein.

    Backbone atoms are in the order of `N, C, CA, O`; for the full-atom
    order see `ProteinEntry.ATOM_ORDER` (sidechain atoms come after the
    backbone atoms).

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the coordinates of the specified chains are returned (in the same order);
        otherwise, all coordinates are concatenated in alphabetical order of the chain IDs
    bb_only : bool, default False
        If `True`, only the backbone atoms are returned
    cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
        If specified, only the CDR region of the specified type is returned
    only_known : bool, default False
        If `True`, only return the coordinates of residues with known coordinates

    Returns
    -------
    crd : np.ndarray
        Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)`
        or `(L, 4, 3)` if `bb_only=True`

    """
    if cdr is not None and self.cdr is None:
        raise ValueError("CDR information not available")
    if cdr is not None:
        assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
    chains = self._get_chains_list(chains)
    crd = np.concatenate([self.crd[c] for c in chains], axis=0)
    if cdr is not None:
        crd = crd[self.cdr == cdr]
    if bb_only:
        crd = crd[:, :4, :]
    if only_known:
        crd = crd[self.get_mask(chains=chains, cdr=cdr).astype(bool)]
    return crd
def get_id(self)

Return the ID of the protein.

Expand source code
def get_id(self):
    """Return the ID of the protein."""
    return self.id
def get_index_array(self, chains=None, index_bump=100)

Get the index array of the protein.

The index array is a 'numpy' array of shape (L,) with the index of each residue along the chain.

Parameters

chains : list of str, optional
If specified, only the index array of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
index_bump : int, default 0
If specified, the index is bumped by this number between chains

Returns

index_array : np.ndarray
A 'numpy' array of shape (L,) with the index of each residue along the chain; if multiple chains are specified, the index is bumped by index_bump at the beginning of each chain
Expand source code
def get_index_array(self, chains=None, index_bump=100):
    """Get the index array of the protein.

    The index array is a `'numpy'` array of shape `(L,)` with the index of each residue along the chain.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the index array of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs
    index_bump : int, default 0
        If specified, the index is bumped by this number between chains

    Returns
    -------
    index_array : np.ndarray
        A `'numpy'` array of shape `(L,)` with the index of each residue along the chain; if multiple chains
        are specified, the index is bumped by `index_bump` at the beginning of each chain

    """
    chains = self._get_chains_list(chains)
    start_value = 0
    start_index = 0
    index_array = np.zeros(self.get_length(chains))
    for chain in chains:
        chain_length = self.get_length([chain])
        index_array[start_index : start_index + chain_length] = np.arange(
            start_value, start_value + chain_length
        )
        start_value += chain_length + index_bump
        start_index += chain_length
    return index_array.astype(int)
def get_length(self, chains=None)

Get the total length of a set of chains.

Parameters

chain : str, optional
Chain ID; if None, the length of the whole protein is returned

Returns

length : int
Length of the chain
Expand source code
def get_length(self, chains=None):
    """Get the total length of a set of chains.

    Parameters
    ----------
    chain : str, optional
        Chain ID; if `None`, the length of the whole protein is returned

    Returns
    -------
    length : int
        Length of the chain

    """
    chains = self._get_chains_list(chains)
    return sum([len(self.seq[x]) for x in chains])
def get_ligand_features(self, ligands, chains=None)

Get ligand coordinates, smiles, and chain mapping.

Parameters

ligands : dict
A dictionary mapping from chain IDs to a list of ligands, where each ligand is a dictionary
chains : list of str, optional
If specified, only the ligands of the specified chains are returned (in the same order); otherwise, all ligands are concatenated in alphabetical order of the chain IDs

Returns

X_ligands : torch.Tensor
A 'torch' tensor of shape (N, 3) with the ligand coordinates
ligand_smiles : str
A string with the ligand smiles separated by a dot
ligand_chains : torch.Tensor
A 'torch' tensor of shape (N, 1) with the chain index of each atom
Expand source code
def get_ligand_features(self, ligands, chains=None):
    """Get ligand coordinates, smiles, and chain mapping.

    Parameters
    ----------
    ligands : dict
        A dictionary mapping from chain IDs to a list of ligands, where each ligand is a dictionary
    chains : list of str, optional
        If specified, only the ligands of the specified chains are returned (in the same order);
        otherwise, all ligands are concatenated in alphabetical order of the chain IDs

    Returns
    -------
    X_ligands : torch.Tensor
        A `'torch'` tensor of shape `(N, 3)` with the ligand coordinates
    ligand_smiles : str
        A string with the ligand smiles separated by a dot
    ligand_chains : torch.Tensor
        A `'torch'` tensor of shape `(N, 1)` with the chain index of each atom
    """
    chains = self._get_chains_list(chains)
    X_ligands = []
    ligand_smiles = []
    ligand_chains = []
    for chain_i, chain in enumerate(chains):
        all_smiles = ".".join([x["smiles"] for x in ligands[chain]])
        ligand_smiles.append(all_smiles)
        x_lig = np.concatenate([x["X"] for x in ligands[chain]])
        X_ligands.append(x_lig)
        ligand_chains += [[chain_i]] * len(x_lig)
    ligand_smiles = ".".join(ligand_smiles)
    X_ligands = from_numpy(np.concatenate(X_ligands, 0))
    ligand_chains = Tensor(ligand_chains)
    return (
        X_ligands,
        ligand_smiles,
        ligand_chains,
    )
def get_mask(self, chains=None, cdr=None, original=False)

Get the mask of the protein.

Parameters

chains : list of str, optional
If specified, only the masks of the specified chains are returned (in the same order); otherwise, all masks are concatenated in alphabetical order of the chain IDs
cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
If specified, only the CDR region of the specified type is returned
original : bool, default False
If True, return the original mask (before interpolation)

Returns

mask : np.ndarray
Mask array where 1 indicates residues with known coordinates and 0 indicates missing values
Expand source code
def get_mask(self, chains=None, cdr=None, original=False):
    """Get the mask of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the masks of the specified chains are returned (in the same order);
        otherwise, all masks are concatenated in alphabetical order of the chain IDs
    cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
        If specified, only the CDR region of the specified type is returned
    original : bool, default False
        If `True`, return the original mask (before interpolation)

    Returns
    -------
    mask : np.ndarray
        Mask array where 1 indicates residues with known coordinates and 0
        indicates missing values

    """
    if cdr is not None and self.cdr is None:
        raise ValueError("CDR information not available")
    if cdr is not None:
        assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
    chains = self._get_chains_list(chains)
    mask = np.concatenate(
        [self.mask_original[c] if original else self.mask[c] for c in chains],
        axis=0,
    )
    if cdr is not None:
        mask = mask[self.cdr == cdr]
    return mask
def get_predict_mask(self, chains=None, only_known=False)

Get the prediction mask of the protein.

The prediction mask is a 'numpy' array of shape (L,) with ones corresponding to residues that were generated by a model and zeros to residues with known coordinates. If the prediction mask is not available, None is returned.

Parameters

chains : list of str, optional
If specified, only the prediction mask of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
only_known : bool, default False
If True, only residues with known coordinates are returned

Returns

predict_mask : np.ndarray
A 'numpy' array of shape (L,) with ones corresponding to residues that were generated by a model and zeros to residues with known coordinates
Expand source code
def get_predict_mask(self, chains=None, only_known=False):
    """Get the prediction mask of the protein.

    The prediction mask is a `'numpy'` array of shape `(L,)` with ones
    corresponding to residues that were generated by a model and zeros to
    residues with known coordinates. If the prediction mask is not available,
    `None` is returned.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the prediction mask of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs
    only_known : bool, default False
        If `True`, only residues with known coordinates are returned

    Returns
    -------
    predict_mask : np.ndarray
        A `'numpy'` array of shape `(L,)` with ones corresponding to residues that were generated by a model and
        zeros to residues with known coordinates

    """
    if list(self.predict_mask.values())[0] is None:
        return None
    chains = self._get_chains_list(chains)
    predict_mask = np.concatenate([self.predict_mask[chain] for chain in chains])
    if only_known:
        mask = self.get_mask(chains=chains)
        predict_mask = predict_mask[mask.astype(bool)]
    return predict_mask
def get_predicted_chains(self)

Return a list of chain IDs that contain predicted residues.

Returns

chains : list of str
Chain IDs
Expand source code
def get_predicted_chains(self):
    """Return a list of chain IDs that contain predicted residues.

    Returns
    -------
    chains : list of str
        Chain IDs

    """
    if not self.has_predict_mask():
        raise ValueError("Predicted mask not available")
    return [k for k, v in self.predict_mask.items() if v.sum() != 0]
def get_predicted_entry(self)

Return a ProteinEntry object that only contains predicted residues.

Returns

entry : ProteinEntry
The truncated ProteinEntry object
Expand source code
def get_predicted_entry(self):
    """Return a `ProteinEntry` object that only contains predicted residues.

    Returns
    -------
    entry : ProteinEntry
        The truncated `ProteinEntry` object

    """
    if self.predict_mask is None:
        raise ValueError("Predicted mask not available")
    entry_dict = self.to_dict()
    for chain in self.get_chains():
        mask_ = self.predict_mask[chain].astype(bool)
        if mask_.sum() == 0:
            entry_dict.pop(chain)
            continue
        if mask_.sum() == len(mask_):
            continue
        seq_arr = np.array(list(entry_dict[chain]["seq"]))
        entry_dict[chain]["seq"] = "".join(seq_arr[mask_])
        entry_dict[chain]["crd_bb"] = entry_dict[chain]["crd_bb"][mask_]
        entry_dict[chain]["crd_sc"] = entry_dict[chain]["crd_sc"][mask_]
        entry_dict[chain]["msk"] = entry_dict[chain]["msk"][mask_]
        entry_dict[chain]["predict_msk"] = entry_dict[chain]["predict_msk"][mask_]
        if "cdr" in entry_dict[chain]:
            entry_dict[chain]["cdr"] = entry_dict[chain]["cdr"][mask_]
    return ProteinEntry.from_dict(entry_dict)
def get_protein_class(self)

Get the protein class.

Returns

protein_class : str
The protein class ("single_chain", "heteromer", "homomer")
Expand source code
def get_protein_class(self):
    """Get the protein class.

    Returns
    -------
    protein_class : str
        The protein class ("single_chain", "heteromer", "homomer")

    """
    if len(self.get_chains()) == 1:
        return "single_chain"
    else:
        for chain1, chain2 in itertools.combinations(self.get_chains(), 2):
            if len(chain1) > 0.9 * len(chain2) or len(chain2) > 0.9 * len(chain1):
                return "heteromer"
            if edit_distance(chain1, chain2) / max(len(chain1), len(chain2)) > 0.1:
                return "heteromer"
        return "homomer"
def get_sequence(self, chains=None, encode=False, cdr=None, only_known=False)

Get the amino acid sequence of the protein.

Parameters

chains : list of str, optional
If specified, only the sequences of the specified chains is returned (in the same order); otherwise, all sequences are concatenated in alphabetical order of the chain IDs
encode : bool, default False
If True, the sequence is encoded as a 'numpy' array of integers where each integer corresponds to the index of the amino acid in proteinflow.constants.ALPHABET
cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
If specified, only the CDR region of the specified type is returned
only_known : bool, default False
If True, only the residues with known coordinates are returned

Returns

seq : str or np.ndarray
Amino acid sequence of the protein (one-letter code) or an encoded sequence as a 'numpy' array of integers
Expand source code
def get_sequence(self, chains=None, encode=False, cdr=None, only_known=False):
    """Get the amino acid sequence of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the sequences of the specified chains is returned (in the same order);
        otherwise, all sequences are concatenated in alphabetical order of the chain IDs
    encode : bool, default False
        If `True`, the sequence is encoded as a `'numpy'` array of integers
        where each integer corresponds to the index of the amino acid in
        `proteinflow.constants.ALPHABET`
    cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
        If specified, only the CDR region of the specified type is returned
    only_known : bool, default False
        If `True`, only the residues with known coordinates are returned

    Returns
    -------
    seq : str or np.ndarray
        Amino acid sequence of the protein (one-letter code) or an encoded
        sequence as a `'numpy'` array of integers

    """
    if cdr is not None and self.cdr is None:
        raise ValueError("CDR information not available")
    if cdr is not None:
        assert cdr in CDR_REVERSE, f"CDR must be one of {list(CDR_REVERSE.keys())}"
    chains = self._get_chains_list(chains)
    seq = "".join([self.seq[c] for c in chains]).replace("B", "")
    if encode:
        seq = np.array([ALPHABET_REVERSE[aa] for aa in seq])
    elif cdr is not None or only_known:
        seq = np.array(list(seq))
    if cdr is not None:
        cdr_arr = self.get_cdr(chains=chains)
        seq = seq[cdr_arr == cdr]
    if only_known:
        seq = seq[self.get_mask(chains=chains, cdr=cdr).astype(bool)]
    if not encode and not isinstance(seq, str):
        seq = "".join(seq)
    return seq
def has_cdr(self)

Check if the protein is from the SAbDab database.

Returns

is_sabdab : bool
True if the protein is from the SAbDab database
Expand source code
def has_cdr(self):
    """Check if the protein is from the SAbDab database.

    Returns
    -------
    is_sabdab : bool
        True if the protein is from the SAbDab database

    """
    return list(self.cdr.values())[0] is not None
def has_predict_mask(self)

Check if the protein has a predicted mask.

Returns

has_predict_mask : bool
True if the protein has a predicted mask
Expand source code
def has_predict_mask(self):
    """Check if the protein has a predicted mask.

    Returns
    -------
    has_predict_mask : bool
        True if the protein has a predicted mask

    """
    return list(self.predict_mask.values())[0] is not None
def interpolate_coords(self, fill_ends=True)

Fill in missing values in the coordinates arrays with linear interpolation.

Parameters

fill_ends : bool, default True
If True, fill in missing values at the ends of the protein sequence with the edge values; otherwise fill them in with zeros
Expand source code
def interpolate_coords(self, fill_ends=True):
    """Fill in missing values in the coordinates arrays with linear interpolation.

    Parameters
    ----------
    fill_ends : bool, default True
        If `True`, fill in missing values at the ends of the protein sequence with the edge values;
        otherwise fill them in with zeros

    """
    for chain in self.get_chains():
        self.crd[chain], self.mask[chain] = interpolate_coords(
            self.crd[chain], self.mask[chain], fill_ends=fill_ends
        )
def is_valid_pair(self, chain1, chain2, cutoff=10)

Check if two chains are a valid pair based on the distance between them.

We consider two chains to be a valid pair if the distance between them is smaller than cutoff Angstroms. The distance is calculated as the minimum distance between any two atoms of the two chains.

Parameters

chain1 : str
Chain ID of the first chain
chain2 : str
Chain ID of the second chain
cutoff : int, optional
Minimum distance between the two chains (in Angstroms)

Returns

valid : bool
True if the two chains are a valid pair, False otherwise
Expand source code
@lru_cache()
def is_valid_pair(self, chain1, chain2, cutoff=10):
    """Check if two chains are a valid pair based on the distance between them.

    We consider two chains to be a valid pair if the distance between them is
    smaller than `cutoff` Angstroms. The distance is calculated as the minimum
    distance between any two atoms of the two chains.

    Parameters
    ----------
    chain1 : str
        Chain ID of the first chain
    chain2 : str
        Chain ID of the second chain
    cutoff : int, optional
        Minimum distance between the two chains (in Angstroms)

    Returns
    -------
    valid : bool
        `True` if the two chains are a valid pair, `False` otherwise

    """
    margin = cutoff * 3
    assert chain1 in self.get_chains(), f"Chain {chain1} not found"
    assert chain2 in self.get_chains(), f"Chain {chain2} not found"
    X1 = self.get_coordinates(chains=[chain1], only_known=True)
    X2 = self.get_coordinates(chains=[chain2], only_known=True)
    intersect_dim_X1 = []
    intersect_dim_X2 = []
    intersect_X1 = np.zeros(len(X1))
    intersect_X2 = np.zeros(len(X2))
    for dim in range(3):
        min_dim_1 = X1[:, 2, dim].min()
        max_dim_1 = X1[:, 2, dim].max()
        min_dim_2 = X2[:, 2, dim].min()
        max_dim_2 = X2[:, 2, dim].max()
        intersect_dim_X1.append(
            np.where(
                np.logical_and(
                    X1[:, 2, dim] >= min_dim_2 - margin,
                    X1[:, 2, dim] <= max_dim_2 + margin,
                )
            )[0]
        )
        intersect_dim_X2.append(
            np.where(
                np.logical_and(
                    X2[:, 2, dim] >= min_dim_1 - margin,
                    X2[:, 2, dim] <= max_dim_1 + margin,
                )
            )[0]
        )

    intersect_X1 = np.intersect1d(
        np.intersect1d(intersect_dim_X1[0], intersect_dim_X1[1]),
        intersect_dim_X1[2],
    )
    intersect_X2 = np.intersect1d(
        np.intersect1d(intersect_dim_X2[0], intersect_dim_X2[1]),
        intersect_dim_X2[2],
    )

    not_end_mask1 = np.where((X1[:, 2, :] == 0).sum(-1) != 3)[0]
    not_end_mask2 = np.where((X2[:, 2, :] == 0).sum(-1) != 3)[0]

    intersect_X1 = np.intersect1d(intersect_X1, not_end_mask1)
    intersect_X2 = np.intersect1d(intersect_X2, not_end_mask2)

    diff = X1[intersect_X1, 2, np.newaxis, :] - X2[intersect_X2, 2, :]
    distances = np.sqrt(np.sum(diff**2, axis=2))

    if np.sum(distances < cutoff) < 3:
        return False
    else:
        return True
def long_repeat_num(self, thr=5)

Calculate the number of long repeats in the protein.

Parameters

thr : int, default 5
The threshold for the minimum length of the repeat

Returns

num : int
The number of long repeats in the protein
Expand source code
def long_repeat_num(self, thr=5):
    """Calculate the number of long repeats in the protein.

    Parameters
    ----------
    thr : int, default 5
        The threshold for the minimum length of the repeat

    Returns
    -------
    num : int
        The number of long repeats in the protein

    """
    seq = self.get_sequence(encode=False)
    if self.predict_mask is not None:
        predict_mask = self.get_predict_mask()
        seq = np.array(list(seq))[predict_mask.astype(bool)]
    return long_repeat_num(seq, thr=thr)
def merge(self, entry)

Merge another ProteinEntry object into this one.

Parameters

entry : ProteinEntry
The merged ProteinEntry object
Expand source code
def merge(self, entry):
    """Merge another `ProteinEntry` object into this one.

    Parameters
    ----------
    entry : ProteinEntry
        The merged `ProteinEntry` object

    """
    for chain in entry.get_chains():
        if chain.split("_")[0] in {x.split("_")[0] for x in self.get_chains()}:
            raise ValueError("Chain IDs must be unique")
        self.seq[chain] = entry.seq[chain]
        self.crd[chain] = entry.crd[chain]
        self.mask[chain] = entry.mask[chain]
        self.mask_original[chain] = entry.mask_original[chain]
        self.cdr[chain] = entry.cdr[chain]
        self.predict_mask[chain] = entry.predict_mask[chain]
    if not all([x is None for x in self.predict_mask.values()]):
        for k, v in self.predict_mask.items():
            if v is None:
                self.predict_mask[k] = np.zeros(len(self.get_sequence(k)))
def rename_chains(self, chain_dict)

Rename the chains of the protein.

Parameters

chain_dict : dict
A dictionary mapping old chain IDs to new chain IDs
Expand source code
def rename_chains(self, chain_dict):
    """Rename the chains of the protein.

    Parameters
    ----------
    chain_dict : dict
        A dictionary mapping old chain IDs to new chain IDs

    """
    for chain in self.get_chains():
        if chain not in chain_dict:
            chain_dict[chain] = chain
    self._rename_chains({k: k * 5 for k in self.get_chains()})
    self._rename_chains({k * 5: v for k, v in chain_dict.items()})
def secondary_structure(self, chains=None)

Calculate the secondary structure of the protein.

Parameters

chains : list of str, optional
If specified, only the secondary structure of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs

Returns

sse : np.ndarray
A 'numpy' array of shape (L, 3) with secondary structure elements encoded as one-hot vectors (alpha-helix, beta-sheet, loop); missing values are marked with zeros
chains : list of str, optional
If specified, only the secondary structure of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
Expand source code
def secondary_structure(self, chains=None):
    """Calculate the secondary structure of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the secondary structure of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    Returns
    -------
    sse : np.ndarray
        A `'numpy'` array of shape `(L, 3)` with secondary structure
        elements encoded as one-hot vectors (alpha-helix, beta-sheet, loop);
        missing values are marked with zeros
    chains : list of str, optional
        If specified, only the secondary structure of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    """
    chains = self._get_chains_list(chains)
    out = []
    for chain in chains:
        crd = self.get_coordinates([chain])
        sse_map = {"c": [0, 0, 1], "b": [0, 1, 0], "a": [1, 0, 0], "": [0, 0, 0]}
        sse = _annotate_sse(crd[:, :4])
        out += [sse_map[x] for x in sse]
    sse = np.array(out)
    return sse
def set_predict_mask(self, mask_dict)

Set the predicted mask.

Parameters

mask_dict : dict
A dictionary mapping from chain IDs to a np.ndarray mask of 0s and 1s of the same length as the chain sequence
Expand source code
def set_predict_mask(self, mask_dict):
    """Set the predicted mask.

    Parameters
    ----------
    mask_dict : dict
        A dictionary mapping from chain IDs to a `np.ndarray` mask of 0s and 1s of the same length as the chain sequence

    """
    for chain in mask_dict:
        if chain not in self.get_chains():
            raise PDBError("Chain not found")
        if len(mask_dict[chain]) != self.get_length([chain]):
            raise PDBError("Mask length does not match sequence length")
    self.predict_mask = mask_dict
def sidechain_coordinates(self, chains=None)

Get the sidechain coordinates of the protein.

Parameters

chains : list of str, optional
If specified, only the sidechain coordinates of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs

Returns

crd : np.ndarray
A 'numpy' array of shape (L, 10, 3) with sidechain atom coordinates (check sidechain_order() for the order of atoms); missing values are marked with zeros
chains : list of str, optional
If specified, only the sidechain coordinates of the specified chains are returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
Expand source code
def sidechain_coordinates(self, chains=None):
    """Get the sidechain coordinates of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the sidechain coordinates of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    Returns
    -------
    crd : np.ndarray
        A `'numpy'` array of shape `(L, 10, 3)` with sidechain atom
        coordinates (check `proteinflow.sidechain_order()` for the order of
        atoms); missing values are marked with zeros
    chains : list of str, optional
        If specified, only the sidechain coordinates of the specified chains are returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    """
    chains = self._get_chains_list(chains)
    return self.get_coordinates(chains)[:, 4:, :]
def sidechain_orientation(self, chains=None)

Calculate the (global) sidechain orientation of the protein.

Parameters

chains : list of str, optional
If specified, only the sidechain orientation of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs

Returns

orientation : np.ndarray
A 'numpy' array of shape (L, 3) with sidechain orientation vectors; missing values are marked with zeros
chains : list of str, optional
If specified, only the sidechain orientation of the specified chains is returned (in the same order); otherwise, all features are concatenated in alphabetical order of the chain IDs
Expand source code
def sidechain_orientation(self, chains=None):
    """Calculate the (global) sidechain orientation of the protein.

    Parameters
    ----------
    chains : list of str, optional
        If specified, only the sidechain orientation of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    Returns
    -------
    orientation : np.ndarray
        A `'numpy'` array of shape `(L, 3)` with sidechain orientation
        vectors; missing values are marked with zeros
    chains : list of str, optional
        If specified, only the sidechain orientation of the specified chains is returned (in the same order);
        otherwise, all features are concatenated in alphabetical order of the chain IDs

    """
    chains = self._get_chains_list(chains)
    crd = self.get_coordinates(chains=chains)
    crd_bb, crd_sc = crd[:, :4, :], crd[:, 4:, :]
    seq = self.get_sequence(chains=chains, encode=True)
    orientation = np.zeros((crd_sc.shape[0], 3))
    for i in range(1, 21):
        if MAIN_ATOM_DICT[i] is not None:
            orientation[seq == i] = (
                crd_sc[seq == i, MAIN_ATOM_DICT[i], :] - crd_bb[seq == i, 2, :]
            )
        else:
            S_mask = self.seq == i
            orientation[S_mask] = np.random.rand(*orientation[S_mask].shape)
    orientation /= np.expand_dims(np.linalg.norm(orientation, axis=-1), -1) + 1e-7
    return orientation
def tm_score(self, entry, chains=None)

Calculate TM score between two proteins.

Parameters

entry : ProteinEntry
A ProteinEntry object
chains : list of str, optional
A list of chain IDs to consider

Returns

tm_score : float
The TM score between the two proteins
Expand source code
def tm_score(self, entry, chains=None):
    """Calculate TM score between two proteins.

    Parameters
    ----------
    entry : ProteinEntry
        A `ProteinEntry` object
    chains : list of str, optional
        A list of chain IDs to consider

    Returns
    -------
    tm_score : float
        The TM score between the two proteins

    """
    structure1 = self.get_coordinates(only_known=True, chains=chains)[:, 2]
    structure2 = entry.get_coordinates(only_known=True, chains=chains)[:, 2]
    sequence1 = self.get_sequence(only_known=True, chains=chains)
    sequence2 = entry.get_sequence(only_known=True, chains=chains)
    return tm_score(structure1, structure2, sequence1, sequence2)
def to_dict(self)

Convert a protein entry into a dictionary.

Returns

dictionary : dict
A nested dictionary where first-level keys are chain IDs and second-level keys are the following: - 'seq' : amino acid sequence (one-letter code) - 'crd_bb' : backbone coordinates, shaped (L, 4, 3) - 'crd_sc' : sidechain coordinates, shaped (L, 10, 3) - 'msk' : mask array where 1 indicates residues with known coordinates and 0 indicates missing values, shaped (L,) - 'cdr' (optional): CDR information, shaped (L,) encoded as integers where each integer corresponds to the index of the CDR type in proteinflow.constants.CDR_ALPHABET - 'predict_msk' (optional): mask array where 1 indicates residues that were generated by a model and 0 indicates residues with known coordinates, shaped (L,) It can optionally also contain protein_id as a first-level key.
Expand source code
def to_dict(self):
    """Convert a protein entry into a dictionary.

    Returns
    -------
    dictionary : dict
        A nested dictionary where first-level keys are chain IDs and
        second-level keys are the following:
        - `'seq'` : amino acid sequence (one-letter code)
        - `'crd_bb'` : backbone coordinates, shaped `(L, 4, 3)`
        - `'crd_sc'` : sidechain coordinates, shaped `(L, 10, 3)`
        - `'msk'` : mask array where 1 indicates residues with known coordinates and 0
            indicates missing values, shaped `(L,)`
        - `'cdr'` (optional): CDR information, shaped `(L,)` encoded as integers where each
            integer corresponds to the index of the CDR type in
            `proteinflow.constants.CDR_ALPHABET`
        - `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
            indicates residues with known coordinates, shaped `(L,)`
        It can optionally also contain `protein_id` as a first-level key.

    """
    data = {}
    for chain in self.get_chains():
        data[chain] = {
            "seq": self.seq[chain],
            "crd_bb": self.crd[chain][:, :4],
            "crd_sc": self.crd[chain][:, 4:],
            "msk": self.mask[chain],
        }
        if self.cdr[chain] is not None:
            data[chain]["cdr"] = self.cdr[chain]
        if self.predict_mask[chain] is not None:
            data[chain]["predict_msk"] = self.predict_mask[chain]
    if self.id is not None:
        data["protein_id"] = self.id
    return data
def to_pdb(self, path, only_ca=False, skip_oxygens=False, only_backbone=False, title=None)

Save the protein entry to a PDB file.

Parameters

path : str
Path to the output PDB file
only_ca : bool, default False
If True, only backbone atoms are saved
skip_oxygens : bool, default False
If True, oxygen atoms are not saved
only_backbone : bool, default False
If True, only backbone atoms are saved
title : str, optional
Title of the PDB file (by default either the protein id or "Untitled")
Expand source code
def to_pdb(
    self,
    path,
    only_ca=False,
    skip_oxygens=False,
    only_backbone=False,
    title=None,
):
    """Save the protein entry to a PDB file.

    Parameters
    ----------
    path : str
        Path to the output PDB file
    only_ca : bool, default False
        If `True`, only backbone atoms are saved
    skip_oxygens : bool, default False
        If `True`, oxygen atoms are not saved
    only_backbone : bool, default False
        If `True`, only backbone atoms are saved
    title : str, optional
        Title of the PDB file (by default either the protein id or "Untitled")

    """
    if any([x[0].upper() != x for x in self.get_chains()]):
        raise ValueError(
            "Chain IDs must be single uppercase letters, please rename with `rename_chains` before saving."
        )
    pdb_builder = PDBBuilder(
        self,
        only_ca=only_ca,
        skip_oxygens=skip_oxygens,
        only_backbone=only_backbone,
    )
    if title is None:
        if self.id is not None:
            title = self.id
        else:
            title = "Untitled"
    pdb_builder.save_pdb(path, title=title)
def to_pickle(self, path)

Save a protein entry to a pickle file.

The output files are pickled nested dictionaries where first-level keys are chain Ids and second-level keys are the following: - 'crd_bb': a numpy array of shape (L, 4, 3) with backbone atom coordinates (N, C, CA, O), - 'crd_sc': a numpy array of shape (L, 10, 3) with sidechain atom coordinates (check sidechain_order() for the order of atoms), - 'msk': a numpy array of shape (L,) where ones correspond to residues with known coordinates and zeros to missing values, - 'seq': a string of length L with residue types.

In a SAbDab datasets, an additional key is added to the dictionary: - 'cdr': a 'numpy' array of shape (L,) where CDR residues are marked with the corresponding type ('H1', 'L1', …) and non-CDR residues are marked with '-'.

If a prediction mask is available, another additional key is added to the dictionary: - 'predict_msk': a numpy array of shape (L,) where ones correspond to residues that were generated by a model and zeros to residues with known coordinates.

Parameters

path : str
Path to the pickle file
Expand source code
def to_pickle(self, path):
    """Save a protein entry to a pickle file.

    The output files are pickled nested dictionaries where first-level keys are chain Ids and second-level keys are the following:
    - `'crd_bb'`: a `numpy` array of shape `(L, 4, 3)` with backbone atom coordinates (N, C, CA, O),
    - `'crd_sc'`: a `numpy` array of shape `(L, 10, 3)` with sidechain atom coordinates (check `proteinflow.sidechain_order()` for the order of atoms),
    - `'msk'`: a `numpy` array of shape `(L,)` where ones correspond to residues with known coordinates and
        zeros to missing values,
    - `'seq'`: a string of length `L` with residue types.

    In a SAbDab datasets, an additional key is added to the dictionary:
    - `'cdr'`: a `'numpy'` array of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
        and non-CDR residues are marked with `'-'`.

    If a prediction mask is available, another additional key is added to the dictionary:
    - `'predict_msk'`: a `numpy` array of shape `(L,)` where ones correspond to residues that were generated by a model and
        zeros to residues with known coordinates.

    Parameters
    ----------
    path : str
        Path to the pickle file

    """
    data = self.to_dict()
    with open(path, "wb") as f:
        pickle.dump(data, f)
def visualize(self, highlight_mask=None, style='cartoon', highlight_style=None, opacity=1, canvas_size=(400, 300))

Visualize the protein in a notebook.

Parameters

highlight_mask : np.ndarray, optional
A 'numpy' array of shape (L,) with the residues to highlight marked with 1 and the rest marked with 0; if not given and self.predict_mask is not None, the predicted residues are highlighted
style : str, default 'cartoon'
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
highlight_style : str, optional
The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross' (defaults to the same as style)
opacity : float or dict, default 1
Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
canvas_size : tuple, default (400, 300)
Shape of the canvas
Expand source code
def visualize(
    self,
    highlight_mask=None,
    style="cartoon",
    highlight_style=None,
    opacity=1,
    canvas_size=(400, 300),
):
    """Visualize the protein in a notebook.

    Parameters
    ----------
    highlight_mask : np.ndarray, optional
        A `'numpy'` array of shape `(L,)` with the residues to highlight
        marked with 1 and the rest marked with 0; if not given and
        `self.predict_mask` is not `None`, the predicted residues are highlighted
    style : str, default 'cartoon'
        The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
    highlight_style : str, optional
        The style of the highlighted atoms; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
        (defaults to the same as `style`)
    opacity : float or dict, default 1
        Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)
    canvas_size : tuple, default (400, 300)
        Shape of the canvas

    """
    if highlight_mask is not None:
        highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
    elif list(self.predict_mask.values())[0] is not None:
        highlight_mask_dict = {
            chain: self.predict_mask[chain][self.get_mask([chain]).astype(bool)]
            for chain in self.get_chains()
        }
    else:
        highlight_mask_dict = None
    with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
        self.to_pdb(tmp.name)
        pdb_entry = PDBEntry(tmp.name)
    pdb_entry.visualize(
        highlight_mask_dict=highlight_mask_dict,
        style=style,
        highlight_style=highlight_style,
        opacity=opacity,
        canvas_size=canvas_size,
    )
class SAbDabEntry (pdb_path, fasta_path, heavy_chain=None, light_chain=None, antigen_chains=None)

A class for parsing SAbDab entries.

Initialize the SAbDabEntry.

Parameters

pdb_path : str
Path to the PDB file
fasta_path : str
Path to the FASTA file
heavy_chain : str, optional
Heavy chain identifier (author chain name)
light_chain : str, optional
Light chain identifier (author chain name)
antigen_chains : list, optional
List of antigen chain identifiers (author chain names)
Expand source code
class SAbDabEntry(PDBEntry):
    """A class for parsing SAbDab entries."""

    def __init__(
        self,
        pdb_path,
        fasta_path,
        heavy_chain=None,
        light_chain=None,
        antigen_chains=None,
    ):
        """Initialize the SAbDabEntry.

        Parameters
        ----------
        pdb_path : str
            Path to the PDB file
        fasta_path : str
            Path to the FASTA file
        heavy_chain : str, optional
            Heavy chain identifier (author chain name)
        light_chain : str, optional
            Light chain identifier (author chain name)
        antigen_chains : list, optional
            List of antigen chain identifiers (author chain names)

        """
        if heavy_chain is None and light_chain is None:
            raise PDBError("At least one chain must be provided")
        self.chain_dict = {
            "heavy": heavy_chain,
            "light": light_chain,
        }
        if antigen_chains is None:
            antigen_chains = []
        self.chain_dict["antigen"] = antigen_chains
        self.reverse_chain_dict = {
            heavy_chain: "heavy",
            light_chain: "light",
        }
        for antigen_chain in antigen_chains:
            self.reverse_chain_dict[antigen_chain] = "antigen"
        super().__init__(pdb_path, fasta_path)

    def _get_relevant_chains(self):
        """Get the chains that are included in the entry."""
        chains = []
        if self.chain_dict["heavy"] is not None:
            chains.append(self.chain_dict["heavy"])
        if self.chain_dict["light"] is not None:
            chains.append(self.chain_dict["light"])
        chains.extend(self.chain_dict["antigen"])
        return chains

    @staticmethod
    def from_id(
        pdb_id,
        local_folder=".",
        light_chain=None,
        heavy_chain=None,
        antigen_chains=None,
    ):
        """Create a SAbDabEntry from a PDB ID.

        Either the light or the heavy chain must be provided.

        Parameters
        ----------
        pdb_id : str
            PDB ID
        local_folder : str, optional
            Local folder to download the PDB and FASTA files
        light_chain : str, optional
            Light chain identifier (author chain name)
        heavy_chain : str, optional
            Heavy chain identifier (author chain name)
        antigen_chains : list, optional
            List of antigen chain identifiers (author chain names)

        Returns
        -------
        entry : SAbDabEntry
            A SAbDabEntry object

        """
        pdb_path = download_pdb(pdb_id, local_folder, sabdab=True)
        fasta_path = download_fasta(pdb_id, local_folder)
        return SAbDabEntry(
            pdb_path=pdb_path,
            fasta_path=fasta_path,
            light_chain=light_chain,
            heavy_chain=heavy_chain,
            antigen_chains=antigen_chains,
        )

    def _get_chain(self, chain):
        """Return the chain identifier."""
        if chain in ["heavy", "light"]:
            chain = self.chain_dict[chain]
        return super()._get_chain(chain)

    def heavy_chain(self):
        """Return the heavy chain identifier.

        Returns
        -------
        chain : str
            The heavy chain identifier

        """
        return self.chain_dict["heavy"]

    def light_chain(self):
        """Return the light chain identifier.

        Returns
        -------
        chain : str
            The light chain identifier

        """
        return self.chain_dict["light"]

    def antigen_chains(self):
        """Return the antigen chain identifiers.

        Returns
        -------
        chains : list
            The antigen chain identifiers

        """
        return self.chain_dict["antigen"]

    def chains(self):
        """Return the chains in the PDB.

        Returns
        -------
        chains : list
            A list of chain identifiers

        """
        return [self.heavy_chain(), self.light_chain()] + self.antigen_chains()

    def chain_type(self, chain):
        """Return the type of a chain.

        Parameters
        ----------
        chain : str
            Chain identifier

        Returns
        -------
        chain_type : str
            The type of the chain (heavy, light or antigen)

        """
        if chain in self.reverse_chain_dict:
            return self.reverse_chain_dict[chain]
        raise PDBError("Chain not found")

    @lru_cache()
    def _get_chain_cdr(self, chain, align_to_fasta=True):
        """Return the CDRs for a given chain ID."""
        chain = self._get_chain(chain)
        chain_crd = self.get_pdb_df(chain)
        chain_type = self.chain_type(chain)[0].upper()
        pdb_seq = self._pdb_sequence(chain)
        unique_numbers = chain_crd["unique_residue_number"].unique()
        if len(unique_numbers) != len(pdb_seq):
            raise PDBError("Inconsistencies in the biopandas dataframe")
        if chain_type in ["H", "L"]:
            cdr_arr = [
                CDR_VALUES[chain_type][int(x.split("_")[0])] for x in unique_numbers
            ]
            cdr_arr = np.array(cdr_arr)
        else:
            cdr_arr = np.array(["-"] * len(unique_numbers), dtype=object)
        if align_to_fasta:
            aligned_seq, _ = self._align_chain(chain)
            aligned_seq_arr = np.array(list(aligned_seq))
            cdr_arr_aligned = np.array(["-"] * len(aligned_seq), dtype=object)
            cdr_arr_aligned[aligned_seq_arr != "-"] = cdr_arr
            cdr_arr = cdr_arr_aligned
        return cdr_arr

    def get_cdr(self, chains=None):
        """Return CDR arrays.

        Parameters
        ----------
        chains : list, optional
            A list of chain identifiers (if not provided, all chains are processed)

        Returns
        -------
        cdrs : dict
            A dictionary containing the CDR arrays for each of the chains

        """
        if chains is None:
            chains = self.chains()
        return {chain: self._get_chain_cdr(chain) for chain in chains}

Ancestors

Static methods

def from_id(pdb_id, local_folder='.', light_chain=None, heavy_chain=None, antigen_chains=None)

Create a SAbDabEntry from a PDB ID.

Either the light or the heavy chain must be provided.

Parameters

pdb_id : str
PDB ID
local_folder : str, optional
Local folder to download the PDB and FASTA files
light_chain : str, optional
Light chain identifier (author chain name)
heavy_chain : str, optional
Heavy chain identifier (author chain name)
antigen_chains : list, optional
List of antigen chain identifiers (author chain names)

Returns

entry : SAbDabEntry
A SAbDabEntry object
Expand source code
@staticmethod
def from_id(
    pdb_id,
    local_folder=".",
    light_chain=None,
    heavy_chain=None,
    antigen_chains=None,
):
    """Create a SAbDabEntry from a PDB ID.

    Either the light or the heavy chain must be provided.

    Parameters
    ----------
    pdb_id : str
        PDB ID
    local_folder : str, optional
        Local folder to download the PDB and FASTA files
    light_chain : str, optional
        Light chain identifier (author chain name)
    heavy_chain : str, optional
        Heavy chain identifier (author chain name)
    antigen_chains : list, optional
        List of antigen chain identifiers (author chain names)

    Returns
    -------
    entry : SAbDabEntry
        A SAbDabEntry object

    """
    pdb_path = download_pdb(pdb_id, local_folder, sabdab=True)
    fasta_path = download_fasta(pdb_id, local_folder)
    return SAbDabEntry(
        pdb_path=pdb_path,
        fasta_path=fasta_path,
        light_chain=light_chain,
        heavy_chain=heavy_chain,
        antigen_chains=antigen_chains,
    )

Methods

def antigen_chains(self)

Return the antigen chain identifiers.

Returns

chains : list
The antigen chain identifiers
Expand source code
def antigen_chains(self):
    """Return the antigen chain identifiers.

    Returns
    -------
    chains : list
        The antigen chain identifiers

    """
    return self.chain_dict["antigen"]
def chain_type(self, chain)

Return the type of a chain.

Parameters

chain : str
Chain identifier

Returns

chain_type : str
The type of the chain (heavy, light or antigen)
Expand source code
def chain_type(self, chain):
    """Return the type of a chain.

    Parameters
    ----------
    chain : str
        Chain identifier

    Returns
    -------
    chain_type : str
        The type of the chain (heavy, light or antigen)

    """
    if chain in self.reverse_chain_dict:
        return self.reverse_chain_dict[chain]
    raise PDBError("Chain not found")
def chains(self)

Return the chains in the PDB.

Returns

chains : list
A list of chain identifiers
Expand source code
def chains(self):
    """Return the chains in the PDB.

    Returns
    -------
    chains : list
        A list of chain identifiers

    """
    return [self.heavy_chain(), self.light_chain()] + self.antigen_chains()
def get_cdr(self, chains=None)

Return CDR arrays.

Parameters

chains : list, optional
A list of chain identifiers (if not provided, all chains are processed)

Returns

cdrs : dict
A dictionary containing the CDR arrays for each of the chains
Expand source code
def get_cdr(self, chains=None):
    """Return CDR arrays.

    Parameters
    ----------
    chains : list, optional
        A list of chain identifiers (if not provided, all chains are processed)

    Returns
    -------
    cdrs : dict
        A dictionary containing the CDR arrays for each of the chains

    """
    if chains is None:
        chains = self.chains()
    return {chain: self._get_chain_cdr(chain) for chain in chains}
def heavy_chain(self)

Return the heavy chain identifier.

Returns

chain : str
The heavy chain identifier
Expand source code
def heavy_chain(self):
    """Return the heavy chain identifier.

    Returns
    -------
    chain : str
        The heavy chain identifier

    """
    return self.chain_dict["heavy"]
def light_chain(self)

Return the light chain identifier.

Returns

chain : str
The light chain identifier
Expand source code
def light_chain(self):
    """Return the light chain identifier.

    Returns
    -------
    chain : str
        The light chain identifier

    """
    return self.chain_dict["light"]

Inherited members