Skip to content

Commit 4090f82

Browse files
committed
Add test for TPST and support jsonl format
1 parent 04c3803 commit 4090f82

File tree

5 files changed

+30
-4
lines changed

5 files changed

+30
-4
lines changed

DiarizationLM/README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,17 @@ We support 3 different output file formats:
7575
| Format | Description |
7676
| ------ | ----------- |
7777
| `tfrecord` | The [TFRecord format](https://www.tensorflow.org/tutorials/load_data/tfrecord) can be used by various machine learning libraries.|
78-
| `csv` | This format can be used by [OpenAI API](https://platform.openai.com/docs/api-reference/) for finetuning GPT models. OpenAI will usually convert these csv files to jsonl files.|
7978
| `json` | This format is more human readable and can be used for debugging. It's also useful for finetuning PaLM models via the [Google Cloud API](https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-text-models-supervised#text).|
79+
| `csv` | This format can be used by many existing tools. OpenAI also provides a tool to convert csv files to jsonl files.|
80+
| `jsonl` | This format can be directly used by the [OpenAI API](https://platform.openai.com/docs/api-reference/) for finetuning GPT models.|
8081

8182
Example command:
8283

8384
```bash
8485
python3 train_data_prep.py \
8586
--input="testdata/example_data.json" \
86-
--output="/tmp/example_data.csv" \
87-
--output_type=csv \
87+
--output="/tmp/example_data.jsonl" \
88+
--output_type=jsonl \
8889
--emit_input_length=1000 \
8990
--emit_target_length=1000 \
9091
--prompt_suffix=" --> " \

DiarizationLM/run_tests.sh

100644100755
File mode changed.

DiarizationLM/run_tools.sh

+8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ python3 train_data_prep.py \
1414
--output=/tmp/example_data.tfrecord \
1515
--output_type=tfrecord
1616

17+
python3 train_data_prep.py \
18+
--input=testdata/example_data.json \
19+
--output=/tmp/example_data.jsonl \
20+
--input_feature_key=prompt \
21+
--output_feature_key=completion \
22+
--completion_suffix=" [eod]" \
23+
--output_type=jsonl
24+
1725
python3 postprocess_completions.py \
1826
--input=testdata/example_completion_with_bad_completion.json \
1927
--output=/tmp/example_postprocessed.json

DiarizationLM/train_data_prep.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
flags.DEFINE_enum(
1515
"output_type",
1616
"tfrecord",
17-
["tfrecord", "json", "csv"],
17+
["tfrecord", "json", "csv", "jsonl"],
1818
"Output container formats for different use cases.",
1919
)
2020
flags.DEFINE_string("text_field", "hyp_text", "Name of field to get text")
@@ -100,6 +100,13 @@ def main(argv: Sequence[str]) -> None:
100100
csv_lines.append('"{}","{}"'.format(prompt, target))
101101
with open(FLAGS.output, "wt") as f:
102102
f.write("\n".join(csv_lines))
103+
elif FLAGS.output_type == "jsonl":
104+
json_lines = []
105+
for _, prompt, target in reader.generate_data_tuple():
106+
json_lines.append('{{"{}":"{}","{}":"{}"}}'.format(
107+
FLAGS.input_feature_key, prompt, FLAGS.output_feature_key, target))
108+
with open(FLAGS.output, "wt") as f:
109+
f.write("\n".join(json_lines))
103110

104111
print("Output has been written to:", FLAGS.output)
105112

DiarizationLM/utils_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ def test_get_oracle_speakers(self):
2323
expected = [1, 1, 1, 1, 2, 2, 2, 2]
2424
self.assertEqual(expected, hyp_spk_oracle)
2525

26+
def test_transcript_preserving_speaker_transfer(self):
27+
src_text = "hello good morning hi how are you pretty good"
28+
src_spk = "1 1 1 2 2 2 2 1 1"
29+
tgt_text = "hello morning hi hey are you be good"
30+
tgt_spk = "1 2 2 2 1 1 2 1"
31+
expected = "1 1 2 2 2 2 1 1"
32+
transfered_spk = utils.transcript_preserving_speaker_transfer(
33+
src_text, src_spk, tgt_text, tgt_spk)
34+
self.assertEqual(expected, transfered_spk)
35+
2636
def test_ref_to_oracle(self):
2737
test_data = {
2838
"hyp_text": "yo hello hi wow great",

0 commit comments

Comments
 (0)