2323 - CARTBuilder
2424 - GradientBoostedTreeBuilder
2525
26+ About categorical and categorical-set features with string dictionary:
27+
28+ Categorical and categorical-set features are tied to a dictionary of possible
29+ values. In addition, the special value "out-of-dictionary" (OOD) designate all
30+ the which are not in the dictionary. For example, the condition
31+ "a in ["x","<OOB>"]" if true if the feature "a" is equal to "x" or to any value
32+ not in the dictionary.
33+
34+ The feature dictionaries are automatically assembled as the union of all the
35+ observed values in the tree conditions. Alternatively, dictionaries can be
36+ get/set manually with "{get,set}_dictionary()" or imported from an existing
37+ dataspec with the "import_dataspec" constructor argument.
38+
39+
2640Usage:
2741
2842```python
5468 neg_child=LeafNode(
5569 value=ProbabilityValue(probability=[0.8, 0.2])))))
5670
71+ # Create a second tree
72+ # f2 in ["x", "y"]
73+ # ├─(pos)─ [0.1, 0.9]
74+ # └─(neg)─ [0.8, 0.2]
75+ #
76+ builder.add_tree(
77+ Tree(
78+ NonLeafNode(
79+ condition=CategoricalIsInCondition(
80+ feature=SimpleColumnSpec(
81+ name="f2",
82+ type=py_tree.dataspec.ColumnType.CATEGORICAL),
83+ mask=["x", "y"],
84+ missing_evaluation=False),
85+ pos_child=LeafNode(
86+ value=ProbabilityValue(probability=[0.1, 0.9])),
87+ neg_child=LeafNode(
88+ value=ProbabilityValue(probability=[0.8, 0.2])))))
89+
90+ # Optionally set the dictionary of the categorical feature "f2".
91+ # If not set, all the values not seens in the model ("z" in this case) will not
92+ # be known by the model and will be treated as OOD (out of dictionary).
93+ #
94+ # Defining a dictionary only has an impact if a condition is testing for the
95+ # `<OOD>` item directly i.e. the test `f2 in ["<OOD>"]` depends on the content
96+ # of the dictionary.
97+ builder.set_dictionary("f2",["<OOD>", "x", "y", "z"]
98+
5799builder.close()
58100
59101# Load and use the model
@@ -103,7 +145,8 @@ class AbstractBuilder(object):
103145 """Generic model builder."""
104146
105147 def __init__ (self , path : str , objective : py_tree .objective .AbstractObjective ,
106- model_format : Optional [ModelFormat ]):
148+ model_format : Optional [ModelFormat ],
149+ import_dataspec : Optional [data_spec_pb2 .DataSpecification ]):
107150
108151 if not path :
109152 raise ValueError ("The path cannot be empty" )
@@ -125,6 +168,9 @@ def __init__(self, path: str, objective: py_tree.objective.AbstractObjective,
125168
126169 tf .io .gfile .makedirs (self .yggdrasil_model_path ())
127170
171+ if import_dataspec :
172+ self ._import_dataspec (import_dataspec )
173+
128174 def close (self ):
129175 """Finalize the builder work.
130176
@@ -204,6 +250,98 @@ def objective(self) -> py_tree.objective.AbstractObjective:
204250
205251 return self ._objective
206252
253+ def _import_dataspec (self , src_dataspec : data_spec_pb2 .DataSpecification ):
254+ """Imports an existing dataspec (feature definitions).
255+
256+ This method should be called right after the object construction i.e. it
257+ should not be called after some part of the model was build.
258+
259+ Actions
260+ - Import the name and type of the features.
261+ - Import the feature dictionaries (if any).
262+ - Import the feature statistics (if any).
263+
264+ Does not import the index of the features i.e. feature #3 in the src
265+ dataspec might be different from feature #3 in the imported dataspec.
266+
267+ Does not import the dataspec column of the label.
268+
269+ Args:
270+ src_dataspec: Dataspec to import.
271+ """
272+
273+ for src_col in src_dataspec .columns :
274+
275+ dst_col_idx , created = self ._get_or_create_column_idx (src_col .name )
276+
277+ # Skip the label
278+ if dst_col_idx == self ._header .label_col_idx :
279+ continue
280+
281+ if isinstance (self ._objective , py_tree .objective .RankingObjective ):
282+ if dst_col_idx == self ._header .ranking_group_col_idx :
283+ continue
284+
285+ if not created :
286+ raise ValueError (
287+ "import_dataspec was called after some of the model was build. "
288+ "Make sure to call import_dataspec right after the model "
289+ "constructor." )
290+
291+ # Simply copy the dataspec column.
292+ self ._dataspec .columns [dst_col_idx ].CopyFrom (src_col )
293+
294+ def _check_column_has_dictionary (self , column_spec : data_spec_pb2 .Column ):
295+ """Ensures that a column spec contain a dictionary (possibly empty)."""
296+
297+ if column_spec .type not in [
298+ ColumnType .CATEGORICAL , ColumnType .CATEGORICAL_SET
299+ ]:
300+ raise ValueError (
301+ f"The feature \" { column_spec .name } \" is neither a CATEGORICAL "
302+ "OR CATEGORICAL_SET feature" )
303+
304+ if column_spec .categorical .is_already_integerized :
305+ raise ValueError (
306+ f"The feature \" { column_spec .name } \" is already integerized "
307+ "and do not have a dictionary" )
308+
309+ def get_dictionary (self , col_name : str ) -> List [str ]:
310+ """Gets the dictionary of a categorical(-set) string feature."""
311+
312+ col_idx = self ._dataspec_column_index .get (col_name )
313+ if col_idx is None :
314+ raise ValueError (f"Unknown feature \" { col_name } \" " )
315+
316+ column_spec = self ._dataspec .columns [col_idx ]
317+ self ._check_column_has_dictionary (column_spec )
318+
319+ return py_tree .dataspec .categorical_column_dictionary_to_list (column_spec )
320+
321+ def set_dictionary (self , col_name : str , dictionary : List [str ]) -> None :
322+ """Sets the dictionary of a categorical or categorical-set column."""
323+
324+ col_idx = self ._dataspec_column_index .get (col_name )
325+ if col_idx is None :
326+ raise ValueError (f"Unknown feature \" { col_name } \" " )
327+
328+ if py_tree .dataspec .OUT_OF_DICTIONARY not in dictionary :
329+ raise ValueError (
330+ "fThe dictionary should contain an \" {OUT_OF_DICTIONARY}\" value" )
331+
332+ column_spec = self ._dataspec .columns [col_idx ]
333+ self ._check_column_has_dictionary (column_spec )
334+
335+ column_spec .categorical .number_of_unique_values = len (dictionary )
336+ column_spec .categorical .items .clear ()
337+ # The OOB value should be the first one.
338+ column_spec .categorical .items [py_tree .dataspec .OUT_OF_DICTIONARY ].index = 0
339+ for item in dictionary :
340+ if item == py_tree .dataspec .OUT_OF_DICTIONARY :
341+ continue
342+ column_spec .categorical .items [item ].index = len (
343+ column_spec .categorical .items )
344+
207345 def observe_feature (self ,
208346 feature : inspector_lib .SimpleColumnSpec ,
209347 categorical_values : Optional [Union [List [str ],
@@ -253,7 +391,7 @@ def observe_feature(self,
253391 # The value is stored as a string.
254392 if created :
255393 # Create the out-of-vocabulary item.
256- column .categorical .items ["<OOV>" ].index = 0
394+ column .categorical .items [py_tree . dataspec . OUT_OF_DICTIONARY ].index = 0
257395 column .categorical .number_of_unique_values = 1
258396 for value in categorical_values :
259397 if value not in column .categorical .items :
@@ -291,22 +429,29 @@ def _initialize_header_column_idx(self):
291429 Should be called once before writing the header to disk.
292430 """
293431
432+ assert not self ._dataspec .columns
433+
294434 # The first column is the label.
295435 self ._header .label_col_idx = 0
296436 label_column = self ._dataspec .columns .add ()
297437 label_column .name = self ._objective .label
438+ self ._dataspec_column_index [label_column .name ] = self ._header .label_col_idx
439+
298440 if isinstance (self ._objective , py_tree .objective .ClassificationObjective ):
299441 label_column .type = ColumnType .CATEGORICAL
300442
301443 # One value is reserved for the non-used OOV item.
302444 label_column .categorical .number_of_unique_values = self ._objective .num_classes + 1
303445
304446 if not self ._objective .has_integer_labels :
305- label_column .categorical .items ["<OOV>" ].index = 0
447+ label_column .categorical .items [
448+ py_tree .dataspec .OUT_OF_DICTIONARY ].index = 0
306449 for idx , value in enumerate (self ._objective .classes ):
307450 label_column .categorical .items [value ].index = idx + 1
308451 assert len (label_column .categorical .items
309452 ) == label_column .categorical .number_of_unique_values
453+ else :
454+ label_column .categorical .is_already_integerized = True
310455
311456 elif isinstance (self ._objective , (py_tree .objective .RegressionObjective ,
312457 py_tree .objective .RankingObjective )):
@@ -316,22 +461,29 @@ def _initialize_header_column_idx(self):
316461 raise NotImplementedError (f"No supported objective { self ._objective } " )
317462
318463 if isinstance (self ._objective , py_tree .objective .RankingObjective ):
464+ assert len (self ._dataspec .columns ) == 1
465+
319466 # Create the "group" column for Ranking.
320467 self ._header .ranking_group_col_idx = 1
321468 group_column = self ._dataspec .columns .add ()
322469 group_column .type = ColumnType .HASH
323470 group_column .name = self ._objective .group
471+ self ._dataspec_column_index [
472+ group_column .name ] = self ._header .ranking_group_col_idx
324473
325474
326475@six .add_metaclass (abc .ABCMeta )
327476class AbstractDecisionForestBuilder (AbstractBuilder ):
328477 """Generic decision forest model builder."""
329478
330479 def __init__ (self , path : str , objective : py_tree .objective .AbstractObjective ,
331- model_format : Optional [ModelFormat ]):
480+ model_format : Optional [ModelFormat ],
481+ import_dataspec : Optional [data_spec_pb2 .DataSpecification ]):
482+
483+ super (AbstractDecisionForestBuilder ,
484+ self ).__init__ (path , objective , model_format , import_dataspec )
332485
333- super (AbstractDecisionForestBuilder , self ).__init__ (path , objective ,
334- model_format )
486+ self ._trees = []
335487
336488 num_node_shards = 1 # Store all the nodes in a single shard.
337489 self .specialized_header ().num_node_shards = num_node_shards
@@ -344,6 +496,12 @@ def __init__(self, path: str, objective: py_tree.objective.AbstractObjective,
344496
345497 def close (self ):
346498
499+ assert self .specialized_header ().num_trees == len (self ._trees )
500+
501+ for tree in self ._trees :
502+ self ._write_branch (tree .root )
503+ self ._trees = []
504+
347505 # Write the model specialized header.
348506 _write_binary_proto (
349507 self .specialized_header (),
@@ -372,7 +530,8 @@ def specialized_header_filename(self) -> str:
372530 def add_tree (self , tree : py_tree .tree .Tree ):
373531 """Adds one tree to the model."""
374532
375- self ._write_branch (tree .root )
533+ self ._observe_branch (tree .root )
534+ self ._trees .append (tree )
376535 self .specialized_header ().num_trees += 1
377536
378537 def check_leaf (self , node : py_tree .node .LeafNode ):
@@ -385,13 +544,11 @@ def check_non_leaf(self, node: py_tree.node.NonLeafNode):
385544
386545 pass
387546
388- def _write_branch (self , node : py_tree .node .AbstractNode ):
389- """Write of a node and its children to the writer .
547+ def _observe_branch (self , node : py_tree .node .AbstractNode ):
548+ """Indexes the possible attribute values and check the tree validity .
390549
391- Nodes are written in a Depth First Pre-order traversals (as expected by the
392- model format).
393-
394- This function is the inverse of inspector_lib._extract_branch.
550+ This method should be called on all the trees before any calls to
551+ "_write_branch".
395552
396553 Args:
397554 node: The node to write.
@@ -410,6 +567,23 @@ def _write_branch(self, node: py_tree.node.AbstractNode):
410567 elif isinstance (node , py_tree .node .LeafNode ):
411568 self .check_leaf (node )
412569
570+ # Recursive call on the children.
571+ if isinstance (node , py_tree .node .NonLeafNode ):
572+ self ._observe_branch (node .neg_child )
573+ self ._observe_branch (node .pos_child )
574+
575+ def _write_branch (self , node : py_tree .node .AbstractNode ):
576+ """Write of a node and its children to the writer.
577+
578+ Nodes are written in a Depth First Pre-order traversals (as expected by the
579+ model format).
580+
581+ This function is the inverse of inspector_lib._extract_branch.
582+
583+ Args:
584+ node: The node to write.
585+ """
586+
413587 # Converts the node into a proto node.
414588 core_node = py_tree .node .node_to_core_node (node , self .dataspec )
415589
@@ -430,12 +604,14 @@ def __init__(
430604 path : str ,
431605 objective : py_tree .objective .AbstractObjective ,
432606 model_format : Optional [ModelFormat ] = ModelFormat .TENSORFLOW_SAVED_MODEL ,
433- winner_take_all : Optional [bool ] = False ):
607+ winner_take_all : Optional [bool ] = False ,
608+ import_dataspec : Optional [data_spec_pb2 .DataSpecification ] = None ):
434609 self ._specialized_header = random_forest_pb2 .Header (
435610 winner_take_all_inference = winner_take_all )
436611
437612 # Should be called last.
438- super (RandomForestBuilder , self ).__init__ (path , objective , model_format )
613+ super (RandomForestBuilder , self ).__init__ (path , objective , model_format ,
614+ import_dataspec )
439615
440616 def model_type (self ) -> str :
441617 return "RANDOM_FOREST"
@@ -456,9 +632,9 @@ def check_leaf(self, node: py_tree.node.LeafNode):
456632 if len (node .value .probability ) != self .objective .num_classes :
457633 raise ValueError (
458634 "The number of dimensions of the probability of "
459- "the classification value ({len(node.value.probability)}) does not "
635+ f "the classification value ({ len (node .value .probability )} ) does not "
460636 "match the number of classes of the label in the objective "
461- "({self.objective.num_classes})" )
637+ f "({ self .objective .num_classes } )" )
462638
463639 elif isinstance (self .objective , py_tree .objective .RegressionObjective ):
464640 if not isinstance (node .value , py_tree .value .RegressionValue ):
@@ -496,7 +672,8 @@ def __init__(
496672 path : str ,
497673 objective : py_tree .objective .AbstractObjective ,
498674 bias : Optional [float ] = 0.0 ,
499- model_format : Optional [ModelFormat ] = ModelFormat .TENSORFLOW_SAVED_MODEL ):
675+ model_format : Optional [ModelFormat ] = ModelFormat .TENSORFLOW_SAVED_MODEL ,
676+ import_dataspec : Optional [data_spec_pb2 .DataSpecification ] = None ):
500677
501678 # Compute the number of tree per iterations and loss.
502679 #
@@ -542,8 +719,8 @@ def __init__(
542719 loss = loss )
543720
544721 # Should be called last.
545- super (GradientBoostedTreeBuilder , self ). __init__ ( path , objective ,
546- model_format )
722+ super (GradientBoostedTreeBuilder ,
723+ self ). __init__ ( path , objective , model_format , import_dataspec )
547724
548725 def model_type (self ) -> str :
549726 return "GRADIENT_BOOSTED_TREES"
0 commit comments