From 1594b02e239d1879d74de49dbeb6c08c913b1880 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Tue, 29 Aug 2023 15:14:29 +0200 Subject: [PATCH] Create RefHasher type When only the Abundance type used refhashing, we could get away with it only being a simple function. However, now that I want marker gene prediction to also use refhashing, the existing implementation did not work, and I would need to have two differing implementation that absolutely needed to give the same result. Instead, make a RefHasher type such that the invariants are stored in one place, and also generalize the error message improvement from commit 37277e2. --- test/test_parsebam.py | 2 +- test/test_vambtools.py | 25 ++++++--- vamb/parsebam.py | 46 +++------------- vamb/parsecontigs.py | 4 +- vamb/vambtools.py | 75 +++++++++++++++++++++++---- workflow_avamb/avamb.snake.conda.smk | 1 - workflow_avamb/src/abundances_mask.py | 4 +- 7 files changed, 96 insertions(+), 61 deletions(-) diff --git a/test/test_parsebam.py b/test/test_parsebam.py index 77ec2ba3..a82b1aa1 100644 --- a/test/test_parsebam.py +++ b/test/test_parsebam.py @@ -33,7 +33,7 @@ def test_refhash(self): # Change the refnames slighty cp.identifiers = cp.identifiers.copy() cp.identifiers[3] = cp.identifiers[3] + "w" - cp.refhash = vamb.vambtools.hash_refnames(cp.identifiers) + cp.refhash = vamb.vambtools.RefHasher.hash_refnames(cp.identifiers) with self.assertRaises(ValueError): vamb.parsebam.Abundance.from_files( testtools.BAM_FILES, None, cp, True, 0.97, 4 diff --git a/test/test_vambtools.py b/test/test_vambtools.py index c2727e6a..90767a6d 100644 --- a/test/test_vambtools.py +++ b/test/test_vambtools.py @@ -340,24 +340,33 @@ def test_torch(self): class TestHashRefNames(unittest.TestCase): def test_refhash(self): names = ["foo", "9", "eleven", "a"] - b1 = vamb.vambtools.hash_refnames(names) + b1 = vamb.vambtools.RefHasher.hash_refnames(names) + + # Test that hashing them all at once is the same as hashing them one at a time + hasher = vamb.vambtools.RefHasher() + hasher.add_refname(names[0]) + hasher.add_refname(names[1]) + for j in names[2:]: + hasher.add_refname(j) + b7 = hasher.digest() + names[1] = names[1] + "x" - b2 = vamb.vambtools.hash_refnames(names) + b2 = vamb.vambtools.RefHasher.hash_refnames(names) names[1] = names[1][:-1] + " \t" # it strips whitespace off right end - b3 = vamb.vambtools.hash_refnames(names) + b3 = vamb.vambtools.RefHasher.hash_refnames(names) names = names[::-1] - b4 = vamb.vambtools.hash_refnames(names) + b4 = vamb.vambtools.RefHasher.hash_refnames(names) + names = (i + " " for i in names[::-1]) - b5 = vamb.vambtools.hash_refnames(names) - b6 = vamb.vambtools.hash_refnames(names) # now empty generator - b7 = vamb.vambtools.hash_refnames([]) + b5 = vamb.vambtools.RefHasher.hash_refnames(names) + b6 = vamb.vambtools.RefHasher.hash_refnames(names) # now empty generator self.assertNotEqual(b1, b2) self.assertEqual(b1, b3) self.assertNotEqual(b1, b4) self.assertEqual(b1, b5) self.assertNotEqual(b1, b6) - self.assertEqual(b6, b7) + self.assertEqual(b1, b7) class TestBinSplit(unittest.TestCase): diff --git a/vamb/parsebam.py b/vamb/parsebam.py index 54bf8ede..b2f73812 100644 --- a/vamb/parsebam.py +++ b/vamb/parsebam.py @@ -12,7 +12,6 @@ from vamb.parsecontigs import CompositionMetaData from vamb import vambtools from typing import Optional, TypeVar, Union, IO, Sequence, Iterable -from itertools import zip_longest from pathlib import Path import shutil @@ -52,39 +51,6 @@ def nseqs(self) -> int: def nsamples(self) -> int: return len(self.samplenames) - @staticmethod - def verify_refhash( - refhash: bytes, - target_refhash: bytes, - identifiers: Optional[tuple[Iterable[str], Iterable[str]]], - ) -> None: - if refhash != target_refhash: - if identifiers is not None: - for i, (fasta_id, bam_id) in enumerate(zip_longest(*identifiers)): - if fasta_id is None: - raise ValueError( - f"FASTA has only {i} identifier(s), which is fewer than BAM file" - ) - elif bam_id is None: - raise ValueError( - f"BAM has only {i} identifier(s), which is fewer than FASTA file" - ) - elif fasta_id != bam_id: - raise ValueError( - f"Identifier number {i+1} does not match for FASTA and BAM files:" - f'FASTA identifier: "{fasta_id}"' - f'BAM identifier: "{bam_id}"' - ) - assert False - else: - raise ValueError( - f"At least one BAM file reference name hash to {refhash.hex()}, " - f"expected {target_refhash.hex()}. " - "Make sure all BAM and FASTA identifiers are identical " - "and in the same order. " - "Note that the identifier is the header before any whitespace." - ) - def save(self, io: Union[Path, IO[bytes]]): _np.savez_compressed( io, @@ -106,7 +72,9 @@ def load( arrs["refhash"].item(), ) if refhash is not None: - cls.verify_refhash(abundance.refhash, refhash, None) + vambtools.RefHasher.verify_refhash( + abundance.refhash, refhash, "Loaded", None, None + ) return abundance @@ -250,14 +218,16 @@ def run_pycoverm( headers = [h for (h, m) in zip(headers, mask) if m] vambtools.numpy_inplace_maskarray(coverage, mask) - refhash = vambtools.hash_refnames(headers) + refhash = vambtools.RefHasher.hash_refnames(headers) if target_identifiers is None: identifier_pairs = None else: - identifier_pairs = (target_identifiers, headers) + identifier_pairs = (headers, target_identifiers) if target_refhash is not None: - Abundance.verify_refhash(refhash, target_refhash, identifier_pairs) + vambtools.RefHasher.verify_refhash( + refhash, target_refhash, "Composition", "BAM", identifier_pairs + ) return (coverage, refhash) diff --git a/vamb/parsecontigs.py b/vamb/parsecontigs.py index 92b4b3a3..42e8b023 100644 --- a/vamb/parsecontigs.py +++ b/vamb/parsecontigs.py @@ -55,7 +55,7 @@ def __init__( self.lengths = lengths self.mask = mask self.minlength = minlength - self.refhash = _vambtools.hash_refnames(identifiers) + self.refhash = _vambtools.RefHasher.hash_refnames(identifiers) @property def nseqs(self) -> int: @@ -73,7 +73,7 @@ def filter_mask(self, mask: Sequence[bool]): self.identifiers = self.identifiers[mask] self.lengths = self.lengths[mask] - self.refhash = _vambtools.hash_refnames(self.identifiers) + self.refhash = _vambtools.RefHasher.hash_refnames(self.identifiers) def filter_min_length(self, length: int): "Set or reset minlength of this object" diff --git a/vamb/vambtools.py b/vamb/vambtools.py index 93aee3e4..3d11f4e1 100644 --- a/vamb/vambtools.py +++ b/vamb/vambtools.py @@ -9,6 +9,7 @@ import re as _re from vamb._vambtools import _kmercounts, _overwrite_matrix import collections as _collections +from itertools import zip_longest from hashlib import md5 as _md5 from collections.abc import Iterable, Iterator, Generator from typing import Optional, IO, Union @@ -329,6 +330,71 @@ def byte_iterfasta( yield FastaEntry(header, bytearray().join(buffer)) +class RefHasher: + __slots__ = ["hasher"] + + def __init__(self): + self.hasher = _md5() + + def add_refname(self, ref: str) -> None: + self.hasher.update(ref.encode().rstrip()) + + def add_refnames(self, refs: Iterable[str]): + for ref in refs: + self.add_refname(ref) + return self + + @classmethod + def hash_refnames(cls, refs: Iterable[str]) -> bytes: + return cls().add_refnames(refs).digest() + + def digest(self) -> bytes: + return self.hasher.digest() + + @staticmethod + def verify_refhash( + refhash: bytes, + target_refhash: bytes, + observed_name: Optional[str], + target_name: Optional[str], + identifiers: Optional[tuple[Iterable[str], Iterable[str]]], + ) -> None: + if refhash == target_refhash: + return None + + obs_name = "Observed" if observed_name is None else observed_name + tgt_name = "Target" if target_name is None else target_name + if identifiers is not None: + (observed_ids, target_ids) = identifiers + for i, (observed_id, target_id) in enumerate( + zip_longest(observed_ids, target_ids) + ): + if observed_id is None: + raise ValueError( + f"{obs_name} identifiers has only {i} identifier(s), which is fewer than {tgt_name}" + ) + elif target_id is None: + raise ValueError( + f"{tgt_name} identifiers has only {i} identifier(s), which is ffewer than {obs_name}" + ) + elif observed_id != target_id: + raise ValueError( + f"Identifier number {i+1} does not match between {obs_name} and {tgt_name}:" + f'{obs_name}: "{observed_id}"' + f'{tgt_name}: "{target_id}"' + ) + assert False + else: + raise ValueError( + f"Mismatch between reference hash of {obs_name} and {tgt_name}." + f"Observed {obs_name} hash: {refhash.hex()}." + f"Expected {tgt_name} hash: {target_refhash.hex()}" + "Make sure all identifiers are identical " + "and in the same order. " + "Note that the identifier is the header before any whitespace." + ) + + def write_clusters( filehandle: IO[str], clusters: Iterable[tuple[str, set[str]]], @@ -589,15 +655,6 @@ def concatenate_fasta( raise e from None -def hash_refnames(refnames: Iterable[str]) -> bytes: - "Hashes an iterable of strings of reference names using MD5." - hasher = _md5() - for refname in refnames: - hasher.update(refname.encode().rstrip()) - - return hasher.digest() - - def _split_bin( binname: str, headers: Iterable[str], diff --git a/workflow_avamb/avamb.snake.conda.smk b/workflow_avamb/avamb.snake.conda.smk index ddcbc00a..99d0b38c 100644 --- a/workflow_avamb/avamb.snake.conda.smk +++ b/workflow_avamb/avamb.snake.conda.smk @@ -1,7 +1,6 @@ import re import os import sys -from vamb.vambtools import concatenate_fasta, hash_refnames import numpy as np SNAKEDIR = os.path.dirname(workflow.snakefile) diff --git a/workflow_avamb/src/abundances_mask.py b/workflow_avamb/src/abundances_mask.py index f1995089..85c7cc7f 100644 --- a/workflow_avamb/src/abundances_mask.py +++ b/workflow_avamb/src/abundances_mask.py @@ -1,6 +1,6 @@ import numpy as np import argparse -from vamb.vambtools import hash_refnames +from vamb.vambtools import RefHasher from pathlib import Path @@ -24,7 +24,7 @@ def abundances_mask(headers: Path, mask_refhash: Path, min_contig_size: int): np.savez_compressed( mask_refhash, mask=np.array(mask, dtype=bool), - refhash=hash_refnames(identifiers), + refhash=RefHasher.hash_refnames(identifiers), )