Skip to content

Commit b39de9e

Browse files
achoumcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 387568961
1 parent 9a596dd commit b39de9e

File tree

7 files changed

+321
-29
lines changed

7 files changed

+321
-29
lines changed

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
threads without using the advanced configuration.
1818
- By default, remove the temporary directory used to train the model when the
1919
model python object is garbage collected.
20+
- Add the `import_dataspec` constructor argument to the model builder to
21+
import the feature definition and dictionaries (instead of relying on
22+
automatic discovery).
23+
24+
### Changes
25+
26+
- When saving a model in a directory already containing a model, only the
27+
`assets` directory is entirely removed before the export (instead of the
28+
entire model directory).
29+
30+
### Fixes
31+
32+
- Wrong label shape in the model inspector's objective field for
33+
pre-integerized labels.
2034

2135
## 0.1.7 - 2021-06-23
2236

tensorflow_decision_forests/component/builder/builder.py

Lines changed: 197 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@
2323
- CARTBuilder
2424
- GradientBoostedTreeBuilder
2525
26+
About categorical and categorical-set features with string dictionary:
27+
28+
Categorical and categorical-set features are tied to a dictionary of possible
29+
values. In addition, the special value "out-of-dictionary" (OOD) designate all
30+
the which are not in the dictionary. For example, the condition
31+
"a in ["x","<OOB>"]" if true if the feature "a" is equal to "x" or to any value
32+
not in the dictionary.
33+
34+
The feature dictionaries are automatically assembled as the union of all the
35+
observed values in the tree conditions. Alternatively, dictionaries can be
36+
get/set manually with "{get,set}_dictionary()" or imported from an existing
37+
dataspec with the "import_dataspec" constructor argument.
38+
39+
2640
Usage:
2741
2842
```python
@@ -54,6 +68,34 @@
5468
neg_child=LeafNode(
5569
value=ProbabilityValue(probability=[0.8, 0.2])))))
5670
71+
# Create a second tree
72+
# f2 in ["x", "y"]
73+
# ├─(pos)─ [0.1, 0.9]
74+
# └─(neg)─ [0.8, 0.2]
75+
#
76+
builder.add_tree(
77+
Tree(
78+
NonLeafNode(
79+
condition=CategoricalIsInCondition(
80+
feature=SimpleColumnSpec(
81+
name="f2",
82+
type=py_tree.dataspec.ColumnType.CATEGORICAL),
83+
mask=["x", "y"],
84+
missing_evaluation=False),
85+
pos_child=LeafNode(
86+
value=ProbabilityValue(probability=[0.1, 0.9])),
87+
neg_child=LeafNode(
88+
value=ProbabilityValue(probability=[0.8, 0.2])))))
89+
90+
# Optionally set the dictionary of the categorical feature "f2".
91+
# If not set, all the values not seens in the model ("z" in this case) will not
92+
# be known by the model and will be treated as OOD (out of dictionary).
93+
#
94+
# Defining a dictionary only has an impact if a condition is testing for the
95+
# `<OOD>` item directly i.e. the test `f2 in ["<OOD>"]` depends on the content
96+
# of the dictionary.
97+
builder.set_dictionary("f2",["<OOD>", "x", "y", "z"]
98+
5799
builder.close()
58100
59101
# Load and use the model
@@ -103,7 +145,8 @@ class AbstractBuilder(object):
103145
"""Generic model builder."""
104146

105147
def __init__(self, path: str, objective: py_tree.objective.AbstractObjective,
106-
model_format: Optional[ModelFormat]):
148+
model_format: Optional[ModelFormat],
149+
import_dataspec: Optional[data_spec_pb2.DataSpecification]):
107150

108151
if not path:
109152
raise ValueError("The path cannot be empty")
@@ -125,6 +168,9 @@ def __init__(self, path: str, objective: py_tree.objective.AbstractObjective,
125168

126169
tf.io.gfile.makedirs(self.yggdrasil_model_path())
127170

171+
if import_dataspec:
172+
self._import_dataspec(import_dataspec)
173+
128174
def close(self):
129175
"""Finalize the builder work.
130176
@@ -204,6 +250,98 @@ def objective(self) -> py_tree.objective.AbstractObjective:
204250

205251
return self._objective
206252

253+
def _import_dataspec(self, src_dataspec: data_spec_pb2.DataSpecification):
254+
"""Imports an existing dataspec (feature definitions).
255+
256+
This method should be called right after the object construction i.e. it
257+
should not be called after some part of the model was build.
258+
259+
Actions
260+
- Import the name and type of the features.
261+
- Import the feature dictionaries (if any).
262+
- Import the feature statistics (if any).
263+
264+
Does not import the index of the features i.e. feature #3 in the src
265+
dataspec might be different from feature #3 in the imported dataspec.
266+
267+
Does not import the dataspec column of the label.
268+
269+
Args:
270+
src_dataspec: Dataspec to import.
271+
"""
272+
273+
for src_col in src_dataspec.columns:
274+
275+
dst_col_idx, created = self._get_or_create_column_idx(src_col.name)
276+
277+
# Skip the label
278+
if dst_col_idx == self._header.label_col_idx:
279+
continue
280+
281+
if isinstance(self._objective, py_tree.objective.RankingObjective):
282+
if dst_col_idx == self._header.ranking_group_col_idx:
283+
continue
284+
285+
if not created:
286+
raise ValueError(
287+
"import_dataspec was called after some of the model was build. "
288+
"Make sure to call import_dataspec right after the model "
289+
"constructor.")
290+
291+
# Simply copy the dataspec column.
292+
self._dataspec.columns[dst_col_idx].CopyFrom(src_col)
293+
294+
def _check_column_has_dictionary(self, column_spec: data_spec_pb2.Column):
295+
"""Ensures that a column spec contain a dictionary (possibly empty)."""
296+
297+
if column_spec.type not in [
298+
ColumnType.CATEGORICAL, ColumnType.CATEGORICAL_SET
299+
]:
300+
raise ValueError(
301+
f"The feature \"{column_spec.name}\" is neither a CATEGORICAL "
302+
"OR CATEGORICAL_SET feature")
303+
304+
if column_spec.categorical.is_already_integerized:
305+
raise ValueError(
306+
f"The feature \"{column_spec.name}\" is already integerized "
307+
"and do not have a dictionary")
308+
309+
def get_dictionary(self, col_name: str) -> List[str]:
310+
"""Gets the dictionary of a categorical(-set) string feature."""
311+
312+
col_idx = self._dataspec_column_index.get(col_name)
313+
if col_idx is None:
314+
raise ValueError(f"Unknown feature \"{col_name}\"")
315+
316+
column_spec = self._dataspec.columns[col_idx]
317+
self._check_column_has_dictionary(column_spec)
318+
319+
return py_tree.dataspec.categorical_column_dictionary_to_list(column_spec)
320+
321+
def set_dictionary(self, col_name: str, dictionary: List[str]) -> None:
322+
"""Sets the dictionary of a categorical or categorical-set column."""
323+
324+
col_idx = self._dataspec_column_index.get(col_name)
325+
if col_idx is None:
326+
raise ValueError(f"Unknown feature \"{col_name}\"")
327+
328+
if py_tree.dataspec.OUT_OF_DICTIONARY not in dictionary:
329+
raise ValueError(
330+
"fThe dictionary should contain an \"{OUT_OF_DICTIONARY}\" value")
331+
332+
column_spec = self._dataspec.columns[col_idx]
333+
self._check_column_has_dictionary(column_spec)
334+
335+
column_spec.categorical.number_of_unique_values = len(dictionary)
336+
column_spec.categorical.items.clear()
337+
# The OOB value should be the first one.
338+
column_spec.categorical.items[py_tree.dataspec.OUT_OF_DICTIONARY].index = 0
339+
for item in dictionary:
340+
if item == py_tree.dataspec.OUT_OF_DICTIONARY:
341+
continue
342+
column_spec.categorical.items[item].index = len(
343+
column_spec.categorical.items)
344+
207345
def observe_feature(self,
208346
feature: inspector_lib.SimpleColumnSpec,
209347
categorical_values: Optional[Union[List[str],
@@ -253,7 +391,7 @@ def observe_feature(self,
253391
# The value is stored as a string.
254392
if created:
255393
# Create the out-of-vocabulary item.
256-
column.categorical.items["<OOV>"].index = 0
394+
column.categorical.items[py_tree.dataspec.OUT_OF_DICTIONARY].index = 0
257395
column.categorical.number_of_unique_values = 1
258396
for value in categorical_values:
259397
if value not in column.categorical.items:
@@ -291,22 +429,29 @@ def _initialize_header_column_idx(self):
291429
Should be called once before writing the header to disk.
292430
"""
293431

432+
assert not self._dataspec.columns
433+
294434
# The first column is the label.
295435
self._header.label_col_idx = 0
296436
label_column = self._dataspec.columns.add()
297437
label_column.name = self._objective.label
438+
self._dataspec_column_index[label_column.name] = self._header.label_col_idx
439+
298440
if isinstance(self._objective, py_tree.objective.ClassificationObjective):
299441
label_column.type = ColumnType.CATEGORICAL
300442

301443
# One value is reserved for the non-used OOV item.
302444
label_column.categorical.number_of_unique_values = self._objective.num_classes + 1
303445

304446
if not self._objective.has_integer_labels:
305-
label_column.categorical.items["<OOV>"].index = 0
447+
label_column.categorical.items[
448+
py_tree.dataspec.OUT_OF_DICTIONARY].index = 0
306449
for idx, value in enumerate(self._objective.classes):
307450
label_column.categorical.items[value].index = idx + 1
308451
assert len(label_column.categorical.items
309452
) == label_column.categorical.number_of_unique_values
453+
else:
454+
label_column.categorical.is_already_integerized = True
310455

311456
elif isinstance(self._objective, (py_tree.objective.RegressionObjective,
312457
py_tree.objective.RankingObjective)):
@@ -316,22 +461,29 @@ def _initialize_header_column_idx(self):
316461
raise NotImplementedError(f"No supported objective {self._objective}")
317462

318463
if isinstance(self._objective, py_tree.objective.RankingObjective):
464+
assert len(self._dataspec.columns) == 1
465+
319466
# Create the "group" column for Ranking.
320467
self._header.ranking_group_col_idx = 1
321468
group_column = self._dataspec.columns.add()
322469
group_column.type = ColumnType.HASH
323470
group_column.name = self._objective.group
471+
self._dataspec_column_index[
472+
group_column.name] = self._header.ranking_group_col_idx
324473

325474

326475
@six.add_metaclass(abc.ABCMeta)
327476
class AbstractDecisionForestBuilder(AbstractBuilder):
328477
"""Generic decision forest model builder."""
329478

330479
def __init__(self, path: str, objective: py_tree.objective.AbstractObjective,
331-
model_format: Optional[ModelFormat]):
480+
model_format: Optional[ModelFormat],
481+
import_dataspec: Optional[data_spec_pb2.DataSpecification]):
482+
483+
super(AbstractDecisionForestBuilder,
484+
self).__init__(path, objective, model_format, import_dataspec)
332485

333-
super(AbstractDecisionForestBuilder, self).__init__(path, objective,
334-
model_format)
486+
self._trees = []
335487

336488
num_node_shards = 1 # Store all the nodes in a single shard.
337489
self.specialized_header().num_node_shards = num_node_shards
@@ -344,6 +496,12 @@ def __init__(self, path: str, objective: py_tree.objective.AbstractObjective,
344496

345497
def close(self):
346498

499+
assert self.specialized_header().num_trees == len(self._trees)
500+
501+
for tree in self._trees:
502+
self._write_branch(tree.root)
503+
self._trees = []
504+
347505
# Write the model specialized header.
348506
_write_binary_proto(
349507
self.specialized_header(),
@@ -372,7 +530,8 @@ def specialized_header_filename(self) -> str:
372530
def add_tree(self, tree: py_tree.tree.Tree):
373531
"""Adds one tree to the model."""
374532

375-
self._write_branch(tree.root)
533+
self._observe_branch(tree.root)
534+
self._trees.append(tree)
376535
self.specialized_header().num_trees += 1
377536

378537
def check_leaf(self, node: py_tree.node.LeafNode):
@@ -385,13 +544,11 @@ def check_non_leaf(self, node: py_tree.node.NonLeafNode):
385544

386545
pass
387546

388-
def _write_branch(self, node: py_tree.node.AbstractNode):
389-
"""Write of a node and its children to the writer.
547+
def _observe_branch(self, node: py_tree.node.AbstractNode):
548+
"""Indexes the possible attribute values and check the tree validity.
390549
391-
Nodes are written in a Depth First Pre-order traversals (as expected by the
392-
model format).
393-
394-
This function is the inverse of inspector_lib._extract_branch.
550+
This method should be called on all the trees before any calls to
551+
"_write_branch".
395552
396553
Args:
397554
node: The node to write.
@@ -410,6 +567,23 @@ def _write_branch(self, node: py_tree.node.AbstractNode):
410567
elif isinstance(node, py_tree.node.LeafNode):
411568
self.check_leaf(node)
412569

570+
# Recursive call on the children.
571+
if isinstance(node, py_tree.node.NonLeafNode):
572+
self._observe_branch(node.neg_child)
573+
self._observe_branch(node.pos_child)
574+
575+
def _write_branch(self, node: py_tree.node.AbstractNode):
576+
"""Write of a node and its children to the writer.
577+
578+
Nodes are written in a Depth First Pre-order traversals (as expected by the
579+
model format).
580+
581+
This function is the inverse of inspector_lib._extract_branch.
582+
583+
Args:
584+
node: The node to write.
585+
"""
586+
413587
# Converts the node into a proto node.
414588
core_node = py_tree.node.node_to_core_node(node, self.dataspec)
415589

@@ -430,12 +604,14 @@ def __init__(
430604
path: str,
431605
objective: py_tree.objective.AbstractObjective,
432606
model_format: Optional[ModelFormat] = ModelFormat.TENSORFLOW_SAVED_MODEL,
433-
winner_take_all: Optional[bool] = False):
607+
winner_take_all: Optional[bool] = False,
608+
import_dataspec: Optional[data_spec_pb2.DataSpecification] = None):
434609
self._specialized_header = random_forest_pb2.Header(
435610
winner_take_all_inference=winner_take_all)
436611

437612
# Should be called last.
438-
super(RandomForestBuilder, self).__init__(path, objective, model_format)
613+
super(RandomForestBuilder, self).__init__(path, objective, model_format,
614+
import_dataspec)
439615

440616
def model_type(self) -> str:
441617
return "RANDOM_FOREST"
@@ -456,9 +632,9 @@ def check_leaf(self, node: py_tree.node.LeafNode):
456632
if len(node.value.probability) != self.objective.num_classes:
457633
raise ValueError(
458634
"The number of dimensions of the probability of "
459-
"the classification value ({len(node.value.probability)}) does not "
635+
f"the classification value ({len(node.value.probability)}) does not "
460636
"match the number of classes of the label in the objective "
461-
"({self.objective.num_classes})")
637+
f"({self.objective.num_classes})")
462638

463639
elif isinstance(self.objective, py_tree.objective.RegressionObjective):
464640
if not isinstance(node.value, py_tree.value.RegressionValue):
@@ -496,7 +672,8 @@ def __init__(
496672
path: str,
497673
objective: py_tree.objective.AbstractObjective,
498674
bias: Optional[float] = 0.0,
499-
model_format: Optional[ModelFormat] = ModelFormat.TENSORFLOW_SAVED_MODEL):
675+
model_format: Optional[ModelFormat] = ModelFormat.TENSORFLOW_SAVED_MODEL,
676+
import_dataspec: Optional[data_spec_pb2.DataSpecification] = None):
500677

501678
# Compute the number of tree per iterations and loss.
502679
#
@@ -542,8 +719,8 @@ def __init__(
542719
loss=loss)
543720

544721
# Should be called last.
545-
super(GradientBoostedTreeBuilder, self).__init__(path, objective,
546-
model_format)
722+
super(GradientBoostedTreeBuilder,
723+
self).__init__(path, objective, model_format, import_dataspec)
547724

548725
def model_type(self) -> str:
549726
return "GRADIENT_BOOSTED_TREES"

0 commit comments

Comments
 (0)