Skip to content

Commit 3d6423b

Browse files
lucametehauManul from Pathway
authored andcommitted
Improve YAML configuration for "demo-question-answering" (#7230)
GitOrigin-RevId: 78ef14280c258d04dde139a51d7ff52e26058cf3
1 parent 44ea519 commit 3d6423b

File tree

2 files changed

+79
-110
lines changed

2 files changed

+79
-110
lines changed
Lines changed: 37 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import logging
2-
import sys
32

4-
import click
53
import pathway as pw
6-
import yaml
74
from dotenv import load_dotenv
8-
from pathway.udfs import DiskCache, ExponentialBackoffRetryStrategy
9-
from pathway.xpacks.llm import embedders, llms, parsers, splitters
5+
from pathway.xpacks import llm
106
from pathway.xpacks.llm.question_answering import BaseRAGQuestionAnswerer
117
from pathway.xpacks.llm.vector_store import VectorStoreServer
8+
from pydantic import BaseModel, ConfigDict, InstanceOf
9+
from typing_extensions import TypedDict
1210

1311
# To use advanced features with Pathway Scale, get your free license key from
1412
# https://pathway.com/features and paste it below.
@@ -23,77 +21,41 @@
2321

2422
load_dotenv()
2523

24+
host_config = TypedDict("host_config", {"host": str, "port": int})
2625

27-
def data_sources(source_configs) -> list[pw.Table]:
28-
sources = []
29-
for source_config in source_configs:
30-
if source_config["kind"] == "local":
31-
source = pw.io.fs.read(
32-
**source_config["config"],
33-
format="binary",
34-
with_metadata=True,
35-
)
36-
sources.append(source)
37-
elif source_config["kind"] == "gdrive":
38-
source = pw.io.gdrive.read(
39-
**source_config["config"],
40-
with_metadata=True,
41-
)
42-
sources.append(source)
43-
elif source_config["kind"] == "sharepoint":
44-
try:
45-
import pathway.xpacks.connectors.sharepoint as io_sp
46-
47-
source = io_sp.read(**source_config["config"], with_metadata=True)
48-
sources.append(source)
49-
except ImportError:
50-
print(
51-
"The Pathway Sharepoint connector is part of the commercial offering, "
52-
"please contact us for a commercial license."
53-
)
54-
sys.exit(1)
55-
56-
return sources
57-
58-
59-
@click.command()
60-
@click.option("--config_file", default="config.yaml", help="Config file to be used.")
61-
def run(
62-
config_file: str = "config.yaml",
63-
):
64-
with open(config_file) as config_f:
65-
configuration = yaml.safe_load(config_f)
66-
67-
GPT_MODEL = configuration["llm_config"]["model"]
68-
69-
embedder = embedders.OpenAIEmbedder(
70-
model="text-embedding-ada-002",
71-
cache_strategy=DiskCache(),
72-
)
73-
74-
chat = llms.OpenAIChat(
75-
model=GPT_MODEL,
76-
retry_strategy=ExponentialBackoffRetryStrategy(max_retries=6),
77-
cache_strategy=DiskCache(),
78-
temperature=0.05,
79-
)
80-
81-
host_config = configuration["host_config"]
82-
host, port = host_config["host"], host_config["port"]
83-
84-
doc_store = VectorStoreServer(
85-
*data_sources(configuration["sources"]),
86-
embedder=embedder,
87-
splitter=splitters.TokenCountSplitter(max_tokens=400),
88-
parser=parsers.ParseUnstructured(),
89-
)
90-
91-
rag_app = BaseRAGQuestionAnswerer(llm=chat, indexer=doc_store)
92-
93-
rag_app.build_server(host=host, port=port)
94-
95-
rag_app.run_server(with_cache=True, terminate_on_error=False)
26+
27+
class App(BaseModel):
28+
llm: InstanceOf[pw.UDF]
29+
embedder: InstanceOf[llm.embedders.BaseEmbedder]
30+
splitter: InstanceOf[pw.UDF]
31+
parser: InstanceOf[pw.UDF]
32+
33+
sources: list[InstanceOf[pw.Table]]
34+
35+
host_config: host_config
36+
37+
def run(self, config_file: str = "config.yaml") -> None:
38+
# Unpack host and port from config
39+
host, port = self.host_config["host"], self.host_config["port"]
40+
41+
doc_store = VectorStoreServer(
42+
*self.sources,
43+
embedder=self.embedder,
44+
splitter=self.splitter,
45+
parser=self.parser,
46+
)
47+
48+
rag_app = BaseRAGQuestionAnswerer(llm=self.llm, indexer=doc_store)
49+
50+
rag_app.build_server(host=host, port=port)
51+
52+
rag_app.run_server(with_cache=True, terminate_on_error=False)
53+
54+
model_config = ConfigDict(extra="forbid")
9655

9756

9857
if __name__ == "__main__":
99-
run()
58+
with open("config.yaml") as f:
59+
config = pw.load_yaml(f)
60+
app = App(**config)
61+
app.run()
Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,46 @@
1-
llm_config:
1+
llm: !pw.xpacks.llm.llms.OpenAIChat
22
model: "gpt-3.5-turbo"
3+
retry_strategy: !pw.udfs.ExponentialBackoffRetryStrategy
4+
max_retries: 6
5+
cache_strategy: !pw.udfs.DiskCache
6+
temperature: 0.05
7+
capacity: 8
8+
9+
embedder: !pw.xpacks.llm.embedders.OpenAIEmbedder
10+
model: "text-embedding-ada-002"
11+
cache_strategy: !pw.udfs.DiskCache
12+
313
host_config:
414
host: "0.0.0.0"
5-
port: 8000
6-
cache_options:
7-
with_cache: True
8-
cache_folder: "./Cache"
15+
port: 16003
16+
17+
splitter: !pw.xpacks.llm.splitters.TokenCountSplitter
18+
max_tokens: 400
19+
20+
parser: !pw.xpacks.llm.parsers.ParseUnstructured
21+
922
sources:
10-
- local_files:
11-
kind: local
12-
config:
13-
# Please refer to
14-
# https://pathway.com/developers/api-docs/pathway-io/fs#pathway.io.fs.read
15-
# for options definition
16-
path: "data/"
17-
# - google_drive_folder:
18-
# kind: gdrive
19-
# config:
20-
# # Please refer to
21-
# # https://pathway.com/developers/api-docs/pathway-io/gdrive#pathway.io.gdrive.read
22-
# # for options definition
23-
# # Please follow https://pathway.com/developers/user-guide/connectors/gdrive-connector/#setting-up-google-drive
24-
# # for instructions on getting credentials
25-
# object_id: "1cULDv2OaViJBmOfG5WB0oWcgayNrGtVs" # folder used in the managed demo
26-
# service_user_credentials_file: SERVICE_CREDENTIALS
27-
# refresh_interval: 5
28-
# - sharepoint_folder:
29-
# kind: sharepoint
30-
# config:
31-
# # The sharepoint is part of our commercial offering, please contact us to use it
32-
# # Please contact here: `[email protected]`
33-
# root_path: ROOT_PATH
34-
# url: SHAREPOINT_URL
35-
# tenant: SHAREPOINT_TENANT
36-
# client_id: SHAREPOINT_CLIENT_ID
37-
# cert_path: SHAREPOINT.pem
38-
# thumbprint: SHAREPOINT_THUMBPRINT
39-
# refresh_interval: 5
23+
- !pw.io.fs.read
24+
path: data
25+
format: binary
26+
with_metadata: true
27+
28+
# - !pw.xpacks.connectors.sharepoint.read
29+
# url: $SHAREPOINT_URL
30+
# tenant: $SHAREPOINT_TENANT
31+
# client_id: $SHAREPOINT_CLIENT_ID
32+
# cert_path: sharepointcert.pem
33+
# thumbprint: $SHAREPOINT_THUMBPRINT
34+
# root_path: $SHAREPOINT_ROOT
35+
# with_metadata: true
36+
# refresh_interval: 30
37+
38+
# - !pw.io.gdrive.read
39+
# object_id: $DRIVE_ID
40+
# service_user_credentials_file: gdrive_indexer.json
41+
# name_pattern:
42+
# - "*.pdf"
43+
# - "*.pptx"
44+
# object_size_limit: null
45+
# with_metadata: true
46+
# refresh_interval: 30

0 commit comments

Comments
 (0)