Skip to content

Commit

Permalink
Merge pull request #359 from steineggerlab/multimer
Browse files Browse the repository at this point in the history
Foldseek MultimerSearch: implement  --monomer-include-mode
  • Loading branch information
Woosub-Kim authored Sep 27, 2024
2 parents 232a4c4 + 079a5a1 commit 19c8820
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 25 deletions.
6 changes: 3 additions & 3 deletions src/commons/LocalParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LocalParameters::LocalParameters() :
PARAM_N_SAMPLE(PARAM_N_SAMPLE_ID, "--n-sample", "Sample size","pick N random sample" ,typeid(int), (void *) &nsample, "^[0-9]{1}[0-9]*$"),
PARAM_COORD_STORE_MODE(PARAM_COORD_STORE_MODE_ID, "--coord-store-mode", "Coord store mode", "Coordinate storage mode: \n1: C-alpha as float\n2: C-alpha as difference (uint16_t)", typeid(int), (void *) &coordStoreMode, "^[1-2]{1}$",MMseqsParameter::COMMAND_EXPERT),
PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD(PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD_ID, "--min-assigned-chains-ratio", "Minimum assigned chains percentage Threshold", "Minimum ratio of assigned chains out of all query chains > thr [0.0,1.0]", typeid(float), (void *) & minAssignedChainsThreshold, "^[0-9]*(\\.[0-9]+)?$", MMseqsParameter::COMMAND_ALIGN),
PARAM_SINGLE_CHAIN_INCLUDE_MODE(PARAM_SINGLE_CHAIN_INCLUDE_MODE_ID, "--single-chain-include-mode", "Single Chained Assignments Inclusion Mode for Multimer", "Single Chained Assignments Inclusion 0: include single chained assignments, 1: NOT include single chained assignment", typeid(int), (void *) & singleChainIncludeMode, "^[0-1]{1}$", MMseqsParameter::COMMAND_ALIGN),
PARAM_MONOMER_INCLUDE_MODE(PARAM_MONOMER_INCLUDE_MODE_ID, "--monomer-include-mode", "Monomer inclusion Mode for MultimerSerch", "Monomer Complex Inclusion 0: include monomers, 1: NOT include monomers", typeid(int), (void *) & monomerIncludeMode, "^[0-1]{1}$", MMseqsParameter::COMMAND_ALIGN),
PARAM_CLUSTER_SEARCH(PARAM_CLUSTER_SEARCH_ID, "--cluster-search", "Cluster search", "first find representative then align all cluster members", typeid(int), (void *) &clusterSearch, "^[0-1]{1}$",MMseqsParameter::COMMAND_MISC),
PARAM_FILE_INCLUDE(PARAM_FILE_INCLUDE_ID, "--file-include", "File Inclusion Regex", "Include file names based on this regex", typeid(std::string), (void *) &fileInclude, "^.*$"),
PARAM_FILE_EXCLUDE(PARAM_FILE_EXCLUDE_ID, "--file-exclude", "File Exclusion Regex", "Exclude file names based on this regex", typeid(std::string), (void *) &fileExclude, "^.*$"),
Expand Down Expand Up @@ -191,7 +191,7 @@ LocalParameters::LocalParameters() :

//scorecmultimer
scoremultimer.push_back(&PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD);
scoremultimer.push_back(&PARAM_SINGLE_CHAIN_INCLUDE_MODE);
scoremultimer.push_back(&PARAM_MONOMER_INCLUDE_MODE);
scoremultimer.push_back(&PARAM_THREADS);
scoremultimer.push_back(&PARAM_V);

