diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 834bcfcd92..0908f57c53 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,6 @@ import org.springframework.ai.embedding.AbstractEmbeddingModel; import org.springframework.ai.embedding.Embedding; import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; @@ -110,12 +109,16 @@ public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataM @Override public EmbeddingResponse call(EmbeddingRequest request) { - var apiRequest = createRequest(request); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); + + var apiRequest = createRequest(embeddingRequest); var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) .provider(MistralAiApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(apiRequest)) + .requestOptions(embeddingRequest.getOptions()) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION @@ -146,20 +149,29 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } + private EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options + MistralAiEmbeddingOptions runtimeOptions = null; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + MistralAiEmbeddingOptions.class); + } + + // Define request options by merging runtime options and default options + MistralAiEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + MistralAiEmbeddingOptions.class); + + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); + } + private DefaultUsage getDefaultUsage(MistralAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } - @SuppressWarnings("unchecked") private MistralAiApi.EmbeddingRequest> createRequest(EmbeddingRequest request) { - var embeddingRequest = new MistralAiApi.EmbeddingRequest<>(request.getInstructions(), - this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat()); - - if (request.getOptions() != null) { - embeddingRequest = ModelOptionsUtils.merge(request.getOptions(), embeddingRequest, - MistralAiApi.EmbeddingRequest.class); - } - return embeddingRequest; + MistralAiEmbeddingOptions requestOptions = (MistralAiEmbeddingOptions) request.getOptions(); + return new MistralAiApi.EmbeddingRequest<>(request.getInstructions(), requestOptions.getModel(), + requestOptions.getEncodingFormat()); } @Override @@ -168,10 +180,6 @@ public float[] embed(Document document) { return this.embed(document.getFormattedContent(this.metadataMode)); } - private EmbeddingOptions buildRequestOptions(MistralAiApi.EmbeddingRequest> request) { - return EmbeddingOptionsBuilder.builder().withModel(request.model()).build(); - } - /** * Use the provided convention for reporting observation data * @param observationConvention The provided convention