Module proteinflow.data.torch
Subclasses of torch.utils.data.Dataset
and torch.utils.data.DataLoader
that are tuned for loading proteinflow data.
Expand source code
"""Subclasses of `torch.utils.data.Dataset` and `torch.utils.data.DataLoader` that are tuned for loading proteinflow data."""
import os
import pickle
import random
from collections import defaultdict
from copy import deepcopy
from itertools import combinations, groupby
from operator import itemgetter
import numpy as np
import torch
from p_tqdm import p_map
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from proteinflow.constants import ALPHABET, CDR_REVERSE, D3TO1, MAIN_ATOMS
from proteinflow.data import ProteinEntry
class _PadCollate:
"""A variant of `collate_fn` that pads according to the longest sequence in a batch of sequences."""
def pad_collate(self, batch):
# find longest sequence
out = {}
for key in batch[0].keys():
if key == "X_ligands" or key == "ligand_chains":
max_len = max([b[key].shape[0] for b in batch])
to_pad = [max_len - b[key].shape[0] for b in batch]
else:
max_len = max(map(lambda x: x["S"].shape[0], batch))
# pad according to max_len
to_pad = [max_len - b["S"].shape[0] for b in batch]
if key in [
"chain_id",
"chain_dict",
"pdb_id",
"cdr_id",
"ligand_smiles",
"chain_type_dict",
]:
continue
out[key] = torch.stack(
[
torch.cat([b[key], torch.zeros((pad, *b[key].shape[1:]))], 0)
for b, pad in zip(batch, to_pad)
],
0,
)
out["chain_id"] = torch.tensor([b["chain_id"] for b in batch])
if "cdr_id" in batch[0]:
out["cdr_id"] = torch.tensor([b["cdr_id"] for b in batch])
out["chain_dict"] = [b["chain_dict"] for b in batch]
out["pdb_id"] = [b["pdb_id"] for b in batch]
if "ligand_smiles" in batch[0]:
out["ligand_smiles"] = list([b["ligand_smiles"] for b in batch])
out["ligand_lengths"] = torch.tensor(
[len(b["ligand_chains"]) for b in batch]
)
return out
def __call__(self, batch):
return self.pad_collate(batch)
class ProteinLoader(DataLoader):
"""A subclass of `torch.data.utils.DataLoader` tuned for the `proteinflow` dataset.
Creates and iterates over an instance of `ProteinDataset`, omitting the `'chain_dict'` keys.
See the `ProteinDataset` documentation for more information.
If batch size is larger than one, all objects are padded with zeros at the ends to reach the length of the
longest protein in the batch.
"""
def __init__(
self,
dataset,
collate_func=_PadCollate,
shuffle_batches=True,
*args,
**kwargs,
):
"""Initialize a ProteinLoader instance.
Parameters
----------
dataset : ProteinDataset
a ProteinDataset instance
shuffle_batches : bool, default True
if `True`, the batches are shuffled at each epoch
collate_func : callable, optional
a function that takes a list of samples and returns a batch and inherits from _PadCollate
"""
super().__init__(
dataset,
collate_fn=collate_func(),
shuffle=shuffle_batches,
*args,
**kwargs,
)
@staticmethod
def from_args(
dataset_folder,
features_folder="./data/tmp/",
clustering_dict_path=None,
max_length=None,
rewrite=False,
use_fraction=1,
load_to_ram=False,
debug=False,
interpolate="none",
node_features_type=None,
entry_type="biounit", # biounit, chain, pair
classes_to_exclude=None,
lower_limit=15,
upper_limit=100,
mask_residues=True,
mask_whole_chains=False,
mask_frac=None,
force_binding_sites_frac=0,
shuffle_clusters=True,
shuffle_batches=True,
mask_all_cdrs=False,
classes_dict_path=None,
load_ligands=False,
cut_edges=False,
require_antigen=False,
require_light_chain=False,
require_no_light_chain=False,
require_heavy_chain=False,
*args,
**kwargs,
) -> None:
"""Create a `ProteinLoader` instance with a `ProteinDataset` from the given arguments.
Parameters
----------
dataset_folder : str
the path to the folder with proteinflow format input files (assumes that files are named {biounit_id}.pickle)
features_folder : str
the path to the folder where the ProteinMPNN features will be saved
clustering_dict_path : str, optional
path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples)
max_length : int, optional
entries with total length of chains larger than `max_length` will be disregarded
rewrite : bool, default False
if `False`, existing feature files are not overwritten
use_fraction : float, default 1
the fraction of the clusters to use (first N in alphabetic order)
load_to_ram : bool, default False
if `True`, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash)
debug : bool, default False
only process 1000 files
interpolate : {"none", "only_middle", "all"}
`"none"` for no interpolation, `"only_middle"` for only linear interpolation in the middle, `"all"` for linear interpolation + ends generation
node_features_type : {"dihedral", "sidechain_orientation", "chemical", "secondary_structure", "sidechain_coords", or combinations with "+"}, optional
the type of node features, e.g. `"dihedral"` or `"sidechain_orientation+chemical"`
entry_type : {"biounit", "chain", "pair"}
the type of entries to generate (`"biounit"` for biounit-level, `"chain"` for chain-level, `"pair"` for chain-chain pairs)
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`)
lower_limit : int, default 15
the minimum number of residues to mask
upper_limit : int, default 100
the maximum number of residues to mask
mask_residues : bool, default True
if `True`, generate a mask key
mask_whole_chains : bool, default False
if `True`, `upper_limit`, `force_binding_sites` and `lower_limit` are ignored and the whole chain is masked instead
mask_frac : float, optional
if given, the `lower_limit` and `upper_limit` are ignored and the number of residues to mask is `mask_frac` times the length of the chain
force_binding_sites_frac : float, default 0
if > 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be
forced to be in a binding site
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
shuffle_batches : bool, default True
if `True`, the batches are shuffled at each epoch
mask_all_cdrs : bool, default False
if `True`, all CDRs are masked instead of just the sampled one
classes_dict_path : str, optional
path to the pickled classes dictionary; if not given, we will try to find split dictionaries in the parent folder of `dataset_folder`
load_ligands : bool, default False
if `True`, the ligands will be loaded from the PDB files and added to the features
cut_edges : bool, default False
if `True`, missing values at the edges of the sequence will be cut off
require_antigen : bool, default False
if `True`, only entries with an antigen will be included (used if the dataset is SAbDab)
require_light_chain : bool, default False
if `True`, only entries with a light chain will be included (used if the dataset is SAbDab)
require_no_light_chain : bool, default False
if `True`, only entries without a light chain will be included (used if the dataset is SAbDab)
require_heavy_chain : bool, default False
if `True`, only entries with a heavy chain will be included (used if the dataset is SAbDab)
*args
additional arguments to `torch.utils.data.DataLoader`
**kwargs
additional keyword arguments to `torch.utils.data.DataLoader`
"""
dataset = ProteinDataset(
dataset_folder=dataset_folder,
features_folder=features_folder,
clustering_dict_path=clustering_dict_path,
max_length=max_length,
rewrite=rewrite,
use_fraction=use_fraction,
load_to_ram=load_to_ram,
debug=debug,
interpolate=interpolate,
node_features_type=node_features_type,
entry_type=entry_type,
classes_to_exclude=classes_to_exclude,
shuffle_clusters=shuffle_clusters,
classes_dict_path=classes_dict_path,
lower_limit=lower_limit,
upper_limit=upper_limit,
mask_residues=mask_residues,
mask_whole_chains=mask_whole_chains,
mask_frac=mask_frac,
force_binding_sites_frac=force_binding_sites_frac,
mask_all_cdrs=mask_all_cdrs,
load_ligands=load_ligands,
cut_edges=cut_edges,
require_antigen=require_antigen,
require_light_chain=require_light_chain,
require_no_light_chain=require_no_light_chain,
require_heavy_chain=require_heavy_chain,
)
return ProteinLoader(
dataset=dataset,
shuffle_batches=shuffle_batches,
*args,
**kwargs,
)
class ProteinDataset(Dataset):
"""Dataset to load proteinflow data.
Saves the model input tensors as pickle files in `features_folder`. When `clustering_dict_path` is provided,
at each iteration a random biounit from a cluster is sampled.
If a complex contains multiple chains, they are concatenated. The sequence identity information is preserved in the
`'chain_encoding_all'` object and in the `'residue_idx'` arrays the chain change is denoted by a +100 jump.
Returns dictionaries with the following keys and values (all values are `torch` tensors):
- `'X'`: 3D coordinates of N, C, Ca, O, `(total_L, 4, 3)`,
- `'S'`: sequence indices (shape `(total_L)`),
- `'mask'`: residue mask (0 where coordinates are missing, 1 otherwise; with interpolation 0s are replaced with 1s), `(total_L)`,
- `'mask_original'`: residue mask (0 where coordinates are missing, 1 otherwise; not changed with interpolation), `(total_L)`,
- `'residue_idx'`: residue indices (from 0 to length of sequence, +100 where chains change), `(total_L)`,
- `'chain_encoding_all'`: chain indices, `(total_L)`,
- `'chain_id`': a sampled chain index,
- `'chain_dict'`: a dictionary of chain ids (keys are chain ids, e.g. `'A'`, values are the indices used in `'chain_id'` and `'chain_encoding_all'` objects)
You can also choose to include additional features (set in the `node_features_type` parameter):
- `'sidechain_orientation'`: a unit vector in the direction of the sidechain, `(total_L, 3)`,
- `'dihedral'`: the dihedral angles, `(total_L, 2)`,
- `'chemical'`: hydropathy, volume, charge, polarity, acceptor/donor features, `(total_L, 6)`,
- `'secondary_structure'`: a one-hot encoding of secondary structure ([alpha-helix, beta-sheet, coil]), `(total_L, 3)`,
- `'sidechain_coords'`: the coordinates of the sidechain atoms (see `proteinflow.sidechain_order()` for the order), `(total_L, 10, 3)`.
If the dataset contains a `'cdr'` key (if it was generated from SAbDab files), the output files will also additionally contain a `'cdr'`
key with a CDR tensor of length `total_L`. In the array, the CDR residues are marked with the corresponding CDR type
(H1=1, H2=2, H3=3, L1=4, L2=5, L3=6) and the rest of the residues are marked with 0s.
Use the `set_cdr` method to only iterate over specific CDRs.
In order to compute additional features, use the `feature_functions` parameter. It should be a dictionary with keys
corresponding to the feature names and values corresponding to the functions that compute the features. The functions
should take a `proteinflow.data.ProteinEntry` instance and a list of chains and return a `numpy` array shaped as `(#residues, #features)`
where `#residues` is the total number of residues in those chains and the features are concatenated in the order of the list:
`func(data_entry: ProteinEntry, chains: list) -> np.ndarray`.
If `mask_residues` is `True`, an additional `'masked_res'` key is added to the output. The value is a binary
tensor shaped `(B, L)` where 1 denotes the part that needs to be predicted and 0 is everything else. The tensors are generated
according to the following rulesd:
- if the dataset is generated from SAbDab files, the sampled CDR is masked,
- if `mask_whole_chains` is `True`, the whole chain is masked,
- if `mask_frac` is given, the number of residues to mask is `mask_frac` times the length of the chain,
- otherwise, the number of residues to mask is sampled uniformly from the range [`lower_limit`, `upper_limit`].
If `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain
from a polymer is sampled, the center of the masked region will be forced to be in a binding site (in PDB datasets).
"""
def __init__(
self,
dataset_folder,
features_folder="./data/tmp/",
clustering_dict_path=None,
max_length=None,
rewrite=False,
use_fraction=1,
load_to_ram=False,
debug=False,
interpolate="none",
node_features_type="zeros",
debug_file_path=None,
entry_type="biounit", # biounit, chain, pair
classes_to_exclude=None, # heteromers, homomers, single_chains
shuffle_clusters=True,
min_cdr_length=None,
feature_functions=None,
classes_dict_path=None,
cut_edges=False,
mask_residues=True,
lower_limit=15,
upper_limit=100,
mask_frac=None,
mask_whole_chains=False,
mask_sequential=False,
force_binding_sites_frac=0.15,
mask_all_cdrs=False,
load_ligands=False,
pyg_graph=False,
patch_around_mask=False,
initial_patch_size=128,
antigen_patch_size=128,
require_antigen=False,
require_light_chain=False,
require_no_light_chain=False,
require_heavy_chain=False,
):
"""Initialize the dataset.
Parameters
----------
dataset_folder : str
the path to the folder with proteinflow format input files (assumes that files are named {biounit_id}.pickle)
features_folder : str, default "./data/tmp/"
the path to the folder where the ProteinMPNN features will be saved
clustering_dict_path : str, optional
path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples)
max_length : int, optional
entries with total length of chains larger than `max_length` will be disregarded
rewrite : bool, default False
if `False`, existing feature files are not overwritten
use_fraction : float, default 1
the fraction of the clusters to use (first N in alphabetic order)
load_to_ram : bool, default False
if `True`, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash)
debug : bool, default False
only process 1000 files
interpolate : {"none", "only_middle", "all"}
`"none"` for no interpolation, `"only_middle"` for only linear interpolation in the middle, `"all"` for linear interpolation + ends generation
node_features_type : {"zeros", "dihedral", "sidechain_orientation", "chemical", "secondary_structure" or combinations with "+"}
the type of node features, e.g. `"dihedral"` or `"sidechain_orientation+chemical"`
debug_file_path : str, optional
if not `None`, open this single file instead of loading the dataset
entry_type : {"biounit", "chain", "pair"}
the type of entries to generate (`"biounit"` for biounit-level complexes, `"chain"` for chain-level, `"pair"`
for chain-chain pairs (all pairs that are seen in the same biounit and have intersecting coordinate clouds))
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`)
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
min_cdr_length : int, optional
for SAbDab datasets, biounits with CDRs shorter than `min_cdr_length` will be excluded
feature_functions : dict, optional
a dictionary of functions to compute additional features (keys are the names of the features, values are the functions)
classes_dict_path : str, optional
a path to a pickled dictionary with biounit classes (single chain / heteromer / homomer)
cut_edges : bool, default False
if `True`, missing values at the edges of the sequence will be cut off
mask_residues : bool, default True
if `True`, the masked residues will be added to the output
lower_limit : int, default 15
the lower limit of the number of residues to mask
upper_limit : int, default 100
the upper limit of the number of residues to mask
mask_frac : float, optional
if given, the number of residues to mask is `mask_frac` times the length of the chain
mask_whole_chains : bool, default False
if `True`, the whole chain is masked
mask_sequential : bool, default False
if `True`, the masked residues will be neighbors in the sequence; otherwise a geometric
mask is applied based on the coordinates
force_binding_sites_frac : float, default 0.15
if `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain
from a polymer is sampled, the center of the masked region will be forced to be in a binding site (in PDB datasets)
mask_all_cdrs : bool, default False
if `True`, all CDRs will be masked (in SAbDab datasets)
load_ligands : bool, default False
if `True`, the ligands will be loaded as well
pyg_graph : bool, default False
if `True`, the output will be a `torch_geometric.data.Data` object instead of a dictionary
patch_around_mask : bool, default False
if `True`, the data entries will be cut around the masked region
initial_patch_size : int, default 128
the size of the initial patch (used if `patch_around_mask` is `True`)
antigen_patch_size : int, default 128
the size of the antigen patch (used if `patch_around_mask` is `True` and the dataset is SAbDab)
require_antigen : bool, default False
if `True`, only entries with an antigen will be included (used if the dataset is SAbDab)
require_light_chain : bool, default False
if `True`, only entries with a light chain will be included (used if the dataset is SAbDab)
require_no_light_chain : bool, default False
if `True`, only entries without a light chain will be included (used if the dataset is SAbDab)
requre_heavy_chain : bool, default False
if `True`, only entries with a heavy chain will be included (used if the dataset is SAbDab)
"""
self.debug = False
if classes_dict_path is None:
dataset_parent = os.path.dirname(dataset_folder)
classes_dict_path = os.path.join(
dataset_parent, "splits_dict", "classes.pickle"
)
if not os.path.exists(classes_dict_path):
classes_dict_path = None
alphabet = ALPHABET
self.alphabet_dict = defaultdict(lambda: 0)
for i, letter in enumerate(alphabet):
self.alphabet_dict[letter] = i
self.alphabet_dict["X"] = 0
self.files = defaultdict(lambda: defaultdict(list)) # file path by biounit id
self.loaded = None
self.dataset_folder = dataset_folder
self.features_folder = features_folder
self.cut_edges = cut_edges
self.mask_residues = mask_residues
self.lower_limit = lower_limit
self.upper_limit = upper_limit
self.mask_frac = mask_frac
self.mask_whole_chains = mask_whole_chains
self.force_binding_sites_frac = force_binding_sites_frac
self.mask_all_cdrs = mask_all_cdrs
self.load_ligands = load_ligands
self.pyg_graph = pyg_graph
self.patch_around_mask = patch_around_mask
self.initial_patch_size = initial_patch_size
self.antigen_patch_size = antigen_patch_size
self.mask_sequential = mask_sequential
self.feature_types = []
if node_features_type is not None:
self.feature_types = node_features_type.split("+")
self.entry_type = entry_type
self.shuffle_clusters = shuffle_clusters
self.feature_functions = {
"sidechain_orientation": self._sidechain,
"dihedral": self._dihedral,
"chemical": self._chemical,
"secondary_structure": self._sse,
"sidechain_coords": self._sidechain_coords,
}
self.feature_functions.update(feature_functions or {})
if classes_to_exclude is not None and not all(
[x in ["single_chain", "heteromer", "homomer"] for x in classes_to_exclude]
):
raise ValueError(
"Invalid class to exclude, choose from 'single_chain', 'heteromer', 'homomer'"
)
if debug_file_path is not None:
self.dataset_folder = os.path.dirname(debug_file_path)
debug_file_path = os.path.basename(debug_file_path)
self.main_atom_dict = defaultdict(lambda: None)
d1to3 = {v: k for k, v in D3TO1.items()}
for i, letter in enumerate(alphabet):
if i == 0:
continue
self.main_atom_dict[i] = MAIN_ATOMS[d1to3[letter]]
# create feature folder if it does not exist
if not os.path.exists(self.features_folder):
os.makedirs(self.features_folder)
self.interpolate = interpolate
# generate the feature files
print("Processing files...")
if debug_file_path is None:
to_process = [
x for x in os.listdir(dataset_folder) if x.endswith(".pickle")
]
else:
to_process = [debug_file_path]
if clustering_dict_path is not None and use_fraction < 1:
with open(clustering_dict_path, "rb") as f:
clusters = pickle.load(f)
keys = sorted(clusters.keys())[: int(len(clusters) * use_fraction)]
to_process = set()
for key in keys:
to_process.update([x[0] for x in clusters[key]])
file_set = set(os.listdir(dataset_folder))
to_process = [x for x in to_process if x in file_set]
if debug:
to_process = to_process[:1000]
if self.entry_type == "pair":
print(
"Please note that the pair entry type takes longer to process than the other two. The progress bar is not linear because of the varying number of chains per file."
)
output_tuples_list = p_map(
lambda x: self._process(
x,
rewrite=rewrite,
max_length=max_length,
min_cdr_length=min_cdr_length,
classes_to_exclude=classes_to_exclude,
),
to_process,
)
# save the file names
for output_tuples in output_tuples_list:
for id, filename, chain_set in output_tuples:
for chain in chain_set:
self.files[id][chain].append(filename)
if classes_to_exclude is None:
classes_to_exclude = []
classes = None
if classes_dict_path is not None:
with open(classes_dict_path, "rb") as f:
classes = pickle.load(f)
if clustering_dict_path is not None:
with open(clustering_dict_path, "rb") as f:
self.clusters = pickle.load(f) # list of biounit ids by cluster id
if classes is None: # old way of storing class information
try:
classes = pickle.load(f)
except EOFError:
pass
else:
self.clusters = None
if classes is None and len(classes_to_exclude) > 0:
raise ValueError(
"Classes to exclude are given but no classes dictionary is found, please set classes_dict_path to the path of the classes dictionary"
)
to_exclude = set()
# if classes is not None:
# for c in classes_to_exclude:
# for key, id_arr in classes.get(c, {}).items():
# for id, _ in id_arr:
# to_exclude.add(id)
if require_antigen or require_light_chain:
to_exclude.update(
self._exclude_by_chains(
require_antigen,
require_light_chain,
require_no_light_chain,
require_heavy_chain,
)
)
if self.clusters is not None:
self._exclude_ids_from_clusters(to_exclude)
self.data = list(self.clusters.keys())
else:
self.data = [x for x in self.files.keys() if x not in to_exclude]
# create a smaller dataset if necessary (if we have clustering it's applied earlier)
if self.clusters is None and use_fraction < 1:
self.data = sorted(self.data)[: int(len(self.data) * use_fraction)]
if load_to_ram:
print("Loading to RAM...")
self.loaded = {}
seen = set()
for id in self.files:
for chain, file_list in self.files[id].items():
for file in file_list:
if file in seen:
continue
seen.add(file)
with open(file, "rb") as f:
self.loaded[file] = pickle.load(f)
sample_file = list(self.files.keys())[0]
sample_chain = list(self.files[sample_file].keys())[0]
self.sabdab = "__" in sample_chain
self.cdr = 0
self.set_cdr(None)
def _exclude_ids_from_clusters(self, to_exclude):
for key in list(self.clusters.keys()):
cluster_list = []
for x in self.clusters[key]:
if x[0] in to_exclude:
continue
id = x[0].split(".")[0]
chain = x[1]
if id not in self.files:
continue
if chain not in self.files[id]:
continue
if len(self.files[id][chain]) == 0:
continue
cluster_list.append([id, chain])
self.clusters[key] = cluster_list
if len(self.clusters[key]) == 0:
self.clusters.pop(key)
def _check_chain_types(self, file):
chain_types = set()
with open(file, "rb") as f:
data = pickle.load(f)
chains = data["chain_dict"].values()
for chain in chains:
chain_mask = data["chain_encoding_all"] == chain
cdr = data["cdr"][chain_mask]
cdr_values = cdr.unique()
if len(cdr_values) == 1:
chain_types.add("antigen")
elif CDR_REVERSE["H1"] in cdr_values:
chain_types.add("heavy")
elif CDR_REVERSE["L1"] in cdr_values:
chain_types.add("light")
return chain_types
def _exclude_by_class(
self,
classes_to_exclude,
):
"""Exclude entries that are in the classes to exclude."""
to_exclude = set()
for id in self.files:
for chain in self.files[id]:
filename = self.files[id][chain][0]
with open(filename, "rb") as f:
data = pickle.load(f)
if classes_to_exclude in data["classes"]:
to_exclude.add(id)
return to_exclude
def _exclude_by_chains(
self,
require_antigen,
require_light_chain,
require_no_light_chain,
require_heavy_chain,
):
"""Exclude entries that do not have an antigen or a light chain."""
to_exclude = set()
for id in self.files:
filename = list(self.files[id].values())[0][
0
] # assuming entry type is biounit
chain_types = self._check_chain_types(filename)
if require_antigen and "antigen" not in chain_types:
to_exclude.add(id)
if require_light_chain and "light" not in chain_types:
to_exclude.add(id)
if require_no_light_chain and "light" in chain_types:
to_exclude.add(id)
if require_heavy_chain and "heavy" not in chain_types:
to_exclude.add(id)
return to_exclude
def _get_masked_sequence(
self,
data,
):
"""Get the mask for the residues that need to be predicted.
Depending on the parameters the residues are selected as follows:
- if `mask_whole_chains` is `True`, the whole chain is masked
- if `mask_frac` is given, the number of residues to mask is `mask_frac` times the length of the chain,
- otherwise, the number of residues to mask is sampled uniformly from the range [`lower_limit`, `upper_limit`].
If `mask_sequential` is `True`, the residues are masked based on the order in the sequence, otherwise a
spherical mask is applied based on the coordinates.
If `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain
from a polymer is sampled, the center of the masked region will be forced to be in a binding site.
Parameters
----------
data : dict
an entry generated by `ProteinDataset`
Returns
-------
chain_M : torch.Tensor
a `(B, L)` shaped binary tensor where 1 denotes the part that needs to be predicted and
0 is everything else
"""
if "cdr" in data and "cdr_id" in data:
chain_M = torch.zeros_like(data["cdr"])
if self.mask_all_cdrs:
chain_M = data["cdr"] != CDR_REVERSE["-"]
else:
chain_M = data["cdr"] == data["cdr_id"]
else:
chain_M = torch.zeros_like(data["S"])
chain_index = data["chain_id"]
chain_bool = data["chain_encoding_all"] == chain_index
if self.mask_whole_chains:
chain_M[chain_bool] = 1
else:
chains = torch.unique(data["chain_encoding_all"])
chain_start = torch.where(chain_bool)[0][0]
chain = data["X"][chain_bool]
res_i = None
interface = []
non_masked_interface = []
if len(chains) > 1 and self.force_binding_sites_frac > 0:
if random.uniform(0, 1) <= self.force_binding_sites_frac:
X_copy = data["X"]
i_indices = (chain_bool == 0).nonzero().flatten() # global
j_indices = chain_bool.nonzero().flatten() # global
distances = torch.norm(
X_copy[i_indices, 2, :]
- X_copy[j_indices, 2, :].unsqueeze(1),
dim=-1,
).cpu()
close_idx = (
np.where(torch.min(distances, dim=1)[0] <= 10)[0]
+ chain_start.item()
) # global
no_mask_idx = (
np.where(data["mask"][chain_bool])[0] + chain_start.item()
) # global
interface = np.intersect1d(close_idx, j_indices) # global
not_end_mask = np.where(
(X_copy[:, 2, :].cpu() == 0).sum(-1) != 3
)[0]
interface = np.intersect1d(interface, not_end_mask) # global
non_masked_interface = np.intersect1d(interface, no_mask_idx)
interpolate = True
if len(non_masked_interface) > 0:
res_i = non_masked_interface[
random.randint(0, len(non_masked_interface) - 1)
]
elif len(interface) > 0 and interpolate:
res_i = interface[random.randint(0, len(interface) - 1)]
else:
res_i = no_mask_idx[random.randint(0, len(no_mask_idx) - 1)]
if res_i is None:
non_zero = torch.where(data["mask"] * chain_bool)[0]
res_i = non_zero[random.randint(0, len(non_zero) - 1)]
res_coords = data["X"][res_i, 2, :]
neighbor_indices = torch.where(data["mask"][chain_bool])[0]
if self.mask_frac is not None:
assert self.mask_frac > 0 and self.mask_frac < 1
k = int(len(neighbor_indices) * self.mask_frac)
else:
up = min(
self.upper_limit, int(len(neighbor_indices) * 0.5)
) # do not mask more than half of the sequence
low = min(up - 1, self.lower_limit)
k = random.choice(range(low, up))
if self.mask_sequential:
start = max(1, res_i - chain_start - k // 2)
end = min(len(chain) - 1, res_i - chain_start + k // 2)
chain_M[chain_start + start : chain_start + end] = 1
else:
dist = torch.norm(
chain[neighbor_indices, 2, :] - res_coords.unsqueeze(0), dim=-1
)
closest_indices = neighbor_indices[
torch.topk(dist, k, largest=False)[1]
]
chain_M[closest_indices + chain_start] = 1
return chain_M
def _dihedral(self, data_entry, chains):
"""Return dihedral angles."""
return data_entry.dihedral_angles(chains)
def _sidechain(self, data_entry, chains):
"""Return Sidechain orientation (defined by the 'main atoms' in the `main_atom_dict` dictionary)."""
return data_entry.sidechain_orientation(chains)
def _chemical(self, data_entry, chains):
"""Return hemical features (hydropathy, volume, charge, polarity, acceptor/donor)."""
return data_entry.chemical_features(chains)
def _sse(self, data_entry, chains):
"""Return secondary structure features."""
return data_entry.secondary_structure(chains)
def _sidechain_coords(self, data_entry, chains):
"""Return idechain coordinates."""
return data_entry.sidechain_coordinates(chains)
def _process(
self,
filename,
rewrite=False,
max_length=None,
min_cdr_length=None,
classes_to_exclude=None,
):
"""Process a proteinflow file and save it as ProteinMPNN features."""
input_file = os.path.join(self.dataset_folder, filename)
no_extension_name = filename.split(".")[0]
data_entry = ProteinEntry.from_pickle(input_file)
if self.load_ligands:
ligands = ProteinEntry.retrieve_ligands_from_pickle(input_file)
if classes_to_exclude is not None:
if data_entry.get_protein_class() in classes_to_exclude:
return []
chains = data_entry.get_chains()
if self.entry_type == "biounit":
chain_sets = [chains]
elif self.entry_type == "chain":
chain_sets = [[x] for x in chains]
elif self.entry_type == "pair":
if len(chains) == 1:
return []
chain_sets = list(combinations(chains, 2))
else:
raise RuntimeError(
"Unknown entry type, please choose from ['biounit', 'chain', 'pair']"
)
output_names = []
if self.cut_edges:
data_entry.cut_missing_edges()
for chains_i, chain_set in enumerate(chain_sets):
output_file = os.path.join(
self.features_folder, no_extension_name + f"_{chains_i}.pickle"
)
pass_set = False
add_name = True
if os.path.exists(output_file) and not rewrite:
pass_set = True
if max_length is not None:
if data_entry.get_length(chain_set) > max_length:
add_name = False
if min_cdr_length is not None and data_entry.has_cdr():
cdr_length = data_entry.get_cdr_length(chain_set)
if not all(
[
length >= min_cdr_length
for length in cdr_length.values()
if length > 0
]
):
add_name = False
else:
if max_length is not None:
if data_entry.get_length(chains=chain_set) > max_length:
pass_set = True
add_name = False
if min_cdr_length is not None and data_entry.has_cdr():
cdr_length = data_entry.get_cdr_length(chain_set)
if not all(
[
length >= min_cdr_length
for length in cdr_length.values()
if length > 0
]
):
pass_set = True
add_name = False
if self.entry_type == "pair":
if not data_entry.is_valid_pair(*chain_set):
pass_set = True
add_name = False
out = {}
if add_name:
cdr_chain_set = set()
if data_entry.has_cdr():
out["cdr"] = torch.tensor(
data_entry.get_cdr(chain_set, encode=True)
)
chain_type_dict = data_entry.get_chain_type_dict(chain_set)
out["chain_type_dict"] = chain_type_dict
if "heavy" in chain_type_dict:
cdr_chain_set.update(
[
f"{chain_type_dict['heavy']}__{cdr}"
for cdr in ["H1", "H2", "H3"]
]
)
if "light" in chain_type_dict:
cdr_chain_set.update(
[
f"{chain_type_dict['light']}__{cdr}"
for cdr in ["L1", "L2", "L3"]
]
)
output_names.append(
(
os.path.basename(no_extension_name),
output_file,
chain_set if len(cdr_chain_set) == 0 else cdr_chain_set,
)
)
if pass_set:
continue
if self.interpolate != "none":
data_entry.interpolate_coords(fill_ends=(self.interpolate == "all"))
out["pdb_id"] = no_extension_name.split("-")[0]
out["mask_original"] = torch.tensor(
data_entry.get_mask(chain_set, original=True)
)
out["mask"] = torch.tensor(data_entry.get_mask(chain_set, original=False))
out["S"] = torch.tensor(data_entry.get_sequence(chain_set, encode=True))
out["X"] = torch.tensor(data_entry.get_coordinates(chain_set, bb_only=True))
out["residue_idx"] = torch.tensor(
data_entry.get_index_array(chain_set, index_bump=100)
)
out["chain_encoding_all"] = torch.tensor(
data_entry.get_chain_id_array(chain_set)
)
out["chain_dict"] = data_entry.get_chain_id_dict(chain_set)
if self.load_ligands and len(ligands) != 0:
(
out["X_ligands"],
out["ligand_smiles"],
out["ligand_chains"],
) = data_entry.get_ligand_features(ligands, chain_set)
for name in self.feature_types:
if name not in self.feature_functions:
continue
func = self.feature_functions[name]
out[name] = torch.tensor(func(data_entry, chain_set))
with open(output_file, "wb") as f:
pickle.dump(out, f)
return output_names
def set_cdr(self, cdr):
"""Set the CDR to be iterated over (only for SAbDab datasets).
Parameters
----------
cdr : list | str | None
The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3).
Set to `None` to go back to iterating over all chains.
"""
if not self.sabdab:
cdr = None
if isinstance(cdr, str):
cdr = [cdr]
if cdr == self.cdr:
return
self.cdr = cdr
if cdr is None:
self.indices = list(range(len(self.data)))
else:
self.indices = []
print(f"Setting CDR to {cdr}...")
for i, data in tqdm(enumerate(self.data)):
if self.clusters is not None:
if data.split("__")[1] in cdr:
self.indices.append(i)
else:
add = False
for chain in self.files[data]:
if chain.split("__")[1] in cdr:
add = True
break
if add:
self.indices.append(i)
def _to_pyg_graph(self, data):
"""Convert a dictionary of data to a PyTorch Geometric graph."""
from torch_geometric.data import Data
pyg_data = Data(x=data["X"])
for key, value in data.items():
pyg_data[key] = value.unsqueeze(0)
return pyg_data
@staticmethod
def get_anchor_ind(masked_res, mask):
"""Get the indices of the anchor residues.
Anchor residues are defined as the first and last known residues before and
after each continuous masked region.
Parameters
----------
masked_res : torch.Tensor
A boolean tensor indicating which residues should be predicted
mask : torch.Tensor
A boolean tensor indicating which residues are known
Returns
-------
list
A list of indices of the anchor residues
"""
anchor_ind = []
masked_ind = torch.where(masked_res.bool())[0]
known_ind = torch.where(mask.bool())[0]
for _, g in groupby(enumerate(masked_ind), lambda x: x[0] - x[1]):
group = map(itemgetter(1), g)
group = list(map(int, group))
start, end = group[0], group[-1]
start = (
known_ind[known_ind < start][-1]
if (known_ind < start).sum() > 0
else known_ind[0]
)
end = (
known_ind[known_ind > end][0]
if (known_ind > end).sum() > 0
else known_ind[-1]
)
anchor_ind += [start, end]
return anchor_ind
def _get_antibody_mask(self, data):
"""Get a mask for the antibody residues."""
mask = torch.zeros_like(data["mask"]).bool()
cdrs = data["cdr"]
chain_enc = data["chain_encoding_all"]
for chain_ind in data["chain_dict"].values():
chain_mask = chain_enc == chain_ind
chain_cdrs = cdrs[chain_mask]
if len(torch.unique(chain_cdrs)) > 1:
mask[chain_mask] = True
return mask
def _patch(self, data):
"""Cut the data around the anchor residues."""
# adapted from diffab
pos_alpha = data["X"][:, 2]
if self.mask_whole_chains:
mask_ = (data["mask"] * data["masked_res"]).bool()
anchor_points = pos_alpha[mask_].mean(0).unsqueeze(0)
anchor_ind = []
else:
anchor_ind = self.get_anchor_ind(data["masked_res"], data["mask"])
anchor_points = torch.stack([pos_alpha[ind] for ind in anchor_ind], dim=0)
dist_anchor = torch.cdist(pos_alpha, anchor_points, p=2).min(dim=1)[0] # (L, )
dist_anchor[~data["mask"].bool()] = float("+inf")
initial_patch_idx = torch.topk(
dist_anchor,
k=min(self.initial_patch_size, dist_anchor.size(0)),
largest=False,
sorted=True,
)[
1
] # (initial_patch_size, )
patch_mask = data["masked_res"].bool().clone()
patch_mask[[int(x) for x in anchor_ind]] = True
patch_mask[initial_patch_idx] = True
if self.sabdab:
antibody_mask = self._get_antibody_mask(data)
antigen_mask = ~antibody_mask
dist_anchor_antigen = dist_anchor.masked_fill(
mask=antibody_mask, value=float("+inf") # Fill antibody with +inf
) # (L, )
antigen_patch_idx = torch.topk(
dist_anchor_antigen,
k=min(self.antigen_patch_size, antigen_mask.sum().item()),
largest=False,
)[
1
] # (ag_size, )
patch_mask[antigen_patch_idx] = True
for key, value in data.items():
if isinstance(value, torch.Tensor):
data[key] = value[patch_mask]
return data
def __len__(self):
"""Return the number of clusters or data entries in the dataset."""
return len(self.indices)
def __getitem__(self, idx):
"""Return an entry from the dataset.
If a clusters file is provided, then the idx is the index of the cluster
and the chain is randomly selected from the cluster. Otherwise, the idx
is the index of the data entry and the chain is randomly selected from
the data entry.
"""
chain_id = None
cdr = None
idx = self.indices[idx]
if self.clusters is None:
id = self.data[idx] # data is already filtered by length
chain_id = random.choice(list(self.files[id].keys()))
if self.cdr is not None:
while chain_id.split("__")[1] not in self.cdr:
chain_id = random.choice(list(self.files[id].keys()))
else:
cluster = self.data[idx]
id = None
chain_n = -1
while (
id is None or len(self.files[id][chain_id]) == 0
): # some IDs can be filtered out by length
if self.shuffle_clusters:
chain_n = random.randint(0, len(self.clusters[cluster]) - 1)
else:
chain_n += 1
id, chain_id = self.clusters[cluster][
chain_n
] # get id and chain from cluster
file = random.choice(self.files[id][chain_id])
if "__" in chain_id:
chain_id, cdr = chain_id.split("__")
if self.loaded is None:
with open(file, "rb") as f:
try:
data = pickle.load(f)
except EOFError:
print("EOFError", file)
raise
else:
data = deepcopy(self.loaded[file])
data["chain_id"] = data["chain_dict"][chain_id]
if cdr is not None:
data["cdr_id"] = CDR_REVERSE[cdr]
if self.mask_residues:
data["masked_res"] = self._get_masked_sequence(data)
if self.patch_around_mask:
data = self._patch(data)
if self.pyg_graph:
data = self._to_pyg_graph(data)
return data
Classes
class ProteinDataset (dataset_folder, features_folder='./data/tmp/', clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate='none', node_features_type='zeros', debug_file_path=None, entry_type='biounit', classes_to_exclude=None, shuffle_clusters=True, min_cdr_length=None, feature_functions=None, classes_dict_path=None, cut_edges=False, mask_residues=True, lower_limit=15, upper_limit=100, mask_frac=None, mask_whole_chains=False, mask_sequential=False, force_binding_sites_frac=0.15, mask_all_cdrs=False, load_ligands=False, pyg_graph=False, patch_around_mask=False, initial_patch_size=128, antigen_patch_size=128, require_antigen=False, require_light_chain=False, require_no_light_chain=False, require_heavy_chain=False)
-
Dataset to load proteinflow data.
Saves the model input tensors as pickle files in
features_folder
. Whenclustering_dict_path
is provided, at each iteration a random biounit from a cluster is sampled.If a complex contains multiple chains, they are concatenated. The sequence identity information is preserved in the
'chain_encoding_all'
object and in the'residue_idx'
arrays the chain change is denoted by a +100 jump.Returns dictionaries with the following keys and values (all values are
torch
tensors):'X'
: 3D coordinates of N, C, Ca, O,(total_L, 4, 3)
,'S'
: sequence indices (shape(total_L)
),'mask'
: residue mask (0 where coordinates are missing, 1 otherwise; with interpolation 0s are replaced with 1s),(total_L)
,'mask_original'
: residue mask (0 where coordinates are missing, 1 otherwise; not changed with interpolation),(total_L)
,'residue_idx'
: residue indices (from 0 to length of sequence, +100 where chains change),(total_L)
,'chain_encoding_all'
: chain indices,(total_L)
,'chain_id
': a sampled chain index,'chain_dict'
: a dictionary of chain ids (keys are chain ids, e.g.'A'
, values are the indices used in'chain_id'
and'chain_encoding_all'
objects)
You can also choose to include additional features (set in the
node_features_type
parameter):'sidechain_orientation'
: a unit vector in the direction of the sidechain,(total_L, 3)
,'dihedral'
: the dihedral angles,(total_L, 2)
,'chemical'
: hydropathy, volume, charge, polarity, acceptor/donor features,(total_L, 6)
,'secondary_structure'
: a one-hot encoding of secondary structure ([alpha-helix, beta-sheet, coil]),(total_L, 3)
,'sidechain_coords'
: the coordinates of the sidechain atoms (seesidechain_order()
for the order),(total_L, 10, 3)
.
If the dataset contains a
'cdr'
key (if it was generated from SAbDab files), the output files will also additionally contain a'cdr'
key with a CDR tensor of lengthtotal_L
. In the array, the CDR residues are marked with the corresponding CDR type (H1=1, H2=2, H3=3, L1=4, L2=5, L3=6) and the rest of the residues are marked with 0s.Use the
set_cdr
method to only iterate over specific CDRs.In order to compute additional features, use the
feature_functions
parameter. It should be a dictionary with keys corresponding to the feature names and values corresponding to the functions that compute the features. The functions should take aProteinEntry
instance and a list of chains and return anumpy
array shaped as(#residues, #features)
where#residues
is the total number of residues in those chains and the features are concatenated in the order of the list:func(data_entry: ProteinEntry, chains: list) -> np.ndarray
.If
mask_residues
isTrue
, an additional'masked_res'
key is added to the output. The value is a binary tensor shaped(B, L)
where 1 denotes the part that needs to be predicted and 0 is everything else. The tensors are generated according to the following rulesd: - if the dataset is generated from SAbDab files, the sampled CDR is masked, - ifmask_whole_chains
isTrue
, the whole chain is masked, - ifmask_frac
is given, the number of residues to mask ismask_frac
times the length of the chain, - otherwise, the number of residues to mask is sampled uniformly from the range [lower_limit
,upper_limit
].If
force_binding_sites_frac
> 0 andmask_whole_chains
isFalse
, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site (in PDB datasets).Initialize the dataset.
Parameters
dataset_folder
:str
- the path to the folder with proteinflow format input files (assumes that files are named {biounit_id}.pickle)
features_folder
:str
, default"./data/tmp/"
- the path to the folder where the ProteinMPNN features will be saved
clustering_dict_path
:str
, optional- path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples)
max_length
:int
, optional- entries with total length of chains larger than
max_length
will be disregarded rewrite
:bool
, defaultFalse
- if
False
, existing feature files are not overwritten use_fraction
:float
, default1
- the fraction of the clusters to use (first N in alphabetic order)
load_to_ram
:bool
, defaultFalse
- if
True
, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash) debug
:bool
, defaultFalse
- only process 1000 files
interpolate
:{"none", "only_middle", "all"}
"none"
for no interpolation,"only_middle"
for only linear interpolation in the middle,"all"
for linear interpolation + ends generationnode_features_type
:{"zeros", "dihedral", "sidechain_orientation", "chemical", "secondary_structure"
orcombinations with "+"}
- the type of node features, e.g.
"dihedral"
or"sidechain_orientation+chemical"
debug_file_path
:str
, optional- if not
None
, open this single file instead of loading the dataset entry_type
:{"biounit", "chain", "pair"}
- the type of entries to generate (
"biounit"
for biounit-level complexes,"chain"
for chain-level,"pair"
for chain-chain pairs (all pairs that are seen in the same biounit and have intersecting coordinate clouds)) classes_to_exclude
:list
ofstr
, optional- a list of classes to exclude from the dataset (select from
"single_chain"
,"heteromer"
,"homomer"
) shuffle_clusters
:bool
, defaultTrue
- if
True
, a new representative is randomly selected for each cluster at each epoch (ifclustering_dict_path
is given) min_cdr_length
:int
, optional- for SAbDab datasets, biounits with CDRs shorter than
min_cdr_length
will be excluded feature_functions
:dict
, optional- a dictionary of functions to compute additional features (keys are the names of the features, values are the functions)
classes_dict_path
:str
, optional- a path to a pickled dictionary with biounit classes (single chain / heteromer / homomer)
cut_edges
:bool
, defaultFalse
- if
True
, missing values at the edges of the sequence will be cut off mask_residues
:bool
, defaultTrue
- if
True
, the masked residues will be added to the output lower_limit
:int
, default15
- the lower limit of the number of residues to mask
upper_limit
:int
, default100
- the upper limit of the number of residues to mask
mask_frac
:float
, optional- if given, the number of residues to mask is
mask_frac
times the length of the chain mask_whole_chains
:bool
, defaultFalse
- if
True
, the whole chain is masked mask_sequential
:bool
, defaultFalse
- if
True
, the masked residues will be neighbors in the sequence; otherwise a geometric mask is applied based on the coordinates force_binding_sites_frac
:float
, default0.15
- if
force_binding_sites_frac
> 0 andmask_whole_chains
isFalse
, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site (in PDB datasets) mask_all_cdrs
:bool
, defaultFalse
- if
True
, all CDRs will be masked (in SAbDab datasets) load_ligands
:bool
, defaultFalse
- if
True
, the ligands will be loaded as well pyg_graph
:bool
, defaultFalse
- if
True
, the output will be atorch_geometric.data.Data
object instead of a dictionary patch_around_mask
:bool
, defaultFalse
- if
True
, the data entries will be cut around the masked region initial_patch_size
:int
, default128
- the size of the initial patch (used if
patch_around_mask
isTrue
) antigen_patch_size
:int
, default128
- the size of the antigen patch (used if
patch_around_mask
isTrue
and the dataset is SAbDab) require_antigen
:bool
, defaultFalse
- if
True
, only entries with an antigen will be included (used if the dataset is SAbDab) require_light_chain
:bool
, defaultFalse
- if
True
, only entries with a light chain will be included (used if the dataset is SAbDab) require_no_light_chain
:bool
, defaultFalse
- if
True
, only entries without a light chain will be included (used if the dataset is SAbDab) requre_heavy_chain
:bool
, defaultFalse
- if
True
, only entries with a heavy chain will be included (used if the dataset is SAbDab)
Expand source code
class ProteinDataset(Dataset): """Dataset to load proteinflow data. Saves the model input tensors as pickle files in `features_folder`. When `clustering_dict_path` is provided, at each iteration a random biounit from a cluster is sampled. If a complex contains multiple chains, they are concatenated. The sequence identity information is preserved in the `'chain_encoding_all'` object and in the `'residue_idx'` arrays the chain change is denoted by a +100 jump. Returns dictionaries with the following keys and values (all values are `torch` tensors): - `'X'`: 3D coordinates of N, C, Ca, O, `(total_L, 4, 3)`, - `'S'`: sequence indices (shape `(total_L)`), - `'mask'`: residue mask (0 where coordinates are missing, 1 otherwise; with interpolation 0s are replaced with 1s), `(total_L)`, - `'mask_original'`: residue mask (0 where coordinates are missing, 1 otherwise; not changed with interpolation), `(total_L)`, - `'residue_idx'`: residue indices (from 0 to length of sequence, +100 where chains change), `(total_L)`, - `'chain_encoding_all'`: chain indices, `(total_L)`, - `'chain_id`': a sampled chain index, - `'chain_dict'`: a dictionary of chain ids (keys are chain ids, e.g. `'A'`, values are the indices used in `'chain_id'` and `'chain_encoding_all'` objects) You can also choose to include additional features (set in the `node_features_type` parameter): - `'sidechain_orientation'`: a unit vector in the direction of the sidechain, `(total_L, 3)`, - `'dihedral'`: the dihedral angles, `(total_L, 2)`, - `'chemical'`: hydropathy, volume, charge, polarity, acceptor/donor features, `(total_L, 6)`, - `'secondary_structure'`: a one-hot encoding of secondary structure ([alpha-helix, beta-sheet, coil]), `(total_L, 3)`, - `'sidechain_coords'`: the coordinates of the sidechain atoms (see `proteinflow.sidechain_order()` for the order), `(total_L, 10, 3)`. If the dataset contains a `'cdr'` key (if it was generated from SAbDab files), the output files will also additionally contain a `'cdr'` key with a CDR tensor of length `total_L`. In the array, the CDR residues are marked with the corresponding CDR type (H1=1, H2=2, H3=3, L1=4, L2=5, L3=6) and the rest of the residues are marked with 0s. Use the `set_cdr` method to only iterate over specific CDRs. In order to compute additional features, use the `feature_functions` parameter. It should be a dictionary with keys corresponding to the feature names and values corresponding to the functions that compute the features. The functions should take a `proteinflow.data.ProteinEntry` instance and a list of chains and return a `numpy` array shaped as `(#residues, #features)` where `#residues` is the total number of residues in those chains and the features are concatenated in the order of the list: `func(data_entry: ProteinEntry, chains: list) -> np.ndarray`. If `mask_residues` is `True`, an additional `'masked_res'` key is added to the output. The value is a binary tensor shaped `(B, L)` where 1 denotes the part that needs to be predicted and 0 is everything else. The tensors are generated according to the following rulesd: - if the dataset is generated from SAbDab files, the sampled CDR is masked, - if `mask_whole_chains` is `True`, the whole chain is masked, - if `mask_frac` is given, the number of residues to mask is `mask_frac` times the length of the chain, - otherwise, the number of residues to mask is sampled uniformly from the range [`lower_limit`, `upper_limit`]. If `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site (in PDB datasets). """ def __init__( self, dataset_folder, features_folder="./data/tmp/", clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate="none", node_features_type="zeros", debug_file_path=None, entry_type="biounit", # biounit, chain, pair classes_to_exclude=None, # heteromers, homomers, single_chains shuffle_clusters=True, min_cdr_length=None, feature_functions=None, classes_dict_path=None, cut_edges=False, mask_residues=True, lower_limit=15, upper_limit=100, mask_frac=None, mask_whole_chains=False, mask_sequential=False, force_binding_sites_frac=0.15, mask_all_cdrs=False, load_ligands=False, pyg_graph=False, patch_around_mask=False, initial_patch_size=128, antigen_patch_size=128, require_antigen=False, require_light_chain=False, require_no_light_chain=False, require_heavy_chain=False, ): """Initialize the dataset. Parameters ---------- dataset_folder : str the path to the folder with proteinflow format input files (assumes that files are named {biounit_id}.pickle) features_folder : str, default "./data/tmp/" the path to the folder where the ProteinMPNN features will be saved clustering_dict_path : str, optional path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples) max_length : int, optional entries with total length of chains larger than `max_length` will be disregarded rewrite : bool, default False if `False`, existing feature files are not overwritten use_fraction : float, default 1 the fraction of the clusters to use (first N in alphabetic order) load_to_ram : bool, default False if `True`, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash) debug : bool, default False only process 1000 files interpolate : {"none", "only_middle", "all"} `"none"` for no interpolation, `"only_middle"` for only linear interpolation in the middle, `"all"` for linear interpolation + ends generation node_features_type : {"zeros", "dihedral", "sidechain_orientation", "chemical", "secondary_structure" or combinations with "+"} the type of node features, e.g. `"dihedral"` or `"sidechain_orientation+chemical"` debug_file_path : str, optional if not `None`, open this single file instead of loading the dataset entry_type : {"biounit", "chain", "pair"} the type of entries to generate (`"biounit"` for biounit-level complexes, `"chain"` for chain-level, `"pair"` for chain-chain pairs (all pairs that are seen in the same biounit and have intersecting coordinate clouds)) classes_to_exclude : list of str, optional a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`) shuffle_clusters : bool, default True if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given) min_cdr_length : int, optional for SAbDab datasets, biounits with CDRs shorter than `min_cdr_length` will be excluded feature_functions : dict, optional a dictionary of functions to compute additional features (keys are the names of the features, values are the functions) classes_dict_path : str, optional a path to a pickled dictionary with biounit classes (single chain / heteromer / homomer) cut_edges : bool, default False if `True`, missing values at the edges of the sequence will be cut off mask_residues : bool, default True if `True`, the masked residues will be added to the output lower_limit : int, default 15 the lower limit of the number of residues to mask upper_limit : int, default 100 the upper limit of the number of residues to mask mask_frac : float, optional if given, the number of residues to mask is `mask_frac` times the length of the chain mask_whole_chains : bool, default False if `True`, the whole chain is masked mask_sequential : bool, default False if `True`, the masked residues will be neighbors in the sequence; otherwise a geometric mask is applied based on the coordinates force_binding_sites_frac : float, default 0.15 if `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site (in PDB datasets) mask_all_cdrs : bool, default False if `True`, all CDRs will be masked (in SAbDab datasets) load_ligands : bool, default False if `True`, the ligands will be loaded as well pyg_graph : bool, default False if `True`, the output will be a `torch_geometric.data.Data` object instead of a dictionary patch_around_mask : bool, default False if `True`, the data entries will be cut around the masked region initial_patch_size : int, default 128 the size of the initial patch (used if `patch_around_mask` is `True`) antigen_patch_size : int, default 128 the size of the antigen patch (used if `patch_around_mask` is `True` and the dataset is SAbDab) require_antigen : bool, default False if `True`, only entries with an antigen will be included (used if the dataset is SAbDab) require_light_chain : bool, default False if `True`, only entries with a light chain will be included (used if the dataset is SAbDab) require_no_light_chain : bool, default False if `True`, only entries without a light chain will be included (used if the dataset is SAbDab) requre_heavy_chain : bool, default False if `True`, only entries with a heavy chain will be included (used if the dataset is SAbDab) """ self.debug = False if classes_dict_path is None: dataset_parent = os.path.dirname(dataset_folder) classes_dict_path = os.path.join( dataset_parent, "splits_dict", "classes.pickle" ) if not os.path.exists(classes_dict_path): classes_dict_path = None alphabet = ALPHABET self.alphabet_dict = defaultdict(lambda: 0) for i, letter in enumerate(alphabet): self.alphabet_dict[letter] = i self.alphabet_dict["X"] = 0 self.files = defaultdict(lambda: defaultdict(list)) # file path by biounit id self.loaded = None self.dataset_folder = dataset_folder self.features_folder = features_folder self.cut_edges = cut_edges self.mask_residues = mask_residues self.lower_limit = lower_limit self.upper_limit = upper_limit self.mask_frac = mask_frac self.mask_whole_chains = mask_whole_chains self.force_binding_sites_frac = force_binding_sites_frac self.mask_all_cdrs = mask_all_cdrs self.load_ligands = load_ligands self.pyg_graph = pyg_graph self.patch_around_mask = patch_around_mask self.initial_patch_size = initial_patch_size self.antigen_patch_size = antigen_patch_size self.mask_sequential = mask_sequential self.feature_types = [] if node_features_type is not None: self.feature_types = node_features_type.split("+") self.entry_type = entry_type self.shuffle_clusters = shuffle_clusters self.feature_functions = { "sidechain_orientation": self._sidechain, "dihedral": self._dihedral, "chemical": self._chemical, "secondary_structure": self._sse, "sidechain_coords": self._sidechain_coords, } self.feature_functions.update(feature_functions or {}) if classes_to_exclude is not None and not all( [x in ["single_chain", "heteromer", "homomer"] for x in classes_to_exclude] ): raise ValueError( "Invalid class to exclude, choose from 'single_chain', 'heteromer', 'homomer'" ) if debug_file_path is not None: self.dataset_folder = os.path.dirname(debug_file_path) debug_file_path = os.path.basename(debug_file_path) self.main_atom_dict = defaultdict(lambda: None) d1to3 = {v: k for k, v in D3TO1.items()} for i, letter in enumerate(alphabet): if i == 0: continue self.main_atom_dict[i] = MAIN_ATOMS[d1to3[letter]] # create feature folder if it does not exist if not os.path.exists(self.features_folder): os.makedirs(self.features_folder) self.interpolate = interpolate # generate the feature files print("Processing files...") if debug_file_path is None: to_process = [ x for x in os.listdir(dataset_folder) if x.endswith(".pickle") ] else: to_process = [debug_file_path] if clustering_dict_path is not None and use_fraction < 1: with open(clustering_dict_path, "rb") as f: clusters = pickle.load(f) keys = sorted(clusters.keys())[: int(len(clusters) * use_fraction)] to_process = set() for key in keys: to_process.update([x[0] for x in clusters[key]]) file_set = set(os.listdir(dataset_folder)) to_process = [x for x in to_process if x in file_set] if debug: to_process = to_process[:1000] if self.entry_type == "pair": print( "Please note that the pair entry type takes longer to process than the other two. The progress bar is not linear because of the varying number of chains per file." ) output_tuples_list = p_map( lambda x: self._process( x, rewrite=rewrite, max_length=max_length, min_cdr_length=min_cdr_length, classes_to_exclude=classes_to_exclude, ), to_process, ) # save the file names for output_tuples in output_tuples_list: for id, filename, chain_set in output_tuples: for chain in chain_set: self.files[id][chain].append(filename) if classes_to_exclude is None: classes_to_exclude = [] classes = None if classes_dict_path is not None: with open(classes_dict_path, "rb") as f: classes = pickle.load(f) if clustering_dict_path is not None: with open(clustering_dict_path, "rb") as f: self.clusters = pickle.load(f) # list of biounit ids by cluster id if classes is None: # old way of storing class information try: classes = pickle.load(f) except EOFError: pass else: self.clusters = None if classes is None and len(classes_to_exclude) > 0: raise ValueError( "Classes to exclude are given but no classes dictionary is found, please set classes_dict_path to the path of the classes dictionary" ) to_exclude = set() # if classes is not None: # for c in classes_to_exclude: # for key, id_arr in classes.get(c, {}).items(): # for id, _ in id_arr: # to_exclude.add(id) if require_antigen or require_light_chain: to_exclude.update( self._exclude_by_chains( require_antigen, require_light_chain, require_no_light_chain, require_heavy_chain, ) ) if self.clusters is not None: self._exclude_ids_from_clusters(to_exclude) self.data = list(self.clusters.keys()) else: self.data = [x for x in self.files.keys() if x not in to_exclude] # create a smaller dataset if necessary (if we have clustering it's applied earlier) if self.clusters is None and use_fraction < 1: self.data = sorted(self.data)[: int(len(self.data) * use_fraction)] if load_to_ram: print("Loading to RAM...") self.loaded = {} seen = set() for id in self.files: for chain, file_list in self.files[id].items(): for file in file_list: if file in seen: continue seen.add(file) with open(file, "rb") as f: self.loaded[file] = pickle.load(f) sample_file = list(self.files.keys())[0] sample_chain = list(self.files[sample_file].keys())[0] self.sabdab = "__" in sample_chain self.cdr = 0 self.set_cdr(None) def _exclude_ids_from_clusters(self, to_exclude): for key in list(self.clusters.keys()): cluster_list = [] for x in self.clusters[key]: if x[0] in to_exclude: continue id = x[0].split(".")[0] chain = x[1] if id not in self.files: continue if chain not in self.files[id]: continue if len(self.files[id][chain]) == 0: continue cluster_list.append([id, chain]) self.clusters[key] = cluster_list if len(self.clusters[key]) == 0: self.clusters.pop(key) def _check_chain_types(self, file): chain_types = set() with open(file, "rb") as f: data = pickle.load(f) chains = data["chain_dict"].values() for chain in chains: chain_mask = data["chain_encoding_all"] == chain cdr = data["cdr"][chain_mask] cdr_values = cdr.unique() if len(cdr_values) == 1: chain_types.add("antigen") elif CDR_REVERSE["H1"] in cdr_values: chain_types.add("heavy") elif CDR_REVERSE["L1"] in cdr_values: chain_types.add("light") return chain_types def _exclude_by_class( self, classes_to_exclude, ): """Exclude entries that are in the classes to exclude.""" to_exclude = set() for id in self.files: for chain in self.files[id]: filename = self.files[id][chain][0] with open(filename, "rb") as f: data = pickle.load(f) if classes_to_exclude in data["classes"]: to_exclude.add(id) return to_exclude def _exclude_by_chains( self, require_antigen, require_light_chain, require_no_light_chain, require_heavy_chain, ): """Exclude entries that do not have an antigen or a light chain.""" to_exclude = set() for id in self.files: filename = list(self.files[id].values())[0][ 0 ] # assuming entry type is biounit chain_types = self._check_chain_types(filename) if require_antigen and "antigen" not in chain_types: to_exclude.add(id) if require_light_chain and "light" not in chain_types: to_exclude.add(id) if require_no_light_chain and "light" in chain_types: to_exclude.add(id) if require_heavy_chain and "heavy" not in chain_types: to_exclude.add(id) return to_exclude def _get_masked_sequence( self, data, ): """Get the mask for the residues that need to be predicted. Depending on the parameters the residues are selected as follows: - if `mask_whole_chains` is `True`, the whole chain is masked - if `mask_frac` is given, the number of residues to mask is `mask_frac` times the length of the chain, - otherwise, the number of residues to mask is sampled uniformly from the range [`lower_limit`, `upper_limit`]. If `mask_sequential` is `True`, the residues are masked based on the order in the sequence, otherwise a spherical mask is applied based on the coordinates. If `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site. Parameters ---------- data : dict an entry generated by `ProteinDataset` Returns ------- chain_M : torch.Tensor a `(B, L)` shaped binary tensor where 1 denotes the part that needs to be predicted and 0 is everything else """ if "cdr" in data and "cdr_id" in data: chain_M = torch.zeros_like(data["cdr"]) if self.mask_all_cdrs: chain_M = data["cdr"] != CDR_REVERSE["-"] else: chain_M = data["cdr"] == data["cdr_id"] else: chain_M = torch.zeros_like(data["S"]) chain_index = data["chain_id"] chain_bool = data["chain_encoding_all"] == chain_index if self.mask_whole_chains: chain_M[chain_bool] = 1 else: chains = torch.unique(data["chain_encoding_all"]) chain_start = torch.where(chain_bool)[0][0] chain = data["X"][chain_bool] res_i = None interface = [] non_masked_interface = [] if len(chains) > 1 and self.force_binding_sites_frac > 0: if random.uniform(0, 1) <= self.force_binding_sites_frac: X_copy = data["X"] i_indices = (chain_bool == 0).nonzero().flatten() # global j_indices = chain_bool.nonzero().flatten() # global distances = torch.norm( X_copy[i_indices, 2, :] - X_copy[j_indices, 2, :].unsqueeze(1), dim=-1, ).cpu() close_idx = ( np.where(torch.min(distances, dim=1)[0] <= 10)[0] + chain_start.item() ) # global no_mask_idx = ( np.where(data["mask"][chain_bool])[0] + chain_start.item() ) # global interface = np.intersect1d(close_idx, j_indices) # global not_end_mask = np.where( (X_copy[:, 2, :].cpu() == 0).sum(-1) != 3 )[0] interface = np.intersect1d(interface, not_end_mask) # global non_masked_interface = np.intersect1d(interface, no_mask_idx) interpolate = True if len(non_masked_interface) > 0: res_i = non_masked_interface[ random.randint(0, len(non_masked_interface) - 1) ] elif len(interface) > 0 and interpolate: res_i = interface[random.randint(0, len(interface) - 1)] else: res_i = no_mask_idx[random.randint(0, len(no_mask_idx) - 1)] if res_i is None: non_zero = torch.where(data["mask"] * chain_bool)[0] res_i = non_zero[random.randint(0, len(non_zero) - 1)] res_coords = data["X"][res_i, 2, :] neighbor_indices = torch.where(data["mask"][chain_bool])[0] if self.mask_frac is not None: assert self.mask_frac > 0 and self.mask_frac < 1 k = int(len(neighbor_indices) * self.mask_frac) else: up = min( self.upper_limit, int(len(neighbor_indices) * 0.5) ) # do not mask more than half of the sequence low = min(up - 1, self.lower_limit) k = random.choice(range(low, up)) if self.mask_sequential: start = max(1, res_i - chain_start - k // 2) end = min(len(chain) - 1, res_i - chain_start + k // 2) chain_M[chain_start + start : chain_start + end] = 1 else: dist = torch.norm( chain[neighbor_indices, 2, :] - res_coords.unsqueeze(0), dim=-1 ) closest_indices = neighbor_indices[ torch.topk(dist, k, largest=False)[1] ] chain_M[closest_indices + chain_start] = 1 return chain_M def _dihedral(self, data_entry, chains): """Return dihedral angles.""" return data_entry.dihedral_angles(chains) def _sidechain(self, data_entry, chains): """Return Sidechain orientation (defined by the 'main atoms' in the `main_atom_dict` dictionary).""" return data_entry.sidechain_orientation(chains) def _chemical(self, data_entry, chains): """Return hemical features (hydropathy, volume, charge, polarity, acceptor/donor).""" return data_entry.chemical_features(chains) def _sse(self, data_entry, chains): """Return secondary structure features.""" return data_entry.secondary_structure(chains) def _sidechain_coords(self, data_entry, chains): """Return idechain coordinates.""" return data_entry.sidechain_coordinates(chains) def _process( self, filename, rewrite=False, max_length=None, min_cdr_length=None, classes_to_exclude=None, ): """Process a proteinflow file and save it as ProteinMPNN features.""" input_file = os.path.join(self.dataset_folder, filename) no_extension_name = filename.split(".")[0] data_entry = ProteinEntry.from_pickle(input_file) if self.load_ligands: ligands = ProteinEntry.retrieve_ligands_from_pickle(input_file) if classes_to_exclude is not None: if data_entry.get_protein_class() in classes_to_exclude: return [] chains = data_entry.get_chains() if self.entry_type == "biounit": chain_sets = [chains] elif self.entry_type == "chain": chain_sets = [[x] for x in chains] elif self.entry_type == "pair": if len(chains) == 1: return [] chain_sets = list(combinations(chains, 2)) else: raise RuntimeError( "Unknown entry type, please choose from ['biounit', 'chain', 'pair']" ) output_names = [] if self.cut_edges: data_entry.cut_missing_edges() for chains_i, chain_set in enumerate(chain_sets): output_file = os.path.join( self.features_folder, no_extension_name + f"_{chains_i}.pickle" ) pass_set = False add_name = True if os.path.exists(output_file) and not rewrite: pass_set = True if max_length is not None: if data_entry.get_length(chain_set) > max_length: add_name = False if min_cdr_length is not None and data_entry.has_cdr(): cdr_length = data_entry.get_cdr_length(chain_set) if not all( [ length >= min_cdr_length for length in cdr_length.values() if length > 0 ] ): add_name = False else: if max_length is not None: if data_entry.get_length(chains=chain_set) > max_length: pass_set = True add_name = False if min_cdr_length is not None and data_entry.has_cdr(): cdr_length = data_entry.get_cdr_length(chain_set) if not all( [ length >= min_cdr_length for length in cdr_length.values() if length > 0 ] ): pass_set = True add_name = False if self.entry_type == "pair": if not data_entry.is_valid_pair(*chain_set): pass_set = True add_name = False out = {} if add_name: cdr_chain_set = set() if data_entry.has_cdr(): out["cdr"] = torch.tensor( data_entry.get_cdr(chain_set, encode=True) ) chain_type_dict = data_entry.get_chain_type_dict(chain_set) out["chain_type_dict"] = chain_type_dict if "heavy" in chain_type_dict: cdr_chain_set.update( [ f"{chain_type_dict['heavy']}__{cdr}" for cdr in ["H1", "H2", "H3"] ] ) if "light" in chain_type_dict: cdr_chain_set.update( [ f"{chain_type_dict['light']}__{cdr}" for cdr in ["L1", "L2", "L3"] ] ) output_names.append( ( os.path.basename(no_extension_name), output_file, chain_set if len(cdr_chain_set) == 0 else cdr_chain_set, ) ) if pass_set: continue if self.interpolate != "none": data_entry.interpolate_coords(fill_ends=(self.interpolate == "all")) out["pdb_id"] = no_extension_name.split("-")[0] out["mask_original"] = torch.tensor( data_entry.get_mask(chain_set, original=True) ) out["mask"] = torch.tensor(data_entry.get_mask(chain_set, original=False)) out["S"] = torch.tensor(data_entry.get_sequence(chain_set, encode=True)) out["X"] = torch.tensor(data_entry.get_coordinates(chain_set, bb_only=True)) out["residue_idx"] = torch.tensor( data_entry.get_index_array(chain_set, index_bump=100) ) out["chain_encoding_all"] = torch.tensor( data_entry.get_chain_id_array(chain_set) ) out["chain_dict"] = data_entry.get_chain_id_dict(chain_set) if self.load_ligands and len(ligands) != 0: ( out["X_ligands"], out["ligand_smiles"], out["ligand_chains"], ) = data_entry.get_ligand_features(ligands, chain_set) for name in self.feature_types: if name not in self.feature_functions: continue func = self.feature_functions[name] out[name] = torch.tensor(func(data_entry, chain_set)) with open(output_file, "wb") as f: pickle.dump(out, f) return output_names def set_cdr(self, cdr): """Set the CDR to be iterated over (only for SAbDab datasets). Parameters ---------- cdr : list | str | None The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3). Set to `None` to go back to iterating over all chains. """ if not self.sabdab: cdr = None if isinstance(cdr, str): cdr = [cdr] if cdr == self.cdr: return self.cdr = cdr if cdr is None: self.indices = list(range(len(self.data))) else: self.indices = [] print(f"Setting CDR to {cdr}...") for i, data in tqdm(enumerate(self.data)): if self.clusters is not None: if data.split("__")[1] in cdr: self.indices.append(i) else: add = False for chain in self.files[data]: if chain.split("__")[1] in cdr: add = True break if add: self.indices.append(i) def _to_pyg_graph(self, data): """Convert a dictionary of data to a PyTorch Geometric graph.""" from torch_geometric.data import Data pyg_data = Data(x=data["X"]) for key, value in data.items(): pyg_data[key] = value.unsqueeze(0) return pyg_data @staticmethod def get_anchor_ind(masked_res, mask): """Get the indices of the anchor residues. Anchor residues are defined as the first and last known residues before and after each continuous masked region. Parameters ---------- masked_res : torch.Tensor A boolean tensor indicating which residues should be predicted mask : torch.Tensor A boolean tensor indicating which residues are known Returns ------- list A list of indices of the anchor residues """ anchor_ind = [] masked_ind = torch.where(masked_res.bool())[0] known_ind = torch.where(mask.bool())[0] for _, g in groupby(enumerate(masked_ind), lambda x: x[0] - x[1]): group = map(itemgetter(1), g) group = list(map(int, group)) start, end = group[0], group[-1] start = ( known_ind[known_ind < start][-1] if (known_ind < start).sum() > 0 else known_ind[0] ) end = ( known_ind[known_ind > end][0] if (known_ind > end).sum() > 0 else known_ind[-1] ) anchor_ind += [start, end] return anchor_ind def _get_antibody_mask(self, data): """Get a mask for the antibody residues.""" mask = torch.zeros_like(data["mask"]).bool() cdrs = data["cdr"] chain_enc = data["chain_encoding_all"] for chain_ind in data["chain_dict"].values(): chain_mask = chain_enc == chain_ind chain_cdrs = cdrs[chain_mask] if len(torch.unique(chain_cdrs)) > 1: mask[chain_mask] = True return mask def _patch(self, data): """Cut the data around the anchor residues.""" # adapted from diffab pos_alpha = data["X"][:, 2] if self.mask_whole_chains: mask_ = (data["mask"] * data["masked_res"]).bool() anchor_points = pos_alpha[mask_].mean(0).unsqueeze(0) anchor_ind = [] else: anchor_ind = self.get_anchor_ind(data["masked_res"], data["mask"]) anchor_points = torch.stack([pos_alpha[ind] for ind in anchor_ind], dim=0) dist_anchor = torch.cdist(pos_alpha, anchor_points, p=2).min(dim=1)[0] # (L, ) dist_anchor[~data["mask"].bool()] = float("+inf") initial_patch_idx = torch.topk( dist_anchor, k=min(self.initial_patch_size, dist_anchor.size(0)), largest=False, sorted=True, )[ 1 ] # (initial_patch_size, ) patch_mask = data["masked_res"].bool().clone() patch_mask[[int(x) for x in anchor_ind]] = True patch_mask[initial_patch_idx] = True if self.sabdab: antibody_mask = self._get_antibody_mask(data) antigen_mask = ~antibody_mask dist_anchor_antigen = dist_anchor.masked_fill( mask=antibody_mask, value=float("+inf") # Fill antibody with +inf ) # (L, ) antigen_patch_idx = torch.topk( dist_anchor_antigen, k=min(self.antigen_patch_size, antigen_mask.sum().item()), largest=False, )[ 1 ] # (ag_size, ) patch_mask[antigen_patch_idx] = True for key, value in data.items(): if isinstance(value, torch.Tensor): data[key] = value[patch_mask] return data def __len__(self): """Return the number of clusters or data entries in the dataset.""" return len(self.indices) def __getitem__(self, idx): """Return an entry from the dataset. If a clusters file is provided, then the idx is the index of the cluster and the chain is randomly selected from the cluster. Otherwise, the idx is the index of the data entry and the chain is randomly selected from the data entry. """ chain_id = None cdr = None idx = self.indices[idx] if self.clusters is None: id = self.data[idx] # data is already filtered by length chain_id = random.choice(list(self.files[id].keys())) if self.cdr is not None: while chain_id.split("__")[1] not in self.cdr: chain_id = random.choice(list(self.files[id].keys())) else: cluster = self.data[idx] id = None chain_n = -1 while ( id is None or len(self.files[id][chain_id]) == 0 ): # some IDs can be filtered out by length if self.shuffle_clusters: chain_n = random.randint(0, len(self.clusters[cluster]) - 1) else: chain_n += 1 id, chain_id = self.clusters[cluster][ chain_n ] # get id and chain from cluster file = random.choice(self.files[id][chain_id]) if "__" in chain_id: chain_id, cdr = chain_id.split("__") if self.loaded is None: with open(file, "rb") as f: try: data = pickle.load(f) except EOFError: print("EOFError", file) raise else: data = deepcopy(self.loaded[file]) data["chain_id"] = data["chain_dict"][chain_id] if cdr is not None: data["cdr_id"] = CDR_REVERSE[cdr] if self.mask_residues: data["masked_res"] = self._get_masked_sequence(data) if self.patch_around_mask: data = self._patch(data) if self.pyg_graph: data = self._to_pyg_graph(data) return data
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Static methods
def get_anchor_ind(masked_res, mask)
-
Get the indices of the anchor residues.
Anchor residues are defined as the first and last known residues before and after each continuous masked region.
Parameters
masked_res
:torch.Tensor
- A boolean tensor indicating which residues should be predicted
mask
:torch.Tensor
- A boolean tensor indicating which residues are known
Returns
list
- A list of indices of the anchor residues
Expand source code
@staticmethod def get_anchor_ind(masked_res, mask): """Get the indices of the anchor residues. Anchor residues are defined as the first and last known residues before and after each continuous masked region. Parameters ---------- masked_res : torch.Tensor A boolean tensor indicating which residues should be predicted mask : torch.Tensor A boolean tensor indicating which residues are known Returns ------- list A list of indices of the anchor residues """ anchor_ind = [] masked_ind = torch.where(masked_res.bool())[0] known_ind = torch.where(mask.bool())[0] for _, g in groupby(enumerate(masked_ind), lambda x: x[0] - x[1]): group = map(itemgetter(1), g) group = list(map(int, group)) start, end = group[0], group[-1] start = ( known_ind[known_ind < start][-1] if (known_ind < start).sum() > 0 else known_ind[0] ) end = ( known_ind[known_ind > end][0] if (known_ind > end).sum() > 0 else known_ind[-1] ) anchor_ind += [start, end] return anchor_ind
Methods
def set_cdr(self, cdr)
-
Set the CDR to be iterated over (only for SAbDab datasets).
Parameters
cdr
:list | str | None
- The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3).
Set to
None
to go back to iterating over all chains.
Expand source code
def set_cdr(self, cdr): """Set the CDR to be iterated over (only for SAbDab datasets). Parameters ---------- cdr : list | str | None The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3). Set to `None` to go back to iterating over all chains. """ if not self.sabdab: cdr = None if isinstance(cdr, str): cdr = [cdr] if cdr == self.cdr: return self.cdr = cdr if cdr is None: self.indices = list(range(len(self.data))) else: self.indices = [] print(f"Setting CDR to {cdr}...") for i, data in tqdm(enumerate(self.data)): if self.clusters is not None: if data.split("__")[1] in cdr: self.indices.append(i) else: add = False for chain in self.files[data]: if chain.split("__")[1] in cdr: add = True break if add: self.indices.append(i)
class ProteinLoader (dataset, collate_func=proteinflow.data.torch._PadCollate, shuffle_batches=True, *args, **kwargs)
-
A subclass of
torch.data.utils.DataLoader
tuned for theproteinflow
dataset.Creates and iterates over an instance of
ProteinDataset
, omitting the'chain_dict'
keys. See theProteinDataset
documentation for more information.If batch size is larger than one, all objects are padded with zeros at the ends to reach the length of the longest protein in the batch.
Initialize a ProteinLoader instance.
Parameters
dataset
:ProteinDataset
- a ProteinDataset instance
shuffle_batches
:bool
, defaultTrue
- if
True
, the batches are shuffled at each epoch collate_func
:callable
, optional- a function that takes a list of samples and returns a batch and inherits from _PadCollate
Expand source code
class ProteinLoader(DataLoader): """A subclass of `torch.data.utils.DataLoader` tuned for the `proteinflow` dataset. Creates and iterates over an instance of `ProteinDataset`, omitting the `'chain_dict'` keys. See the `ProteinDataset` documentation for more information. If batch size is larger than one, all objects are padded with zeros at the ends to reach the length of the longest protein in the batch. """ def __init__( self, dataset, collate_func=_PadCollate, shuffle_batches=True, *args, **kwargs, ): """Initialize a ProteinLoader instance. Parameters ---------- dataset : ProteinDataset a ProteinDataset instance shuffle_batches : bool, default True if `True`, the batches are shuffled at each epoch collate_func : callable, optional a function that takes a list of samples and returns a batch and inherits from _PadCollate """ super().__init__( dataset, collate_fn=collate_func(), shuffle=shuffle_batches, *args, **kwargs, ) @staticmethod def from_args( dataset_folder, features_folder="./data/tmp/", clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate="none", node_features_type=None, entry_type="biounit", # biounit, chain, pair classes_to_exclude=None, lower_limit=15, upper_limit=100, mask_residues=True, mask_whole_chains=False, mask_frac=None, force_binding_sites_frac=0, shuffle_clusters=True, shuffle_batches=True, mask_all_cdrs=False, classes_dict_path=None, load_ligands=False, cut_edges=False, require_antigen=False, require_light_chain=False, require_no_light_chain=False, require_heavy_chain=False, *args, **kwargs, ) -> None: """Create a `ProteinLoader` instance with a `ProteinDataset` from the given arguments. Parameters ---------- dataset_folder : str the path to the folder with proteinflow format input files (assumes that files are named {biounit_id}.pickle) features_folder : str the path to the folder where the ProteinMPNN features will be saved clustering_dict_path : str, optional path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples) max_length : int, optional entries with total length of chains larger than `max_length` will be disregarded rewrite : bool, default False if `False`, existing feature files are not overwritten use_fraction : float, default 1 the fraction of the clusters to use (first N in alphabetic order) load_to_ram : bool, default False if `True`, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash) debug : bool, default False only process 1000 files interpolate : {"none", "only_middle", "all"} `"none"` for no interpolation, `"only_middle"` for only linear interpolation in the middle, `"all"` for linear interpolation + ends generation node_features_type : {"dihedral", "sidechain_orientation", "chemical", "secondary_structure", "sidechain_coords", or combinations with "+"}, optional the type of node features, e.g. `"dihedral"` or `"sidechain_orientation+chemical"` entry_type : {"biounit", "chain", "pair"} the type of entries to generate (`"biounit"` for biounit-level, `"chain"` for chain-level, `"pair"` for chain-chain pairs) classes_to_exclude : list of str, optional a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`) lower_limit : int, default 15 the minimum number of residues to mask upper_limit : int, default 100 the maximum number of residues to mask mask_residues : bool, default True if `True`, generate a mask key mask_whole_chains : bool, default False if `True`, `upper_limit`, `force_binding_sites` and `lower_limit` are ignored and the whole chain is masked instead mask_frac : float, optional if given, the `lower_limit` and `upper_limit` are ignored and the number of residues to mask is `mask_frac` times the length of the chain force_binding_sites_frac : float, default 0 if > 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site shuffle_clusters : bool, default True if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given) shuffle_batches : bool, default True if `True`, the batches are shuffled at each epoch mask_all_cdrs : bool, default False if `True`, all CDRs are masked instead of just the sampled one classes_dict_path : str, optional path to the pickled classes dictionary; if not given, we will try to find split dictionaries in the parent folder of `dataset_folder` load_ligands : bool, default False if `True`, the ligands will be loaded from the PDB files and added to the features cut_edges : bool, default False if `True`, missing values at the edges of the sequence will be cut off require_antigen : bool, default False if `True`, only entries with an antigen will be included (used if the dataset is SAbDab) require_light_chain : bool, default False if `True`, only entries with a light chain will be included (used if the dataset is SAbDab) require_no_light_chain : bool, default False if `True`, only entries without a light chain will be included (used if the dataset is SAbDab) require_heavy_chain : bool, default False if `True`, only entries with a heavy chain will be included (used if the dataset is SAbDab) *args additional arguments to `torch.utils.data.DataLoader` **kwargs additional keyword arguments to `torch.utils.data.DataLoader` """ dataset = ProteinDataset( dataset_folder=dataset_folder, features_folder=features_folder, clustering_dict_path=clustering_dict_path, max_length=max_length, rewrite=rewrite, use_fraction=use_fraction, load_to_ram=load_to_ram, debug=debug, interpolate=interpolate, node_features_type=node_features_type, entry_type=entry_type, classes_to_exclude=classes_to_exclude, shuffle_clusters=shuffle_clusters, classes_dict_path=classes_dict_path, lower_limit=lower_limit, upper_limit=upper_limit, mask_residues=mask_residues, mask_whole_chains=mask_whole_chains, mask_frac=mask_frac, force_binding_sites_frac=force_binding_sites_frac, mask_all_cdrs=mask_all_cdrs, load_ligands=load_ligands, cut_edges=cut_edges, require_antigen=require_antigen, require_light_chain=require_light_chain, require_no_light_chain=require_no_light_chain, require_heavy_chain=require_heavy_chain, ) return ProteinLoader( dataset=dataset, shuffle_batches=shuffle_batches, *args, **kwargs, )
Ancestors
- torch.utils.data.dataloader.DataLoader
- typing.Generic
Class variables
var batch_size : Optional[int]
var dataset : torch.utils.data.dataset.Dataset[+T_co]
var drop_last : bool
var num_workers : int
var pin_memory : bool
var pin_memory_device : str
var prefetch_factor : Optional[int]
var sampler : Union[torch.utils.data.sampler.Sampler, Iterable]
var timeout : float
Static methods
def from_args(dataset_folder, features_folder='./data/tmp/', clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate='none', node_features_type=None, entry_type='biounit', classes_to_exclude=None, lower_limit=15, upper_limit=100, mask_residues=True, mask_whole_chains=False, mask_frac=None, force_binding_sites_frac=0, shuffle_clusters=True, shuffle_batches=True, mask_all_cdrs=False, classes_dict_path=None, load_ligands=False, cut_edges=False, require_antigen=False, require_light_chain=False, require_no_light_chain=False, require_heavy_chain=False, *args, **kwargs) ‑> None
-
Create a
ProteinLoader
instance with aProteinDataset
from the given arguments.Parameters
dataset_folder
:str
- the path to the folder with proteinflow format input files (assumes that files are named {biounit_id}.pickle)
features_folder
:str
- the path to the folder where the ProteinMPNN features will be saved
clustering_dict_path
:str
, optional- path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples)
max_length
:int
, optional- entries with total length of chains larger than
max_length
will be disregarded rewrite
:bool
, defaultFalse
- if
False
, existing feature files are not overwritten use_fraction
:float
, default1
- the fraction of the clusters to use (first N in alphabetic order)
load_to_ram
:bool
, defaultFalse
- if
True
, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash) debug
:bool
, defaultFalse
- only process 1000 files
interpolate
:{"none", "only_middle", "all"}
"none"
for no interpolation,"only_middle"
for only linear interpolation in the middle,"all"
for linear interpolation + ends generationnode_features_type
:{"dihedral", "sidechain_orientation", "chemical", "secondary_structure", "sidechain_coords",
orcombinations with "+"}
, optional- the type of node features, e.g.
"dihedral"
or"sidechain_orientation+chemical"
entry_type
:{"biounit", "chain", "pair"}
- the type of entries to generate (
"biounit"
for biounit-level,"chain"
for chain-level,"pair"
for chain-chain pairs) classes_to_exclude
:list
ofstr
, optional- a list of classes to exclude from the dataset (select from
"single_chain"
,"heteromer"
,"homomer"
) lower_limit
:int
, default15
- the minimum number of residues to mask
upper_limit
:int
, default100
- the maximum number of residues to mask
mask_residues
:bool
, defaultTrue
- if
True
, generate a mask key mask_whole_chains
:bool
, defaultFalse
- if
True
,upper_limit
,force_binding_sites
andlower_limit
are ignored and the whole chain is masked instead mask_frac
:float
, optional- if given, the
lower_limit
andupper_limit
are ignored and the number of residues to mask ismask_frac
times the length of the chain force_binding_sites_frac
:float
, default0
- if > 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site
shuffle_clusters
:bool
, defaultTrue
- if
True
, a new representative is randomly selected for each cluster at each epoch (ifclustering_dict_path
is given) shuffle_batches
:bool
, defaultTrue
- if
True
, the batches are shuffled at each epoch mask_all_cdrs
:bool
, defaultFalse
- if
True
, all CDRs are masked instead of just the sampled one classes_dict_path
:str
, optional- path to the pickled classes dictionary; if not given, we will try to find split dictionaries in the parent folder of
dataset_folder
load_ligands
:bool
, defaultFalse
- if
True
, the ligands will be loaded from the PDB files and added to the features cut_edges
:bool
, defaultFalse
- if
True
, missing values at the edges of the sequence will be cut off require_antigen
:bool
, defaultFalse
- if
True
, only entries with an antigen will be included (used if the dataset is SAbDab) require_light_chain
:bool
, defaultFalse
- if
True
, only entries with a light chain will be included (used if the dataset is SAbDab) require_no_light_chain
:bool
, defaultFalse
- if
True
, only entries without a light chain will be included (used if the dataset is SAbDab) require_heavy_chain
:bool
, defaultFalse
- if
True
, only entries with a heavy chain will be included (used if the dataset is SAbDab) *args
- additional arguments to
torch.utils.data.DataLoader
**kwargs
- additional keyword arguments to
torch.utils.data.DataLoader
Expand source code
@staticmethod def from_args( dataset_folder, features_folder="./data/tmp/", clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate="none", node_features_type=None, entry_type="biounit", # biounit, chain, pair classes_to_exclude=None, lower_limit=15, upper_limit=100, mask_residues=True, mask_whole_chains=False, mask_frac=None, force_binding_sites_frac=0, shuffle_clusters=True, shuffle_batches=True, mask_all_cdrs=False, classes_dict_path=None, load_ligands=False, cut_edges=False, require_antigen=False, require_light_chain=False, require_no_light_chain=False, require_heavy_chain=False, *args, **kwargs, ) -> None: """Create a `ProteinLoader` instance with a `ProteinDataset` from the given arguments. Parameters ---------- dataset_folder : str the path to the folder with proteinflow format input files (assumes that files are named {biounit_id}.pickle) features_folder : str the path to the folder where the ProteinMPNN features will be saved clustering_dict_path : str, optional path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples) max_length : int, optional entries with total length of chains larger than `max_length` will be disregarded rewrite : bool, default False if `False`, existing feature files are not overwritten use_fraction : float, default 1 the fraction of the clusters to use (first N in alphabetic order) load_to_ram : bool, default False if `True`, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash) debug : bool, default False only process 1000 files interpolate : {"none", "only_middle", "all"} `"none"` for no interpolation, `"only_middle"` for only linear interpolation in the middle, `"all"` for linear interpolation + ends generation node_features_type : {"dihedral", "sidechain_orientation", "chemical", "secondary_structure", "sidechain_coords", or combinations with "+"}, optional the type of node features, e.g. `"dihedral"` or `"sidechain_orientation+chemical"` entry_type : {"biounit", "chain", "pair"} the type of entries to generate (`"biounit"` for biounit-level, `"chain"` for chain-level, `"pair"` for chain-chain pairs) classes_to_exclude : list of str, optional a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`) lower_limit : int, default 15 the minimum number of residues to mask upper_limit : int, default 100 the maximum number of residues to mask mask_residues : bool, default True if `True`, generate a mask key mask_whole_chains : bool, default False if `True`, `upper_limit`, `force_binding_sites` and `lower_limit` are ignored and the whole chain is masked instead mask_frac : float, optional if given, the `lower_limit` and `upper_limit` are ignored and the number of residues to mask is `mask_frac` times the length of the chain force_binding_sites_frac : float, default 0 if > 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be forced to be in a binding site shuffle_clusters : bool, default True if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given) shuffle_batches : bool, default True if `True`, the batches are shuffled at each epoch mask_all_cdrs : bool, default False if `True`, all CDRs are masked instead of just the sampled one classes_dict_path : str, optional path to the pickled classes dictionary; if not given, we will try to find split dictionaries in the parent folder of `dataset_folder` load_ligands : bool, default False if `True`, the ligands will be loaded from the PDB files and added to the features cut_edges : bool, default False if `True`, missing values at the edges of the sequence will be cut off require_antigen : bool, default False if `True`, only entries with an antigen will be included (used if the dataset is SAbDab) require_light_chain : bool, default False if `True`, only entries with a light chain will be included (used if the dataset is SAbDab) require_no_light_chain : bool, default False if `True`, only entries without a light chain will be included (used if the dataset is SAbDab) require_heavy_chain : bool, default False if `True`, only entries with a heavy chain will be included (used if the dataset is SAbDab) *args additional arguments to `torch.utils.data.DataLoader` **kwargs additional keyword arguments to `torch.utils.data.DataLoader` """ dataset = ProteinDataset( dataset_folder=dataset_folder, features_folder=features_folder, clustering_dict_path=clustering_dict_path, max_length=max_length, rewrite=rewrite, use_fraction=use_fraction, load_to_ram=load_to_ram, debug=debug, interpolate=interpolate, node_features_type=node_features_type, entry_type=entry_type, classes_to_exclude=classes_to_exclude, shuffle_clusters=shuffle_clusters, classes_dict_path=classes_dict_path, lower_limit=lower_limit, upper_limit=upper_limit, mask_residues=mask_residues, mask_whole_chains=mask_whole_chains, mask_frac=mask_frac, force_binding_sites_frac=force_binding_sites_frac, mask_all_cdrs=mask_all_cdrs, load_ligands=load_ligands, cut_edges=cut_edges, require_antigen=require_antigen, require_light_chain=require_light_chain, require_no_light_chain=require_no_light_chain, require_heavy_chain=require_heavy_chain, ) return ProteinLoader( dataset=dataset, shuffle_batches=shuffle_batches, *args, **kwargs, )