-
Notifications
You must be signed in to change notification settings - Fork 107
feat(datasets): Add LangchainPromptDataset to experimental datasets #1200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
243f4f7
7a7b0a0
47557db
6b9247b
2681408
f6b9504
68c3f94
cd17a1c
0173dc3
a3c23b0
14591d1
f074b5e
9dd1673
4695764
9dca77a
0f65fd6
de94bac
a52d0a6
674e487
0397ecb
26571d8
0174342
ecd1e9d
7ecad93
d5c5dd0
f93f419
cadeec0
f6c5f31
698d618
1494696
1c9bb51
ce3933e
2ca6722
ebc29b9
eabb100
d903366
281f422
1195955
229c47a
c889e9e
c7504e4
d7760f8
2e77fe7
662dd24
9ad29df
d6b4f10
33f1491
d08b51f
4e31229
62da18f
ca864a8
115839d
b2077d9
a69820c
bb62c16
14d48ad
0126125
e8f6f24
2943164
bdfdc48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| from copy import deepcopy | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from kedro.io import AbstractDataset, DatasetError | ||
| from kedro.io.catalog_config_resolver import CREDENTIALS_KEY | ||
| from kedro.io.core import get_filepath_str, parse_dataset_definition | ||
| from langchain.prompts import ChatPromptTemplate, PromptTemplate | ||
|
|
||
| # Minimum number of elements required for a message (role, content) | ||
| MIN_MESSAGE_LENGTH = 2 | ||
|
|
||
| class LangChainPromptDataset(AbstractDataset[PromptTemplate | ChatPromptTemplate, Any]): | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Kedro dataset for loading LangChain prompts using existing Kedro datasets.""" | ||
|
||
|
|
||
| TEMPLATES = { | ||
| "PromptTemplate": PromptTemplate, | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "ChatPromptTemplate": ChatPromptTemplate, | ||
| } | ||
|
|
||
| def __init__( # noqa: PLR0913 | ||
| self, | ||
| filepath: str, | ||
| template: str = "PromptTemplate", | ||
| dataset: dict[str, Any] | str | None = None, | ||
| credentials: dict[str, Any] | None = None, | ||
| fs_args: dict[str, Any] | None = None, | ||
| metadata: dict[str, Any] | None = None, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Initialize the LangChain prompt dataset. | ||
|
|
||
| Args: | ||
| filepath: Path to the prompt file | ||
| template: Name of the LangChain template class ("PromptTemplate" or "ChatPromptTemplate") | ||
| dataset: Configuration for the underlying Kedro dataset | ||
| credentials: Credentials passed to the underlying dataset unless already defined | ||
| fs_args: Extra arguments passed to the filesystem, if supported | ||
| metadata: Arbitrary metadata | ||
| **kwargs: Additional arguments (ignored) | ||
| """ | ||
| super().__init__() | ||
|
|
||
| self.metadata = metadata | ||
| self._filepath = get_filepath_str(Path(filepath), kwargs.get("protocol")) | ||
|
|
||
| try: | ||
| self._template_class = self.TEMPLATES[template] | ||
| except KeyError: | ||
| raise DatasetError( | ||
| f"Invalid template '{template}'. Must be one of: {list(self.TEMPLATES)}" | ||
| ) | ||
|
|
||
| # Infer dataset type if not explicitly provided | ||
| dataset_config = self._build_dataset_config(dataset) | ||
|
|
||
| # Handle credentials | ||
| self._credentials = deepcopy(credentials or {}) | ||
| self._fs_args = deepcopy(fs_args or {}) | ||
|
|
||
| if self._credentials: | ||
| if CREDENTIALS_KEY in dataset_config: | ||
| self._logger.warning( | ||
| "Top-level credentials will not propagate into the underlying dataset " | ||
| "since credentials were explicitly defined in the dataset config." | ||
| ) | ||
| else: | ||
| dataset_config[CREDENTIALS_KEY] = deepcopy(self._credentials) | ||
|
|
||
| if self._fs_args: | ||
| if "fs_args" in dataset_config: | ||
| self._logger.warning( | ||
| "Top-level fs_args will not propagate into the underlying dataset " | ||
| "since fs_args were explicitly defined in the dataset config." | ||
| ) | ||
| else: | ||
| dataset_config["fs_args"] = deepcopy(self._fs_args) | ||
|
|
||
| try: | ||
| dataset_class, dataset_kwargs = parse_dataset_definition(dataset_config) | ||
| self._dataset = dataset_class(**dataset_kwargs) | ||
| except Exception as e: | ||
| raise DatasetError(f"Failed to create underlying dataset: {e}") | ||
|
|
||
| def _build_dataset_config(self, dataset: dict[str, Any] | str | None) -> dict[str, Any]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it's better to specify strictly what the underlying datasets can be - just
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, yeah, it makes sense cause now one can set a random dataset that will load data incompatible with the langchain template.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some validation for that.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I meant was, instead of allowing for a case where the user hasn't specified an underlying
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the idea of allowing users to use a custom underlying dataset but I think this can be a future addition. I'd like to see if this dataset is something that people are actually interested in using first. |
||
| """Infer and normalize dataset configuration.""" | ||
| if dataset is None: | ||
ankatiyar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self._filepath.endswith(".txt"): | ||
| dataset = {"type": "text.TextDataset"} | ||
| elif self._filepath.endswith(".json"): | ||
| dataset = {"type": "json.JSONDataset"} | ||
| elif self._filepath.endswith((".yaml", ".yml")): | ||
| dataset = {"type": "yaml.YAMLDataset"} | ||
| else: | ||
| raise DatasetError(f"Cannot auto-detect dataset type for file: {self._filepath}") | ||
|
|
||
| dataset_config = dataset if isinstance(dataset, dict) else {"type": dataset} | ||
| dataset_config = deepcopy(dataset_config) | ||
| dataset_config["filepath"] = self._filepath | ||
|
|
||
| return dataset_config | ||
|
|
||
| def load(self) -> PromptTemplate | ChatPromptTemplate: | ||
| """Load data using underlying dataset and wrap in LangChain template.""" | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| try: | ||
| raw_data = self._dataset.load() | ||
| except Exception as e: | ||
| raise DatasetError(f"Failed to load data from {self._filepath}: {e}") | ||
|
|
||
| if raw_data is None: | ||
| raise DatasetError(f"No data loaded from {self._filepath}") | ||
|
|
||
| try: | ||
| if self._template_class == ChatPromptTemplate: | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return self._create_chat_prompt_template(raw_data) | ||
| else: | ||
| return self._create_prompt_template(raw_data) | ||
| except Exception as e: | ||
| raise DatasetError(f"Failed to create {self._template_class.__name__}: {e}") | ||
|
|
||
| def _create_prompt_template(self, raw_data: Any) -> PromptTemplate: | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Create a PromptTemplate from loaded data.""" | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if isinstance(raw_data, str): | ||
| return PromptTemplate.from_template(raw_data) | ||
|
|
||
| if isinstance(raw_data, dict): | ||
| return PromptTemplate(**raw_data) | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| raise DatasetError(f"Unsupported data type for PromptTemplate: {type(raw_data)}") | ||
|
|
||
| def _validate_chat_prompt_data(self, data: dict | list[tuple[str, str]]) -> list[tuple[str, str]]: | ||
| """Validate and normalize chat prompt data.""" | ||
ElenaKhaustova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| messages = data.get("messages") if isinstance(data, dict) else data | ||
| if not isinstance(messages, list) or not messages: | ||
| raise DatasetError( | ||
| "ChatPromptTemplate requires a non-empty list of messages " | ||
| "(either directly or under the 'messages' key in a dict)." | ||
| ) | ||
|
|
||
| validated_data = [tuple(msg) if isinstance(msg, list) and len(msg) >= MIN_MESSAGE_LENGTH else msg | ||
| for msg in messages] | ||
|
|
||
| if any(not isinstance(msg, tuple) or len(msg) != MIN_MESSAGE_LENGTH for msg in validated_data): | ||
| raise DatasetError(f"Unsupported message type found in messages: {validated_data}") | ||
|
|
||
| return validated_data | ||
|
|
||
| def _create_chat_prompt_template(self, data: dict | list[tuple[str, str]]) -> ChatPromptTemplate: | ||
| """Create a ChatPromptTemplate from validated data.""" | ||
| messages = self._validate_chat_prompt_data(data) | ||
ElenaKhaustova marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return ChatPromptTemplate.from_messages(messages) | ||
|
|
||
| def save(self, data: Any) -> None: | ||
| raise DatasetError("Saving is not supported for LangChainPromptDataset") | ||
|
|
||
| def _describe(self) -> dict[str, Any]: | ||
| clean_config = { | ||
| k: v for k, v in getattr(self._dataset, "_config", {}).items() if k != CREDENTIALS_KEY | ||
| } | ||
| return { | ||
| "path": self._filepath, | ||
| "template": self._template_class.__name__, | ||
| "underlying_dataset": self._dataset.__class__.__name__, | ||
| "dataset_config": clean_config, | ||
| } | ||
|
|
||
| def _exists(self) -> bool: | ||
| return self._dataset._exists() if hasattr(self._dataset, "_exists") else True | ||
Uh oh!
There was an error while loading. Please reload this page.