Skip to content

Commit

Permalink
Remove typeFullName and referenceTargetTypeFullName APIs from TypeInf…
Browse files Browse the repository at this point in the history
…oProvider.

Also refactored assignmentAstForDestructuringEntry. The right hand side
base is now created outside of this function allowing for other
constructs than pure identifiers. The pure identifier were e.g. wrong
in case of class member references.
  • Loading branch information
ml86 committed Oct 18, 2024
1 parent eda46f2 commit 2c63821
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,7 @@ class AstCreator(

protected def assignmentAstForDestructuringEntry(
entry: KtDestructuringDeclarationEntry,
componentNReceiverName: String,
componentNTypeFullName: String,
rhsBaseAst: Ast,
componentIdx: Integer
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val entryTypeFullName = registerType(
Expand All @@ -428,10 +427,6 @@ class AstCreator(
val assignmentLHSNode = identifierNode(entry, entry.getText, entry.getText, entryTypeFullName)
val assignmentLHSAst = astWithRefEdgeMaybe(assignmentLHSNode.name, assignmentLHSNode)

val componentNIdentifierNode =
identifierNode(entry, componentNReceiverName, componentNReceiverName, componentNTypeFullName)
.argumentIndex(0)

val desc = bindingUtils.getCalledFunctionDesc(entry)
val descFullName = desc
.flatMap(nameRenderer.descFullName)
Expand All @@ -441,7 +436,8 @@ class AstCreator(
.getOrElse(s"${Defines.UnresolvedSignature}()")
val fullName = nameRenderer.combineFunctionFullName(descFullName, signature)

val componentNCallCode = s"$componentNReceiverName.${Constants.componentNPrefix}$componentIdx()"
val componentNCallCode =
s"${rhsBaseAst.root.get.asInstanceOf[ExpressionNew].code}.${Constants.componentNPrefix}$componentIdx()"
val componentNCallNode = callNode(
entry,
componentNCallCode,
Expand All @@ -452,9 +448,8 @@ class AstCreator(
Some(entryTypeFullName)
)

val componentNIdentifierAst = astWithRefEdgeMaybe(componentNIdentifierNode.name, componentNIdentifierNode)
val componentNAst =
callAst(componentNCallNode, Seq(), Option(componentNIdentifierAst))
callAst(componentNCallNode, Seq(), Option(rhsBaseAst))

val assignmentCallNode = NodeBuilders.newOperatorCallNode(
Operators.assignment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,13 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
}

val assignmentsForEntries = nonUnderscoreDestructuringEntries(expr).zipWithIndex.map { case (entry, idx) =>
assignmentAstForDestructuringEntry(entry, localForTmpNode.name, localForTmpNode.typeFullName, idx + 1)
val rhsBaseAst =
astWithRefEdgeMaybe(
localForTmpNode.name,
identifierNode(entry, localForTmpNode.name, localForTmpNode.name, localForTmpNode.typeFullName)
.argumentIndex(0)
)
assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1)
}
localsForEntries ++ Seq(localForTmpAst) ++
Seq(tmpAssignmentAst) ++ tmpAssignmentPrologue ++ assignmentsForEntries
Expand All @@ -397,12 +403,11 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
logger.warn(s"Unhandled case for destructuring declaration: `${expr.getText}` in this file `$relativizedPath`.")
return Seq()
}
val destructuringRHS = typedInit.get

val initTypeFullName = registerType(typeInfoProvider.typeFullName(typedInit.get, TypeConstants.any))
val assignmentsForEntries =
nonUnderscoreDestructuringEntries(expr).zipWithIndex.map { case (entry, idx) =>
assignmentAstForDestructuringEntry(entry, destructuringRHS.getText, initTypeFullName, idx + 1)
val rhsBaseAst = astForNameReference(typedInit.get, Some(1), None)
assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1)
}
val localsForEntries = localsForDestructuringEntries(expr)
localsForEntries ++ assignmentsForEntries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewMember
import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal
import org.jetbrains.kotlin.descriptors.{ClassifierDescriptor, PropertyDescriptor, ValueDescriptor}
import org.jetbrains.kotlin.descriptors.impl.PropertyDescriptorImpl
import org.jetbrains.kotlin.psi.KtAnnotationEntry
import org.jetbrains.kotlin.psi.KtClassLiteralExpression
import org.jetbrains.kotlin.psi.KtConstantExpression
Expand Down Expand Up @@ -105,7 +107,9 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) {
private def astForNameReferenceToType(expr: KtNameReferenceExpression, argIdx: Option[Int])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
val typeFullName = registerType(typeInfoProvider.typeFullName(expr, TypeConstants.any))
val declDesc =
bindingUtils.getDeclDesc(expr).collect { case classifierDesc: ClassifierDescriptor => classifierDesc }
val typeFullName = registerType(declDesc.flatMap(nameRenderer.descFullName).getOrElse(TypeConstants.any))
val referencesCompanionObject = typeInfoProvider.isRefToCompanionObject(expr)
if (referencesCompanionObject) {
val argAsts = List(
Expand All @@ -130,11 +134,18 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) {
private def astForNameReferenceToMember(expr: KtNameReferenceExpression, argIdx: Option[Int])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
val typeFullName = registerType(typeInfoProvider.typeFullName(expr, TypeConstants.any))
val referenceTargetTypeFullName = registerType(
typeInfoProvider.referenceTargetTypeFullName(expr, TypeConstants.any)
)
val thisNode = identifierNode(expr, Constants.this_, Constants.this_, referenceTargetTypeFullName)
val declDesc = bindingUtils.getDeclDesc(expr).collect { case propDesc: PropertyDescriptor => propDesc }
val typeFullName = declDesc
.flatMap(desc => nameRenderer.typeFullName(desc.getType))
.getOrElse(TypeConstants.any)
registerType(typeFullName)

val baseTypeFullName = declDesc
.flatMap(desc => nameRenderer.typeFullName(desc.getDispatchReceiverParameter.getType))
.getOrElse(TypeConstants.any)
registerType(baseTypeFullName)

val thisNode = identifierNode(expr, Constants.this_, Constants.this_, baseTypeFullName)
val thisAst = astWithRefEdgeMaybe(Constants.this_, thisNode)
val _fieldIdentifierNode = fieldIdentifierNode(expr, expr.getReferencedName, expr.getReferencedName)
val node = NodeBuilders.newOperatorCallNode(
Expand All @@ -152,17 +163,19 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) {
argIdx: Option[Int],
argName: Option[String] = None
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val typeFromScopeMaybe = scope.lookupVariable(expr.getIdentifier.getText) match {
case Some(n: NewLocal) => Some(n.typeFullName)
case Some(n: NewMethodParameterIn) => Some(n.typeFullName)
case _ => None
}
val typeFromProvider = typeInfoProvider.typeFullName(expr, Defines.UnresolvedNamespace)
val typeFullName =
typeFromScopeMaybe match {
case Some(fullName) => registerType(fullName)
case None => registerType(typeFromProvider)
val declDesc = bindingUtils.getDeclDesc(expr).collect { case valueDesc: ValueDescriptor => valueDesc }
val typeFullName = declDesc
.flatMap(desc => nameRenderer.typeFullName(desc.getType))
.orElse {
val typeFromScopeMaybe = scope.lookupVariable(expr.getIdentifier.getText) match {
case Some(n: NewLocal) => Some(n.typeFullName)
case Some(n: NewMethodParameterIn) => Some(n.typeFullName)
case _ => None
}
typeFromScopeMaybe
}
.getOrElse(TypeConstants.any)

val name = expr.getIdentifier.getText
val node =
withArgumentName(withArgumentIndex(identifierNode(expr, name, name, typeFullName), argIdx), argName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) {
val assignmentsForEntries =
destructuringDeclEntries.asScala.filterNot(_.getText == Constants.unusedDestructuringEntryText).zipWithIndex.map {
case (entry, idx) =>
assignmentAstForDestructuringEntry(entry, localForTmp.name, localForTmp.typeFullName, idx + 1)
val rhsBaseAst =
astWithRefEdgeMaybe(
localForTmp.name,
identifierNode(entry, localForTmp.name, localForTmp.name, localForTmp.typeFullName)
.argumentIndex(0)
)
assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1)
}

val stmtAsts = astsForExpression(expr.getBody, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlin.descriptors.annotations.AnnotationDescriptor
import org.jetbrains.kotlin.descriptors.{
ClassDescriptor,
ConstructorDescriptor,
DeclarationDescriptor,
FunctionDescriptor,
TypeAliasDescriptor,
VariableDescriptor
Expand All @@ -17,6 +18,7 @@ import org.jetbrains.kotlin.psi.{
KtFunctionLiteral,
KtNamedFunction,
KtParameter,
KtReferenceExpression,
KtTypeAlias,
KtTypeReference
}
Expand Down Expand Up @@ -93,6 +95,10 @@ class BindingContextUtils(bindingContext: BindingContext) {
bindingContext.get(BindingContext.ANNOTATION, entry)
}

def getDeclDesc(nameRefExpr: KtReferenceExpression): Option[DeclarationDescriptor] = {
Option(bindingContext.get(BindingContext.REFERENCE_TARGET, nameRefExpr))
}

def getExprType(expr: KtExpression): Option[KotlinType] = {
Option(bindingContext.get(BindingContext.EXPRESSION_TYPE_INFO, expr))
.flatMap(typeInfo => Option(typeInfo.getType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,29 +201,6 @@ class DefaultTypeInfoProvider(val bindingContext: BindingContext, typeRenderer:
.map(_ => bindingContext.get(BindingContext.REFERENCE_TARGET, expr))
}

def referenceTargetTypeFullName(expr: KtNameReferenceExpression, defaultValue: String): String = {
descriptorForNameReference(expr)
.collect { case desc: PropertyDescriptorImpl => typeRenderer.renderFqNameForDesc(desc.getContainingDeclaration) }
.getOrElse(defaultValue)
}

def typeFullName(expr: KtNameReferenceExpression, defaultValue: String): String = {
descriptorForNameReference(expr)
.flatMap {
case typedDesc: ValueDescriptor => Some(typeRenderer.render(typedDesc.getType))
// TODO: add test cases for the LazyClassDescriptors (`okio` codebase serves as good example)
case typedDesc: LazyClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType))
case typedDesc: LazyJavaClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType))
case typedDesc: DeserializedClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType))
case typedDesc: EnumEntrySyntheticClassDescriptor => Some(typeRenderer.render(typedDesc.getDefaultType))
case typedDesc: LazyPackageViewDescriptorImpl => Some(typeRenderer.renderFqNameForDesc(typedDesc))
case unhandled: Any =>
logger.debug(s"Unhandled class type info fetch in for `${expr.getText}` with class `${unhandled.getClass}`.")
None
case null => None
}
.getOrElse(defaultValue)
}
def typeFromImports(name: String, file: KtFile): Option[String] = {
file.getImportList.getImports.asScala.flatMap { directive =>
if (directive.getImportedName != null && directive.getImportedName.toString == name.stripSuffix("?"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ trait TypeInfoProvider(val typeRenderer: TypeRenderer = new TypeRenderer()) {

def propertyType(expr: KtProperty, defaultValue: String): String

def typeFullName(expr: KtNameReferenceExpression, defaultValue: String): String

def referenceTargetTypeFullName(expr: KtNameReferenceExpression, defaultValue: String): String

def isReferenceToClass(expr: KtNameReferenceExpression): Boolean

def bindingKind(expr: KtQualifiedExpression): CallKind
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CallsToFieldAccessTests extends KotlinCode2CpgFixture(withOssDataflow = fa
val List(c) = cpg.call.codeExact("println(x)").argument.isCall.l
c.code shouldBe "this.x"
c.name shouldBe Operators.fieldAccess
c.typeFullName shouldBe "java.lang.String"
c.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH
c.lineNumber shouldBe Some(6)
c.columnNumber shouldBe Some(16)
Expand Down

0 comments on commit 2c63821

Please sign in to comment.