Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,32 @@
import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingResponse;
import modelengine.fel.community.model.openai.entity.image.OpenAiImageRequest;
import modelengine.fel.community.model.openai.entity.image.OpenAiImageResponse;
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankRequest;
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankResponse;
import modelengine.fel.community.model.openai.enums.ModelProcessingState;
import modelengine.fel.community.model.openai.util.HttpUtils;
import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.chat.ChatModel;
import modelengine.fel.core.chat.ChatOption;
import modelengine.fel.core.chat.Prompt;
import modelengine.fel.core.chat.support.AiMessage;
import modelengine.fel.core.document.MeasurableDocument;
import modelengine.fel.core.embed.EmbedModel;
import modelengine.fel.core.embed.EmbedOption;
import modelengine.fel.core.embed.Embedding;
import modelengine.fel.core.image.ImageModel;
import modelengine.fel.core.image.ImageOption;
import modelengine.fel.core.model.http.SecureConfig;
import modelengine.fel.core.rerank.RerankModel;
import modelengine.fel.core.rerank.RerankOption;
import modelengine.fit.http.client.HttpClassicClient;
import modelengine.fit.http.client.HttpClassicClientFactory;
import modelengine.fit.http.client.HttpClassicClientRequest;
import modelengine.fit.http.client.HttpClassicClientResponse;
import modelengine.fit.http.entity.Entity;
import modelengine.fit.http.entity.ObjectEntity;
import modelengine.fit.http.protocol.HttpRequestMethod;
import modelengine.fit.http.protocol.HttpResponseStatus;
import modelengine.fit.security.Decryptor;
import modelengine.fitframework.annotation.Component;
import modelengine.fitframework.annotation.Fit;
Expand Down Expand Up @@ -69,7 +76,7 @@
* @since 2024-08-07
*/
@Component
public class OpenAiModel implements EmbedModel, ChatModel, ImageModel {
public class OpenAiModel implements EmbedModel, ChatModel, ImageModel, RerankModel {
private static final Logger log = Logger.get(OpenAiModel.class);
private static final Map<String, Boolean> HTTPS_CONFIG_KEY_MAPS = MapBuilder.<String, Boolean>get()
.put("client.http.secure.ignore-trust", Boolean.FALSE)
Expand Down Expand Up @@ -168,6 +175,43 @@ public List<Media> generate(String prompt, ImageOption option) {
}
}

@Override
public List<MeasurableDocument> generate(List<MeasurableDocument> documents, RerankOption rerankOption) {
notEmpty(documents, "The documents cannot be empty.");
notNull(rerankOption, "The rerank option cannot be null.");
String modelSource = StringUtils.blankIf(rerankOption.baseUri(), this.baseUrl);
HttpClassicClientRequest request = this.httpClient.get()
Comment thread
CodeCasterX marked this conversation as resolved.
Outdated
.createRequest(HttpRequestMethod.POST,
UrlUtils.combine(modelSource, OpenAiApi.RERANK_ENDPOINT));
HttpUtils.setBearerAuth(request, StringUtils.blankIf(rerankOption.apiKey(), this.defaultApiKey));
List<String> docs = documents.stream().map(MeasurableDocument::text).collect(Collectors.toList());
OpenAiRerankRequest fields = new OpenAiRerankRequest(rerankOption, docs);
request.entity(Entity.createObject(request, fields));
OpenAiRerankResponse rerankResponse = this.rerankExchange(request);

return rerankResponse.results()
.stream()
.map(result -> new MeasurableDocument(documents.get(result.index()), result.relevanceScore()))
.sorted((document1, document2) -> (int) (document2.score() - document1.score()))
.collect(Collectors.toList());
}

private OpenAiRerankResponse rerankExchange(HttpClassicClientRequest request) {
try (HttpClassicClientResponse<Object> response = request.exchange(OpenAiRerankResponse.class)) {
if (response.statusCode() != HttpResponseStatus.OK.statusCode()) {
log.error("Failed to get rerank model response. [code={}, reason={}]",
response.statusCode(),
response.reasonPhrase());
throw new FitException("Failed to get rerank model response.");
}
return ObjectUtils.cast(response.objectEntity()
.map(ObjectEntity::object)
.orElseThrow(() -> new FitException("The response body is abnormal.")));
} catch (IOException e) {
throw new IllegalStateException("Failed to request rerank model.", e);
}
}

private Choir<ChatMessage> createChatStream(HttpClassicClientRequest request) {
AtomicReference<ModelProcessingState> modelProcessingState =
new AtomicReference<>(ModelProcessingState.INITIAL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ public interface OpenAiApi {
*/
String IMAGE_ENDPOINT = "/images/generations";

/**
* 重排请求的端点。
*/
String RERANK_ENDPOINT = "/rerank";

/**
* 请求头模型密钥字段。
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,35 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

package modelengine.fel.core.document.support;
package modelengine.fel.community.model.openai.entity.rerank;

import modelengine.fel.core.rerank.RerankOption;
import modelengine.fitframework.annotation.Property;
import modelengine.fitframework.inspection.Validation;
import modelengine.fitframework.serialization.annotation.SerializeStrategy;

import java.util.List;

/**
* 表示 Rerank API 格式的请求
* 表示 OpenAI API 格式的重排请求
*
* @since 2024-09-27
*/
@SerializeStrategy(include = SerializeStrategy.Include.NON_NULL)
public class RerankRequest {
public class OpenAiRerankRequest {
private final String model;
private final String query;
private final List<String> documents;
@Property(name = "top_n")
private final Integer topN;

/**
* 创建 {@link RerankRequest} 的实体。
* 创建 {@link OpenAiRerankRequest} 的实体。
*
* @param documents 表示要重新排序的文档对象。
* @param rerankOption 表示 rerank 模型参数。
*/
public RerankRequest(RerankOption rerankOption, List<String> documents) {
public OpenAiRerankRequest(RerankOption rerankOption, List<String> documents) {
Validation.notNull(rerankOption, "The rerankOption cannot be null.");
this.model = rerankOption.model();
this.query = rerankOption.query();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

package modelengine.fel.core.document.support;
package modelengine.fel.community.model.openai.entity.rerank;

import modelengine.fitframework.annotation.Property;
import modelengine.fitframework.util.CollectionUtils;
Expand All @@ -13,25 +13,25 @@
import java.util.List;

/**
* 表示 Rerank API 格式的请求
* 表示 OpenAI API 格式的重排响应
*
* @since 2024-09-27
*/
public class RerankResponse {
private List<RerankOrder> results;
public class OpenAiRerankResponse {
private List<OpenAiRerankResponse.RerankOrder> results;

/**
* 获取重新排序后的文档列表。
*
* @return 表示重新排序后的文档列表的 {@link List}{@code <}{@link RerankOrder}{@code >}。
* @return 表示重新排序后的文档列表的 {@link List}{@code <}{@link OpenAiRerankResponse.RerankOrder}{@code >}。
*/
public List<RerankOrder> results() {
public List<OpenAiRerankResponse.RerankOrder> results() {
return CollectionUtils.isEmpty(this.results)
? Collections.emptyList()
: Collections.unmodifiableList(this.results);
}

static class RerankOrder {
public static class RerankOrder {
Comment thread
CodeCasterX marked this conversation as resolved.
private int index;
@Property(name = "relevance_score")
private double relevanceScore;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
package modelengine.fel.community.model.openai;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertAll;

import modelengine.fel.community.model.openai.config.OpenAiConfig;
import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.chat.ChatOption;
import modelengine.fel.core.chat.support.ChatMessages;
import modelengine.fel.core.chat.support.HumanMessage;
import modelengine.fel.core.document.Document;
import modelengine.fel.core.document.MeasurableDocument;
import modelengine.fel.core.embed.EmbedOption;
import modelengine.fel.core.embed.Embedding;
import modelengine.fel.core.image.ImageOption;
import modelengine.fel.core.rerank.RerankOption;
import modelengine.fit.http.client.HttpClassicClientFactory;
import modelengine.fitframework.annotation.Fit;
import modelengine.fitframework.conf.Config;
Expand All @@ -31,6 +35,8 @@
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -42,6 +48,9 @@
@MvcTest(classes = TestModelController.class)
public class OpenAiModelTest {
private OpenAiModel openAiModel;
private static final int EXPECTED_TOP_K = 3;
private static final String HIGHEST_RANKED_TEXT = "C++ offers high performance.";
private static final double EXPECTED_HIGHEST_SCORE = 0.999071;
Comment thread
CodeCasterX marked this conversation as resolved.

@Fit
private HttpClassicClientFactory httpClientFactory;
Expand Down Expand Up @@ -91,4 +100,45 @@ void testOpenAiImageModel() {
"456",
"789");
}

@Test
@DisplayName("测试重排模型返回:应返回按相关性排序的前 K 个文档")
void testOpenAiRerankModel() {
// Given: 准备输入文档
List<MeasurableDocument> inputDocs = Arrays.asList(doc("0", "Java is a programming language."),
doc("1", "Python is great for data science."),
doc("2", HIGHEST_RANKED_TEXT),
doc("3", "Rust offers high performance."),
doc("4", "C offers high performance."));

RerankOption rerankOption = RerankOption.custom().model("rerank-model").build();

// When: 调用重排接口
List<MeasurableDocument> result = openAiModel.generate(inputDocs, rerankOption);

// Then: 验证结果
assertAll(() -> assertThat(result).as("应返回 top-%d 结果", EXPECTED_TOP_K).hasSize(EXPECTED_TOP_K),
Comment thread
CodeCasterX marked this conversation as resolved.

() -> {
List<Double> scores = result.stream().map(MeasurableDocument::score).collect(Collectors.toList());
assertThat(scores).as("结果应按相关性分数降序排列").isSortedAccordingTo(Collections.reverseOrder());
},

() -> {
List<String> resultTexts =
result.stream().map(MeasurableDocument::text).collect(Collectors.toList());
List<String> inputTexts =
inputDocs.stream().map(MeasurableDocument::text).collect(Collectors.toList());
assertThat(inputTexts).as("所有返回文档必须来自输入集").containsAll(resultTexts);
},

() -> assertThat(result.get(0).text()).as("得分最高的文档应为 C++").isEqualTo(HIGHEST_RANKED_TEXT),

() -> assertThat(result.get(0).score()).as("最高分应与模拟响应一致").isEqualTo(EXPECTED_HIGHEST_SCORE));
}

private MeasurableDocument doc(String id, String text) {
Document document = Document.custom().id(id).text(text).metadata(new HashMap<>()).build();
return new MeasurableDocument(document, 0.0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import static modelengine.fel.community.model.openai.api.OpenAiApi.CHAT_ENDPOINT;
import static modelengine.fel.community.model.openai.api.OpenAiApi.EMBEDDING_ENDPOINT;
import static modelengine.fel.community.model.openai.api.OpenAiApi.IMAGE_ENDPOINT;
import static modelengine.fel.community.model.openai.api.OpenAiApi.RERANK_ENDPOINT;

import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingResponse;
import modelengine.fel.community.model.openai.entity.image.OpenAiImageResponse;
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankResponse;
import modelengine.fit.http.annotation.PostMapping;
import modelengine.fitframework.annotation.Component;
import modelengine.fitframework.flowable.Choir;
Expand Down Expand Up @@ -81,4 +83,17 @@ public OpenAiImageResponse image() {
+ "\"data\":[{\"b64_json\":\"123\"}, {\"b64_json\":\"456\"}, {\"b64_json\":\"789\"}]}";
return this.serializer.deserialize(json, OpenAiImageResponse.class);
}

/**
* 测试用重排接口。
*
* @return 表示重排响应的 {@link OpenAiRerankResponse}。
*/
@PostMapping(RERANK_ENDPOINT)
public OpenAiRerankResponse rerank() {
String json =
"{\"results\":[{\"index\":2,\"relevance_score\":0.999071},{\"index\":3,\"relevance_score\":0.7867867},"
+ "{\"index\":0,\"relevance_score\":0.32713068}]}";
return this.serializer.deserialize(json, OpenAiRerankResponse.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,32 @@

import modelengine.fel.core.document.DocumentPostProcessor;
import modelengine.fel.core.document.MeasurableDocument;
import modelengine.fit.http.client.HttpClassicClient;
import modelengine.fit.http.client.HttpClassicClientFactory;
import modelengine.fit.http.client.HttpClassicClientRequest;
import modelengine.fit.http.client.HttpClassicClientResponse;
import modelengine.fit.http.entity.Entity;
import modelengine.fit.http.entity.ObjectEntity;
import modelengine.fit.http.protocol.HttpRequestMethod;
import modelengine.fit.http.protocol.HttpResponseStatus;
import modelengine.fitframework.exception.FitException;
import modelengine.fel.core.rerank.RerankModel;
import modelengine.fel.core.rerank.RerankOption;
import modelengine.fitframework.inspection.Validation;
import modelengine.fitframework.log.Logger;
import modelengine.fitframework.resource.UrlUtils;
import modelengine.fitframework.util.CollectionUtils;
import modelengine.fitframework.util.LazyLoader;
import modelengine.fitframework.util.ObjectUtils;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

/**
* 表示检索文档的后置重排序接口。
*
* @since 2024-09-14
*/
public class RerankDocumentProcessor implements DocumentPostProcessor {
private static final Logger log = Logger.get(RerankDocumentProcessor.class);

private final LazyLoader<HttpClassicClient> httpClient;
private final RerankOption rerankOption;
private final RerankModel rerankModel;

/**
* 创建 {@link RerankDocumentProcessor} 的实体。
*
* @param httpClientFactory 表示 {@link HttpClassicClientFactory} 的实例
* @param rerankOption 表示 rerank 模型参数的 {@link RerankOption}
* @param rerankOption 表示 rerank 模型参数的 {@link RerankOption}
Comment thread
CodeCasterX marked this conversation as resolved.
Outdated
* @param rerankModel 表示 rerank 模型接口的 {@link RerankModel}。
*/
public RerankDocumentProcessor(HttpClassicClientFactory httpClientFactory, RerankOption rerankOption) {
Validation.notNull(httpClientFactory, "The httpClientFactory cannot be null.");
this.httpClient =
new LazyLoader<>(() -> httpClientFactory.create(HttpClassicClientFactory.Config.builder().build()));
public RerankDocumentProcessor(RerankOption rerankOption, RerankModel rerankModel) {
this.rerankOption = Validation.notNull(rerankOption, "The rerankOption cannot be null.");
this.rerankModel = Validation.notNull(rerankModel, "The rerankModel cannot be null.");
}

/**
Expand All @@ -63,35 +46,6 @@ public List<MeasurableDocument> process(List<MeasurableDocument> documents) {
if (CollectionUtils.isEmpty(documents)) {
return Collections.emptyList();
}
List<String> docs = documents.stream().map(MeasurableDocument::text).collect(Collectors.toList());
RerankRequest fields = new RerankRequest(this.rerankOption, docs);

HttpClassicClientRequest request = this.httpClient.get()
.createRequest(HttpRequestMethod.POST,
UrlUtils.combine(this.rerankOption.baseUri(), RerankApi.RERANK_ENDPOINT));
request.entity(Entity.createObject(request, fields));
RerankResponse rerankResponse = this.rerankExchange(request);

return rerankResponse.results()
.stream()
.map(result -> new MeasurableDocument(documents.get(result.index()), result.relevanceScore()))
.sorted((document1, document2) -> (int) (document2.score() - document1.score()))
.collect(Collectors.toList());
}

private RerankResponse rerankExchange(HttpClassicClientRequest request) {
try (HttpClassicClientResponse<Object> response = request.exchange(RerankResponse.class)) {
if (response.statusCode() != HttpResponseStatus.OK.statusCode()) {
log.error("Failed to get rerank model response. [code={}, reason={}]",
response.statusCode(),
response.reasonPhrase());
throw new FitException("Failed to get rerank model response.");
}
return ObjectUtils.cast(response.objectEntity()
.map(ObjectEntity::object)
.orElseThrow(() -> new FitException("The response body is abnormal.")));
} catch (IOException e) {
throw new IllegalStateException("Failed to request rerank model.", e);
}
return rerankModel.generate(documents, rerankOption);
}
}
Loading