Skip to content

Commit

Permalink
Adds commit classification rule (#397)
Browse files Browse the repository at this point in the history
This PR adds a new rule using the `LLMService`.

It sends the diff of a commit to the LLM and asks if this commit is
security relevant or not. Relevance of the rule is set to 32 for now,
but this value can be adjusted after evaluation.

Thanks to @tommasoaiello
  • Loading branch information
lauraschauer authored Jul 17, 2024
1 parent 53446b0 commit b8f600f
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 8 deletions.
54 changes: 53 additions & 1 deletion prospector/llm/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import validators
from langchain_core.language_models.llms import LLM
from langchain_core.output_parsers import StrOutputParser
from requests import HTTPError

from llm.instantiation import create_model_instance
from llm.prompts import prompt_best_guess
from llm.prompts.classify_commit import zero_shot as cc_zero_shot
from llm.prompts.get_repository_url import prompt_best_guess
from log.logger import logger
from util.config_parser import LLMServiceConfig
from util.singleton import Singleton
Expand Down Expand Up @@ -74,3 +76,53 @@ def get_repository_url(self, advisory_description, advisory_references) -> str:
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")

return url

def classify_commit(
self, diff: str, repository_name: str, commit_message: str
) -> bool:
"""Ask an LLM whether a commit is security relevant or not. The response will be either True or False.
Args:
candidate (Commit): The commit to input into the LLM
Returns:
True if the commit is deemed security relevant, False if not.
Raises:
ValueError if there is an error in the model invocation or the response was not valid.
"""
try:
chain = cc_zero_shot | self.model | StrOutputParser()

is_relevant = chain.invoke(
{
"diff": diff,
"repository_name": repository_name,
"commit_message": commit_message,
}
)
logger.info(f"LLM returned is_relevant={is_relevant}")

except HTTPError as e:
# if the diff is too big, a 400 error is returned -> silently ignore by returning False for this commit
status_code = e.response.status_code
if status_code == 400:
return False
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")
except Exception as e:
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")

if is_relevant in [
"True",
"ANSWER:True",
"```ANSWER:True```",
]:
return True
elif is_relevant in [
"False",
"ANSWER:False",
"```ANSWER:False```",
]:
return False
else:
raise RuntimeError(f"The model returned an invalid response: {is_relevant}")
16 changes: 16 additions & 0 deletions prospector/llm/prompts/classify_commit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from langchain.prompts import PromptTemplate

zero_shot = PromptTemplate.from_template(
"""Is the following commit security relevant or not?
Please provide the output as a boolean value, either True or False.
If it is security relevant just answer True otherwise answer False. Do not return anything else.
To provide you with some context, the name of the repository is: {repository_name}, and the
commit message is: {commit_message}.
Finally, here is the diff of the commit:
{diff}\n
Your answer:\n"""
)
File renamed without changes.
16 changes: 15 additions & 1 deletion prospector/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,18 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
return False


class CommitIsSecurityRelevant(Rule):
"""Matches commits that are deemed security relevant by the commit classification service."""

def apply(
self,
candidate: Commit,
) -> bool:
return LLMService().classify_commit(
candidate.diff, candidate.repository, candidate.message
)


RULES_PHASE_1: List[Rule] = [
VulnIdInMessage("VULN_ID_IN_MESSAGE", 64),
# CommitMentionedInAdv("COMMIT_IN_ADVISORY", 64),
Expand All @@ -433,4 +445,6 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
CommitHasTwins("COMMIT_HAS_TWINS", 2),
]

RULES_PHASE_2: List[Rule] = []
RULES_PHASE_2: List[Rule] = [
CommitIsSecurityRelevant("COMMIT_IS_SECURITY_RELEVANT", 32)
]
24 changes: 18 additions & 6 deletions prospector/rules/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def candidates():
changed_files={
"core/src/main/java/org/apache/cxf/workqueue/AutomaticWorkQueueImpl.java"
},
minhash=get_encoded_minhash(get_msg("Insecure deserialization", 50)),
minhash=get_encoded_minhash(
get_msg("Insecure deserialization", 50)
),
),
# TODO: Not matched by existing tests: GHSecurityAdvInMessage, ReferencesBug, ChangesRelevantCode, TwinMentionedInAdv, VulnIdInLinkedIssue, SecurityKeywordInLinkedGhIssue, SecurityKeywordInLinkedBug, CrossReferencedBug, CrossReferencedGh, CommitHasTwins, ChangesRelevantFiles, CommitMentionedInAdv, RelevantWordsInMessage
]
Expand All @@ -109,37 +111,47 @@ def advisory_record():
)


def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: AdvisoryRecord):
def test_apply_phase_1_rules(
candidates: List[Commit], advisory_record: AdvisoryRecord
):
annotated_candidates = apply_rules(
candidates, advisory_record, enabled_rules=enabled_rules_from_config
)

# Repo 5: Should match: AdvKeywordsInFiles, SecurityKeywordsInMsg, CommitMentionedInReference
assert len(annotated_candidates[0].matched_rules) == 3

matched_rules_names = [item["id"] for item in annotated_candidates[0].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[0].matched_rules
]
assert "ADV_KEYWORDS_IN_FILES" in matched_rules_names
assert "COMMIT_IN_REFERENCE" in matched_rules_names
assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names

# Repo 1: Should match: VulnIdInMessage, ReferencesGhIssue
assert len(annotated_candidates[1].matched_rules) == 2

matched_rules_names = [item["id"] for item in annotated_candidates[1].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[1].matched_rules
]
assert "VULN_ID_IN_MESSAGE" in matched_rules_names
assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names

# Repo 3: Should match: VulnIdInMessage, ReferencesGhIssue
assert len(annotated_candidates[2].matched_rules) == 2

matched_rules_names = [item["id"] for item in annotated_candidates[2].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[2].matched_rules
]
assert "VULN_ID_IN_MESSAGE" in matched_rules_names
assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names

# Repo 4: Should match: SecurityKeywordsInMsg
assert len(annotated_candidates[3].matched_rules) == 1

matched_rules_names = [item["id"] for item in annotated_candidates[3].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[3].matched_rules
]
assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names

# Repo 2: Matches nothing
Expand Down

0 comments on commit b8f600f

Please sign in to comment.