Skip to content

Commit

Permalink
(improvement)(Chat) llmSqlParser is adapted for tag mode, and rule pa…
Browse files Browse the repository at this point in the history
…rsing filters based on the dataset query type. (#804)
  • Loading branch information
lexluo09 authored Mar 12, 2024
1 parent c2316c9 commit bcc0f9c
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public class LLMParserConfig {
@Value("${metric.topn:10}")
private Integer metricTopN;

@Value("${tag.topn:20}")
private Integer tagTopN;

@Value("${all.model:false}")
private Boolean allModel;
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package com.tencent.supersonic.chat.core.parser;

import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
Expand All @@ -27,10 +25,7 @@ public void parse(QueryContext queryContext, ChatContext chatContext) {
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
// 2.set queryType
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Long dataSetId = parseInfo.getDataSetId();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
parseInfo.setQueryType(dataSetSchema.getQueryType());
parseInfo.setQueryType(queryContext.getQueryType(parseInfo.getDataSetId()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
Expand Down Expand Up @@ -92,9 +93,8 @@ public NL2SQLTool getParserTool(QueryContext queryCtx, Long dataSetId) {
return llmParserTool.orElse(null);
}

public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, List<ElementValue> linkingValues) {
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
String queryText = queryCtx.getQueryText();

LLMReq llmReq = new LLMReq();
Expand Down Expand Up @@ -190,7 +190,8 @@ protected List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId)
.filter(elementMatch -> !elementMatch.isInherited())
.filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
return SchemaElementType.VALUE.equals(type) || SchemaElementType.TAG_VALUE.equals(type)
|| SchemaElementType.ID.equals(type);
})
.map(elementMatch -> {
ElementValue elementValue = new ElementValue();
Expand All @@ -203,25 +204,38 @@ protected List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId)

protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
return semanticSchema.getDimensions(dataSetId).stream()
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
elements = semanticSchema.getTags(dataSetId);
}
return elements.stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}

private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());

Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());

results.addAll(metrics);
Set<String> results = new HashSet<>();
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
Set<String> tags = semanticSchema.getTags(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(tags);
} else {
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(dimensions);
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
}
return results;
}

Expand All @@ -236,12 +250,15 @@ protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType)
|| SchemaElementType.DIMENSION.equals(elementType)
|| SchemaElementType.VALUE.equals(elementType);
|| SchemaElementType.VALUE.equals(elementType)
|| SchemaElementType.TAG.equals(elementType)
|| SchemaElementType.TAG_VALUE.equals(elementType);
})
.map(schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType elementType = element.getType();
if (SchemaElementType.VALUE.equals(elementType)) {
if (SchemaElementType.VALUE.equals(elementType) || SchemaElementType.TAG_VALUE.equals(
elementType)) {
return itemIdToName.get(element.getId());
}
return schemaElementMatch.getWord();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;

import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
Expand All @@ -10,12 +9,11 @@
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;

@Slf4j
public class LLMSqlParser implements SemanticParser {
Expand All @@ -41,8 +39,7 @@ public void parse(QueryContext queryCtx, ChatContext chatCtx) {
}
//4.construct a request, call the API for the large model, and retrieve the results.
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);

if (Objects.isNull(llmResp)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;

import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;

@Slf4j
public class AgentCheckParser implements SemanticParser {
Expand Down Expand Up @@ -46,18 +43,6 @@ private void filterQueries(QueryContext queryContext, List<SemanticQuery> querie
&& !tool.getQueryModes().contains(query.getQueryMode())) {
return true;
}
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
if (QueryManager.isTagQuery(query.getQueryMode())) {
if (!tool.getQueryTypes().contains(QueryType.TAG.name())) {
return true;
}
}
if (QueryManager.isMetricQuery(query.getQueryMode())) {
if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) {
return true;
}
}
}
if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.core.parser.sql.rule;

import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;

import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;

/**
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
Expand All @@ -10,17 +11,17 @@
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@Builder
Expand Down Expand Up @@ -58,4 +59,10 @@ public List<SemanticQuery> getCandidateQueries() {
.collect(Collectors.toList());
return candidateQueries;
}

public QueryType getQueryType(Long dataSetId) {
SemanticSchema semanticSchema = this.semanticSchema;
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
return dataSetSchema.getQueryType();
}
}
6 changes: 4 additions & 2 deletions common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-chroma</artifactId>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai</artifactId>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId>
Expand All @@ -195,7 +198,6 @@
<artifactId>hanlp</artifactId>
<version>${hanlp.version}</version>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
Expand All @@ -32,9 +33,6 @@
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.util.Arrays;
import java.util.List;

@Component
@Slf4j
@Order(3)
Expand Down Expand Up @@ -170,7 +168,6 @@ private void addAgent1() {
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
Expand All @@ -196,7 +193,6 @@ private void addAgent2() {
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(2L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool);

if (demoEnabledNl2SqlLlm) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void queryTest_tag_list_filter() throws Exception {
List<String> list = new ArrayList<>();
list.add("流行");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
"流行", "风格", 6L);
"流行", "风格", 2L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);

SchemaElement metric = SchemaElement.builder().name("播放量").build();
Expand Down
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down

0 comments on commit bcc0f9c

Please sign in to comment.