diff --git a/src/main/scala/com/fulcrumgenomics/fastq/DemuxFastqs.scala b/src/main/scala/com/fulcrumgenomics/fastq/DemuxFastqs.scala index cb5d936f3..7733d69ac 100644 --- a/src/main/scala/com/fulcrumgenomics/fastq/DemuxFastqs.scala +++ b/src/main/scala/com/fulcrumgenomics/fastq/DemuxFastqs.scala @@ -29,6 +29,7 @@ import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.bam.api.{SamOrder, SamRecord, SamWriter} import com.fulcrumgenomics.cmdline.{ClpGroups, FgBioTool} import com.fulcrumgenomics.commons.CommonsDef.{DirPath, FilePath, PathPrefix, PathToFastq} +import com.fulcrumgenomics.commons.collection.LeastRecentlyUsedCache import com.fulcrumgenomics.commons.io.PathUtil import com.fulcrumgenomics.commons.util.{LazyLogging, Logger} import com.fulcrumgenomics.fastq.FastqDemultiplexer.{DemuxRecord, DemuxResult} @@ -45,6 +46,7 @@ import htsjdk.samtools.util.{Iso8601Date, SequenceUtil} import java.io.Closeable import java.util.concurrent.ForkJoinPool +import scala.collection.immutable.ArraySeq import scala.collection.mutable.ListBuffer object DemuxFastqs { @@ -379,6 +381,7 @@ class DemuxFastqs @arg(doc="Do not keep reads identified as control if true, otherwise keep all reads. Control reads are determined from the comment in the FASTQ header.") val omitControlReads: Boolean = false, @arg(doc="Mask bases with a quality score below the specified threshold as Ns") val maskBasesBelowQuality: Int = 0, + @arg(doc="The number of barcodes to cache; zero will disable the cache.") val cacheSize: Int = 1000000 ) extends FgBioTool with LazyLogging { // Support the deprecated --illumina-standards option @@ -468,7 +471,8 @@ class DemuxFastqs includeOriginal = this.includeAllBasesInFastqs, fastqStandards = this.fastqStandards, omitFailingReads = this.omitFailingReads, - omitControlReads = this.omitControlReads + omitControlReads = this.omitControlReads, + cacheSize = this.cacheSize ) val progress = ProgressLogger(this.logger, unit=1e6.toInt) @@ -763,6 +767,7 @@ private[fastq] object FastqDemultiplexer { * and skipped bases. * @param omitFailingReads true if to remove reads that don't pass QC, marked as 'N' in the header comment * @param omitControlReads false if to keep reads that are marked as internal control reads in the header comment. + * @param cacheSize the number of barcodes to cache; zero will disable the cache. */ private class FastqDemultiplexer(val sampleInfos: Seq[SampleInfo], readStructures: Seq[ReadStructure], @@ -773,7 +778,8 @@ private class FastqDemultiplexer(val sampleInfos: Seq[SampleInfo], val maxNoCalls: Int = 2, val includeOriginal: Boolean = false, val omitFailingReads: Boolean = false, - val omitControlReads: Boolean = false) { + val omitControlReads: Boolean = false, + val cacheSize: Int = 0) extends LazyLogging { import FastqDemultiplexer._ require(readStructures.nonEmpty, "No read structures were given") @@ -787,6 +793,11 @@ private class FastqDemultiplexer(val sampleInfos: Seq[SampleInfo], private val sampleInfosNoUnmatched = sampleInfos.filterNot(_.isUnmatched) private val unmatchedSample = sampleInfos.find(_.isUnmatched).getOrElse(throw new IllegalArgumentException("No unmatched sample provided.")) + private val cache = if (this.cacheSize == 0) None else { + logger.info(f"Using cache of size: $cacheSize") + Some(new LeastRecentlyUsedCache[ArraySeq[Byte], (SampleInfo, Int)](maxEntries=cacheSize)) + } + /** The number of reads that are expected to be given to the [[demultiplex()]] method. */ def expectedNumberOfReads: Int = this.variableReadStructures.length @@ -802,7 +813,20 @@ private class FastqDemultiplexer(val sampleInfos: Seq[SampleInfo], * found, the unmatched sample and [[Int.MaxValue]] are returned. */ private def matchSampleBarcode(subReads: Seq[SubRead]): (SampleInfo, Int) = { val observedBarcode = subReads.filter(_.kind == SegmentType.SampleBarcode).map(_.bases).mkString.getBytes - val numNoCalls = observedBarcode.count(base => SequenceUtil.isNoCall(base)) + cache match { + case None => matchSampleBarcode(observedBarcode=observedBarcode) + case Some(_cache) => + val key = new ArraySeq.ofByte(observedBarcode) + _cache.get(key).getOrElse { + val result = matchSampleBarcode(observedBarcode=observedBarcode) + _cache.put(key, result) + result + } + } + } + + private def matchSampleBarcode(observedBarcode: Array[Byte]): (SampleInfo, Int) = { + val numNoCalls = observedBarcode.count(base => SequenceUtil.isNoCall(base)) // Get the best and second best sample barcode matches. val (bestSampleInfo, bestMismatches, secondBestMismatches) = if (numNoCalls <= maxNoCalls) {