diff --git a/.github/workflows/docker-build-push-mainland.yml b/.github/workflows/docker-build-push-mainland.yml index 6099a4817..279cdb68e 100644 --- a/.github/workflows/docker-build-push-mainland.yml +++ b/.github/workflows/docker-build-push-mainland.yml @@ -3,6 +3,15 @@ name: Docker Build and Push All Images to tencentyun on: workflow_dispatch: inputs: + version: + description: 'Image version tag (e.g. v1.0.0 or latest)' + required: true + default: 'latest' + push_latest: + description: 'Also push latest tag' + required: false + default: false + type: boolean runner_label_json: description: 'runner array in json format (e.g. ["ubuntu-latest"] or ["self-hosted"])' required: true @@ -23,10 +32,16 @@ jobs: uses: actions/checkout@v4 - name: Build main image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent:amd64 -f make/main/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . + docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-amd64 -f make/main/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push main image (amd64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-amd64 + - name: Tag main image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-amd64 ccr.ccs.tencentyun.com/nexent-hub/nexent:amd64 + - name: Push latest main image (amd64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent:amd64 build-and-push-main-arm64: @@ -43,10 +58,16 @@ jobs: uses: actions/checkout@v4 - name: Build main image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent:arm64 -f make/main/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . + docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-arm64 -f make/main/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push main image (arm64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-arm64 + - name: Tag main image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-arm64 ccr.ccs.tencentyun.com/nexent-hub/nexent:arm64 + - name: Push latest main image (arm64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent:arm64 build-and-push-data-process-amd64: @@ -72,10 +93,16 @@ jobs: rm -rf .git .gitattributes - name: Build data process image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:amd64 -f make/data_process/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . + docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-amd64 -f make/data_process/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push data process image (amd64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-amd64 + - name: Tag data process image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-amd64 ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:amd64 + - name: Push latest data process image (amd64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:amd64 build-and-push-data-process-arm64: @@ -101,10 +128,16 @@ jobs: rm -rf .git .gitattributes - name: Build data process image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:arm64 -f make/data_process/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . + docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-arm64 -f make/data_process/Dockerfile --build-arg MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple --build-arg APT_MIRROR=tsinghua . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push data process image (arm64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-arm64 + - name: Tag data process image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-arm64 ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:arm64 + - name: Push latest data process image (arm64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:arm64 build-and-push-web-amd64: @@ -121,10 +154,16 @@ jobs: uses: actions/checkout@v4 - name: Build web image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-web:amd64 -f make/web/Dockerfile --build-arg MIRROR=https://registry.npmmirror.com --build-arg APK_MIRROR=tsinghua . + docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-amd64 -f make/web/Dockerfile --build-arg MIRROR=https://registry.npmmirror.com --build-arg APK_MIRROR=tsinghua . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push web image (amd64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-amd64 + - name: Tag web image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-amd64 ccr.ccs.tencentyun.com/nexent-hub/nexent-web:amd64 + - name: Push latest web image (amd64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-web:amd64 build-and-push-web-arm64: @@ -141,10 +180,16 @@ jobs: uses: actions/checkout@v4 - name: Build web image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-web:arm64 -f make/web/Dockerfile --build-arg MIRROR=https://registry.npmmirror.com --build-arg APK_MIRROR=tsinghua . + docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-arm64 -f make/web/Dockerfile --build-arg MIRROR=https://registry.npmmirror.com --build-arg APK_MIRROR=tsinghua . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push web image (arm64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-arm64 + - name: Tag web image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-arm64 ccr.ccs.tencentyun.com/nexent-hub/nexent-web:arm64 + - name: Push latest web image (arm64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-web:arm64 build-and-push-terminal-amd64: @@ -161,10 +206,16 @@ jobs: uses: actions/checkout@v4 - name: Build terminal image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:amd64 -f make/terminal/Dockerfile . + docker buildx build --platform linux/amd64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 -f make/terminal/Dockerfile . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push terminal image (amd64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 + - name: Tag terminal image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:amd64 + - name: Push latest terminal image (amd64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:amd64 build-and-push-terminal-arm64: @@ -181,10 +232,16 @@ jobs: uses: actions/checkout@v4 - name: Build terminal image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:arm64 -f make/terminal/Dockerfile . + docker buildx build --platform linux/arm64 --load -t ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 -f make/terminal/Dockerfile . - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Push terminal image (arm64) to Tencent Cloud + run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 + - name: Tag terminal image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:arm64 + - name: Push latest terminal image (arm64) to Tencent Cloud + if: inputs.push_latest == 'true' run: docker push ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:arm64 manifest-push-main: @@ -196,6 +253,13 @@ jobs: - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Create and push manifest for main (Tencent Cloud) + run: | + docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }} \ + ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-amd64 \ + ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }}-arm64 + docker manifest push ccr.ccs.tencentyun.com/nexent-hub/nexent:${{ inputs.version }} + - name: Create and push latest manifest for main (Tencent Cloud) + if: inputs.push_latest == 'true' run: | docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent:latest \ ccr.ccs.tencentyun.com/nexent-hub/nexent:amd64 \ @@ -211,6 +275,13 @@ jobs: - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Create and push manifest for data-process (Tencent Cloud) + run: | + docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }} \ + ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-amd64 \ + ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }}-arm64 + docker manifest push ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${{ inputs.version }} + - name: Create and push latest manifest for data-process (Tencent Cloud) + if: inputs.push_latest == 'true' run: | docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:latest \ ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:amd64 \ @@ -226,6 +297,13 @@ jobs: - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Create and push manifest for web (Tencent Cloud) + run: | + docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }} \ + ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-amd64 \ + ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }}-arm64 + docker manifest push ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${{ inputs.version }} + - name: Create and push latest manifest for web (Tencent Cloud) + if: inputs.push_latest == 'true' run: | docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent-web:latest \ ccr.ccs.tencentyun.com/nexent-hub/nexent-web:amd64 \ @@ -241,6 +319,13 @@ jobs: - name: Login to Tencent Cloud run: echo ${{ secrets.TCR_PASSWORD }} | docker login ccr.ccs.tencentyun.com --username=${{ secrets.TCR_USERNAME }} --password-stdin - name: Create and push manifest for terminal (Tencent Cloud) + run: | + docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }} \ + ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 \ + ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 + docker manifest push ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${{ inputs.version }} + - name: Create and push latest manifest for terminal (Tencent Cloud) + if: inputs.push_latest == 'true' run: | docker manifest create ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:latest \ ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:amd64 \ diff --git a/.github/workflows/docker-build-push-overseas.yml b/.github/workflows/docker-build-push-overseas.yml index e9c48d272..d0483c05c 100644 --- a/.github/workflows/docker-build-push-overseas.yml +++ b/.github/workflows/docker-build-push-overseas.yml @@ -3,6 +3,15 @@ name: Docker Build and Push All Images to DockerHub on: workflow_dispatch: inputs: + version: + description: 'Image version tag (e.g. v1.0.0 or latest)' + required: true + default: 'latest' + push_latest: + description: 'Also push latest tag' + required: false + default: false + type: boolean runner_label_json: description: 'runner array in json format (e.g. ["ubuntu-latest"] or ["self-hosted"])' required: true @@ -23,10 +32,16 @@ jobs: uses: actions/checkout@v4 - name: Build main image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 -t nexent/nexent:amd64 --load -f make/main/Dockerfile . + docker buildx build --platform linux/amd64 -t nexent/nexent:${{ inputs.version }}-amd64 --load -f make/main/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push main image (amd64) to DockerHub + run: docker push nexent/nexent:${{ inputs.version }}-amd64 + - name: Tag main image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent:${{ inputs.version }}-amd64 nexent/nexent:amd64 + - name: Push latest main image (amd64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent:amd64 build-and-push-main-arm64: @@ -43,10 +58,16 @@ jobs: uses: actions/checkout@v4 - name: Build main image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 -t nexent/nexent:arm64 --load -f make/main/Dockerfile . + docker buildx build --platform linux/arm64 -t nexent/nexent:${{ inputs.version }}-arm64 --load -f make/main/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push main image (arm64) to DockerHub + run: docker push nexent/nexent:${{ inputs.version }}-arm64 + - name: Tag main image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent:${{ inputs.version }}-arm64 nexent/nexent:arm64 + - name: Push latest main image (arm64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent:arm64 build-and-push-data-process-amd64: @@ -72,10 +93,16 @@ jobs: rm -rf .git .gitattributes - name: Build data process image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 -t nexent/nexent-data-process:amd64 --load -f make/data_process/Dockerfile . + docker buildx build --platform linux/amd64 -t nexent/nexent-data-process:${{ inputs.version }}-amd64 --load -f make/data_process/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push data process image (amd64) to DockerHub + run: docker push nexent/nexent-data-process:${{ inputs.version }}-amd64 + - name: Tag data process image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent-data-process:${{ inputs.version }}-amd64 nexent/nexent-data-process:amd64 + - name: Push latest data process image (amd64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent-data-process:amd64 build-and-push-data-process-arm64: @@ -101,10 +128,16 @@ jobs: rm -rf .git .gitattributes - name: Build data process image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 -t nexent/nexent-data-process:arm64 --load -f make/data_process/Dockerfile . + docker buildx build --platform linux/arm64 -t nexent/nexent-data-process:${{ inputs.version }}-arm64 --load -f make/data_process/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push data process image (arm64) to DockerHub + run: docker push nexent/nexent-data-process:${{ inputs.version }}-arm64 + - name: Tag data process image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent-data-process:${{ inputs.version }}-arm64 nexent/nexent-data-process:arm64 + - name: Push latest data process image (arm64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent-data-process:arm64 build-and-push-web-amd64: @@ -121,10 +154,16 @@ jobs: uses: actions/checkout@v4 - name: Build web image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 -t nexent/nexent-web:amd64 --load -f make/web/Dockerfile . + docker buildx build --platform linux/amd64 -t nexent/nexent-web:${{ inputs.version }}-amd64 --load -f make/web/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push web image (amd64) to DockerHub + run: docker push nexent/nexent-web:${{ inputs.version }}-amd64 + - name: Tag web image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent-web:${{ inputs.version }}-amd64 nexent/nexent-web:amd64 + - name: Push latest web image (amd64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent-web:amd64 build-and-push-web-arm64: @@ -141,10 +180,16 @@ jobs: uses: actions/checkout@v4 - name: Build web image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 -t nexent/nexent-web:arm64 --load -f make/web/Dockerfile . + docker buildx build --platform linux/arm64 -t nexent/nexent-web:${{ inputs.version }}-arm64 --load -f make/web/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push web image (arm64) to DockerHub + run: docker push nexent/nexent-web:${{ inputs.version }}-arm64 + - name: Tag web image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent-web:${{ inputs.version }}-arm64 nexent/nexent-web:arm64 + - name: Push latest web image (arm64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent-web:arm64 build-and-push-terminal-amd64: @@ -161,10 +206,16 @@ jobs: uses: actions/checkout@v4 - name: Build terminal image (amd64) and load locally run: | - docker buildx build --platform linux/amd64 -t nexent/nexent-ubuntu-terminal:amd64 --load -f make/terminal/Dockerfile . + docker buildx build --platform linux/amd64 -t nexent/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 --load -f make/terminal/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push terminal image (amd64) to DockerHub + run: docker push nexent/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 + - name: Tag terminal image (amd64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 nexent/nexent-ubuntu-terminal:amd64 + - name: Push latest terminal image (amd64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent-ubuntu-terminal:amd64 build-and-push-terminal-arm64: @@ -181,10 +232,16 @@ jobs: uses: actions/checkout@v4 - name: Build terminal image (arm64) and load locally run: | - docker buildx build --platform linux/arm64 -t nexent/nexent-ubuntu-terminal:arm64 --load -f make/terminal/Dockerfile . + docker buildx build --platform linux/arm64 -t nexent/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 --load -f make/terminal/Dockerfile . - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Push terminal image (arm64) to DockerHub + run: docker push nexent/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 + - name: Tag terminal image (arm64) as latest + if: inputs.push_latest == 'true' + run: docker tag nexent/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 nexent/nexent-ubuntu-terminal:arm64 + - name: Push latest terminal image (arm64) to DockerHub + if: inputs.push_latest == 'true' run: docker push nexent/nexent-ubuntu-terminal:arm64 manifest-push-main: @@ -196,6 +253,13 @@ jobs: - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Create and push manifest for main (DockerHub) + run: | + docker manifest create nexent/nexent:${{ inputs.version }} \ + nexent/nexent:${{ inputs.version }}-amd64 \ + nexent/nexent:${{ inputs.version }}-arm64 + docker manifest push nexent/nexent:${{ inputs.version }} + - name: Create and push latest manifest for main (DockerHub) + if: inputs.push_latest == 'true' run: | docker manifest create nexent/nexent:latest \ nexent/nexent:amd64 \ @@ -211,6 +275,13 @@ jobs: - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Create and push manifest for data-process (DockerHub) + run: | + docker manifest create nexent/nexent-data-process:${{ inputs.version }} \ + nexent/nexent-data-process:${{ inputs.version }}-amd64 \ + nexent/nexent-data-process:${{ inputs.version }}-arm64 + docker manifest push nexent/nexent-data-process:${{ inputs.version }} + - name: Create and push latest manifest for data-process (DockerHub) + if: inputs.push_latest == 'true' run: | docker manifest create nexent/nexent-data-process:latest \ nexent/nexent-data-process:amd64 \ @@ -226,6 +297,13 @@ jobs: - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Create and push manifest for web (DockerHub) + run: | + docker manifest create nexent/nexent-web:${{ inputs.version }} \ + nexent/nexent-web:${{ inputs.version }}-amd64 \ + nexent/nexent-web:${{ inputs.version }}-arm64 + docker manifest push nexent/nexent-web:${{ inputs.version }} + - name: Create and push latest manifest for web (DockerHub) + if: inputs.push_latest == 'true' run: | docker manifest create nexent/nexent-web:latest \ nexent/nexent-web:amd64 \ @@ -241,6 +319,13 @@ jobs: - name: Login to DockerHub run: echo ${{ secrets.DOCKERHUB_TOKEN }} | docker login -u nexent --password-stdin - name: Create and push manifest for terminal (DockerHub) + run: | + docker manifest create nexent/nexent-ubuntu-terminal:${{ inputs.version }} \ + nexent/nexent-ubuntu-terminal:${{ inputs.version }}-amd64 \ + nexent/nexent-ubuntu-terminal:${{ inputs.version }}-arm64 + docker manifest push nexent/nexent-ubuntu-terminal:${{ inputs.version }} + - name: Create and push latest manifest for terminal (DockerHub) + if: inputs.push_latest == 'true' run: | docker manifest create nexent/nexent-ubuntu-terminal:latest \ nexent/nexent-ubuntu-terminal:amd64 \ diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index ca801109f..6e8d17740 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -43,7 +43,8 @@ async def create_model_config_list(tenant_id): model_repo=record["model_repo"], model_name=record["model_name"], ), - url=record["base_url"])) + url=record["base_url"], + ssl_verify=record.get("ssl_verify", True))) # fit for old version, main_model and sub_model use default model main_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) @@ -52,13 +53,15 @@ async def create_model_config_list(tenant_id): api_key=main_model_config.get("api_key", ""), model_name=get_model_name_from_config(main_model_config) if main_model_config.get( "model_name") else "", - url=main_model_config.get("base_url", ""))) + url=main_model_config.get("base_url", ""), + ssl_verify=main_model_config.get("ssl_verify", True))) model_list.append( ModelConfig(cite_name="sub_model", api_key=main_model_config.get("api_key", ""), model_name=get_model_name_from_config(main_model_config) if main_model_config.get( "model_name") else "", - url=main_model_config.get("base_url", ""))) + url=main_model_config.get("base_url", ""), + ssl_verify=main_model_config.get("ssl_verify", True))) return model_list diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 19e382ba1..4869ce440 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -1,7 +1,10 @@ import logging +import re from http import HTTPStatus from typing import List, Optional +from urllib.parse import urlparse, urlunparse, unquote, quote +import httpx from fastapi import APIRouter, Body, File, Form, Header, HTTPException, Path as PathParam, Query, UploadFile from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse @@ -12,6 +15,51 @@ logger = logging.getLogger("file_management_app") + +def build_content_disposition_header(filename: Optional[str]) -> str: + """ + Build a Content-Disposition header that keeps the original filename. + + - ASCII filenames are returned directly. + - Non-ASCII filenames include both an ASCII fallback and RFC 5987 encoded value + so modern browsers keep the original name. + """ + safe_name = (filename or "download").strip() or "download" + + def _sanitize_ascii(value: str) -> str: + # Replace problematic characters that break HTTP headers + # Remove control characters (newlines, carriage returns, tabs, etc.) + # Remove control characters (0x00-0x1F and 0x7F) + sanitized = re.sub(r'[\x00-\x1F\x7F]', '', value) + # Replace problematic characters that break HTTP headers + sanitized = sanitized.replace("\\", "_").replace('"', "_") + # Remove leading/trailing spaces and dots (Windows filename restrictions) + sanitized = sanitized.strip(' .') + return sanitized if sanitized else "download" + + try: + safe_name.encode("ascii") + return f'attachment; filename="{_sanitize_ascii(safe_name)}"' + except UnicodeEncodeError: + try: + encoded = quote(safe_name, safe="") + except Exception: + # quote failure, fallback to sanitized ASCII only + logger.warning("Failed to encode filename '%s', using fallback", safe_name) + return f'attachment; filename="{_sanitize_ascii(safe_name)}"' + + fallback = _sanitize_ascii( + safe_name.encode("ascii", "ignore").decode("ascii") or "download" + ) + return f'attachment; filename="{fallback}"; filename*=UTF-8\'\'{encoded}' + except Exception as exc: # pragma: no cover + logger.warning( + "Failed to encode filename '%s': %s. Using fallback.", + safe_name, + exc, + ) + return 'attachment; filename="download"' + # Create API router file_management_runtime_router = APIRouter(prefix="/file") file_management_config_router = APIRouter(prefix="/file") @@ -98,6 +146,64 @@ async def process_files( ) +@file_management_config_router.get("/download/{object_name:path}") +async def get_storage_file( + object_name: str = PathParam(..., description="File object name"), + download: str = Query("ignore", description="How to get the file"), + expires: int = Query(3600, description="URL validity period (seconds)"), + filename: Optional[str] = Query(None, description="Original filename for download (optional)") +): + """ + Get information, download link, or file stream for a single file + + - **object_name**: File object name + - **download**: Download mode: ignore (default, return file info), stream (return file stream), redirect (redirect to download URL) + - **expires**: URL validity period in seconds (default 3600) + - **filename**: Original filename for download (optional, if not provided, will use object_name) + + Returns file information, download link, or file content + """ + try: + logger.info(f"[get_storage_file] Route matched! object_name={object_name}, download={download}, filename={filename}") + if download == "redirect": + # return a redirect download URL + result = await get_file_url_impl(object_name=object_name, expires=expires) + return RedirectResponse(url=result["url"]) + elif download == "stream": + # return a readable file stream + file_stream, content_type = await get_file_stream_impl(object_name=object_name) + logger.info(f"Streaming file: object_name={object_name}, content_type={content_type}") + + # Use provided filename or extract from object_name + download_filename = filename + if not download_filename: + # Extract filename from object_name (get the last part after the last slash) + download_filename = object_name.split("/")[-1] if "/" in object_name else object_name + + # Build Content-Disposition header with proper encoding for non-ASCII characters + content_disposition = build_content_disposition_header(download_filename) + + return StreamingResponse( + file_stream, + media_type=content_type, + headers={ + "Content-Disposition": content_disposition, + "Cache-Control": "public, max-age=3600", + "ETag": f'"{object_name}"', + } + ) + else: + # return file metadata + return await get_file_url_impl(object_name=object_name, expires=expires) + except Exception as e: + logger.error(f"Failed to get file: object_name={object_name}, error={str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Failed to get file information: {str(e)}" + ) + + + @file_management_runtime_router.post("/storage") async def storage_upload_files( files: List[UploadFile] = File(..., description="List of files to upload"), @@ -158,43 +264,204 @@ async def get_storage_files( ) -@file_management_config_router.get("/storage/{path}/{object_name}") -async def get_storage_file( - object_name: str = PathParam(..., description="File object name"), - download: str = Query("ignore", description="How to get the file"), - expires: int = Query(3600, description="URL validity period (seconds)") +def _ensure_http_scheme(raw_url: str) -> str: + """ + Ensure the provided Datamate URL has an explicit HTTP or HTTPS scheme. + """ + candidate = (raw_url or "").strip() + if not candidate: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="URL cannot be empty" + ) + + parsed = urlparse(candidate) + if parsed.scheme: + if parsed.scheme not in ("http", "https"): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="URL must start with http:// or https://" + ) + return candidate + + if candidate.startswith("//"): + return f"http:{candidate}" + + return f"http://{candidate}" + + +def _normalize_datamate_download_url(raw_url: str) -> str: + """ + Normalize Datamate download URL to ensure it follows /data-management/datasets/{datasetId}/files/{fileId}/download + """ + normalized_source = _ensure_http_scheme(raw_url) + parsed_url = urlparse(normalized_source) + path_segments = [segment for segment in parsed_url.path.split("/") if segment] + + if "data-management" not in path_segments: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Invalid Datamate URL: missing 'data-management' segment" + ) + + try: + dm_index = path_segments.index("data-management") + datasets_index = path_segments.index("datasets", dm_index) + dataset_id = path_segments[datasets_index + 1] + files_index = path_segments.index("files", datasets_index) + file_id = path_segments[files_index + 1] + except (ValueError, IndexError): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Invalid Datamate URL: unable to parse dataset_id or file_id" + ) + + prefix_segments = path_segments[:dm_index] + prefix_path = "/" + "/".join(prefix_segments) if prefix_segments else "" + normalized_path = f"{prefix_path}/data-management/datasets/{dataset_id}/files/{file_id}/download" + + normalized_url = urlunparse(( + parsed_url.scheme, + parsed_url.netloc, + normalized_path, + "", + "", + "" + )) + + return normalized_url + + +def _build_datamate_url_from_parts(base_url: str, dataset_id: str, file_id: str) -> str: + """ + Build Datamate download URL from individual parts + """ + if not base_url: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="base_url is required when dataset_id and file_id are provided" + ) + + base_with_scheme = _ensure_http_scheme(base_url) + parsed_base = urlparse(base_with_scheme) + base_prefix = parsed_base.path.rstrip("/") + + if base_prefix and not base_prefix.endswith("/api"): + if base_prefix.endswith("/"): + base_prefix = f"{base_prefix}api" + else: + base_prefix = f"{base_prefix}/api" + elif not base_prefix: + base_prefix = "/api" + + normalized_path = f"{base_prefix}/data-management/datasets/{dataset_id}/files/{file_id}/download" + + return urlunparse(( + parsed_base.scheme, + parsed_base.netloc, + normalized_path, + "", + "", + "" + )) + + +@file_management_config_router.get("/datamate/download") +async def download_datamate_file( + url: Optional[str] = Query(None, description="Datamate file URL to download"), + base_url: Optional[str] = Query(None, description="Datamate base server URL (e.g., host:port)"), + dataset_id: Optional[str] = Query(None, description="Datamate dataset ID"), + file_id: Optional[str] = Query(None, description="Datamate file ID"), + filename: Optional[str] = Query(None, description="Optional filename for download"), + authorization: Optional[str] = Header(None, alias="Authorization") ): """ - Get information, download link, or file stream for a single file + Download file from Datamate knowledge base via HTTP URL - - **object_name**: File object name - - **download**: Download mode: ignore (default, return file info), stream (return file stream), redirect (redirect to download URL) - - **expires**: URL validity period in seconds (default 3600) + - **url**: Full HTTP URL of the file to download (optional) + - **base_url**: Base server URL (e.g., host:port) + - **dataset_id**: Datamate dataset ID + - **file_id**: Datamate file ID + - **filename**: Optional filename for the download (extracted automatically if not provided) + - **authorization**: Optional authorizatio n header to pass to the target URL - Returns file information, download link, or file content + Returns file stream for download """ try: - if download == "redirect": - # return a redirect download URL - result = await get_file_url_impl(object_name=object_name, expires=expires) - return RedirectResponse(url=result["url"]) - elif download == "stream": - # return a readable file stream - file_stream, content_type = await get_file_stream_impl(object_name=object_name) + if url: + logger.info(f"[download_datamate_file] Using full URL: {url}") + normalized_url = _normalize_datamate_download_url(url) + elif base_url and dataset_id and file_id: + logger.info(f"[download_datamate_file] Building URL from parts: base_url={base_url}, dataset_id={dataset_id}, file_id={file_id}") + normalized_url = _build_datamate_url_from_parts(base_url, dataset_id, file_id) + else: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Either url or (base_url, dataset_id, file_id) must be provided" + ) + + logger.info(f"[download_datamate_file] Normalized download URL: {normalized_url}") + logger.info(f"[download_datamate_file] Authorization header present: {authorization is not None}") + + headers = {} + if authorization: + headers["Authorization"] = authorization + logger.debug(f"[download_datamate_file] Using authorization header: {authorization[:20]}...") + headers["User-Agent"] = "Nexent-File-Downloader/1.0" + + logger.info(f"[download_datamate_file] Request headers: {list(headers.keys())}") + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(normalized_url, headers=headers, follow_redirects=True) + logger.info(f"[download_datamate_file] Response status: {response.status_code}") + + if response.status_code == 404: + logger.error(f"[download_datamate_file] File not found at URL: {normalized_url}") + logger.error(f"[download_datamate_file] Response headers: {dict(response.headers)}") + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="File not found. Please verify dataset_id and file_id." + ) + + response.raise_for_status() + + content_type = response.headers.get("Content-Type", "application/octet-stream") + + download_filename = filename + if not download_filename: + content_disposition = response.headers.get("Content-Disposition", "") + if content_disposition: + filename_match = re.search(r'filename="?(.+?)"?$', content_disposition) + if filename_match: + download_filename = filename_match.group(1) + + if not download_filename: + path = unquote(urlparse(normalized_url).path) + download_filename = path.split('/')[-1] or "download" + + # Build Content-Disposition header with proper encoding for non-ASCII characters + content_disposition = build_content_disposition_header(download_filename) + return StreamingResponse( - file_stream, + iter([response.content]), media_type=content_type, headers={ - "Content-Disposition": f'inline; filename="{object_name}"' + "Content-Disposition": content_disposition } ) - else: - # return file metadata - return await get_file_url_impl(object_name=object_name, expires=expires) + except httpx.HTTPError as e: + logger.error(f"Failed to download file from URL {url}: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.BAD_GATEWAY, + detail=f"Failed to download file from URL: {str(e)}" + ) + except HTTPException: + raise except Exception as e: + logger.error(f"Failed to download datamate file: {str(e)}") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"Failed to get file information: {str(e)}" + detail=f"Failed to download file: {str(e)}" ) diff --git a/backend/apps/image_app.py b/backend/apps/image_app.py index 3024d4226..61eed2fcc 100644 --- a/backend/apps/image_app.py +++ b/backend/apps/image_app.py @@ -1,7 +1,11 @@ import logging +import base64 from urllib.parse import unquote +from io import BytesIO -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from http import HTTPStatus from services.image_service import proxy_image_impl @@ -12,23 +16,53 @@ logger = logging.getLogger("image_app") -# TODO: To remove this proxy service after frontend uses image filter service as image provider @router.get("/image") -async def proxy_image(url: str): +async def proxy_image(url: str, format: str = "json"): """ - Image proxy service that fetches remote images and returns base64 encoded data - + Image proxy service that fetches remote images + Parameters: url: Remote image URL + format: Response format - "json" (default, returns base64) or "stream" (returns image stream) Returns: - JSON object containing base64 encoded image + JSON object containing base64 encoded image (format=json) or image stream (format=stream) """ try: # URL decode decoded_url = unquote(url) - return await proxy_image_impl(decoded_url) + + if format == "stream": + # Return image as stream for direct use in tags + result = await proxy_image_impl(decoded_url) + if not result.get("success"): + raise HTTPException( + status_code=HTTPStatus.BAD_GATEWAY, + detail=result.get("error", "Failed to fetch image") + ) + + # Decode base64 to bytes + base64_data = result.get("base64", "") + content_type = result.get("content_type", "image/jpeg") + image_bytes = base64.b64decode(base64_data) + + # Return as streaming response + return StreamingResponse( + BytesIO(image_bytes), + media_type=content_type, + headers={ + "Cache-Control": "public, max-age=3600" + } + ) + else: + # Return JSON with base64 (default behavior for backward compatibility) + return await proxy_image_impl(decoded_url) except Exception as e: logger.error( f"Error occurred while proxying image: {str(e)}, URL: {url[:50]}...") + if format == "stream": + raise HTTPException( + status_code=HTTPStatus.BAD_GATEWAY, + detail=str(e) + ) return {"success": False, "error": str(e)} \ No newline at end of file diff --git a/backend/apps/me_model_managment_app.py b/backend/apps/me_model_managment_app.py index 70c4cfab8..d7055474f 100644 --- a/backend/apps/me_model_managment_app.py +++ b/backend/apps/me_model_managment_app.py @@ -4,81 +4,44 @@ from fastapi import APIRouter, Query, HTTPException from fastapi.responses import JSONResponse -from consts.exceptions import TimeoutException, NotFoundException, MEConnectionException -from services.me_model_management_service import get_me_models_impl, check_me_variable_set -from services.model_health_service import check_me_connectivity_impl +from consts.exceptions import MEConnectionException, TimeoutException +from services.me_model_management_service import check_me_variable_set, check_me_connectivity router = APIRouter(prefix="/me") -@router.get("/model/list") -async def get_me_models( - type: str = Query( - default="", description="Model type: embed/chat/rerank"), - timeout: int = Query( - default=2, description="Request timeout in seconds") -): - """ - Get list of models from model engine API - """ - try: - # Pre-check ME environment variables; return empty list if not configured - if not await check_me_variable_set(): - return JSONResponse( - status_code=HTTPStatus.OK, - content={ - "message": "Retrieve skipped", - "data": [] - } - ) - filtered_result = await get_me_models_impl(timeout=timeout, type=type) - return JSONResponse( - status_code=HTTPStatus.OK, - content={ - "message": "Successfully retrieved", - "data": filtered_result - } - ) - except TimeoutException as e: - logging.error(f"Request me model timeout: {str(e)}") - raise HTTPException(status_code=HTTPStatus.REQUEST_TIMEOUT, detail="Failed to get ModelEngine model list: timeout") - except NotFoundException as e: - logging.error(f"Request me model not found: {str(e)}") - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="ModelEngine model not found") - except Exception as e: - logging.error(f"Failed to get me model list: {str(e)}") - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Failed to get ModelEngine model list") - - @router.get("/healthcheck") -async def check_me_connectivity(timeout: int = Query(default=2, description="Timeout in seconds")): +async def check_me_health(timeout: int = Query(default=30, description="Timeout in seconds")): """ - Health check from model engine API + Health check for ModelEngine platform by actually calling the API. + Returns connectivity status based on actual API response. """ try: - # Pre-check ME environment variables; return not connected if not configured + # First check if environment variables are configured if not await check_me_variable_set(): return JSONResponse( status_code=HTTPStatus.OK, content={ "connectivity": False, - "message": "ModelEngine platform necessary environment variables not configured. Healthcheck skipped.", + "message": "ModelEngine platform environment variables not configured. Healthcheck skipped.", } ) - await check_me_connectivity_impl(timeout) + + # Then check actual connectivity + await check_me_connectivity(timeout) return JSONResponse( status_code=HTTPStatus.OK, content={ "connectivity": True, - "message": "ModelEngine platform connect successfully.", + "message": "ModelEngine platform connected successfully.", } ) except MEConnectionException as e: - logging.error(f"ModelEngine model healthcheck failed: {str(e)}") - raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail="ModelEngine model connect failed.") + logging.error(f"ModelEngine healthcheck failed: {str(e)}") + raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=f"ModelEngine connection failed: {str(e)}") except TimeoutException as e: - logging.error(f"ModelEngine model healthcheck timeout: {str(e)}") - raise HTTPException(status_code=HTTPStatus.REQUEST_TIMEOUT, detail="ModelEngine model connect timeout.") + logging.error(f"ModelEngine healthcheck timeout: {str(e)}") + raise HTTPException(status_code=HTTPStatus.REQUEST_TIMEOUT, detail="ModelEngine connection timeout.") except Exception as e: - logging.error(f"ModelEngine model healthcheck failed with unknown error: {str(e)}.") - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="ModelEngine model connect failed.") + logging.error(f"ModelEngine healthcheck failed with unknown error: {str(e)}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"ModelEngine healthcheck failed: {str(e)}") diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 5383aeed8..0c3c7a8cf 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -157,25 +157,35 @@ async def get_provider_list(request: ProviderModelRequest, authorization: Option @router.post("/update") -async def update_single_model(request: dict, authorization: Optional[str] = Header(None)): - """Update a single model by its `model_id`. +async def update_single_model( + request: dict, + display_name: str = Query(..., description="Current display name of the model to update"), + authorization: Optional[str] = Header(None) +): + """Update a single model by its current `display_name`. - Performs a uniqueness check on `display_name` within the tenant and updates - the record if valid. + The model is looked up using the `display_name` query parameter. The request + body contains the fields to update, which may include a new `display_name`. Args: - request: Arbitrary model fields with required `model_id`. + request: Arbitrary model fields to update (may include new display_name). + display_name: Current display name of the model (query parameter for lookup). authorization: Bearer token header used to derive identity context. Raises: - HTTPException: 409 if `display_name` conflicts, 500 for unexpected errors. + HTTPException: 404 if model not found, 409 if new `display_name` conflicts, + 500 for unexpected errors. """ try: user_id, tenant_id = get_current_user_id(authorization) - await update_single_model_for_tenant(user_id, tenant_id, request) + await update_single_model_for_tenant(user_id, tenant_id, display_name, request) return JSONResponse(status_code=HTTPStatus.OK, content={ "message": "Model updated successfully" }) + except LookupError as e: + logging.error(f"Failed to update model: {str(e)}") + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, + detail=str(e)) except ValueError as e: logging.error(f"Failed to update model: {str(e)}") raise HTTPException(status_code=HTTPStatus.CONFLICT, diff --git a/backend/consts/const.py b/backend/consts/const.py index 754619d12..8e99ca84d 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -279,7 +279,7 @@ class VectorDatabaseType(str, Enum): os.getenv("LLM_SLOW_TOKEN_RATE_THRESHOLD", "10.0")) # tokens per second # APP Version -APP_VERSION = "v1.7.7" +APP_VERSION = "v1.7.7.1" DEFAULT_ZH_TITLE = "新对话" DEFAULT_EN_TITLE = "New Conversation" diff --git a/backend/consts/provider.py b/backend/consts/provider.py index f82a60e7f..7fd783015 100644 --- a/backend/consts/provider.py +++ b/backend/consts/provider.py @@ -5,8 +5,12 @@ class ProviderEnum(str, Enum): """Supported model providers""" SILICON = "silicon" OPENAI = "openai" + MODELENGINE = "modelengine" # Silicon Flow SILICON_BASE_URL = "https://api.siliconflow.cn/v1/" SILICON_GET_URL = "https://api.siliconflow.cn/v1/models" + +# ModelEngine +# Base URL and API key are loaded from environment variables at runtime diff --git a/backend/database/db_models.py b/backend/database/db_models.py index eeb9d1c34..a4201abad 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -160,6 +160,8 @@ class ModelRecord(TableBase): Integer, doc="Expected chunk size for embedding models, used during document chunking") maximum_chunk_size = Column( Integer, doc="Maximum chunk size for embedding models, used during document chunking") + ssl_verify = Column( + Boolean, default=True, doc="Whether to verify SSL certificates when connecting to this model API. Default is true. Set to false for local services without SSL support.") class ToolInfo(TableBase): diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index b5d0d5b1e..257320499 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -185,6 +185,21 @@ def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dic return model +def get_models_by_display_name(display_name: str, tenant_id: str) -> List[Dict[str, Any]]: + """ + Get all model records by display name (for multi_embedding which creates two records) + + Args: + display_name: Model display name + tenant_id: Tenant ID + + Returns: + List[Dict[str, Any]]: List of model records with the same display_name + """ + filters = {'display_name': display_name} + return get_model_records(filters, tenant_id) + + def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[int]: """ Get a model ID by display name @@ -252,3 +267,25 @@ def get_models_by_tenant_factory_type(tenant_id: str, model_factory: str, model_ "model_type": model_type } return get_model_records(filters, tenant_id) + + +def get_model_by_name_factory(model_name: str, model_factory: str, tenant_id: str) -> Optional[Dict[str, Any]]: + """ + Get a model record by model_name and model_factory for deduplication. + + Args: + model_name: Model name (e.g., "deepseek-r1-distill-qwen-14b") + model_factory: Model factory (e.g., "ModelEngine") + tenant_id: Tenant ID + + Returns: + Optional[Dict[str, Any]]: Model record if found, None otherwise + """ + filters = { + 'model_name': model_name, + 'model_factory': model_factory + } + records = get_model_records(filters, tenant_id) + return records[0] if records else None + + diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index aa537bb25..5184b0e25 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -41,7 +41,7 @@ update_related_agents ) from database.model_management_db import get_model_by_model_id, get_model_id_by_display_name -from database.remote_mcp_db import check_mcp_name_exists, get_mcp_server_by_name_and_tenant +from database.remote_mcp_db import get_mcp_server_by_name_and_tenant from database.tool_db import ( check_tool_is_available, create_or_update_tool_by_tool_info, @@ -53,8 +53,6 @@ ) from services.conversation_management_service import save_conversation_assistant, save_conversation_user from services.memory_config_service import build_memory_context -from services.remote_mcp_service import add_remote_mcp_server_list -from services.tool_configuration_service import update_tool_list from utils.auth_utils import get_current_user_info, get_user_language from utils.config_utils import tenant_config_manager from utils.memory_utils import build_memory_config @@ -573,6 +571,15 @@ async def get_agent_info_impl(agent_id: int, tenant_id: str): elif "business_logic_model_name" not in agent_info: agent_info["business_logic_model_name"] = None + # Check agent availability + is_available, unavailable_reasons = check_agent_availability( + agent_id=agent_id, + tenant_id=tenant_id, + agent_info=agent_info + ) + agent_info["is_available"] = is_available + agent_info["unavailable_reasons"] = unavailable_reasons + return agent_info @@ -890,52 +897,17 @@ async def import_agent_impl( force_import: bool = False ): """ - Import agent using DFS + Import agent using DFS. + + Note: + MCP server registration and tool list refresh are now handled + on the frontend / dedicated MCP configuration flows. + The backend import logic only consumes the tools that already + exist for the current tenant. """ user_id, tenant_id, _ = get_current_user_info(authorization) agent_id = agent_info.agent_id - # First, add MCP servers if any - if agent_info.mcp_info: - for mcp_info in agent_info.mcp_info: - if mcp_info.mcp_server_name and mcp_info.mcp_url: - try: - # Check if MCP name already exists - if check_mcp_name_exists(mcp_name=mcp_info.mcp_server_name, tenant_id=tenant_id): - # Get existing MCP server info to compare URLs - existing_mcp = get_mcp_server_by_name_and_tenant(mcp_name=mcp_info.mcp_server_name, - tenant_id=tenant_id) - if existing_mcp and existing_mcp == mcp_info.mcp_url: - # Same name and URL, skip - logger.info( - f"MCP server {mcp_info.mcp_server_name} with same URL already exists, skipping") - continue - else: - # Same name but different URL, add import prefix - import_mcp_name = f"import_{mcp_info.mcp_server_name}" - logger.info( - f"MCP server {mcp_info.mcp_server_name} exists with different URL, using name: {import_mcp_name}") - mcp_server_name = import_mcp_name - else: - # Name doesn't exist, use original name - mcp_server_name = mcp_info.mcp_server_name - - await add_remote_mcp_server_list( - tenant_id=tenant_id, - user_id=user_id, - remote_mcp_server=mcp_info.mcp_url, - remote_mcp_server_name=mcp_server_name - ) - except Exception as e: - raise Exception( - f"Failed to add MCP server {mcp_info.mcp_server_name}: {str(e)}") - - # Then, update tool list to include new MCP tools - try: - await update_tool_list(tenant_id=tenant_id, user_id=user_id) - except Exception as e: - raise Exception(f"Failed to update tool list: {str(e)}") - agent_stack = deque([agent_id]) agent_id_set = set() mapping_agent_id = {} @@ -1047,14 +1019,16 @@ async def import_agent_by_agent_id( regeneration_model_id = business_logic_model_id or model_id if regeneration_model_id: try: - agent_name = _regenerate_agent_name_with_llm( + # Offload blocking LLM regeneration to a thread to avoid blocking the event loop + agent_name = await asyncio.to_thread( + _regenerate_agent_name_with_llm, original_name=agent_name, existing_names=existing_names, task_description=import_agent_info.business_description or import_agent_info.description or "", model_id=regeneration_model_id, tenant_id=tenant_id, language=LANGUAGE["ZH"], # Default to Chinese, can be enhanced later - agents_cache=all_agents + agents_cache=all_agents, ) logger.info(f"Regenerated agent name: '{agent_name}'") except Exception as e: @@ -1079,14 +1053,16 @@ async def import_agent_by_agent_id( regeneration_model_id = business_logic_model_id or model_id if regeneration_model_id: try: - agent_display_name = _regenerate_agent_display_name_with_llm( + # Offload blocking LLM regeneration to a thread to avoid blocking the event loop + agent_display_name = await asyncio.to_thread( + _regenerate_agent_display_name_with_llm, original_display_name=agent_display_name, existing_display_names=existing_display_names, task_description=import_agent_info.business_description or import_agent_info.description or "", model_id=regeneration_model_id, tenant_id=tenant_id, language=LANGUAGE["ZH"], # Default to Chinese, can be enhanced later - agents_cache=all_agents + agents_cache=all_agents, ) logger.info(f"Regenerated agent display_name: '{agent_display_name}'") except Exception as e: @@ -1168,23 +1144,13 @@ async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: if not agent["enabled"]: continue - unavailable_reasons: list[str] = [] - - tool_info = search_tools_for_sub_agent( - agent_id=agent["agent_id"], tenant_id=tenant_id) - tool_id_list = [tool["tool_id"] - for tool in tool_info if tool.get("tool_id") is not None] - if tool_id_list: - tool_statuses = check_tool_is_available(tool_id_list) - if not all(tool_statuses): - unavailable_reasons.append("tool_unavailable") - - model_reasons = _collect_model_availability_reasons( - agent=agent, + # Use shared availability check function + _, unavailable_reasons = check_agent_availability( + agent_id=agent["agent_id"], tenant_id=tenant_id, + agent_info=agent, model_cache=model_cache ) - unavailable_reasons.extend(model_reasons) # Preserve the raw data so we can adjust availability for duplicates enriched_agents.append({ @@ -1295,6 +1261,56 @@ def _check_single_model_availability( return [] +def check_agent_availability( + agent_id: int, + tenant_id: str, + agent_info: dict | None = None, + model_cache: Dict[int, Optional[dict]] | None = None +) -> tuple[bool, list[str]]: + """ + Check if an agent is available based on its tools and model configuration. + + Args: + agent_id: The agent ID to check + tenant_id: The tenant ID + agent_info: Optional pre-fetched agent info (to avoid duplicate DB queries) + model_cache: Optional model cache for performance optimization + + Returns: + tuple: (is_available: bool, unavailable_reasons: list[str]) + """ + unavailable_reasons: list[str] = [] + + if model_cache is None: + model_cache = {} + + # Fetch agent info if not provided + if agent_info is None: + agent_info = search_agent_info_by_agent_id(agent_id, tenant_id) + + if not agent_info: + return False, ["agent_not_found"] + + # Check tool availability + tool_info = search_tools_for_sub_agent(agent_id=agent_id, tenant_id=tenant_id) + tool_id_list = [tool["tool_id"] for tool in tool_info if tool.get("tool_id") is not None] + if tool_id_list: + tool_statuses = check_tool_is_available(tool_id_list) + if not all(tool_statuses): + unavailable_reasons.append("tool_unavailable") + + # Check model availability + model_reasons = _collect_model_availability_reasons( + agent=agent_info, + tenant_id=tenant_id, + model_cache=model_cache + ) + unavailable_reasons.extend(model_reasons) + + is_available = len(unavailable_reasons) == 0 + return is_available, unavailable_reasons + + def insert_related_agent_impl(parent_agent_id, child_agent_id, tenant_id): # search the agent by bfs, check if there is a circular call search_list = deque([child_agent_id]) diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index a794598df..b14835d90 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -263,8 +263,13 @@ def call_llm_for_title(content: str, tenant_id: str, language: str = LANGUAGE["Z key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) # Create OpenAIServerModel instance - llm = OpenAIServerModel(model_id=get_model_name_from_config(model_config) if model_config.get("model_name") else "", api_base=model_config.get("base_url", ""), - api_key=model_config.get("api_key", ""), temperature=0.7, top_p=0.95) + llm = OpenAIServerModel( + model_id=get_model_name_from_config(model_config) if model_config.get("model_name") else "", + api_base=model_config.get("base_url", ""), + api_key=model_config.get("api_key", ""), + temperature=0.7, + top_p=0.95 + ) # Build messages user_prompt = Template(prompt_template["USER_PROMPT"], undefined=StrictUndefined).render({ @@ -276,7 +281,7 @@ def call_llm_for_title(content: str, tenant_id: str, language: str = LANGUAGE["Z "content": user_prompt}] # Call the model - response = llm(messages, max_tokens=10) + response = llm.generate(messages) if not response or not response.content or not response.content.strip(): return DEFAULT_EN_TITLE if language == LANGUAGE["EN"] else DEFAULT_ZH_TITLE return remove_think_blocks(response.content.strip()) diff --git a/backend/services/me_model_management_service.py b/backend/services/me_model_management_service.py index e44aab0d5..9860ffe5b 100644 --- a/backend/services/me_model_management_service.py +++ b/backend/services/me_model_management_service.py @@ -1,61 +1,55 @@ -import asyncio -from typing import List - import aiohttp +import asyncio from consts.const import MODEL_ENGINE_APIKEY, MODEL_ENGINE_HOST -from consts.exceptions import TimeoutException, NotFoundException +from consts.exceptions import MEConnectionException, TimeoutException + + +async def check_me_variable_set() -> bool: + """ + Check if the ME environment variables are correctly set. + Returns: + bool: True if both MODEL_ENGINE_APIKEY and MODEL_ENGINE_HOST are set and non-empty, False otherwise. + """ + return bool(MODEL_ENGINE_APIKEY and MODEL_ENGINE_HOST) -async def get_me_models_impl(timeout: int = 2, type: str = "") -> List: +async def check_me_connectivity(timeout: int = 30) -> bool: """ - Fetches a list of models from the model engine API with response formatting. - Parameters: - timeout (int): The total timeout for the request in seconds. - type (str): The type of model to filter for. If empty, returns all models. + Check ModelEngine connectivity by actually calling the API. + + Args: + timeout: Request timeout in seconds + Returns: - - filtered_result: List of model data dictionaries + bool: True if connection successful, False otherwise + + Raises: + MEConnectionException: If connection failed with specific error + TimeoutException: If request timed out """ + if not await check_me_variable_set(): + return False + try: - headers = { - 'Authorization': f'Bearer {MODEL_ENGINE_APIKEY}', - } + headers = {"Authorization": f"Bearer {MODEL_ENGINE_APIKEY}"} + async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=timeout), - connector=aiohttp.TCPConnector(ssl=False) + timeout=aiohttp.ClientTimeout(total=timeout), + connector=aiohttp.TCPConnector(ssl=False) ) as session: async with session.get( - f"{MODEL_ENGINE_HOST}/open/router/v1/models", - headers=headers + f"{MODEL_ENGINE_HOST}/open/router/v1/models", + headers=headers ) as response: - response.raise_for_status() - result_data = await response.json() - result: list = result_data['data'] - - # Type filtering - filtered_result = [] - if type: - for data in result: - if data['type'] == type: - filtered_result.append(data) - if not filtered_result: - result_types = set(data['type'] for data in result) - raise NotFoundException( - f"No models found with type '{type}'. Available types: {result_types}.") - else: - filtered_result = result - - return filtered_result + if response.status == 200: + return True + else: + raise MEConnectionException( + f"Connection failed, error code: {response.status}") except asyncio.TimeoutError: - raise TimeoutException("Request timeout.") + raise TimeoutException("Connection timed out") + except MEConnectionException: + raise except Exception as e: - raise Exception(f"Failed to get model list: {str(e)}.") - - -async def check_me_variable_set() -> bool: - """ - Check if the ME environment variables are correctly set. - Returns: - bool: True if both MODEL_ENGINE_APIKEY and MODEL_ENGINE_HOST are set and non-empty, False otherwise. - """ - return bool(MODEL_ENGINE_APIKEY and MODEL_ENGINE_HOST) + raise Exception(f"Unknown error occurred: {str(e)}") diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index df98f508a..c6f426789 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -1,15 +1,11 @@ -import asyncio import logging -import aiohttp -from http import HTTPStatus from nexent.core import MessageObserver from nexent.core.models import OpenAIModel, OpenAIVLModel from nexent.core.models.embedding_model import JinaEmbedding, OpenAICompatibleEmbedding from services.voice_service import get_voice_service -from consts.const import MODEL_ENGINE_APIKEY, MODEL_ENGINE_HOST, LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST -from consts.exceptions import MEConnectionException, TimeoutException +from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST from consts.model import ModelConnectStatusEnum from database.model_management_db import get_model_by_display_name, update_model_record from utils.config_utils import get_model_name_from_config @@ -57,6 +53,7 @@ async def _perform_connectivity_check( model_type: str, model_base_url: str, model_api_key: str, + ssl_verify: bool = True, ) -> bool: """ Perform specific model connectivity check @@ -65,6 +62,7 @@ async def _perform_connectivity_check( model_type: Model type model_base_url: Model base URL model_api_key: API key + ssl_verify: Whether to verify SSL certificates (default: True) Returns: bool: Connectivity check result """ @@ -95,7 +93,8 @@ async def _perform_connectivity_check( observer, model_id=model_name, api_base=model_base_url, - api_key=model_api_key + api_key=model_api_key, + ssl_verify=ssl_verify ).check_connectivity() elif model_type == "rerank": connectivity = False @@ -135,11 +134,12 @@ async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: model_type = model["model_type"] model_base_url = model["base_url"] model_api_key = model["api_key"] + ssl_verify = model.get("ssl_verify", True) # Default to True if not present try: # Use the common connectivity check function connectivity = await _perform_connectivity_check( - model_name, model_type, model_base_url, model_api_key + model_name, model_type, model_base_url, model_api_key, ssl_verify ) except Exception as e: update_data = {"connect_status": ModelConnectStatusEnum.UNAVAILABLE.value} @@ -167,32 +167,6 @@ async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: raise e -async def check_me_connectivity_impl(timeout: int): - """ - Check ME connectivity and return structured response data - Args: - timeout: Request timeout in seconds - """ - try: - headers = {'Authorization': f'Bearer {MODEL_ENGINE_APIKEY}'} - - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=timeout), - connector=aiohttp.TCPConnector(ssl=False) - ) as session: - async with session.get( - f"{MODEL_ENGINE_HOST}/open/router/v1/models", - headers=headers - ) as response: - if response.status == HTTPStatus.OK: - return - else: - raise MEConnectionException( - f"Connection failed, error code: {response.status}") - except asyncio.TimeoutError: - raise TimeoutException("Connection timed out") - except Exception as e: - raise Exception(f"Unknown error occurred: {str(e)}") async def verify_model_config_connectivity(model_config: dict): @@ -208,11 +182,12 @@ async def verify_model_config_connectivity(model_config: dict): model_type = model_config["model_type"] model_base_url = model_config["base_url"] model_api_key = model_config["api_key"] + ssl_verify = model_config.get("ssl_verify", True) # Default to True if not present try: # Use the common connectivity check function connectivity = await _perform_connectivity_check( - model_name, model_type, model_base_url, model_api_key + model_name, model_type, model_base_url, model_api_key, ssl_verify ) if not connectivity: diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 84936f393..7fe1a86b6 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -9,6 +9,7 @@ create_model_record, delete_model_record, get_model_by_display_name, + get_models_by_display_name, get_model_records, get_models_by_tenant_factory_type, update_model_record, @@ -133,6 +134,9 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay if provider == ProviderEnum.SILICON.value: model_url = SILICON_BASE_URL + elif provider == ProviderEnum.MODELENGINE.value: + # ModelEngine models carry their own base_url in each model dict + model_url = "" else: model_url = "" @@ -195,20 +199,63 @@ async def list_provider_models_for_tenant(tenant_id: str, provider: str, model_t raise Exception(f"Failed to list provider models: {str(e)}") -async def update_single_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict[str, Any]): - """Update a single model by its model_id, ensuring display_name uniqueness.""" +async def update_single_model_for_tenant( + user_id: str, + tenant_id: str, + current_display_name: str, + model_data: Dict[str, Any] +): + """Update model(s) by current display_name. If embedding/multi_embedding, update both types. + + Args: + user_id: The user performing the update. + tenant_id: The tenant context. + current_display_name: The current display_name used to look up the model(s). + model_data: The fields to update, which may include a new display_name. + + Raises: + LookupError: If no model is found with the current_display_name. + ValueError: If a new display_name conflicts with an existing model. + """ try: - existing_model_by_display = get_model_by_display_name(model_data["display_name"], tenant_id) - current_model_id = int(model_data["model_id"]) - existing_model_id = existing_model_by_display["model_id"] if existing_model_by_display else None - - if existing_model_by_display and existing_model_id != current_model_id: - raise ValueError( - f"Name {model_data['display_name']} is already in use, please choose another display name") + # Get all models with the current display_name (may be 1 or 2 for embedding types) + existing_models = get_models_by_display_name(current_display_name, tenant_id) - update_model_record(current_model_id, model_data, user_id) - logging.debug( - f"Model {model_data['display_name']} updated successfully") + if not existing_models: + raise LookupError(f"Model not found: {current_display_name}") + + # Check if a new display_name is being set and if it conflicts + new_display_name = model_data.get("display_name") + if new_display_name and new_display_name != current_display_name: + conflict_models = get_models_by_display_name(new_display_name, tenant_id) + if conflict_models: + raise ValueError( + f"Name {new_display_name} is already in use, please choose another display name" + ) + + # Check if any of the existing models is multi_embedding + has_multi_embedding = any( + m.get("model_type") == "multi_embedding" for m in existing_models + ) + + if has_multi_embedding: + # Update both embedding and multi_embedding records + for model in existing_models: + # Prepare update data, excluding model_type to preserve original type + update_data = {k: v for k, v in model_data.items() if k not in ["model_id", "model_type"]} + update_model_record(model["model_id"], update_data, user_id) + logging.debug( + f"Model {current_display_name} (embedding + multi_embedding) updated successfully") + else: + # Single model update + current_model_id = existing_models[0]["model_id"] + update_data = {k: v for k, v in model_data.items() if k != "model_id"} + update_model_record(current_model_id, update_data, user_id) + logging.debug(f"Model {current_display_name} updated successfully") + except LookupError: + raise + except ValueError: + raise except Exception as e: logging.error(f"Failed to update model: {str(e)}") raise Exception(f"Failed to update model: {str(e)}") @@ -218,7 +265,7 @@ async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_lis """Batch update models for a tenant.""" try: for model in model_list: - update_model_record(model["model_id"], model, user_id) + update_model_record(model["model_id"], model, user_id, tenant_id) logging.debug("Batch update models successfully") except Exception as e: @@ -229,24 +276,24 @@ async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_lis async def delete_model_for_tenant(user_id: str, tenant_id: str, display_name: str): """Delete model(s) by display_name. If embedding/multi_embedding, delete both types.""" try: - model = get_model_by_display_name(display_name, tenant_id) - if not model: + # Get all models with this display_name (may be 1 or 2 for embedding types) + models = get_models_by_display_name(display_name, tenant_id) + if not models: raise LookupError(f"Model not found: {display_name}") deleted_types: List[str] = [] - if model.get("model_type") in ["embedding", "multi_embedding"]: - # Fetch both variants once to avoid repeated lookups - models_by_type: Dict[str, Dict[str, Any]] = {} - for t in ["embedding", "multi_embedding"]: - m = get_model_by_display_name(display_name, tenant_id) - if m and m.get("model_type") == t: - models_by_type[t] = m - - # Best-effort memory cleanup using the fetched variants + + # Check if any of the models is multi_embedding (which means we have both types) + has_multi_embedding = any( + m.get("model_type") == "multi_embedding" for m in models + ) + + if has_multi_embedding: + # Best-effort memory cleanup for embedding models try: vdb_core = get_vector_db_core() base_memory_config = build_memory_config_for_tenant(tenant_id) - for t, m in models_by_type.items(): + for m in models: try: await clear_model_memories( vdb_core=vdb_core, @@ -267,17 +314,21 @@ async def delete_model_for_tenant(user_id: str, tenant_id: str, display_name: st logger.warning( "Memory cleanup preparation failed: %s", outer_cleanup_exc) - # Delete the fetched variants - for t, m in models_by_type.items(): + # Delete all records with the same display_name + for m in models: delete_model_record(m["model_id"], user_id, tenant_id) - deleted_types.append(t) + deleted_types.append(m.get("model_type", "unknown")) else: + # Single model delete + model = models[0] delete_model_record(model["model_id"], user_id, tenant_id) deleted_types.append(model.get("model_type", "unknown")) logging.debug( f"Successfully deleted model(s) in types: {', '.join(deleted_types)}") return display_name + except LookupError: + raise except Exception as e: logging.error(f"Failed to delete model: {str(e)}") raise Exception(f"Failed to delete model: {str(e)}") @@ -288,6 +339,12 @@ async def list_models_for_tenant(tenant_id: str): try: records = get_model_records(None, tenant_id) result: List[Dict[str, Any]] = [] + + # Type mapping for backwards compatibility (chat -> llm for frontend) + type_map = { + "chat": "llm", + } + for record in records: record["model_name"] = add_repo_to_name( model_repo=record["model_repo"], @@ -295,6 +352,11 @@ async def list_models_for_tenant(tenant_id: str): ) record["connect_status"] = ModelConnectStatusEnum.get_value( record.get("connect_status")) + + # Map model_type if necessary (for ModelEngine compatibility) + if record.get("model_type") in type_map: + record["model_type"] = type_map[record["model_type"]] + result.append(record) logging.debug("Successfully retrieved model list") diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index ecde87321..271ad7f99 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -3,14 +3,18 @@ from typing import Dict, List import httpx +import aiohttp from consts.const import ( DEFAULT_LLM_MAX_TOKENS, DEFAULT_EXPECTED_CHUNK_SIZE, - DEFAULT_MAXIMUM_CHUNK_SIZE + DEFAULT_MAXIMUM_CHUNK_SIZE, + MODEL_ENGINE_HOST, + MODEL_ENGINE_APIKEY, ) from consts.model import ModelConnectStatusEnum, ModelRequest from consts.provider import SILICON_GET_URL, ProviderEnum +from consts.exceptions import TimeoutException from database.model_management_db import get_models_by_tenant_factory_type from services.model_health_service import embedding_dimension_check from utils.model_name_utils import split_repo_name, add_repo_to_name @@ -67,6 +71,76 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: return [] +class ModelEngineProvider(AbstractModelProvider): + """Concrete implementation for ModelEngine provider.""" + + async def get_models(self, provider_config: Dict) -> List[Dict]: + """ + Fetch models from ModelEngine API. + + Args: + provider_config: Configuration dict containing model_type + + Returns: + List of models with canonical fields + """ + try: + if not MODEL_ENGINE_HOST or not MODEL_ENGINE_APIKEY: + logger.warning("ModelEngine environment variables not configured") + return [] + + model_type: str = provider_config.get("model_type", "") + headers = {"Authorization": f"Bearer {MODEL_ENGINE_APIKEY}"} + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=30), + connector=aiohttp.TCPConnector(ssl=False) + ) as session: + async with session.get( + f"{MODEL_ENGINE_HOST}/open/router/v1/models", + headers=headers + ) as response: + response.raise_for_status() + data = await response.json() + all_models = data.get("data", []) + + # Type mapping from ModelEngine to internal types + type_map = { + "embed": "embedding", + "chat": "llm", + "asr": "stt", + "tts": "tts", + "rerank": "rerank", + "vlm": "vlm", + } + + # Filter models by type if specified + filtered_models = [] + for model in all_models: + me_type = model.get("type", "") + internal_type = type_map.get(me_type) + + # If model_type filter is provided, only include matching models + if model_type and internal_type != model_type: + continue + + if internal_type: + filtered_models.append({ + "id": model.get("id", ""), + "model_type": internal_type, + "model_tag": me_type, + "max_tokens": DEFAULT_LLM_MAX_TOKENS if internal_type in ("llm", "vlm") else 0, + # ModelEngine models will get base_url and api_key from environment + "base_url": MODEL_ENGINE_HOST, + "api_key": MODEL_ENGINE_APIKEY, + }) + + return filtered_models + except Exception as e: + logger.error(f"Error getting models from ModelEngine: {e}") + return [] + + async def prepare_model_dict(provider: str, model: dict, model_url: str, model_api_key: str) -> dict: """ Construct a model configuration dictionary that is ready to be stored in the @@ -75,11 +149,10 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a the router implementation concise. Args: - provider: Name of the model provider (e.g. "silicon", "openai"). + provider: Name of the model provider (e.g. "silicon", "openai", "modelengine"). model: A single model item coming from the provider list. model_url: Base URL for the provider API. model_api_key: API key that should be saved together with the model. - max_tokens: User-supplied max token / embedding dimension upper-bound. Returns: A dictionary ready to be passed to *create_model_record*. @@ -98,6 +171,18 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a expected_chunk_size = model.get("expected_chunk_size", DEFAULT_EXPECTED_CHUNK_SIZE) maximum_chunk_size = model.get("maximum_chunk_size", DEFAULT_MAXIMUM_CHUNK_SIZE) + # For ModelEngine provider, extract the host from model's base_url + # We'll append the correct path later + if provider == ProviderEnum.MODELENGINE.value: + # Get the raw host URL from model (e.g., "https://120.253.225.102:50001") + raw_model_url = model.get("base_url", "") + # Strip any existing path to get just the host + if raw_model_url: + # Remove any trailing /open/router/v1 or similar paths to get base host + raw_model_url = raw_model_url.split("/open/")[0] if "/open/" in raw_model_url else raw_model_url + model_url = raw_model_url + model_api_key = model.get("api_key", model_api_key) + # Build the canonical representation using the existing Pydantic schema for # consistency of validation and default handling. model_obj = ModelRequest( @@ -117,11 +202,24 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a # Determine the correct base_url and, for embeddings, update the actual # dimension by performing a real connectivity check. if model["model_type"] in ["embedding", "multi_embedding"]: - model_dict["base_url"] = f"{model_url}embeddings" + if provider != ProviderEnum.MODELENGINE.value: + model_dict["base_url"] = f"{model_url}embeddings" + else: + # For ModelEngine embedding models, append the embeddings path + model_dict["base_url"] = f"{model_url.rstrip('/')}/open/router/v1/embeddings" # The embedding dimension might differ from the provided max_tokens. model_dict["max_tokens"] = await embedding_dimension_check(model_dict) else: - model_dict["base_url"] = model_url + # For non-embedding models + if provider == ProviderEnum.MODELENGINE.value: + # Ensure ModelEngine models have the full API path + model_dict["base_url"] = f"{model_url.rstrip('/')}/open/router/v1" + else: + model_dict["base_url"] = model_url + + # ModelEngine models don't support SSL verification + if provider == ProviderEnum.MODELENGINE.value: + model_dict["ssl_verify"] = False # All newly created models start in NOT_DETECTED status. model_dict["connect_status"] = ModelConnectStatusEnum.NOT_DETECTED.value @@ -182,5 +280,8 @@ async def get_provider_models(model_data: dict) -> List[dict]: if model_data["provider"] == ProviderEnum.SILICON.value: provider = SiliconModelProvider() model_list = await provider.get_models(model_data) + elif model_data["provider"] == ProviderEnum.MODELENGINE.value: + provider = ModelEngineProvider() + model_list = await provider.get_models(model_data) return model_list diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index 55d2a5e4a..e72b3f9f3 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -15,12 +15,11 @@ import time import uuid from datetime import datetime, timezone -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Dict, List, Optional from fastapi import Body, Depends, Path, Query from fastapi.responses import StreamingResponse from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding -from nexent.core.nlp.tokenizer import calculate_term_weights from nexent.vector_database.base import VectorDatabaseCore from nexent.vector_database.elasticsearch_core import ElasticSearchCore diff --git a/backend/services/voice_service.py b/backend/services/voice_service.py index a66f7c15d..0bffec895 100644 --- a/backend/services/voice_service.py +++ b/backend/services/voice_service.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Dict, Any, Optional +from typing import Any, Optional from nexent.core.models.stt_model import STTConfig, STTModel from nexent.core.models.tts_model import TTSConfig, TTSModel diff --git a/doc/docs/en/sdk/core/multimodal.md b/doc/docs/en/sdk/core/multimodal.md new file mode 100644 index 000000000..e83f83416 --- /dev/null +++ b/doc/docs/en/sdk/core/multimodal.md @@ -0,0 +1,327 @@ +# Multimodal Module + +This module provides a native multimodal data processing bus designed for agents. With the `@load_object` and `@save_object` decorators, it supports real-time transmission and processing of text, images, audio, video, and other data formats, enabling seamless cross-modal data flow. + +## 📋 Table of Contents + +- [LoadSaveObjectManager Initialization](#loadsaveobjectmanager-initialization) +- [@load_object Decorator](#load_object-decorator) +- [@save_object Decorator](#save_object-decorator) +- [Combined Usage Example](#combined-usage-example) + +## LoadSaveObjectManager Initialization + +Before using the decorators, you need to initialize a `LoadSaveObjectManager` instance and pass in a storage client (for example, a MinIO client): + +```python +from nexent.multi_modal.load_save_object import LoadSaveObjectManager +from database.client import minio_client + + +# Create manager instance +Multimodal = LoadSaveObjectManager(storage_client=minio_client) +``` + +You can also implement your own storage client based on the `StorageClient` base class in `sdk.nexent.storage.storage_client_base`. +The storage client must implement: + +- `get_file_stream(object_name, bucket)`: get a file stream from storage (for download) +- `upload_fileobj(file_obj, object_name, bucket)`: upload a file-like object to storage (for save) + +## @load_object Decorator + +The `@load_object` decorator downloads files from URLs (S3 / HTTP / HTTPS) **before** the wrapped function is executed, and passes the file content (or transformed data) into the wrapped function. + +### Features + +- **Automatic download**: Automatically detect and download files pointed to by S3, HTTP, or HTTPS URLs. +- **Data transformation**: Use custom transformer functions to convert downloaded bytes into types required by the wrapped function (for example, `PIL.Image`, text, etc.). +- **Batch processing**: Support a single URL or a list of URLs. + +### Parameters + +- `input_names` (`List[str]`): names of function parameters to transform. +- `input_data_transformer` (`Optional[List[Callable[[bytes], Any]]]`): optional list of transformers; each transformer converts raw `bytes` into the target type for the corresponding parameter. + +### Supported URL Formats + +The decorator supports: + +- **S3 URLs** + - `s3://bucket-name/object/file.jpg` + - `/bucket-name/object/file.jpg` (short form) +- **HTTP / HTTPS URLs** + - `http://example.com/file.jpg` + - `https://example.com/file.jpg` + +URL type detection: + +- Starts with `http://` → HTTP URL +- Starts with `https://` → HTTPS URL +- Starts with `s3://` or looks like `/bucket/object` → S3 URL + +### Examples + +#### Basic: download as bytes + +```python +@Multimodal.load_object(input_names=["image_url"]) +def process_image(image_url: bytes): + """image_url will be replaced with downloaded bytes.""" + print(f"File size: {len(image_url)} bytes") + return image_url + + +# Call process_image +result = process_image(image_url="http://example.com/pic.PNG") +``` + +#### Advanced: convert bytes to PIL Image + +If the function parameter is not `bytes` (for example, it expects `PIL.Image.Image`), define a converter (such as `bytes_to_pil`) and pass it to the decorator. + +```python +import io +from PIL import Image + + +def bytes_to_pil(binary_data: bytes) -> Image.Image: + image_stream = io.BytesIO(binary_data) + img = Image.open(image_stream) + return img + + +@Multimodal.load_object( + input_names=["image_url"], + input_data_transformer=[bytes_to_pil], +) +def process_image(image_url: Image.Image) -> Image.Image: + """image_url will be converted into a PIL Image object.""" + resized = image_url.resize((800, 600)) + return resized + + +result = process_image(image_url="http://example.com/pic.PNG") +``` + +#### Multiple inputs + +```python +from PIL import Image + + +@Multimodal.load_object( + input_names=["image_url1", "image_url2"], + input_data_transformer=[bytes_to_pil, bytes_to_pil], +) +def process_two_images(image_url1: Image.Image, image_url2: Image.Image) -> Image.Image: + """Both image URLs will be downloaded and converted into PIL Images.""" + combined = Image.new("RGB", (1600, 600)) + combined.paste(image_url1, (0, 0)) + combined.paste(image_url2, (800, 0)) + return combined + + +result = process_two_images( + image_url1="http://example.com/pic1.PNG", + image_url2="http://example.com/pic2.PNG", +) +``` + +#### List of URLs + +```python +from typing import List +from PIL import Image + + +@Multimodal.load_object( + input_names=["image_urls"], + input_data_transformer=[bytes_to_pil], +) +def process_image_list(image_urls: List[Image.Image]) -> List[Image.Image]: + """Support a list of URLs, each will be downloaded and converted.""" + results: List[Image.Image] = [] + for img in image_urls: + results.append(img.resize((200, 200))) + return results + + +result = process_image_list( + image_urls=[ + "http://example.com/pic1.PNG", + "http://example.com/pic2.PNG", + ] +) +``` + +## @save_object Decorator + +The `@save_object` decorator uploads return values to storage (MinIO) **after** the wrapped function finishes, and returns S3 URLs. + +### Features + +- **Automatic upload**: Automatically upload function return values to MinIO. +- **Data transformation**: Use transformers to convert return values into `bytes` (for example, `PIL.Image` → `bytes`). +- **Batch processing**: Support a single return value or multiple values (tuple). +- **URL return**: Return S3 URLs of the form `s3://bucket/object_name`. + +### Parameters + +- `output_names` (`List[str]`): logical names for each return value. +- `output_transformers` (`Optional[List[Callable[[Any], bytes]]]`): transformers that convert each return value into `bytes`. +- `bucket` (`str`): target bucket name, default `"nexent"`. + +### Examples + +#### Basic: save raw bytes + +```python +@Multimodal.save_object( + output_names=["content"], +) +def generate_file() -> bytes: + """Returned bytes will be uploaded to MinIO automatically.""" + content = b"Hello, World!" + return content +``` + +#### Advanced: convert PIL Image to bytes before upload + +If the function does not return `bytes` (for example, it returns `PIL.Image.Image`), define a converter such as `pil_to_bytes` and pass it to the decorator. + +```python +import io +from typing import Optional +from PIL import Image, ImageFilter + + +def pil_to_bytes(img: Image.Image, format: Optional[str] = None) -> bytes: + """ + Convert a PIL Image to binary data (bytes). + """ + if img is None: + raise ValueError("Input image cannot be None") + + buffer = io.BytesIO() + + # Decide which format to use + if format is None: + # Use original format if available, otherwise default to PNG + format = img.format if img.format else "PNG" + + # For JPEG, ensure RGB (no alpha channel) + if format.upper() == "JPEG" and img.mode in ("RGBA", "LA", "P"): + rgb_img = Image.new("RGB", img.size, (255, 255, 255)) + if img.mode == "P": + img = img.convert("RGBA") + rgb_img.paste( + img, + mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None, + ) + rgb_img.save(buffer, format=format) + else: + img.save(buffer, format=format) + + data = buffer.getvalue() + buffer.close() + return data + + +@Multimodal.save_object( + output_names=["processed_image"], + output_transformers=[pil_to_bytes], +) +def process_image(image: Image.Image) -> Image.Image: + """Returned PIL Image will be converted to bytes and uploaded.""" + blurred = image.filter(ImageFilter.GaussianBlur(radius=5)) + return blurred +``` + +#### Multiple files + +```python +from typing import Tuple + + +@Multimodal.save_object( + output_names=["resized1", "resized2"], + output_transformers=[pil_to_bytes, pil_to_bytes], +) +def process_two_images( + img1: Image.Image, + img2: Image.Image, +) -> Tuple[Image.Image, Image.Image]: + """Both returned images will be uploaded and return corresponding S3 URLs.""" + resized1 = img1.resize((800, 600)) + resized2 = img2.resize((800, 600)) + return resized1, resized2 +``` + +### Return Format + +- **Single return value**: a single S3 URL string, `s3://bucket/object_name`. +- **Multiple return values (tuple)**: a tuple where each element is the corresponding S3 URL. + +### Notes + +- If you do **not** provide a transformer, the function return value must be `bytes`. +- If you provide a transformer, the transformer **must** return `bytes`. +- The number of return values must match the length of `output_names`. + +## Combined Usage Example + +In practice, `@load_object` and `@save_object` are often used together to build a full **download → process → upload** pipeline: + +```python +from typing import Union, List +from PIL import Image, ImageFilter + +from database.client import minio_client +from nexent.multi_modal.load_save_object import LoadSaveObjectManager + + +Multimodal = LoadSaveObjectManager(storage_client=minio_client) + + +@Multimodal.load_object( + input_names=["image_url"], + input_data_transformer=[bytes_to_pil], +) +@Multimodal.save_object( + output_names=["blurred_image"], + output_transformers=[pil_to_bytes], +) +def blur_image_tool( + image_url: Union[str, List[str]], + blur_radius: int = 5, +) -> Image.Image: + """ + Apply a Gaussian blur filter to an image. + + Args: + image_url: S3 URL or HTTP/HTTPS URL of the image. + blur_radius: Blur radius (default 5, valid range 1–50). + + Returns: + Processed PIL Image object (it will be uploaded and returned as an S3 URL). + """ + # At this point, image_url has already been converted to a PIL Image + if image_url is None: + raise ValueError("Failed to load image") + + # Clamp blur radius + blur_radius = max(1, min(50, blur_radius)) + + # Apply blur + blurred_image = image_url.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + return blurred_image + + +# Example usage +result_url = blur_image_tool( + image_url="s3://nexent/images/input.png", + blur_radius=10, +) +# result_url is something like "s3://nexent/attachments/xxx.png" +``` \ No newline at end of file diff --git a/doc/docs/en/user-guide/agent-development.md b/doc/docs/en/user-guide/agent-development.md index 54970e5c2..538041ab1 100644 --- a/doc/docs/en/user-guide/agent-development.md +++ b/doc/docs/en/user-guide/agent-development.md @@ -15,6 +15,14 @@ If you have an existing agent configuration, you can also import it: +> ⚠️ **Note:** If you import an agent with a duplicate name, a prompt dialog will appear. You can choose: +> - **Import anyway**: Keep the duplicate name; the imported agent will be in an unavailable state and requires manual modification of the Agent name and variable name before it can be used +> - **Regenerate and import**: The system will call the LLM to rename the Agent, which will consume a certain amount of model tokens and may take longer + +
+ +
+ ## 👥 Configure Collaborative Agents/Tools You can configure other collaborative agents for your created agent, as well as assign available tools to empower the agent to complete complex tasks. diff --git a/doc/docs/en/user-guide/assets/agent-development/duplicated_import.png b/doc/docs/en/user-guide/assets/agent-development/duplicated_import.png new file mode 100644 index 000000000..3d7e0e6bc Binary files /dev/null and b/doc/docs/en/user-guide/assets/agent-development/duplicated_import.png differ diff --git a/doc/docs/zh/opensource-memorial-wall.md b/doc/docs/zh/opensource-memorial-wall.md index d1c9a3ccc..41fdb6ee8 100644 --- a/doc/docs/zh/opensource-memorial-wall.md +++ b/doc/docs/zh/opensource-memorial-wall.md @@ -24,6 +24,10 @@ 第一次玩开源项目,nexent真的挺好用的!用自然语言就能搞智能体,比我想象的简单多了 ::: +::: info codingcat99 - 2025-05-15 +大家一起加油 +::: + ::: tip bytedancer2023 - 2025-05-18 我们小公司想做客服机器人,之前技术门槛太高了。nexent的多文件格式支持真的帮了大忙,产品经理现在也能自己调智能体了哈哈 ::: @@ -517,15 +521,15 @@ nexent智能体帮助我学到更多的东西,赞! 第一次使用nexent,想借此更快入手ai应用开发呀! ::: -:::info user - 2025-11-26 +::: info user - 2025-11-26 Nexent开发者加油 ::: -:::info NOSN - 2025-11-27 +::: info NOSN - 2025-11-27 Nexent越做越强大! ::: -:::info Chenpi-Sakura - 2025-11-27 +::: info Chenpi-Sakura - 2025-11-27 开源共创未来! ::: @@ -533,11 +537,11 @@ Nexent越做越强大! Nexent加油 ::: -:::info AstreoX - 2025-11-27 +::: info AstreoX - 2025-11-27 感谢Nexent为智能体开发提出了更多可能! ::: -:::info user - 2025-11-26 +::: info user - 2025-11-26 祝nexent平台越做越胡奥 ::: @@ -552,3 +556,54 @@ Nexent加油 ::: info kj - 2025-11-27 祝越来越好 ::: + +::: info aaa - 2025-11-28 +祝nexent平台越来越好 +::: + +::: info hanyuan5888-beep - 2025-11-29 +通过华为ICT大赛接触到的这个平台,前端做的非常好看,并且功能很全面。 +::: + +::: info user - 2025-11-29 +感谢 Nexent 让我踏上了开源之旅!这个项目的文档真的很棒,帮助我快速上手。 +::: + +::: info G-oeX - 2025-11-30 +感谢 Nexent 让我第一次感受到智能体 希望参加ICT比赛过程中可以学到更多知识 能够对该领域有更多的了解和认识!star!!! +::: + +::: tip peri1506 - 2025-11-30 +感谢 Nexent 让我踏上了开源之旅!这个项目的文档真的很棒,帮助我快速上手。 +::: + +::: tip kissmekm - 2025-12-01 +感谢 Nexent 让我踏上了开源之旅!希望能使用nexent开发智能体 +::: + +::: info luna - 2025-12-1 +感谢nexent,祝平台越做越强大 +::: + +::: info sbwrn - 2025-12-02 +祝越来越好 +::: + +::: info sbwrn - 2025-12-02 +祝nexent平台越来越好 + +:::tip 开源新手 - 2025-12-02 +感谢 Nexent 让我踏上了开源之旅!这个项目的文档真的很棒,帮助我快速上手。 +::: + +::: info sbwrn - 2025-12-02 +祝nexent平台越来越好 +::: + +::: info dengpeiying - 2025-12-02 +Nexent开发者加油 +::: + +::: info jinhb - 2025-12-03 +祝nexent平台越来越好 +::: diff --git a/doc/docs/zh/sdk/core/multimodal.md b/doc/docs/zh/sdk/core/multimodal.md new file mode 100644 index 000000000..eec6c66cb --- /dev/null +++ b/doc/docs/zh/sdk/core/multimodal.md @@ -0,0 +1,312 @@ +# 多模态模块 + +本模块提供专为智能体设计的原生多模态数据处理总线,通过 `@load_object`、 `@save_object` 装饰器,支持文本、图像、音频、视频等多种数据格式的实时传输和处理,实现跨模态的无缝数据流转。 + +## 📋 目录 + +- [LoadSaveObjectManager 初始化](#loadsaveobjectmanager-初始化) +- [@load_object装饰器](#@load_object装饰器) +- [@save_object装饰器](#@save_object装饰器) +- [组合使用示例](#组合使用示例) + + +## LoadSaveObjectManager 初始化 + +在使用装饰器之前,需要先初始化 `LoadSaveObjectManager` 实例,并传入存储客户端(如 MinIO 客户端): + +```python +from nexent.multi_modal.load_save_object import LoadSaveObjectManager +from database.client import minio_client + + +# 创建管理器实例 +Multimodal = LoadSaveObjectManager(storage_client=minio_client) +``` + +存储客户端也可以通过`sdk.nexent.storage.storage_client_base`中的`StorageClient`基类,实现自己的存储客户端。存储客户端需要实现以下方法: +- `get_file_stream(object_name, bucket)`: 从存储中获取文件流(用于下载) +- `upload_fileobj(file_obj, object_name, bucket)`: 上传文件对象到存储(用于保存) + + +## @load_object装饰器 + +`@load_object` 装饰器用于在被装饰函数执行前自动从 URL(S3、HTTP、HTTPS)下载文件,并将文件内容(或转换后的数据)传递给被装饰函数。 + +### 功能特性 + +- **自动下载**: 自动识别并下载 S3、HTTP、HTTPS URL 指向的文件 +- **数据转换**: 支持通过自定义转换器将下载的字节数据转换为被装饰函数所需格式(如 PIL Image、文本等) +- **批量处理**: 支持处理单个 URL 或 URL 列表 + + +### 参数说明 + +- `input_names` (List[str]): : 需要处理的函数参数名称列表 +- `input_data_transformer` (Optional[List[Callable[[Any], bytes]]]): 可选的数据转换器列表,用于将下载的字节数据转换为所需格式 + +### 支持的URL格式 + +装饰器支持以下 URL 格式: + +- S3 URL + - `s3://bucket-name/object/file.jpg` + - `/bucket-name/object/file.jpg`(简化格式) +- HTTP/HTTPS URL + - `http://example.com/file.jpg` + - `https://example.com/file.jpg` + + +系统会自动检测 URL 类型: +- 以 `http://` 开头 → HTTP URL +- 以 `https://` 开头 → HTTPS URL +- 以 `s3://` 开头或符合 `/bucket/object` 格式 → S3 URL + +### 使用示例 + +#### 基础用法:下载为字节数据 + +```python +@Multimodal.load_object(input_names=["image_url"]) +def process_image(image_url: bytes): + """file_url 参数会被自动替换为从 URL 下载的字节数据""" + print(f"文件大小: {len(image_url)} bytes") + return image_url + +# 调用process_file方法 +result = process_image(image_url=f"http://example/pic.PNG") +``` + +#### 进阶用法:使用转换器将字节数据转换为所需格式 + +若被装饰函数的入参不是字节数据,而是其他数据类型的数据(如PIL Image)。可以定义一个数据转换的函数(如bytes_to_pil)并将函数名作为入参传给装饰器。 + +```python +import io +import PIL +from PIL import Image + +def bytes_to_pil(binary_data): + image_stream = io.BytesIO(binary_data) + img = Image.open(image_stream) + return img + +@Multimodal.load_object( + input_names=["image_url"], + input_data_transformer=[bytes_to_pil] +) +def process_image(image_url: Image.Image): + """image_url 参数会被自动转换为 PIL Image 对象""" + resized = image_url.resize((800, 600)) + return resized + +# 调用process_file方法 +result = process_image(image_url=f"http://example/pic.PNG") +``` + +#### 处理多个输入 + +```python +@Multimodal.load_object( + input_names=["image_url1", "image_url2"], + input_data_transformer=[bytes_to_pil, bytes_to_pil] +) +def process_two_images(image_url1: Image.Image, image_url2: Image.Image): + """两个图片 URL 都会被下载并转换为 PIL Image""" + combined = Image.new('RGB', (1600, 600)) + combined.paste(image_url1, (0, 0)) + combined.paste(image_url2, (800, 0)) + return combined + +# 调用process_file方法 +result = process_two_images(image_url1=f"http://example/pic1.PNG", image_url2=f"http://example/pic2.PNG") +``` + +#### 处理 URL 列表 + +```python +@Multimodal.load_object( + input_names=["image_urls"], + input_data_transformer=[bytes_to_pil] +) +def process_image_list(image_urls: List[Image.Image]): + """支持传入 URL 列表,每个 URL 都会被下载并转换""" + results = [] + for img in image_urls: + results.append(img.resize((200, 200))) + return results + +# 调用process_file方法 +result = process_image_list(image_urls=["http://example/pic1.PNG", "http://example/pic2.PNG"]) +``` + + +## @save_object装饰器 + +`@save_object` 装饰器用于在被装饰函数执行后自动将返回值上传到存储(MinIO),并返回 S3 URL。 + +### 功能特性 + +- **自动上传**: 自动将被装饰函数返回值上传到 MinIO 存储 +- **数据转换**: 支持通过转换器将返回值转换为字节数据(如 PIL Image 转 bytes) +- **批量处理**: 支持处理单个返回值或多个返回值(tuple) +- **URL 返回**: 返回 S3 URL 格式(`s3://bucket/object_name`) + +### 参数说明 + +- `output_names` (List[str]): 被装饰器函数的输出参数的名称列表 +- `output_transformers` (Optional[List[Callable[[Any], bytes]]]): 可选的数据转换器列表,用于将返回值转换为字节数据 +- `bucket` (str): 存储桶名称,默认为 `"nexent"` + +### 使用示例 + +#### 基础用法:直接保存字节数据 + +```python +@Multimodal.save_object( + output_names=["content"] +) +def generate_file() -> bytes: + """返回的字节数据会被自动上传到 MinIO""" + content = b"Hello, World!" + return content +``` + +#### 进阶用法: 使用转换器将函数返回值转换为字节数据 + +若被装饰函数的出参不是字节数据,而是其他数据类型的数据(如PIL Image)。可以定义一个数据转换的函数(如pil_to_bytes)并将函数名作为入参传给装饰器。 + + +```python +# 定义将PIL对象转换为Bytes的转换器函数 +def pil_to_bytes(img, format=None): + """ + Convert PIL Image to binary data (bytes) + + Args: + img: PIL.Image object + format: Output format ('JPEG', 'PNG', 'BMP', 'WEBP', etc.). + If None, uses the image's original format or defaults to PNG. + + Returns: + bytes: Binary data of the image + """ + if img is None: + raise ValueError("Input image cannot be None") + + # Create memory buffer + buffer = io.BytesIO() + + # Determine format to use + if format is None: + # Use image's original format if available, otherwise default to PNG + format = img.format if img.format else 'PNG' + + # Save image to buffer with specified format + # For JPEG, ensure RGB mode (no transparency) + if format.upper() == 'JPEG' and img.mode in ('RGBA', 'LA', 'P'): + # Convert to RGB for JPEG compatibility + rgb_img = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'P': + img = img.convert('RGBA') + rgb_img.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None) + rgb_img.save(buffer, format=format) + else: + img.save(buffer, format=format) + + # Get binary data + binary_data = buffer.getvalue() + buffer.close() + + return binary_data + + +@Multimodal.save_object( + output_names=["processed_image"], + output_transformers=[pil_to_bytes] +) +def process_image(image: Image.Image) -> Image.Image: + """返回的 PIL Image 会被转换为字节并上传""" + blurred = image.filter(ImageFilter.GaussianBlur(radius=5)) + return blurred +``` + +#### 返回多个文件 + +```python +@Multimodal.save_object( + output_names=["resized1", "resized2"], + output_transformers=[pil_to_bytes, pil_to_bytes] +) +def process_two_images(img1: Image.Image, img2: Image.Image) -> Tuple[Image.Image, Image.Image]: + """返回两个图片,都会被上传并返回对应的 S3 URL""" + resized1 = img1.resize((800, 600)) + resized2 = img2.resize((800, 600)) + return resized1, resized2 +``` + +### 返回值格式 + +- 单个返回值:返回单个 S3 URL 字符串,格式为 `s3://bucket/object_name` +- 多个返回值(tuple):返回 tuple,每个元素是对应的 S3 URL + +### 注意事项 + +- 如果没有提供转换器,被装饰函数的返回值必须是 `bytes` 类型 +- 如果提供了转换器,转换器必须返回 `bytes` 类型 +- 返回值的数量必须与 `output_names` 的长度一致 + + +## 组合使用示例 + +在实际应用中,通常会将 `@load_object` 和 `@save_object` 组合使用,实现完整的"下载-处理-上传"流程: + +```python +from PIL import Image, ImageFilter +from typing import Union, List +from database.client import minio_client +from multi_modal.load_save_object import LoadSaveObjectManager + +Multimodal = LoadSaveObjectManager(storage_client=minio_client) + +@Multimodal.load_object( + input_names=["image_url"], + input_data_transformer=[bytes_to_pil] +) +@Multimodal.save_object( + output_names=["blurred_image"], + output_transformers=[pil_to_bytes] +) +def blur_image_tool( + image_url: Union[str, List[str]], + blur_radius: int = 5 +) -> Image.Image: + """ + 对图片应用高斯模糊滤镜 + + Args: + image_url: 图片的 S3 URL 或 HTTP/HTTPS URL + blur_radius: 模糊半径(默认 5,范围 1-50) + + Returns: + 处理后的 PIL Image 对象(会被自动上传并返回 S3 URL) + """ + # 此时 image_url 已经是 PIL Image 对象 + if image_url is None: + raise ValueError("Failed to load image") + + # 验证并限制模糊半径 + blur_radius = max(1, min(50, blur_radius)) + + # 应用模糊滤镜 + blurred_image = image_url.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # 返回 PIL Image(会被 @save_object 自动上传) + return blurred_image + +# 使用示例 +result_url = blur_image_tool( + image_url="s3://nexent/images/input.png", + blur_radius=10 +) +# result_url 是 "s3://nexent/attachments/xxx.png" +``` \ No newline at end of file diff --git a/doc/docs/zh/user-guide/agent-development.md b/doc/docs/zh/user-guide/agent-development.md index 5f9eb0646..ff4c7c943 100644 --- a/doc/docs/zh/user-guide/agent-development.md +++ b/doc/docs/zh/user-guide/agent-development.md @@ -15,6 +15,14 @@ +> ⚠️ **提示**:如果导入了重名的智能体,系统会弹出提示弹窗。您可以选择: +> - **直接导入**:保留重复名称,导入后的智能体会处于不可用状态,需手动修改 Agent 名称和变量名后才能使用 +> - **重新生成并导入**:系统将调用 LLM 对 Agent 进行重命名,会消耗一定的模型 token 数,可能耗时较长 + +
+ +
+ ## 👥 配置协作智能体/工具 您可以为创建的智能体配置其他协作智能体,也可以为它配置可使用的工具,以赋予智能体能力完成复杂任务。 diff --git a/doc/docs/zh/user-guide/assets/agent-development/duplicated_import.png b/doc/docs/zh/user-guide/assets/agent-development/duplicated_import.png new file mode 100644 index 000000000..e4d51cad5 Binary files /dev/null and b/doc/docs/zh/user-guide/assets/agent-development/duplicated_import.png differ diff --git a/docker/.env.example b/docker/.env.example index e770040e7..1018228ac 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -153,4 +153,4 @@ LLM_SLOW_REQUEST_THRESHOLD_SECONDS=5.0 LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0 # Market Backend Address -MARKET_BACKEND=http://localhost:8010 \ No newline at end of file +MARKET_BACKEND=https://market.nexent.tech diff --git a/docker/deploy.sh b/docker/deploy.sh index 2b3ea9618..dc39eaf6a 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -1,10 +1,13 @@ #!/bin/bash +# Ensure the script is executed with bash (required for arrays and [[ ]]) +if [ -z "$BASH_VERSION" ]; then + echo "❌ This script must be run with bash. Please use: bash deploy.sh or ./deploy.sh" + exit 0 +fi + # Exit immediately if a command exits with a non-zero status set -e - -ERROR_OCCURRED=0 - set -a source .env @@ -51,6 +54,173 @@ sanitize_input() { printf "%s" "$input" | tr -d '\r' } +is_windows_env() { + # Detect Windows Git Bash / MSYS / MINGW environment + local os_name + os_name=$(uname -s 2>/dev/null | tr '[:upper:]' '[:lower:]') + if [[ "$os_name" == mingw* || "$os_name" == msys* ]]; then + return 0 + fi + return 1 +} + +is_port_in_use() { + # Check if a TCP port is already in use (Linux/macOS/Windows Git Bash) + local port="$1" + + # Prefer lsof when available (typically on Linux/macOS) + if command -v lsof >/dev/null 2>&1 && ! is_windows_env; then + if lsof -iTCP:"$port" -sTCP:LISTEN -P -n >/dev/null 2>&1; then + return 0 + fi + return 1 + fi + + # Fallback to ss if available + if command -v ss >/dev/null 2>&1; then + if ss -ltn 2>/dev/null | awk '{print $4}' | grep -qE "[:\.]${port}$"; then + return 0 + fi + return 1 + fi + + # Fallback to netstat (works on Windows and many Linux distributions) + if command -v netstat >/dev/null 2>&1; then + if netstat -an 2>/dev/null | grep -qE "[:\.]${port}[[:space:]]"; then + return 0 + fi + return 1 + fi + + # If no inspection tool is available, assume the port is free + return 1 +} + +add_port_if_new() { + # Helper to add a port to global arrays only if not already present + local port="$1" + local source="$2" + local existing_port + + for existing_port in "${PORTS_TO_CHECK[@]}"; do + if [ "$existing_port" = "$port" ]; then + return 0 + fi + done + + PORTS_TO_CHECK+=("$port") + PORT_SOURCES+=("$source") +} + +collect_ports_from_env_file() { + # Collect ports from a single env file, based on addresses and *_PORT style variables + local env_file="$1" + + if [ ! -f "$env_file" ]; then + return 0 + fi + + # 1) Address-style values containing :PORT (for example http://host:3000) + # We only care about the numeric port part. + while IFS= read -r match; do + local port="${match#:}" + port=$(echo "$port" | tr -d '[:space:]') + if [[ "$port" =~ ^[0-9]{2,5}$ ]]; then + add_port_if_new "$port" "$env_file (address)" + fi + done < <(grep -Eo ':[0-9]{2,5}' "$env_file" 2>/dev/null | sort -u) + + # 2) Variables that explicitly define a port, for example FOO_PORT=3000 + while IFS= read -r line; do + # Strip inline comments + line="${line%%#*}" + # Extract value part after '=' + local value="${line#*=}" + value=$(echo "$value" | tr -d '[:space:]"'\''') + if [[ "$value" =~ ^[0-9]{2,5}$ ]]; then + add_port_if_new "$value" "$env_file (PORT variable)" + fi + done < <(grep -E '^[A-Za-z_][A-Za-z0-9_]*_PORT *= *[0-9]{2,5}' "$env_file" 2>/dev/null) +} + +check_ports_in_env_files() { + # Preflight check: ensure all ports referenced in env files are free + PORTS_TO_CHECK=() + PORT_SOURCES=() + + # Always include the main .env if present, plus any .env.* files + local env_files=() + if [ -f ".env" ]; then + env_files+=(".env") + fi + + # Include additional env variants such as .env.general and .env.mainland + local f + for f in .env.*; do + if [ -f "$f" ]; then + env_files+=("$f") + fi + done + + # Collect ports from all discovered env files + for f in "${env_files[@]}"; do + collect_ports_from_env_file "$f" + done + + if [ ${#PORTS_TO_CHECK[@]} -eq 0 ]; then + echo "🔍 No port definitions found in environment files, skipping port availability check." + echo "" + echo "--------------------------------" + echo "" + return 0 + fi + + echo "🔍 Checking port availability defined in environment files..." + local occupied_ports=() + local occupied_sources=() + + local idx + for idx in "${!PORTS_TO_CHECK[@]}"; do + local port="${PORTS_TO_CHECK[$idx]}" + local source="${PORT_SOURCES[$idx]}" + + if is_port_in_use "$port"; then + occupied_ports+=("$port") + occupied_sources+=("$source") + echo " ❌ Port $port is already in use." + else + echo " ✅ Port $port is free." + fi + done + + if [ ${#occupied_ports[@]} -gt 0 ]; then + echo "" + echo "❌ Port conflict detected. The following ports required by Nexent are already in use:" + local i + for i in "${!occupied_ports[@]}"; do + echo " - Port ${occupied_ports[$i]}" + done + echo "" + echo "Please free these ports or update the corresponding .env files." + echo "" + + # Ask user whether to continue deployment even if some ports are occupied + local confirm_continue + read -p "👉 Do you still want to continue deployment even though some ports are in use? [y/N]: " confirm_continue + confirm_continue=$(sanitize_input "$confirm_continue") + if ! [[ "$confirm_continue" =~ ^[Yy]$ ]]; then + echo "🚫 Deployment aborted due to port conflicts." + exit 0 + fi + + echo "⚠️ Continuing deployment even though some required ports are already in use." + fi + + echo "" + echo "--------------------------------" + echo "" +} + generate_minio_ak_sk() { echo "🔑 Generating MinIO keys..." @@ -69,7 +239,6 @@ generate_minio_ak_sk() { if [ -z "$ACCESS_KEY" ] || [ -z "$SECRET_KEY" ]; then echo " ❌ ERROR Failed to generate MinIO access keys" - ERROR_OCCURRED=1 return 1 fi @@ -130,7 +299,7 @@ generate_supabase_keys() { generate_elasticsearch_api_key() { # Function to generate Elasticsearch API key - wait_for_elasticsearch_healthy || { echo " ❌ Elasticsearch health check failed"; exit 1; } + wait_for_elasticsearch_healthy || { echo " ❌ Elasticsearch health check failed"; return 0; } # Generate API key echo "🔑 Generating ELASTICSEARCH_API_KEY..." @@ -203,7 +372,7 @@ get_compose_version() { fi echo "unknown" - return 1 + return 0 } disable_dashboard() { @@ -325,7 +494,6 @@ create_dir_with_permission() { # Check if parameters are provided if [ -z "$dir_path" ] || [ -z "$permission" ]; then echo " ❌ ERROR Directory path and permission parameters are required." >&2 - ERROR_OCCURRED=1 return 1 fi @@ -334,7 +502,6 @@ create_dir_with_permission() { mkdir -p "$dir_path" if [ $? -ne 0 ]; then echo " ❌ ERROR Failed to create directory $dir_path." >&2 - ERROR_OCCURRED=1 return 1 fi fi @@ -377,7 +544,7 @@ deploy_core_services() { echo "👀 Starting core services..." if ! ${docker_compose_command} -p nexent -f "docker-compose${COMPOSE_FILE_SUFFIX}" up -d nexent-config nexent-runtime nexent-mcp nexent-northbound nexent-web nexent-data-process; then echo " ❌ ERROR Failed to start core services" - exit 1 + return 0 fi } @@ -394,7 +561,7 @@ deploy_infrastructure() { if ! ${docker_compose_command} -p nexent -f "docker-compose${COMPOSE_FILE_SUFFIX}" up -d $INFRA_SERVICES; then echo " ❌ ERROR Failed to start infrastructure services" - exit 1 + return 0 fi if [ "$ENABLE_TERMINAL_TOOL_CONTAINER" = "true" ]; then @@ -408,14 +575,12 @@ deploy_infrastructure() { # Check if the supabase compose file exists if [ ! -f "docker-compose-supabase${COMPOSE_FILE_SUFFIX}" ]; then echo " ❌ ERROR Supabase compose file not found: docker-compose-supabase${COMPOSE_FILE_SUFFIX}" - ERROR_OCCURRED=1 return 1 fi # Start Supabase services if ! $docker_compose_command -p nexent -f "docker-compose-supabase${COMPOSE_FILE_SUFFIX}" up -d; then echo " ❌ ERROR Failed to start supabase services" - ERROR_OCCURRED=1 return 1 fi @@ -488,8 +653,7 @@ setup_package_install_script() { echo " ✅ Package installation script created/updated" else echo " ❌ ERROR openssh-install-script.sh not found" - ERROR_OCCURRED=1 - return 1 + return 0 fi } @@ -506,7 +670,7 @@ wait_for_elasticsearch_healthy() { if [ $retries -eq $max_retries ]; then echo " ⚠️ Warning: Elasticsearch did not become healthy within expected time" echo " You may need to check the container logs and try again" - return 1 + return 0 else echo " ✅ Elasticsearch is now healthy!" return 0 @@ -580,7 +744,6 @@ select_terminal_tool() { echo "" if [ -z "$input_password" ]; then echo "❌ SSH password cannot be empty" - ERROR_OCCURRED=1 return 1 fi SSH_PASSWORD="$input_password" @@ -589,7 +752,6 @@ select_terminal_tool() { # Validate credentials if [ -z "$SSH_USERNAME" ] || [ -z "$SSH_PASSWORD" ]; then echo "❌ Both username and password are required" - ERROR_OCCURRED=1 return 1 fi @@ -671,25 +833,28 @@ main_deploy() { echo "--------------------------------" echo "" + # Check all relevant ports from environment files before starting deployment + check_ports_in_env_files + # Select deployment version, mode and image source - select_deployment_version || { echo "❌ Deployment version selection failed"; exit 1; } - select_deployment_mode || { echo "❌ Deployment mode selection failed"; exit 1; } - select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 1; } - choose_image_env || { echo "❌ Image environment setup failed"; exit 1; } + select_deployment_version || { echo "❌ Deployment version selection failed"; exit 0; } + select_deployment_mode || { echo "❌ Deployment mode selection failed"; exit 0; } + select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 0; } + choose_image_env || { echo "❌ Image environment setup failed"; exit 0; } # Add permission - prepare_directory_and_data || { echo "❌ Permission setup failed"; exit 1; } - generate_minio_ak_sk || { echo "❌ MinIO key generation failed"; exit 1; } + prepare_directory_and_data || { echo "❌ Permission setup failed"; exit 0; } + generate_minio_ak_sk || { echo "❌ MinIO key generation failed"; exit 0; } # Generate Supabase secrets - generate_supabase_keys || { echo "❌ Supabase secrets generation failed"; exit 1; } + generate_supabase_keys || { echo "❌ Supabase secrets generation failed"; exit 0; } # Deploy infrastructure services - deploy_infrastructure || { echo "❌ Infrastructure deployment failed"; exit 1; } + deploy_infrastructure || { echo "❌ Infrastructure deployment failed"; exit 0; } # Generate Elasticsearch API key - generate_elasticsearch_api_key || { echo "❌ Elasticsearch API key generation failed"; exit 1; } + generate_elasticsearch_api_key || { echo "❌ Elasticsearch API key generation failed"; exit 0; } echo "" echo "--------------------------------" @@ -697,7 +862,7 @@ main_deploy() { # Special handling for infrastructure mode if [ "$DEPLOYMENT_MODE" = "infrastructure" ]; then - generate_env_for_infrastructure || { echo "❌ Environment generation failed"; exit 1; } + generate_env_for_infrastructure || { echo "❌ Environment generation failed"; exit 0; } echo "🎉 Infrastructure deployment completed successfully!" echo " You can now start the core services manually using dev containers" echo " Environment file available at: $(cd .. && pwd)/.env" @@ -706,7 +871,7 @@ main_deploy() { fi # Start core services - deploy_core_services || { echo "❌ Core services deployment failed"; exit 1; } + deploy_core_services || { echo "❌ Core services deployment failed"; exit 0; } echo " ✅ Core services started successfully" echo "" @@ -715,7 +880,7 @@ main_deploy() { # Create default admin user if [ "$DEPLOYMENT_VERSION" = "full" ]; then - create_default_admin_user || { echo "❌ Default admin user creation failed"; exit 1; } + create_default_admin_user || { echo "❌ Default admin user creation failed"; exit 0; } fi echo "🎉 Deployment completed successfully!" @@ -726,7 +891,7 @@ main_deploy() { version_info=$(get_compose_version) if [[ $version_info == "unknown" ]]; then echo "Error: Docker Compose not found or version detection failed" - exit 1 + exit 0 fi # extract version @@ -741,7 +906,7 @@ case $version_type in # The version ​​v1.28.0​​ is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default} if [[ $version_number < "1.28.0" ]]; then echo "Warning: V1 version is too old, consider upgrading to V2" - exit 1 + exit 0 fi docker_compose_command="docker-compose" ;; @@ -751,14 +916,14 @@ case $version_type in ;; *) echo "Error: Unknown docker compose version type." - exit 1 + exit 0 ;; esac # Execute main deployment with error handling if ! main_deploy; then echo "❌ Deployment failed. Please check the error messages above and try again." - exit 1 + exit 0 fi clean diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index 41b0f0c1c..71042c3ef 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -186,6 +186,7 @@ services: - WS_BACKEND=ws://nexent-runtime:5014 - RUNTIME_HTTP_BACKEND=http://nexent-runtime:5014 - MINIO_ENDPOINT=http://nexent-minio:9000 + - MARKET_BACKEND=https://market.nexent.tech logging: driver: "json-file" options: diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index f59430d24..7898ddf61 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -199,6 +199,7 @@ services: - WS_BACKEND=ws://nexent-runtime:5014 - RUNTIME_HTTP_BACKEND=http://nexent-runtime:5014 - MINIO_ENDPOINT=http://nexent-minio:9000 + - MARKET_BACKEND=https://market.nexent.tech logging: driver: "json-file" options: diff --git a/docker/init.sql b/docker/init.sql index a8af3d190..1181c8237 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -167,6 +167,7 @@ CREATE TABLE IF NOT EXISTS "model_record_t" ( "maximum_chunk_size" int4, "display_name" varchar(100) COLLATE "pg_catalog"."default", "connect_status" varchar(100) COLLATE "pg_catalog"."default", + "ssl_verify" boolean DEFAULT true, "create_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "delete_flag" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, "update_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, @@ -189,6 +190,7 @@ COMMENT ON COLUMN "model_record_t".expected_chunk_size IS 'Expected chunk size f COMMENT ON COLUMN "model_record_t".maximum_chunk_size IS 'Maximum chunk size for embedding models, used during document chunking'; COMMENT ON COLUMN "model_record_t"."display_name" IS 'Model name displayed directly in frontend, customized by user'; COMMENT ON COLUMN "model_record_t"."connect_status" IS 'Model connectivity status from last check, optional values: "检测中"、"可用"、"不可用"'; +COMMENT ON COLUMN "model_record_t"."ssl_verify" IS 'Whether to verify SSL certificates when connecting to this model API. Default is true. Set to false for local services without SSL support.'; COMMENT ON COLUMN "model_record_t"."create_time" IS 'Creation time, audit field'; COMMENT ON COLUMN "model_record_t"."delete_flag" IS 'When deleted by user frontend, delete flag will be set to true, achieving soft delete effect. Optional values Y/N'; COMMENT ON COLUMN "model_record_t"."update_time" IS 'Update time, audit field'; diff --git a/docker/sql/1129_add_ssl_verify_to_model_record_t.sql b/docker/sql/1129_add_ssl_verify_to_model_record_t.sql new file mode 100644 index 000000000..aa2c9d9c9 --- /dev/null +++ b/docker/sql/1129_add_ssl_verify_to_model_record_t.sql @@ -0,0 +1,5 @@ +ALTER TABLE nexent.model_record_t +ADD COLUMN ssl_verify BOOLEAN DEFAULT TRUE; + +COMMENT ON COLUMN nexent.model_record_t.ssl_verify IS 'Whether to verify SSL certificates when connecting to this model API. Default is true. Set to false for local services without SSL support.'; + diff --git a/frontend/app/[locale]/agents/AgentsContent.tsx b/frontend/app/[locale]/agents/AgentsContent.tsx index af658c0a4..72d5e66ed 100644 --- a/frontend/app/[locale]/agents/AgentsContent.tsx +++ b/frontend/app/[locale]/agents/AgentsContent.tsx @@ -74,9 +74,11 @@ export default function AgentsContent({ transition={pageTransition} style={{width: "100%", height: "100%"}} > - {canAccessProtectedData ? ( - - ) : null} +
+ {canAccessProtectedData ? ( + + ) : null} +
(null); + + // Agent import wizard states + const [importWizardVisible, setImportWizardVisible] = useState(false); + const [importWizardData, setImportWizardData] = useState(null); // Use generation state passed from parent component, not local state // Delete confirmation popup status @@ -1589,7 +1594,7 @@ export default function AgentSetupOrchestrator({ [runNormalImport, runForceImport] ); - // Handle importing agent + // Handle importing agent - use AgentImportWizard for ExportAndImportDataFormat const handleImportAgent = (t: TFunction) => { // Create a hidden file input element const fileInput = document.createElement("input"); @@ -1618,6 +1623,20 @@ export default function AgentSetupOrchestrator({ return; } + // Check if it's ExportAndImportDataFormat (has agent_id and agent_info) + if (agentInfo.agent_id && agentInfo.agent_info && typeof agentInfo.agent_info === "object") { + // Use AgentImportWizard for full agent import with configuration + const importData: ImportAgentData = { + agent_id: agentInfo.agent_id, + agent_info: agentInfo.agent_info, + mcp_info: agentInfo.mcp_info || [], + }; + setImportWizardData(importData); + setImportWizardVisible(true); + return; + } + + // Fallback to legacy import logic for other formats const normalizeValue = (value?: string | null) => typeof value === "string" ? value.trim() : ""; @@ -1700,6 +1719,13 @@ export default function AgentSetupOrchestrator({ fileInput.click(); }; + // Handle import completion from wizard + const handleImportComplete = () => { + refreshAgentList(t, false); + setImportWizardVisible(false); + setImportWizardData(null); + }; + const handleConfirmedDuplicateImport = useCallback(async () => { if (!pendingImportData) { return; @@ -2256,6 +2282,23 @@ export default function AgentSetupOrchestrator({ {t("businessLogic.config.import.duplicateDescription")}

+ {/* Agent Import Wizard */} + { + setImportWizardVisible(false); + setImportWizardData(null); + }} + initialData={importWizardData} + onImportComplete={handleImportComplete} + title={undefined} // Use default title + agentDisplayName={ + importWizardData?.agent_info?.[String(importWizardData.agent_id)]?.display_name + } + agentDescription={ + importWizardData?.agent_info?.[String(importWizardData.agent_id)]?.description + } + /> {/* Auto unselect knowledge_base_search notice when embedding not configured */} { + log.error("Failed to refresh tools and agents after deletion:", error); + }); } else { message.error(result.message); + // Throw error to prevent modal from closing + throw new Error(result.message); } } catch (error) { message.error(t("mcpConfig.message.deleteServerFailed")); + // Throw error to prevent modal from closing + throw error; } }, }); diff --git a/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx b/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx index f74c46040..88b8594bb 100644 --- a/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx +++ b/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx @@ -288,6 +288,7 @@ export default function SubAgentPool({ const isCurrentlyEditing = editingAgent && String(editingAgent.id) === String(agent.id); // Ensure type matching + const displayName = agent.display_name || agent.name; const agentItem = (
-
+
{!isAvailable && ( )} - {agent.display_name && ( - - {agent.display_name} + {displayName && ( + + {displayName} )} - - {agent.name} - {unsavedAgentId !== null && String(unsavedAgentId) === String(agent.id) && (
diff --git a/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx index c039ee22f..d1160b722 100644 --- a/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx @@ -722,7 +722,7 @@ export default function ToolConfigModal({ - {group.subGroups.map((subGroup, index) => ( - - {subGroup.label} - - } - className={`tool-category-panel ${ - index === 0 ? "mt-1" : "mt-3" - }`} - > + items={group.subGroups.map((subGroup, index) => ({ + key: subGroup.key, + label: ( + + {subGroup.label} + + ), + className: `tool-category-panel ${ + index === 0 ? "mt-1" : "mt-3" + }`, + children: (
{subGroup.tools.map((tool) => ( ))}
-
- ))} - + ), + }))} + />
) : ( @@ -650,15 +648,19 @@ function ToolPool({ {t("toolPool.tooltip.functionGuide")}
} - overlayInnerStyle={{ - backgroundColor: "#ffffff", - color: "#374151", - border: "1px solid #e5e7eb", - borderRadius: "6px", - boxShadow: "0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)", - padding: "12px", - maxWidth: "600px", - minWidth: "400px", + color="#ffffff" + styles={{ + body: { + backgroundColor: "#ffffff", + color: "#374151", + border: "1px solid #e5e7eb", + borderRadius: "6px", + boxShadow: "0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)", + padding: "12px", + maxWidth: "800px", + minWidth: "700px", + width: "fit-content", + }, }} > diff --git a/frontend/app/[locale]/chat/components/chatRightPanel.tsx b/frontend/app/[locale]/chat/components/chatRightPanel.tsx index 874dee081..80792db0a 100644 --- a/frontend/app/[locale]/chat/components/chatRightPanel.tsx +++ b/frontend/app/[locale]/chat/components/chatRightPanel.tsx @@ -1,6 +1,6 @@ import { useState, useEffect, useRef, useCallback } from "react"; import { useTranslation } from "react-i18next"; -import { ExternalLink, Database, X } from "lucide-react"; +import { ExternalLink, Database, X, Server } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; @@ -8,6 +8,8 @@ import { StaticScrollArea } from "@/components/ui/scrollArea"; import { ImageItem, ChatRightPanelProps, SearchResult } from "@/types/chat"; import { API_ENDPOINTS } from "@/services/api"; import { formatDate, formatUrl } from "@/lib/utils"; +import { convertImageUrlToApiUrl, extractObjectNameFromUrl, storageService } from "@/services/storageService"; +import { message } from "antd"; import log from "@/lib/logger"; @@ -92,30 +94,48 @@ export function ChatRightPanel({ })); try { - // Use the proxy service to get the image - const response = await fetch(API_ENDPOINTS.proxy.image(imageUrl)); - const data = await response.json(); + // Convert image URL to backend API URL + const apiUrl = convertImageUrlToApiUrl(imageUrl); + + // Use backend API to get the image + const response = await fetch(apiUrl); + + if (!response.ok) { + throw new Error(`Failed to load image: ${response.statusText}`); + } - if (data.success) { + // Get image as blob and convert to base64 + const blob = await response.blob(); + const reader = new FileReader(); + + reader.onloadend = () => { + const base64Data = reader.result as string; + // Remove data URL prefix (e.g., "data:image/png;base64,") + const base64 = base64Data.split(',')[1] || base64Data; + setImageData((prev) => ({ ...prev, [imageUrl]: { - base64Data: data.base64, - contentType: data.content_type || "image/jpeg", + base64Data: base64, + contentType: blob.type || "image/jpeg", isLoading: false, loadAttempts: currentAttempts + 1, }, })); - } else { - // If loading fails, remove it directly from the list + loadingImages.current.delete(imageUrl); + }; + + reader.onerror = () => { + log.error("Failed to read image blob"); handleImageLoadFail(imageUrl); - } + loadingImages.current.delete(imageUrl); + }; + + reader.readAsDataURL(blob); } catch (error) { log.error(t("chatRightPanel.imageProxyError"), error); // If loading fails, remove it directly from the list handleImageLoadFail(imageUrl); - } finally { - // Whether successful or not, remove the loading mark loadingImages.current.delete(imageUrl); } @@ -200,11 +220,71 @@ export function ChatRightPanel({ // Search result item component const SearchResultItem = ({ result }: { result: SearchResult }) => { const [isExpanded, setIsExpanded] = useState(false); + const [isDownloading, setIsDownloading] = useState(false); const title = result.title || t("chatRightPanel.unknownTitle"); const url = result.url || "#"; const text = result.text || t("chatRightPanel.noContentDescription"); const published_date = result.published_date || ""; const source_type = result.source_type || "url"; + const filename = result.filename || result.title || ""; + const datamateDatasetId = result.score_details?.datamate_dataset_id; + const datamateFileId = result.score_details?.datamate_file_id; + const datamateBaseUrl = result.score_details?.datamate_base_url; + + // Handle file download + const handleFileDownload = async (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + + if (!filename && !url) { + message.error(t("chatRightPanel.fileDownloadError", "File name or URL is missing")); + return; + } + + setIsDownloading(true); + try { + // Handle datamate source type + if (source_type === "datamate") { + if (!datamateDatasetId || !datamateFileId || !datamateBaseUrl) { + if (!url || url === "#") { + message.error(t("chatRightPanel.fileDownloadError", "Missing Datamate dataset or file information")); + return; + } + } + await storageService.downloadDatamateFile({ + url: url !== "#" ? url : undefined, + baseUrl: datamateBaseUrl, + datasetId: datamateDatasetId, + fileId: datamateFileId, + filename: filename || undefined, + }); + message.success(t("chatRightPanel.fileDownloadSuccess", "File download started")); + return; + } + + // Handle regular file source type (source_type === "file") + // For knowledge base files, backend stores the MinIO object_name in path_or_url, + // so we should always try to extract it from the URL and avoid guessing from filename. + let objectName: string | undefined = undefined; + + if (url && url !== "#") { + objectName = extractObjectNameFromUrl(url) || undefined; + } + + if (!objectName) { + message.error(t("chatRightPanel.fileDownloadError", "Cannot determine file object name")); + return; + } + + await storageService.downloadFile(objectName, filename || "download"); + message.success(t("chatRightPanel.fileDownloadSuccess", "File download started")); + } catch (error) { + log.error("Failed to download file:", error); + message.error(t("chatRightPanel.fileDownloadError", "Failed to download file. Please try again.")); + } finally { + setIsDownloading(false); + } + }; return (
@@ -227,6 +307,29 @@ export function ChatRightPanel({ > {title} + ) : source_type === "file" || source_type === "datamate" ? ( + + {isDownloading ? ( + + + {t("chatRightPanel.downloading", "Downloading...")} + + ) : ( + title + )} + ) : (
-
- {source_type === "url" ? ( - - ) : source_type === "file" ? ( - - ) : null} -
- - {formatUrl(result)} - + {source_type === "file" || source_type === "datamate" ? ( + <> + +
+
+ +
+
+ {source_type === "datamate" + ? t("chatRightPanel.source.datamate", "Source: Datamate") + : source_type === "file" + ? t("chatRightPanel.source.nexent", "Source: Nexent") + : ""} +
+
+ + ) : ( +
+
+ +
+ + {formatUrl(result)} + +
+ )}
{text.length > 150 && ( diff --git a/frontend/app/[locale]/chat/internal/chatAttachment.tsx b/frontend/app/[locale]/chat/internal/chatAttachment.tsx index ff7b5ceb8..c08ece8f7 100644 --- a/frontend/app/[locale]/chat/internal/chatAttachment.tsx +++ b/frontend/app/[locale]/chat/internal/chatAttachment.tsx @@ -2,6 +2,9 @@ import { chatConfig } from "@/const/chatConfig"; import { useState } from "react"; import { useTranslation } from "react-i18next"; import { ExternalLink } from "lucide-react"; +import { storageService, convertImageUrlToApiUrl, extractObjectNameFromUrl } from "@/services/storageService"; +import { message } from "antd"; +import log from "@/lib/logger"; import { AiFillFileImage, AiFillFilePdf, @@ -37,6 +40,9 @@ const ImageViewer = ({ }) => { if (!isOpen) return null; const { t } = useTranslation("common"); + + // Convert image URL to backend API URL + const imageUrl = convertImageUrlToApiUrl(url); return ( @@ -47,7 +53,7 @@ const ImageViewer = ({
- Full size + Full size
@@ -56,13 +62,15 @@ const ImageViewer = ({ // File viewer component const FileViewer = ({ + objectName, url, name, contentType, isOpen, onClose, }: { - url: string; + objectName?: string; + url?: string; name: string; contentType?: string; isOpen: boolean; @@ -70,6 +78,109 @@ const FileViewer = ({ }) => { if (!isOpen) return null; const { t } = useTranslation("common"); + const [isDownloading, setIsDownloading] = useState(false); + + + // Handle file download + const handleDownload = async (e: React.MouseEvent) => { + // Prevent dialog from closing immediately + e.preventDefault(); + e.stopPropagation(); + + // Check if URL is a direct http/https URL that can be accessed directly + // Exclude backend API endpoints (containing /api/file/download/) + if ( + url && + (url.startsWith("http://") || url.startsWith("https://")) && + !url.includes("/api/file/download/") + ) { + // Direct download from HTTP/HTTPS URL without backend + const link = document.createElement("a"); + link.href = url; + link.download = name; + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + setTimeout(() => { + document.body.removeChild(link); + }, 100); + message.success(t("chatAttachment.downloadSuccess", "File download started")); + setTimeout(() => { + onClose(); + }, 500); + return; + } + + // Try to get object_name from props or extract from URL + let finalObjectName: string | undefined = objectName; + + if (!finalObjectName && url) { + finalObjectName = extractObjectNameFromUrl(url) || undefined; + } + + if (!finalObjectName) { + // If we still don't have object_name, fall back to direct URL download + if (url) { + // Create a temporary link to download from URL + const link = document.createElement("a"); + link.href = url; + link.download = name; + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + setTimeout(() => { + document.body.removeChild(link); + }, 100); + message.success(t("chatAttachment.downloadSuccess", "File download started")); + return; + } else { + message.error(t("chatAttachment.downloadError", "File object name or URL is missing")); + return; + } + } + + setIsDownloading(true); + try { + // Start download (non-blocking, browser handles it) + await storageService.downloadFile(finalObjectName, name); + // Show success message immediately after triggering download + message.success(t("chatAttachment.downloadSuccess", "File download started")); + // Keep dialog open for a moment to show the message, then close + setTimeout(() => { + setIsDownloading(false); + onClose(); + }, 500); + } catch (error) { + log.error("Failed to download file:", error); + setIsDownloading(false); + // If backend download fails and we have URL, try direct download as fallback + if (url) { + try { + const link = document.createElement("a"); + link.href = url; + link.download = name; + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + setTimeout(() => { + document.body.removeChild(link); + }, 100); + message.success(t("chatAttachment.downloadSuccess", "File download started")); + setTimeout(() => { + onClose(); + }, 500); + } catch (fallbackError) { + message.error( + t("chatAttachment.downloadError", "Failed to download file. Please try again.") + ); + } + } else { + message.error( + t("chatAttachment.downloadError", "Failed to download file. Please try again.") + ); + } + } + }; return ( @@ -89,15 +200,17 @@ const FileViewer = ({

{t("chatAttachment.previewNotSupported")}

- - {t("chatAttachment.downloadToView")} - + {isDownloading + ? t("chatAttachment.downloading", "Downloading...") + : t("chatAttachment.downloadToView")} +
@@ -183,7 +296,8 @@ export function ChatAttachment({ }: ChatAttachmentProps) { const [selectedImage, setSelectedImage] = useState(null); const [selectedFile, setSelectedFile] = useState<{ - url: string; + objectName?: string; + url?: string; name: string; contentType?: string; } | null>(null); @@ -218,6 +332,7 @@ export function ChatAttachment({ } else { // For files, use internal preview setSelectedFile({ + objectName: attachment.object_name, url: attachment.url, name: attachment.name, contentType: attachment.contentType, @@ -252,7 +367,7 @@ export function ChatAttachment({
{attachment.url && ( {attachment.name} e.stopPropagation()} > {t("chatInterface.imagePreview")} { diff --git a/frontend/app/[locale]/chat/streaming/taskWindow.tsx b/frontend/app/[locale]/chat/streaming/taskWindow.tsx index 0180c5b7c..cb1d1cc94 100644 --- a/frontend/app/[locale]/chat/streaming/taskWindow.tsx +++ b/frontend/app/[locale]/chat/streaming/taskWindow.tsx @@ -17,6 +17,7 @@ import { MarkdownRenderer } from "@/components/ui/markdownRenderer"; import { chatConfig } from "@/const/chatConfig"; import { ChatMessageType, TaskMessageType, CardItem, MessageHandler } from "@/types/chat"; import { useChatTaskMessage } from "@/hooks/useChatTaskMessage"; +import { storageService, extractObjectNameFromUrl } from "@/services/storageService"; import log from "@/lib/logger"; // Icon mapping dictionary - map strings to corresponding icon components @@ -31,6 +32,23 @@ const iconMap: Record = { default: , // Default icon }; +type KnowledgeSiteInfo = { + key: string; + domain: string; + displayName: string; + faviconUrl: string; + useDefaultIcon: boolean; + isKnowledgeBase: boolean; + sourceType: string; + url: string; + filename: string; + datamateDatasetId?: string; + datamateFileId?: string; + datamateBaseUrl?: string; + objectName?: string; + canOpenWeb: boolean; +}; + // Define the handlers for different types of messages to improve extensibility const messageHandlers: MessageHandler[] = [ // Preprocess type processor - handles contents array logic @@ -126,77 +144,188 @@ const messageHandlers: MessageHandler[] = [ } ); - // Process website information for display - const siteInfos = uniqueSearchResults.map((result: any) => { - const pageUrl = result.url || ""; - const filename = result.filename || ""; - const sourceType = result.source_type || ""; - let domain = t("taskWindow.unknownSource"); - let displayName = t("taskWindow.unknownSource"); - let baseUrl = ""; - let faviconUrl = ""; - let useDefaultIcon = false; - let isKnowledgeBase = false; - let canClick = true; // whether to allow click to jump + // Process website / knowledge base information for display + const siteInfos: KnowledgeSiteInfo[] = uniqueSearchResults.map( + (result: any, index: number) => { + const pageUrl = result.url || ""; + const filename = result.filename || result.title || ""; + const sourceType = result.source_type || (filename ? "file" : "url"); + const scoreDetails = result.score_details || {}; + const datamateDatasetId = + scoreDetails?.datamate_dataset_id || scoreDetails?.dataset_id; + const datamateFileId = + scoreDetails?.datamate_file_id || scoreDetails?.file_id; + const datamateBaseUrl = + scoreDetails?.datamate_base_url || + scoreDetails?.datamate_baseUrl || + scoreDetails?.base_url; + const objectName = + result.object_name || + scoreDetails?.object_name || + scoreDetails?.minio_object_name; + + let domain = t("taskWindow.unknownSource"); + let displayName = t("taskWindow.unknownSource"); + let baseUrl = ""; + let faviconUrl = ""; + let useDefaultIcon = false; + let isKnowledgeBase = + sourceType === "file" || + sourceType === "datamate" || + (!sourceType && !!filename); + let canOpenWeb = false; + + if (isKnowledgeBase) { + displayName = + filename || result.title || t("taskWindow.knowledgeFile"); + domain = + datamateBaseUrl || + (pageUrl && pageUrl !== "#" + ? (() => { + try { + return new URL(pageUrl).hostname; + } catch { + return t("taskWindow.unknownSource"); + } + })() + : t("taskWindow.unknownSource")); + useDefaultIcon = true; + } else if (pageUrl && pageUrl !== "#") { + try { + const parsedUrl = new URL(pageUrl); + baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}`; + domain = parsedUrl.hostname; + + displayName = domain + .replace(/^www\./, "") + .replace( + /\.(com|cn|org|net|io|gov|edu|co|info|biz|xyz)(\.[a-z]{2})?$/, + "" + ); + if (!displayName) { + displayName = domain; + } + + faviconUrl = `${baseUrl}/favicon.ico`; + canOpenWeb = true; + } catch (e) { + log.error(t("taskWindow.urlParseError"), e); + useDefaultIcon = true; + canOpenWeb = false; + } + } else { + useDefaultIcon = true; + canOpenWeb = false; + } - // first judge based on source_type - if (sourceType === "file") { - isKnowledgeBase = true; - displayName = - filename || result.title || t("taskWindow.knowledgeFile"); - useDefaultIcon = true; - canClick = false; // file type does not allow jump - } - // if there is no source_type, judge based on filename (compatibility processing) - else if (filename) { - isKnowledgeBase = true; - displayName = filename; - useDefaultIcon = true; - canClick = false; // file type does not allow jump + return { + key: `site-${index}-${result.cite_index ?? ""}-${filename ?? ""}`, + domain, + displayName, + faviconUrl, + url: pageUrl, + useDefaultIcon, + isKnowledgeBase, + filename, + sourceType, + datamateDatasetId, + datamateFileId, + datamateBaseUrl, + objectName, + canOpenWeb, + }; } - // handle webpage link - else if (pageUrl && pageUrl !== "#") { - try { - const parsedUrl = new URL(pageUrl); - baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}`; - domain = parsedUrl.hostname; + ); - // Process the domain, remove the www prefix and com/cn etc. suffix - displayName = domain - .replace(/^www\./, "") // Remove the www. prefix - .replace( - /\.(com|cn|org|net|io|gov|edu|co|info|biz|xyz)(\.[a-z]{2})?$/, - "" - ); // Remove common suffixes + const handleKnowledgeFileDownload = async ( + site: KnowledgeSiteInfo + ): Promise => { + try { + if (site.sourceType === "datamate") { + if ( + !site.datamateDatasetId && + !site.datamateFileId && + (!site.url || site.url === "#") + ) { + message.error( + t( + "taskWindow.downloadError", + "Missing Datamate dataset or file information" + ) + ); + return; + } - // If the processing is empty, use the original domain - if (!displayName) { - displayName = domain; + await storageService.downloadDatamateFile({ + url: site.url && site.url !== "#" ? site.url : undefined, + baseUrl: site.datamateBaseUrl, + datasetId: site.datamateDatasetId, + fileId: site.datamateFileId, + filename: site.filename || undefined, + }); + } else { + // Check if URL is a direct http/https URL that can be accessed directly + // Exclude backend API endpoints (containing /api/file/download/) + if ( + site.url && + site.url !== "#" && + (site.url.startsWith("http://") || site.url.startsWith("https://")) && + !site.url.includes("/api/file/download/") + ) { + // Direct download from HTTP/HTTPS URL without backend + const link = document.createElement("a"); + link.href = site.url; + link.download = site.filename || "download"; + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + setTimeout(() => { + document.body.removeChild(link); + }, 100); + message.success( + t("taskWindow.downloadSuccess", "File download started") + ); + return; } - faviconUrl = `${baseUrl}/favicon.ico`; - canClick = true; - } catch (e) { - log.error(t("taskWindow.urlParseError"), e); - useDefaultIcon = true; - canClick = false; + let objectName = site.objectName; + if (!objectName && site.url) { + objectName = + extractObjectNameFromUrl(site.url) || undefined; + } + if (!objectName && site.filename) { + objectName = site.filename.includes("/") + ? site.filename + : `attachments/${site.filename}`; + } + if (!objectName) { + message.error( + t( + "taskWindow.downloadError", + "Failed to download file. Please try again." + ) + ); + return; + } + await storageService.downloadFile( + objectName, + site.filename || undefined + ); } - } else { - useDefaultIcon = true; - canClick = false; - } - return { - domain, - displayName, - faviconUrl, - url: pageUrl, - useDefaultIcon, - isKnowledgeBase, - filename, - canClick, - }; - }); + message.success( + t("taskWindow.downloadSuccess", "File download started") + ); + } catch (error) { + log.error("Failed to download knowledge file:", error); + message.error( + t( + "taskWindow.downloadError", + "Failed to download file. Please try again." + ) + ); + } + }; // Render the search result information bar return ( @@ -237,9 +366,11 @@ const messageHandlers: MessageHandler[] = [ gap: "0.5rem", }} > - {siteInfos.map((site: any, index: number) => ( + {siteInfos.map((site) => { + const isClickable = site.isKnowledgeBase || site.canOpenWeb; + return (
{ - if (site.canClick && site.url) { + if (site.isKnowledgeBase) { + handleKnowledgeFileDownload(site); + } else if (site.canOpenWeb && site.url) { window.open(site.url, "_blank", "noopener,noreferrer"); } }} onMouseEnter={(e) => { - if (site.canClick) { + if (isClickable) { e.currentTarget.style.backgroundColor = "#f3f4f6"; } }} onMouseLeave={(e) => { - if (site.canClick) { + if (isClickable) { e.currentTarget.style.backgroundColor = "#f9fafb"; } }} title={ - site.canClick + site.isKnowledgeBase + ? t("taskWindow.downloadFile", { + name: site.filename || site.displayName, + }) + : site.canOpenWeb ? t("taskWindow.visit", { domain: site.domain }) : site.filename || site.displayName } @@ -314,9 +449,26 @@ const messageHandlers: MessageHandler[] = [ }} /> )} - {site.displayName} + + {site.displayName} + {site.isKnowledgeBase && ( + + )} +
- ))} + ); + })}
@@ -967,6 +1119,19 @@ export function TaskWindow({ messages, isStreaming = false }: TaskWindowProps) { ); }; + // Error messages that should be completely hidden (including the node) + const suppressedErrorMessages = [ + "Model is interrupted by stop event", + "Agent execution interrupted by external stop signal", + ]; + + // Check if a message should be suppressed (not displayed at all) + const shouldSuppressMessage = (message: any) => { + if (message.type !== "error") return false; + const content = message.content || ""; + return suppressedErrorMessages.some((errText) => content.includes(errText)); + }; + // Check if it is the last message const isLastMessage = (index: number, messages: any[]) => { return index === messages.length - 1; @@ -996,15 +1161,20 @@ export function TaskWindow({ messages, isStreaming = false }: TaskWindowProps) { ); } + // Filter out messages that should be suppressed + const filteredGroupedMessages = groupedMessages.filter( + (group) => !shouldSuppressMessage(group.message) + ); + return (
- {groupedMessages.map((group, groupIndex) => { + {filteredGroupedMessages.map((group, groupIndex) => { const message = group.message; const isBlinking = shouldBlinkDot( groupIndex, - groupedMessages.map((g) => g.message) + filteredGroupedMessages.map((g) => g.message) ); return ( diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index b2f7b4f24..995eea580 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -4,7 +4,7 @@ import type React from "react"; import { useState, useEffect, useRef, useLayoutEffect } from "react"; import { useTranslation } from "react-i18next"; -import { App, Modal } from "antd"; +import { App, Modal, Row, Col } from "antd"; import { InfoCircleFilled, WarningFilled } from "@ant-design/icons"; import { DOCUMENT_ACTION_TYPES, @@ -19,7 +19,7 @@ import { KnowledgeBase } from "@/types/knowledgeBase"; import { useConfig } from "@/hooks/useConfig"; import { SETUP_PAGE_CONTAINER, - FLEX_TWO_COLUMN_LAYOUT, + TWO_COLUMN_LAYOUT, STANDARD_CARD, } from "@/const/layoutConstants"; @@ -757,6 +757,40 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(name); }; + // If Embedding model is not configured, show warning container instead of content + if (showEmbeddingWarning) { + return ( +
+
+
+ +
+ {t("embedding.knowledgeBaseDisabledWarningModal.title")} +
+
+
+
+ ); + } + return ( <>
- {showEmbeddingWarning && ( -
- )} - contentRef.current || document.body} - styles={{ body: { padding: 0 } }} - rootClassName="kb-embedding-warning" - > -
-
- -
-
- {t("embedding.knowledgeBaseDisabledWarningModal.title")} -
-
-
-
-
-
- {/* Left knowledge base list - occupies 1/3 space */} -
+ + {}} // No need to trigger repeatedly here as it's already handled in handleKnowledgeBaseClick /> -
- - {/* Right content area - occupies 2/3 space, now unified with config.tsx style */} -
+ + + {isCreatingMode ? (
)} -
-
+ +
); diff --git a/frontend/app/[locale]/knowledges/KnowledgesContent.tsx b/frontend/app/[locale]/knowledges/KnowledgesContent.tsx index cb8f0731b..779ab47dc 100644 --- a/frontend/app/[locale]/knowledges/KnowledgesContent.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgesContent.tsx @@ -99,7 +99,9 @@ export default function KnowledgesContent({ transition={pageTransition} style={{width: "100%", height: "100%"}} > - +
+ +
) : null} diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx index b6e1d118b..0022eec70 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx @@ -26,6 +26,8 @@ import { FilePlus2, Goal, X, + Server, + Database, } from "lucide-react"; import { FieldNumberOutlined } from "@ant-design/icons"; import knowledgeBaseService from "@/services/knowledgeBaseService"; @@ -47,6 +49,7 @@ interface Chunk { filename?: string; create_time?: string; score?: number; // Search score (0-1 range) - only present in search results + source_type?: string; // Source type: "file" (nexent) or "datamate" } interface ChunkFormValues { @@ -289,6 +292,7 @@ const DocumentChunk: React.FC = ({ filename: item.filename, create_time: item.create_time, score: item.score, // Preserve search score for display + source_type: item.source_type, // Preserve source type for display }; }); @@ -657,6 +661,37 @@ const DocumentChunk: React.FC = ({ } > + {/* Display filename and source type if available */} + {chunk.filename && ( +
+
+
+
+ +
+
+ {chunk.filename} +
+
+ {chunk.source_type && ( +
+
+ +
+
+ {chunk.source_type === "datamate" + ? t("document.chunk.source.datamate", "来源: Datamate") + : chunk.source_type === "file" || + chunk.source_type === "minio" || + chunk.source_type === "local" + ? t("document.chunk.source.nexent", "来源: Nexent") + : ""} +
+
+ )} +
+
+ )}
{chunk.content || ""}
diff --git a/frontend/app/[locale]/layout.tsx b/frontend/app/[locale]/layout.tsx index 0a5822d8f..26a6f9d6b 100644 --- a/frontend/app/[locale]/layout.tsx +++ b/frontend/app/[locale]/layout.tsx @@ -16,25 +16,28 @@ import log from "@/lib/logger"; const inter = Inter({ subsets: ["latin"] }); export async function generateMetadata(props: { - params: Promise<{ locale: string }>; + params: Promise<{ locale?: string }>; }): Promise { const { locale } = await props.params; + const resolvedLocale = (["zh", "en"].includes(locale ?? "") + ? locale + : "zh") as "zh" | "en"; let messages: any = {}; - if (["zh", "en"].includes(locale)) { + if (["zh", "en"].includes(resolvedLocale)) { try { const filePath = path.join( process.cwd(), "public", "locales", - locale, + resolvedLocale, "common.json" ); const fileContent = await fs.readFile(filePath, "utf8"); messages = JSON.parse(fileContent); } catch (error) { log.error( - `Failed to load i18n messages for locale: ${locale}`, + `Failed to load i18n messages for locale: ${resolvedLocale}`, error ); } @@ -54,15 +57,20 @@ export async function generateMetadata(props: { }; } -export default async function RootLayout(props: { +export default async function RootLayout({ + children, + params, +}: { children: ReactNode; - params: Promise<{ locale: string }>; + params: Promise<{ locale?: string }>; }) { - const { children, params } = props; const { locale } = await params; + const resolvedLocale = (["zh", "en"].includes(locale ?? "") + ? locale + : "zh") as "zh" | "en"; return ( - + void; } -interface ConfigField { - fieldPath: string; // e.g., "duty_prompt", "tools[0].params.api_key" - fieldLabel: string; // User-friendly label - promptHint?: string; // Hint from - currentValue: string; -} - -interface McpServerToInstall { - mcp_server_name: string; - mcp_url: string; - isInstalled: boolean; - isUrlEditable: boolean; // true if url is - editedUrl?: string; -} - -const needsConfig = (value: any): boolean => { - if (typeof value === "string") { - return value.trim() === "" || value.trim().startsWith(" { - if (typeof value !== "string") return undefined; - const match = value.trim().match(/^$/); - return match ? match[1] : undefined; -}; - export default function AgentInstallModal({ visible, onCancel, agentDetails, onInstallComplete, }: AgentInstallModalProps) { - const { t, i18n } = useTranslation("common"); - const isZh = i18n.language === "zh" || i18n.language === "zh-CN"; - const { message } = App.useApp(); - - // Use unified import hook - const { importFromData, isImporting: isInstallingAgent } = useAgentImport({ - onSuccess: () => { - onInstallComplete?.(); - }, - onError: (error) => { - message.error(error.message || t("market.install.error.installFailed", "Failed to install agent")); - }, - }); - - const [currentStep, setCurrentStep] = useState(0); - const [llmModels, setLlmModels] = useState([]); - const [loadingModels, setLoadingModels] = useState(false); - const [selectedModelId, setSelectedModelId] = useState(null); - const [selectedModelName, setSelectedModelName] = useState(""); - - const [configFields, setConfigFields] = useState([]); - const [configValues, setConfigValues] = useState>({}); - - const [mcpServers, setMcpServers] = useState([]); - const [existingMcpServers, setExistingMcpServers] = useState([]); - const [loadingMcpServers, setLoadingMcpServers] = useState(false); - const [installingMcp, setInstallingMcp] = useState>({}); - - // Load LLM models - useEffect(() => { - if (visible) { - loadLLMModels(); - } - }, [visible]); - - // Parse agent details for config fields and MCP servers - useEffect(() => { - if (visible && agentDetails) { - parseConfigFields(); - parseMcpServers(); - } - }, [visible, agentDetails]); - - const loadLLMModels = async () => { - setLoadingModels(true); - try { - const models = await modelService.getLLMModels(); - setLlmModels(models.filter(m => m.connect_status === "available")); - - // Auto-select first available model - if (models.length > 0 && models[0].connect_status === "available") { - setSelectedModelId(models[0].id); - setSelectedModelName(models[0].displayName); - } - } catch (error) { - log.error("Failed to load LLM models:", error); - message.error(t("market.install.error.loadModels", "Failed to load models")); - } finally { - setLoadingModels(false); - } - }; - - const parseConfigFields = () => { - if (!agentDetails) return; - - const fields: ConfigField[] = []; - - // Check basic fields (excluding MCP-related fields) - const basicFields: Array<{ key: keyof MarketAgentDetail; label: string }> = [ - { key: "description", label: t("market.detail.description", "Description") }, - { key: "business_description", label: t("market.detail.businessDescription", "Business Description") }, - { key: "duty_prompt", label: t("market.detail.dutyPrompt", "Duty Prompt") }, - { key: "constraint_prompt", label: t("market.detail.constraintPrompt", "Constraint Prompt") }, - { key: "few_shots_prompt", label: t("market.detail.fewShotsPrompt", "Few Shots Prompt") }, - ]; - - basicFields.forEach(({ key, label }) => { - const value = agentDetails[key]; - if (needsConfig(value)) { - fields.push({ - fieldPath: key, - fieldLabel: label, - promptHint: extractPromptHint(value as string), - currentValue: value as string, - }); - } - }); - - // Check tool params (excluding MCP server names/urls) - agentDetails.tools?.forEach((tool, toolIndex) => { - if (tool.params && typeof tool.params === "object") { - Object.entries(tool.params).forEach(([paramKey, paramValue]) => { - if (needsConfig(paramValue)) { - fields.push({ - fieldPath: `tools[${toolIndex}].params.${paramKey}`, - fieldLabel: `${tool.name || tool.class_name} - ${paramKey}`, - promptHint: extractPromptHint(paramValue as string), - currentValue: paramValue as string, - }); - } - }); - } - }); - - setConfigFields(fields); - - // Initialize config values - const initialValues: Record = {}; - fields.forEach(field => { - initialValues[field.fieldPath] = ""; - }); - setConfigValues(initialValues); - }; - - const parseMcpServers = async () => { - if (!agentDetails?.mcp_servers || agentDetails.mcp_servers.length === 0) { - setMcpServers([]); - return; - } - - setLoadingMcpServers(true); - try { - // Load existing MCP servers from system - const result = await getMcpServerList(); - const existing = result.success ? result.data : []; - setExistingMcpServers(existing); - - // Check each MCP server - const serversToInstall: McpServerToInstall[] = agentDetails.mcp_servers.map(mcp => { - const isUrlConfigNeeded = needsConfig(mcp.mcp_url); - - // Check if already installed (match by both name and url) - const isInstalled = !isUrlConfigNeeded && existing.some( - (existingMcp: McpServer) => - existingMcp.service_name === mcp.mcp_server_name && - existingMcp.mcp_url === mcp.mcp_url - ); - - return { - mcp_server_name: mcp.mcp_server_name, - mcp_url: mcp.mcp_url, - isInstalled, - isUrlEditable: isUrlConfigNeeded, - editedUrl: isUrlConfigNeeded ? "" : mcp.mcp_url, - }; - }); - - setMcpServers(serversToInstall); - } catch (error) { - log.error("Failed to check MCP servers:", error); - message.error(t("market.install.error.checkMcp", "Failed to check MCP servers")); - } finally { - setLoadingMcpServers(false); - } - }; - - const handleMcpUrlChange = (index: number, newUrl: string) => { - setMcpServers(prev => { - const updated = [...prev]; - updated[index].editedUrl = newUrl; - return updated; - }); - }; - - const handleInstallMcp = async (index: number) => { - const mcp = mcpServers[index]; - const urlToUse = mcp.editedUrl || mcp.mcp_url; - - if (!urlToUse || urlToUse.trim() === "") { - message.error(t("market.install.error.mcpUrlRequired", "MCP URL is required")); - return; - } - - const key = `${index}`; - setInstallingMcp(prev => ({ ...prev, [key]: true })); - - try { - const result = await addMcpServer(urlToUse, mcp.mcp_server_name); - if (result.success) { - message.success(t("market.install.success.mcpInstalled", "MCP server installed successfully")); - // Mark as installed - update state directly without re-fetching - setMcpServers(prev => { - const updated = [...prev]; - updated[index].isInstalled = true; - updated[index].editedUrl = urlToUse; - return updated; - }); - } else { - message.error(result.message || t("market.install.error.mcpInstall", "Failed to install MCP server")); + // Convert MarketAgentDetail to ImportAgentData format + const importData: ImportAgentData | null = agentDetails?.agent_json + ? { + agent_id: agentDetails.agent_id, + agent_info: agentDetails.agent_json.agent_info, + mcp_info: agentDetails.agent_json.mcp_info, } - } catch (error) { - log.error("Failed to install MCP server:", error); - message.error(t("market.install.error.mcpInstall", "Failed to install MCP server")); - } finally { - setInstallingMcp(prev => ({ ...prev, [key]: false })); - } - }; - - const handleNext = () => { - if (currentStep === 0) { - // Step 1: Model selection validation - if (!selectedModelId || !selectedModelName) { - message.error(t("market.install.error.modelRequired", "Please select a model")); - return; - } - } else if (currentStep === 1) { - // Step 2: Config fields validation - const emptyFields = configFields.filter(field => !configValues[field.fieldPath]?.trim()); - if (emptyFields.length > 0) { - message.error(t("market.install.error.configRequired", "Please fill in all required fields")); - return; - } - } - - setCurrentStep(prev => prev + 1); - }; - - const handlePrevious = () => { - setCurrentStep(prev => prev - 1); - }; - - const handleInstall = async () => { - try { - // Prepare the data structure for import - const importData = prepareImportData(); - - if (!importData) { - message.error(t("market.install.error.invalidData", "Invalid agent data")); - return; - } - - log.info("Importing agent with data:", importData); - - // Import using unified hook - await importFromData(importData); - - // Success message will be shown by onSuccess callback - message.success(t("market.install.success", "Agent installed successfully!")); - } catch (error) { - // Error message will be shown by onError callback - log.error("Failed to install agent:", error); - } - }; - - const prepareImportData = () => { - if (!agentDetails) return null; - - // Clone agent_json structure - const agentJson = JSON.parse(JSON.stringify(agentDetails.agent_json)); - - // Update model information - const agentInfo = agentJson.agent_info[String(agentDetails.agent_id)]; - if (agentInfo) { - agentInfo.model_id = selectedModelId; - agentInfo.model_name = selectedModelName; - - // Clear business logic model fields - agentInfo.business_logic_model_id = null; - agentInfo.business_logic_model_name = null; - - // Update config fields - configFields.forEach(field => { - const value = configValues[field.fieldPath]; - if (field.fieldPath.includes("tools[")) { - // Handle tool params - const match = field.fieldPath.match(/tools\[(\d+)\]\.params\.(.+)/); - if (match && agentInfo.tools) { - const toolIndex = parseInt(match[1]); - const paramKey = match[2]; - if (agentInfo.tools[toolIndex]) { - agentInfo.tools[toolIndex].params[paramKey] = value; - } - } - } else { - // Handle basic fields - agentInfo[field.fieldPath] = value; - } - }); - - // Update MCP info - if (agentJson.mcp_info) { - agentJson.mcp_info = agentJson.mcp_info.map((mcp: any) => { - const matchingServer = mcpServers.find( - s => s.mcp_server_name === mcp.mcp_server_name - ); - if (matchingServer && matchingServer.editedUrl) { - return { - ...mcp, - mcp_url: matchingServer.editedUrl, - }; - } - return mcp; - }); - } - } - - return agentJson; - }; - - const handleCancel = () => { - // Reset state - setCurrentStep(0); - setSelectedModelId(null); - setSelectedModelName(""); - setConfigFields([]); - setConfigValues({}); - setMcpServers([]); - onCancel(); - }; - - // Filter only required steps for navigation - const steps = [ - { - key: "model", - title: t("market.install.step.model", "Select Model"), - }, - configFields.length > 0 && { - key: "config", - title: t("market.install.step.config", "Configure Fields"), - }, - mcpServers.length > 0 && { - key: "mcp", - title: t("market.install.step.mcp", "MCP Servers"), - }, - ].filter(Boolean) as Array<{ key: string; title: string }>; - - // Check if can proceed to next step - const canProceed = () => { - const currentStepKey = steps[currentStep]?.key; - - if (currentStepKey === "model") { - return selectedModelId !== null && selectedModelName !== ""; - } else if (currentStepKey === "config") { - return configFields.every(field => configValues[field.fieldPath]?.trim()); - } else if (currentStepKey === "mcp") { - // All non-editable MCPs should be installed or have edited URLs - return mcpServers.every(mcp => - mcp.isInstalled || - (mcp.isUrlEditable && mcp.editedUrl && mcp.editedUrl.trim() !== "") || - (!mcp.isUrlEditable && mcp.mcp_url && mcp.mcp_url.trim() !== "") - ); - } - - return true; - }; - - const renderStepContent = () => { - const currentStepKey = steps[currentStep]?.key; - - if (currentStepKey === "model") { - return ( -
- {/* Agent Info - Title and Description Style */} - {agentDetails && ( -
-

- {agentDetails.display_name} -

-

- {agentDetails.description} -

-
- )} - -
-

- {t("market.install.model.description", "Select a model from your configured models to use for this agent.")} -

- -
- -
- {loadingModels ? ( - - ) : ( - - )} -
-
- - {llmModels.length === 0 && !loadingModels && ( -
- {t("market.install.model.noModels", "No available models. Please configure models first.")} -
- )} -
-
- ); - } else if (currentStepKey === "config") { - return ( -
-

- {t("market.install.config.description", "Please configure the following required fields for this agent.")} -

- -
- {configFields.map((field) => ( - - {field.fieldLabel} - * - - } - required={false} - > - { - setConfigValues(prev => ({ - ...prev, - [field.fieldPath]: e.target.value, - })); - }} - placeholder={field.promptHint || t("market.install.config.placeholder", "Enter configuration value")} - rows={3} - size="large" - /> - - ))} -
-
- ); - } else if (currentStepKey === "mcp") { - return ( -
-

- {t("market.install.mcp.description", "This agent requires the following MCP servers. Please install or configure them.")} -

- - {loadingMcpServers ? ( -
- -
- ) : ( -
- {mcpServers.map((mcp, index) => ( -
-
-
-
- - {mcp.mcp_server_name} - - {mcp.isInstalled ? ( - } color="success" className="text-sm"> - {t("market.install.mcp.installed", "Installed")} - - ) : ( - } color="default" className="text-sm"> - {t("market.install.mcp.notInstalled", "Not Installed")} - - )} -
- -
- - MCP URL: - - {(mcp.isUrlEditable || !mcp.isInstalled) ? ( - handleMcpUrlChange(index, e.target.value)} - placeholder={mcp.isUrlEditable - ? t("market.install.mcp.urlPlaceholder", "Enter MCP server URL") - : mcp.mcp_url - } - size="middle" - disabled={mcp.isInstalled} - style={{ maxWidth: "400px" }} - /> - ) : ( - - {mcp.editedUrl || mcp.mcp_url} - - )} -
-
- - {!mcp.isInstalled && ( - - )} -
-
- ))} -
- )} -
- ); - } - - return null; - }; - - const isLastStep = currentStep === steps.length - 1; + : null; return ( - - - {t("market.install.title", "Install Agent")} - - } - open={visible} - onCancel={handleCancel} - width={800} - footer={ -
- - - {currentStep > 0 && ( - - )} - {!isLastStep && ( - - )} - {isLastStep && ( - - )} - -
- } - > -
- ({ - title: step.title, - }))} - className="mb-6" - /> - -
- {renderStepContent()} -
-
-
+ ); } - diff --git a/frontend/app/[locale]/mcp-tools/McpToolsContent.tsx b/frontend/app/[locale]/mcp-tools/McpToolsContent.tsx new file mode 100644 index 000000000..89c6c03d4 --- /dev/null +++ b/frontend/app/[locale]/mcp-tools/McpToolsContent.tsx @@ -0,0 +1,128 @@ +"use client"; + +import React from "react"; +import { motion } from "framer-motion"; +import { useTranslation } from "react-i18next"; +import { Puzzle } from "lucide-react"; + +import { useSetupFlow } from "@/hooks/useSetupFlow"; +import { ConnectionStatus } from "@/const/modelConfig"; + +interface McpToolsContentProps { + /** Connection status */ + connectionStatus?: ConnectionStatus; + /** Is checking connection */ + isCheckingConnection?: boolean; + /** Check connection callback */ + onCheckConnection?: () => void; + /** Callback to expose connection status */ + onConnectionStatusChange?: (status: ConnectionStatus) => void; +} + +/** + * McpToolsContent - MCP tools management coming soon page + * This will allow admins to manage MCP servers and tools + */ +export default function McpToolsContent({ + connectionStatus: externalConnectionStatus, + isCheckingConnection: externalIsCheckingConnection, + onCheckConnection: externalOnCheckConnection, + onConnectionStatusChange, +}: McpToolsContentProps) { + const { t } = useTranslation("common"); + + // Use custom hook for common setup flow logic + const { canAccessProtectedData, pageVariants, pageTransition } = useSetupFlow({ + requireAdmin: true, + externalConnectionStatus, + externalIsCheckingConnection, + onCheckConnection: externalOnCheckConnection, + onConnectionStatusChange, + }); + + return ( + <> + {canAccessProtectedData ? ( + +
+ {/* Icon */} + + + + + {/* Title */} + + {t("mcpTools.comingSoon.title")} + + + {/* Description */} + + {t("mcpTools.comingSoon.description")} + + + {/* Feature list */} + +
  • + + + {t("mcpTools.comingSoon.feature1")} + +
  • +
  • + + + {t("mcpTools.comingSoon.feature2")} + +
  • +
  • + + + {t("mcpTools.comingSoon.feature3")} + +
  • +
    + + {/* Coming soon badge */} + + {t("mcpTools.comingSoon.badge")} + +
    +
    + ) : null} + + ); +} + + diff --git a/frontend/app/[locale]/memory/MemoryContent.tsx b/frontend/app/[locale]/memory/MemoryContent.tsx index e294074c6..dd35eab82 100644 --- a/frontend/app/[locale]/memory/MemoryContent.tsx +++ b/frontend/app/[locale]/memory/MemoryContent.tsx @@ -349,29 +349,31 @@ export default function MemoryContent({ onNavigate }: MemoryContentProps) { style={{ width: "100%", height: "100%" }} > {canAccessProtectedData ? ( -
    +
    - memory.setActiveTabKey(key)} - tabBarStyle={{ - marginBottom: "16px", +
    + > + memory.setActiveTabKey(key)} + tabBarStyle={{ + marginBottom: "16px", + }} + /> +
    ) : null} diff --git a/frontend/app/[locale]/models/ModelsContent.tsx b/frontend/app/[locale]/models/ModelsContent.tsx index e53fead10..6b48e8dac 100644 --- a/frontend/app/[locale]/models/ModelsContent.tsx +++ b/frontend/app/[locale]/models/ModelsContent.tsx @@ -130,16 +130,18 @@ export default function ModelsContent({ transition={pageTransition} style={{width: "100%", height: "100%"}} > - {canAccessProtectedData ? ( - - setLiveSelectedModels(selected) - } - onEmbeddingConnectivityChange={() => {}} - forwardedRef={modelConfigSectionRef} - canAccessProtectedData={canAccessProtectedData} - /> - ) : null} +
    + {canAccessProtectedData ? ( + + setLiveSelectedModels(selected) + } + onEmbeddingConnectivityChange={() => {}} + forwardedRef={modelConfigSectionRef} + canAccessProtectedData={canAccessProtectedData} + /> + ) : null} +
    void; onSuccess: (model?: AddedModel) => Promise; + defaultProvider?: string; // Default provider to select when dialog opens } // Connectivity status type comes from utils @@ -129,6 +130,7 @@ export const ModelAddDialog = ({ isOpen, onClose, onSuccess, + defaultProvider, }: ModelAddDialogProps) => { const { t } = useTranslation(); const { message } = App.useApp(); @@ -166,7 +168,7 @@ export const ModelAddDialog = ({ isMultimodal: false, // Whether to import multiple models at once isBatchImport: false, - provider: "silicon", + provider: "modelengine", vectorDimension: "1024", // Default chunk size range for embedding models chunkSizeRange: [ @@ -219,6 +221,17 @@ export const ModelAddDialog = ({ setLoadingModelList, }); + // Handle default provider when dialog opens + useEffect(() => { + if (isOpen && defaultProvider) { + setForm((prev) => ({ + ...prev, + provider: defaultProvider, + isBatchImport: true, + })); + } + }, [isOpen, defaultProvider]); + const parseModelName = (name: string): string => { if (!name) return ""; const parts = name.split("/"); @@ -638,6 +651,7 @@ export const ModelAddDialog = ({ value={form.provider} onChange={(value) => handleFormChange("provider", value)} > +
    @@ -1058,6 +1072,13 @@ export const ModelAddDialog = ({
    {t("model.dialog.label.currentlySupported")} + + ModelEngine + {form.isBatchImport && ( void; onSuccess: () => Promise; - customModels: ModelOption[]; + models: ModelOption[]; } export const ModelDeleteDialog = ({ isOpen, onClose, onSuccess, - customModels, + models, }: ModelDeleteDialogProps) => { const { t } = useTranslation(); const { message } = App.useApp(); @@ -167,6 +167,8 @@ export const ModelDeleteDialog = ({ return t("model.source.openai"); case MODEL_SOURCES.SILICON: return t("model.source.silicon"); + case MODEL_SOURCES.MODELENGINE: + return t("model.source.modelEngine"); case MODEL_SOURCES.OPENAI_API_COMPATIBLE: return t("model.source.custom"); default: @@ -185,6 +187,12 @@ export const ModelDeleteDialog = ({ text: "text-purple-600", border: "border-purple-100", }; + case MODEL_SOURCES.MODELENGINE: + return { + bg: "bg-blue-50", + text: "text-blue-600", + border: "border-blue-100", + }; case MODEL_SOURCES.OPENAI: return { bg: "bg-indigo-50", @@ -217,6 +225,14 @@ export const ModelDeleteDialog = ({ className="w-5 h-5" /> ); + case MODEL_SOURCES.MODELENGINE: + return ( + ModelEngine + ); case MODEL_SOURCES.OPENAI: return ( @@ -242,12 +258,12 @@ export const ModelDeleteDialog = ({ const getApiKeyByType = (type: ModelType | null): string => { if (!type) return ""; // Prioritize silicon models of the current type - const byType = customModels.find( + const byType = models.find( (m) => m.source === MODEL_SOURCES.SILICON && m.type === type && m.apiKey ); if (byType?.apiKey) return byType.apiKey; // Fall back to any available silicon model - const anySilicon = customModels.find( + const anySilicon = models.find( (m) => m.source === MODEL_SOURCES.SILICON && m.apiKey ); return anySilicon?.apiKey || ""; @@ -266,9 +282,9 @@ export const ModelDeleteDialog = ({ apiKey: apiKey && apiKey.trim() !== "" ? apiKey : "sk-no-api-key", }); setProviderModels(result || []); - // Initialize pending selected switch states (based on current customModels status) + // Initialize pending selected switch states (based on current models status) const currentIds = new Set( - customModels + models .filter( (m) => m.type === modelType && m.source === MODEL_SOURCES.SILICON ) @@ -379,7 +395,7 @@ export const ModelDeleteDialog = ({ // Adjust hierarchical navigation based on remaining count after deletion if (deletingModelType) { - const remainingByTypeAndSource = customModels.filter( + const remainingByTypeAndSource = models.filter( (model) => model.type === deletingModelType && (!selectedSource || model.source === selectedSource) && @@ -389,7 +405,7 @@ export const ModelDeleteDialog = ({ // No models under current source, return to source selection setSelectedSource(null); } - const remainingByType = customModels.filter( + const remainingByType = models.filter( (model) => model.type === deletingModelType && model.displayName !== displayName @@ -452,7 +468,7 @@ export const ModelDeleteDialog = ({ if (selectedSource === MODEL_SOURCES.SILICON && deletingModelType) { try { const currentIds = new Set( - customModels + models .filter( (m) => m.type === deletingModelType && @@ -462,7 +478,7 @@ export const ModelDeleteDialog = ({ ); // Build payload items for the current silicon models in required format - const currentModelPayloads = customModels + const currentModelPayloads = models .filter( (m) => m.type === deletingModelType && @@ -630,12 +646,12 @@ export const ModelDeleteDialog = ({ MODEL_TYPES.TTS, ] as ModelType[] ).map((type) => { - const customModelsByType = customModels.filter( + const modelsByType = models.filter( (model) => model.type === type ); const colorScheme = getModelColorScheme(type); - if (customModelsByType.length === 0) return null; + if (modelsByType.length === 0) return null; return (
    {t("model.dialog.delete.customModelCount", { - count: customModelsByType.length, + count: modelsByType.length, })} {(type === MODEL_TYPES.STT || type === MODEL_TYPES.TTS) && @@ -685,7 +701,7 @@ export const ModelDeleteDialog = ({ })}
    - {customModels.length === 0 && ( + {models.length === 0 && (
    {t("model.dialog.delete.noModels")}
    @@ -717,12 +733,13 @@ export const ModelDeleteDialog = ({
    {( [ + MODEL_SOURCES.MODELENGINE, MODEL_SOURCES.OPENAI, MODEL_SOURCES.SILICON, MODEL_SOURCES.OPENAI_API_COMPATIBLE, ] as ModelSource[] ).map((source) => { - const modelsOfSource = customModels.filter( + const modelsOfSource = models.filter( (model) => model.type === deletingModelType && model.source === source ); @@ -918,7 +935,7 @@ export const ModelDeleteDialog = ({
    ) : (
    - {customModels + {models .filter( (model) => model.type === deletingModelType && @@ -994,7 +1011,7 @@ export const ModelDeleteDialog = ({
    ))} - {customModels.filter( + {models.filter( (model) => model.type === deletingModelType && model.source === selectedSource @@ -1045,7 +1062,7 @@ export const ModelDeleteDialog = ({ onClose={() => setIsProviderConfigOpen(false)} initialApiKey={getApiKeyByType(deletingModelType)} initialMaxTokens={( - customModels.find( + models.find( (m) => m.type === deletingModelType && m.source === "silicon" )?.maxTokens || 4096 ).toString()} diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index c7005206a..feaeacad8 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -153,9 +153,16 @@ export const ModelEditDialog = ({ let maxTokensValue = parseInt(form.maxTokens); if (isEmbeddingModel) maxTokensValue = 0; + // Use original displayName for lookup, pass new displayName in body if changed + const originalDisplayName = model.displayName || model.name; + const newDisplayName = form.displayName; + await modelService.updateSingleModel({ - model_id: model.id.toString(), - displayName: form.displayName, + currentDisplayName: originalDisplayName, + // Only send displayName if it changed + ...(newDisplayName !== originalDisplayName + ? { displayName: newDisplayName } + : {}), url: form.url, apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, ...(maxTokensValue !== 0 ? { maxTokens: maxTokensValue } : {}), @@ -210,6 +217,7 @@ export const ModelEditDialog = ({ message.error(t("model.dialog.error.serverError")); } else { message.error(t("model.dialog.error.editFailed")); + console.error(error); } } finally { setLoading(false); diff --git a/frontend/app/[locale]/models/components/model/ModelListCard.tsx b/frontend/app/[locale]/models/components/model/ModelListCard.tsx index 83d6ccf53..daf7b8610 100644 --- a/frontend/app/[locale]/models/components/model/ModelListCard.tsx +++ b/frontend/app/[locale]/models/components/model/ModelListCard.tsx @@ -13,7 +13,6 @@ import { import { ModelConnectStatus, ModelOption, - ModelSource, ModelType, } from "@/types/modelConfig"; import log from "@/lib/logger"; @@ -119,10 +118,9 @@ interface ModelListCardProps { modelTypeName: string; selectedModel: string; onModelChange: (value: string) => void; - officialModels: ModelOption[]; - customModels: ModelOption[]; - onVerifyModel?: (modelName: string, modelType: ModelType) => void; // New callback for verifying models - errorFields?: { [key: string]: boolean }; // New error field state + models: ModelOption[]; + onVerifyModel?: (modelName: string, modelType: ModelType) => void; + errorFields?: { [key: string]: boolean }; } export const ModelListCard = ({ @@ -131,18 +129,14 @@ export const ModelListCard = ({ modelTypeName, selectedModel, onModelChange, - officialModels, - customModels, + models, onVerifyModel, errorFields, }: ModelListCardProps) => { const { t } = useTranslation(); // Add model list state for updates - const [modelsData, setModelsData] = useState({ - official: [...officialModels], - custom: [...customModels], - }); + const [modelsData, setModelsData] = useState([...models]); // Create a style element in the component containing animation definitions useEffect(() => { @@ -158,110 +152,44 @@ export const ModelListCard = ({ }; }, []); - // When getting model list, need to consider specific option type - const getModelsBySource = (): { - official: ModelOption[]; - custom: ModelOption[]; - } => { - // Each type only shows models of corresponding type - return { - official: modelsData.official.filter((model) => model.type === type), - custom: modelsData.custom.filter((model) => model.type === type), - }; + // Get filtered models by type + const getFilteredModels = (): ModelOption[] => { + return modelsData.filter((model) => model.type === type); }; - // Get model source + // Get model source label based on source field const getModelSource = (displayName: string): string => { - if ( - type === MODEL_TYPES.TTS || - type === MODEL_TYPES.STT || - type === MODEL_TYPES.VLM - ) { - const modelOfType = modelsData.custom.find( - (m) => m.type === type && m.displayName === displayName - ); - if (modelOfType) return t("model.source.custom"); - } - - const officialModel = modelsData.official.find( - (m) => m.type === type && m.name === displayName - ); - if (officialModel) return t("model.source.modelEngine"); - - const customModel = modelsData.custom.find( + const model = modelsData.find( (m) => m.type === type && m.displayName === displayName ); - return customModel ? t("model.source.custom") : t("model.source.unknown"); + + if (!model) return t("model.source.unknown"); + + // Return source label based on model.source + if (model.source === "modelengine") { + return t("model.source.modelEngine"); + } else if (model.source === "silicon") { + return t("model.source.silicon"); + } else if (model.source === "OpenAI-API-Compatible") { + return t("model.source.custom"); + } + + return t("model.source.unknown"); }; - const modelsBySource = getModelsBySource(); - - // Local update model status - const updateLocalModelStatus = ( - displayName: string, - status: ModelConnectStatus - ) => { - setModelsData((prevData) => { - // Find model to update - const modelToUpdate = prevData.custom.find( - (m) => m.displayName === displayName && m.type === type - ); - - if (!modelToUpdate) { - log.warn(t("model.warning.updateNotFound", { displayName, type })); - return prevData; - } - - const updatedCustomModels = prevData.custom.map((model) => { - if (model.displayName === displayName && model.type === type) { - return { - ...model, - connect_status: status, - }; - } - return model; - }); - - return { - official: prevData.official, - custom: updatedCustomModels, - }; - }); + const filteredModels = getFilteredModels(); + + // Group models by source for display + const groupedModels = { + modelengine: filteredModels.filter((m) => m.source === "modelengine"), + silicon: filteredModels.filter((m) => m.source === "silicon"), + custom: filteredModels.filter((m) => m.source === "OpenAI-API-Compatible"), }; // When parent component's model list updates, update local state useEffect(() => { - // Update local state but don't trigger fetchModelsStatus - setModelsData((prevData) => { - const updatedOfficialModels = officialModels.map((model) => { - // Preserve existing connect_status if it exists - const existingModel = prevData.official.find( - (m) => m.name === model.name && m.type === model.type - ); - return { - ...model, - connect_status: - existingModel?.connect_status || - (MODEL_STATUS.AVAILABLE as ModelConnectStatus), - }; - }); - - const updatedCustomModels = customModels.map((model) => { - // Prioritize using newly passed status to reflect latest backend state - return { - ...model, - connect_status: - model.connect_status || - (MODEL_STATUS.UNCHECKED as ModelConnectStatus), - }; - }); - - return { - official: updatedOfficialModels, - custom: updatedCustomModels, - }; - }); - }, [officialModels, customModels, type, modelId]); + setModelsData(models); + }, [models]); // Handle status indicator click event const handleStatusClick = (e: React.MouseEvent, displayName: string) => { @@ -270,9 +198,7 @@ export const ModelListCard = ({ e.nativeEvent.stopImmediatePropagation(); // Prevent all sibling event handlers if (onVerifyModel && displayName) { - // First update local state to "checking" - updateLocalModelStatus(displayName, MODEL_STATUS.CHECKING); - // Then call verification function + // Call verification function (parent component will update status) onVerifyModel(displayName, type); } @@ -317,35 +243,105 @@ export const ModelListCard = ({ errorFields && errorFields[`${type}.${modelId}`] ? "error-select" : "" } > - {modelsBySource.official.length > 0 && ( + {groupedModels.modelengine.length > 0 && ( - {modelsBySource.official.map((model) => ( + {groupedModels.modelengine.map((model) => (
    + ); +} + diff --git a/frontend/components/navigation/SideNavigation.tsx b/frontend/components/navigation/SideNavigation.tsx index afff9fcc1..2de970754 100644 --- a/frontend/components/navigation/SideNavigation.tsx +++ b/frontend/components/navigation/SideNavigation.tsx @@ -17,6 +17,8 @@ import { ChevronLeft, ChevronRight, Home, + Puzzle, + Activity, } from "lucide-react"; import type { MenuProps } from "antd"; import { useAuth } from "@/hooks/useAuth"; @@ -50,7 +52,9 @@ function getMenuKeyFromPathname(pathname: string): string { 'knowledges': '6', // Knowledge base 'models': '7', // Model management 'memory': '8', // Memory management - 'users': '9', // User management + 'users': '9', // User management + 'mcp-tools': '10', // MCP tools management + 'monitoring': '11', // Monitoring and operations }; return pathToKeyMap[pathWithoutLocale] || '0'; @@ -80,16 +84,18 @@ export function SideNavigation({ // If we have a currentView from parent, use it to determine the key if (currentView) { const viewToKeyMap: Record = { - 'home': '0', - 'chat': '1', - 'setup': '2', - 'space': '3', - 'market': '4', - 'agents': '5', - 'knowledges': '6', - 'models': '7', - 'memory': '8', - 'users': '9', + home: "0", + chat: "1", + setup: "2", + space: "3", + market: "4", + agents: "5", + knowledges: "6", + models: "7", + memory: "8", + users: "9", + mcpTools: "10", + monitoring: "11", }; setSelectedKey(viewToKeyMap[currentView] || '0'); } else { @@ -181,6 +187,30 @@ export function SideNavigation({ } }, }, + { + key: "10", + icon: , + label: t("sidebar.mcpToolsManagement"), + onClick: () => { + if (!isSpeedMode && user?.role !== "admin") { + onAdminRequired?.(); + } else { + onViewChange?.("mcpTools"); + } + }, + }, + { + key: "11", + icon: , + label: t("sidebar.monitoringManagement"), + onClick: () => { + if (!isSpeedMode && user?.role !== "admin") { + onAdminRequired?.(); + } else { + onViewChange?.("monitoring"); + } + }, + }, { key: "7", icon: , diff --git a/frontend/const/layoutConstants.ts b/frontend/const/layoutConstants.ts index ccc4fa80b..77b16df09 100644 --- a/frontend/const/layoutConstants.ts +++ b/frontend/const/layoutConstants.ts @@ -48,7 +48,7 @@ export const SETUP_PAGE_CONTAINER = { // Two column layout responsive configuration (based on the first page design) export const TWO_COLUMN_LAYOUT = { // Row/Col spacing configuration - GUTTER: [24, 16] as [number, number], + GUTTER: [16, 16] as [number, number], // Responsive column ratio LEFT_COLUMN: { @@ -68,18 +68,6 @@ export const TWO_COLUMN_LAYOUT = { }, } as const; -// Flex two column layout configuration (based on the KnowledgeBaseManager design) -export const FLEX_TWO_COLUMN_LAYOUT = { - // Left knowledge base list width - LEFT_WIDTH: "33.333333%", // 1/3 - - // Right content area width - RIGHT_WIDTH: "66.666667%", // 2/3 - - // Column spacing - GAP: "12px", -} as const; - // Standard card style configuration (based on the first page design) export const STANDARD_CARD = { // Base style class name diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts index efe19ec14..a68b087fb 100644 --- a/frontend/const/modelConfig.ts +++ b/frontend/const/modelConfig.ts @@ -19,6 +19,7 @@ export const MODEL_TYPES = { export const MODEL_SOURCES = { OPENAI: "openai", SILICON: "silicon", + MODELENGINE: "modelengine", OPENAI_API_COMPATIBLE: "OpenAI-API-Compatible", CUSTOM: "custom" } as const; @@ -42,7 +43,6 @@ export const MODEL_PROVIDER_KEYS = [ "qwen", "openai", "siliconflow", - "ponytoken", "jina", "deepseek", "aliyuncs", @@ -55,7 +55,6 @@ export const PROVIDER_HINTS: Record = { qwen: "qwen", openai: "openai", siliconflow: "siliconflow", - ponytoken: "ponytoken", jina: "jina", deepseek: "deepseek", aliyuncs: "aliyuncs", @@ -66,7 +65,6 @@ export const PROVIDER_ICON_MAP: Record = { qwen: "/qwen.png", openai: "/openai.png", siliconflow: "/siliconflow.png", - ponytoken: "/ponytoken.png", jina: "/jina.png", deepseek: "/deepseek.png", aliyuncs: "/aliyuncs.png", diff --git a/frontend/lib/config.ts b/frontend/lib/config.ts index d40ff6069..c8fbb6724 100644 --- a/frontend/lib/config.ts +++ b/frontend/lib/config.ts @@ -42,6 +42,7 @@ class ConfigStoreClass { // Deep merge configuration private deepMerge(target: T, source: Partial): T { if (!source) return target; + if (!target) return source as T; const result = { ...target } as T; @@ -50,7 +51,12 @@ class ConfigStoreClass { const sourceValue = (source as any)[key]; if (sourceValue && typeof sourceValue === 'object' && !Array.isArray(sourceValue)) { - (result as any)[key] = this.deepMerge(targetValue, sourceValue); + // If target has no value for this key, use source value directly + if (targetValue !== undefined && targetValue !== null) { + (result as any)[key] = this.deepMerge(targetValue, sourceValue); + } else { + (result as any)[key] = sourceValue; + } } else if (sourceValue !== undefined) { (result as any)[key] = sourceValue; } diff --git a/frontend/lib/viewPersistence.ts b/frontend/lib/viewPersistence.ts index a08b9f2fe..9021b445d 100644 --- a/frontend/lib/viewPersistence.ts +++ b/frontend/lib/viewPersistence.ts @@ -5,17 +5,19 @@ const VIEW_STORAGE_KEY = 'nexent_current_view'; -type ViewType = - | "home" - | "memory" - | "models" - | "agents" - | "knowledges" - | "space" - | "setup" - | "chat" - | "market" - | "users"; +type ViewType = + | "home" + | "memory" + | "models" + | "agents" + | "knowledges" + | "space" + | "setup" + | "chat" + | "market" + | "users" + | "mcpTools" + | "monitoring"; const VALID_VIEWS: ViewType[] = [ "home", @@ -28,6 +30,8 @@ const VALID_VIEWS: ViewType[] = [ "chat", "market", "users", + "mcpTools", + "monitoring", ]; /** diff --git a/frontend/package.json b/frontend/package.json index d818cf459..9fceddc44 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -61,7 +61,7 @@ "katex": "^0.16.11", "lucide-react": "^0.454.0", "mermaid": "^11.12.0", - "next": "15.4.5", + "next": "15.5.7", "next-i18next": "^15.4.2", "next-themes": "^0.4.4", "react": "18.2.0", @@ -97,7 +97,7 @@ "@types/react": "18.3.20", "@types/react-dom": "18.3.6", "eslint": "^9.34.0", - "eslint-config-next": "15.5.0", + "eslint-config-next": "15.5.7", "postcss": "^8", "tailwindcss": "^3.4.17", "typescript": "5.8.3" diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index e06e3c40c..5ee25a7b8 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -8,6 +8,9 @@ "chatAttachment.imagePreview": "Image Preview", "chatAttachment.previewNotSupported": "Preview not supported for this file type", "chatAttachment.downloadToView": "Please download the file to view", + "chatAttachment.downloading": "Downloading...", + "chatAttachment.downloadSuccess": "File downloaded successfully", + "chatAttachment.downloadError": "Failed to download file. Please try again.", "chatAttachment.image": "Image", "chatInterface.newConversation": "New Conversation", @@ -214,6 +217,11 @@ "chatRightPanel.noAssociatedImages": "This message has no associated images", "chatRightPanel.sources": "Sources", "chatRightPanel.images": "Images", + "chatRightPanel.downloading": "Downloading...", + "chatRightPanel.fileDownloadSuccess": "File download started", + "chatRightPanel.fileDownloadError": "Failed to download file. Please try again.", + "chatRightPanel.source.datamate": "Source: Datamate", + "chatRightPanel.source.nexent": "Source: Nexent", "chatStreamFinalMessage.copyFailed": "Copy failed:", "chatStreamFinalMessage.getMessageIdFailed": "Failed to get message ID:", @@ -240,6 +248,9 @@ "taskWindow.urlParseError": "URL parsing error:", "taskWindow.visit": "Visit {{domain}}", "taskWindow.readingSearchResults": "Reading search results", + "taskWindow.downloadFile": "Download {{name}}", + "taskWindow.downloadSuccess": "File download started", + "taskWindow.downloadError": "Failed to download file. Please try again.", "taskWindow.parseCardError": "Failed to parse card content:", "taskWindow.cannotParseCard": "Cannot parse card content", "taskWindow.parseSearchError": "Failed to parse search results:", @@ -285,8 +296,8 @@ "agent.description": "Agent Description", "agent.descriptionPlaceholder": "Please enter agent description", "agent.detailContent.title": "Agent Detail Content", - "agent.generating.title": "Generating Prompts", - "agent.generating.subtitle": "Please wait, the system is generating intelligent prompts for you...", + "agent.generating.title": "Generating Agent", + "agent.generating.subtitle": "Please wait, the system is generating intelligent agent for you...", "agent.error.loadTools": "Failed to load tool list:", "agent.error.loadToolsRetry": "Failed to get tool list, please refresh the page and try again", "agent.error.fetchAgentList": "Failed to get agent list", @@ -402,30 +413,6 @@ "toolPool.error.requiredFields": "The following required fields are not filled: {{fields}}", "toolPool.tooltip.functionGuide": "1. For local knowledge base search functionality, please enable the knowledge_base_search tool;\n2. For text file parsing functionality, please enable the analyze_text_file tool;\n3. For image parsing functionality, please enable the analyze_image tool.", - "common.loading": "Loading", - "common.save": "Save", - "common.cancel": "Cancel", - "common.confirm": "Confirm", - "common.enabled": "enabled", - "common.disabled": "disabled", - "common.yes": "Yes", - "common.no": "No", - "common.none": "None", - "common.toBeConfigured": "To Be Configured", - "common.required": "Required", - "common.refresh": "Refresh", - "common.unknownError": "Unknown error", - "common.retryLater": "Please try again later", - "common.back": "Back", - "common.delete": "Delete", - "common.notice": "Notice", - "common.button.close": "Close", - "common.button.cancel": "Cancel", - "common.button.save": "Save", - "common.button.saving": "Saving", - "common.button.editConfig": "Edit Configuration", - "common.message.refreshSuccess": "Refresh successful", - "common.message.refreshFailed": "Refresh failed", "tool.message.unavailable": "This tool is currently unavailable and cannot be selected", "tool.error.noMainAgentId": "Main agent ID is not set, cannot update tool status", @@ -599,6 +586,7 @@ "model.dialog.hint.batchImportEnabled": "Batch add enabled. Multiple models will be added at once.", "model.dialog.hint.batchImportDisabled": "Batch add disabled. Only a single model will be added.", "model.provider.silicon": "SiliconFlow", + "model.provider.modelengine": "ModelEngine", "model.dialog.modelList.title": "Show Models", "model.dialog.modelList.searchPlaceholder": "Search models by name", "model.dialog.modelList.noResults": "No models match your search", @@ -612,8 +600,8 @@ "model.dialog.button.verifying": "Verifying...", "model.dialog.button.add": "Add", "model.dialog.help.title": "Model Configuration Guide", - "model.dialog.help.content": "Please fill in the model's basic information. API Key and display name are optional, other fields are required. It's recommended to verify connectivity before adding the model. For detailed configuration methods, please refer to [Model Configuration](https://modelengine-group.github.io/nexent/en/user-guide/model-configuration.html).", - "model.dialog.help.content.batchImport": "Please fill in the provider's basic information. API Key and provider name are required, other fields are optional. It's recommended to verify connectivity before adding the model. For detailed configuration methods, please refer to [Model Configuration](https://modelengine-group.github.io/nexent/en/user-guide/model-configuration.html).", + "model.dialog.help.content": "Please fill in the model's basic information. API Key and display name are optional, other fields are required. It's recommended to verify connectivity before adding the model. For detailed configuration methods, please refer to [Model Configuration](https://modelengine-group.github.io/nexent/en/user-guide/model-management.html).", + "model.dialog.help.content.batchImport": "Please fill in the provider's basic information. API Key and provider name are required, other fields are optional. It's recommended to verify connectivity before adding the model. For detailed configuration methods, please refer to [Model Configuration](https://modelengine-group.github.io/nexent/en/user-guide/model-management.html).", "model.dialog.warning.incompleteForm": "Please complete the model configuration information first", "model.dialog.status.verifying": "Verifying model connectivity...", "model.dialog.success.connectivityVerified": "Model connectivity verification successful!", @@ -664,6 +652,7 @@ "model.type.main": "LLM Model", "model.select.placeholder": "Select Model", "model.group.modelEngine": "ModelEngine Models", + "model.group.silicon": "Silicon Flow Models", "model.group.custom": "Custom Models", "model.status.tooltip": "Click to verify connectivity", @@ -707,8 +696,8 @@ "modelConfig.message.syncSuccess": "Models synced successfully", "modelConfig.message.addSuccess": "Model added successfully", "modelConfig.button.syncModelEngine": "Sync ModelEngine Models", - "modelConfig.button.addCustomModel": "Add Custom Model", - "modelConfig.button.editCustomModel": "Edit Custom Model", + "modelConfig.button.addCustomModel": "Add Model", + "modelConfig.button.editCustomModel": "Edit or Delete Model", "modelConfig.button.checkConnectivity": "Check Model Connectivity", "modelConfig.slider.chunkingSize": "Chunk Size", "modelConfig.slider.expectedChunkSize": "Expected Chunk Size", @@ -1091,6 +1080,8 @@ "sidebar.modelManagement": "Model Management", "sidebar.memoryManagement": "Memory Management", "sidebar.userManagement": "User Management", + "sidebar.mcpToolsManagement": "MCP Tools", + "sidebar.monitoringManagement": "Monitoring & Ops", "market.comingSoon.title": "Agent Market Coming Soon", "market.comingSoon.description": "Discover and install pre-built AI agents from our marketplace. Save time by leveraging community-created solutions.", @@ -1106,6 +1097,20 @@ "users.comingSoon.feature3": "Monitor user activity and usage", "users.comingSoon.badge": "Coming Soon", + "mcpTools.comingSoon.title": "MCP Tools Management Coming Soon", + "mcpTools.comingSoon.description": "Centralized management for MCP servers and tools. Configure connectivity, synchronize tools, and monitor MCP health in one place.", + "mcpTools.comingSoon.feature1": "Register and manage multiple MCP servers", + "mcpTools.comingSoon.feature2": "Sync, inspect, and organize MCP tools", + "mcpTools.comingSoon.feature3": "Monitor MCP connectivity and usage status", + "mcpTools.comingSoon.badge": "Coming Soon", + + "monitoring.comingSoon.title": "Monitoring & Operations Coming Soon", + "monitoring.comingSoon.description": "Unified monitoring and operations center for your agents. Track health, performance, and incidents in real time.", + "monitoring.comingSoon.feature1": "Monitor agent health, latency, and error rates", + "monitoring.comingSoon.feature2": "View and filter agent logs and run history", + "monitoring.comingSoon.feature3": "Configure alerts and operational actions for critical events", + "monitoring.comingSoon.badge": "Coming Soon", + "market.title": "Agent Market", "market.description": "Discover and download pre-built intelligent agents", "market.searchPlaceholder": "Search agents by name or description...", @@ -1169,11 +1174,20 @@ "market.install.button.next": "Next", "market.install.button.install": "Install", "market.install.button.installing": "Installing...", + "market.install.model.mode": "Model Selection Mode", + "market.install.model.mode.unified": "Unified: Use one model for all agents", + "market.install.model.mode.individual": "Individual: Select model for each agent", "market.install.model.description": "Select a model from your configured models to use for this agent.", + "market.install.model.description.unified": "Select a model from your configured models. This model will be applied to all agents (main agent and sub-agents).", + "market.install.model.description.individual": "Select a model for each agent (main agent and sub-agents).", "market.install.model.label": "Model", "market.install.model.placeholder": "Select a model", "market.install.model.noModels": "No available models. Please configure models first.", - "market.install.config.description": "Please configure the following required fields for this agent.", + "market.install.config.description": "Please configure the following required fields for this agent and its sub-agents.", + "market.install.config.fields": "fields", + "market.install.config.noFields": "No configuration fields required.", + "market.install.agent.defaultName": "Agent", + "market.install.agent.main": "Main", "market.install.config.placeholder": "Enter configuration value", "market.install.mcp.description": "This agent requires the following MCP servers. Please install or configure them.", "market.install.mcp.installed": "Installed", @@ -1181,6 +1195,7 @@ "market.install.mcp.urlPlaceholder": "Enter MCP server URL", "market.install.mcp.install": "Install", "market.install.error.modelRequired": "Please select a model", + "market.install.error.allModelsRequired": "Please select models for all agents", "market.install.error.configRequired": "Please fill in all required fields", "market.install.error.mcpUrlRequired": "MCP URL is required", "market.install.error.loadModels": "Failed to load models", @@ -1204,6 +1219,29 @@ "market.error.unknown.title": "Something Went Wrong", "market.error.unknown.description": "An unexpected error occurred. Please try again.", + "common.loading": "Loading", + "common.save": "Save", + "common.cancel": "Cancel", + "common.confirm": "Confirm", + "common.enabled": "enabled", + "common.disabled": "disabled", + "common.yes": "Yes", + "common.no": "No", + "common.none": "None", + "common.required": "Required", + "common.refresh": "Refresh", + "common.unknownError": "Unknown error", + "common.retryLater": "Please try again later", + "common.back": "Back", + "common.delete": "Delete", + "common.notice": "Notice", + "common.button.close": "Close", + "common.button.cancel": "Cancel", + "common.button.save": "Save", + "common.button.saving": "Saving", + "common.button.editConfig": "Edit Configuration", + "common.message.refreshSuccess": "Refresh successful", + "common.message.refreshFailed": "Refresh failed", "common.toBeConfigured": "To Be Configured", "common.source": "Source", "common.category": "Category", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 07d5058d4..65d80dacf 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -8,6 +8,9 @@ "chatAttachment.imagePreview": "图片预览", "chatAttachment.previewNotSupported": "此文件类型暂不支持在线预览", "chatAttachment.downloadToView": "请下载文件后查看", + "chatAttachment.downloading": "正在下载...", + "chatAttachment.downloadSuccess": "文件下载成功", + "chatAttachment.downloadError": "文件下载失败,请重试", "chatAttachment.image": "图片", "chatInterface.newConversation": "新对话", @@ -214,6 +217,11 @@ "chatRightPanel.noAssociatedImages": "这条消息没有关联的图片内容", "chatRightPanel.sources": "来源", "chatRightPanel.images": "图片", + "chatRightPanel.downloading": "正在下载...", + "chatRightPanel.fileDownloadSuccess": "文件下载已开始", + "chatRightPanel.fileDownloadError": "文件下载失败,请重试", + "chatRightPanel.source.datamate": "来源: Datamate", + "chatRightPanel.source.nexent": "来源: Nexent", "chatStreamFinalMessage.copyFailed": "复制失败:", "chatStreamFinalMessage.getMessageIdFailed": "获取消息ID失败:", @@ -240,6 +248,9 @@ "taskWindow.urlParseError": "URL解析错误:", "taskWindow.visit": "访问 {{domain}}", "taskWindow.readingSearchResults": "阅读检索结果", + "taskWindow.downloadFile": "下载 {{name}}", + "taskWindow.downloadSuccess": "文件已开始下载", + "taskWindow.downloadError": "文件下载失败,请稍后重试", "taskWindow.parseCardError": "解析卡片内容失败:", "taskWindow.cannotParseCard": "无法解析卡片内容", "taskWindow.parseSearchError": "解析搜索结果失败:", @@ -286,8 +297,8 @@ "agent.description": "Agent描述", "agent.descriptionPlaceholder": "请输入Agent描述", "agent.detailContent.title": "Agent详细内容", - "agent.generating.title": "正在生成提示词", - "agent.generating.subtitle": "请稍候,系统正在为您生成智能提示词...", + "agent.generating.title": "正在生成智能体", + "agent.generating.subtitle": "请稍候,系统正在为您生成智能智能体...", "agent.error.loadTools": "加载工具列表失败:", "agent.error.loadToolsRetry": "获取工具列表失败,请刷新页面重试", "agent.error.fetchAgentList": "获取Agent列表失败", @@ -403,30 +414,6 @@ "toolPool.error.requiredFields": "以下必填字段未填写: {{fields}}", "toolPool.tooltip.functionGuide": "1. 本地知识库检索功能,请启用knowledge_base_search工具;\n2. 文本文件解析功能,请启用analyze_text_file工具;\n3. 图片解析功能,请启用analyze_image工具。", - "common.loading": "加载中", - "common.save": "保存", - "common.cancel": "取消", - "common.confirm": "确定", - "common.enabled": "已启用", - "common.disabled": "已禁用", - "common.yes": "是", - "common.no": "否", - "common.none": "无", - "common.toBeConfigured": "待配置", - "common.required": "必填", - "common.refresh": "刷新", - "common.unknownError": "未知错误", - "common.retryLater": "请稍后重试", - "common.back": "返回", - "common.delete": "删除", - "common.button.cancel": "取消", - "common.button.save": "保存", - "common.button.saving": "保存中", - "common.notice": "注意", - "common.button.close": "关闭", - "common.button.editConfig": "修改配置", - "common.message.refreshSuccess": "刷新成功", - "common.message.refreshFailed": "刷新失败", "tool.message.unavailable": "该工具当前不可用,无法选择", "tool.error.noMainAgentId": "主代理ID未设置,无法更新工具状态", @@ -600,6 +587,7 @@ "model.dialog.hint.batchImportEnabled": "批量添加模式已启用,可通过API Key一次性导入多个模型", "model.dialog.hint.batchImportDisabled": "批量添加模式已关闭,仅添加单个模型", "model.provider.silicon": "硅基流动", + "model.provider.modelengine": "ModelEngine", "model.dialog.modelList.title": "显示模型", "model.dialog.modelList.searchPlaceholder": "按名称搜索模型", "model.dialog.modelList.noResults": "没有匹配的模型", @@ -613,8 +601,8 @@ "model.dialog.button.verifying": "验证中...", "model.dialog.button.add": "添加", "model.dialog.help.title": "模型配置说明", - "model.dialog.help.content": "请填写模型的基本信息,API Key、展示名称为可选项,其他字段为必填项。建议先验证连通性后再添加模型。详细配置方法请参考[模型配置](https://modelengine-group.github.io/nexent/zh/user-guide/model-configuration.html)。", - "model.dialog.help.content.batchImport": "请填写提供商的基本信息,API Key和提供商名称为必填项,其他字段为可选项。详细配置方法请参考[模型配置](https://modelengine-group.github.io/nexent/zh/user-guide/model-configuration.html)。", + "model.dialog.help.content": "请填写模型的基本信息,API Key、展示名称为可选项,其他字段为必填项。建议先验证连通性后再添加模型。详细配置方法请参考[模型配置](https://modelengine-group.github.io/nexent/zh/user-guide/model-management.html)。", + "model.dialog.help.content.batchImport": "请填写提供商的基本信息,API Key和提供商名称为必填项,其他字段为可选项。详细配置方法请参考[模型配置](https://modelengine-group.github.io/nexent/zh/user-guide/model-management.html)。", "model.dialog.warning.incompleteForm": "请先填写完整的模型配置信息", "model.dialog.status.verifying": "正在验证模型连通性...", "model.dialog.error.connectivityRequired": "请先验证模型连通性且确保连接成功后再添加模型", @@ -663,6 +651,7 @@ "model.type.main": "大语言模型", "model.select.placeholder": "选择模型", "model.group.modelEngine": "ModelEngine模型", + "model.group.silicon": "硅基流动模型", "model.group.custom": "自定义模型", "model.status.tooltip": "点击可验证连通性", "model.dialog.success.updateSuccess": "更新成功", @@ -707,8 +696,8 @@ "modelConfig.message.syncSuccess": "模型同步成功", "modelConfig.message.addSuccess": "添加模型成功", "modelConfig.button.syncModelEngine": "同步ModelEngine模型", - "modelConfig.button.addCustomModel": "添加自定义模型", - "modelConfig.button.editCustomModel": "修改自定义模型", + "modelConfig.button.addCustomModel": "添加模型", + "modelConfig.button.editCustomModel": "修改或删除模型", "modelConfig.button.checkConnectivity": "检查模型连通性", "modelConfig.slider.chunkingSize": "文档切片大小", "modelConfig.slider.expectedChunkSize": "期望切片大小", @@ -1091,6 +1080,8 @@ "sidebar.modelManagement": "模型管理", "sidebar.memoryManagement": "记忆管理", "sidebar.userManagement": "用户管理", + "sidebar.mcpToolsManagement": "MCP 工具", + "sidebar.monitoringManagement": "监控与运维", "market.comingSoon.title": "智能体市场即将推出", "market.comingSoon.description": "从我们的市场中发现并安装预构建的AI智能体。通过使用社区创建的解决方案节省时间。", @@ -1162,11 +1153,20 @@ "market.install.button.next": "下一步", "market.install.button.install": "安装", "market.install.button.installing": "正在安装...", + "market.install.model.mode": "模型选择模式", + "market.install.model.mode.unified": "统一配置:所有智能体使用同一模型", + "market.install.model.mode.individual": "独立配置:为每个智能体单独选择模型", "market.install.model.description": "从已配置的模型中选择一个模型用于该智能体。", + "market.install.model.description.unified": "从已配置的模型中选择一个模型。该模型将应用于所有智能体(主智能体和子智能体)。", + "market.install.model.description.individual": "为每个智能体(主智能体和子智能体)选择模型。", "market.install.model.label": "模型", "market.install.model.placeholder": "选择一个模型", "market.install.model.noModels": "暂无可用模型。请先配置模型。", - "market.install.config.description": "请为该智能体配置以下必填字段。", + "market.install.config.description": "请为该智能体及其子智能体配置以下必填字段。", + "market.install.config.fields": "个字段", + "market.install.config.noFields": "无需配置字段。", + "market.install.agent.defaultName": "智能体", + "market.install.agent.main": "主", "market.install.config.placeholder": "输入配置值", "market.install.mcp.description": "该智能体需要以下 MCP 服务器。请安装或配置它们。", "market.install.mcp.installed": "已安装", @@ -1174,6 +1174,7 @@ "market.install.mcp.urlPlaceholder": "输入 MCP 服务器地址", "market.install.mcp.install": "安装", "market.install.error.modelRequired": "请选择一个模型", + "market.install.error.allModelsRequired": "请为所有智能体选择模型", "market.install.error.configRequired": "请填写所有必填字段", "market.install.error.mcpUrlRequired": "MCP 地址为必填项", "market.install.error.loadModels": "加载模型失败", @@ -1203,7 +1204,44 @@ "users.comingSoon.feature2": "配置精细化权限", "users.comingSoon.feature3": "监控用户活动和使用情况", "users.comingSoon.badge": "即将推出", + + "mcpTools.comingSoon.title": "MCP 工具管理即将推出", + "mcpTools.comingSoon.description": "集中管理 MCP 服务器与工具,在一个页面中完成连接配置、工具同步与健康状态监控。", + "mcpTools.comingSoon.feature1": "注册并管理多个 MCP 服务器", + "mcpTools.comingSoon.feature2": "同步、查看和组织 MCP 工具列表", + "mcpTools.comingSoon.feature3": "监控 MCP 连接状态和使用情况", + "mcpTools.comingSoon.badge": "即将推出", + "monitoring.comingSoon.title": "监控与运维中心即将推出", + "monitoring.comingSoon.description": "面向智能体的统一监控与运维中心,用于实时跟踪健康状态、性能指标与异常事件。", + "monitoring.comingSoon.feature1": "监控智能体健康状态、延迟与错误率", + "monitoring.comingSoon.feature2": "查看并筛选智能体运行日志和历史任务", + "monitoring.comingSoon.feature3": "配置告警策略与关键事件的运维操作", + "monitoring.comingSoon.badge": "即将推出", + + "common.loading": "加载中", + "common.save": "保存", + "common.cancel": "取消", + "common.confirm": "确定", + "common.enabled": "已启用", + "common.disabled": "已禁用", + "common.yes": "是", + "common.no": "否", + "common.none": "无", + "common.required": "必填", + "common.refresh": "刷新", + "common.unknownError": "未知错误", + "common.retryLater": "请稍后重试", + "common.back": "返回", + "common.delete": "删除", + "common.button.cancel": "取消", + "common.button.save": "保存", + "common.button.saving": "保存中", + "common.notice": "注意", + "common.button.close": "关闭", + "common.button.editConfig": "修改配置", + "common.message.refreshSuccess": "刷新成功", + "common.message.refreshFailed": "刷新失败", "common.toBeConfigured": "待配置", "common.source": "来源", "common.category": "分类", diff --git a/frontend/server.js b/frontend/server.js index 5b0f5d66b..45ec14799 100644 --- a/frontend/server.js +++ b/frontend/server.js @@ -2,6 +2,15 @@ const { createServer } = require('http'); const { parse } = require('url'); const next = require('next'); const { createProxyServer } = require('http-proxy'); +const path = require('path'); + +// Load environment variables from .env file in parent directory (project root) +// In container environments, env vars are injected directly by Docker, so .env file may not exist +// Using optional: true to avoid errors if .env file is not found +require('dotenv').config({ + path: path.resolve(__dirname, '../.env'), + override: false // Don't override existing environment variables (important for Docker) +}); const dev = process.env.NODE_ENV !== 'production'; const app = next({ @@ -14,7 +23,7 @@ const HTTP_BACKEND = process.env.HTTP_BACKEND || 'http://localhost:5010'; // con const WS_BACKEND = process.env.WS_BACKEND || 'ws://localhost:5014'; // runtime const RUNTIME_HTTP_BACKEND = process.env.RUNTIME_HTTP_BACKEND || 'http://localhost:5014'; // runtime const MINIO_BACKEND = process.env.MINIO_ENDPOINT || 'http://localhost:9010'; -const MARKET_BACKEND = process.env.MARKET_BACKEND || 'http://localhost:8010'; // market +const MARKET_BACKEND = process.env.MARKET_BACKEND || 'https://market.nexent.tech'; // market const PORT = 3000; const proxy = createProxyServer(); diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 94d7562dd..0af193d52 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -61,23 +61,44 @@ export const API_ENDPOINTS = { storage: { upload: `${API_BASE_URL}/file/storage`, files: `${API_BASE_URL}/file/storage`, - file: (objectName: string, download: string = "ignore") => - `${API_BASE_URL}/file/storage/${objectName}?download=${download}`, + file: (objectName: string, download: string = "ignore", filename?: string) => { + const queryParams = new URLSearchParams(); + queryParams.append("download", download); + if (filename) queryParams.append("filename", filename); + return `${API_BASE_URL}/file/download/${objectName}?${queryParams.toString()}`; + }, + datamateDownload: (params: { + url?: string; + baseUrl?: string; + datasetId?: string; + fileId?: string; + filename?: string; + }) => { + const queryParams = new URLSearchParams(); + if (params.url) queryParams.append("url", params.url); + if (params.baseUrl) queryParams.append("base_url", params.baseUrl); + if (params.datasetId) queryParams.append("dataset_id", params.datasetId); + if (params.fileId) queryParams.append("file_id", params.fileId); + if (params.filename) queryParams.append("filename", params.filename); + return `${API_BASE_URL}/file/datamate/download?${queryParams.toString()}`; + }, delete: (objectName: string) => `${API_BASE_URL}/file/storage/${objectName}`, preprocess: `${API_BASE_URL}/file/preprocess`, }, proxy: { - image: (url: string) => - `${API_BASE_URL}/image?url=${encodeURIComponent(url)}`, + image: (url: string, format: string = "stream") => + `${API_BASE_URL}/image?url=${encodeURIComponent(url)}&format=${format}`, }, model: { - // Official model service - officialModelList: `${API_BASE_URL}/me/model/list`, - officialModelHealthcheck: `${API_BASE_URL}/me/healthcheck`, + // ModelEngine health check + modelEngineHealthcheck: `${API_BASE_URL}/me/healthcheck`, - // Custom model service + // Model lists + officialModelList: `${API_BASE_URL}/model/list`, // ModelEngine models are also in this list customModelList: `${API_BASE_URL}/model/list`, + + // Custom model service customModelCreate: `${API_BASE_URL}/model/create`, customModelCreateProvider: `${API_BASE_URL}/model/provider/create`, customModelBatchCreate: `${API_BASE_URL}/model/provider/batch_create`, @@ -91,7 +112,8 @@ export const API_ENDPOINTS = { displayName )}`, verifyModelConfig: `${API_BASE_URL}/model/temporary_healthcheck`, - updateSingleModel: `${API_BASE_URL}/model/update`, + updateSingleModel: (displayName: string) => + `${API_BASE_URL}/model/update?display_name=${encodeURIComponent(displayName)}`, updateBatchModel: `${API_BASE_URL}/model/batch_update`, // LLM model list for generation llmModelList: `${API_BASE_URL}/model/llm_list`, diff --git a/frontend/services/modelEngineService.ts b/frontend/services/modelEngineService.ts index c3b3f3bcd..839151448 100644 --- a/frontend/services/modelEngineService.ts +++ b/frontend/services/modelEngineService.ts @@ -15,12 +15,12 @@ const fetch = fetchWithAuth; */ const modelEngineService = { /** - * Check ModelEngine connection status + * Check ModelEngine connection status (environment variable configuration check) * @returns Promise Result object containing connection status and check time */ checkConnection: async (): Promise => { try { - const response = await fetch(API_ENDPOINTS.model.officialModelHealthcheck, { + const response = await fetch(API_ENDPOINTS.model.modelEngineHealthcheck, { method: "GET" }) diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index fdb41d537..9de2c5483 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -45,60 +45,15 @@ export class ModelError extends Error { // Model service export const modelService = { - // Get official model list - getOfficialModels: async (): Promise => { - try { - const response = await fetch(API_ENDPOINTS.model.officialModelList, { - headers: getAuthHeaders(), - }); - const result = await response.json(); - - if (response.status === STATUS_CODES.SUCCESS && result.data) { - const modelOptions: ModelOption[] = []; - const typeMap: Record = { - embed: MODEL_TYPES.EMBEDDING, - chat: MODEL_TYPES.LLM, - asr: MODEL_TYPES.STT, - tts: MODEL_TYPES.TTS, - rerank: MODEL_TYPES.RERANK, - vlm: MODEL_TYPES.VLM, - }; - - for (const model of result.data) { - if (typeMap[model.type]) { - modelOptions.push({ - id: model.id, - name: model.id, - type: typeMap[model.type], - maxTokens: 0, - source: MODEL_SOURCES.OPENAI_API_COMPATIBLE, - apiKey: model.api_key, - apiUrl: model.base_url, - displayName: model.id, - }); - } - } - - return modelOptions; - } - // If API call was not successful, return empty array - return []; - } catch (error) { - // In case of any error, return empty array - log.warn("Failed to load official models:", error); - return []; - } - }, - - // Get custom model list - getCustomModels: async (): Promise => { + // Get all models (unified method) + getAllModels: async (): Promise => { try { const response = await fetch(API_ENDPOINTS.model.customModelList, { headers: getAuthHeaders(), }); const result = await response.json(); - if (response.status === 200 && result.data) { + if (response.status === STATUS_CODES.SUCCESS && result.data) { return result.data.map((model: any) => ({ id: model.model_id, name: model.model_name, @@ -114,19 +69,24 @@ export const modelService = { maximumChunkSize: model.maximum_chunk_size, })); } - // If API call was not successful, return empty array - log.warn( - "Failed to load custom models:", - result.message || "Unknown error" - ); return []; } catch (error) { - // In case of any error, return empty array - log.warn("Failed to load custom models:", error); + log.warn("Failed to load models:", error); return []; } }, + // Legacy methods for backward compatibility (will be removed after refactoring) + getOfficialModels: async (): Promise => { + const allModels = await modelService.getAllModels(); + return allModels.filter((model) => model.source === "modelengine"); + }, + + getCustomModels: async (): Promise => { + const allModels = await modelService.getAllModels(); + return allModels.filter((model) => model.source !== "modelengine"); + }, + // Add custom model addCustomModel: async (model: { name: string; @@ -271,8 +231,8 @@ export const modelService = { }, updateSingleModel: async (model: { - model_id: string; - displayName: string; + currentDisplayName: string; + displayName?: string; url: string; apiKey: string; maxTokens?: number; @@ -281,26 +241,30 @@ export const modelService = { maximumChunkSize?: number; }): Promise => { try { - const response = await fetch(API_ENDPOINTS.model.updateSingleModel, { - method: "POST", - headers: getAuthHeaders(), - body: JSON.stringify({ - model_id: model.model_id, - display_name: model.displayName, - base_url: model.url, - api_key: model.apiKey, - ...(model.maxTokens !== undefined - ? { max_tokens: model.maxTokens } - : {}), - model_factory: model.source || "OpenAI-API-Compatible", - ...(model.expectedChunkSize !== undefined - ? { expected_chunk_size: model.expectedChunkSize } - : {}), - ...(model.maximumChunkSize !== undefined - ? { maximum_chunk_size: model.maximumChunkSize } - : {}), - }), - }); + const response = await fetch( + API_ENDPOINTS.model.updateSingleModel(model.currentDisplayName), + { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify({ + ...(model.displayName !== undefined + ? { display_name: model.displayName } + : {}), + base_url: model.url, + api_key: model.apiKey, + ...(model.maxTokens !== undefined + ? { max_tokens: model.maxTokens } + : {}), + model_factory: model.source || "OpenAI-API-Compatible", + ...(model.expectedChunkSize !== undefined + ? { expected_chunk_size: model.expectedChunkSize } + : {}), + ...(model.maximumChunkSize !== undefined + ? { maximum_chunk_size: model.maximumChunkSize } + : {}), + }), + } + ); const result = await response.json(); if (response.status !== 200) { throw new ModelError( diff --git a/frontend/services/storageService.ts b/frontend/services/storageService.ts index 7869bf012..ec60eb187 100644 --- a/frontend/services/storageService.ts +++ b/frontend/services/storageService.ts @@ -5,6 +5,124 @@ import { fetchWithAuth } from '@/lib/auth'; // @ts-ignore const fetch = fetchWithAuth; +/** + * Extract object_name from file URL + * Supports formats like: + * - http://localhost:3000/nexent/attachments/filename.png + * - /nexent/attachments/filename.png + * - attachments/filename.png + * - s3://nexent/attachments/filename.png + * Works for all file types: images, videos, documents, etc. + * @param url File URL (can be image, video, document, or any other file type) + * @returns object_name or null + */ +export function extractObjectNameFromUrl(url: string): string | null { + try { + // Handle s3:// protocol URLs (e.g., s3://nexent/attachments/filename.png) + if (url.startsWith("s3://")) { + // Remove s3:// prefix + const withoutProtocol = url.replace(/^s3:\/\//, ""); + const parts = withoutProtocol.split("/").filter(Boolean); + + // Find attachments in path + const attachmentsIndex = parts.indexOf("attachments"); + if (attachmentsIndex >= 0) { + return parts.slice(attachmentsIndex).join("/"); + } + + // If no attachments found but has bucket and path, return the path after bucket + if (parts.length > 1) { + return parts.slice(1).join("/"); + } + + // If only one part, return it as object_name + if (parts.length === 1) { + return parts[0]; + } + + return null; + } + + // Handle object_name or relative paths directly (e.g. "attachments/xxx.pdf") + const isHttpUrl = url.startsWith("http://") || url.startsWith("https://"); + if (!isHttpUrl) { + // Remove leading "/" if present + const normalized = url.replace(/^\/+/, ""); + if (!normalized) { + return null; + } + + const attachmentsIndex = normalized.indexOf("attachments/"); + if (attachmentsIndex >= 0) { + return normalized.slice(attachmentsIndex); + } + + // If there is no "attachments" segment but this is a plain path, + // treat the whole normalized path as object_name + return normalized; + } + + // Handle relative URLs + if (url.startsWith("/")) { + // Remove leading slash and extract path after /nexent/ or /attachments/ + const parts = url.split("/").filter(Boolean); + const attachmentsIndex = parts.indexOf("attachments"); + if (attachmentsIndex >= 0) { + return parts.slice(attachmentsIndex).join("/"); + } + // If no attachments found, try to find the last part + if (parts.length > 0) { + return parts.join("/"); + } + } + + // Handle full URLs + const urlObj = new URL(url); + const pathname = urlObj.pathname; + const parts = pathname.split("/").filter(Boolean); + + // Find attachments in path + const attachmentsIndex = parts.indexOf("attachments"); + if (attachmentsIndex >= 0) { + return parts.slice(attachmentsIndex).join("/"); + } + + // If no attachments found, return the last meaningful part + if (parts.length > 0) { + return parts.join("/"); + } + + return null; + } catch (error) { + return null; + } +} + +/** + * Convert image URL to backend API URL + * @param url Original image URL (can be MinIO URL or local path) + * @returns Backend API URL for the image + */ +export function convertImageUrlToApiUrl(url: string): string { + // If URL is an external http/https URL (not backend API), use proxy to avoid CORS and 403 errors + if ( + (url.startsWith("http://") || url.startsWith("https://")) && + !url.includes("/api/file/download/") && + !url.includes("/api/image") + ) { + // Use backend proxy to fetch external images (avoids CORS and hotlink protection) + return API_ENDPOINTS.proxy.image(url); + } + + const objectName = extractObjectNameFromUrl(url); + if (objectName) { + // Use the same download endpoint with stream mode for images + return API_ENDPOINTS.storage.file(objectName, "stream"); + } + // Fallback to original URL if extraction fails + return url; +} + export const storageService = { /** * Upload files to storage service @@ -54,5 +172,73 @@ export const storageService = { const data = await response.json(); return data.url; + }, + + /** + * Download file directly using backend API (faster, browser handles download) + * @param objectName File object name + * @param filename Optional filename for download + * @returns Promise that resolves when download link is opened + */ + async downloadFile(objectName: string, filename?: string): Promise { + try { + // Use direct link download for better performance + // Browser will handle the download stream directly + // Pass filename to backend so it can set the correct Content-Disposition header + const downloadUrl = API_ENDPOINTS.storage.file(objectName, "stream", filename); + + // Create download link and trigger download + // Using direct link allows browser to handle download stream efficiently + const link = document.createElement("a"); + link.href = downloadUrl; + // Set download attribute as fallback (browser will use Content-Disposition header if available) + link.download = filename || objectName.split("/").pop() || "download"; + link.style.display = "none"; + document.body.appendChild(link); + + // Trigger download + link.click(); + + // Clean up after a short delay to ensure download starts + setTimeout(() => { + document.body.removeChild(link); + }, 100); + } catch (error) { + throw new Error(`Failed to download file: ${error instanceof Error ? error.message : String(error)}`); + } + }, + + /** + * Download file from Datamate knowledge base via HTTP URL + * @param url HTTP URL of the file to download + * @param filename Optional filename for download + * @returns Promise that resolves when download link is opened + */ + async downloadDatamateFile(options: { + url?: string; + baseUrl?: string; + datasetId?: string; + fileId?: string; + filename?: string; + }): Promise { + try { + const downloadUrl = API_ENDPOINTS.storage.datamateDownload(options); + const link = document.createElement("a"); + link.href = downloadUrl; + // Only set download attribute when caller explicitly provides a filename. + // Otherwise, let the browser use the Content-Disposition header from backend, + // which already encodes the correct filename. + if (options.filename) { + link.download = options.filename; + } + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + setTimeout(() => { + document.body.removeChild(link); + }, 100); + } catch (error) { + throw new Error(`Failed to download datamate file: ${error instanceof Error ? error.message : String(error)}`); + } } }; \ No newline at end of file diff --git a/frontend/types/chat.ts b/frontend/types/chat.ts index e979b5fec..700edfdbf 100644 --- a/frontend/types/chat.ts +++ b/frontend/types/chat.ts @@ -95,6 +95,7 @@ export interface AttachmentItem { name: string; size: number; url?: string; + object_name?: string; contentType?: string; } diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index dc21e987e..db97a8c0d 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -19,7 +19,8 @@ export type ModelSource = | "openai" | "custom" | "silicon" - | "OpenAI-API-Compatible"; + | "OpenAI-API-Compatible" + | "modelengine"; // Model type export type ModelType = diff --git a/make/web/Dockerfile b/make/web/Dockerfile index edeff2d44..bdb66102d 100644 --- a/make/web/Dockerfile +++ b/make/web/Dockerfile @@ -22,7 +22,7 @@ RUN if [ -n "$MIRROR" ]; then npm config set registry "$MIRROR"; fi && \ "start": "NODE_ENV=production HOSTNAME=localhost node server.js"\ },\ "dependencies": {\ - "next": "15.4.5",\ + "next": "15.5.7",\ "react": "18.2.0",\ "react-dom": "18.2.0",\ "http-proxy": "^1.18.1",\ @@ -46,8 +46,13 @@ LABEL authors="nexent" RUN if [ "$APK_MIRROR" = "tsinghua" ]; then \ echo "https://mirrors.tuna.tsinghua.edu.cn/alpine/latest-stable/main" > /etc/apk/repositories && \ echo "https://mirrors.tuna.tsinghua.edu.cn/alpine/latest-stable/community" >> /etc/apk/repositories; \ - fi && \ - apk add --no-cache curl + fi + +# Update package index, upgrade busybox first, then install curl +# This avoids trigger script issues in cross-platform builds with QEMU emulation +RUN apk update && \ + apk upgrade --no-cache busybox || true && \ + apk add --no-cache --no-scripts curl WORKDIR /opt/frontend-dist diff --git a/sdk/nexent/core/agents/agent_model.py b/sdk/nexent/core/agents/agent_model.py index c48a2fc13..6eff00718 100644 --- a/sdk/nexent/core/agents/agent_model.py +++ b/sdk/nexent/core/agents/agent_model.py @@ -15,6 +15,7 @@ class ModelConfig(BaseModel): url: str = Field(description="Model endpoint URL") temperature: Optional[float] = Field(description="Temperature", default=0.1) top_p: Optional[float] = Field(description="Top P", default=0.95) + ssl_verify: Optional[bool] = Field(description="Whether to verify SSL certificates", default=True) class ToolConfig(BaseModel): diff --git a/sdk/nexent/core/agents/core_agent.py b/sdk/nexent/core/agents/core_agent.py index 2e3fc05e7..826ef7093 100644 --- a/sdk/nexent/core/agents/core_agent.py +++ b/sdk/nexent/core/agents/core_agent.py @@ -129,7 +129,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: additional_args = { "grammar": self.grammar} if self.grammar is not None else {} chat_message: ChatMessage = self.model(input_messages, - stop_sequences=["", "Observation:", "Calling tools:", "", "Observation:", "Calling tools:", " str: + """Execute DataMate search. + + Args: + query: Search query text. + top_k: Optional override for maximum number of search results. + threshold: Optional override for similarity threshold. + kb_page: Optional override for knowledge base list page index. + kb_page_size: Optional override for knowledge base list page size. + """ + + self.kb_page = kb_page + self.kb_page_size = kb_page_size + + # Send tool run message + if self.observer: + running_prompt = self.running_prompt_zh if self.observer.lang == "zh" else self.running_prompt_en + self.observer.add_message("", ProcessType.TOOL, running_prompt) + card_content = [{"icon": "search", "text": query}] + self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)) + + logger.info( + f"DataMateSearchTool called with query: '{query}', base_url: '{self.server_base_url}', " + f"top_k: {top_k}, threshold: {threshold}" + ) + + try: + # Step 1: Get knowledge base list + knowledge_base_ids = self._get_knowledge_base_list() + if not knowledge_base_ids: + return json.dumps("No knowledge base found. No relevant information found.", ensure_ascii=False) + + # Step 2: Retrieve knowledge base content + kb_search_results = self._retrieve_knowledge_base_content(query, knowledge_base_ids, top_k, threshold + ) + + if not kb_search_results: + raise Exception("No results found! Try a less restrictive/shorter query.") + + # Format search results + search_results_json = [] # Organize search results into a unified format + search_results_return = [] # Format for input to the large model + for index, single_search_result in enumerate(kb_search_results): + # Extract fields from DataMate API response + entity_data = single_search_result.get("entity", {}) + metadata = self._parse_metadata(entity_data.get("metadata")) + dataset_id = self._extract_dataset_id(metadata.get("absolute_directory_path", "")) + file_id = metadata.get("original_file_id") + download_url = self._build_file_download_url(dataset_id, file_id) + + score_details = entity_data.get("scoreDetails", {}) or {} + score_details.update({ + "datamate_dataset_id": dataset_id, + "datamate_file_id": file_id, + "datamate_download_url": download_url, + "datamate_base_url": self.server_base_url.rstrip("/") + }) + + search_result_message = SearchResultTextMessage( + title=metadata.get("file_name", ""), + text=entity_data.get("text", ""), + source_type="datamate", + url=download_url, + filename=metadata.get("file_name", ""), + published_date=entity_data.get("createTime", ""), + score=entity_data.get("score", "0"), + score_details=score_details, + cite_index=self.record_ops + index, + search_type=self.name, + tool_sign=self.tool_sign, + ) + + search_results_json.append(search_result_message.to_dict()) + search_results_return.append(search_result_message.to_model_dict()) + + self.record_ops += len(search_results_return) + + # Record the detailed content of this search + if self.observer: + search_results_data = json.dumps(search_results_json, ensure_ascii=False) + self.observer.add_message("", ProcessType.SEARCH_CONTENT, search_results_data) + return json.dumps(search_results_return, ensure_ascii=False) + + except Exception as e: + error_msg = f"Error during DataMate knowledge base search: {str(e)}" + logger.error(error_msg) + raise Exception(error_msg) + + def _get_knowledge_base_list(self) -> List[str]: + """Get knowledge base list from DataMate API. + + Returns: + List[str]: List of knowledge base IDs. + """ + try: + url = f"{self.server_base_url}/api/knowledge-base/list" + payload = {"page": self.kb_page, "size": self.kb_page_size} + + with httpx.Client(timeout=30) as client: + response = client.post(url, json=payload) + + if response.status_code != 200: + error_detail = ( + response.json().get("detail", "unknown error") + if response.headers.get("content-type", "").startswith("application/json") + else response.text + ) + raise Exception(f"Failed to get knowledge base list (status {response.status_code}): {error_detail}") + + result = response.json() + # Extract knowledge base IDs from response + # Assuming the response structure contains a list of knowledge bases with 'id' field + data = result.get("data", {}) + knowledge_bases = data.get("content", []) if data else [] + + knowledge_base_ids = [] + for kb in knowledge_bases: + kb_id = kb.get("id") + chunk_count = kb.get("chunkCount") + if kb_id and chunk_count: + knowledge_base_ids.append(str(kb_id)) + + logger.info(f"Retrieved {len(knowledge_base_ids)} knowledge base(s): {knowledge_base_ids}") + return knowledge_base_ids + + except httpx.TimeoutException: + raise Exception("Timeout while getting knowledge base list from DataMate API") + except httpx.RequestError as e: + raise Exception(f"Request error while getting knowledge base list: {str(e)}") + except Exception as e: + raise Exception(f"Error getting knowledge base list: {str(e)}") + + def _retrieve_knowledge_base_content( + self, query: str, knowledge_base_ids: List[str], top_k: int, threshold: float + ) -> List[dict]: + """Retrieve knowledge base content from DataMate API. + + Args: + query (str): Search query. + knowledge_base_ids (List[str]): List of knowledge base IDs to search. + top_k (int): Maximum number of results to return. + threshold (float): Similarity threshold. + + Returns: + List[dict]: List of search results. + """ + search_results = [] + for knowledge_base_id in knowledge_base_ids: + try: + url = f"{self.server_base_url}/api/knowledge-base/retrieve" + payload = { + "query": query, + "topK": top_k, + "threshold": threshold, + "knowledgeBaseIds": [knowledge_base_id], + } + + with httpx.Client(timeout=60) as client: + response = client.post(url, json=payload) + + if response.status_code != 200: + error_detail = ( + response.json().get("detail", "unknown error") + if response.headers.get("content-type", "").startswith("application/json") + else response.text + ) + raise Exception( + f"Failed to retrieve knowledge base content (status {response.status_code}): {error_detail}") + + result = response.json() + # Extract search results from response + for data in result.get("data", {}): + search_results.append(data) + except httpx.TimeoutException: + raise Exception("Timeout while retrieving knowledge base content from DataMate API") + except httpx.RequestError as e: + raise Exception(f"Request error while retrieving knowledge base content: {str(e)}") + except Exception as e: + raise Exception(f"Error retrieving knowledge base content: {str(e)}") + logger.info(f"Retrieved {len(search_results)} search result(s)") + return search_results + + @staticmethod + def _parse_metadata(metadata_raw: Optional[str]) -> dict: + """Parse metadata payload safely.""" + if not metadata_raw: + return {} + if isinstance(metadata_raw, dict): + return metadata_raw + try: + return json.loads(metadata_raw) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse metadata payload, falling back to empty metadata.") + return {} + + @staticmethod + def _extract_dataset_id(absolute_path: str) -> str: + """Extract dataset identifier from an absolute directory path.""" + if not absolute_path: + return "" + segments = [segment for segment in absolute_path.strip("/").split("/") if segment] + return segments[-1] if segments else "" + + def _build_file_download_url(self, dataset_id: str, file_id: str) -> str: + """Build the download URL for a dataset file.""" + if not (self.server_base_url and dataset_id and file_id): + return "" + return f"{self.server_base_url}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file diff --git a/sdk/nexent/core/utils/tools_common_message.py b/sdk/nexent/core/utils/tools_common_message.py index a79035d81..f89846fa5 100644 --- a/sdk/nexent/core/utils/tools_common_message.py +++ b/sdk/nexent/core/utils/tools_common_message.py @@ -9,6 +9,7 @@ class ToolSign(Enum): EXA_SEARCH = "b" # Exa search tool identifier LINKUP_SEARCH = "c" # Linkup search tool identifier TAVILY_SEARCH = "d" # Tavily search tool identifier + DATAMATE_KNOWLEDGE_BASE = "e" # DataMate knowledge base search tool identifier FILE_OPERATION = "f" # File operation tool identifier TERMINAL_OPERATION = "t" # Terminal operation tool identifier MULTIMODAL_OPERATION = "m" # Multimodal operation tool identifier @@ -20,6 +21,7 @@ class ToolSign(Enum): "tavily_search": ToolSign.TAVILY_SEARCH.value, "linkup_search": ToolSign.LINKUP_SEARCH.value, "exa_search": ToolSign.EXA_SEARCH.value, + "datamate_knowledge_base_search": ToolSign.DATAMATE_KNOWLEDGE_BASE.value, "file_operation": ToolSign.FILE_OPERATION.value, "terminal_operation": ToolSign.TERMINAL_OPERATION.value, "multimodal_operation": ToolSign.MULTIMODAL_OPERATION.value, diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index d81b3df4a..cd4be8afd 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -267,7 +267,7 @@ async def fake_get_url(object_name, expires): return {"success": True, "url": "http://example.com/a"} monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get_url) - resp = await file_management_app.get_storage_file(object_name="a.txt", download="redirect", expires=60) + resp = await file_management_app.get_storage_file(object_name="a.txt", download="redirect", expires=60, filename="a.txt") # Starlette RedirectResponse defaults to 307 assert 300 <= resp.status_code < 400 assert resp.headers["location"] == "http://example.com/a" @@ -281,9 +281,13 @@ async def gen(): return gen(), "text/plain" monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) - resp = await file_management_app.get_storage_file(object_name="a.txt", download="stream", expires=60) + resp = await file_management_app.get_storage_file(object_name="a.txt", download="stream", expires=60, filename="a.txt") + assert resp.headers["content-type"].startswith("text/plain") assert resp.media_type == "text/plain" - assert "inline; filename=\"a.txt\"" in resp.headers.get("content-disposition", "") + # Content-Disposition should be "attachment" not "inline", and filename should be extracted from object_name + content_disposition = resp.headers.get("content-disposition", "") + assert "attachment" in content_disposition + assert "a.txt" in content_disposition # consume stream chunks = [] async for part in resp.body_iterator: # type: ignore[attr-defined] @@ -297,7 +301,7 @@ async def fake_get_url(object_name, expires): return {"success": True, "url": "http://example.com/x"} monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get_url) - result = await file_management_app.get_storage_file(object_name="x", download="ignore", expires=10) + result = await file_management_app.get_storage_file(object_name="x", download="ignore", expires=10, filename="x.txt") assert result["url"] == "http://example.com/x" @@ -308,7 +312,7 @@ async def boom_url(object_name, expires): monkeypatch.setattr(file_management_app, "get_file_url_impl", boom_url) with pytest.raises(Exception) as ei: - await file_management_app.get_storage_file(object_name="x", download="ignore", expires=1) + await file_management_app.get_storage_file(object_name="x", download="ignore", expires=1, filename="x.txt") assert "Failed to get file information" in str(ei.value) @@ -357,3 +361,467 @@ def fake_get(object_name, expires): assert any(item["object_name"] == "bad" and item["success"] is False for item in out["results"]) +# --- Tests for build_content_disposition_header --- + +def test_build_content_disposition_header_ascii(): + """Test build_content_disposition_header with ASCII filename""" + result = file_management_app.build_content_disposition_header("test.pdf") + assert result == 'attachment; filename="test.pdf"' + + +def test_build_content_disposition_header_non_ascii(): + """Test build_content_disposition_header with non-ASCII filename""" + result = file_management_app.build_content_disposition_header("测试文件.pdf") + assert 'attachment; filename=' in result + assert 'filename*=UTF-8' in result + assert '测试文件' in result or '%E6%B5%8B%E8%AF%95' in result + + +def test_build_content_disposition_header_non_ascii_with_extension(): + """Test build_content_disposition_header with non-ASCII filename and extension""" + result = file_management_app.build_content_disposition_header("文档.docx") + assert 'attachment; filename=' in result + assert 'filename*=UTF-8' in result + assert '.docx' in result + + +def test_build_content_disposition_header_exception_handling(monkeypatch): + """Test build_content_disposition_header exception handling""" + def boom(_value: str, safe: str = "") -> str: + raise RuntimeError("quote failure") + + monkeypatch.setattr("backend.apps.file_management_app.quote", boom) + + result = file_management_app.build_content_disposition_header("测试.pdf") + assert 'attachment; filename=' in result + assert 'filename*=UTF-8' not in result + + +# --- Tests for get_storage_file with filename parameter --- + +@pytest.mark.asyncio +async def test_get_storage_file_stream_with_filename(monkeypatch): + """Test get_storage_file stream mode with filename parameter""" + async def fake_get_stream(object_name): + async def gen(): + yield b"chunk1" + return gen(), "application/pdf" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + resp = await file_management_app.get_storage_file( + object_name="attachments/file.pdf", + download="stream", + expires=60, + filename="原始文件名.pdf" + ) + assert resp.media_type == "application/pdf" + content_disposition = resp.headers.get("content-disposition", "") + assert "原始文件名.pdf" in content_disposition or "filename*=UTF-8" in content_disposition + + +@pytest.mark.asyncio +async def test_get_storage_file_stream_without_filename(monkeypatch): + """Test get_storage_file stream mode without filename parameter (extract from object_name)""" + async def fake_get_stream(object_name): + async def gen(): + yield b"chunk1" + return gen(), "text/plain" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + resp = await file_management_app.get_storage_file( + object_name="attachments/test.txt", + download="stream", + expires=60, + filename=None + ) + assert resp.media_type == "text/plain" + content_disposition = resp.headers.get("content-disposition", "") + assert "test.txt" in content_disposition + + +@pytest.mark.asyncio +async def test_get_storage_file_stream_error(monkeypatch): + """Test get_storage_file stream mode error handling""" + async def fake_get_stream(object_name): + raise RuntimeError("Stream error") + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + with pytest.raises(Exception) as ei: + await file_management_app.get_storage_file( + object_name="test.txt", + download="stream", + expires=60, + filename="test.txt" + ) + assert "Failed to get file information" in str(ei.value) + + +# --- Tests for download_datamate_file --- + +@pytest.mark.asyncio +async def test_download_datamate_file_with_url(monkeypatch): + """Test download_datamate_file with full URL""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"file content" + mock_response.headers = {"Content-Type": "application/pdf", "Content-Disposition": 'attachment; filename="test.pdf"'} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + resp = await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename="test.pdf", + authorization=None, + ) + assert resp.media_type == "application/pdf" + content_disposition = resp.headers.get("content-disposition", "") + assert "test.pdf" in content_disposition + + +@pytest.mark.asyncio +async def test_download_datamate_file_with_parts(monkeypatch): + """Test download_datamate_file with base_url, dataset_id, file_id""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"file content" + mock_response.headers = {"Content-Type": "application/pdf"} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + resp = await file_management_app.download_datamate_file( + url=None, + base_url="http://example.com", + dataset_id="123", + file_id="456", + filename=None, + authorization=None, + ) + assert resp.media_type == "application/pdf" + + +@pytest.mark.asyncio +async def test_download_datamate_file_404_error(monkeypatch): + """Test download_datamate_file with 404 error""" + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.headers = {} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + with pytest.raises(Exception) as ei: + await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization=None, + ) + assert "File not found" in str(ei.value) + + +@pytest.mark.asyncio +async def test_download_datamate_file_http_error(monkeypatch): + """Test download_datamate_file with HTTP error""" + import httpx + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=httpx.HTTPError("Network error")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + with pytest.raises(Exception) as ei: + await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization=None, + ) + assert "Failed to download file from URL" in str(ei.value) + + +@pytest.mark.asyncio +async def test_download_datamate_file_missing_params(): + """Test download_datamate_file with missing parameters""" + with pytest.raises(Exception) as ei: + await file_management_app.download_datamate_file( + url=None, + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization=None, + ) + assert "Either url or (base_url, dataset_id, file_id) must be provided" in str(ei.value) + + +@pytest.mark.asyncio +async def test_download_datamate_file_extract_filename_from_content_disposition(monkeypatch): + """Test download_datamate_file extracting filename from Content-Disposition header""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"file content" + mock_response.headers = {"Content-Type": "application/pdf", "Content-Disposition": 'attachment; filename="extracted.pdf"'} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + resp = await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization=None, + ) + content_disposition = resp.headers.get("content-disposition", "") + assert "extracted.pdf" in content_disposition + + +@pytest.mark.asyncio +async def test_download_datamate_file_extract_filename_from_url(monkeypatch): + """Test download_datamate_file extracting filename from URL path""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"file content" + mock_response.headers = {"Content-Type": "application/pdf"} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + resp = await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization=None, + ) + content_disposition = resp.headers.get("content-disposition", "") + assert "attachment" in content_disposition + + +@pytest.mark.asyncio +async def test_download_datamate_file_with_authorization(monkeypatch): + """Test download_datamate_file with authorization header""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"file content" + mock_response.headers = {"Content-Type": "application/pdf"} + mock_response.raise_for_status = MagicMock() + + call_args_list = [] + async def fake_httpx_get(url, headers=None, follow_redirects=True): + call_args_list.append((url, headers)) + return mock_response + + mock_client = MagicMock() + mock_client.get = fake_httpx_get + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization="Bearer token123", + ) + assert len(call_args_list) > 0 + assert call_args_list[0][1].get("Authorization") == "Bearer token123" + + +@pytest.mark.asyncio +async def test_download_datamate_file_unexpected_exception(monkeypatch): + """Unexpected exceptions should surface with new 500 message.""" + + def fail_normalize(_url: str): + raise ValueError("boom") + + monkeypatch.setattr( + file_management_app, + "_normalize_datamate_download_url", + fail_normalize, + ) + + with pytest.raises(Exception) as exc: + await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization=None, + ) + assert "Failed to download file: boom" in str(exc.value) + + +# --- Tests for _normalize_datamate_download_url --- + +def test_normalize_datamate_download_url_valid(): + """Test _normalize_datamate_download_url with valid URL""" + url = "http://example.com/api/data-management/datasets/123/files/456/download" + result = file_management_app._normalize_datamate_download_url(url) + assert result == url + + +def test_normalize_datamate_download_url_adds_scheme(): + """URLs without scheme should default to https://""" + url = "example.com/api/data-management/datasets/123/files/456/download" + result = file_management_app._normalize_datamate_download_url(url) + assert result.startswith("http://example.com") + + +def test_normalize_datamate_download_url_with_prefix(): + """Test _normalize_datamate_download_url with URL prefix""" + url = "http://example.com/prefix/api/data-management/datasets/123/files/456/download" + result = file_management_app._normalize_datamate_download_url(url) + assert "/prefix/api/data-management/datasets/123/files/456/download" in result + + +def test_normalize_datamate_download_url_missing_data_management(): + """Test _normalize_datamate_download_url with missing data-management segment""" + with pytest.raises(Exception) as ei: + file_management_app._normalize_datamate_download_url("http://example.com/invalid/url") + assert "missing 'data-management' segment" in str(ei.value) + + +def test_normalize_datamate_download_url_invalid_structure(): + """Test _normalize_datamate_download_url with invalid URL structure""" + with pytest.raises(Exception) as ei: + file_management_app._normalize_datamate_download_url("http://example.com/data-management/invalid") + assert "unable to parse dataset_id or file_id" in str(ei.value) + + +# --- Tests for _build_datamate_url_from_parts --- + +def test_build_datamate_url_from_parts_with_api(): + """Test _build_datamate_url_from_parts with base_url ending with /api""" + result = file_management_app._build_datamate_url_from_parts( + "http://example.com/api", + "123", + "456" + ) + assert "/api/data-management/datasets/123/files/456/download" in result + + +def test_build_datamate_url_from_parts_without_scheme(): + """base_url without scheme should default to https://""" + result = file_management_app._build_datamate_url_from_parts( + "example.com", + "123", + "456" + ) + assert result.startswith("http://example.com/api/") + + +def test_build_datamate_url_from_parts_without_api(): + """Test _build_datamate_url_from_parts with base_url without /api""" + result = file_management_app._build_datamate_url_from_parts( + "http://example.com", + "123", + "456" + ) + assert "/api/data-management/datasets/123/files/456/download" in result + + +def test_build_datamate_url_from_parts_with_slash(): + """Test _build_datamate_url_from_parts with base_url ending with slash""" + result = file_management_app._build_datamate_url_from_parts( + "http://example.com/", + "123", + "456" + ) + assert "/api/data-management/datasets/123/files/456/download" in result + + +def test_build_datamate_url_from_parts_appends_api_segment(): + """Ensure /api is appended when missing from base path""" + result = file_management_app._build_datamate_url_from_parts( + "http://example.com/service", + "123", + "456" + ) + assert result.startswith("http://example.com/service/api/") + + +def test_build_datamate_url_from_parts_defaults_api_when_no_path(): + """Ensure empty base path defaults to /api""" + result = file_management_app._build_datamate_url_from_parts( + "http://example.com", + "123", + "456" + ) + assert result.startswith("http://example.com/api/") + + +def test_build_datamate_url_from_parts_trailing_slash_branch(monkeypatch): + """Force branch where rstrip result still ends with slash.""" + + class DummyPath: + def rstrip(self, chars=None): + return "/prefix/" + + class DummyParseResult: + scheme = "http" + netloc = "example.com" + path = DummyPath() + + def fake_urlparse(_url: str): + return DummyParseResult() + + monkeypatch.setattr("backend.apps.file_management_app.urlparse", fake_urlparse) + + result = file_management_app._build_datamate_url_from_parts( + "http://placeholder", + "123", + "456" + ) + assert result.startswith("http://example.com/prefix/api/") + + +def test_build_datamate_url_from_parts_empty_base_url(): + """Test _build_datamate_url_from_parts with empty base_url""" + with pytest.raises(Exception) as ei: + file_management_app._build_datamate_url_from_parts("", "123", "456") + assert "base_url is required" in str(ei.value) + + diff --git a/test/backend/app/test_image_app.py b/test/backend/app/test_image_app.py index 60db95d53..6c1d8f54c 100644 --- a/test/backend/app/test_image_app.py +++ b/test/backend/app/test_image_app.py @@ -202,3 +202,159 @@ async def test_proxy_image_logging(monkeypatch): # Verify the mock was called with the expected URL mock_session.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_proxy_image_stream_format(monkeypatch): + """Test proxy_image with format=stream""" + import base64 + from io import BytesIO + + # Create mock response with base64 image data + test_image_bytes = b"fake image data" + test_base64 = base64.b64encode(test_image_bytes).decode('utf-8') + + success_response_stream = { + "success": True, + "base64": test_base64, + "content_type": "image/png" + } + + async def fake_proxy_image_impl(decoded_url): + return success_response_stream + + from backend.apps import image_app + monkeypatch.setattr(image_app, "proxy_image_impl", fake_proxy_image_impl) + + resp = await image_app.proxy_image(url=encoded_test_url, format="stream") + + # Should return StreamingResponse + assert hasattr(resp, 'media_type') + assert resp.media_type == "image/png" + assert "Cache-Control" in resp.headers + assert resp.headers["Cache-Control"] == "public, max-age=3600" + + # Verify content + content = b"" + async for chunk in resp.body_iterator: + content += chunk + assert content == test_image_bytes + + +@pytest.mark.asyncio +async def test_proxy_image_stream_format_error(monkeypatch): + """Test proxy_image with format=stream when proxy_image_impl returns error""" + error_response = { + "success": False, + "error": "Failed to fetch image" + } + + async def fake_proxy_image_impl(decoded_url): + return error_response + + from backend.apps import image_app + from fastapi import HTTPException + + monkeypatch.setattr(image_app, "proxy_image_impl", fake_proxy_image_impl) + + with pytest.raises(HTTPException) as exc_info: + await image_app.proxy_image(url=encoded_test_url, format="stream") + + assert exc_info.value.status_code == 502 + assert "Failed to fetch image" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_proxy_image_stream_format_base64_decode_error(monkeypatch): + """Test proxy_image with format=stream when base64 decoding fails""" + import base64 + + # Invalid base64 data + success_response_invalid = { + "success": True, + "base64": "invalid base64!!!", + "content_type": "image/png" + } + + async def fake_proxy_image_impl(decoded_url): + return success_response_invalid + + from backend.apps import image_app + from fastapi import HTTPException + + monkeypatch.setattr(image_app, "proxy_image_impl", fake_proxy_image_impl) + + with pytest.raises(HTTPException) as exc_info: + await image_app.proxy_image(url=encoded_test_url, format="stream") + + assert exc_info.value.status_code == 502 + + +@pytest.mark.asyncio +async def test_proxy_image_stream_format_exception(monkeypatch): + """Test proxy_image with format=stream when exception occurs""" + async def fake_proxy_image_impl(decoded_url): + raise ValueError("Unexpected error") + + from backend.apps import image_app + from fastapi import HTTPException + + monkeypatch.setattr(image_app, "proxy_image_impl", fake_proxy_image_impl) + + with pytest.raises(HTTPException) as exc_info: + await image_app.proxy_image(url=encoded_test_url, format="stream") + + assert exc_info.value.status_code == 502 + assert "Unexpected error" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_proxy_image_json_format_default(monkeypatch): + """Test proxy_image with format=json (default)""" + async def fake_proxy_image_impl(decoded_url): + return success_response + + from backend.apps import image_app + + monkeypatch.setattr(image_app, "proxy_image_impl", fake_proxy_image_impl) + + result = await image_app.proxy_image(url=encoded_test_url, format="json") + + assert result == success_response + + +@pytest.mark.asyncio +async def test_proxy_image_json_format_exception(monkeypatch): + """Test proxy_image with format=json when exception occurs""" + async def fake_proxy_image_impl(decoded_url): + raise RuntimeError("Service unavailable") + + from backend.apps import image_app + + monkeypatch.setattr(image_app, "proxy_image_impl", fake_proxy_image_impl) + + result = await image_app.proxy_image(url=encoded_test_url, format="json") + + assert result["success"] is False + assert "Service unavailable" in result["error"] + + +@pytest.mark.asyncio +async def test_proxy_image_url_decoding(monkeypatch): + """Test proxy_image correctly decodes URL""" + special_url = "https://example.com/image with spaces.jpg" + encoded_special_url = "https%3A%2F%2Fexample.com%2Fimage%20with%20spaces.jpg" + + call_urls = [] + async def fake_proxy_image_impl(decoded_url): + call_urls.append(decoded_url) + return success_response + + from backend.apps import image_app + + monkeypatch.setattr(image_app, "proxy_image_impl", fake_proxy_image_impl) + + await image_app.proxy_image(url=encoded_special_url, format="json") + + assert len(call_urls) == 1 + assert call_urls[0] == special_url diff --git a/test/backend/app/test_me_model_managment_app.py b/test/backend/app/test_me_model_managment_app.py index 2234013db..141e41e1e 100644 --- a/test/backend/app/test_me_model_managment_app.py +++ b/test/backend/app/test_me_model_managment_app.py @@ -198,85 +198,33 @@ async def mock_list_models_with_filter(type: str = Query(None)): assert response_data["data"][0]["name"] == "model2" -@pytest.mark.asyncio -async def test_get_me_models_env_not_configured_returns_skip_message_and_empty_list(): - """When ME env not configured, endpoint returns 200 with skip message and empty data.""" - with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=False)): - response = client.get("/me/model/list") - - assert response.status_code == HTTPStatus.OK - data = response.json() - assert data["message"] == "Retrieve skipped" - assert data["data"] == [] - - -@pytest.mark.asyncio -async def test_get_me_models_not_found_filter(): - # Patch the service impl to raise NotFoundException so the route returns 404 - with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.get_me_models_impl') as mock_impl: - mock_impl.side_effect = NotFoundException( - "No models found with type 'nonexistent'.") - response = client.get("/me/model/list?type=nonexistent") - - # Assertions - route maps NotFoundException -> 404 and raises HTTPException with detail - assert response.status_code == HTTPStatus.NOT_FOUND - body = response.json() - assert body["detail"] == "ModelEngine model not found" - - -@pytest.mark.asyncio -async def test_get_me_models_timeout(): - """Test model list retrieval with timeout via real route""" - # Patch service to raise TimeoutException so the real route returns 408 - with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.get_me_models_impl') as mock_impl: - mock_impl.side_effect = TimeoutException("Request timeout.") - - response = client.get("/me/model/list") - - assert response.status_code == HTTPStatus.REQUEST_TIMEOUT - body = response.json() - assert body["detail"] == "Failed to get ModelEngine model list: timeout" - - -@pytest.mark.asyncio -async def test_get_me_models_exception(): - """Test model list retrieval with generic exception""" - with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.get_me_models_impl') as mock_impl: - mock_impl.side_effect = Exception("boom") - response = client.get("/me/model/list") +# NOTE: The following tests are disabled because /me/model/list endpoint has been removed +# Model listing is now handled through the main model management endpoints - # Assertions - assert response.status_code == 500 - response_data = response.json() - assert response_data["detail"] == "Failed to get ModelEngine model list" +# @pytest.mark.asyncio +# async def test_get_me_models_env_not_configured_returns_skip_message_and_empty_list(): +# """When ME env not configured, endpoint returns 200 with skip message and empty data.""" +# pass +# @pytest.mark.asyncio +# async def test_get_me_models_not_found_filter(): +# """Test model list retrieval with not found filter""" +# pass -@pytest.mark.asyncio -async def test_get_me_models_success_response(): - """Test successful model list retrieval with proper JSONResponse format""" - # Mock the service implementation to return test data - with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.get_me_models_impl') as mock_impl: - mock_impl.return_value = [ - {"name": "model1", "type": "embed", "version": "1.0"}, - {"name": "model2", "type": "chat", "version": "1.0"} - ] +# @pytest.mark.asyncio +# async def test_get_me_models_timeout(): +# """Test model list retrieval with timeout via real route""" +# pass - # Test the endpoint - response = client.get("/me/model/list") +# @pytest.mark.asyncio +# async def test_get_me_models_exception(): +# """Test model list retrieval with generic exception""" +# pass - # Assertions - assert response.status_code == HTTPStatus.OK - response_data = response.json() - assert response_data["message"] == "Successfully retrieved" - assert response_data["data"] == [ - {"name": "model1", "type": "embed", "version": "1.0"}, - {"name": "model2", "type": "chat", "version": "1.0"} - ] - assert len(response_data["data"]) == 2 +# @pytest.mark.asyncio +# async def test_get_me_models_success_response(): +# """Test successful model list retrieval with proper JSONResponse format""" +# pass @pytest.mark.asyncio @@ -288,92 +236,64 @@ async def test_check_me_connectivity_env_not_configured_returns_skip_message(): assert response.status_code == HTTPStatus.OK body = response.json() assert body["connectivity"] is False - assert body["message"] == "ModelEngine platform necessary environment variables not configured. Healthcheck skipped." + assert body["message"] == "ModelEngine platform environment variables not configured. Healthcheck skipped." @pytest.mark.asyncio async def test_check_me_connectivity_success(): """Test successful ME connectivity check""" - # Mock the check_me_connectivity_impl function in the app module + # Mock the check_me_connectivity function from the service with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.check_me_connectivity_impl') as mock_connectivity: - mock_connectivity.return_value = ( - HTTPStatus.OK, - "Connection successful", - { - "status": "Connected", - "desc": "Connection successful", - "connect_status": "available" - } - ) - + with patch('backend.apps.me_model_managment_app.check_me_connectivity', AsyncMock(return_value=None)): # Test with TestClient response = client.get("/me/healthcheck") # Assertions - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK response_data = response.json() - assert response_data["connectivity"] - # Updated success message string - with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.check_me_connectivity_impl') as mock_connectivity2: - mock_connectivity2.return_value = ( - HTTPStatus.OK, - "Connection successful", - { - "status": "Connected", - "desc": "Connection successful", - "connect_status": "available" - } - ) - response2 = client.get("/me/healthcheck") - assert response2.status_code == 200 - assert response2.json()[ - "message"] == "ModelEngine platform connect successfully." + assert response_data["connectivity"] is True + assert response_data["message"] == "ModelEngine platform connected successfully." @pytest.mark.asyncio async def test_check_me_connectivity_failure(): """Trigger MEConnectionException to simulate connectivity failure""" - # Patch the impl to raise MEConnectionException so the route returns 503 + # Patch the connectivity check to raise MEConnectionException so the route returns 503 with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.check_me_connectivity_impl') as mock_connectivity: - mock_connectivity.side_effect = MEConnectionException( - "Downstream 404 or similar") - + with patch('backend.apps.me_model_managment_app.check_me_connectivity', AsyncMock(side_effect=MEConnectionException("Downstream 404 or similar"))): response = client.get("/me/healthcheck") assert response.status_code == HTTPStatus.SERVICE_UNAVAILABLE + body = response.json() + assert "ModelEngine connection failed" in body["detail"] @pytest.mark.asyncio async def test_check_me_connectivity_timeout(): """Test ME connectivity check with timeout""" - # Mock the impl to raise TimeoutException so the route returns 408 + # Mock the connectivity check to raise TimeoutException so the route returns 408 with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.check_me_connectivity_impl') as mock_connectivity: - mock_connectivity.side_effect = TimeoutException( - "timeout simulated") - + with patch('backend.apps.me_model_managment_app.check_me_connectivity', AsyncMock(side_effect=TimeoutException("timeout simulated"))): response = client.get("/me/healthcheck") - # Assertions - route maps TimeoutException -> 408 and returns status/desc/connect_status + # Assertions - route maps TimeoutException -> 408 assert response.status_code == HTTPStatus.REQUEST_TIMEOUT + body = response.json() + assert body["detail"] == "ModelEngine connection timeout." @pytest.mark.asyncio async def test_check_me_connectivity_generic_exception(): """Test ME connectivity check with generic exception""" - # Mock the impl to raise a generic Exception so the route returns 500 + # Mock the connectivity check to raise a generic Exception so the route returns 500 with patch('backend.apps.me_model_managment_app.check_me_variable_set', AsyncMock(return_value=True)): - with patch('backend.apps.me_model_managment_app.check_me_connectivity_impl') as mock_connectivity: - mock_connectivity.side_effect = Exception( - "Unexpected error occurred") - + with patch('backend.apps.me_model_managment_app.check_me_connectivity', AsyncMock(side_effect=Exception("Unexpected error occurred"))): response = client.get("/me/healthcheck") - # Assertions - route maps generic Exception -> 500 and returns status/desc/connect_status + # Assertions - route maps generic Exception -> 500 assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + body = response.json() + assert "ModelEngine healthcheck failed" in body["detail"] @pytest.mark.asyncio diff --git a/test/backend/app/test_model_managment_app.py b/test/backend/app/test_model_managment_app.py index 18a41b54f..6162f1773 100644 --- a/test/backend/app/test_model_managment_app.py +++ b/test/backend/app/test_model_managment_app.py @@ -6,8 +6,15 @@ from http import HTTPStatus from unittest.mock import patch, MagicMock -# Add path for correct imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend")) +# Add project root to sys.path so that the top-level `backend` package is importable +PROJECT_ROOT = os.path.join(os.path.dirname(__file__), "../../..") +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +# Also add the backend source directory so that subpackages like `consts` can be imported directly +BACKEND_ROOT = os.path.join(PROJECT_ROOT, "backend") +if BACKEND_ROOT not in sys.path: + sys.path.insert(0, BACKEND_ROOT) # Patch environment variables before any imports that might use them os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') @@ -513,12 +520,21 @@ async def mock_update_single(*args, **kwargs): "provider": "huggingface" } response = client.post( - "/model/update", json=update_data, headers=auth_header) + "/model/update", + params={"display_name": "Updated Test Model"}, + json=update_data, + headers=auth_header, + ) assert response.status_code == HTTPStatus.OK data = response.json() assert "Model updated successfully" in data["message"] - mock_update.assert_called_once_with(user_credentials[0], user_credentials[1], update_data) + mock_update.assert_called_once_with( + user_credentials[0], + user_credentials[1], + "Updated Test Model", + update_data, + ) @pytest.mark.asyncio @@ -527,8 +543,8 @@ async def test_update_single_model_conflict(client, auth_header, user_credential mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) mock_update = mocker.patch( - 'apps.model_managment_app.update_single_model_for_tenant', - side_effect=ValueError("Name 'Conflicting Name' is already in use, please choose another display name") + 'apps.model_managment_app.update_single_model_for_tenant', + side_effect=ValueError("Name 'Conflicting Name' is already in use, please choose another display name"), ) update_data = { @@ -541,13 +557,22 @@ async def test_update_single_model_conflict(client, auth_header, user_credential "provider": "huggingface" } response = client.post( - "/model/update", json=update_data, headers=auth_header) + "/model/update", + params={"display_name": "Conflicting Name"}, + json=update_data, + headers=auth_header, + ) assert response.status_code == HTTPStatus.CONFLICT data = response.json() # Now we return the actual error message assert "Name 'Conflicting Name' is already in use" in data.get("detail", "") - mock_update.assert_called_once_with(user_credentials[0], user_credentials[1], update_data) + mock_update.assert_called_once_with( + user_credentials[0], + user_credentials[1], + "Conflicting Name", + update_data, + ) # Tests for /model/batch_update endpoint diff --git a/test/backend/database/test_model_managment_db.py b/test/backend/database/test_model_managment_db.py index 2537301c7..34160fc3b 100644 --- a/test/backend/database/test_model_managment_db.py +++ b/test/backend/database/test_model_managment_db.py @@ -150,3 +150,247 @@ def test_get_model_by_model_id_fills_default_chunk_sizes(monkeypatch): assert out is not None assert out["expected_chunk_size"] == 1024 assert out["maximum_chunk_size"] == 1536 + + +def test_create_model_record(monkeypatch): + """Test create_model_record function (covers lines 23-42)""" + mock_result = MagicMock() + mock_result.rowcount = 1 + + mock_stmt = MagicMock() + mock_stmt.values.return_value = mock_stmt + + mock_insert = MagicMock(return_value=mock_stmt) + monkeypatch.setattr("backend.database.model_management_db.insert", mock_insert) + + session = MagicMock() + session.execute.return_value = mock_result + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + + # Mock clean_string_values and add_creation_tracking + monkeypatch.setattr("backend.database.model_management_db.db_client.clean_string_values", lambda x: x) + monkeypatch.setattr("backend.database.model_management_db.add_creation_tracking", lambda x, uid: x) + monkeypatch.setattr("backend.database.model_management_db.func.current_timestamp", MagicMock()) + + model_data = {"model_name": "test", "model_type": "llm"} + result = model_mgmt_db.create_model_record(model_data, user_id="u1", tenant_id="t1") + + assert result is True + session.execute.assert_called_once() + + +def test_update_model_record(monkeypatch): + """Test update_model_record function (covers lines 63-84)""" + mock_result = MagicMock() + mock_result.rowcount = 1 + + mock_stmt = MagicMock() + mock_stmt.where.return_value = mock_stmt + mock_stmt.values.return_value = mock_stmt + + mock_update = MagicMock(return_value=mock_stmt) + monkeypatch.setattr("backend.database.model_management_db.update", mock_update) + + session = MagicMock() + session.execute.return_value = mock_result + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + + # Mock clean_string_values and add_update_tracking + monkeypatch.setattr("backend.database.model_management_db.db_client.clean_string_values", lambda x: x) + monkeypatch.setattr("backend.database.model_management_db.add_update_tracking", lambda x, uid: x) + monkeypatch.setattr("backend.database.model_management_db.func.current_timestamp", MagicMock()) + + update_data = {"model_name": "updated"} + result = model_mgmt_db.update_model_record(1, update_data, user_id="u1", tenant_id="t1") + + assert result is True + session.execute.assert_called_once() + + +def test_delete_model_record(monkeypatch): + """Test delete_model_record function (covers lines 99-119)""" + mock_result = MagicMock() + mock_result.rowcount = 1 + + mock_stmt = MagicMock() + mock_stmt.where.return_value = mock_stmt + mock_stmt.values.return_value = mock_stmt + + mock_update = MagicMock(return_value=mock_stmt) + monkeypatch.setattr("backend.database.model_management_db.update", mock_update) + + session = MagicMock() + session.execute.return_value = mock_result + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + + # Mock add_update_tracking + monkeypatch.setattr("backend.database.model_management_db.add_update_tracking", lambda x, uid: x) + monkeypatch.setattr("backend.database.model_management_db.func.current_timestamp", MagicMock()) + + result = model_mgmt_db.delete_model_record(1, user_id="u1", tenant_id="t1") + + assert result is True + session.execute.assert_called_once() + + +def test_get_model_records_with_tenant_id(monkeypatch): + """Test get_model_records with tenant_id filter (covers lines 137->141)""" + mock_model = SimpleNamespace( + model_id=4, + model_factory="openai", + model_type="llm", + tenant_id="tenant4", + delete_flag="N", + ) + mock_scalars = MagicMock() + mock_scalars.all.return_value = [mock_model] + session = MagicMock() + session.scalars.return_value = mock_scalars + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.model_management_db.as_dict", lambda obj: obj.__dict__) + + records = model_mgmt_db.get_model_records({"model_type": "llm"}, tenant_id="tenant4") + assert len(records) == 1 + assert records[0]["tenant_id"] == "tenant4" + + +def test_get_model_records_with_none_filter(monkeypatch): + """Test get_model_records with None value in filter (covers line 145)""" + mock_model = SimpleNamespace( + model_id=5, + model_factory="openai", + model_type="llm", + tenant_id="tenant5", + delete_flag="N", + display_name=None, + ) + mock_scalars = MagicMock() + mock_scalars.all.return_value = [mock_model] + session = MagicMock() + session.scalars.return_value = mock_scalars + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.model_management_db.as_dict", lambda obj: obj.__dict__) + + records = model_mgmt_db.get_model_records({"display_name": None}, tenant_id="tenant5") + assert len(records) == 1 + + +def test_get_model_by_display_name(monkeypatch): + """Test get_model_by_display_name function (covers lines 178-185)""" + mock_model = SimpleNamespace( + model_id=6, + model_factory="openai", + model_name="gpt-4", + display_name="GPT-4", + tenant_id="tenant6", + delete_flag="N", + ) + mock_scalars = MagicMock() + mock_scalars.all.return_value = [mock_model] + session = MagicMock() + session.scalars.return_value = mock_scalars + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.model_management_db.as_dict", lambda obj: obj.__dict__) + + result = model_mgmt_db.get_model_by_display_name("GPT-4", "tenant6") + assert result is not None + assert result["display_name"] == "GPT-4" + + +def test_get_model_id_by_display_name(monkeypatch): + """Test get_model_id_by_display_name function (covers lines 199-200)""" + mock_model = SimpleNamespace( + model_id=7, + model_factory="openai", + model_name="gpt-4", + display_name="GPT-4", + tenant_id="tenant7", + delete_flag="N", + ) + mock_scalars = MagicMock() + mock_scalars.all.return_value = [mock_model] + session = MagicMock() + session.scalars.return_value = mock_scalars + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.model_management_db.as_dict", lambda obj: obj.__dict__) + + result = model_mgmt_db.get_model_id_by_display_name("GPT-4", "tenant7") + assert result == 7 + + +def test_get_model_by_model_id_with_tenant_id(monkeypatch): + """Test get_model_by_model_id with tenant_id filter (covers lines 222->226)""" + mock_model = SimpleNamespace( + model_id=8, + model_factory="openai", + model_type="llm", + tenant_id="tenant8", + delete_flag="N", + ) + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_model + session = MagicMock() + session.scalars.return_value = mock_scalars + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + + result = model_mgmt_db.get_model_by_model_id(8, tenant_id="tenant8") + assert result is not None + assert result["model_id"] == 8 + + +def test_get_model_by_name_factory(monkeypatch): + """Test get_model_by_name_factory function (covers lines 269-274)""" + mock_model = SimpleNamespace( + model_id=9, + model_factory="openai", + model_name="gpt-4", + tenant_id="tenant9", + delete_flag="N", + ) + mock_scalars = MagicMock() + mock_scalars.all.return_value = [mock_model] + session = MagicMock() + session.scalars.return_value = mock_scalars + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.model_management_db.as_dict", lambda obj: obj.__dict__) + + result = model_mgmt_db.get_model_by_name_factory("gpt-4", "openai", "tenant9") + assert result is not None + assert result["model_name"] == "gpt-4" + assert result["model_factory"] == "openai" diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 099f8692c..9c202209c 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -271,12 +271,13 @@ async def test_get_creating_sub_agent_id_service_new_agent(mock_search, mock_cre ) +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test successful retrieval of an agent's information by ID. @@ -284,7 +285,7 @@ async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_t 1. The function correctly retrieves the agent's basic information 2. It fetches the associated tools 3. It gets the sub-agent ID list - 4. It returns a complete agent information structure + 4. It returns a complete agent information structure with availability status """ # Setup mock_agent_info = { @@ -302,6 +303,9 @@ async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_t # Mock get_model_by_model_id - return None for model_id=None mock_get_model_by_model_id.return_value = None + + # Mock check_agent_availability - agent is available + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -314,7 +318,9 @@ async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_t "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, - "business_logic_model_name": None + "business_logic_model_name": None, + "is_available": True, + "unavailable_reasons": [] } assert result == expected_result mock_search_agent_info.assert_called_once_with(123, "test_tenant") @@ -322,6 +328,7 @@ async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_t agent_id=123, tenant_id="test_tenant") mock_query_sub_agents_id.assert_called_once_with( main_agent_id=123, tenant_id="test_tenant") + mock_check_availability.assert_called_once() @patch('backend.services.agent_service.get_model_by_model_id') @@ -1063,9 +1070,10 @@ async def test_export_agent_impl_no_mcp_tools(mock_get_current_user_info, mock_e mock_export_data_format.assert_called_once() +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.search_agent_info_by_agent_id') -async def test_get_agent_info_impl_with_tool_error(mock_search_agent_info, mock_get_model_by_model_id): +async def test_get_agent_info_impl_with_tool_error(mock_search_agent_info, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with an error in retrieving tool information. @@ -1081,6 +1089,7 @@ async def test_get_agent_info_impl_with_tool_error(mock_search_agent_info, mock_ "business_description": "Test agent" } mock_search_agent_info.return_value = mock_agent_info + mock_check_availability.return_value = (True, []) # Mock the search_tools_for_sub_agent function to raise an exception with patch('backend.services.agent_service.search_tools_for_sub_agent') as mock_search_tools, \ @@ -1097,15 +1106,18 @@ async def test_get_agent_info_impl_with_tool_error(mock_search_agent_info, mock_ assert result["tools"] == [] assert result["sub_agent_id_list"] == [] assert result["model_name"] is None + assert result["is_available"] == True + assert result["unavailable_reasons"] == [] mock_search_agent_info.assert_called_once_with(123, "test_tenant") +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_sub_agent_error(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_sub_agent_error(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with an error in retrieving sub agent id list. @@ -1128,6 +1140,7 @@ async def test_get_agent_info_impl_sub_agent_error(mock_search_agent_info, mock_ # Mock query_sub_agents_id_list to raise an exception mock_query_sub_agents_id.side_effect = Exception("Sub agent query error") mock_get_model_by_model_id.return_value = None + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -1137,6 +1150,8 @@ async def test_get_agent_info_impl_sub_agent_error(mock_search_agent_info, mock_ assert result["tools"] == mock_tools assert result["sub_agent_id_list"] == [] assert result["model_name"] is None + assert result["is_available"] == True + assert result["unavailable_reasons"] == [] mock_search_agent_info.assert_called_once_with(123, "test_tenant") mock_search_tools.assert_called_once_with( agent_id=123, tenant_id="test_tenant") @@ -1144,12 +1159,13 @@ async def test_get_agent_info_impl_sub_agent_error(mock_search_agent_info, mock_ main_agent_id=123, tenant_id="test_tenant") +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_with_model_id_success(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_with_model_id_success(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with a valid model_id. @@ -1179,6 +1195,9 @@ async def test_get_agent_info_impl_with_model_id_success(mock_search_agent_info, "provider": "openai" } mock_get_model_by_model_id.return_value = mock_model_info + + # Mock check_agent_availability - agent is available + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -1191,18 +1210,21 @@ async def test_get_agent_info_impl_with_model_id_success(mock_search_agent_info, "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", - "business_logic_model_name": None + "business_logic_model_name": None, + "is_available": True, + "unavailable_reasons": [] } assert result == expected_result mock_get_model_by_model_id.assert_called_once_with(456) +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_with_model_id_no_display_name(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_with_model_id_no_display_name(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with model_id but model has no display_name. @@ -1231,6 +1253,7 @@ async def test_get_agent_info_impl_with_model_id_no_display_name(mock_search_age # No display_name field } mock_get_model_by_model_id.return_value = mock_model_info + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -1243,18 +1266,21 @@ async def test_get_agent_info_impl_with_model_id_no_display_name(mock_search_age "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, - "business_logic_model_name": None + "business_logic_model_name": None, + "is_available": True, + "unavailable_reasons": [] } assert result == expected_result mock_get_model_by_model_id.assert_called_once_with(456) +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_with_model_id_none_model_info(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_with_model_id_none_model_info(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with model_id but get_model_by_model_id returns None. @@ -1278,6 +1304,7 @@ async def test_get_agent_info_impl_with_model_id_none_model_info(mock_search_age # Mock get_model_by_model_id to return None mock_get_model_by_model_id.return_value = None + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -1290,18 +1317,21 @@ async def test_get_agent_info_impl_with_model_id_none_model_info(mock_search_age "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, - "business_logic_model_name": None + "business_logic_model_name": None, + "is_available": True, + "unavailable_reasons": [] } assert result == expected_result mock_get_model_by_model_id.assert_called_once_with(456) +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_with_business_logic_model(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_with_business_logic_model(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with business_logic_model_id. @@ -1348,6 +1378,7 @@ def mock_get_model(model_id): return None mock_get_model_by_model_id.side_effect = mock_get_model + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -1361,7 +1392,9 @@ def mock_get_model(model_id): "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", - "business_logic_model_name": "Claude-3.5" + "business_logic_model_name": "Claude-3.5", + "is_available": True, + "unavailable_reasons": [] } assert result == expected_result @@ -1371,12 +1404,13 @@ def mock_get_model(model_id): mock_get_model_by_model_id.assert_any_call(789) +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_with_business_logic_model_none(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_with_business_logic_model_none(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with business_logic_model_id but get_model_by_model_id returns None. @@ -1415,6 +1449,7 @@ def mock_get_model(model_id): return None mock_get_model_by_model_id.side_effect = mock_get_model + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -1428,7 +1463,9 @@ def mock_get_model(model_id): "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", - "business_logic_model_name": None # Should be None when model info is not found + "business_logic_model_name": None, # Should be None when model info is not found + "is_available": True, + "unavailable_reasons": [] } assert result == expected_result @@ -1438,12 +1475,13 @@ def mock_get_model(model_id): mock_get_model_by_model_id.assert_any_call(789) +@patch('backend.services.agent_service.check_agent_availability') @patch('backend.services.agent_service.get_model_by_model_id') @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.search_tools_for_sub_agent') @patch('backend.services.agent_service.search_agent_info_by_agent_id') @pytest.mark.asyncio -async def test_get_agent_info_impl_with_business_logic_model_no_display_name(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): +async def test_get_agent_info_impl_with_business_logic_model_no_display_name(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id, mock_check_availability): """ Test get_agent_info_impl with business_logic_model_id but model has no display_name. @@ -1489,6 +1527,7 @@ def mock_get_model(model_id): return None mock_get_model_by_model_id.side_effect = mock_get_model + mock_check_availability.return_value = (True, []) # Execute result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") @@ -1502,7 +1541,9 @@ def mock_get_model(model_id): "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", - "business_logic_model_name": None # Should be None when display_name is not in model_info + "business_logic_model_name": None, # Should be None when display_name is not in model_info + "is_available": True, + "unavailable_reasons": [] } assert result == expected_result @@ -2342,390 +2383,132 @@ async def test_clear_agent_memory_clear_memory_error(mock_build_config, mock_cle assert mock_clear_memory.call_count == 2 -# Import agent tests -@patch('backend.services.agent_service.import_agent_by_agent_id') -@patch('backend.services.agent_service.update_tool_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.add_remote_mcp_server_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.get_mcp_server_by_name_and_tenant') -@patch('backend.services.agent_service.check_mcp_name_exists') -@patch('backend.services.agent_service.get_current_user_info') -@pytest.mark.asyncio -async def test_import_agent_impl_success_with_mcp(mock_get_current_user_info, mock_check_mcp_exists, - mock_get_mcp_server, - mock_add_mcp_server, mock_update_tool_list, mock_import_agent): - """ - Test successful import of agent with MCP servers. - """ - # Setup - mock_get_current_user_info.return_value = ( - "test_user", "test_tenant", "en") - - # Mock MCP server checks - mock_check_mcp_exists.return_value = False # MCP server doesn't exist - mock_get_mcp_server.return_value = "http://existing-mcp-server.com" - mock_add_mcp_server.return_value = None # Function returns None on success - mock_update_tool_list.return_value = None - - # Create MCP info - mcp_info = MCPInfo(mcp_server_name="test_mcp_server", - mcp_url="http://test-mcp-server.com") - - # Create agent info - agent_info = ExportAndImportAgentInfo( - agent_id=123, - name="Test Agent", - display_name="Test Agent Display", - description="A test agent", - business_description="For testing purposes", - max_steps=10, - provide_run_summary=True, - duty_prompt="Test duty prompt", - constraint_prompt="Test constraint prompt", - few_shots_prompt="Test few shots prompt", - enabled=True, - tools=[], - managed_agents=[] - ) - - # Create export data format - export_data = ExportAndImportDataFormat( - agent_id=123, - agent_info={"123": agent_info}, - mcp_info=[mcp_info] - ) - - # Mock import agent - mock_import_agent.return_value = 456 # New agent ID - - # Execute - await import_agent_impl(export_data, authorization="Bearer token") - - # Assert - mock_get_current_user_info.assert_called_once_with("Bearer token") - mock_check_mcp_exists.assert_called_once_with( - mcp_name="test_mcp_server", tenant_id="test_tenant") - mock_add_mcp_server.assert_called_once_with( - tenant_id="test_tenant", - user_id="test_user", - remote_mcp_server="http://test-mcp-server.com", - remote_mcp_server_name="test_mcp_server" - ) - mock_update_tool_list.assert_called_once_with( - tenant_id="test_tenant", user_id="test_user") - mock_import_agent.assert_called_once_with( - import_agent_info=agent_info, - tenant_id="test_tenant", - user_id="test_user", - skip_duplicate_regeneration=False, - ) - - +@patch('backend.services.agent_service.insert_related_agent') @patch('backend.services.agent_service.import_agent_by_agent_id') -@patch('backend.services.agent_service.update_tool_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.add_remote_mcp_server_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.get_mcp_server_by_name_and_tenant') -@patch('backend.services.agent_service.check_mcp_name_exists') @patch('backend.services.agent_service.get_current_user_info') @pytest.mark.asyncio -async def test_import_agent_impl_mcp_exists_same_url(mock_get_current_user_info, mock_check_mcp_exists, - mock_get_mcp_server, - mock_add_mcp_server, mock_update_tool_list, mock_import_agent): +async def test_import_agent_impl_imports_all_agents_and_links_relations( + mock_get_current_user_info, + mock_import_agent, + mock_insert_relationship, +): """ - Test import of agent when MCP server exists with same URL (should skip). + Import agent implementation should import sub-agents before their parents + and create the relationship between the newly created agent IDs. """ - # Setup - mock_get_current_user_info.return_value = ( - "test_user", "test_tenant", "en") - - # Mock MCP server exists with same URL - mock_check_mcp_exists.return_value = True - mock_get_mcp_server.return_value = "http://test-mcp-server.com" # Same URL - mock_update_tool_list.return_value = None - # Create MCP info - mcp_info = MCPInfo(mcp_server_name="test_mcp_server", - mcp_url="http://test-mcp-server.com") - - # Create agent info - agent_info = ExportAndImportAgentInfo( - agent_id=123, - name="Test Agent", - display_name="Test Agent Display", - description="A test agent", - business_description="For testing purposes", - max_steps=10, + mock_get_current_user_info.return_value = ("test_user", "test_tenant", "en") + # Sub-agent (ID 2) with no managed agents + sub_agent_info = ExportAndImportAgentInfo( + agent_id=2, + name="SubAgent", + display_name="Sub Agent", + description="Sub agent desc", + business_description="Business desc", + max_steps=5, provide_run_summary=True, - duty_prompt="Test duty prompt", - constraint_prompt="Test constraint prompt", - few_shots_prompt="Test few shots prompt", + duty_prompt="Sub duty", + constraint_prompt="Sub constraint", + few_shots_prompt="Sub few shots", enabled=True, tools=[], managed_agents=[] ) - # Create export data format - export_data = ExportAndImportDataFormat( - agent_id=123, - agent_info={"123": agent_info}, - mcp_info=[mcp_info] - ) - - # Mock import agent - mock_import_agent.return_value = 456 - - # Execute - await import_agent_impl(export_data, authorization="Bearer token") - - # Assert - mock_get_current_user_info.assert_called_once_with("Bearer token") - mock_check_mcp_exists.assert_called_once_with( - mcp_name="test_mcp_server", tenant_id="test_tenant") - mock_get_mcp_server.assert_called_once_with( - mcp_name="test_mcp_server", tenant_id="test_tenant") - mock_add_mcp_server.assert_not_called() # Should not add since URL is the same - mock_update_tool_list.assert_called_once_with( - tenant_id="test_tenant", user_id="test_user") - mock_import_agent.assert_called_once() - - -@patch('backend.services.agent_service.import_agent_by_agent_id') -@patch('backend.services.agent_service.update_tool_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.add_remote_mcp_server_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.get_mcp_server_by_name_and_tenant') -@patch('backend.services.agent_service.check_mcp_name_exists') -@patch('backend.services.agent_service.get_current_user_info') -@pytest.mark.asyncio -async def test_import_agent_impl_mcp_exists_different_url(mock_get_current_user_info, mock_check_mcp_exists, - mock_get_mcp_server, - mock_add_mcp_server, mock_update_tool_list, mock_import_agent): - """ - Test import of agent when MCP server exists with different URL (should add with import prefix). - """ - # Setup - mock_get_current_user_info.return_value = ( - "test_user", "test_tenant", "en") - - # Mock MCP server exists with different URL - mock_check_mcp_exists.return_value = True - mock_get_mcp_server.return_value = "http://different-mcp-server.com" # Different URL - mock_add_mcp_server.return_value = None # Function returns None on success - mock_update_tool_list.return_value = None - - # Create MCP info - mcp_info = MCPInfo(mcp_server_name="test_mcp_server", - mcp_url="http://test-mcp-server.com") - - # Create agent info - agent_info = ExportAndImportAgentInfo( - agent_id=123, - name="Test Agent", - display_name="Test Agent Display", - description="A test agent", - business_description="For testing purposes", + # Main agent references sub agent id 2 + main_agent_info = ExportAndImportAgentInfo( + agent_id=1, + name="MainAgent", + display_name="Main Agent", + description="Main desc", + business_description="Business main", max_steps=10, provide_run_summary=True, - duty_prompt="Test duty prompt", - constraint_prompt="Test constraint prompt", - few_shots_prompt="Test few shots prompt", + duty_prompt="Main duty", + constraint_prompt="Main constraint", + few_shots_prompt="Main few shots", enabled=True, tools=[], - managed_agents=[] + managed_agents=[2] ) - # Create export data format export_data = ExportAndImportDataFormat( - agent_id=123, - agent_info={"123": agent_info}, - mcp_info=[mcp_info] + agent_id=1, + agent_info={ + "1": main_agent_info, + "2": sub_agent_info, + }, + mcp_info=[ + MCPInfo(mcp_server_name="test_mcp_server", + mcp_url="http://test-mcp-server.com") + ], ) - # Mock import agent - mock_import_agent.return_value = 456 + # The order of returns matches the import order: sub-agent first, then main agent + mock_import_agent.side_effect = [101, 202] - # Execute await import_agent_impl(export_data, authorization="Bearer token") - # Assert - mock_get_current_user_info.assert_called_once_with("Bearer token") - mock_check_mcp_exists.assert_called_once_with( - mcp_name="test_mcp_server", tenant_id="test_tenant") - mock_get_mcp_server.assert_called_once_with( - mcp_name="test_mcp_server", tenant_id="test_tenant") - # Should add with import prefix - mock_add_mcp_server.assert_called_once_with( - tenant_id="test_tenant", - user_id="test_user", - remote_mcp_server="http://test-mcp-server.com", - remote_mcp_server_name="import_test_mcp_server" - ) - mock_update_tool_list.assert_called_once_with( - tenant_id="test_tenant", user_id="test_user") - mock_import_agent.assert_called_once() - - -@patch('backend.services.agent_service.add_remote_mcp_server_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.get_mcp_server_by_name_and_tenant') -@patch('backend.services.agent_service.check_mcp_name_exists') -@patch('backend.services.agent_service.get_current_user_info') -@pytest.mark.asyncio -async def test_import_agent_impl_mcp_add_failure(mock_get_current_user_info, mock_check_mcp_exists, mock_get_mcp_server, - mock_add_mcp_server): - """ - Test import of agent when MCP server addition fails. - """ - # Setup - mock_get_current_user_info.return_value = ( - "test_user", "test_tenant", "en") - - # Mock MCP server checks - mock_check_mcp_exists.return_value = False # MCP server doesn't exist - mock_get_mcp_server.return_value = "http://existing-mcp-server.com" - - # Mock MCP server addition failure - the function raises an exception - mock_add_mcp_server.side_effect = Exception("MCP server connection failed") - - # Create MCP info - mcp_info = MCPInfo(mcp_server_name="test_mcp_server", - mcp_url="http://test-mcp-server.com") - - # Create agent info - agent_info = ExportAndImportAgentInfo( - agent_id=123, - name="Test Agent", - display_name="Test Agent Display", - description="A test agent", - business_description="For testing purposes", - max_steps=10, - provide_run_summary=True, - duty_prompt="Test duty prompt", - constraint_prompt="Test constraint prompt", - few_shots_prompt="Test few shots prompt", - enabled=True, - tools=[], - managed_agents=[] - ) + # Sub-agent should be imported before main agent + assert mock_import_agent.call_count == 2 + first_call = mock_import_agent.call_args_list[0] + second_call = mock_import_agent.call_args_list[1] - # Create export data format - export_data = ExportAndImportDataFormat( - agent_id=123, - agent_info={"123": agent_info}, - mcp_info=[mcp_info] - ) + assert first_call.kwargs["import_agent_info"] is sub_agent_info + assert first_call.kwargs["skip_duplicate_regeneration"] is False - # Execute & Assert - with pytest.raises(Exception) as context: - await import_agent_impl(export_data, authorization="Bearer token") + assert second_call.kwargs["import_agent_info"] is main_agent_info + assert second_call.kwargs["skip_duplicate_regeneration"] is False - assert "Failed to add MCP server test_mcp_server" in str(context.value) - mock_add_mcp_server.assert_called_once_with( + # Relationship should link newly created ids (main -> sub) + mock_insert_relationship.assert_called_once_with( + parent_agent_id=202, + child_agent_id=101, tenant_id="test_tenant", - user_id="test_user", - remote_mcp_server="http://test-mcp-server.com", - remote_mcp_server_name="test_mcp_server" - ) - - -@patch('backend.services.agent_service.update_tool_list', new_callable=AsyncMock) -@patch('backend.services.agent_service.get_current_user_info') -@pytest.mark.asyncio -async def test_import_agent_impl_update_tool_list_failure(mock_get_current_user_info, mock_update_tool_list): - """ - Test import of agent when tool list update fails. - """ - # Setup - mock_get_current_user_info.return_value = ( - "test_user", "test_tenant", "en") - - # Mock tool list update failure - mock_update_tool_list.side_effect = Exception("Tool list update failed") - - # Create agent info - agent_info = ExportAndImportAgentInfo( - agent_id=123, - name="Test Agent", - display_name="Test Agent Display", - description="A test agent", - business_description="For testing purposes", - max_steps=10, - provide_run_summary=True, - duty_prompt="Test duty prompt", - constraint_prompt="Test constraint prompt", - few_shots_prompt="Test few shots prompt", - enabled=True, - tools=[], - managed_agents=[] - ) - - # Create export data format - export_data = ExportAndImportDataFormat( - agent_id=123, - agent_info={"123": agent_info}, - mcp_info=[] ) - # Execute & Assert - with pytest.raises(Exception) as context: - await import_agent_impl(export_data, authorization="Bearer token") - - assert "Failed to update tool list" in str(context.value) - mock_update_tool_list.assert_called_once_with( - tenant_id="test_tenant", user_id="test_user") - @patch('backend.services.agent_service.import_agent_by_agent_id') -@patch('backend.services.agent_service.update_tool_list', new_callable=AsyncMock) @patch('backend.services.agent_service.get_current_user_info') @pytest.mark.asyncio -async def test_import_agent_impl_no_mcp_info(mock_get_current_user_info, mock_update_tool_list, - mock_import_agent): +async def test_import_agent_impl_force_import_passes_skip_flag( + mock_get_current_user_info, + mock_import_agent, +): """ - Test import of agent without MCP info. + When force_import=True, skip_duplicate_regeneration should be True. """ - # Setup - mock_get_current_user_info.return_value = ( - "test_user", "test_tenant", "en") - mock_update_tool_list.return_value = None + mock_get_current_user_info.return_value = ("test_user", "test_tenant", "en") - # Create agent info agent_info = ExportAndImportAgentInfo( - agent_id=123, - name="Test Agent", - display_name="Test Agent Display", - description="A test agent", - business_description="For testing purposes", - max_steps=10, + agent_id=1, + name="Agent", + display_name="Agent Display", + description="desc", + business_description="biz", + max_steps=5, provide_run_summary=True, - duty_prompt="Test duty prompt", - constraint_prompt="Test constraint prompt", - few_shots_prompt="Test few shots prompt", + duty_prompt="duty", + constraint_prompt="constraint", + few_shots_prompt="few shots", enabled=True, tools=[], managed_agents=[] ) - # Create export data format without MCP info export_data = ExportAndImportDataFormat( - agent_id=123, - agent_info={"123": agent_info}, + agent_id=1, + agent_info={"1": agent_info}, mcp_info=[] ) - # Mock import agent - mock_import_agent.return_value = 456 - - # Execute - await import_agent_impl(export_data, authorization="Bearer token") + await import_agent_impl(export_data, authorization="Bearer token", force_import=True) - # Assert mock_get_current_user_info.assert_called_once_with("Bearer token") - mock_update_tool_list.assert_called_once_with( - tenant_id="test_tenant", user_id="test_user") - mock_import_agent.assert_called_once_with( - import_agent_info=agent_info, - tenant_id="test_tenant", - user_id="test_user", - skip_duplicate_regeneration=False, - ) + mock_import_agent.assert_called_once() + call_kwargs = mock_import_agent.call_args.kwargs + assert call_kwargs["import_agent_info"] is agent_info + assert call_kwargs["skip_duplicate_regeneration"] is True if __name__ == '__main__': @@ -6192,6 +5975,260 @@ def test_check_single_model_availability_returns_empty_for_available_model(): assert reasons == [] +# ============================================================================ +# Tests for check_agent_availability function +# ============================================================================ + + +@patch('backend.services.agent_service._collect_model_availability_reasons') +@patch('backend.services.agent_service.check_tool_is_available') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +def test_check_agent_availability_all_available( + mock_search_agent_info, + mock_search_tools, + mock_check_tool, + mock_collect_model_reasons +): + """Test check_agent_availability when all tools and models are available.""" + from backend.services.agent_service import check_agent_availability + + mock_agent_info = {"agent_id": 123, "model_id": 456} + mock_search_agent_info.return_value = mock_agent_info + mock_search_tools.return_value = [{"tool_id": 1}, {"tool_id": 2}] + mock_check_tool.return_value = [True, True] + mock_collect_model_reasons.return_value = [] + + is_available, reasons = check_agent_availability( + agent_id=123, + tenant_id="test_tenant" + ) + + assert is_available is True + assert reasons == [] + mock_search_agent_info.assert_called_once_with(123, "test_tenant") + mock_search_tools.assert_called_once_with(agent_id=123, tenant_id="test_tenant") + mock_check_tool.assert_called_once_with([1, 2]) + + +@patch('backend.services.agent_service._collect_model_availability_reasons') +@patch('backend.services.agent_service.check_tool_is_available') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +def test_check_agent_availability_tool_unavailable( + mock_search_agent_info, + mock_search_tools, + mock_check_tool, + mock_collect_model_reasons +): + """Test check_agent_availability when some tools are unavailable.""" + from backend.services.agent_service import check_agent_availability + + mock_agent_info = {"agent_id": 123, "model_id": 456} + mock_search_agent_info.return_value = mock_agent_info + mock_search_tools.return_value = [{"tool_id": 1}, {"tool_id": 2}] + mock_check_tool.return_value = [True, False] # One tool unavailable + mock_collect_model_reasons.return_value = [] + + is_available, reasons = check_agent_availability( + agent_id=123, + tenant_id="test_tenant" + ) + + assert is_available is False + assert reasons == ["tool_unavailable"] + + +@patch('backend.services.agent_service._collect_model_availability_reasons') +@patch('backend.services.agent_service.check_tool_is_available') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +def test_check_agent_availability_model_unavailable( + mock_search_agent_info, + mock_search_tools, + mock_check_tool, + mock_collect_model_reasons +): + """Test check_agent_availability when model is unavailable.""" + from backend.services.agent_service import check_agent_availability + + mock_agent_info = {"agent_id": 123, "model_id": 456} + mock_search_agent_info.return_value = mock_agent_info + mock_search_tools.return_value = [{"tool_id": 1}] + mock_check_tool.return_value = [True] + mock_collect_model_reasons.return_value = ["model_unavailable"] + + is_available, reasons = check_agent_availability( + agent_id=123, + tenant_id="test_tenant" + ) + + assert is_available is False + assert reasons == ["model_unavailable"] + + +@patch('backend.services.agent_service._collect_model_availability_reasons') +@patch('backend.services.agent_service.check_tool_is_available') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +def test_check_agent_availability_both_unavailable( + mock_search_agent_info, + mock_search_tools, + mock_check_tool, + mock_collect_model_reasons +): + """Test check_agent_availability when both tools and model are unavailable.""" + from backend.services.agent_service import check_agent_availability + + mock_agent_info = {"agent_id": 123, "model_id": 456} + mock_search_agent_info.return_value = mock_agent_info + mock_search_tools.return_value = [{"tool_id": 1}] + mock_check_tool.return_value = [False] + mock_collect_model_reasons.return_value = ["model_unavailable"] + + is_available, reasons = check_agent_availability( + agent_id=123, + tenant_id="test_tenant" + ) + + assert is_available is False + assert "tool_unavailable" in reasons + assert "model_unavailable" in reasons + + +@patch('backend.services.agent_service._collect_model_availability_reasons') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +def test_check_agent_availability_no_tools( + mock_search_agent_info, + mock_search_tools, + mock_collect_model_reasons +): + """Test check_agent_availability when agent has no tools.""" + from backend.services.agent_service import check_agent_availability + + mock_agent_info = {"agent_id": 123, "model_id": 456} + mock_search_agent_info.return_value = mock_agent_info + mock_search_tools.return_value = [] # No tools + mock_collect_model_reasons.return_value = [] + + is_available, reasons = check_agent_availability( + agent_id=123, + tenant_id="test_tenant" + ) + + assert is_available is True + assert reasons == [] + + +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +def test_check_agent_availability_agent_not_found(mock_search_agent_info): + """Test check_agent_availability when agent is not found.""" + from backend.services.agent_service import check_agent_availability + + mock_search_agent_info.return_value = None + + is_available, reasons = check_agent_availability( + agent_id=999, + tenant_id="test_tenant" + ) + + assert is_available is False + assert reasons == ["agent_not_found"] + + +@patch('backend.services.agent_service._collect_model_availability_reasons') +@patch('backend.services.agent_service.check_tool_is_available') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +def test_check_agent_availability_with_pre_fetched_agent_info( + mock_search_tools, + mock_check_tool, + mock_collect_model_reasons +): + """Test check_agent_availability with pre-fetched agent_info (avoids duplicate DB query).""" + from backend.services.agent_service import check_agent_availability + + pre_fetched_agent_info = {"agent_id": 123, "model_id": 456} + mock_search_tools.return_value = [{"tool_id": 1}] + mock_check_tool.return_value = [True] + mock_collect_model_reasons.return_value = [] + + is_available, reasons = check_agent_availability( + agent_id=123, + tenant_id="test_tenant", + agent_info=pre_fetched_agent_info + ) + + assert is_available is True + assert reasons == [] + # search_agent_info_by_agent_id should NOT be called since agent_info was provided + mock_search_tools.assert_called_once_with(agent_id=123, tenant_id="test_tenant") + + +@patch('backend.services.agent_service._collect_model_availability_reasons') +@patch('backend.services.agent_service.check_tool_is_available') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +def test_check_agent_availability_with_model_cache( + mock_search_tools, + mock_check_tool, + mock_collect_model_reasons +): + """Test check_agent_availability with pre-populated model cache.""" + from backend.services.agent_service import check_agent_availability + + pre_fetched_agent_info = {"agent_id": 123, "model_id": 456} + model_cache = {456: {"connect_status": "available"}} + mock_search_tools.return_value = [{"tool_id": 1}] + mock_check_tool.return_value = [True] + mock_collect_model_reasons.return_value = [] + + is_available, reasons = check_agent_availability( + agent_id=123, + tenant_id="test_tenant", + agent_info=pre_fetched_agent_info, + model_cache=model_cache + ) + + assert is_available is True + assert reasons == [] + # Verify model_cache was passed to _collect_model_availability_reasons + mock_collect_model_reasons.assert_called_once() + call_args = mock_collect_model_reasons.call_args + assert call_args.kwargs.get("model_cache") == model_cache or call_args[1].get("model_cache") == model_cache + + +@pytest.mark.asyncio +@patch('backend.services.agent_service.check_agent_availability') +@patch('backend.services.agent_service.get_model_by_model_id') +@patch('backend.services.agent_service.query_sub_agents_id_list') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +async def test_get_agent_info_impl_with_unavailable_agent( + mock_search_agent_info, + mock_search_tools, + mock_query_sub_agents_id, + mock_get_model_by_model_id, + mock_check_availability +): + """Test get_agent_info_impl returns is_available=False when agent is unavailable.""" + mock_agent_info = { + "agent_id": 123, + "model_id": 456, + "business_description": "Test agent" + } + mock_search_agent_info.return_value = mock_agent_info + mock_search_tools.return_value = [{"tool_id": 1}] + mock_query_sub_agents_id.return_value = [] + mock_get_model_by_model_id.return_value = {"display_name": "GPT-4"} + # Agent is unavailable due to tool issues + mock_check_availability.return_value = (False, ["tool_unavailable"]) + + result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") + + assert result["is_available"] is False + assert result["unavailable_reasons"] == ["tool_unavailable"] + + @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') diff --git a/test/backend/services/test_conversation_management_service.py b/test/backend/services/test_conversation_management_service.py index 822ed3d87..feeb68d0e 100644 --- a/test/backend/services/test_conversation_management_service.py +++ b/test/backend/services/test_conversation_management_service.py @@ -348,7 +348,7 @@ def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_templat mock_llm_instance = mock_openai.return_value mock_response = MagicMock() mock_response.content = "AI Discussion" - mock_llm_instance.return_value = mock_response + mock_llm_instance.generate.return_value = mock_response # Execute result = call_llm_for_title( @@ -357,7 +357,7 @@ def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_templat # Assert self.assertEqual(result, "AI Discussion") mock_openai.assert_called_once() - mock_llm_instance.assert_called_once() + mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') @patch('backend.services.conversation_management_service.OpenAIServerModel') @@ -380,7 +380,7 @@ def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_g mock_get_prompt_template.return_value = mock_prompt_template mock_llm_instance = mock_openai.return_value - mock_llm_instance.return_value = None + mock_llm_instance.generate.return_value = None # Execute result = call_llm_for_title( @@ -389,6 +389,7 @@ def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_g # Assert self.assertEqual(result, "新对话") mock_openai.assert_called_once() + mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') @patch('backend.services.conversation_management_service.OpenAIServerModel') @@ -411,7 +412,7 @@ def test_call_llm_for_title_response_none_en(self, mock_get_model_config, mock_g mock_get_prompt_template.return_value = mock_prompt_template mock_llm_instance = mock_openai.return_value - mock_llm_instance.return_value = None + mock_llm_instance.generate.return_value = None # Execute result = call_llm_for_title( @@ -420,6 +421,7 @@ def test_call_llm_for_title_response_none_en(self, mock_get_model_config, mock_g # Assert self.assertEqual(result, "New Conversation") mock_openai.assert_called_once() + mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='en') @patch('backend.services.conversation_management_service.rename_conversation') diff --git a/test/backend/services/test_me_model_management_service.py b/test/backend/services/test_me_model_management_service.py index 3abb24485..01676e57a 100644 --- a/test/backend/services/test_me_model_management_service.py +++ b/test/backend/services/test_me_model_management_service.py @@ -1,37 +1,17 @@ import backend.services.me_model_management_service as svc -from consts.exceptions import TimeoutException +from consts.exceptions import MEConnectionException, TimeoutException import sys import os +import asyncio -import aiohttp import pytest -import asyncio -from unittest.mock import patch, MagicMock, AsyncMock +from unittest.mock import patch, AsyncMock, MagicMock # Add the project root directory to sys.path sys.path.insert(0, os.path.abspath( os.path.join(os.path.dirname(__file__), '../../..'))) - -# Sample test data -sample_models_data = { - "data": [ - {"name": "model1", "type": "embed", "version": "1.0"}, - {"name": "model2", "type": "chat", "version": "1.0"}, - {"name": "model3", "type": "rerank", "version": "1.0"}, - {"name": "model4", "type": "embed", "version": "2.0"} - ] -} - -sample_models_list = [ - {"name": "model1", "type": "embed", "version": "1.0"}, - {"name": "model2", "type": "chat", "version": "1.0"}, - {"name": "model3", "type": "rerank", "version": "1.0"}, - {"name": "model4", "type": "embed", "version": "2.0"} -] - - @pytest.mark.asyncio async def test_check_me_variable_set_truthy_when_both_present(): # Patch service module constants to have non-empty values @@ -55,60 +35,15 @@ async def test_check_me_variable_set_falsy_when_host_missing(): @pytest.mark.asyncio -async def test_get_me_models_impl_success_no_filter(): - """Test successful model list retrieval without type filter""" - # Patch service module constants - with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): - - # Create mock response - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=sample_models_data) - mock_response.raise_for_status = MagicMock() - - # Create mock session - mock_session = AsyncMock() - mock_get = AsyncMock() - mock_get.__aenter__.return_value = mock_response - mock_session.get = MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session - - # Test the function - data = await svc.get_me_models_impl(timeout=30, type="") - - # Assertions - assert data == sample_models_list - assert len(data) == 4 - - # Verify correct URL and headers were used - mock_session.get.assert_called_once() - called_url = mock_session.get.call_args[0][0] - assert called_url == "http://mock-model-engine-host/open/router/v1/models" - - called_headers = mock_session.get.call_args[1]['headers'] - assert called_headers['Authorization'] == 'Bearer mock-api-key' - - -@pytest.mark.asyncio -async def test_get_me_models_impl_success_with_filter(): - """Test successful model list retrieval with type filter""" - # Patch service module constants +async def test_check_me_connectivity_success(): + """Test successful ME connectivity check""" with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): + patch.object(svc, 'MODEL_ENGINE_HOST', 'https://me-host.com'), \ + patch('backend.services.me_model_management_service.aiohttp.ClientSession') as mock_session_class: # Create mock response mock_response = AsyncMock() mock_response.status = 200 - mock_response.json = AsyncMock(return_value=sample_models_data) - mock_response.raise_for_status = MagicMock() # Create mock session mock_session = AsyncMock() @@ -119,38 +54,25 @@ async def test_get_me_models_impl_success_with_filter(): # Create mock session factory mock_client_session = AsyncMock() mock_client_session.__aenter__.return_value = mock_session + mock_session_class.return_value = mock_client_session - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session + # Execute + result = await svc.check_me_connectivity(timeout=30) - # Test the function with embed type filter - data = await svc.get_me_models_impl(timeout=30, type="embed") - - # Assertions - expected_embed_models = [ - {"name": "model1", "type": "embed", "version": "1.0"}, - {"name": "model4", "type": "embed", "version": "2.0"} - ] - assert data == expected_embed_models - assert len(data) == 2 - - # Verify correct URL and headers were used - mock_session.get.assert_called_once() + # Assert + assert result is True @pytest.mark.asyncio -async def test_get_me_models_impl_filter_not_found(): - """Test model list retrieval with non-existent type filter""" - # Patch service module constants +async def test_check_me_connectivity_http_error(): + """Test ME connectivity check with HTTP error response""" with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): + patch.object(svc, 'MODEL_ENGINE_HOST', 'https://me-host.com'), \ + patch('backend.services.me_model_management_service.aiohttp.ClientSession') as mock_session_class: - # Create mock response + # Create mock response with error status mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=sample_models_data) - mock_response.raise_for_status = MagicMock() + mock_response.status = 500 # Create mock session mock_session = AsyncMock() @@ -161,26 +83,22 @@ async def test_get_me_models_impl_filter_not_found(): # Create mock session factory mock_client_session = AsyncMock() mock_client_session.__aenter__.return_value = mock_session + mock_session_class.return_value = mock_client_session - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session + # Execute and expect an exception + with pytest.raises(MEConnectionException) as exc_info: + await svc.check_me_connectivity(timeout=30) - # Test the function with non-existent type filter - should raise Exception (not NotFoundException) - with pytest.raises(Exception) as exc_info: - await svc.get_me_models_impl(timeout=30, type="nonexistent") - - # Verify the exception message contains expected information - assert "Failed to get model list: No models found with type 'nonexistent'" in str( - exc_info.value) + # Assert the exception message + assert "Connection failed, error code: 500" in str(exc_info.value) @pytest.mark.asyncio -async def test_get_me_models_impl_timeout(): - """Test model list retrieval with timeout""" - # Patch service module constants +async def test_check_me_connectivity_timeout(): + """Test ME connectivity check with timeout error""" with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): + patch.object(svc, 'MODEL_ENGINE_HOST', 'https://me-host.com'), \ + patch('backend.services.me_model_management_service.aiohttp.ClientSession') as mock_session_class: # Create mock session that raises TimeoutError mock_session = AsyncMock() @@ -191,217 +109,51 @@ async def test_get_me_models_impl_timeout(): # Create mock session factory mock_client_session = AsyncMock() mock_client_session.__aenter__.return_value = mock_session + mock_session_class.return_value = mock_client_session - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session - - # Test the function - should raise TimeoutException - with pytest.raises(TimeoutException) as exc_info: - await svc.get_me_models_impl(timeout=30, type="") - - # Verify the exception message - assert "Request timeout." in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_get_me_models_impl_http_error(): - """Test model list retrieval with HTTP error""" - # Patch service module constants - with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): - - # Create mock response that raises HTTP error - mock_response = AsyncMock() - mock_response.raise_for_status = MagicMock(side_effect=aiohttp.ClientResponseError( - request_info=MagicMock(), - history=(), - status=404, - message="Not Found" - )) - - # Create mock session - mock_session = AsyncMock() - mock_get = AsyncMock() - mock_get.__aenter__.return_value = mock_response - mock_session.get = MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session - - # Test the function - should raise Exception - with pytest.raises(Exception) as exc_info: - await svc.get_me_models_impl(timeout=30, type="") - - # Verify the exception message contains expected information - assert "Failed to get model list:" in str(exc_info.value) - assert "404" in str(exc_info.value) - assert "Not Found" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_get_me_models_impl_json_parse_error(): - """Test model list retrieval when JSON parsing fails""" - # Patch service module constants - with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): - - # Create mock response - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(side_effect=Exception("Invalid JSON")) - mock_response.raise_for_status = MagicMock() - - # Create mock session - mock_session = AsyncMock() - mock_get = AsyncMock() - mock_get.__aenter__.return_value = mock_response - mock_session.get = MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session - - # Test the function - should raise Exception - with pytest.raises(Exception) as exc_info: - await svc.get_me_models_impl(timeout=30, type="") + # Execute and expect a TimeoutException + with pytest.raises(TimeoutException) as exc_info: + await svc.check_me_connectivity(timeout=30) - # Verify the exception message contains expected information - assert "Failed to get model list: Invalid JSON." in str( - exc_info.value) + # Assert the exception message + assert "Connection timed out" in str(exc_info.value) @pytest.mark.asyncio -async def test_get_me_models_impl_connection_exception(): - """Test model list retrieval when connection exception occurs""" - # Patch service module constants - with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): - - # Create mock session that raises exception - mock_session = AsyncMock() - mock_get = AsyncMock() - mock_get.__aenter__.side_effect = Exception("Connection error") - mock_session.get = MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session - - # Test the function - should raise Exception - with pytest.raises(Exception) as exc_info: - await svc.get_me_models_impl(timeout=30, type="") - - # Verify the exception message contains expected information - assert "Failed to get model list: Connection error." in str( - exc_info.value) +async def test_check_me_connectivity_variables_not_set(): + """Test ME connectivity check when environment variables not set""" + with patch.object(svc, 'MODEL_ENGINE_APIKEY', ''), \ + patch.object(svc, 'MODEL_ENGINE_HOST', ''): + # Execute - should return False when env vars not set + result = await svc.check_me_connectivity(timeout=30) -@pytest.mark.asyncio -async def test_get_me_models_impl_different_types(): - """Test model list retrieval with different type filters""" - # Patch service module constants and import the function - with patch('backend.services.me_model_management_service.MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch('backend.services.me_model_management_service.MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): - - # Import the function after mocking - from backend.services.me_model_management_service import get_me_models_impl - - test_cases = [ - ("chat", [{"name": "model2", "type": "chat", "version": "1.0"}]), - ("rerank", [{"name": "model3", "type": "rerank", "version": "1.0"}]), - ("embed", [ - {"name": "model1", "type": "embed", "version": "1.0"}, - {"name": "model4", "type": "embed", "version": "2.0"} - ]) - ] - - for filter_type, expected_models in test_cases: - # Create mock response - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=sample_models_data) - mock_response.raise_for_status = MagicMock() - - # Create mock session - mock_session = AsyncMock() - mock_get = AsyncMock() - mock_get.__aenter__.return_value = mock_response - mock_session.get = MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - - # Patch the ClientSession - with patch('backend.services.me_model_management_service.aiohttp.ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session - - # Test the function - data = await svc.get_me_models_impl(timeout=30, type=filter_type) - - # Assertions - assert data == expected_models - assert len(data) == len(expected_models) - - # Verify correct URL was called - mock_session.get.assert_called_once() + # Assert + assert result is False @pytest.mark.asyncio -async def test_get_me_models_impl_empty_response(): - """Test model list retrieval with empty response""" - # Patch service module constants +async def test_check_me_connectivity_general_exception(): + """Test ME connectivity check with general exception (covers lines 54-55)""" with patch.object(svc, 'MODEL_ENGINE_APIKEY', 'mock-api-key'), \ - patch.object(svc, 'MODEL_ENGINE_HOST', 'http://mock-model-engine-host'): - - empty_models_data = {"data": []} + patch.object(svc, 'MODEL_ENGINE_HOST', 'https://me-host.com'), \ + patch('backend.services.me_model_management_service.aiohttp.ClientSession') as mock_session_class: - # Create mock response - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=empty_models_data) - mock_response.raise_for_status = MagicMock() - - # Create mock session + # Create mock session that raises a general exception mock_session = AsyncMock() mock_get = AsyncMock() - mock_get.__aenter__.return_value = mock_response + mock_get.__aenter__.side_effect = ValueError("Unexpected error") mock_session.get = MagicMock(return_value=mock_get) # Create mock session factory mock_client_session = AsyncMock() mock_client_session.__aenter__.return_value = mock_session + mock_session_class.return_value = mock_client_session - # Patch the ClientSession - with patch.object(svc.aiohttp, 'ClientSession') as mock_session_class: - mock_session_class.return_value = mock_client_session - - # Test the function without filter - should return empty list - data = await svc.get_me_models_impl(timeout=30, type="") - - # Assertions for no filter - assert data == [] - assert len(data) == 0 - - # Test the function with filter on empty data - should raise Exception - with pytest.raises(Exception) as exc_info: - await svc.get_me_models_impl(timeout=30, type="embed") + # Execute and expect a generic Exception + with pytest.raises(Exception) as exc_info: + await svc.check_me_connectivity(timeout=30) - # Verify the exception message contains expected information - assert "Failed to get model list: No models found with type 'embed'" in str( - exc_info.value) - assert "Available types: set()" in str(exc_info.value) + # Assert the exception message contains "Unknown error occurred" + assert "Unknown error occurred" in str(exc_info.value) + assert "Unexpected error" in str(exc_info.value) diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py index 5c56d4592..15efe79fa 100644 --- a/test/backend/services/test_model_health_service.py +++ b/test/backend/services/test_model_health_service.py @@ -203,7 +203,8 @@ async def test_perform_connectivity_check_llm(): mock_observer_instance, model_id="gpt-4", api_base="https://api.openai.com", - api_key="test-key" + api_key="test-key", + ssl_verify=True ) mock_model_instance.check_connectivity.assert_called_once() @@ -330,7 +331,8 @@ async def test_perform_connectivity_check_base_url_normalization_localhost(): mock_observer_instance, model_id="gpt-4", api_base="http://host.docker.internal:8080", - api_key="test-key" + api_key="test-key", + ssl_verify=True ) @@ -362,7 +364,8 @@ async def test_perform_connectivity_check_base_url_normalization_127001(): mock_observer_instance, model_id="gpt-4", api_base="http://host.docker.internal:8000", - api_key="test-key" + api_key="test-key", + ssl_verify=True ) @pytest.mark.asyncio @@ -414,7 +417,7 @@ async def test_check_model_connectivity_success(): mock_update_model.assert_any_call( "model123", {"connect_status": "available"}) mock_connectivity_check.assert_called_once_with( - "openai/gpt-4", "llm", "https://api.openai.com", "test-key" + "openai/gpt-4", "llm", "https://api.openai.com", "test-key", True ) @@ -539,7 +542,7 @@ async def test_verify_model_config_connectivity_success(): assert "error" not in response mock_connectivity_check.assert_called_once_with( - "gpt-4", "llm", "https://api.openai.com", "test-key" + "gpt-4", "llm", "https://api.openai.com", "test-key", True ) @@ -785,146 +788,3 @@ async def test_embedding_dimension_check_wrapper_value_error(): ) -@pytest.mark.asyncio -async def test_check_me_connectivity_impl_success(): - """Test successful ME connectivity check""" - # Setup - with mock.patch("backend.services.model_health_service.MODEL_ENGINE_APIKEY", "me-api-key"), \ - mock.patch("backend.services.model_health_service.MODEL_ENGINE_HOST", "https://me-host.com"), \ - mock.patch("backend.services.model_health_service.ModelConnectStatusEnum") as mock_enum, \ - mock.patch("backend.services.model_health_service.aiohttp.ClientSession") as mock_session_class: - - mock_enum.AVAILABLE.value = "available" - mock_enum.UNAVAILABLE.value = "unavailable" - - # Create mock response - mock_response = mock.AsyncMock() - mock_response.status = 200 - - # Create mock session - mock_session = mock.AsyncMock() - mock_get = mock.AsyncMock() - mock_get.__aenter__.return_value = mock_response - mock_session.get = mock.MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = mock.AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - mock_session_class.return_value = mock_client_session - - # Import the function after mocking - from backend.services.model_health_service import check_me_connectivity_impl - - # Execute - the function should return None on success or raise an exception - result = await check_me_connectivity_impl(timeout=30) - - # Assert - should return None on success - assert result is None - - -@pytest.mark.asyncio -async def test_check_me_connectivity_impl_http_error(): - """Test ME connectivity check with HTTP error response""" - # Setup - with mock.patch("backend.services.model_health_service.MODEL_ENGINE_APIKEY", "me-api-key"), \ - mock.patch("backend.services.model_health_service.MODEL_ENGINE_HOST", "https://me-host.com"), \ - mock.patch("backend.services.model_health_service.ModelConnectStatusEnum") as mock_enum, \ - mock.patch("backend.services.model_health_service.aiohttp.ClientSession") as mock_session_class: - - mock_enum.AVAILABLE.value = "available" - mock_enum.UNAVAILABLE.value = "unavailable" - - # Create mock response with error status - mock_response = mock.AsyncMock() - mock_response.status = 500 - - # Create mock session - mock_session = mock.AsyncMock() - mock_get = mock.AsyncMock() - mock_get.__aenter__.return_value = mock_response - mock_session.get = mock.MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = mock.AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - mock_session_class.return_value = mock_client_session - - # Import the function after mocking - from backend.services.model_health_service import check_me_connectivity_impl - - # Execute and expect an exception - with pytest.raises(Exception) as exc_info: - await check_me_connectivity_impl(timeout=30) - - # Assert the exception message - assert "Unknown error occurred: Connection failed, error code: 500" in str( - exc_info.value) - - -@pytest.mark.asyncio -async def test_check_me_connectivity_impl_timeout(): - """Test ME connectivity check with timeout error""" - # Setup - with mock.patch("backend.services.model_health_service.MODEL_ENGINE_APIKEY", "me-api-key"), \ - mock.patch("backend.services.model_health_service.MODEL_ENGINE_HOST", "https://me-host.com"), \ - mock.patch("backend.services.model_health_service.ModelConnectStatusEnum") as mock_enum, \ - mock.patch("backend.services.model_health_service.aiohttp.ClientSession") as mock_session_class: - - mock_enum.AVAILABLE.value = "available" - mock_enum.UNAVAILABLE.value = "unavailable" - - # Create mock session that raises TimeoutError - mock_session = mock.AsyncMock() - mock_get = mock.AsyncMock() - mock_get.__aenter__.side_effect = asyncio.TimeoutError() - mock_session.get = mock.MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = mock.AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - mock_session_class.return_value = mock_client_session - - # Import the function after mocking - from backend.services.model_health_service import check_me_connectivity_impl - - # Execute and expect a TimeoutException - with pytest.raises(TimeoutException) as exc_info: - await check_me_connectivity_impl(timeout=30) - - # Assert the exception message - assert "Connection timed out" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_check_me_connectivity_impl_exception(): - """Test ME connectivity check with general exception""" - # Setup - with mock.patch("backend.services.model_health_service.MODEL_ENGINE_APIKEY", "me-api-key"), \ - mock.patch("backend.services.model_health_service.MODEL_ENGINE_HOST", "https://me-host.com"), \ - mock.patch("backend.services.model_health_service.ModelConnectStatusEnum") as mock_enum, \ - mock.patch("backend.services.model_health_service.aiohttp.ClientSession") as mock_session_class: - - mock_enum.AVAILABLE.value = "available" - mock_enum.UNAVAILABLE.value = "unavailable" - - # Create mock session that raises general exception - mock_session = mock.AsyncMock() - mock_get = mock.AsyncMock() - mock_get.__aenter__.side_effect = Exception("Connection error") - mock_session.get = mock.MagicMock(return_value=mock_get) - - # Create mock session factory - mock_client_session = mock.AsyncMock() - mock_client_session.__aenter__.return_value = mock_session - mock_session_class.return_value = mock_client_session - - # Import the function after mocking - from backend.services.model_health_service import check_me_connectivity_impl - - # Execute and expect an Exception - with pytest.raises(Exception) as exc_info: - await check_me_connectivity_impl(timeout=30) - - # Assert the exception message - assert "Unknown error occurred: Connection error" in str( - exc_info.value) diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 8d5cdcd4d..86e1cac73 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -119,6 +119,7 @@ class _Func: class _ProviderEnum: SILICON = _EnumItem("silicon") + MODELENGINE = _EnumItem("modelengine") consts_provider_mod.ProviderEnum = _ProviderEnum @@ -200,9 +201,15 @@ def _get_models_by_tenant_factory_type(*args, **kwargs): return [] +def _get_models_by_display_name(*args, **kwargs): + """Return an empty list for display name lookups in tests.""" + return [] + + db_mm_mod.create_model_record = _noop db_mm_mod.delete_model_record = _noop db_mm_mod.get_model_by_display_name = _noop +db_mm_mod.get_models_by_display_name = _get_models_by_display_name db_mm_mod.get_model_records = _get_model_records db_mm_mod.get_models_by_tenant_factory_type = _get_models_by_tenant_factory_type @@ -328,6 +335,29 @@ async def test_create_model_for_tenant_conflict_raises(): assert "Failed to create model" in str(exc.value) +@pytest.mark.asyncio +async def test_create_model_for_tenant_display_name_conflict_valueerror(): + """Test that display_name conflict raises ValueError (covers lines 65-72)""" + svc = import_svc() + + existing_model = {"model_id": 1, "display_name": "existing_name"} + with mock.patch.object(svc, "get_model_by_display_name", return_value=existing_model): + user_id = "u1" + tenant_id = "t1" + model_data = { + "model_name": "huggingface/llama", + "display_name": "existing_name", # Conflicts with existing + "base_url": "http://localhost:8000", + "model_type": "llm", + } + + # ValueError is wrapped in Exception, but the error message should contain the original ValueError message + with pytest.raises(Exception) as exc: + await svc.create_model_for_tenant(user_id, tenant_id, model_data) + assert "already in use" in str(exc.value) + assert "existing_name" in str(exc.value) + + @pytest.mark.asyncio async def test_create_model_for_tenant_multi_embedding_creates_two_records(): svc = import_svc() @@ -404,6 +434,40 @@ async def test_create_provider_models_for_tenant_exception(): assert "Failed to create provider models" in str(exc.value) +@pytest.mark.asyncio +async def test_batch_create_models_for_tenant_other_provider(): + """Test batch_create_models_for_tenant with non-Silicon/ModelEngine provider (covers lines 138-140)""" + svc = import_svc() + + batch_payload = { + "provider": "openai", # Not Silicon or ModelEngine + "type": "llm", + "models": [ + {"id": "openai/gpt-4", "max_tokens": 4096}, + ], + "api_key": "k", + } + + # Add MODELENGINE to ProviderEnum if it doesn't exist + if not hasattr(svc.ProviderEnum, 'MODELENGINE'): + modelengine_item = _EnumItem("modelengine") + svc.ProviderEnum.MODELENGINE = modelengine_item + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", return_value=("openai", "gpt-4")), \ + mock.patch.object(svc, "add_repo_to_name", return_value="openai/gpt-4"), \ + mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ + mock.patch.object(svc, "create_model_record", return_value=True): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + # Verify prepare_model_dict was called with empty model_url for non-Silicon/ModelEngine provider + call_args = svc.prepare_model_dict.call_args + assert call_args[1]["model_url"] == "" # Should be empty for other providers + + @pytest.mark.asyncio async def test_batch_create_models_for_tenant_flow(): svc = import_svc() @@ -446,6 +510,56 @@ def get_by_display(display_name, tenant_id): mock_create.assert_called_once() +@pytest.mark.asyncio +async def test_batch_create_models_max_tokens_update(): + """Test batch_create_models updates max_tokens when display_name exists and max_tokens changed (covers lines 160->173, 168->171)""" + svc = import_svc() + + batch_payload = { + "provider": "silicon", + "type": "llm", + "models": [ + {"id": "silicon/model1", "max_tokens": 8192}, # Changed from 4096 + {"id": "silicon/model2", "max_tokens": 4096}, # Same as existing + {"id": "silicon/model3", "max_tokens": None}, # None should not update + ], + "api_key": "k", + } + + def get_by_display(display_name, tenant_id): + if display_name == "silicon/model1": + return {"model_id": "id1", "max_tokens": 4096} # Different from new value + elif display_name == "silicon/model2": + return {"model_id": "id2", "max_tokens": 4096} # Same as new value + elif display_name == "silicon/model3": + return {"model_id": "id3", "max_tokens": 2048} # Existing has value, new is None + return None + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", side_effect=lambda x: ("silicon", x.split("/")[1] if "/" in x else x)), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda r, n: f"{r}/{n}"), \ + mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \ + mock.patch.object(svc, "update_model_record") as mock_update, \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ + mock.patch.object(svc, "create_model_record", return_value=True): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + # Should update model1 (max_tokens changed from 4096 to 8192) + # Note: update_model_record may be called multiple times, so check if it was called with correct args + update_calls = [call for call in mock_update.call_args_list if call[0][0] == "id1"] + if update_calls: + assert update_calls[0][0][1] == {"max_tokens": 8192} + + # Should NOT update model2 (max_tokens same) or model3 (new max_tokens is None) + # Verify model2 and model3 were not updated + model2_calls = [call for call in mock_update.call_args_list if call[0][0] == "id2"] + model3_calls = [call for call in mock_update.call_args_list if call[0][0] == "id3"] + assert len(model2_calls) == 0 # model2 should not be updated (same max_tokens) + assert len(model3_calls) == 0 # model3 should not be updated (new max_tokens is None) + + @pytest.mark.asyncio async def test_batch_create_models_for_tenant_exception(): svc = import_svc() @@ -482,64 +596,87 @@ async def test_list_provider_models_for_tenant_exception(): assert "Failed to list provider models" in str(exc.value) -async def test_update_single_model_for_tenant_success(): +async def test_update_single_model_for_tenant_success_single_model(): + """Update succeeds for a single non-embedding model with no display_name change.""" svc = import_svc() - model = {"model_id": "1", "display_name": "name"} - with mock.patch.object(svc, "get_model_by_display_name", return_value=None) as mock_get, \ + existing_models = [ + {"model_id": 1, "model_type": "llm", "display_name": "name"}, + ] + model_data = { + "model_id": 1, + "display_name": "name", + "description": "updated", + "model_type": "llm", + } + + with mock.patch.object(svc, "get_models_by_display_name", return_value=existing_models) as mock_get, \ mock.patch.object(svc, "update_model_record") as mock_update: - await svc.update_single_model_for_tenant("u1", "t1", model) + await svc.update_single_model_for_tenant("u1", "t1", "name", model_data) + mock_get.assert_called_once_with("name", "t1") - mock_update.assert_called_once_with(1, model, "u1") + # update_model_record should be called without model_id in the payload + mock_update.assert_called_once_with( + 1, + {"display_name": "name", "description": "updated", "model_type": "llm"}, + "u1", + ) -async def test_update_single_model_for_tenant_conflict(): +async def test_update_single_model_for_tenant_conflict_new_display_name(): + """Updating to a new conflicting display_name raises ValueError.""" svc = import_svc() - model = {"model_id": "m1", "display_name": "name"} - with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": "other"}): - with pytest.raises(Exception) as exc: - await svc.update_single_model_for_tenant("u1", "t1", model) - assert "Failed to update model" in str(exc.value) + existing_models = [ + {"model_id": 1, "model_type": "llm", "display_name": "old_name"}, + ] + conflict_models = [ + {"model_id": 2, "model_type": "llm", "display_name": "new_name"}, + ] + model_data = { + "model_id": 1, + "display_name": "new_name", + } + with mock.patch.object(svc, "get_models_by_display_name", side_effect=[existing_models, conflict_models]): + with pytest.raises(ValueError) as exc: + await svc.update_single_model_for_tenant("u1", "t1", "old_name", model_data) + assert "already in use" in str(exc.value) -async def test_update_single_model_for_tenant_same_model_no_conflict(): - """Test that updating the same model with same display name doesn't raise conflict.""" + +async def test_update_single_model_for_tenant_not_found_raises_lookup_error(): + """If no model is found for current_display_name, raise LookupError.""" svc = import_svc() - model = {"model_id": "123", "display_name": "existing_name"} - # Return the same model_id (as int) to simulate updating the same model - with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": 123}) as mock_get, \ - mock.patch.object(svc, "update_model_record") as mock_update: - await svc.update_single_model_for_tenant("u1", "t1", model) - mock_get.assert_called_once_with("existing_name", "t1") - mock_update.assert_called_once_with(123, model, "u1") + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]): + with pytest.raises(LookupError): + await svc.update_single_model_for_tenant("u1", "t1", "missing", {"display_name": "x"}) -async def test_update_single_model_for_tenant_type_conversion(): - """Test that string model_id is properly converted to int for comparison.""" +async def test_update_single_model_for_tenant_multi_embedding_updates_both(): + """Updating multi_embedding models updates both embedding and multi_embedding records.""" svc = import_svc() - model = {"model_id": "456", "display_name": "test_name"} - # Return the same model_id as int to test type conversion - with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": 456}) as mock_get, \ - mock.patch.object(svc, "update_model_record") as mock_update: - await svc.update_single_model_for_tenant("u1", "t1", model) - mock_get.assert_called_once_with("test_name", "t1") - mock_update.assert_called_once_with(456, model, "u1") - + existing_models = [ + {"model_id": 10, "model_type": "embedding", "display_name": "emb_name"}, + {"model_id": 11, "model_type": "multi_embedding", "display_name": "emb_name"}, + ] + model_data = { + "model_id": 10, + "display_name": "emb_name", + "description": "updated", + "model_type": "multi_embedding", + } -async def test_update_single_model_for_tenant_different_model_conflict(): - """Test that updating with a display name used by a different model raises conflict.""" - svc = import_svc() + with mock.patch.object(svc, "get_models_by_display_name", return_value=existing_models) as mock_get, \ + mock.patch.object(svc, "update_model_record") as mock_update: + await svc.update_single_model_for_tenant("u1", "t1", "emb_name", model_data) - model = {"model_id": "789", "display_name": "conflict_name"} - # Return a different model_id to simulate name conflict - with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": 999}): - with pytest.raises(Exception) as exc: - await svc.update_single_model_for_tenant("u1", "t1", model) - assert "Failed to update model" in str(exc.value) - assert "Name conflict_name is already in use" in str(exc.value) + mock_get.assert_called_once_with("emb_name", "t1") + # model_type should be stripped from update payload for multi_embedding flow + expected_update = {"display_name": "emb_name", "description": "updated"} + mock_update.assert_any_call(10, expected_update, "u1") + mock_update.assert_any_call(11, expected_update, "u1") async def test_batch_update_models_for_tenant_success(): @@ -549,8 +686,8 @@ async def test_batch_update_models_for_tenant_success(): with mock.patch.object(svc, "update_model_record") as mock_update: await svc.batch_update_models_for_tenant("u1", "t1", models) assert mock_update.call_count == 2 - mock_update.assert_any_call("a", models[0], "u1") - mock_update.assert_any_call("b", models[1], "u1") + mock_update.assert_any_call("a", models[0], "u1", "t1") + mock_update.assert_any_call("b", models[1], "u1", "t1") async def test_batch_update_models_for_tenant_exception(): @@ -563,48 +700,62 @@ async def test_batch_update_models_for_tenant_exception(): assert "Failed to batch update models" in str(exc.value) -async def test_delete_model_for_tenant_not_found(): +async def test_delete_model_for_tenant_not_found_raises_lookup_error(): + """If no models are found for display_name, raise LookupError.""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None): - with pytest.raises(Exception) as exc: + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]): + with pytest.raises(LookupError): await svc.delete_model_for_tenant("u1", "t1", "missing") - assert "Failed to delete model" in str(exc.value) async def test_delete_model_for_tenant_embedding_deletes_both(): + """Embedding + multi_embedding models are both deleted and memories cleared.""" svc = import_svc() - # Call sequence: initial -> embedding -> multi_embedding - side_effect = [ - {"model_id": "id-emb", "model_type": "embedding"}, - {"model_id": "id-emb", "model_type": "embedding"}, - {"model_id": "id-multi", "model_type": "multi_embedding"}, + models = [ + { + "model_id": "id-emb", + "model_type": "embedding", + "model_repo": "openai", + "model_name": "text-embedding-3-small", + "max_tokens": 1536, + }, + { + "model_id": "id-multi", + "model_type": "multi_embedding", + "model_repo": "openai", + "model_name": "text-embedding-3-small", + "max_tokens": 1536, + }, ] - with mock.patch.object(svc, "get_model_by_display_name", side_effect=side_effect) as mock_get, \ + + with mock.patch.object(svc, "get_models_by_display_name", return_value=models) as mock_get, \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ mock.patch.object(svc, "get_vector_db_core", return_value=object()) as mock_get_vdb, \ mock.patch.object(svc, "build_memory_config_for_tenant", return_value={}) as mock_build_cfg, \ mock.patch.object(svc, "clear_model_memories", new=mock.AsyncMock()) as mock_clear: await svc.delete_model_for_tenant("u1", "t1", "name") + + mock_get.assert_called_once_with("name", "t1") assert mock_delete.call_count == 2 - mock_get.assert_called() mock_get_vdb.assert_called_once() mock_build_cfg.assert_called_once_with("t1") - # Best-effort cleanup may call once or twice depending on state - assert mock_clear.await_count >= 1 + # Best-effort cleanup should be attempted for both records + assert mock_clear.await_count == 2 @pytest.mark.asyncio async def test_delete_model_for_tenant_cleanup_inner_exception(caplog): svc = import_svc() - side_effect = [ - {"model_id": "id-emb", "model_type": "embedding"}, - {"model_id": "id-emb", "model_type": "embedding"}, - {"model_id": "id-multi", "model_type": "multi_embedding"}, + models = [ + {"model_id": "id-emb", "model_type": "embedding", + "model_repo": "r", "model_name": "n", "max_tokens": 1}, + {"model_id": "id-multi", "model_type": "multi_embedding", + "model_repo": "r", "model_name": "n", "max_tokens": 1}, ] - with mock.patch.object(svc, "get_model_by_display_name", side_effect=side_effect), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=models), \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ mock.patch.object(svc, "get_vector_db_core", return_value=object()), \ mock.patch.object(svc, "build_memory_config_for_tenant", return_value={}), \ @@ -622,12 +773,11 @@ async def test_delete_model_for_tenant_cleanup_inner_exception(caplog): async def test_delete_model_for_tenant_cleanup_outer_exception(caplog): svc = import_svc() - side_effect = [ - {"model_id": "id-emb", "model_type": "embedding"}, + models = [ {"model_id": "id-emb", "model_type": "embedding"}, {"model_id": "id-multi", "model_type": "multi_embedding"}, ] - with mock.patch.object(svc, "get_model_by_display_name", side_effect=side_effect), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=models), \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ mock.patch.object(svc, "get_vector_db_core", side_effect=Exception("vdb_down")), \ mock.patch.object(svc, "build_memory_config_for_tenant", return_value={}): @@ -641,12 +791,19 @@ async def test_delete_model_for_tenant_cleanup_outer_exception(caplog): async def test_delete_model_for_tenant_non_embedding(): + """Non-embedding model deletes a single record without memory cleanup.""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": "id", "model_type": "llm"}), \ - mock.patch.object(svc, "delete_model_record") as mock_delete: + models = [ + {"model_id": "id", "model_type": "llm"}, + ] + with mock.patch.object(svc, "get_models_by_display_name", return_value=models), \ + mock.patch.object(svc, "delete_model_record") as mock_delete, \ + mock.patch.object(svc, "get_vector_db_core") as mock_get_vdb: await svc.delete_model_for_tenant("u1", "t1", "name") mock_delete.assert_called_once_with("id", "u1", "t1") + # For non-embedding models we should not prepare vector DB cleanup + mock_get_vdb.assert_not_called() async def test_list_models_for_tenant_success(): @@ -761,6 +918,44 @@ async def test_list_llm_models_for_tenant_normalizes_connect_status(): assert result[1]["connect_status"] == "operational" +async def test_list_models_for_tenant_type_mapping(): + """Test list_models_for_tenant maps model_type from 'chat' to 'llm' (covers line 310)""" + svc = import_svc() + + records = [ + { + "model_id": "llm1", + "model_repo": "openai", + "model_name": "gpt-4", + "display_name": "GPT-4", + "model_type": "chat", # ModelEngine type that should be mapped to "llm" + "connect_status": "operational" + }, + { + "model_id": "llm2", + "model_repo": "anthropic", + "model_name": "claude-3", + "display_name": "Claude 3", + "model_type": "llm", # Already correct type + "connect_status": "not_detected" + } + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"): + + result = await svc.list_models_for_tenant("t1") + + assert len(result) == 2 + # First model should have model_type mapped from "chat" to "llm" (covers line 310) + assert result[0]["model_type"] == "llm" # Should be mapped from "chat" + assert result[0]["model_id"] == "llm1" + # Second model should remain "llm" + assert result[1]["model_type"] == "llm" + assert result[1]["model_id"] == "llm2" + + async def test_list_llm_models_for_tenant_handles_missing_repo(): """Test list_llm_models_for_tenant handles models without repo.""" svc = import_svc() diff --git a/test/backend/services/test_model_provider_service.py b/test/backend/services/test_model_provider_service.py index 2702d35bd..ce3a0ab75 100644 --- a/test/backend/services/test_model_provider_service.py +++ b/test/backend/services/test_model_provider_service.py @@ -13,7 +13,7 @@ def __getattr__(cls, item): # without its real heavy dependencies being present during unit-testing. # --------------------------------------------------------------------------- for module_path in [ - "consts", "consts.provider", "consts.model", "consts.const", + "consts", "consts.provider", "consts.model", "consts.const", "consts.exceptions", "utils", "utils.model_name_utils", "services", "services.model_health_service", "database", "database.model_management_db", @@ -31,6 +31,7 @@ def __getattr__(cls, item): # Mock ProviderEnum for get_provider_models tests class _ProviderEnumStub: SILICON = mock.Mock(value="silicon") + MODELENGINE = mock.Mock(value="modelengine") sys.modules["consts.provider"].ProviderEnum = _ProviderEnumStub @@ -42,6 +43,13 @@ class _EnumStub: sys.modules["consts.model"].ModelConnectStatusEnum = _EnumStub +# Mock exception classes +class _TimeoutExceptionStub(Exception): + """Mock TimeoutException for testing.""" + pass + +sys.modules["consts.exceptions"].TimeoutException = _TimeoutExceptionStub + # Mock the database function that merge_existing_model_tokens depends on sys.modules["database.model_management_db"].get_models_by_tenant_factory_type = mock.MagicMock() @@ -739,3 +747,383 @@ async def test_get_provider_models_silicon_with_different_model_types(): assert result == [{"id": "test-model"}] mock_provider_instance.get_models.assert_called_once_with(model_data) + + +# --------------------------------------------------------------------------- +# Test-cases for ModelEngineProvider.get_models +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_modelengine_get_models_no_env_config(): + """ModelEngine provider should return empty list when env vars not configured.""" + # Import ModelEngineProvider + from backend.services.model_provider_service import ModelEngineProvider + + provider_config = {"model_type": "llm"} + + with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", ""), \ + mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", ""): + + result = await ModelEngineProvider().get_models(provider_config) + + assert result == [] + + +@pytest.mark.asyncio +async def test_modelengine_get_models_llm_success(): + """ModelEngine provider should return LLM models with correct type mapping.""" + from backend.services.model_provider_service import ModelEngineProvider + + provider_config = {"model_type": "llm"} + + with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", "https://model-engine.com"), \ + mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", "test-key"), \ + mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \ + mock.patch("backend.services.model_provider_service.aiohttp.ClientTimeout"), \ + mock.patch("backend.services.model_provider_service.aiohttp.TCPConnector"): + + # Setup mock response + mock_response = mock.AsyncMock() + mock_response.status = 200 + mock_response.raise_for_status = mock.Mock() + mock_response.json = mock.AsyncMock(return_value={ + "data": [ + {"id": "gpt-4", "type": "chat"}, + {"id": "claude-3", "type": "chat"}, + ] + }) + + # Setup mock session with proper async context manager + mock_get_cm = mock.MagicMock() + mock_get_cm.__aenter__ = mock.AsyncMock(return_value=mock_response) + mock_get_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_session_instance = mock.MagicMock() + mock_session_instance.get = mock.Mock(return_value=mock_get_cm) + + mock_session_cm = mock.MagicMock() + mock_session_cm.__aenter__ = mock.AsyncMock(return_value=mock_session_instance) + mock_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_session_class.return_value = mock_session_cm + + result = await ModelEngineProvider().get_models(provider_config) + + assert len(result) == 2 + assert result[0]["id"] == "gpt-4" + assert result[0]["model_type"] == "llm" + assert result[0]["model_tag"] == "chat" + assert result[0]["max_tokens"] == sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS + assert result[0]["base_url"] == "https://model-engine.com" + assert result[0]["api_key"] == "test-key" + + +@pytest.mark.asyncio +async def test_modelengine_get_models_embedding_success(): + """ModelEngine provider should return embedding models with correct type mapping.""" + from backend.services.model_provider_service import ModelEngineProvider + + provider_config = {"model_type": "embedding"} + + with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", "https://model-engine.com"), \ + mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", "test-key"), \ + mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \ + mock.patch("backend.services.model_provider_service.aiohttp.ClientTimeout"), \ + mock.patch("backend.services.model_provider_service.aiohttp.TCPConnector"): + + mock_response = mock.AsyncMock() + mock_response.status = 200 + mock_response.raise_for_status = mock.Mock() + mock_response.json = mock.AsyncMock(return_value={ + "data": [ + {"id": "text-embedding-ada", "type": "embed"}, + {"id": "gpt-4", "type": "chat"}, # Should be filtered out + ] + }) + + # Setup mock session with proper async context manager + mock_get_cm = mock.MagicMock() + mock_get_cm.__aenter__ = mock.AsyncMock(return_value=mock_response) + mock_get_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_session_instance = mock.MagicMock() + mock_session_instance.get = mock.Mock(return_value=mock_get_cm) + + mock_session_cm = mock.MagicMock() + mock_session_cm.__aenter__ = mock.AsyncMock(return_value=mock_session_instance) + mock_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_session_class.return_value = mock_session_cm + + result = await ModelEngineProvider().get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "text-embedding-ada" + assert result[0]["model_type"] == "embedding" + assert result[0]["model_tag"] == "embed" + assert result[0]["max_tokens"] == 0 + + +@pytest.mark.asyncio +async def test_modelengine_get_models_all_types(): + """ModelEngine provider should return all models when no type filter specified.""" + from backend.services.model_provider_service import ModelEngineProvider + + provider_config = {} # No model_type filter + + with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", "https://model-engine.com"), \ + mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", "test-key"), \ + mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \ + mock.patch("backend.services.model_provider_service.aiohttp.ClientTimeout"), \ + mock.patch("backend.services.model_provider_service.aiohttp.TCPConnector"): + + mock_response = mock.AsyncMock() + mock_response.status = 200 + mock_response.raise_for_status = mock.Mock() + mock_response.json = mock.AsyncMock(return_value={ + "data": [ + {"id": "gpt-4", "type": "chat"}, + {"id": "text-embedding-ada", "type": "embed"}, + {"id": "whisper", "type": "asr"}, + {"id": "tts-model", "type": "tts"}, + {"id": "rerank-model", "type": "rerank"}, + {"id": "vlm-model", "type": "vlm"}, + {"id": "unknown-model", "type": "unknown"}, # Should be filtered out + ] + }) + + # Setup mock session with proper async context manager + mock_get_cm = mock.MagicMock() + mock_get_cm.__aenter__ = mock.AsyncMock(return_value=mock_response) + mock_get_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_session_instance = mock.MagicMock() + mock_session_instance.get = mock.Mock(return_value=mock_get_cm) + + mock_session_cm = mock.MagicMock() + mock_session_cm.__aenter__ = mock.AsyncMock(return_value=mock_session_instance) + mock_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_session_class.return_value = mock_session_cm + + result = await ModelEngineProvider().get_models(provider_config) + + assert len(result) == 6 + # Verify type mapping + type_map = {model["id"]: model["model_type"] for model in result} + assert type_map["gpt-4"] == "llm" + assert type_map["text-embedding-ada"] == "embedding" + assert type_map["whisper"] == "stt" + assert type_map["tts-model"] == "tts" + assert type_map["rerank-model"] == "rerank" + assert type_map["vlm-model"] == "vlm" + + +@pytest.mark.asyncio +async def test_modelengine_get_models_exception(): + """ModelEngine provider should return empty list on exception.""" + from backend.services.model_provider_service import ModelEngineProvider + + provider_config = {"model_type": "llm"} + + with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", "https://model-engine.com"), \ + mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", "test-key"), \ + mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session: + + mock_session_instance = mock.AsyncMock() + mock_session_instance.__aenter__.return_value = mock_session_instance + mock_session_instance.get.side_effect = Exception("Network error") + mock_session.return_value = mock_session_instance + + result = await ModelEngineProvider().get_models(provider_config) + + assert result == [] + + +# --------------------------------------------------------------------------- +# Test-cases for prepare_model_dict with ModelEngine provider +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_prepare_model_dict_modelengine_llm(): + """ModelEngine LLM models should have correct base_url path and ssl_verify=False.""" + with mock.patch("backend.services.model_provider_service.split_repo_name", return_value=("modelengine", "gpt-4")), \ + mock.patch("backend.services.model_provider_service.add_repo_to_name", return_value="modelengine/gpt-4"), \ + mock.patch("backend.services.model_provider_service.ModelRequest") as mock_model_request, \ + mock.patch("backend.services.model_provider_service.embedding_dimension_check", new_callable=mock.AsyncMock), \ + mock.patch("backend.services.model_provider_service.ProviderEnum") as mock_enum: + + mock_model_req_instance = mock.MagicMock() + dump_dict = { + "model_factory": "modelengine", + "model_name": "gpt-4", + "model_type": "llm", + "api_key": "me-key", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + "display_name": "modelengine/gpt-4", + } + mock_model_req_instance.model_dump.return_value = dump_dict + mock_model_request.return_value = mock_model_req_instance + mock_enum.MODELENGINE.value = "modelengine" + + provider = "modelengine" + model = { + "id": "modelengine/gpt-4", + "model_type": "llm", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + "base_url": "https://120.253.225.102:50001/open/router/v1", + "api_key": "me-key" + } + base_url = "https://api.openai.com/v1" + api_key = "original-key" + + result = await prepare_model_dict(provider, model, base_url, api_key) + + expected = dump_dict | { + "model_repo": "modelengine", + "base_url": "https://120.253.225.102:50001/open/router/v1", + "connect_status": "not_detected", + "ssl_verify": False, + } + assert result == expected + assert result["ssl_verify"] == False + assert "/open/router/v1" in result["base_url"] + + +@pytest.mark.asyncio +async def test_prepare_model_dict_modelengine_embedding(): + """ModelEngine embedding models should have correct embeddings path.""" + with mock.patch("backend.services.model_provider_service.split_repo_name", return_value=("modelengine", "text-embedding")), \ + mock.patch("backend.services.model_provider_service.add_repo_to_name", return_value="modelengine/text-embedding"), \ + mock.patch("backend.services.model_provider_service.ModelRequest") as mock_model_request, \ + mock.patch("backend.services.model_provider_service.embedding_dimension_check", new_callable=mock.AsyncMock, return_value=1536), \ + mock.patch("backend.services.model_provider_service.ProviderEnum") as mock_enum, \ + mock.patch("backend.services.model_provider_service.ModelConnectStatusEnum") as mock_status_enum: + + mock_model_req_instance = mock.MagicMock() + dump_dict = { + "model_factory": "modelengine", + "model_name": "text-embedding", + "model_type": "embedding", + "api_key": "me-key", + "max_tokens": 1024, + "display_name": "modelengine/text-embedding", + } + mock_model_req_instance.model_dump.return_value = dump_dict + mock_model_request.return_value = mock_model_req_instance + mock_enum.MODELENGINE.value = "modelengine" + mock_status_enum.NOT_DETECTED.value = "not_detected" + + provider = "modelengine" + model = { + "id": "modelengine/text-embedding", + "model_type": "embedding", + "max_tokens": 1024, + "base_url": "https://120.253.225.102:50001", + "api_key": "me-key" + } + base_url = "https://api.openai.com/v1" + api_key = "original-key" + + result = await prepare_model_dict(provider, model, base_url, api_key) + + expected = dump_dict | { + "model_repo": "modelengine", + "base_url": "https://120.253.225.102:50001/open/router/v1/embeddings", + "connect_status": "not_detected", + "ssl_verify": False, + "max_tokens": 1536, + } + assert result == expected + assert result["ssl_verify"] == False + assert "/open/router/v1/embeddings" in result["base_url"] + + +@pytest.mark.asyncio +async def test_prepare_model_dict_modelengine_base_url_stripping(): + """ModelEngine should strip existing /open/ paths from base_url.""" + with mock.patch("backend.services.model_provider_service.split_repo_name", return_value=("modelengine", "gpt-4")), \ + mock.patch("backend.services.model_provider_service.add_repo_to_name", return_value="modelengine/gpt-4"), \ + mock.patch("backend.services.model_provider_service.ModelRequest") as mock_model_request, \ + mock.patch("backend.services.model_provider_service.embedding_dimension_check", new_callable=mock.AsyncMock), \ + mock.patch("backend.services.model_provider_service.ProviderEnum") as mock_enum: + + mock_model_req_instance = mock.MagicMock() + dump_dict = { + "model_factory": "modelengine", + "model_name": "gpt-4", + "model_type": "llm", + "api_key": "me-key", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + "display_name": "modelengine/gpt-4", + } + mock_model_req_instance.model_dump.return_value = dump_dict + mock_model_request.return_value = mock_model_req_instance + mock_enum.MODELENGINE.value = "modelengine" + + provider = "modelengine" + model = { + "id": "modelengine/gpt-4", + "model_type": "llm", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + "base_url": "https://120.253.225.102:50001/open/router/v1/some/path", + "api_key": "me-key" + } + base_url = "https://api.openai.com/v1" + api_key = "original-key" + + result = await prepare_model_dict(provider, model, base_url, api_key) + + # Should strip everything after /open/ + assert result["base_url"] == "https://120.253.225.102:50001/open/router/v1" + + +# --------------------------------------------------------------------------- +# Test-cases for get_provider_models with ModelEngine provider +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_get_provider_models_modelengine_success(): + """Should successfully get models from ModelEngine provider.""" + from backend.services.model_provider_service import ModelEngineProvider + + model_data = { + "provider": "modelengine", + "model_type": "llm" + } + + expected_models = [ + {"id": "gpt-4", "model_tag": "chat", "model_type": "llm", "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS} + ] + + with mock.patch("backend.services.model_provider_service.ModelEngineProvider") as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = expected_models + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == expected_models + mock_provider_class.assert_called_once() + mock_provider_instance.get_models.assert_called_once_with(model_data) + + +@pytest.mark.asyncio +async def test_get_provider_models_modelengine_empty_result(): + """Should handle empty result from ModelEngine provider.""" + from backend.services.model_provider_service import ModelEngineProvider + + model_data = { + "provider": "modelengine", + "model_type": "embedding" + } + + with mock.patch("backend.services.model_provider_service.ModelEngineProvider") as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = [] + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == [] + mock_provider_instance.get_models.assert_called_once_with(model_data) diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py index 9571c2d15..ba66119c8 100644 --- a/test/backend/services/test_vectordatabase_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -1248,21 +1248,17 @@ def test_health_check_unhealthy(self): self.assertIn("Health check failed", str(context.exception)) - @patch('backend.services.vectordatabase_service.calculate_term_weights') @patch('database.model_management_db.get_model_by_model_id') - def test_summary_index_name(self, mock_get_model_by_model_id, mock_calculate_weights): + def test_summary_index_name(self, mock_get_model_by_model_id): """ Test generating a summary for an index. This test verifies that: 1. Random documents are retrieved for summarization - 2. Term weights are calculated to identify important keywords - 3. The summary generation stream is properly initialized - 4. A StreamingResponse object is returned for streaming the summary tokens + 2. The summary generation stream is properly initialized using Map-Reduce approach + 3. A StreamingResponse object is returned for streaming the summary tokens """ # Setup - mock_calculate_weights.return_value = { - "keyword1": 0.8, "keyword2": 0.6} mock_get_model_by_model_id.return_value = { 'api_key': 'test_api_key', 'base_url': 'https://api.test.com', diff --git a/test/backend/test_llm_integration.py b/test/backend/test_llm_integration.py index dfd62539b..baada28c1 100644 --- a/test/backend/test_llm_integration.py +++ b/test/backend/test_llm_integration.py @@ -5,10 +5,52 @@ import pytest import sys import os +import types +from unittest.mock import patch, MagicMock # Add backend to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) +# Mock database.client and MinioClient before any imports to avoid MinIO initialization +class _MinioClient: + pass + +if "database.client" not in sys.modules: + database_client_mod = types.ModuleType("database.client") + database_client_mod.MinioClient = _MinioClient + sys.modules["database.client"] = database_client_mod + +# Mock backend.database.client as well +if "backend.database.client" not in sys.modules: + backend_db_client_mod = types.ModuleType("backend.database.client") + backend_db_client_mod.MinioClient = _MinioClient + sys.modules["backend.database.client"] = backend_db_client_mod + +# Ensure database module exists as a package (needs __path__ attribute) +if "database" not in sys.modules: + database_mod = types.ModuleType("database") + database_mod.__path__ = [] # Make it a package + sys.modules["database"] = database_mod + +# Mock database.model_management_db module to avoid MinIO initialization +if "database.model_management_db" not in sys.modules: + model_mgmt_db_mod = types.ModuleType("database.model_management_db") + model_mgmt_db_mod.get_model_by_model_id = MagicMock(return_value=None) + sys.modules["database.model_management_db"] = model_mgmt_db_mod + setattr(sys.modules["database"], "model_management_db", model_mgmt_db_mod) + +# Mock database.tenant_config_db to avoid import errors +if "database.tenant_config_db" not in sys.modules: + tenant_config_db_mod = types.ModuleType("database.tenant_config_db") + # Mock all functions that config_utils imports + tenant_config_db_mod.delete_config_by_tenant_config_id = MagicMock() + tenant_config_db_mod.get_all_configs_by_tenant_id = MagicMock() + tenant_config_db_mod.get_single_config_info = MagicMock() + tenant_config_db_mod.insert_config = MagicMock() + tenant_config_db_mod.update_config_by_tenant_config_id_and_data = MagicMock() + sys.modules["database.tenant_config_db"] = tenant_config_db_mod + setattr(sys.modules["database"], "tenant_config_db", tenant_config_db_mod) + from utils.document_vector_utils import summarize_document, summarize_cluster @@ -32,14 +74,21 @@ def test_summarize_document_with_llm_params_no_config(self): content = "This is a test document with some content about machine learning and AI." filename = "test_doc.txt" - # Test with model_id and tenant_id but no actual LLM call (will fail due to missing config) + # Mock get_model_by_model_id to return None (no config found) + # Use the already mocked module and just ensure it returns None + import database.model_management_db as model_mgmt_db + model_mgmt_db.get_model_by_model_id = MagicMock(return_value=None) + + # Test with model_id and tenant_id but no actual LLM call (will fallback due to missing config) result = summarize_document( content, filename, language="zh", max_words=50, model_id=1, tenant_id="test_tenant" ) - # Should return error message when model config not found - assert "Failed to generate summary" in result or "No model configuration found" in result + # Should return placeholder summary when model config not found (fallback behavior) + assert "[Document Summary: test_doc.txt]" in result + assert "max 50 words" in result + assert "Content:" in result def test_summarize_cluster_without_llm(self): """Test cluster summarization without LLM (fallback mode)""" @@ -63,13 +112,20 @@ def test_summarize_cluster_with_llm_params_no_config(self): "Document 2 discusses neural networks and deep learning." ] + # Mock get_model_by_model_id to return None (no config found) + # Use the already mocked module and just ensure it returns None + import database.model_management_db as model_mgmt_db + model_mgmt_db.get_model_by_model_id = MagicMock(return_value=None) + result = summarize_cluster( document_summaries, language="zh", max_words=100, model_id=1, tenant_id="test_tenant" ) - # Should return error message when model config not found - assert "Failed to generate summary" in result or "No model configuration found" in result + # Should return placeholder summary when model config not found (fallback behavior) + assert "[Cluster Summary]" in result + assert "max 100 words" in result + assert "Based on 2 documents" in result def test_summarize_document_english(self): """Test document summarization in English""" diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index dedbd7b3b..3dc831323 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -400,7 +400,8 @@ def test_create_model_success(nexent_agent_with_models, mock_model_config): api_key=mock_model_config.api_key, api_base=mock_model_config.url, temperature=mock_model_config.temperature, - top_p=mock_model_config.top_p + top_p=mock_model_config.top_p, + ssl_verify=True ) # Verify stop_event was set @@ -426,7 +427,8 @@ def test_create_model_deep_thinking_success(nexent_agent_with_models, mock_deep_ api_key=mock_deep_thinking_model_config.api_key, api_base=mock_deep_thinking_model_config.url, temperature=mock_deep_thinking_model_config.temperature, - top_p=mock_deep_thinking_model_config.top_p + top_p=mock_deep_thinking_model_config.top_p, + ssl_verify=True ) # Verify stop_event was set diff --git a/test/sdk/core/models/test_openai_llm.py b/test/sdk/core/models/test_openai_llm.py index 6bb61bdfa..d666c2bd2 100644 --- a/test/sdk/core/models/test_openai_llm.py +++ b/test/sdk/core/models/test_openai_llm.py @@ -1,4 +1,7 @@ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch, ANY +import importlib.util +import sys +from pathlib import Path import pytest @@ -55,6 +58,28 @@ def decorator(func): nexent_monitor_mock.MonitoringManager = MagicMock nexent_monitor_mock.MonitoringConfig = MagicMock +# Create mock parent package structure for nexent module +nexent_mock = MagicMock() +nexent_mock.monitor = nexent_monitor_mock +nexent_core_mock = MagicMock() +nexent_core_models_mock = MagicMock() +nexent_core_utils_mock = MagicMock() + +# Mock MessageObserver and ProcessType for utils.observer +class MockMessageObserver: + def __init__(self, *args, **kwargs): + self.add_model_new_token = MagicMock() + self.add_model_reasoning_content = MagicMock() + self.flush_remaining_tokens = MagicMock() + +class MockProcessType: + MODEL_OUTPUT_THINKING = "model_output_thinking" + MODEL_OUTPUT = "model_output" + +nexent_core_utils_mock.observer = MagicMock() +nexent_core_utils_mock.observer.MessageObserver = MockMessageObserver +nexent_core_utils_mock.observer.ProcessType = MockProcessType + # Assemble smolagents.* paths and monitoring mocks module_mocks = { "smolagents": mock_smolagents, @@ -62,13 +87,35 @@ def decorator(func): "openai.types": MagicMock(), "openai.types.chat": MagicMock(), "openai.types.chat.chat_completion_message": MagicMock(), + "openai": MagicMock(), + "openai.lib": MagicMock(), + "nexent": nexent_mock, "nexent.monitor": nexent_monitor_mock, "nexent.monitor.monitoring": nexent_monitor_mock, + "nexent.core": nexent_core_mock, + "nexent.core.models": nexent_core_models_mock, + "nexent.core.utils": nexent_core_utils_mock, + "nexent.core.utils.observer": nexent_core_utils_mock.observer, } +# Dynamically load the module directly by file path +MODULE_NAME = "nexent.core.models.openai_llm" +MODULE_PATH = ( + Path(__file__).resolve().parents[4] + / "sdk" + / "nexent" + / "core" + / "models" + / "openai_llm.py" +) + with patch.dict("sys.modules", module_mocks): - # Import after patching so dependencies are satisfied - from sdk.nexent.core.models.openai_llm import OpenAIModel as ImportedOpenAIModel + spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH) + openai_llm_module = importlib.util.module_from_spec(spec) + sys.modules[MODULE_NAME] = openai_llm_module + assert spec and spec.loader + spec.loader.exec_module(openai_llm_module) + ImportedOpenAIModel = openai_llm_module.OpenAIModel # ----------------------------------------------------------------------- # Fixtures @@ -83,6 +130,8 @@ def openai_model_instance(): # Inject dummy attributes required by the method under test model.model_id = "dummy-model" + model.temperature = 0.7 + model.top_p = 0.9 model.custom_role_conversions = {} # Add missing attribute # Client hierarchy: client.chat.completions.create @@ -592,5 +641,219 @@ def test_call_with_reasoning_content_and_content_together(openai_model_instance) "Response text") +# --------------------------------------------------------------------------- +# Tests for __init__ with ssl_verify parameter +# --------------------------------------------------------------------------- + +def test_init_with_ssl_verify_false(): + """Test __init__ method creates http_client when ssl_verify=False""" + + observer = MagicMock() + + # Mock DefaultHttpxClient from openai module + with patch("openai.DefaultHttpxClient") as mock_httpx_client: + mock_httpx_client.return_value = MagicMock() + + # Create model with ssl_verify=False + model = ImportedOpenAIModel(observer=observer, ssl_verify=False) + + # Verify DefaultHttpxClient was called with verify=False + mock_httpx_client.assert_called_once_with(verify=False) + + +def test_init_with_ssl_verify_true(): + """Test __init__ method doesn't create http_client when ssl_verify=True (default)""" + + observer = MagicMock() + + # Mock DefaultHttpxClient from openai module + with patch("openai.DefaultHttpxClient") as mock_httpx_client: + # Create model with ssl_verify=True (default) + model = ImportedOpenAIModel(observer=observer, ssl_verify=True) + + # Verify DefaultHttpxClient was NOT called + mock_httpx_client.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests for monitoring and token_tracker integration +# --------------------------------------------------------------------------- + +def test_call_with_monitoring_and_token_tracker(openai_model_instance): + """Test __call__ method with monitoring and token_tracker enabled""" + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Create mock token_tracker + mock_token_tracker = MagicMock() + mock_token_tracker.record_first_token = MagicMock() + mock_token_tracker.record_token = MagicMock() + mock_token_tracker.record_completion = MagicMock() + + # Mock the stream response + mock_chunk1 = MagicMock() + mock_chunk1.choices = [MagicMock()] + mock_chunk1.choices[0].delta.content = "Hello" + mock_chunk1.choices[0].delta.role = "assistant" + mock_chunk1.choices[0].delta.reasoning_content = None + + mock_chunk2 = MagicMock() + mock_chunk2.choices = [MagicMock()] + mock_chunk2.choices[0].delta.content = " world" + mock_chunk2.choices[0].delta.role = None + mock_chunk2.choices[0].delta.reasoning_content = None + + mock_chunk3 = MagicMock() + mock_chunk3.choices = [MagicMock()] + mock_chunk3.choices[0].delta.content = None + mock_chunk3.choices[0].delta.role = None + mock_chunk3.choices[0].delta.reasoning_content = None + mock_chunk3.usage = MagicMock() + mock_chunk3.usage.prompt_tokens = 10 + mock_chunk3.usage.completion_tokens = 5 + mock_chunk3.usage.total_tokens = 15 + + mock_stream = [mock_chunk1, mock_chunk2, mock_chunk3] + + # Mock ChatMessage.from_dict + mock_result_message = MagicMock() + mock_result_message.raw = mock_stream + mock_result_message.role = MagicMock() + + with patch.object(openai_model_instance, "_prepare_completion_kwargs", return_value={}), \ + patch.object(mock_models_module.ChatMessage, "from_dict", return_value=mock_result_message): + openai_model_instance.client.chat.completions.create.return_value = mock_stream + + # Call with _token_tracker kwarg + result = openai_model_instance.__call__(messages, _token_tracker=mock_token_tracker) + + # Verify monitoring calls + monitoring_manager_mock.add_span_event.assert_any_call("completion_started") + monitoring_manager_mock.set_span_attributes.assert_called() + monitoring_manager_mock.add_span_event.assert_any_call("completion_finished", ANY) + + # Verify token_tracker calls + mock_token_tracker.record_first_token.assert_called_once() + assert mock_token_tracker.record_token.call_count == 2 # "Hello" and " world" + mock_token_tracker.record_completion.assert_called_once_with(10, 5) + + +def test_call_with_token_tracker_on_reasoning_content(openai_model_instance): + """Test __call__ method tracks first token on reasoning_content""" + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Create mock token_tracker + mock_token_tracker = MagicMock() + mock_token_tracker.record_first_token = MagicMock() + mock_token_tracker.record_token = MagicMock() + mock_token_tracker.record_completion = MagicMock() + + # Mock the stream response with reasoning_content first + mock_chunk1 = MagicMock() + mock_chunk1.choices = [MagicMock()] + mock_chunk1.choices[0].delta.content = None + mock_chunk1.choices[0].delta.role = "assistant" + mock_chunk1.choices[0].delta.reasoning_content = "Thinking..." + + mock_chunk2 = MagicMock() + mock_chunk2.choices = [MagicMock()] + mock_chunk2.choices[0].delta.content = "Response" + mock_chunk2.choices[0].delta.role = None + mock_chunk2.choices[0].delta.reasoning_content = None + mock_chunk2.usage = MagicMock() + mock_chunk2.usage.prompt_tokens = 5 + mock_chunk2.usage.completion_tokens = 3 + mock_chunk2.usage.total_tokens = 8 + + mock_stream = [mock_chunk1, mock_chunk2] + + # Mock ChatMessage.from_dict + mock_result_message = MagicMock() + mock_result_message.raw = mock_stream + mock_result_message.role = MagicMock() + + with patch.object(openai_model_instance, "_prepare_completion_kwargs", return_value={}), \ + patch.object(mock_models_module.ChatMessage, "from_dict", return_value=mock_result_message): + openai_model_instance.client.chat.completions.create.return_value = mock_stream + + # Call with _token_tracker kwarg + result = openai_model_instance.__call__(messages, _token_tracker=mock_token_tracker) + + # Verify token_tracker.record_first_token was called when reasoning_content was received + mock_token_tracker.record_first_token.assert_called() + mock_token_tracker.record_token.assert_called_once_with("Response") + + +def test_call_with_stop_event_and_token_tracker(openai_model_instance): + """Test __call__ method adds monitoring event when stop_event is set with token_tracker""" + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Create mock token_tracker + mock_token_tracker = MagicMock() + + # Mock the stream response + mock_chunk = MagicMock() + mock_chunk.choices = [MagicMock()] + mock_chunk.choices[0].delta.content = "Response" + mock_chunk.choices[0].delta.role = "assistant" + mock_chunk.choices[0].delta.reasoning_content = None + + with patch.object(openai_model_instance, "_prepare_completion_kwargs", return_value={}): + openai_model_instance.client.chat.completions.create.return_value = [mock_chunk] + + # Set the stop event before calling + openai_model_instance.stop_event.set() + + # Call the method with token_tracker and expect RuntimeError + with pytest.raises(RuntimeError, match="Model is interrupted by stop event"): + openai_model_instance.__call__(messages, _token_tracker=mock_token_tracker) + + # Verify monitoring event was added + monitoring_manager_mock.add_span_event.assert_any_call("model_stopped", {"reason": "stop_event_set"}) + + +def test_call_exception_with_token_tracker(openai_model_instance): + """Test __call__ method adds error event when exception occurs with token_tracker""" + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Create mock token_tracker + mock_token_tracker = MagicMock() + + with patch.object(openai_model_instance, "_prepare_completion_kwargs", return_value={}): + # Mock the client to raise an exception + openai_model_instance.client.chat.completions.create.side_effect = Exception("API Error") + + # Call the method with token_tracker and expect exception + with pytest.raises(Exception, match="API Error"): + openai_model_instance.__call__(messages, _token_tracker=mock_token_tracker) + + # Verify error event was added + monitoring_manager_mock.add_span_event.assert_any_call("error_occurred", ANY) + + +def test_call_context_length_exceeded_with_token_tracker(openai_model_instance): + """Test __call__ method adds error event for context_length_exceeded with token_tracker""" + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Create mock token_tracker + mock_token_tracker = MagicMock() + + with patch.object(openai_model_instance, "_prepare_completion_kwargs", return_value={}): + # Mock the client to raise context length exceeded error + openai_model_instance.client.chat.completions.create.side_effect = Exception( + "context_length_exceeded: token limit exceeded") + + # Call the method with token_tracker and expect exception + with pytest.raises(Exception, match="context_length_exceeded"): + openai_model_instance.__call__(messages, _token_tracker=mock_token_tracker) + + # Verify error event was added + monitoring_manager_mock.add_span_event.assert_any_call("error_occurred", ANY) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/sdk/core/tools/test_datamate_search_tool.py b/test/sdk/core/tools/test_datamate_search_tool.py new file mode 100644 index 000000000..ebfdb3bba --- /dev/null +++ b/test/sdk/core/tools/test_datamate_search_tool.py @@ -0,0 +1,375 @@ +import json +from typing import List +from unittest.mock import ANY, MagicMock + +import httpx +import pytest +from pytest_mock import MockFixture + +from sdk.nexent.core.tools.datamate_search_tool import DataMateSearchTool +from sdk.nexent.core.utils.observer import MessageObserver, ProcessType + + +@pytest.fixture +def mock_observer() -> MessageObserver: + observer = MagicMock(spec=MessageObserver) + observer.lang = "en" + return observer + + +@pytest.fixture +def datamate_tool(mock_observer: MessageObserver) -> DataMateSearchTool: + return DataMateSearchTool( + server_ip="127.0.0.1", + server_port=8080, + observer=mock_observer, + ) + + +def _build_kb_list_response(ids: List[str]): + return { + "data": { + "content": [ + {"id": kb_id, "chunkCount": 1} + for kb_id in ids + ] + } + } + + +def _build_search_response(kb_id: str, count: int = 2): + return { + "data": [ + { + "entity": { + "id": f"file-{i}", + "text": f"content-{i}", + "createTime": "2024-01-01T00:00:00Z", + "score": 0.9 - i * 0.1, + "metadata": json.dumps( + { + "file_name": f"file-{i}.txt", + "absolute_directory_path": f"/data/{kb_id}", + } + ), + "scoreDetails": {"raw": 0.8}, + } + } + for i in range(count) + ] + } + + +class TestDataMateSearchToolInit: + def test_init_success(self, mock_observer: MessageObserver): + tool = DataMateSearchTool( + server_ip=" datamate.local ", + server_port=1234, + observer=mock_observer, + ) + + assert tool.server_ip == "datamate.local" + assert tool.server_port == 1234 + assert tool.server_base_url == "http://datamate.local:1234" + assert tool.kb_page == 0 + assert tool.kb_page_size == 20 + assert tool.observer is mock_observer + + @pytest.mark.parametrize("server_ip", ["", None]) + def test_init_invalid_server_ip(self, server_ip): + with pytest.raises(ValueError) as excinfo: + DataMateSearchTool(server_ip=server_ip, server_port=8080) + assert "server_ip is required" in str(excinfo.value) + + @pytest.mark.parametrize("server_port", [0, 65536, "8080"]) + def test_init_invalid_server_port(self, server_port): + with pytest.raises(ValueError) as excinfo: + DataMateSearchTool(server_ip="127.0.0.1", server_port=server_port) + assert "server_port must be an integer between 1 and 65535" in str(excinfo.value) + + +class TestHelperMethods: + @pytest.mark.parametrize( + "metadata_raw, expected", + [ + (None, {}), + ({"a": 1}, {"a": 1}), + ('{"b": 2}', {"b": 2}), + ("not-json", {}), + ], + ) + def test_parse_metadata(self, datamate_tool: DataMateSearchTool, metadata_raw, expected): + result = datamate_tool._parse_metadata(metadata_raw) + assert result == expected + + @pytest.mark.parametrize( + "path, expected", + [ + ("", ""), + ("/single", "single"), + ("/a/b/c", "c"), + ("////", ""), + ], + ) + def test_extract_dataset_id(self, datamate_tool: DataMateSearchTool, path, expected): + assert datamate_tool._extract_dataset_id(path) == expected + + @pytest.mark.parametrize( + "dataset_id, file_id, expected", + [ + ("ds1", "f1", "http://127.0.0.1:8080/api/data-management/datasets/ds1/files/f1/download"), + ("", "f1", ""), + ("ds1", "", ""), + ], + ) + def test_build_file_download_url(self, datamate_tool: DataMateSearchTool, dataset_id, file_id, expected): + assert datamate_tool._build_file_download_url(dataset_id, file_id) == expected + + +class TestKnowledgeBaseList: + def test_get_knowledge_base_list_success(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + response = MagicMock() + response.status_code = 200 + response.json.return_value = _build_kb_list_response(["kb1", "kb2"]) + client.post.return_value = response + + kb_ids = datamate_tool._get_knowledge_base_list() + + assert kb_ids == ["kb1", "kb2"] + client.post.assert_called_once_with( + f"{datamate_tool.server_base_url}/api/knowledge-base/list", + json={"page": datamate_tool.kb_page, "size": datamate_tool.kb_page_size}, + ) + + def test_get_knowledge_base_list_http_error_json_detail(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + response = MagicMock() + response.status_code = 500 + response.headers = {"content-type": "application/json"} + response.json.return_value = {"detail": "server error"} + client.post.return_value = response + + with pytest.raises(Exception) as excinfo: + datamate_tool._get_knowledge_base_list() + + assert "Failed to get knowledge base list" in str(excinfo.value) + + def test_get_knowledge_base_list_http_error_text_detail(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + response = MagicMock() + response.status_code = 400 + response.headers = {"content-type": "text/plain"} + response.text = "bad request" + client.post.return_value = response + + with pytest.raises(Exception) as excinfo: + datamate_tool._get_knowledge_base_list() + + assert "bad request" in str(excinfo.value) + + def test_get_knowledge_base_list_timeout(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + client.post.side_effect = httpx.TimeoutException("timeout") + + with pytest.raises(Exception) as excinfo: + datamate_tool._get_knowledge_base_list() + + assert "Timeout while getting knowledge base list" in str(excinfo.value) + + def test_get_knowledge_base_list_request_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + client.post.side_effect = httpx.RequestError("network", request=MagicMock()) + + with pytest.raises(Exception) as excinfo: + datamate_tool._get_knowledge_base_list() + + assert "Request error while getting knowledge base list" in str(excinfo.value) + + +class TestRetrieveKnowledgeBaseContent: + def test_retrieve_content_success(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + response = MagicMock() + response.status_code = 200 + response.json.return_value = _build_search_response("kb1", count=2) + client.post.return_value = response + + results = datamate_tool._retrieve_knowledge_base_content( + "query", + ["kb1"], + top_k=3, + threshold=0.2, + ) + + assert len(results) == 2 + client.post.assert_called_once() + + def test_retrieve_content_http_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + response = MagicMock() + response.status_code = 500 + response.headers = {"content-type": "application/json"} + response.json.return_value = {"detail": "server error"} + client.post.return_value = response + + with pytest.raises(Exception) as excinfo: + datamate_tool._retrieve_knowledge_base_content( + "query", + ["kb1"], + top_k=3, + threshold=0.2, + ) + + assert "Failed to retrieve knowledge base content" in str(excinfo.value) + + def test_retrieve_content_timeout(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + client.post.side_effect = httpx.TimeoutException("timeout") + + with pytest.raises(Exception) as excinfo: + datamate_tool._retrieve_knowledge_base_content( + "query", + ["kb1"], + top_k=3, + threshold=0.2, + ) + + assert "Timeout while retrieving knowledge base content" in str(excinfo.value) + + def test_retrieve_content_request_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + client.post.side_effect = httpx.RequestError("network", request=MagicMock()) + + with pytest.raises(Exception) as excinfo: + datamate_tool._retrieve_knowledge_base_content( + "query", + ["kb1"], + top_k=3, + threshold=0.2, + ) + + assert "Request error while retrieving knowledge base content" in str(excinfo.value) + + +class TestForward: + def _setup_success_flow(self, mocker: MockFixture, tool: DataMateSearchTool): + # Mock knowledge base list + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + kb_response = MagicMock() + kb_response.status_code = 200 + kb_response.json.return_value = _build_kb_list_response(["kb1"]) + + search_response = MagicMock() + search_response.status_code = 200 + search_response.json.return_value = _build_search_response("kb1", count=2) + + # First call for list, second for retrieve + client.post.side_effect = [kb_response, search_response] + return client + + def test_forward_success_with_observer_en(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client = self._setup_success_flow(mocker, datamate_tool) + + result_json = datamate_tool.forward("test query", top_k=2, threshold=0.5) + results = json.loads(result_json) + + assert len(results) == 2 + # Check that observer received running prompt and card + datamate_tool.observer.add_message.assert_any_call( + "", ProcessType.TOOL, datamate_tool.running_prompt_en + ) + datamate_tool.observer.add_message.assert_any_call( + "", ProcessType.CARD, json.dumps([{"icon": "search", "text": "test query"}], ensure_ascii=False) + ) + # Check that search content message is added (payload content is not strictly validated here) + datamate_tool.observer.add_message.assert_any_call( + "", ProcessType.SEARCH_CONTENT, ANY + ) + assert datamate_tool.record_ops == 1 + len(results) + assert all(isinstance(item["index"], str) for item in results) + + # Ensure both list and retrieve endpoints were called + assert client.post.call_count == 2 + + def test_forward_success_with_observer_zh(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + datamate_tool.observer.lang = "zh" + self._setup_success_flow(mocker, datamate_tool) + + datamate_tool.forward("测试查询") + + datamate_tool.observer.add_message.assert_any_call( + "", ProcessType.TOOL, datamate_tool.running_prompt_zh + ) + + def test_forward_no_observer(self, mocker: MockFixture): + tool = DataMateSearchTool(server_ip="127.0.0.1", server_port=8080, observer=None) + self._setup_success_flow(mocker, tool) + + # Should not raise and should not call observer + result_json = tool.forward("query") + assert len(json.loads(result_json)) == 2 + + def test_forward_no_knowledge_bases(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + kb_response = MagicMock() + kb_response.status_code = 200 + kb_response.json.return_value = _build_kb_list_response([]) + client.post.return_value = kb_response + + result = datamate_tool.forward("query") + assert result == json.dumps("No knowledge base found. No relevant information found.", ensure_ascii=False) + + def test_forward_no_results(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + client_cls = mocker.patch("sdk.nexent.core.tools.datamate_search_tool.httpx.Client") + client = client_cls.return_value.__enter__.return_value + + kb_response = MagicMock() + kb_response.status_code = 200 + kb_response.json.return_value = _build_kb_list_response(["kb1"]) + + search_response = MagicMock() + search_response.status_code = 200 + search_response.json.return_value = {"data": []} + + client.post.side_effect = [kb_response, search_response] + + with pytest.raises(Exception) as excinfo: + datamate_tool.forward("query") + + assert "No results found!" in str(excinfo.value) + + def test_forward_wrapped_error(self, mocker: MockFixture, datamate_tool: DataMateSearchTool): + # Simulate error in underlying method to verify top-level error wrapping + mocker.patch.object( + datamate_tool, + "_get_knowledge_base_list", + side_effect=Exception("low level error"), + ) + + with pytest.raises(Exception) as excinfo: + datamate_tool.forward("query") + + msg = str(excinfo.value) + assert "Error during DataMate knowledge base search" in msg + assert "low level error" in msg + +