Skip to content
Merged
Changes from 6 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
243f4f7
Add LangchainPromptDataset to experimental datasets
lrcouto Sep 29, 2025
7a7b0a0
Add credential handling
lrcouto Oct 2, 2025
47557db
Merge branch 'main' into add-langchain-prompt-dataset
lrcouto Oct 2, 2025
6b9247b
Lint
lrcouto Oct 2, 2025
2681408
Cleanup
lrcouto Oct 2, 2025
f6b9504
Separate validation from _create_chat_prompt_template
lrcouto Oct 2, 2025
68c3f94
Change validation function to not try to validate the template format
lrcouto Oct 3, 2025
cd17a1c
Add unit tests for LangChainPromptDataset
lrcouto Oct 3, 2025
0173dc3
Map constant template type to function
lrcouto Oct 3, 2025
a3c23b0
Better docstrings
lrcouto Oct 3, 2025
14591d1
Add LangChainPromptDataset to release notes
lrcouto Oct 3, 2025
f074b5e
Add new dataset to docs
lrcouto Oct 3, 2025
9dd1673
Add new dataset to docs index
lrcouto Oct 3, 2025
4695764
Fix mkdocs error
lrcouto Oct 3, 2025
9dca77a
Add preview method
lrcouto Oct 4, 2025
0f65fd6
Fix preview method, should work on Viz now
lrcouto Oct 4, 2025
de94bac
Add requirements to pyproject.toml
lrcouto Oct 6, 2025
a52d0a6
Improve docstrings
lrcouto Oct 6, 2025
674e487
Add LangchainPromptDataset to experimental datasets
lrcouto Sep 29, 2025
0397ecb
Add credential handling
lrcouto Oct 2, 2025
26571d8
Lint
lrcouto Oct 2, 2025
0174342
Cleanup
lrcouto Oct 2, 2025
ecd1e9d
Separate validation from _create_chat_prompt_template
lrcouto Oct 2, 2025
7ecad93
Change validation function to not try to validate the template format
lrcouto Oct 3, 2025
d5c5dd0
Add unit tests for LangChainPromptDataset
lrcouto Oct 3, 2025
f93f419
Map constant template type to function
lrcouto Oct 3, 2025
cadeec0
Better docstrings
lrcouto Oct 3, 2025
f6c5f31
Add LangChainPromptDataset to release notes
lrcouto Oct 3, 2025
698d618
Add new dataset to docs
lrcouto Oct 3, 2025
1494696
Add new dataset to docs index
lrcouto Oct 3, 2025
1c9bb51
Fix mkdocs error
lrcouto Oct 3, 2025
ce3933e
Add preview method
lrcouto Oct 4, 2025
2ca6722
Fix preview method, should work on Viz now
lrcouto Oct 4, 2025
ebc29b9
Add requirements to pyproject.toml
lrcouto Oct 6, 2025
eabb100
Improve docstrings
lrcouto Oct 6, 2025
d903366
Fix return type on validate function
lrcouto Oct 6, 2025
281f422
Merge branch 'add-langchain-prompt-dataset' of github.com:kedro-org/k…
lrcouto Oct 6, 2025
1195955
Delete coverage.xml
lrcouto Oct 6, 2025
229c47a
Remove coverage files that shouldn't be there
lrcouto Oct 6, 2025
c889e9e
Merge branch 'add-langchain-prompt-dataset' of github.com:kedro-org/k…
lrcouto Oct 6, 2025
c7504e4
Simplify preview function
lrcouto Oct 6, 2025
d7760f8
Add better docstring to class
lrcouto Oct 6, 2025
2e77fe7
Lower required langchain version
lrcouto Oct 6, 2025
662dd24
Lint
lrcouto Oct 6, 2025
9ad29df
Update kedro-datasets/kedro_datasets_experimental/langchain/__init__.py
lrcouto Oct 7, 2025
d6b4f10
Improve docstring
lrcouto Oct 7, 2025
33f1491
Merge branch 'add-langchain-prompt-dataset' of github.com:kedro-org/k…
lrcouto Oct 7, 2025
d08b51f
Add validation for plain string on ChatPromptTemplate
lrcouto Oct 7, 2025
4e31229
Fix indentation on docstring
lrcouto Oct 7, 2025
62da18f
update docstring and version
lrcouto Oct 8, 2025
ca864a8
Remove redundant part of docstring
lrcouto Oct 8, 2025
115839d
Add validation for dataset type
lrcouto Oct 8, 2025
b2077d9
Update docstring for _build_dataset_config
lrcouto Oct 8, 2025
a69820c
Update docstring for _build_dataset_config
lrcouto Oct 8, 2025
bb62c16
Fix indentation on docstring
lrcouto Oct 8, 2025
14d48ad
Make dataset type parameter mandatory
lrcouto Oct 8, 2025
0126125
Split by period and use one last two names in dataset type validation
lrcouto Oct 8, 2025
e8f6f24
Update kedro-datasets/kedro_datasets_experimental/langchain/langchain…
lrcouto Oct 9, 2025
2943164
Separate validation on build config function
lrcouto Oct 9, 2025
bdfdc48
Lint?
lrcouto Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from copy import deepcopy
from pathlib import Path
from typing import Any

