Skip to content

Commit bb493eb

Browse files
authored
Update synthesis with usability feedback (#1939)
1 parent aa3ac62 commit bb493eb

File tree

8 files changed

+266
-259
lines changed

8 files changed

+266
-259
lines changed

src/oumi/core/configs/params/synthesis_params.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -176,19 +176,19 @@ def __post_init__(self):
176176

177177

178178
@dataclass
179-
class PermutableAttributeValue:
180-
"""Value to be used for the attribute."""
179+
class SampledAttributeValue:
180+
"""Value to be sampled for the attribute."""
181181

182182
id: str
183183
"""ID to be used when referencing the attribute value during synthesis."""
184184

185-
value: str
186-
"""Value to be used for the attribute.
187-
Referenced as {attribute_id.value}"""
185+
name: str
186+
"""Plaintext name of the attribute value.
187+
Referenced as {attribute_id}"""
188188

189189
description: str
190190
"""Description of the attribute value.
191-
Referenced as {attribute_id.value.description}"""
191+
Referenced as {attribute_id.description}"""
192192

193193
sample_rate: Optional[float] = None
194194
"""Sample rate for the attribute value. If not specified, will assume uniform
@@ -197,34 +197,34 @@ class PermutableAttributeValue:
197197
def __post_init__(self):
198198
"""Verifies/populates params."""
199199
if not self.id:
200-
raise ValueError("PermutableAttributeValue.id cannot be empty.")
201-
if not self.value:
202-
raise ValueError("PermutableAttributeValue.value cannot be empty.")
200+
raise ValueError("SampledAttributeValue.id cannot be empty.")
201+
if not self.name:
202+
raise ValueError("SampledAttributeValue.name cannot be empty.")
203203
if not self.description:
204-
raise ValueError("PermutableAttributeValue.description cannot be empty.")
204+
raise ValueError("SampledAttributeValue.description cannot be empty.")
205205
if self.sample_rate is not None and (
206206
self.sample_rate < 0 or self.sample_rate > 1
207207
):
208208
raise ValueError(
209-
"PermutableAttributeValue.sample_rate must be between 0 and 1."
209+
"SampledAttributeValue.sample_rate must be between 0 and 1."
210210
)
211211

212212

213213
@dataclass
214-
class PermutableAttribute:
215-
"""Attributes to be varied across the dataset."""
214+
class SampledAttribute:
215+
"""Attributes to be sampled across the dataset."""
216216

217217
id: str
218218
"""ID to be used when referencing the attribute during synthesis."""
219219

220-
attribute: str
221-
"""Plaintext name of the attribute. Referenced as {attribute_id}"""
220+
name: str
221+
"""Plaintext name of the attribute. Referenced as {id.parent}"""
222222

223223
description: str
224-
"""Description of the attribute. Referenced as {attribute_id.description}"""
224+
"""Description of the attribute. Referenced as {id.parent.description}"""
225225

226-
possible_values: list[PermutableAttributeValue]
227-
"""Type of the attribute."""
226+
possible_values: list[SampledAttributeValue]
227+
"""Values to be sampled for the attribute."""
228228

229229
def get_value_distribution(self) -> dict[str, float]:
230230
"""Get the distribution of attribute values."""
@@ -236,13 +236,13 @@ def get_value_distribution(self) -> dict[str, float]:
236236
def __post_init__(self):
237237
"""Verifies/populates params."""
238238
if not self.id:
239-
raise ValueError("PermutableAttribute.id cannot be empty.")
240-
if not self.attribute:
241-
raise ValueError("PermutableAttribute.attribute cannot be empty.")
239+
raise ValueError("SampledAttribute.id cannot be empty.")
240+
if not self.name:
241+
raise ValueError("SampledAttribute.name cannot be empty.")
242242
if not self.description:
243-
raise ValueError("PermutableAttribute.description cannot be empty.")
243+
raise ValueError("SampledAttribute.description cannot be empty.")
244244
if not self.possible_values:
245-
raise ValueError("PermutableAttribute.possible_values cannot be empty.")
245+
raise ValueError("SampledAttribute.possible_values cannot be empty.")
246246

247247
value_ids = []
248248
sample_rates = []
@@ -252,9 +252,7 @@ def __post_init__(self):
252252

253253
value_ids_set = set(value_ids)
254254
if len(value_ids) != len(value_ids_set):
255-
raise ValueError(
256-
"PermutableAttribute.possible_values must have unique IDs."
257-
)
255+
raise ValueError("SampledAttribute.possible_values must have unique IDs.")
258256

259257
# Normalize sample rates
260258
normalized_sample_rates = []
@@ -267,7 +265,7 @@ def __post_init__(self):
267265
undefined_sample_rate_count += 1
268266

269267
if defined_sample_rate > 1.0:
270-
raise ValueError("PermutableAttribute.possible_values must sum to 1.0.")
268+
raise ValueError("SampledAttribute.possible_values must sum to 1.0.")
271269

272270
# Assign remaining sample rate to undefined sample rates
273271
remaining_sample_rate = 1.0 - defined_sample_rate
@@ -517,7 +515,7 @@ class GeneralSynthesisParams(BaseParams):
517515
Examples will be enumerated during sampling, and attributes can be referenced as
518516
attributes when generating new attributes."""
519517

520-
permutable_attributes: Optional[list[PermutableAttribute]] = None
518+
sampled_attributes: Optional[list[SampledAttribute]] = None
521519
"""Attributes to be varied across the dataset.
522520
523521
Attributes each have a set of possible values which will be randomly sampled
@@ -636,18 +634,18 @@ def _check_example_source_attribute_ids(self, all_attribute_ids: set[str]) -> No
636634
for new_key in example_keys:
637635
self._check_attribute_ids(all_attribute_ids, new_key)
638636

639-
def _check_permutable_attribute_ids(self, all_attribute_ids: set[str]) -> None:
640-
"""Check attribute IDs from permutable attributes for uniqueness."""
641-
if self.permutable_attributes is None:
637+
def _check_sampled_attribute_ids(self, all_attribute_ids: set[str]) -> None:
638+
"""Check attribute IDs from sampled attributes for uniqueness."""
639+
if self.sampled_attributes is None:
642640
return
643641

644-
if len(self.permutable_attributes) == 0:
642+
if len(self.sampled_attributes) == 0:
645643
raise ValueError(
646-
"GeneralSynthesisParams.permutable_attributes cannot be empty."
644+
"GeneralSynthesisParams.sampled_attributes cannot be empty."
647645
)
648646

649-
for permutable_attribute in self.permutable_attributes:
650-
attribute_id = permutable_attribute.id
647+
for sampled_attribute in self.sampled_attributes:
648+
attribute_id = sampled_attribute.id
651649
self._check_attribute_ids(all_attribute_ids, attribute_id)
652650

653651
def _check_generated_attribute_ids(self, all_attribute_ids: set[str]) -> None:
@@ -716,7 +714,7 @@ def __post_init__(self):
716714
self._check_dataset_source_attribute_ids(all_attribute_ids)
717715
self._check_document_source_attribute_ids(all_attribute_ids)
718716
self._check_example_source_attribute_ids(all_attribute_ids)
719-
self._check_permutable_attribute_ids(all_attribute_ids)
717+
self._check_sampled_attribute_ids(all_attribute_ids)
720718
self._check_generated_attribute_ids(all_attribute_ids)
721719
self._check_transformed_attribute_ids(all_attribute_ids)
722720
self._check_passthrough_attribute_ids()

src/oumi/core/synthesis/attribute_formatter.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,52 @@
1616
from oumi.utils.placeholders import resolve_placeholders
1717

1818

19-
class _AttributeValueInfo:
20-
"""Information about a value of a permutable attribute.
19+
class _AttributeParentInfo:
20+
"""Information about a parent of a sampled attribute."""
2121

22-
Used to format the string for a sample.
23-
"""
24-
25-
def __init__(self, value_name: str, value_description: str):
26-
"""Initialize the attribute value info."""
27-
self._value_name = value_name
28-
self.description = value_description
22+
def __init__(self, parent_name: str, parent_description: str):
23+
"""Initialize the attribute parent info."""
24+
self._parent_name = parent_name
25+
self.description = parent_description
2926

3027
def __str__(self) -> str:
31-
return self._value_name
28+
return self._parent_name
3229

3330

3431
class _AttributeInfo:
35-
"""Information about a permutable attribute.
32+
"""Information about a sampled attribute.
3633
3734
Used to format the string for a sample.
35+
36+
Example:
37+
attribute_id: "complexity"
38+
parent_name: "Complexity"
39+
parent_description: "The complexity of the text."
40+
value_name: "High"
41+
value_description: "The text is complex."
42+
43+
Formatting string:
44+
{complexity.parent} ({complexity.parent.description})
45+
{complexity} ({complexity.description})
46+
47+
Result:
48+
Complexity (The complexity of the text.)
49+
High (The text is complex.)
3850
"""
3951

4052
def __init__(
4153
self,
4254
attribute_id: str,
43-
attribute_name: str,
44-
attribute_description: str,
55+
parent_name: str,
56+
parent_description: str,
4557
value_name: str,
4658
value_description: str,
4759
):
4860
"""Initialize the attribute value info."""
4961
self.attribute_id = attribute_id
50-
self._attribute_name = attribute_name
51-
self.description = attribute_description
52-
self.value = _AttributeValueInfo(value_name, value_description)
62+
self._attribute_name = value_name
63+
self.description = value_description
64+
self.parent = _AttributeParentInfo(parent_name, parent_description)
5365

5466
def __str__(self) -> str:
5567
return self._attribute_name
@@ -65,22 +77,25 @@ class AttributeFormatter:
6577
def __init__(self, params: GeneralSynthesisParams):
6678
"""Initialize the formatter."""
6779
self._params = params
68-
self._permutable_attribute_map = (
69-
{perm_attr.id: perm_attr for perm_attr in params.permutable_attributes}
70-
if params.permutable_attributes
80+
self._sampled_attribute_map = (
81+
{
82+
sampled_attr.id: sampled_attr
83+
for sampled_attr in params.sampled_attributes
84+
}
85+
if params.sampled_attributes
7186
else {}
7287
)
73-
self._permutable_attribute_info = {}
88+
self._sampled_attribute_info = {}
7489

7590
# Pre-compute the attribute info for each possible value
76-
for attribute_id, attribute in self._permutable_attribute_map.items():
91+
for attribute_id, attribute in self._sampled_attribute_map.items():
7792
for value in attribute.possible_values:
7893
key = (attribute_id, value.id)
79-
self._permutable_attribute_info[key] = _AttributeInfo(
94+
self._sampled_attribute_info[key] = _AttributeInfo(
8095
attribute_id=attribute_id,
81-
attribute_name=attribute.attribute,
82-
attribute_description=attribute.description,
83-
value_name=value.value,
96+
parent_name=attribute.name,
97+
parent_description=attribute.description,
98+
value_name=value.name,
8499
value_description=value.description,
85100
)
86101

@@ -102,9 +117,9 @@ def format(
102117
"""
103118
attr_values = {}
104119
for attribute_id, attribute_value in sample.items():
105-
if self._is_permutable_attribute(attribute_id):
120+
if self._is_sampled_attribute(attribute_id):
106121
value_id = attribute_value
107-
attr_values[attribute_id] = self._get_permutable_attribute_value_info(
122+
attr_values[attribute_id] = self._get_sampled_attribute_value_info(
108123
attribute_id, value_id
109124
)
110125
else:
@@ -117,17 +132,17 @@ def format(
117132
)
118133
return formatted_string
119134

120-
def _is_permutable_attribute(self, attribute_id: str) -> bool:
121-
"""Check if the attribute is a permutable attribute."""
122-
return attribute_id in self._permutable_attribute_map
135+
def _is_sampled_attribute(self, attribute_id: str) -> bool:
136+
"""Check if the attribute is a sampled attribute."""
137+
return attribute_id in self._sampled_attribute_map
123138

124-
def _get_permutable_attribute_value_info(
139+
def _get_sampled_attribute_value_info(
125140
self, attribute_id: str, attribute_value_id: str
126141
) -> _AttributeInfo:
127-
"""Get the string representation information for a permutable attribute."""
142+
"""Get the string representation information for a sampled attribute."""
128143
key = (attribute_id, attribute_value_id)
129-
if key in self._permutable_attribute_info:
130-
return self._permutable_attribute_info[key]
144+
if key in self._sampled_attribute_info:
145+
return self._sampled_attribute_info[key]
131146

132147
raise ValueError(
133148
f"Attribute value {attribute_value_id} not found for "

src/oumi/core/synthesis/dataset_planner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DocumentSource,
2222
ExampleSource,
2323
GeneralSynthesisParams,
24-
PermutableAttribute,
24+
SampledAttribute,
2525
)
2626
from oumi.core.synthesis.dataset_ingestion import DatasetReader
2727
from oumi.core.synthesis.document_ingestion import DocumentReader, DocumentSegmenter
@@ -84,7 +84,7 @@ def plan(
8484
)
8585

8686
permutable_attribute_samples = self._plan_permutable_attributes(
87-
synthesis_params.permutable_attributes,
87+
synthesis_params.sampled_attributes,
8888
synthesis_params.combination_sampling,
8989
sample_count,
9090
)
@@ -190,7 +190,7 @@ def _ingest_document_sources(
190190

191191
def _plan_permutable_attributes(
192192
self,
193-
permutable_attributes: Optional[list[PermutableAttribute]],
193+
permutable_attributes: Optional[list[SampledAttribute]],
194194
combination_sampling: Optional[list[AttributeCombination]],
195195
sample_count: int,
196196
) -> list[dict]:

0 commit comments

Comments
 (0)