Skip to content

Commit

Permalink
Add mode to compute exact (slow) tmscore.
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Feb 20, 2024
1 parent 75a50f7 commit 493cefe
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 14 deletions.
6 changes: 6 additions & 0 deletions src/commons/LocalParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ LocalParameters::LocalParameters() :
PARAM_CHAIN_NAME_MODE(PARAM_CHAIN_NAME_MODE_ID,"--chain-name-mode", "Chain name mode", "Add chain to name:\n0: auto\n1: always add\n",typeid(int), (void *) &chainNameMode, "^[0-1]{1}$", MMseqsParameter::COMMAND_EXPERT),
PARAM_WRITE_MAPPING(PARAM_WRITE_MAPPING_ID, "--write-mapping", "Write mapping file", "write _mapping file containing mapping from internal id to taxonomic identifier", typeid(int), (void *) &writeMapping, "^[0-1]{1}", MMseqsParameter::COMMAND_EXPERT),
PARAM_TMALIGN_FAST(PARAM_TMALIGN_FAST_ID,"--tmalign-fast", "TMalign fast","turn on fast search in TM-align" ,typeid(int), (void *) &tmAlignFast, "^[0-1]{1}$"),
PARAM_EXACT_TMSCORE(PARAM_EXACT_TMSCORE_ID,"--exact-tmscore", "Exact TMscore","turn on fast exact TMscore (slow), default is approximate" ,typeid(int), (void *) &exactTMscore, "^[0-1]{1}$"),
PARAM_N_SAMPLE(PARAM_N_SAMPLE_ID, "--n-sample", "Sample size","pick N random sample" ,typeid(int), (void *) &nsample, "^[0-9]{1}[0-9]*$"),
PARAM_COORD_STORE_MODE(PARAM_COORD_STORE_MODE_ID, "--coord-store-mode", "Coord store mode", "Coordinate storage mode: \n1: C-alpha as float\n2: C-alpha as difference (uint16_t)", typeid(int), (void *) &coordStoreMode, "^[1-2]{1}$",MMseqsParameter::COMMAND_EXPERT),
PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD(PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD_ID, "--min-assigned-chains-ratio", "Minimum assigned chains percentage Threshold", "minimum percentage of assigned chains out of all query chains > thr [0,100] %", typeid(float), (void *) & minAssignedChainsThreshold, "^[0-9]*(\\.[0-9]+)?$"),
Expand Down Expand Up @@ -86,6 +87,8 @@ LocalParameters::LocalParameters() :
structurecreatedb.push_back(&PARAM_THREADS);
structurecreatedb.push_back(&PARAM_V);

convertalignments.push_back(&PARAM_EXACT_TMSCORE);

createindex.push_back(&PARAM_INDEX_EXCLUDE);

// tmalign
Expand All @@ -103,6 +106,7 @@ LocalParameters::LocalParameters() :
tmalign.push_back(&PARAM_THREADS);
tmalign.push_back(&PARAM_V);