from kedro.io import AbstractDataset, DatasetError
from kedro.io.catalog_config_resolver import CREDENTIALS_KEY
from kedro.io.core import get_filepath_str, parse_dataset_definition
from langchain.prompts import ChatPromptTemplate, PromptTemplate

# Minimum number of elements required for a message (role, content)
MIN_MESSAGE_LENGTH = 2

class LangChainPromptDataset(AbstractDataset[PromptTemplate | ChatPromptTemplate, Any]):
"""Kedro dataset for loading LangChain prompts using existing Kedro datasets."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please extend class docstrings like we do for the rest of the datasets? It is used in the docs, so it should be quite informative.


TEMPLATES = {
"PromptTemplate": PromptTemplate,
"ChatPromptTemplate": ChatPromptTemplate,
}

def __init__( # noqa: PLR0913
self,
filepath: str,
template: str = "PromptTemplate",
dataset: dict[str, Any] | str | None = None,
credentials: dict[str, Any] | None = None,
fs_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs,
):
"""
Initialize the LangChain prompt dataset.

Args:
filepath: Path to the prompt file
template: Name of the LangChain template class ("PromptTemplate" or "ChatPromptTemplate")
dataset: Configuration for the underlying Kedro dataset
credentials: Credentials passed to the underlying dataset unless already defined
fs_args: Extra arguments passed to the filesystem, if supported
metadata: Arbitrary metadata
**kwargs: Additional arguments (ignored)
"""
super().__init__()

self.metadata = metadata
self._filepath = get_filepath_str(Path(filepath), kwargs.get("protocol"))

try:
self._template_class = self.TEMPLATES[template]
except KeyError:
raise DatasetError(
f"Invalid template '{template}'. Must be one of: {list(self.TEMPLATES)}"
)

# Infer dataset type if not explicitly provided
dataset_config = self._build_dataset_config(dataset)

# Handle credentials
self._credentials = deepcopy(credentials or {})
self._fs_args = deepcopy(fs_args or {})

if self._credentials:
if CREDENTIALS_KEY in dataset_config:
self._logger.warning(
"Top-level credentials will not propagate into the underlying dataset "
"since credentials were explicitly defined in the dataset config."
)
else:
dataset_config[CREDENTIALS_KEY] = deepcopy(self._credentials)

if self._fs_args:
if "fs_args" in dataset_config:
self._logger.warning(
"Top-level fs_args will not propagate into the underlying dataset "
"since fs_args were explicitly defined in the dataset config."
)
else:
dataset_config["fs_args"] = deepcopy(self._fs_args)

try:
dataset_class, dataset_kwargs = parse_dataset_definition(dataset_config)
self._dataset = dataset_class(**dataset_kwargs)
except Exception as e:
raise DatasetError(f"Failed to create underlying dataset: {e}")

def _build_dataset_config(self, dataset: dict[str, Any] | str | None) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's better to specify strictly what the underlying datasets can be - just TextDataset, YAMLDataset and JSONDataset and error out if it isn't instead of inferring it. Unlike PartitionedDataset it's not like this can be any underlying dataset type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yeah, it makes sense cause now one can set a random dataset that will load data incompatible with the langchain template.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some validation for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant was, instead of allowing for a case where the user hasn't specified an underlying dataset config and inferring from file extension, we can just error out if the dataset config is not provided. And if the config is provided, maybe we can check if the type is only TextDataset YAMLDataset or JSONDataset (on further discussion with @ElenaKhaustova, this might limit the user incase they wanted to use a custom underlying dataset so I am not too fussed about if we include this validation or not. If the data is not in the correct format langchain should complain anyway. But it might be fine to limit this to these types for now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of allowing users to use a custom underlying dataset but I think this can be a future addition. I'd like to see if this dataset is something that people are actually interested in using first.

"""Infer and normalize dataset configuration."""
if dataset is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please also add at least two unit tests to check that proper errors are raised?

if self._filepath.endswith(".txt"):
dataset = {"type": "text.TextDataset"}
elif self._filepath.endswith(".json"):
dataset = {"type": "json.JSONDataset"}
elif self._filepath.endswith((".yaml", ".yml")):
dataset = {"type": "yaml.YAMLDataset"}
else:
raise DatasetError(f"Cannot auto-detect dataset type for file: {self._filepath}")

dataset_config = dataset if isinstance(dataset, dict) else {"type": dataset}
dataset_config = deepcopy(dataset_config)
dataset_config["filepath"] = self._filepath

return dataset_config

def load(self) -> PromptTemplate | ChatPromptTemplate:
"""Load data using underlying dataset and wrap in LangChain template."""
try:
raw_data = self._dataset.load()
except Exception as e:
raise DatasetError(f"Failed to load data from {self._filepath}: {e}")

if raw_data is None:
raise DatasetError(f"No data loaded from {self._filepath}")

try:
if self._template_class == ChatPromptTemplate:
return self._create_chat_prompt_template(raw_data)
else:
return self._create_prompt_template(raw_data)
except Exception as e:
raise DatasetError(f"Failed to create {self._template_class.__name__}: {e}")

def _create_prompt_template(self, raw_data: Any) -> PromptTemplate:
"""Create a PromptTemplate from loaded data."""
if isinstance(raw_data, str):
return PromptTemplate.from_template(raw_data)

if isinstance(raw_data, dict):
return PromptTemplate(**raw_data)

raise DatasetError(f"Unsupported data type for PromptTemplate: {type(raw_data)}")

def _validate_chat_prompt_data(self, data: dict | list[tuple[str, str]]) -> list[tuple[str, str]]:
"""Validate and normalize chat prompt data."""
messages = data.get("messages") if isinstance(data, dict) else data
if not isinstance(messages, list) or not messages:
raise DatasetError(
"ChatPromptTemplate requires a non-empty list of messages "
"(either directly or under the 'messages' key in a dict)."
)

validated_data = [tuple(msg) if isinstance(msg, list) and len(msg) >= MIN_MESSAGE_LENGTH else msg
for msg in messages]

if any(not isinstance(msg, tuple) or len(msg) != MIN_MESSAGE_LENGTH for msg in validated_data):
raise DatasetError(f"Unsupported message type found in messages: {validated_data}")

return validated_data

def _create_chat_prompt_template(self, data: dict | list[tuple[str, str]]) -> ChatPromptTemplate:
"""Create a ChatPromptTemplate from validated data."""
messages = self._validate_chat_prompt_data(data)
return ChatPromptTemplate.from_messages(messages)

def save(self, data: Any) -> None:
raise DatasetError("Saving is not supported for LangChainPromptDataset")

def _describe(self) -> dict[str, Any]:
clean_config = {
k: v for k, v in getattr(self._dataset, "_config", {}).items() if k != CREDENTIALS_KEY
}
return {
"path": self._filepath,
"template": self._template_class.__name__,
"underlying_dataset": self._dataset.__class__.__name__,
"dataset_config": clean_config,
}

def _exists(self) -> bool:
return self._dataset._exists() if hasattr(self._dataset, "_exists") else True