Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions framework/fit/python/fit_cli/utils/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

type_errors = []

def parse_type(annotation):
def parse_type(annotation, custom_classes):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不对,按照java这边写插件举个例子,你可以看下json schema的官网规范
image
对于这样的一个fitable,这里返回的结构是
image
也就是除非解析不了的类型是object,否则如果是用户自定义的类,就要嵌入的去解析,我看了下你这边的逻辑,返回的结果是这样的
image
直到类就结束了,这样不符合预期

"""解析参数类型"""
global type_errors

Expand All @@ -36,6 +36,8 @@ def parse_type(annotation):
elif isinstance(annotation, ast.Name):
if annotation.id in TYPE_MAP:
return TYPE_MAP[annotation.id], None, True
elif annotation.id in custom_classes:
return "object", None, True
else:
type_errors.append(f"不支持的类型: {annotation.id}")
return "invalid", None, True
Expand All @@ -49,7 +51,7 @@ def parse_type(annotation):

# List[int]
if container in ("list", "List"):
item_type, _, _ = parse_type(annotation.slice)
item_type, _, _ = parse_type(annotation.slice, custom_classes)
if item_type == "invalid":
type_errors.append(f"不支持的列表元素类型: {annotation.slice}")
return "invalid", None, True
Expand All @@ -61,7 +63,7 @@ def parse_type(annotation):

# Optional[int]
elif container == "Optional":
inner_type, inner_items, _ = parse_type(annotation.slice)
inner_type, inner_items, _ = parse_type(annotation.slice, custom_classes)
if inner_type == "invalid":
type_errors.append(f"不支持的Optional类型: {annotation.slice}")
return "invalid", None, False
Expand All @@ -76,22 +78,22 @@ def parse_type(annotation):
items = []
if isinstance(annotation.slice, ast.Tuple):
for elt in annotation.slice.elts:
item_type, _, _ = parse_type(elt)
item_type, _, _ = parse_type(elt, custom_classes)
if item_type == "invalid":
type_errors.append(f"不支持的元组元素类型: {ast.dump(elt)}")
return "invalid", None, True
items.append({"type":item_type})
return "array", f"{items}", True
else:
item_type, _, _ = parse_type(annotation.slice)
item_type, _, _ = parse_type(annotation.slice, custom_classes)
if item_type == "invalid":
type_errors.append(f"不支持的元组元素类型: {ast.dump(annotation.slice)}")
return "invalid", None, True
return "array", {"type":item_type}, True

# Set[int]
elif container in ("set", "Set"):
item_type, _, _ = parse_type(annotation.slice)
item_type, _, _ = parse_type(annotation.slice, custom_classes)
if item_type == "invalid":
type_errors.append(f"不支持的集合元素类型: {annotation.slice}")
return "invalid", None, True
Expand All @@ -106,7 +108,7 @@ def parse_type(annotation):
return "invalid", None, True


def parse_parameters(args):
def parse_parameters(args, custom_classes):
"""解析函数参数"""
properties = {}
order = []
Expand All @@ -115,7 +117,7 @@ def parse_parameters(args):
for arg in args.args:
arg_name = arg.arg
order.append(arg_name)
arg_type, items, is_required = parse_type(arg.annotation)
arg_type, items, is_required = parse_type(arg.annotation, custom_classes)
# 定义参数
prop_def = {
"defaultValue": "",
Expand All @@ -132,12 +134,12 @@ def parse_parameters(args):
return properties, order, required


def parse_return(annotation):
def parse_return(annotation, custom_classes):
"""解析返回值类型"""
if not annotation:
return {"type": "string", "convertor": ""}

return_type, items, _ = parse_type(annotation)
return_type, items, _ = parse_type(annotation, custom_classes)
ret = {
"type": return_type,
**({"items": items} if items else {}),
Expand All @@ -155,8 +157,12 @@ def parse_python_file(file_path: Path):
py_name = file_path.stem
definitions = []
tool_groups = []
custom_classes = set()

for node in tree.body:
if isinstance(node, ast.ClassDef):
custom_classes.add(node.name)

if isinstance(node, ast.FunctionDef):
func_name = node.name
# 默认描述
Expand Down Expand Up @@ -185,8 +191,8 @@ def parse_python_file(file_path: Path):
continue

# 解析参数和返回值
properties, order, required = parse_parameters(node.args)
return_schema = parse_return(node.returns)
properties, order, required = parse_parameters(node.args, custom_classes)
return_schema = parse_return(node.returns, custom_classes)

# definition schema
definition_schema = {
Expand Down