structurerescorediagonal.push_back(&PARAM_EXACT_TMSCORE);
structurerescorediagonal.push_back(&PARAM_TMSCORE_THRESHOLD);
structurerescorediagonal.push_back(&PARAM_LDDT_THRESHOLD);
structurerescorediagonal.push_back(&PARAM_ALIGNMENT_TYPE);
Expand All @@ -112,6 +116,7 @@ LocalParameters::LocalParameters() :
structurealign.push_back(&PARAM_LDDT_THRESHOLD);
structurealign.push_back(&PARAM_SORT_BY_STRUCTURE_BITS);
structurealign.push_back(&PARAM_ALIGNMENT_TYPE);
structurealign.push_back(&PARAM_EXACT_TMSCORE);
structurealign = combineList(structurealign, align);
// tmalign.push_back(&PARAM_GAP_OPEN);
// tmalign.push_back(&PARAM_GAP_EXTEND);
Expand Down Expand Up @@ -205,6 +210,7 @@ LocalParameters::LocalParameters() :
chainNameMode = 0;
writeMapping = 0;
tmAlignFast = 1;
exactTMscore = 0;
gapOpen = 10;
gapExtend = 1;
nsample = 5000;
Expand Down
2 changes: 2 additions & 0 deletions src/commons/LocalParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class LocalParameters : public Parameters {
PARAMETER(PARAM_CHAIN_NAME_MODE)
PARAMETER(PARAM_WRITE_MAPPING)
PARAMETER(PARAM_TMALIGN_FAST)
PARAMETER(PARAM_EXACT_TMSCORE)
PARAMETER(PARAM_N_SAMPLE)
PARAMETER(PARAM_COORD_STORE_MODE)
PARAMETER(PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD)
Expand All @@ -125,6 +126,7 @@ class LocalParameters : public Parameters {
int chainNameMode;
bool writeMapping;
int tmAlignFast;
int exactTMscore;
int nsample;
int coordStoreMode;
float minAssignedChainsThreshold;
Expand Down
121 changes: 116 additions & 5 deletions src/commons/TMaligner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
#include "TMaligner.h"
#include "tmalign/Coordinates.h"
#include <tmalign/TMalign.h>
#include <tmalign/basic_fun.h>
#include "StructureSmithWaterman.h"
#include "StructureSmithWaterman.h"

TMaligner::TMaligner(unsigned int maxSeqLen, bool tmAlignFast, bool tmScoreOnly)
TMaligner::TMaligner(unsigned int maxSeqLen, bool tmAlignFast, bool tmScoreOnly, bool computeExactScore)
: tmAlignFast(tmAlignFast),
xtm(maxSeqLen), ytm(maxSeqLen), xt(maxSeqLen),
r1(maxSeqLen), r2(maxSeqLen){
r1(maxSeqLen), r2(maxSeqLen), computeExactScore(computeExactScore){
affineNW = NULL;
if(tmScoreOnly == false){
affineNW = new AffineNeedlemanWunsch(maxSeqLen, 20);
Expand Down Expand Up @@ -44,9 +45,10 @@ TMaligner::~TMaligner(){
delete [] invmap;
}

TMaligner::TMscoreResult TMaligner::computeTMscore(float *x, float *y, float *z, unsigned int targetLen,
int qStartPos, int dbStartPos, const std::string &backtrace,
int normalizationLen) {

TMaligner::TMscoreResult TMaligner::computeAppoximateTMscore(float *x, float *y, float *z, unsigned int targetLen,
int qStartPos, int dbStartPos, const std::string &backtrace,
int normalizationLen) {
int qPos = qStartPos;
int tPos = dbStartPos;
std::string cigarString = backtrace;
Expand Down Expand Up @@ -100,6 +102,115 @@ TMaligner::TMscoreResult TMaligner::computeTMscore(float *x, float *y, float *z,
return TMaligner::TMscoreResult(u, t, TM, rmsd0);
}


TMaligner::TMscoreResult TMaligner::computeExactTMscore(float *x, float *y, float *z, unsigned int targetLen,
int qStartPos, int dbStartPos, const std::string &backtrace,
int normalizationLen) {
int qPos = qStartPos;
int tPos = dbStartPos;
std::string cigarString = backtrace;
std::fill(invmap, invmap+queryLen, -1);
for (size_t btPos = 0; btPos < cigarString.size(); btPos++) {
if (cigarString[btPos] == 'M') {
invmap[qPos] = tPos;
qPos++;
tPos++;
}
else if (cigarString[btPos] == 'I') {
qPos++;
}
else {
tPos++;
}
}

memcpy(target_x, x, sizeof(float) * targetLen);
memcpy(target_y, y, sizeof(float) * targetLen);
memcpy(target_z, z, sizeof(float) * targetLen);
Coordinates targetCaCords;
targetCaCords.x = target_x;
targetCaCords.y = target_y;
targetCaCords.z = target_z;
Coordinates queryCaCords;
queryCaCords.x = query_x;
queryCaCords.y = query_y;
queryCaCords.z = query_z;
float t[3], u[3][3];
float D0_MIN;

float rmsd0 = 0.0;
float Lnorm; //normalization length
float score_d8,d0,d0_search,dcu0;//for TMscore search
parameter_set4search(normalizationLen, normalizationLen, D0_MIN, Lnorm,
score_d8, d0, d0_search, dcu0);
float local_d0_search = d0_search;

int simplify_step=1;
if (tmAlignFast) {
simplify_step=40;
}
detailed_search_standard(r1, r2, xtm, ytm, xt, targetCaCords, queryCaCords, queryLen,
invmap, t, u, simplify_step, local_d0_search, true, Lnorm, score_d8, d0, mem);
BasicFunction::do_rotation(targetCaCords, xt, targetLen, t, u);
int k = 0;
for(unsigned int j=0; j<queryLen; j++)
{
int i=invmap[j];
if(i>=0)//aligned
{
float d = sqrt(BasicFunction::dist(xt.x[i], xt.y[i], xt.z[i], queryCaCords.x[j], queryCaCords.y[j],
queryCaCords.z[j]));

if (i >= 0 || d <= score_d8) {
r1.x[k] = targetCaCords.x[i];
r1.y[k] = targetCaCords.y[i];
r1.z[k] = targetCaCords.z[i];

r2.x[k] = queryCaCords.x[j];
r2.y[k] = queryCaCords.y[j];
r2.z[k] = queryCaCords.z[j];

xtm.x[k] = targetCaCords.x[i];
xtm.y[k] = targetCaCords.y[i];
xtm.z[k] = targetCaCords.z[i];

ytm.x[k] = queryCaCords.x[j];
ytm.y[k] = queryCaCords.y[j];
ytm.z[k] = queryCaCords.z[j];

k++;
}
}
}
int n_ali8=k;

KabschFast(r1, r2, n_ali8, &rmsd0, t, u, mem);// rmsd0 is used for final output, only recalculate rmsd0, not t & u
rmsd0 = sqrt(rmsd0 / n_ali8);

simplify_step=1;
float Lnorm_0=normalizationLen;
//normalized by length of structure A
parameter_set4final(Lnorm_0, D0_MIN, Lnorm,
d0, d0_search);
local_d0_search = d0_search;

double TM = TMscore8_search(r1, r2, xtm, ytm, xt, n_ali8, t, u, simplify_step,
&rmsd0, local_d0_search, Lnorm, score_d8, d0, mem);

return TMaligner::TMscoreResult(u, t, TM, rmsd0);
}

TMaligner::TMscoreResult TMaligner::computeTMscore(float *x, float *y, float *z, unsigned int targetLen,
int qStartPos, int dbStartPos, const std::string &backtrace,
int normalizationLen) {
if(computeExactScore){
return computeExactTMscore(x, y, z, targetLen, qStartPos, dbStartPos, backtrace, normalizationLen);
} else {
return computeAppoximateTMscore(x, y, z, targetLen, qStartPos, dbStartPos, backtrace, normalizationLen);
}
}


void TMaligner::initQuery(float *x, float *y, float *z, char * querySeq, unsigned int queryLen){
memset(querySecStruc, 0, sizeof(char) * queryLen);
memcpy(query_x, x, sizeof(float) * queryLen);
Expand Down
13 changes: 12 additions & 1 deletion src/commons/TMaligner.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class TMaligner{
public:
TMaligner(unsigned int maxSeqLen, bool tmAlignFast, bool tmScoreOnly);
TMaligner(unsigned int maxSeqLen, bool tmAlignFast, bool tmScoreOnly, bool exact);
~TMaligner();

struct TMscoreResult{
Expand All @@ -39,6 +39,7 @@ class TMaligner{
unsigned int targetLen, int qStartPos,
int targetStartPos, const std::string & backtrace,
int normalizationLen);

Matcher::result_t align(unsigned int dbKey, float *target_x, float *target_y, float *target_z,
char * targetSeq, unsigned int targetLen, float &TM);

Expand All @@ -59,7 +60,17 @@ class TMaligner{
std::string seqM, seqxA, seqyA;// for output alignment
bool tmAlignFast;
Coordinates xtm, ytm, xt, r1, r2;
bool computeExactScore;
int * invmap;

TMscoreResult computeExactTMscore(float *x, float *y, float *z,
unsigned int targetLen, int qStartPos,
int targetStartPos, const std::string & backtrace,
int normalizationLen);
TMscoreResult computeAppoximateTMscore(float *x, float *y, float *z,
unsigned int targetLen, int qStartPos,
int targetStartPos, const std::string & backtrace,
int normalizationLen);
};

#endif //FOLDSEEK_TMALIGNER_H
4 changes: 2 additions & 2 deletions src/strucclustutils/aln2tmscore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#endif

int aln2tmscore(int argc, const char **argv, const Command& command) {
Parameters& par = Parameters::getInstance();
LocalParameters &par = LocalParameters::getLocalInstance();
par.parseParameters(argc, argv, command, true, 0, 0);

// never allow deletions
Expand Down Expand Up @@ -64,7 +64,7 @@ int aln2tmscore(int argc, const char **argv, const Command& command) {
std::string resultsStr;
resultsStr.reserve(10 * 1024);

TMaligner tmaln(std::max(qdbr.getMaxSeqLen() + 1,tdbr->getMaxSeqLen() + 1), false, true);
TMaligner tmaln(std::max(qdbr.getMaxSeqLen() + 1,tdbr->getMaxSeqLen() + 1), false, true, par.exactTMscore);
Coordinate16 qcoords;
Coordinate16 tcoords;

Expand Down
4 changes: 2 additions & 2 deletions src/strucclustutils/scorecomplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ class ComplexScorer {
q3diDbr = qDbr3Di;
t3diDbr = tDbr3Di;
maxResLen = maxChainLen * 2;
tmAligner = new TMaligner(maxResLen, false, true);
tmAligner = new TMaligner(maxResLen, false, true, false);
}

void getSearchResults(unsigned int qComplexId, std::vector<unsigned int> &qChainKeys, chainKeyToComplexId_t &dbChainKeyToComplexIdLookup, complexIdToChainKeys_t &dbComplexIdToChainKeysLookup, std::vector<SearchResult> &searchResults) {
Expand Down Expand Up @@ -568,7 +568,7 @@ class ComplexScorer {
if (maxResLen < maxChainLen * std::min(searchResult.qChainKeys.size(), searchResult.dbChainKeys.size())) {
delete tmAligner;
maxResLen = std::max(searchResult.qChainKeys.size(), searchResult.dbChainKeys.size()) * maxChainLen;
tmAligner = new TMaligner(maxResLen, false, true);
tmAligner = new TMaligner(maxResLen, false, true, false);
}

finalClusters.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/strucclustutils/structurealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ int structurealign(int argc, const char **argv, const Command& command) {
TMaligner *tmaligner = NULL;
if(needTMaligner) {
tmaligner = new TMaligner(
std::max(q3DiDbr->sequenceReader->getMaxSeqLen() + 1, t3DiDbr.sequenceReader->getMaxSeqLen() + 1), false, true);
std::max(q3DiDbr->sequenceReader->getMaxSeqLen() + 1, t3DiDbr.sequenceReader->getMaxSeqLen() + 1), false, true, par.exactTMscore);
}
LDDTCalculator *lddtcalculator = NULL;
if(needLDDT) {
Expand Down
2 changes: 1 addition & 1 deletion src/strucclustutils/structureconvertalis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ R"html(<!DOCTYPE html>
TMaligner *tmaligner = NULL;
if(needTMaligner) {
tmaligner = new TMaligner(
std::max(tDbr->sequenceReader->getMaxSeqLen() + 1, qDbr.sequenceReader->getMaxSeqLen() + 1), false, true);
std::max(tDbr->sequenceReader->getMaxSeqLen() + 1, qDbr.sequenceReader->getMaxSeqLen() + 1), false, true, par.exactTMscore);
}
LDDTCalculator *lddtcalculator = NULL;
if(needLDDT) {
Expand Down
2 changes: 1 addition & 1 deletion src/strucclustutils/structurerescorediagonal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ int structureungappedalign(int argc, const char **argv, const Command& command)
TMaligner *tmaligner = NULL;
if(needTMaligner) {
tmaligner = new TMaligner(
std::max(qdbr3Di.sequenceReader->getMaxSeqLen() + 1, t3DiDbr->sequenceReader->getMaxSeqLen() + 1), false, true);
std::max(qdbr3Di.sequenceReader->getMaxSeqLen() + 1, t3DiDbr->sequenceReader->getMaxSeqLen() + 1), false, true, par.exactTMscore);
}
LDDTCalculator *lddtcalculator = NULL;
if(needLDDT) {
Expand Down
2 changes: 1 addition & 1 deletion src/strucclustutils/tmalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ int tmalign(int argc, const char **argv, const Command& command) {
#ifdef OPENMP
thread_idx = static_cast<unsigned int>(omp_get_thread_num());
#endif
TMaligner tmaln(std::max(qdbr.sequenceReader->getMaxSeqLen() + 1,tdbr->sequenceReader->getMaxSeqLen() + 1), par.tmAlignFast, false);
TMaligner tmaln(std::max(qdbr.sequenceReader->getMaxSeqLen() + 1,tdbr->sequenceReader->getMaxSeqLen() + 1), par.tmAlignFast, false, false);
std::vector<Matcher::result_t> swResults;
swResults.reserve(300);
std::string backtrace;
Expand Down

0 comments on commit 493cefe

Please sign in to comment.