Skip to content

Commit

Permalink
Merge pull request #1221 from Apollon77/entityfeat2
Browse files Browse the repository at this point in the history
More strict entity matching and better handle duplicate enum entity cases
  • Loading branch information
ericzon committed May 24, 2023
2 parents 0fb309a + dafb1f5 commit 5a9fe3b
Show file tree
Hide file tree
Showing 7 changed files with 491 additions and 15 deletions.
7 changes: 6 additions & 1 deletion packages/ner/src/extractor-builtin.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ class ExtractorBuiltin {
input.edges = input.edges || [];
if (newInput.edges) {
for (let i = 0; i < newInput.edges.length; i += 1) {
input.edges.push(newInput.edges[i]);
if (
!input.nerLimitToEntities ||
input.intentEntities.includes(newInput.edges[i].entity)
) {
input.edges.push(newInput.edges[i]);
}
}
}
input.edges = reduceEdges(input.edges, false);
Expand Down
2 changes: 1 addition & 1 deletion packages/ner/src/extractor-enum.js
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class ExtractorEnum {
}
}
edges.sort((a, b) => a.start - b.start);
input.edges = reduceEdges(edges, false);
input.edges = reduceEdges(edges, false, input.intentEntities);
return input;
}

Expand Down
43 changes: 36 additions & 7 deletions packages/ner/src/ner.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,28 @@ class Ner extends Clonable {
return result;
}

decideRules(srcInput) {
decideRules(srcInput, intentEntities) {
const input = srcInput;
input.nerRules = this.getRules(input.locale || 'en');
let nerRules = this.getRules(input.locale || 'en');
if (intentEntities && this.settings.considerOnlyIntentEntities) {
nerRules = nerRules.filter((rule) => intentEntities.includes(rule.name));
} else if (intentEntities) {
// entities in the current intent get a higher priority when
// sorting out overlapping matches
const intentRelevantRule = [];
const nonIntentRelevantRule = [];
nerRules.forEach((rule) => {
if (intentEntities.includes(rule.name)) {
intentRelevantRule.push(rule);
} else {
nonIntentRelevantRule.push(rule);
}
});
nerRules = intentRelevantRule.concat(nonIntentRelevantRule);
}
input.nerRules = nerRules;
input.nerLimitToEntities = this.settings.considerOnlyIntentEntities;
input.intentEntities = intentEntities;
return input;
}

Expand Down Expand Up @@ -343,10 +362,12 @@ class Ner extends Clonable {
input.entities = input.edges;
delete input.edges;
delete input.nerRules;
delete input.nerLimitToEntities;
delete input.intentEntities;
return input;
}

async defaultPipelineProcess(input) {
async defaultPipelineProcess(input, intentEntities) {
if (!this.cache) {
this.cache = {
extractEnum: this.container.get('extract-enum'),
Expand All @@ -371,7 +392,7 @@ class Ner extends Clonable {
this.cache.extractBuiltin = this.container.get('extract-builtin');
}
}
let output = await this.decideRules(input);
let output = await this.decideRules(input, intentEntities);
if (this.cache.extractEnum) {
output = await this.cache.extractEnum.run(output);
}
Expand All @@ -388,8 +409,11 @@ class Ner extends Clonable {
return output;
}

async process(srcInput) {
const input = { threshold: this.settings.threshold || 0.8, ...srcInput };
async process(srcInput, consideredEntities) {
const input = {
threshold: this.settings.threshold || 0.8,
...srcInput,
};
let result;
if (input.locale) {
const pipeline = this.container.getPipeline(
Expand All @@ -402,7 +426,12 @@ class Ner extends Clonable {
result = await this.runPipeline(input, this.pipelineProcess);
}
if (!result) {
result = await this.defaultPipelineProcess(input);
result = await this.defaultPipelineProcess(input, consideredEntities);
} else if (consideredEntities) {
// when custom pipeline is used then we can not be sure it is handled correctly
result.entities = result.entities.filter((entity) =>
consideredEntities.includes(entity.entity)
);
}
delete result.threshold;
return result;
Expand Down
44 changes: 39 additions & 5 deletions packages/ner/src/reduce-edges.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
*/
const { TrimTypesList } = require('./trim-types');

function runDiscard(srcEdge, srcOther, useMaxLength) {
function runDiscard(srcEdge, srcOther, useMaxLength, intentEntities = []) {
let edge;
let other;
if (
Expand All @@ -44,15 +44,29 @@ function runDiscard(srcEdge, srcOther, useMaxLength) {
other.entity === 'number') &&
other.len <= edge.len
) {
// Do nothing! entities have same priority
// Entities have same priority
if (
other.start === edge.start &&
other.end === edge.end &&
other.type === edge.type &&
other.entity === edge.entity &&
other.option === edge.option
) {
// same type and none of them is an enum or both are an enum
other.discarded = true;
} else if (
other.start === edge.start &&
other.end === edge.end &&
other.entity === edge.entity &&
other.type !== edge.type
) {
if (edge.type === 'trim' && other.type !== 'trim') {
edge.discarded = true;
} else if (edge.type !== 'trim' && other.type === 'trim') {
other.discarded = true;
} else {
other.discarded = true;
}
}
} else if (
(useMaxLength ||
Expand All @@ -62,7 +76,18 @@ function runDiscard(srcEdge, srcOther, useMaxLength) {
) {
edge.discarded = true;
} else if (edge.type === 'enum' && other.type === 'enum') {
if (
const edgeIncludedInIntentEntities = intentEntities.includes(edge.entity);
const otherIncludedInIntentEntities = intentEntities.includes(
other.entity
);
if (edgeIncludedInIntentEntities && !otherIncludedInIntentEntities) {
other.discarded = true;
} else if (
!edgeIncludedInIntentEntities &&
otherIncludedInIntentEntities
) {
edge.discarded = true;
} else if (
edge.len <= other.len &&
other.utteranceText.includes(edge.utteranceText)
) {
Expand Down Expand Up @@ -121,7 +146,7 @@ function splitEdges(edges) {
return edges;
}

function reduceEdges(edges, useMaxLength = true) {
function reduceEdges(edges, useMaxLength = true, intentEntities = []) {
edges = splitEdges(edges);
const edgeslen = edges.length;
for (let i = 0; i < edgeslen; i += 1) {
Expand All @@ -133,8 +158,17 @@ function reduceEdges(edges, useMaxLength = true) {
for (let j = i + 1; j < edgeslen; j += 1) {
const other = edges[j];
if (!other.discarded) {
runDiscard(edge, other, useMaxLength);
runDiscard(edge, other, useMaxLength, intentEntities);
}
if (edge.discarded) {
break;
}
}
}
if (!edge.discarded) {
const knownEntityPos = intentEntities.indexOf(edge.entity);
if (knownEntityPos !== -1) {
intentEntities.splice(knownEntityPos, 1);
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion packages/nlp/src/nlp.js
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,10 @@ class Nlp extends Clonable {
}
output.context = context;
if (forceNER || !this.slotManager.isEmpty) {
output = await this.ner.process({ ...output });
const intentEntities = this.slotManager.getIntentEntityNames(
output.intent
);
output = await this.ner.process({ ...output }, intentEntities);
} else {
output.entities = [];
output.sourceEntities = [];
Expand Down
Loading

0 comments on commit 5a9fe3b

Please sign in to comment.