-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdocs_dataset.py
165 lines (130 loc) · 4.95 KB
/
docs_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
r"""Script to generate a dataset from the markdown docs files in a repository.
Usage:
$ python docs_dataset.py -h
Example:
$ python docs_dataset.py \
"argilla-io/argilla-python" \
--dataset-name "plaguss/argilla_sdk_docs_raw_unstructured"
$ python docs_dataset.py \
"argilla-io/argilla-python" \
"argilla-io/distilabel" \
--dataset-name "plaguss/argilla_sdk_docs_raw_unstructured"
$ python docs_dataset.py \
"argilla-io/argilla" \
--docs_folder "argilla/docs" \
--dataset-name "plaguss/argilla_sdk_docs_raw_dev"
"""
import pandas as pd
from datasets import Dataset, concatenate_datasets
from github import Github, Repository, ContentFile
import requests
import os
from typing import List, Optional
from pathlib import Path
# The github related functions are a copy from the following repository
# https://github.com/Nordgaren/Github-Folder-Downloader/blob/master/gitdl.py
def download(c: ContentFile, out: str) -> None:
r = requests.get(c.download_url)
output_path = f"{out}/{c.path}"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "wb") as f:
print(f"downloading {c.path} to {out}")
f.write(r.content)
def download_folder(repo: Repository, folder: str, out: str, recursive: bool) -> None:
contents = repo.get_contents(folder)
for c in contents:
if c.download_url is None:
if recursive:
download_folder(repo, c.path, out, recursive)
continue
download(c, out)
def create_chunks(md_files: List[Path]) -> dict[str, List[str]]:
"""Create the chunks of text from the markdown files.
Note:
We should allow the chunking strategy to take into account the max size delimited
by the number of tokens.
Args:
md_files: List of paths to the markdown files.
Returns:
Dictionary from filename to the list of chunks.
"""
from unstructured.chunking.title import chunk_by_title
from unstructured.partition.auto import partition
data = {}
for file in md_files:
partitioned_file = partition(filename=file)
chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
data[str(file)] = chunks
return data
def create_dataset(data: dict[str, List[str]], repo_name: Optional[str] = None) -> Dataset:
"""Creates a dataset from the dictionary of chunks.
Args:
data: Dictionary from filename to the list of chunks,
as obtained from `create_chunks`.
Returns:
Dataset with `filename` and `chunks` columns.
"""
df = pd.DataFrame.from_records(
[(k, v) for k, values in data.items() for v in values],
columns=["filename", "chunks"],
)
if repo_name:
df["repo_name"] = repo_name
ds = Dataset.from_pandas(df)
return ds
def main():
import argparse
description = (
"Download the docs from a github repository and generate a dataset "
"from the markdown files. The dataset will be pushed to the hub."
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"repo",
nargs="+",
help="Name of the repository in the hub. For example 'argilla-io/argilla-python'.",
)
parser.add_argument(
"--dataset-name",
help="Name to give to the new dataset. For example 'my-name/argilla_sdk_docs_raw'.",
)
parser.add_argument(
"--docs_folder",
default="docs",
help="Name of the docs folder in the repo, defaults to 'docs'.",
)
parser.add_argument(
"--output_dir",
help="Path to save the downloaded files from the repo (optional)",
)
parser.add_argument(
"--private",
action=argparse.BooleanOptionalAction,
help="Whether to keep the repository private or not. Defaults to False.",
)
args = parser.parse_args()
# Instantiate the Github object to download the files
dss = []
print("Instantiate repository...")
for repo_name in args.repo:
gh = Github()
repo = gh.get_repo(repo_name)
docs_path = Path(args.output_dir or repo_name.split("/")[1])
if docs_path.exists():
print(f"Folder {docs_path} already exists, skipping download.")
else:
print("Start downloading the files...")
download_folder(repo, args.docs_folder, str(docs_path), True)
md_files = list(docs_path.glob("**/*.md"))
# Loop to iterate over the files and generate chunks from the text pieces
print("Generating the chunks from the markdown files...")
data = create_chunks(md_files)
# Create a dataset to push it to the hub
print("Creating dataset...")
dss.append(create_dataset(data, repo_name=repo_name))
ds = concatenate_datasets(dss)
# Create multiple datasets and merge them
ds.push_to_hub(args.dataset_name, private=args.private)
print("Dataset pushed to the hub")
if __name__ == "__main__":
main()