Skip to content

Commit 9119182

Browse files
committed
refactor(llamabot/cli)🔧: Refactor bot functions to accept model names as parameters
- Add model_name parameter to ood_checker_bot, docwriter_bot, and refine_bot functions. - Update write function to pass model_name parameters to bot functions.
1 parent fb336b8 commit 9119182

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

‎llamabot/cli/docs.py‎

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,21 +226,21 @@ def docwriter_sysprompt():
226226
"""
227227

228228

229-
def ood_checker_bot() -> StructuredBot:
229+
def ood_checker_bot(model_name: str = "gpt-4o") -> StructuredBot:
230230
"""Return a StructuredBot for the out-of-date checker."""
231231
return StructuredBot(
232232
system_prompt=ood_checker_sysprompt(),
233233
pydantic_model=DocsOutOfDate,
234-
model_name="gpt-4-turbo",
234+
model_name=model_name,
235235
)
236236

237237

238-
def docwriter_bot() -> StructuredBot:
238+
def docwriter_bot(model_name: str = "gpt-4o") -> StructuredBot:
239239
"""Return a StructuredBot for the documentation writer."""
240240
return StructuredBot(
241241
system_prompt=docwriter_sysprompt(),
242242
pydantic_model=DocumentationContent,
243-
model_name="gpt-4-turbo",
243+
model_name=model_name,
244244
)
245245

246246

@@ -256,11 +256,11 @@ def refine_bot_sysprompt():
256256
"""
257257

258258

259-
def refine_bot() -> SimpleBot:
259+
def refine_bot(model_name: str = "o1-preview") -> SimpleBot:
260260
"""Return a SimpleBot for the documentation writer."""
261261
return SimpleBot(
262262
system_prompt=refine_bot_sysprompt(),
263-
model_name="o1-preview",
263+
model_name=model_name,
264264
)
265265

266266

@@ -270,6 +270,9 @@ def write(
270270
from_scratch: bool = False,
271271
refine: bool = False,
272272
verbose: bool = False,
273+
ood_checker_model_name: str = "gpt-4o",
274+
docwriter_model_name: str = "gpt-4o",
275+
refiner_model_name: str = "o1-preview",
273276
):
274277
"""Write the documentation based on the given source file.
275278
@@ -290,27 +293,32 @@ def write(
290293
291294
:param file_path: Path to the Markdown source file.
292295
:param from_scratch: Whether to start with a blank documentation.
296+
:param refine: Whether to refine the documentation.
297+
:param verbose: Whether to print the verbose output.
298+
:param ood_checker_model_name: The model name for the out-of-date checker.
299+
:param docwriter_model_name: The model name for the docwriter.
300+
:param refiner_model_name: The model name for the refiner.
293301
"""
294302
src_file = MarkdownSourceFile(file_path)
295303

296304
if from_scratch:
297305
src_file.post.content = ""
298306

299-
ood_checker = ood_checker_bot()
307+
ood_checker = ood_checker_bot(model_name=ood_checker_model_name)
300308
result: DocsOutOfDate = ood_checker(
301309
documentation_information(src_file), verbose=verbose
302310
)
303311

304312
if not src_file.post.content or result:
305-
docwriter = docwriter_bot()
313+
docwriter = docwriter_bot(model_name=docwriter_model_name)
306314
response: DocumentationContent = docwriter(
307315
documentation_information(src_file) + "\nNow please write the docs.",
308316
verbose=verbose,
309317
)
310318
src_file.post.content = response.content
311319

312320
if refine:
313-
refiner = refine_bot()
321+
refiner = refine_bot(model_name=refiner_model_name)
314322
response: str = refiner(src_file.post.content, verbose=verbose)
315323
src_file.post.content = response
316324

0 commit comments

Comments
 (0)