forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community: support Databricks Unity Catalog functions as LangChain to…
…ols (langchain-ai#22555) This PR adds support for using Databricks Unity Catalog functions as LangChain tools, which runs inside a Databricks SQL warehouse. * An example notebook is provided.
- Loading branch information
Showing
4 changed files
with
544 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Databricks Unity Catalog (UC)\n", | ||
"\n", | ||
"This notebook shows how to use UC functions as LangChain tools.\n", | ||
"\n", | ||
"See Databricks documentation ([AWS](https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)|[Azure](https://learn.microsoft.com/en-us/azure/databricks/sql/language-manual/sql-ref-syntax-ddl-create-sql-function)|[GCP](https://docs.gcp.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)) to learn how to create SQL or Python functions in UC. Do not skip function and parameter comments, which are critical for LLMs to call functions properly.\n", | ||
"\n", | ||
"In this example notebook, we create a simple Python function that executes arbitary code and use it as a LangChain tool:\n", | ||
"\n", | ||
"```sql\n", | ||
"CREATE FUNCTION main.tools.python_exec (\n", | ||
" code STRING COMMENT 'Python code to execute. Remember to print the final result to stdout.'\n", | ||
")\n", | ||
"RETURNS STRING\n", | ||
"LANGUAGE PYTHON\n", | ||
"COMMENT 'Executes Python code and returns its stdout.'\n", | ||
"AS $$\n", | ||
" import sys\n", | ||
" from io import StringIO\n", | ||
" stdout = StringIO()\n", | ||
" sys.stdout = stdout\n", | ||
" exec(code)\n", | ||
" return stdout.getvalue()\n", | ||
"$$\n", | ||
"```\n", | ||
"\n", | ||
"It runs in a secure and isolated environment within a Databricks SQL warehouse." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_openai import ChatOpenAI\n", | ||
"\n", | ||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.tools.databricks import UCFunctionToolkit\n", | ||
"\n", | ||
"tools = (\n", | ||
" UCFunctionToolkit(\n", | ||
" # You can find the SQL warehouse ID in its UI after creation.\n", | ||
" warehouse_id=\"xxxx123456789\"\n", | ||
" )\n", | ||
" .include(\n", | ||
" # Include functions as tools using their qualified names.\n", | ||
" # You can use \"{catalog_name}.{schema_name}.*\" to get all functions in a schema.\n", | ||
" \"main.tools.python_exec\",\n", | ||
" )\n", | ||
" .get_tools()\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.agents import AgentExecutor, create_tool_calling_agent\n", | ||
"from langchain_core.prompts import ChatPromptTemplate\n", | ||
"\n", | ||
"prompt = ChatPromptTemplate.from_messages(\n", | ||
" [\n", | ||
" (\n", | ||
" \"system\",\n", | ||
" \"You are a helpful assistant. Make sure to use tool for information.\",\n", | ||
" ),\n", | ||
" (\"placeholder\", \"{chat_history}\"),\n", | ||
" (\"human\", \"{input}\"),\n", | ||
" (\"placeholder\", \"{agent_scratchpad}\"),\n", | ||
" ]\n", | ||
")\n", | ||
"\n", | ||
"agent = create_tool_calling_agent(llm, tools, prompt)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"\n", | ||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", | ||
"\u001b[32;1m\u001b[1;3m\n", | ||
"Invoking: `main__tools__python_exec` with `{'code': 'print(36939 * 8922.4)'}`\n", | ||
"\n", | ||
"\n", | ||
"\u001b[0m\u001b[36;1m\u001b[1;3m{\"format\": \"SCALAR\", \"value\": \"329584533.59999996\\n\", \"truncated\": false}\u001b[0m\u001b[32;1m\u001b[1;3mThe result of the multiplication 36939 * 8922.4 is 329,584,533.60.\u001b[0m\n", | ||
"\n", | ||
"\u001b[1m> Finished chain.\u001b[0m\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'input': '36939 * 8922.4',\n", | ||
" 'output': 'The result of the multiplication 36939 * 8922.4 is 329,584,533.60.'}" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n", | ||
"agent_executor.invoke({\"input\": \"36939 * 8922.4\"})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "llm", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
3 changes: 3 additions & 0 deletions
3
libs/community/langchain_community/tools/databricks/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from langchain_community.tools.databricks.tool import UCFunctionToolkit | ||
|
||
__all__ = ["UCFunctionToolkit"] |
172 changes: 172 additions & 0 deletions
172
libs/community/langchain_community/tools/databricks/_execution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import json | ||
from dataclasses import dataclass | ||
from io import StringIO | ||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional | ||
|
||
if TYPE_CHECKING: | ||
from databricks.sdk import WorkspaceClient | ||
from databricks.sdk.service.catalog import FunctionInfo | ||
from databricks.sdk.service.sql import StatementParameterListItem | ||
|
||
|
||
def is_scalar(function: "FunctionInfo") -> bool: | ||
from databricks.sdk.service.catalog import ColumnTypeName | ||
|
||
return function.data_type != ColumnTypeName.TABLE_TYPE | ||
|
||
|
||
@dataclass | ||
class ParameterizedStatement: | ||
statement: str | ||
parameters: List["StatementParameterListItem"] | ||
|
||
|
||
@dataclass | ||
class FunctionExecutionResult: | ||
""" | ||
Result of executing a function. | ||
We always use a string to present the result value for AI model to consume. | ||
""" | ||
|
||
error: Optional[str] = None | ||
format: Optional[Literal["SCALAR", "CSV"]] = None | ||
value: Optional[str] = None | ||
truncated: Optional[bool] = None | ||
|
||
def to_json(self) -> str: | ||
data = {k: v for (k, v) in self.__dict__.items() if v is not None} | ||
return json.dumps(data) | ||
|
||
|
||
def get_execute_function_sql_stmt( | ||
function: "FunctionInfo", json_params: Dict[str, Any] | ||
) -> ParameterizedStatement: | ||
from databricks.sdk.service.catalog import ColumnTypeName | ||
from databricks.sdk.service.sql import StatementParameterListItem | ||
|
||
parts = [] | ||
output_params = [] | ||
if is_scalar(function): | ||
# TODO: IDENTIFIER(:function) did not work | ||
parts.append(f"SELECT {function.full_name}(") | ||
else: | ||
parts.append(f"SELECT * FROM {function.full_name}(") | ||
if function.input_params is None or function.input_params.parameters is None: | ||
assert ( | ||
not json_params | ||
), "Function has no parameters but parameters were provided." | ||
else: | ||
args = [] | ||
use_named_args = False | ||
for p in function.input_params.parameters: | ||
if p.name not in json_params: | ||
if p.parameter_default is not None: | ||
use_named_args = True | ||
else: | ||
raise ValueError( | ||
f"Parameter {p.name} is required but not provided." | ||
) | ||
else: | ||
arg_clause = "" | ||
if use_named_args: | ||
arg_clause += f"{p.name} => " | ||
json_value = json_params[p.name] | ||
if p.type_name in ( | ||
ColumnTypeName.ARRAY, | ||
ColumnTypeName.MAP, | ||
ColumnTypeName.STRUCT, | ||
): | ||
# Use from_json to restore values of complex types. | ||
json_value_str = json.dumps(json_value) | ||
# TODO: parametrize type | ||
arg_clause += f"from_json(:{p.name}, '{p.type_text}')" | ||
output_params.append( | ||
StatementParameterListItem(name=p.name, value=json_value_str) | ||
) | ||
elif p.type_name == ColumnTypeName.BINARY: | ||
# Use ubbase64 to restore binary values. | ||
arg_clause += f"unbase64(:{p.name})" | ||
output_params.append( | ||
StatementParameterListItem(name=p.name, value=json_value) | ||
) | ||
else: | ||
arg_clause += f":{p.name}" | ||
output_params.append( | ||
StatementParameterListItem( | ||
name=p.name, value=json_value, type=p.type_text | ||
) | ||
) | ||
args.append(arg_clause) | ||
parts.append(",".join(args)) | ||
parts.append(")") | ||
# TODO: check extra params in kwargs | ||
statement = "".join(parts) | ||
return ParameterizedStatement(statement=statement, parameters=output_params) | ||
|
||
|
||
def execute_function( | ||
ws: "WorkspaceClient", | ||
warehouse_id: str, | ||
function: "FunctionInfo", | ||
parameters: Dict[str, Any], | ||
) -> FunctionExecutionResult: | ||
""" | ||
Execute a function with the given arguments and return the result. | ||
""" | ||
try: | ||
import pandas as pd | ||
except ImportError as e: | ||
raise ImportError( | ||
"Could not import pandas python package. " | ||
"Please install it with `pip install pandas`." | ||
) from e | ||
from databricks.sdk.service.sql import StatementState | ||
|
||
# TODO: async so we can run functions in parallel | ||
parametrized_statement = get_execute_function_sql_stmt(function, parameters) | ||
# TODO: configurable limits | ||
response = ws.statement_execution.execute_statement( | ||
statement=parametrized_statement.statement, | ||
warehouse_id=warehouse_id, | ||
parameters=parametrized_statement.parameters, | ||
wait_timeout="30s", | ||
row_limit=100, | ||
byte_limit=4096, | ||
) | ||
status = response.status | ||
assert status is not None, f"Statement execution failed: {response}" | ||
if status.state != StatementState.SUCCEEDED: | ||
error = status.error | ||
assert ( | ||
error is not None | ||
), "Statement execution failed but no error message was provided." | ||
return FunctionExecutionResult(error=f"{error.error_code}: {error.message}") | ||
manifest = response.manifest | ||
assert manifest is not None | ||
truncated = manifest.truncated | ||
result = response.result | ||
assert ( | ||
result is not None | ||
), "Statement execution succeeded but no result was provided." | ||
data_array = result.data_array | ||
if is_scalar(function): | ||
value = None | ||
if data_array and len(data_array) > 0 and len(data_array[0]) > 0: | ||
value = str(data_array[0][0]) # type: ignore | ||
return FunctionExecutionResult( | ||
format="SCALAR", value=value, truncated=truncated | ||
) | ||
else: | ||
schema = manifest.schema | ||
assert ( | ||
schema is not None and schema.columns is not None | ||
), "Statement execution succeeded but no schema was provided." | ||
columns = [c.name for c in schema.columns] | ||
if data_array is None: | ||
data_array = [] | ||
pdf = pd.DataFrame.from_records(data_array, columns=columns) | ||
csv_buffer = StringIO() | ||
pdf.to_csv(csv_buffer, index=False) | ||
return FunctionExecutionResult( | ||
format="CSV", value=csv_buffer.getvalue(), truncated=truncated | ||
) |
Oops, something went wrong.