Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 191 additions & 4 deletions packages/pyright-internal/src/analyzer/dataClasses.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,60 @@ export function synthesizeDataClassMethods(
// based on whether this is a NamedTuple or a dataclass.
const constructorType = isNamedTuple ? newType : initType;

// Detect if this class is a Pydantic BaseModel (by MRO full name heuristic).
const isPydanticModel = classType.shared.mro.some(
(m) =>
isClass(m) && (m.shared.fullName === 'pydantic.main.BaseModel' || m.shared.fullName.endsWith('.BaseModel'))
);

// Maintain a list of "type evaluators".
type EntryTypeEvaluator = () => Type;
const localEntryTypeEvaluator: { entry: DataClassEntry; evaluator: EntryTypeEvaluator }[] = [];
let sawKeywordOnlySeparator = false;
// Indicates that at least one field uses a dynamic alias (e.g., validation_alias=AliasChoices or alias_generator)
// In this case, we will relax constructor parameter checking by adding **kwargs and excluding those fields from __init__.
let sawDynamicAlias = false;

ClassType.getSymbolTable(classType).forEach((symbol, name) => {
// Early handling for Pydantic model_config regardless of typing/annotation on the symbol
if (name === 'model_config') {
const decls = symbol.getDeclarations();
for (const decl of decls) {
// We care only about variable assignments
if (decl.type !== DeclarationType.Variable) {
continue;
}
// Find the assignment statement for this declaration
let stmt: ParseNode | undefined = decl.node;
while (stmt && stmt.nodeType !== ParseNodeType.Assignment) {
stmt = stmt.parent;
}
if (stmt && stmt.nodeType === ParseNodeType.Assignment) {
const right = stmt.d.rightExpr;
if (right.nodeType === ParseNodeType.Call) {
// alias_generator implies dynamic aliases
const hasAliasGen = !!right.d.args.find((arg) => arg.d.name?.d.value === 'alias_generator');
if (hasAliasGen) {
sawDynamicAlias = true;
}
const fileInfo = AnalyzerNodeInfo.getFileInfo(node);
const behaviors = (classType.shared.dataClassBehaviors ||= { fieldDescriptorNames: [] });
const popArg = right.d.args.find((arg) => arg.d.name?.d.value === 'populate_by_name');
if (popArg?.d.valueExpr) {
const val = evaluateStaticBoolExpression(
popArg.d.valueExpr,
fileInfo.executionEnvironment,
fileInfo.definedConstants
);
if (val !== undefined) {
behaviors.populateByName = val;
}
}
// ignore frozen for now in tests
}
}
}
}
if (symbol.isIgnoredForProtocolMatch()) {
return;
}
Expand Down Expand Up @@ -331,6 +379,14 @@ export function synthesizeDataClassMethods(
if (defaultValueArg?.d.valueExpr) {
defaultExpr = defaultValueArg.d.valueExpr;
}
// Support positional default as the first argument to Field(...)
if (!hasDefault) {
const firstPositional = statement.d.rightExpr.d.args.find((arg) => !arg.d.name);
if (firstPositional?.d.valueExpr) {
hasDefault = true;
defaultExpr = firstPositional.d.valueExpr;
}
}

const defaultFactoryArg = statement.d.rightExpr.d.args.find(
(arg) => arg.d.name?.d.value === 'default_factory' || arg.d.name?.d.value === 'factory'
Expand All @@ -343,16 +399,60 @@ export function synthesizeDataClassMethods(
defaultExpr = defaultFactoryArg.d.valueExpr;
}

const aliasArg = statement.d.rightExpr.d.args.find((arg) => arg.d.name?.d.value === 'alias');
if (aliasArg) {
// Prefer `validation_alias` over `alias` if both are provided. for pydantic
const validationAliasArg = statement.d.rightExpr.d.args.find(
(arg) => arg.d.name?.d.value === 'validation_alias'
);
const aliasArg =
validationAliasArg ??
statement.d.rightExpr.d.args.find((arg) => arg.d.name?.d.value === 'alias');
if (aliasArg && aliasArg.d.valueExpr) {
const valueType = evaluator.getTypeOfExpression(aliasArg.d.valueExpr).type;
if (
isClassInstance(valueType) &&
ClassType.isBuiltIn(valueType, 'str') &&
isLiteralType(valueType)
) {
// Static, literal alias: use it as the constructor parameter name.
aliasName = valueType.priv.literalValue as string;
} else {
// Dynamic alias (e.g., AliasChoices or computed). We can't know the name statically,
// so exclude this field from the generated __init__ signature and allow **kwargs.
includeInInit = false;
sawDynamicAlias = true;
}
}

// Detect pydantic model_config settings on this class.
if (
variableNameNode?.d.value === 'model_config' &&
statement.d.rightExpr.nodeType === ParseNodeType.Call
) {
// alias_generator implies dynamic aliases
const hasAliasGen = !!statement.d.rightExpr.d.args.find(
(arg) => arg.d.name?.d.value === 'alias_generator'
);
if (hasAliasGen) {
sawDynamicAlias = true;
}

// Extract populate_by_name and frozen flags from ConfigDict
const fileInfo = AnalyzerNodeInfo.getFileInfo(node);
const behaviors = (classType.shared.dataClassBehaviors ||= { fieldDescriptorNames: [] });
const popArg = statement.d.rightExpr.d.args.find(
(arg) => arg.d.name?.d.value === 'populate_by_name'
);
if (popArg?.d.valueExpr) {
const val = evaluateStaticBoolExpression(
popArg.d.valueExpr,
fileInfo.executionEnvironment,
fileInfo.definedConstants
);
if (val !== undefined) {
behaviors.populateByName = val;
}
}
// ignore frozen for now in tests
}

const converterArg = statement.d.rightExpr.d.args.find(
Expand All @@ -363,6 +463,36 @@ export function synthesizeDataClassMethods(
}
}
}

// Detect pydantic model_config on this class assignment
if (
name === 'model_config' &&
statement.nodeType === ParseNodeType.Assignment &&
statement.d.rightExpr.nodeType === ParseNodeType.Call
) {
const hasAliasGen = !!statement.d.rightExpr.d.args.find(
(arg) => arg.d.name?.d.value === 'alias_generator'
);
if (hasAliasGen) {
sawDynamicAlias = true;
}
const fileInfo = AnalyzerNodeInfo.getFileInfo(node);
const behaviors = (classType.shared.dataClassBehaviors ||= { fieldDescriptorNames: [] });
const popArg = statement.d.rightExpr.d.args.find(
(arg) => arg.d.name?.d.value === 'populate_by_name'
);
if (popArg?.d.valueExpr) {
const val = evaluateStaticBoolExpression(
popArg.d.valueExpr,
fileInfo.executionEnvironment,
fileInfo.definedConstants
);
if (val !== undefined) {
behaviors.populateByName = val;
}
}
// ignore frozen for now in tests
}
} else if (statement.nodeType === ParseNodeType.TypeAnnotation) {
if (statement.d.valueExpr.nodeType === ParseNodeType.Name) {
variableNameNode = statement.d.valueExpr;
Expand Down Expand Up @@ -392,6 +522,12 @@ export function synthesizeDataClassMethods(
if (variableNameNode && variableTypeEvaluator) {
const variableName = variableNameNode.d.value;

// In Pydantic BaseModel, attributes starting with an underscore are not fields
// and should not be accepted by the constructor.
if (isPydanticModel && variableName.startsWith('_')) {
includeInInit = false;
}

// Named tuples don't allow attributes that begin with an underscore.
if (isNamedTuple && variableName.startsWith('_')) {
evaluator.addDiagnostic(
Expand Down Expand Up @@ -565,7 +701,7 @@ export function synthesizeDataClassMethods(
if (!skipSynthesizeInit && !hasExistingInitMethod) {
if (allAncestorsKnown) {
fullDataClassEntries.forEach((entry) => {
if (entry.includeInInit) {
if (entry.includeInInit && !sawDynamicAlias) {
let defaultType: Type | undefined;

// If the type refers to Self of the parent class, we need to
Expand Down Expand Up @@ -643,12 +779,18 @@ export function synthesizeDataClassMethods(
);
}

const isPopulateDual = !!(
classType.shared.dataClassBehaviors?.populateByName &&
entry.alias &&
entry.name !== entry.alias
);
const optionalDefault = isPopulateDual ? AnyType.create(/* isEllipsis */ true) : defaultType;
const param = FunctionParam.create(
ParamCategory.Simple,
effectiveType,
FunctionParamFlags.TypeDeclared,
effectiveName,
defaultType,
optionalDefault,
entry.defaultExpr
);

Expand All @@ -658,6 +800,37 @@ export function synthesizeDataClassMethods(
FunctionType.addParam(constructorType, param);
}

// If configured to populate by name, accept both the alias and the original field name.
if (
classType.shared.dataClassBehaviors?.populateByName &&
entry.alias &&
entry.name !== entry.alias
) {
const paramByName = FunctionParam.create(
ParamCategory.Simple,
effectiveType,
FunctionParamFlags.TypeDeclared,
entry.name,
AnyType.create(/* isEllipsis */ true),
entry.defaultExpr
);
if (entry.isKeywordOnly) {
keywordOnlyParams.push(paramByName);
} else {
FunctionType.addParam(constructorType, paramByName);
}
if (replaceType) {
const paramByNameWithDefault = FunctionParam.create(
paramByName.category,
paramByName._type,
paramByName.flags,
paramByName.name,
AnyType.create(/* isEllipsis */ true)
);
FunctionType.addParam(replaceType, paramByNameWithDefault);
}
}

if (replaceType) {
const paramWithDefault = FunctionParam.create(
param.category,
Expand All @@ -672,6 +845,20 @@ export function synthesizeDataClassMethods(
}
});

// If we saw any dynamic aliases, add a **kwargs parameter to relax parameter checking
if (sawDynamicAlias) {
const kwargsParam = FunctionParam.create(
ParamCategory.KwargsDict,
UnknownType.create(),
FunctionParamFlags.TypeDeclared,
'kwargs'
);
FunctionType.addParam(constructorType, kwargsParam);
if (replaceType) {
FunctionType.addParam(replaceType, kwargsParam);
}
}

if (keywordOnlyParams.length > 0) {
FunctionType.addKeywordOnlyParamSeparator(constructorType);
keywordOnlyParams.forEach((param) => {
Expand Down
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18359,7 +18359,8 @@ export function createTypeEvaluator(
}

// Synthesize dataclass methods.
if (ClassType.isDataClass(classType) || isNamedTupleSubclass) {
// Also synthesize for classes that receive dataclass-like behaviors via dataclass_transform.
if (ClassType.isDataClass(classType) || isNamedTupleSubclass || !!classType.shared.dataClassBehaviors) {
const skipSynthesizedInit = ClassType.isDataClassSkipGenerateInit(classType);
let hasExistingInitMethod = skipSynthesizedInit;

Expand Down
2 changes: 2 additions & 0 deletions packages/pyright-internal/src/analyzer/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,8 @@ export interface DataClassBehaviors {
keywordOnly?: boolean;
frozen?: boolean;
frozenDefault?: boolean;
// Pydantic-specific: when true, accept both field name and alias in __init__
populateByName?: boolean;
fieldDescriptorNames: string[];
}

Expand Down
54 changes: 54 additions & 0 deletions packages/pyright-internal/src/tests/pydantic.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) BasedSoft Corporation.
* Licensed under the MIT license.
* Author: KotlinIsland
*
* Unit tests for pydantic support.
*/

import * as TestUtils from './testUtils';
import { DiagnosticRule } from '../common/diagnosticRules';

test('aliases', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['pydanticAlias.py']);
TestUtils.validateResultsButBased(analysisResults, {
errors: [
{
code: DiagnosticRule.reportCallIssue,
line: 18,
message: 'Arguments missing for parameters "b1", "b2", "b3"',
},
{ code: DiagnosticRule.reportCallIssue, line: 19, message: 'No parameter named "a1"' },
{ code: DiagnosticRule.reportCallIssue, line: 20, message: 'No parameter named "a2"' },
{ code: DiagnosticRule.reportCallIssue, line: 21, message: 'No parameter named "a3"' },
{ code: DiagnosticRule.reportCallIssue, line: 22, message: 'No parameter named "z"' },
{
code: DiagnosticRule.reportAttributeAccessIssue,
line: 38,
message: 'Cannot access attribute "b1" for class "M"\n\u00A0\u00A0Attribute "b1" is unknown',
},
{
code: DiagnosticRule.reportAttributeAccessIssue,
line: 39,
message: 'Cannot access attribute "b2" for class "M"\n\u00A0\u00A0Attribute "b2" is unknown',
},
{
code: DiagnosticRule.reportAttributeAccessIssue,
line: 40,
message: 'Cannot access attribute "b3" for class "M"\n\u00A0\u00A0Attribute "b3" is unknown',
},
{
code: DiagnosticRule.reportAttributeAccessIssue,
line: 41,
message: 'Cannot access attribute "z" for class "M"\n\u00A0\u00A0Attribute "z" is unknown',
},
],
});
});

test('other features', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['pydanticFeatures.py']);
TestUtils.validateResultsButBased(analysisResults, {
errors: [{ code: DiagnosticRule.reportCallIssue, line: 21, message: 'No parameter named "z"' }],
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# minimal stub for testing

from .config import ConfigDict as ConfigDict
from .aliases import AliasChoices as AliasChoices
from .main import BaseModel as BaseModel
from .fields import Field as Field
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class AliasChoices:
def __init__(self, *choices: str) -> None: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ConfigDict(dict):
def __init__(self, **kwargs): ...
12 changes: 12 additions & 0 deletions packages/pyright-internal/src/tests/samples/pydantic/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any, Callable, Tuple, dataclass_transform

def Field(
default: Any = ...,
default_factory: Callable[[], Any] | None = ...,
alias: str | None = ...,
validation_alias: Any = ...,
kw_only: bool | None = ...,
init: bool | None = ...,
converter: Any = ...,
factory: Callable[[], Any] | None = ...,
) -> Any: ...
10 changes: 10 additions & 0 deletions packages/pyright-internal/src/tests/samples/pydantic/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Any, Callable, Tuple, dataclass_transform
from abc import ABCMeta

from .fields import Field

@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class ModelMetaclass(ABCMeta): ...


class BaseModel(metaclass=ModelMetaclass): ...
Loading
Loading