Spaces:
Running
Running
| """ | |
| Classification tasks take in biological sequence and functional labels. | |
| Multi-class and/or multi-label classification tasks are supported. | |
| """ | |
| import logging | |
| from collections import defaultdict | |
| import datasets | |
| import numpy as np | |
| from dgeb.eval_utils import merge_split_elem_embeds | |
| from dgeb.evaluators import ( | |
| MultiClassMultiOutputKNNClassificationEvaluator, | |
| logRegClassificationEvaluator, | |
| ) | |
| from dgeb.modality import Modality | |
| from dgeb.models import BioSeqTransformer | |
| from dgeb.tasks import Dataset, Task, TaskMetadata, TaskResult | |
| logger = logging.getLogger(__name__) | |
| def split_sequences( | |
| ds: datasets.DatasetDict, max_seq_length: int | |
| ) -> datasets.DatasetDict: | |
| """Split sequences into chunks of max_seq_length using datasets.Dataset.map().""" | |
| def _split_sequence(examples, max_seq_length): | |
| assert ( | |
| len(examples["Sequence"]) == 1 | |
| ), "split map function should use batch size of 1." | |
| example = {k: v[0] for k, v in examples.items()} | |
| seq = example["Sequence"] | |
| # Split by chunks of max_seq_length. | |
| seq_split = [ | |
| seq[i : i + max_seq_length] for i in range(0, len(seq), max_seq_length) | |
| ] | |
| # Repeat other fields by the number of splits. | |
| example = { | |
| k: [v] * len(seq_split) for k, v in example.items() if k != "Sequence" | |
| } | |
| example["Sequence"] = seq_split | |
| return example | |
| ds = ds.map( | |
| _split_sequence, | |
| batched=True, | |
| batch_size=1, | |
| fn_kwargs={"max_seq_length": max_seq_length}, | |
| keep_in_memory=True, | |
| load_from_cache_file=False, | |
| ) | |
| return ds | |
| def run_classification_task( | |
| model: BioSeqTransformer, metadata: TaskMetadata | |
| ) -> TaskResult: | |
| """Evaluate on classification tasks using logistic regression classifier.""" | |
| ds = metadata.datasets[0].load() | |
| layer_results = defaultdict(dict) | |
| train_embeds = model.encode(ds["train"]["Sequence"]) | |
| test_embeds = model.encode(ds["test"]["Sequence"]) | |
| for i, layer in enumerate(model.layers): | |
| layer_results["layers"][layer] = logRegClassificationEvaluator( | |
| train_embeds[:, i], | |
| ds["train"]["Label"], | |
| test_embeds[:, i], | |
| ds["test"]["Label"], | |
| )() | |
| logger.info( | |
| f"Layer: {layer}, {metadata.display_name} results: {layer_results['layers'][layer]}" | |
| ) | |
| return TaskResult.from_dict(metadata, layer_results, model.metadata) | |
| class EnzymeCommissionClassification(Task): | |
| metadata = TaskMetadata( | |
| id="ec_classification", | |
| display_name="EC Classification", | |
| description="Evaluate on Enzyme Commission number classification task.", | |
| type="classification", | |
| modality=Modality.PROTEIN, | |
| datasets=[ | |
| Dataset( | |
| path="tattabio/ec_classification", | |
| revision="ead5570168e6969a5149f6861e8a33d6b5d22498", | |
| ) | |
| ], | |
| primary_metric_id="f1", | |
| ) | |
| def run(self, model: BioSeqTransformer) -> TaskResult: | |
| return run_classification_task(model, self.metadata) | |
| class EnzymeCommissionDNAClassification(Task): | |
| metadata = TaskMetadata( | |
| id="ec_dna_classification", | |
| display_name="EC Classification", | |
| description="Evaluate on Enzyme Commission number classification task using DNA sequences.", | |
| type="classification", | |
| modality=Modality.DNA, | |
| datasets=[ | |
| Dataset( | |
| path="tattabio/ec_classification_dna", | |
| revision="cd61c74b4930cf9f1963e6d73ff7f14e2c8e74dd", | |
| ) | |
| ], | |
| primary_metric_id="f1", | |
| ) | |
| def run(self, model: BioSeqTransformer) -> TaskResult: | |
| return run_classification_task(model, self.metadata) | |
| class ConvergentEnzymesClassification(Task): | |
| metadata = TaskMetadata( | |
| id="convergent_enzymes_classification", | |
| display_name="Convergent Enzymes Classification", | |
| description="Evaluate on convergent enzymes classification task, where convergent enzymes are proteins with the same EC number but without blastp hits against each other", | |
| type="classification", | |
| modality=Modality.PROTEIN, | |
| datasets=[ | |
| Dataset( | |
| path="tattabio/convergent_enzymes", | |
| revision="37f75609f54de2bc0911ccb72faf1c2f5a4285aa", | |
| ) | |
| ], | |
| primary_metric_id="f1", | |
| ) | |
| def run(self, model: BioSeqTransformer) -> TaskResult: | |
| return run_classification_task(model, self.metadata) | |
| def run_mibig_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult: | |
| """ | |
| Evaluate on MIBIG classification tasks. Multiclass, multi-label KNN classification is used for evaluation. | |
| """ | |
| ds = metadata.datasets[0].load() | |
| if metadata.modality == Modality.DNA: | |
| # MIBiG DNA sequences can be very long. Instead of truncating to max_seq_length, | |
| # split into multiple sequences and mean pool the resulting embeddings. | |
| ds = split_sequences(ds, model.max_seq_length) | |
| layer_results = defaultdict(dict) | |
| train_embeds = model.encode(ds["train"]["Sequence"]) | |
| test_embeds = model.encode(ds["test"]["Sequence"]) | |
| train_ids = ds["train"]["Entry"] | |
| test_ids = ds["test"]["Entry"] | |
| train_labels = ds["train"]["class"] | |
| test_labels = ds["test"]["class"] | |
| train_id_to_label = {id: label for id, label in zip(train_ids, train_labels)} | |
| test_id_to_label = {id: label for id, label in zip(test_ids, test_labels)} | |
| # Mean pool embeds with the same ID. | |
| train_ids, train_embeds = merge_split_elem_embeds(train_ids, train_embeds) | |
| test_ids, test_embeds = merge_split_elem_embeds(test_ids, test_embeds) | |
| # Gather the labels after merging by unique ID. | |
| train_labels = np.array([train_id_to_label[id] for id in train_ids]) | |
| test_labels = np.array([test_id_to_label[id] for id in test_ids]) | |
| for i, layer in enumerate(model.layers): | |
| evaluator = MultiClassMultiOutputKNNClassificationEvaluator( | |
| train_embeds[:, i], train_labels, test_embeds[:, i], test_labels | |
| ) | |
| layer_results["layers"][layer] = evaluator() | |
| logger.info( | |
| f"Layer: {layer}, MIBiG classification results: {layer_results['layers'][layer]}" | |
| ) | |
| return TaskResult.from_dict(metadata, layer_results, model.metadata) | |
| class MIBiGProteinClassification(Task): | |
| metadata = TaskMetadata( | |
| id="MIBIG_protein_classification", | |
| display_name="MIBiG Classification", | |
| description="Biosynthetic Gene cluster classification using protein sequences on MIBIG dataset.", | |
| type="classification", | |
| modality=Modality.PROTEIN, | |
| datasets=[ | |
| Dataset( | |
| path="tattabio/mibig_classification_prot", | |
| revision="915a7ff28dc9820e35c4d7fd03d4c8c44a88ff1f", | |
| ) | |
| ], | |
| primary_metric_id="f1", | |
| ) | |
| def run(self, model: BioSeqTransformer) -> TaskResult: | |
| return run_mibig_task(model, self.metadata) | |
| class MIBiGDNAClassification(Task): | |
| metadata = TaskMetadata( | |
| id="MIBIG_dna_classification", | |
| display_name="MIBiG Classification", | |
| description="Biosynthetic Gene cluster classification using DNA sequences on MIBIG dataset.", | |
| type="classification", | |
| modality=Modality.DNA, | |
| datasets=[ | |
| Dataset( | |
| path="tattabio/mibig_classification_dna", | |
| revision="b5ca7a76d469e4e66c46f1b655903972571e6b61", | |
| ) | |
| ], | |
| primary_metric_id="f1", | |
| ) | |
| def run(self, model: BioSeqTransformer) -> TaskResult: | |
| return run_mibig_task(model, self.metadata) | |