From bcc0f9caa98ac611d2fe2e4a328d709bda6ee43e Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:03:08 +0800 Subject: [PATCH] (improvement)(Chat) llmSqlParser is adapted for tag mode, and rule parsing filters based on the dataset query type. (#804) --- .../chat/core/config/LLMParserConfig.java | 3 + .../chat/core/parser/QueryTypeParser.java | 7 +-- .../parser/sql/llm/LLMRequestService.java | 57 ++++++++++++------- .../core/parser/sql/llm/LLMSqlParser.java | 9 +-- .../parser/sql/rule/AgentCheckParser.java | 19 +------ .../core/parser/sql/rule/RuleSqlParser.java | 5 +- .../chat/core/pojo/QueryContext.java | 17 ++++-- common/pom.xml | 6 +- .../tencent/supersonic/ChatDemoLoader.java | 8 +-- .../com/tencent/supersonic/chat/TagTest.java | 2 +- pom.xml | 5 ++ 11 files changed, 72 insertions(+), 66 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java index 5761cbb59..9b078ce8b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/LLMParserConfig.java @@ -22,6 +22,9 @@ public class LLMParserConfig { @Value("${metric.topn:10}") private Integer metricTopN; + @Value("${tag.topn:20}") + private Integer tagTopN; + @Value("${all.model:false}") private Boolean allModel; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/QueryTypeParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/QueryTypeParser.java index 55232470c..ba76579f9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/QueryTypeParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/QueryTypeParser.java @@ -1,9 +1,7 @@ package com.tencent.supersonic.chat.core.parser; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.chat.api.pojo.DataSetSchema; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.query.SemanticQuery; @@ -27,10 +25,7 @@ public void parse(QueryContext queryContext, ChatContext chatContext) { semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user); // 2.set queryType SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); - Long dataSetId = parseInfo.getDataSetId(); - DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); - parseInfo.setQueryType(dataSetSchema.getQueryType()); + parseInfo.setQueryType(queryContext.getQueryType(parseInfo.getDataSetId())); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java index 18852d1a1..c5653f6c4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java @@ -2,6 +2,7 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper; +import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; @@ -92,9 +93,8 @@ public NL2SQLTool getParserTool(QueryContext queryCtx, Long dataSetId) { return llmParserTool.orElse(null); } - public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, - SemanticSchema semanticSchema, List linkingValues) { - Map dataSetIdToName = semanticSchema.getDataSetIdToName(); + public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, List linkingValues) { + Map dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName(); String queryText = queryCtx.getQueryText(); LLMReq llmReq = new LLMReq(); @@ -190,7 +190,8 @@ protected List getValueList(QueryContext queryCtx, Long dataSetId) .filter(elementMatch -> !elementMatch.isInherited()) .filter(schemaElementMatch -> { SchemaElementType type = schemaElementMatch.getElement().getType(); - return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type); + return SchemaElementType.VALUE.equals(type) || SchemaElementType.TAG_VALUE.equals(type) + || SchemaElementType.ID.equals(type); }) .map(elementMatch -> { ElementValue elementValue = new ElementValue(); @@ -203,25 +204,38 @@ protected List getValueList(QueryContext queryCtx, Long dataSetId) protected Map getItemIdToName(QueryContext queryCtx, Long dataSetId) { SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - return semanticSchema.getDimensions(dataSetId).stream() + List elements = semanticSchema.getDimensions(dataSetId); + if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) { + elements = semanticSchema.getTags(dataSetId); + } + return elements.stream() .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } private Set getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) { SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - Set results = semanticSchema.getDimensions(dataSetId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getDimensionTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - Set metrics = semanticSchema.getMetrics(dataSetId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getMetricTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - results.addAll(metrics); + Set results = new HashSet<>(); + if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) { + Set tags = semanticSchema.getTags(dataSetId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getDimensionTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + results.addAll(tags); + } else { + Set dimensions = semanticSchema.getDimensions(dataSetId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getDimensionTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + results.addAll(dimensions); + Set metrics = semanticSchema.getMetrics(dataSetId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getMetricTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + results.addAll(metrics); + } return results; } @@ -236,12 +250,15 @@ protected Set getMatchedFieldNames(QueryContext queryCtx, Long dataSetId SchemaElementType elementType = schemaElementMatch.getElement().getType(); return SchemaElementType.METRIC.equals(elementType) || SchemaElementType.DIMENSION.equals(elementType) - || SchemaElementType.VALUE.equals(elementType); + || SchemaElementType.VALUE.equals(elementType) + || SchemaElementType.TAG.equals(elementType) + || SchemaElementType.TAG_VALUE.equals(elementType); }) .map(schemaElementMatch -> { SchemaElement element = schemaElementMatch.getElement(); SchemaElementType elementType = element.getType(); - if (SchemaElementType.VALUE.equals(elementType)) { + if (SchemaElementType.VALUE.equals(elementType) || SchemaElementType.TAG_VALUE.equals( + elementType)) { return itemIdToName.get(element.getId()); } return schemaElementMatch.getWord(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMSqlParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMSqlParser.java index 11714f676..211c810d1 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMSqlParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMSqlParser.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.chat.core.parser.sql.llm; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.core.agent.NL2SQLTool; import com.tencent.supersonic.chat.core.parser.SemanticParser; import com.tencent.supersonic.chat.core.pojo.ChatContext; @@ -10,12 +9,11 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.common.util.ContextUtils; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections4.MapUtils; - import java.util.List; import java.util.Map; import java.util.Objects; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.MapUtils; @Slf4j public class LLMSqlParser implements SemanticParser { @@ -41,8 +39,7 @@ public void parse(QueryContext queryCtx, ChatContext chatCtx) { } //4.construct a request, call the API for the large model, and retrieve the results. List linkingValues = requestService.getValueList(queryCtx, dataSetId); - SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); - LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues); + LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, linkingValues); LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId); if (Objects.isNull(llmResp)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java index f1f2533c5..39a94bd8f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java @@ -8,14 +8,11 @@ import com.tencent.supersonic.chat.core.parser.SemanticParser; import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.QueryContext; -import com.tencent.supersonic.chat.core.query.QueryManager; import com.tencent.supersonic.chat.core.query.SemanticQuery; -import com.tencent.supersonic.common.pojo.enums.QueryType; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; - import java.util.List; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; @Slf4j public class AgentCheckParser implements SemanticParser { @@ -46,18 +43,6 @@ private void filterQueries(QueryContext queryContext, List querie && !tool.getQueryModes().contains(query.getQueryMode())) { return true; } - if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) { - if (QueryManager.isTagQuery(query.getQueryMode())) { - if (!tool.getQueryTypes().contains(QueryType.TAG.name())) { - return true; - } - } - if (QueryManager.isMetricQuery(query.getQueryMode())) { - if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) { - return true; - } - } - } if (CollectionUtils.isEmpty(tool.getDataSetIds())) { return true; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java index 6dae880a5..b176e2ca9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java @@ -1,15 +1,14 @@ package com.tencent.supersonic.chat.core.parser.sql.rule; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.core.parser.SemanticParser; import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery; -import lombok.extern.slf4j.Slf4j; - import java.util.Arrays; import java.util.List; +import lombok.extern.slf4j.Slf4j; /** * RuleSqlParser resolves a specific SemanticQuery according to co-appearance diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java index c8ca91334..d8bd28c53 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/pojo/QueryContext.java @@ -2,6 +2,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.api.pojo.DataSetSchema; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; @@ -10,17 +11,17 @@ import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.plugin.Plugin; import com.tencent.supersonic.chat.core.query.SemanticQuery; +import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.util.ContextUtils; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; @Data @Builder @@ -58,4 +59,10 @@ public List getCandidateQueries() { .collect(Collectors.toList()); return candidateQueries; } + + public QueryType getQueryType(Long dataSetId) { + SemanticSchema semanticSchema = this.semanticSchema; + DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); + return dataSetSchema.getQueryType(); + } } diff --git a/common/pom.xml b/common/pom.xml index 887db13b4..aed6cc3eb 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -180,7 +180,10 @@ dev.langchain4j langchain4j-chroma - + + dev.langchain4j + langchain4j-azure-open-ai + org.apache.logging.log4j log4j-api @@ -195,7 +198,6 @@ hanlp ${hanlp.version} - diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java index 02e6108f1..537311f30 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java @@ -20,9 +20,10 @@ import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.chat.server.service.QueryService; import com.tencent.supersonic.common.pojo.SysParameter; -import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.service.SysParameterService; import com.tencent.supersonic.common.util.JsonUtil; +import java.util.Arrays; +import java.util.List; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -32,9 +33,6 @@ import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import java.util.Arrays; -import java.util.List; - @Component @Slf4j @Order(3) @@ -170,7 +168,6 @@ private void addAgent1() { ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setId("0"); ruleQueryTool.setDataSetIds(Lists.newArrayList(1L)); - ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name())); agentConfig.getTools().add(ruleQueryTool); if (demoEnabledNl2SqlLlm) { LLMParserTool llmParserTool = new LLMParserTool(); @@ -196,7 +193,6 @@ private void addAgent2() { ruleQueryTool.setId("0"); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setDataSetIds(Lists.newArrayList(2L)); - ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name())); agentConfig.getTools().add(ruleQueryTool); if (demoEnabledNl2SqlLlm) { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java index 1ee5fbd4c..23cf3d0ed 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java @@ -58,7 +58,7 @@ public void queryTest_tag_list_filter() throws Exception { List list = new ArrayList<>(); list.add("流行"); QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, - "流行", "风格", 6L); + "流行", "风格", 2L); expectedParseInfo.getDimensionFilters().add(dimensionFilter); SchemaElement metric = SchemaElement.builder().name("播放量").build(); diff --git a/pom.xml b/pom.xml index 4812b16f1..817d896d0 100644 --- a/pom.xml +++ b/pom.xml @@ -149,6 +149,11 @@ langchain4j-embeddings-bge-small-zh ${langchain4j.version} + + dev.langchain4j + langchain4j-azure-open-ai + ${langchain4j.version} +