Expand Down Expand Up @@ -253,7 +253,7 @@ LocalParameters::LocalParameters() :
maskBfactorThreshold = 0;
chainNameMode = 0;
minAssignedChainsThreshold = 0.0;
singleChainIncludeMode = 0;
monomerIncludeMode = 0;
writeMapping = 0;
tmAlignFast = 1;
exactTMscore = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/commons/LocalParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class LocalParameters : public Parameters {
PARAMETER(PARAM_N_SAMPLE)
PARAMETER(PARAM_COORD_STORE_MODE)
PARAMETER(PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD)
PARAMETER(PARAM_SINGLE_CHAIN_INCLUDE_MODE)
PARAMETER(PARAM_MONOMER_INCLUDE_MODE)
PARAMETER(PARAM_CLUSTER_SEARCH)
PARAMETER(PARAM_FILE_INCLUDE)
PARAMETER(PARAM_FILE_EXCLUDE)
Expand Down Expand Up @@ -162,7 +162,7 @@ class LocalParameters : public Parameters {
int nsample;
int coordStoreMode;
float minAssignedChainsThreshold;
int singleChainIncludeMode;
int monomerIncludeMode;
int clusterSearch;
std::string fileInclude;
std::string fileExclude;
Expand Down
2 changes: 1 addition & 1 deletion src/strucclustutils/MultimerUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const float LEARNING_RATE = 0.1;
const float TM_SCORE_MARGIN = 0.7;
const unsigned int MULTIPLE_CHAINED_COMPLEX = 2;
const unsigned int SIZE_OF_SUPERPOSITION_VECTOR = 12;
const int SKIP_SINGLE_CHAIN_ASSIGNMENTS = 1;
const int SKIP_MONOMERS = 1;
typedef std::pair<std::string, std::string> compNameChainName_t;
typedef std::map<unsigned int, unsigned int> chainKeyToComplexId_t;
typedef std::map<unsigned int, std::vector<unsigned int>> complexIdToChainKeys_t;
Expand Down
45 changes: 26 additions & 19 deletions src/strucclustutils/scoremultimer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ struct SearchResult {
dbResidueLen = residueLen;
}

void standardize(int singleChainedAssignmentIncludeMode) {
void standardize(int MonomerIncludeMode) {
if (dbResidueLen == 0)
alnVec.clear();

if (singleChainedAssignmentIncludeMode==SKIP_SINGLE_CHAIN_ASSIGNMENTS && dbChainKeys.size() < MULTIPLE_CHAINED_COMPLEX)
if (MonomerIncludeMode == SKIP_MONOMERS && dbChainKeys.size() < MULTIPLE_CHAINED_COMPLEX)
alnVec.clear();

if (alnVec.empty())
Expand Down Expand Up @@ -182,11 +182,9 @@ bool compareNeighborWithDist(const NeighborsWithDist &first, const NeighborsWith

class DBSCANCluster {
public:
DBSCANCluster(SearchResult &searchResult, std::set<cluster_t> &finalClusters, double minCov, int singleChainMode) : searchResult(searchResult), finalClusters(finalClusters) {
DBSCANCluster(SearchResult &searchResult, std::set<cluster_t> &finalClusters, float minCov) : searchResult(searchResult), finalClusters(finalClusters) {
cLabel = 0;
minimumClusterSize = (unsigned int) ((double) searchResult.qChainKeys.size() * minCov);
if (singleChainMode == SKIP_SINGLE_CHAIN_ASSIGNMENTS)
minimumClusterSize = std::max(MULTIPLE_CHAINED_COMPLEX, minimumClusterSize);
minimumClusterSize = std::ceil((float) searchResult.qChainKeys.size() * minCov);
maximumClusterSize = std::min(searchResult.qChainKeys.size(), searchResult.dbChainKeys.size());
maximumClusterNum = searchResult.alnVec.size() / maximumClusterSize;
prevMaxClusterSize = 0;
Expand All @@ -196,9 +194,9 @@ class DBSCANCluster {
}

bool getAlnClusters() {
// if Query or Target is a Single Chain Complex.
// if Query or Target is a Monomer Complex.
if (std::min(searchResult.qChainKeys.size(), searchResult.dbChainKeys.size()) < MULTIPLE_CHAINED_COMPLEX)
return earlyStopForSingleChainComplex();
return earlyStopForMonomers();

// rbh filter
filterAlnsByRBH();
Expand Down Expand Up @@ -234,15 +232,20 @@ class DBSCANCluster {
std::map<unsigned int, float> qBestTmScore;
std::map<unsigned int, float> dbBestTmScore;

bool earlyStopForSingleChainComplex() {
bool earlyStopForMonomers() {
if (minimumClusterSize >= MULTIPLE_CHAINED_COMPLEX)
return finishDBSCAN();

getSingleChainedCluster();
return finishDBSCAN();
}

void getSingleChainedCluster() {
finalClusters.clear();
for (unsigned int alnIdx = 0; alnIdx < searchResult.alnVec.size(); alnIdx++ ) {
neighbors = {alnIdx};
finalClusters.insert(neighbors);
}
return finishDBSCAN();
}

bool runDBSCAN() {
Expand Down Expand Up @@ -314,6 +317,10 @@ class DBSCANCluster {

eps += learningRate;
}

if (minimumClusterSize < MULTIPLE_CHAINED_COMPLEX && currMaxClusterSize < MULTIPLE_CHAINED_COMPLEX)
getSingleChainedCluster();

return finishDBSCAN();
}

Expand Down Expand Up @@ -472,7 +479,7 @@ class DBSCANCluster {

class ComplexScorer {
public:
ComplexScorer(IndexReader *qDbr3Di, IndexReader *tDbr3Di, DBReader<unsigned int> &alnDbr, IndexReader *qCaDbr, IndexReader *tCaDbr, unsigned int thread_idx, double minAssignedChainsRatio, int singleChainedAssignmentIncludeMode) : alnDbr(alnDbr), qCaDbr(qCaDbr), tCaDbr(tCaDbr), thread_idx(thread_idx), minAssignedChainsRatio(minAssignedChainsRatio), singleChainedAssignmentIncludeMode(singleChainedAssignmentIncludeMode) {
ComplexScorer(IndexReader *qDbr3Di, IndexReader *tDbr3Di, DBReader<unsigned int> &alnDbr, IndexReader *qCaDbr, IndexReader *tCaDbr, unsigned int thread_idx, float minAssignedChainsRatio, int monomerIncludeMode) : alnDbr(alnDbr), qCaDbr(qCaDbr), tCaDbr(tCaDbr), thread_idx(thread_idx), minAssignedChainsRatio(minAssignedChainsRatio), monomerIncludeMode(monomerIncludeMode) {
maxChainLen = std::max(qDbr3Di->sequenceReader->getMaxSeqLen()+1, tDbr3Di->sequenceReader->getMaxSeqLen()+1);
q3diDbr = qDbr3Di;
t3diDbr = tDbr3Di;
Expand Down Expand Up @@ -538,7 +545,7 @@ class ComplexScorer {
paredSearchResult.alnVec.emplace_back(aln);
continue;
}
paredSearchResult.standardize(singleChainedAssignmentIncludeMode);
paredSearchResult.standardize(monomerIncludeMode);
if (!paredSearchResult.alnVec.empty())
searchResults.emplace_back(paredSearchResult);

Expand All @@ -550,7 +557,7 @@ class ComplexScorer {
paredSearchResult.alnVec.emplace_back(aln);
}
currAlns.clear();
paredSearchResult.standardize(singleChainedAssignmentIncludeMode);
paredSearchResult.standardize(monomerIncludeMode);
if (!paredSearchResult.alnVec.empty())
searchResults.emplace_back(paredSearchResult);

Expand All @@ -564,7 +571,7 @@ class ComplexScorer {
tmAligner = new TMaligner(maxResLen, false, true, false);
}
finalClusters.clear();
DBSCANCluster dbscanCluster(searchResult, finalClusters, minAssignedChainsRatio, singleChainedAssignmentIncludeMode);
DBSCANCluster dbscanCluster(searchResult, finalClusters, minAssignedChainsRatio);
if (!dbscanCluster.getAlnClusters()) {
finalClusters.clear();
return;
Expand Down Expand Up @@ -600,7 +607,7 @@ class ComplexScorer {
Coordinate16 qCoords;
Coordinate16 tCoords;
unsigned int thread_idx;
double minAssignedChainsRatio;
float minAssignedChainsRatio;
unsigned int maxResLen;
Chain qChain;
Chain dbChain;
Expand All @@ -610,7 +617,7 @@ class ComplexScorer {
SearchResult paredSearchResult;
std::set<cluster_t> finalClusters;
bool hasBacktrace;
int singleChainedAssignmentIncludeMode;
int monomerIncludeMode;

unsigned int getQueryResidueLength(std::vector<unsigned int> &qChainKeys) {
unsigned int qResidueLen = 0;
Expand Down Expand Up @@ -704,7 +711,7 @@ int scoremultimer(int argc, const char **argv, const Command &command) {
}

float minAssignedChainsRatio = par.minAssignedChainsThreshold > MAX_ASSIGNED_CHAIN_RATIO ? MAX_ASSIGNED_CHAIN_RATIO: par.minAssignedChainsThreshold;
int singleChainIncludeMode = par.singleChainIncludeMode;
int monomerIncludeMode = par.monomerIncludeMode;

std::vector<unsigned int> qComplexIndices;
std::vector<unsigned int> dbComplexIndices;
Expand All @@ -730,13 +737,13 @@ int scoremultimer(int argc, const char **argv, const Command &command) {
std::vector<SearchResult> searchResults;
std::vector<Assignment> assignments;
std::vector<resultToWrite_t> resultToWriteLines;
ComplexScorer complexScorer(q3DiDbr, &t3DiDbr, alnDbr, qCaDbr, &tCaDbr, thread_idx, minAssignedChainsRatio, singleChainIncludeMode);
ComplexScorer complexScorer(q3DiDbr, &t3DiDbr, alnDbr, qCaDbr, &tCaDbr, thread_idx, minAssignedChainsRatio, monomerIncludeMode);
#pragma omp for schedule(dynamic, 1)
// for each q complex
for (size_t qCompIdx = 0; qCompIdx < qComplexIndices.size(); qCompIdx++) {
unsigned int qComplexId = qComplexIndices[qCompIdx];
std::vector<unsigned int> &qChainKeys = qComplexIdToChainKeysMap.at(qComplexId);
if (par.singleChainIncludeMode == SKIP_SINGLE_CHAIN_ASSIGNMENTS && qChainKeys.size() < MULTIPLE_CHAINED_COMPLEX)
if (monomerIncludeMode == SKIP_MONOMERS && qChainKeys.size() < MULTIPLE_CHAINED_COMPLEX)
continue;
complexScorer.getSearchResults(qComplexId, qChainKeys, dbChainKeyToComplexIdMap, dbComplexIdToChainKeysMap, searchResults);
// for each db complex
Expand Down

0 comments on commit 19c8820

Please sign in to comment.