Skip to content

Commit

Permalink
add originalSourcefiles in metadata (#1252)
Browse files Browse the repository at this point in the history
* add originalSourcefiles in metadata

* add - original source linking

* fix test case

* remove extra traversals

* enhance test case

* added test case
  • Loading branch information
khemrajrathore authored Aug 20, 2024
1 parent 4b26368 commit aa17bb4
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 23 deletions.
22 changes: 22 additions & 0 deletions src/main/scala/ai/privado/exporter/JSONExporter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import ai.privado.model.exporter.ViolationEncoderDecoder.*
import ai.privado.model.exporter.CollectionEncoderDecoder.*
import ai.privado.model.exporter.AndroidPermissionsEncoderDecoder.*
import ai.privado.model.exporter.SinkEncoderDecoder.*
import ai.privado.semantic.language.*
import ai.privado.script.ExternalScalaScriptRunner
import better.files.File
import io.circe.Json
Expand All @@ -75,6 +76,7 @@ import scala.language.postfixOps
import scala.util.{Failure, Success, Try}
import io.shiftleft.semanticcpg.language.*
import ai.privado.semantic.language.*
import io.shiftleft.codepropertygraph.generated.nodes.AstNode
object JSONExporter {

private val logger = LoggerFactory.getLogger(getClass)
Expand Down Expand Up @@ -218,6 +220,26 @@ object JSONExporter {
output.addOne(Constants.propertyDependency -> propertyAndUsedAt.asJson)
output.addOne(Constants.propertyFiles -> cpg.property.file.name.dedup.l.asJson)

// All Original source nodes
val originalSources = Try(cpg.all.collectAll[AstNode].originalSource.dedup.l).getOrElse(List.empty)

// Original source file names
val originalSourceFileNames = originalSources.file.name.dedup.l

val originalSourceToDerivedSourceFileMapping = originalSources
.flatMap { source =>
val derivedSourceNodes = source.derivedSource.l
if (derivedSourceNodes.nonEmpty)
Some((source.file.name.headOption.getOrElse(""), derivedSourceNodes.file.name.dedup.l))
else
None
}
.groupBy(_._1)
.map(item => (item._1, item._2.flatMap(_._2).distinct))

output.addOne(Constants.originalSourceFiles -> originalSourceFileNames.asJson)
output.addOne(Constants.originalSourceDependency -> originalSourceToDerivedSourceFileMapping.asJson)

/** For Java the namespace is working as expected, for languages like JS, Python we are getting the namespace as
* <`global`> for nearly all files
*
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/ai/privado/model/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ object Constants {
val namespaceDependency = "namespaceDependency"
val propertyFiles = "propertyFiles"
val importDependency = "importDependency"
val originalSourceFiles = "originalSourceFiles"
val originalSourceDependency = "originalSourceDependency"

// database details
val dbName = "dbName"
Expand Down
37 changes: 25 additions & 12 deletions src/main/scala/ai/privado/semantic/language/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,25 @@ package object language {
}
}

implicit class NodeToOriginalSourceTraversal(val nodes: Iterator[AstNode]) extends AnyVal {

/** For a given input of nodes returns all the original sources
* @return
*/
def originalSource: Iterator[AstNode] = Try {
nodes.flatMap(n => NodeToOriginalSource(n).originalSource)
}.getOrElse(Iterator.empty)
}

implicit class NodeToOriginalSource(val node: AstNode) extends AnyVal {
def originalSource: Option[AstNode] = {

/** Returns all the original sources for the input node
* @return
*/
def originalSource: Iterator[AstNode] = Try {
val _originalSource = node.out(EdgeTypes.ORIGINAL_SOURCE)
if (_originalSource.nonEmpty && _originalSource.hasNext) {
return Option(_originalSource.next().asInstanceOf[AstNode])
}
None
}
_originalSource.toList.collectAll[AstNode]
}.getOrElse(Iterator.empty)

def originalSource(sourceId: String): Option[AstNode] = {
val _originalSource = node.out(EdgeTypes.ORIGINAL_SOURCE)
Expand All @@ -169,13 +180,15 @@ package object language {
}

implicit class OriginalToDerivedSource(val node: AstNode) extends AnyVal {
def derivedSource: Option[AstNode] = {

/** For a given input of node returns all the derived sources
*
* @return
*/
def derivedSource: Iterator[AstNode] = Try {
val _derivedSource = node.out(EdgeTypes.DERIVED_SOURCE)
if (_derivedSource.nonEmpty && _derivedSource.hasNext) {
return Option(_derivedSource.next().asInstanceOf[AstNode])
}
None
}
_derivedSource.toList.collectAll[AstNode]
}.getOrElse(Iterator.empty)
}

implicit class NodeStartersForModule(cpg: Cpg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ class IdentifierTaggingTests extends CSharpFrontendTestSuite with TraversalValid
"build correct edges between derived and original sources" in {
val List(identifierNode) = cpg.identifier("b").l
val List(phoneNumberMember) = cpg.member("PhoneNumber").l
originalSourceTraversalValidator(identifierNode, "Data.Sensitive.ContactData.PhoneNumber")
derivedSourceTraversalValidator(phoneNumberMember)
originalSourceTraversalValidator(phoneNumberMember, identifierNode, "Data.Sensitive.ContactData.PhoneNumber")
derivedSourceTraversalValidator(phoneNumberMember, identifierNode)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ class GoIdentifierTaggingTest extends GoTestBase with TraversalValidator {
"build correct edges between derived and original sources" in {
val List(userIdentifier) = cpg.identifier("user").lineNumber(16).l
val List(firstNameMember) = cpg.member("FirstName").l
originalSourceTraversalValidator(userIdentifier, "Data.Sensitive.PersonalIdentification.FirstName")
derivedSourceTraversalValidator(firstNameMember)
originalSourceTraversalValidator(
firstNameMember,
userIdentifier,
"Data.Sensitive.PersonalIdentification.FirstName"
)
derivedSourceTraversalValidator(firstNameMember, userIdentifier)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ class JavaSourceTaggingTest extends JavaFrontendTestSuite with TraversalValidato
"build correct edges between derived and original sources" in {
val List(_, userIdentifier) = cpg.identifier.nameExact("us").lineNumber(4).l
val List(firstNameMember) = cpg.member("firstName").l
originalSourceTraversalValidator(userIdentifier, "Data.Sensitive.FirstName")
derivedSourceTraversalValidator(firstNameMember)
originalSourceTraversalValidator(firstNameMember, userIdentifier, "Data.Sensitive.FirstName")
derivedSourceTraversalValidator(firstNameMember, userIdentifier)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ class IdentifierTaggingTest extends PhpFrontendTestSuite with TraversalValidator
"build correct edges between derived and original sources" in {
val List(userIdentifier) = cpg.identifier("user").lineNumber(20).l
val List(firstNameMember) = cpg.member("firstName").l
originalSourceTraversalValidator(userIdentifier, "Data.Sensitive.FirstName")
derivedSourceTraversalValidator(firstNameMember)
originalSourceTraversalValidator(firstNameMember, userIdentifier, "Data.Sensitive.FirstName")
derivedSourceTraversalValidator(firstNameMember, userIdentifier)
}

}
Expand Down
13 changes: 10 additions & 3 deletions src/test/scala/ai/privado/traversal/TraversalValidator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,30 @@ import ai.privado.model.InternalTag
import io.shiftleft.codepropertygraph.generated.nodes.AstNode
import io.shiftleft.semanticcpg.language.*
import org.scalatest.Assertion
import org.scalatest.Assertions.assert
import org.scalatest.matchers.should.Matchers

trait TraversalValidator extends Matchers {

def originalSourceTraversalValidator(derivedSourceNode: AstNode, sourceId: String): Assertion = {
def originalSourceTraversalValidator(
originalSource: AstNode,
derivedSourceNode: AstNode,
sourceId: String
): Assertion = {
derivedSourceNode
.originalSource(sourceId)
.get
.tag
.name(InternalTag.ORIGINAL_SOURCE_FOR_DERIVED_NODE.toString)
.nonEmpty shouldBe true
derivedSourceNode.originalSource.contains(originalSource) shouldBe true
}

def derivedSourceTraversalValidator(originalSourceNode: AstNode): Assertion = {
originalSourceNode.derivedSource.get.tag
def derivedSourceTraversalValidator(originalSourceNode: AstNode, derivedSourceNode: AstNode): Assertion = {
originalSourceNode.derivedSource.tag
.name(InternalTag.OBJECT_OF_SENSITIVE_CLASS_BY_MEMBER_NAME.toString)
.nonEmpty shouldBe true
originalSourceNode.derivedSource.contains(derivedSourceNode) shouldBe true
}

}

0 comments on commit aa17bb4

Please sign in to comment.