Skip to content

Commit

Permalink
Merge branch 'fingerprinter' into 'stable'
Browse files Browse the repository at this point in the history
Buffered parallelization for Fingerprinter subtool

See merge request bioinf-mit/ms/sirius_frontend!17
  • Loading branch information
mfleisch committed Jun 12, 2023
2 parents ddae266 + 93db8ae commit d6860ce
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ public class FingerprinterOptions implements StandaloneTool<FingerprinterWorkflo
@CommandLine.Option(names = {"--version", "-v"}, description = "Specify file to write fingerprint version information to", required = false)
private Path version;

@CommandLine.Option(names = {"--bufferSize", "-b"}, description = "Specify buffer size for memory usage", required = false)
private int bufferSize;

@Override
public FingerprinterWorkflow makeWorkflow(RootOptions<?, ?, ?, ?> rootOptions, ParameterConfig config) {
return new FingerprinterWorkflow(rootOptions, outputPath, charge, version);
return new FingerprinterWorkflow(rootOptions, outputPath, charge, version, bufferSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import de.unijena.bioinf.fingerid.fingerprints.FixedFingerprinter;
import de.unijena.bioinf.fingerid.predictor_types.PredictorType;
import de.unijena.bioinf.jjobs.BasicJJob;
import de.unijena.bioinf.jjobs.JJob;
import de.unijena.bioinf.ms.frontend.core.ApplicationCore;
import de.unijena.bioinf.ms.frontend.subtools.RootOptions;
import de.unijena.bioinf.ms.frontend.workflow.Workflow;
Expand All @@ -40,96 +39,122 @@
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
public class FingerprinterWorkflow implements Workflow {

private final Path outputFile;
private final RootOptions<?, ?, ?, ?> rootOptions;

private final Path versionFile;
private final int charge;
private final BlockingSet<BasicJJob<Void>> jobs;
private FingerIdData fdata;
private MaskedFingerprintVersion mask;
private CdkFingerprintVersion cdkVersion;
private final Map<String, Exception> failedComputations = new ConcurrentHashMap<>();

public FingerprinterWorkflow(RootOptions<?, ?, ?, ?> rootOptions, Path outputFile, int charge, Path versionFile) {
public FingerprinterWorkflow(RootOptions<?, ?, ?, ?> rootOptions, Path outputFile, int charge, Path versionFile, int bufferSize) {
this.outputFile = outputFile;
this.rootOptions = rootOptions;
this.charge = charge;
this.versionFile = versionFile;

int bufferSizeInit = bufferSize == 0 ? (5 * SiriusJobs.getCPUThreads()) : bufferSize;
this.jobs = new BlockingSet<>(bufferSizeInit);
}

@Override
public void run() {
//todo maybe fixed size buffer for memory optimization

List<Path> in = rootOptions.getInput().getAllFiles();
if (in.isEmpty())
throw new IllegalArgumentException("No input file given!");

Path inputFile = in.iterator().next();
// get WEB API
WebAPI<?> api = ApplicationCore.WEB_API;

try {
//get FingerIdDate File based on charge
FingerIdData fdata = charge > 0 ? api.getFingerIdData(PredictorType.CSI_FINGERID_POSITIVE) : api.getFingerIdData(PredictorType.CSI_FINGERID_NEGATIVE);
MaskedFingerprintVersion mask = fdata.getFingerprintVersion();
CdkFingerprintVersion cdkVersion = api.getCDKChemDBFingerprintVersion();

LoggerFactory.getLogger(getClass()).info("Reading input from '" + inputFile.toString() + "'...");
List<String> smilesList = readInput(inputFile);

List<BasicJJob<SmilesFpt>> jobs = new ArrayList<>();

LoggerFactory.getLogger(getClass()).info("Creating fingerprint jobs for '" +smilesList.size() + "' input structures.");
for (String smiles : smilesList) {
BasicJJob<SmilesFpt> fpt_job = new BasicJJob<>() {
@Override
protected SmilesFpt compute() {
FixedFingerprinter printer = new FixedFingerprinter(cdkVersion);
return new SmilesFpt(smiles, mask.mask(printer.computeFingerprintFromSMILES(smiles).toIndizesArray()));
}
};
jobs.add(fpt_job);
loadFingerprintVersionData();

try (BufferedWriter bw = Files.newBufferedWriter(outputFile)) {

// Creating producer job that reads smiles, build worker jobs adds them to a fixed blocking set
BasicJJob<Void> producer = buildProducer(inputFile, bw);
SiriusJobs.getGlobalJobManager().submitJob(producer);
producer.awaitResult();
LoggerFactory.getLogger(getClass()).info("DONE!");

} catch (IOException | ExecutionException e) {
LoggerFactory.getLogger(getClass()).error("Unexpected error during fingerprint computation", e);
}

if (!failedComputations.isEmpty()) {
LoggerFactory.getLogger(getClass()).info("Following smiles could not be computed:");
for (String smiles : failedComputations.keySet()) {
LoggerFactory.getLogger(getClass()).info(smiles + ": " + failedComputations.get(smiles).toString());
}
}

LoggerFactory.getLogger(getClass()).info("Computing fingerprints...");
jobs = SiriusJobs.getGlobalJobManager().submitJobsInBatchesByThreads(jobs, SiriusJobs.getCPUThreads());
//collect jobs skipping failed ones (null)
List<SmilesFpt> outList = jobs.stream().map(JJob::getResult).filter(Objects::nonNull).toList();

LoggerFactory.getLogger(getClass()).info("Writing fingerprints to '" + outputFile.toString() + "'...");
writeOutput(outputFile, outList);
if (versionFile != null) {
LoggerFactory.getLogger(getClass()).info("Writing fingerprint definition file to '" + versionFile.toString() + "'...");
try (BufferedWriter bw = Files.newBufferedWriter(versionFile)) {
FingerIdData.write(bw, fdata);
}
if (versionFile != null) {
LoggerFactory.getLogger(getClass()).info("Writing fingerprint definition file to '" + versionFile.toString() + "'...");
try (BufferedWriter bw1 = Files.newBufferedWriter(versionFile)) {
FingerIdData.write(bw1, fdata);
} catch (IOException e) {
LoggerFactory.getLogger(getClass()).error(String.valueOf(e));
}
LoggerFactory.getLogger(getClass()).info("DONE!");
}
}

public void loadFingerprintVersionData() {
// get WEB API
WebAPI<?> api = ApplicationCore.WEB_API;
try {
// get FingerIdDate File based on charge
fdata = charge > 0 ? api.getFingerIdData(PredictorType.CSI_FINGERID_POSITIVE) : api.getFingerIdData(PredictorType.CSI_FINGERID_NEGATIVE);
mask = fdata.getFingerprintVersion();
cdkVersion = api.getCDKChemDBFingerprintVersion();
} catch (IOException e) {
LoggerFactory.getLogger(getClass()).error("Unexpected error during fingerprint computation", e);
LoggerFactory.getLogger(getClass()).error("Unexpected error during api access", e);
}
}

public List<String> readInput(Path in) throws IOException {
List<String> smiles = new ArrayList<>();
try (BufferedReader reader = Files.newBufferedReader(in)) {
String line;
while ((line = reader.readLine()) != null) {
if (line.length() > 0) smiles.add(line);
public BasicJJob<Void> buildProducer(Path inputFile, BufferedWriter bw) {
return new BasicJJob<>() {
@Override
protected Void compute() {
try (BufferedReader br = Files.newBufferedReader(inputFile)) {
String smiles;
while ((smiles = br.readLine()) != null) {
if (smiles.length() > 0) {
BasicJJob<Void> worker = buildWorker(smiles, bw);
jobs.add(worker);
SiriusJobs.getGlobalJobManager().submitJob(worker);
}
}
jobs.waitForEmpty();
} catch (IOException | InterruptedException e) {
throw new RuntimeException(e);
}
return null;
}
return smiles;
}
};
}

public void writeOutput(Path outputFile, List<SmilesFpt> outList) throws IOException {
try (BufferedWriter bw = Files.newBufferedWriter(outputFile)) {
for (SmilesFpt obj : outList) {
bw.write(obj.smiles + "\t" + obj.fpt.toCommaSeparatedString() + System.lineSeparator());
public BasicJJob<Void> buildWorker(String smiles, BufferedWriter bw) {
return new BasicJJob<>() {
@Override
protected Void compute() {
FixedFingerprinter printer = new FixedFingerprinter(cdkVersion);
try {
// computing fingerprint
SmilesFpt smilesFpt = new SmilesFpt(smiles, mask.mask(printer.computeFingerprintFromSMILES(smiles).toIndizesArray()));
bw.write(smilesFpt.smiles + "\t" + smilesFpt.fpt.toCommaSeparatedString() + System.lineSeparator());
} catch (IOException | RuntimeException e) {
// if an error occurs, skip this smiles
failedComputations.put(smiles, e);
}
jobs.remove(this);
return null;
}
}

};
}
}

Expand All @@ -142,3 +167,36 @@ protected SmilesFpt(String smiles, Fingerprint fpt) {
this.smiles = smiles;
}
}

class BlockingSet<E> {
private final Set<E> set = new HashSet<>();
private final int bufferSize;

protected BlockingSet(int bufferSize) {
this.bufferSize = bufferSize;
}

protected synchronized void add(E element) throws InterruptedException {
while (set.size() >= bufferSize) {
wait();
}
set.add(element);
}

protected synchronized boolean remove(E element) {
boolean success = set.remove(element);
if (success)
notify();
return success;
}

protected synchronized int size() {
return set.size();
}

protected synchronized void waitForEmpty() throws InterruptedException {
while (!set.isEmpty()) {
wait();
}
}
}

0 comments on commit d6860ce

Please sign in to comment.