Skip to content

Commit 111aca4

Browse files
authored
move implementations from script embedded in rust to py-moose-lib (#2508)
1 parent 03120df commit 111aca4

File tree

5 files changed

+250
-251
lines changed

5 files changed

+250
-251
lines changed

apps/framework-cli/src/framework/python/wrappers/consumption_runner.py

Lines changed: 4 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import asyncio
33
import dataclasses
4-
import hashlib
54
import json
65
import os
76
import subprocess
@@ -15,20 +14,19 @@
1514
from http.server import HTTPServer, BaseHTTPRequestHandler
1615

1716
from importlib import import_module
18-
from string import Formatter
19-
from typing import Optional, Dict, Any, Type
17+
from typing import Optional, Dict, Any
2018
from urllib.parse import urlparse, parse_qs
19+
from moose_lib import MooseClient
2120
from moose_lib.query_param import map_params_to_class, convert_consumption_api_param, convert_pydantic_definition
2221
from moose_lib.internal import load_models
2322
from moose_lib.dmv2 import get_consumption_api, get_workflow
2423
from pydantic import BaseModel, ValidationError
2524

2625
import jwt
2726
from clickhouse_connect import get_client
28-
from clickhouse_connect.driver.client import Client as ClickhouseClient
2927

30-
from temporalio.client import Client as TemporalClient, TLSConfig
31-
from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy, WorkflowIDReusePolicy
28+
from moose_lib.commons import EnhancedJSONEncoder
29+
3230
from consumption_wrapper.utils import create_temporal_connection
3331

3432
parser = argparse.ArgumentParser(description='Run Consumption Server')
@@ -76,202 +74,6 @@
7674
sys.path.append(consumption_dir_path)
7775

7876

79-
80-
# TODO: move this to python moose lib
81-
class EnhancedJSONEncoder(json.JSONEncoder):
82-
def default(self, o):
83-
if isinstance(o, datetime):
84-
if o.tzinfo is None:
85-
o = o.replace(tzinfo=timezone.utc)
86-
return o.isoformat()
87-
if isinstance(o, date):
88-
return o.isoformat()
89-
if dataclasses.is_dataclass(o):
90-
return dataclasses.asdict(o)
91-
return super().default(o)
92-
93-
94-
class QueryClient:
95-
def __init__(self, ch_client: ClickhouseClient):
96-
self.ch_client = ch_client
97-
98-
def __call__(self, input, variables):
99-
return self.execute(input, variables)
100-
101-
def execute(self, input, variables, row_type: Type[BaseModel] = None):
102-
params = {}
103-
values = {}
104-
105-
for i, (_, variable_name, _, _) in enumerate(Formatter().parse(input)):
106-
if variable_name:
107-
value = variables[variable_name]
108-
if isinstance(value, list) and len(value) == 1:
109-
# handling passing the value of the query string dict directly to variables
110-
value = value[0]
111-
112-
t = 'String' if isinstance(value, str) else \
113-
'Int64' if isinstance(value, int) else \
114-
'Float64' if isinstance(value, float) else "String" # unknown type
115-
116-
params[variable_name] = f'{{p{i}: {t}}}'
117-
values[f'p{i}'] = value
118-
clickhouse_query = input.format_map(params)
119-
120-
# We are not using the result of the ping
121-
# but this ensures that if the clickhouse cloud service is idle, we
122-
# wake it up, before we send the query.
123-
self.ch_client.ping()
124-
125-
val = self.ch_client.query(clickhouse_query, values)
126-
127-
if row_type is None:
128-
return list(val.named_results())
129-
else:
130-
return list(row_type(**row) for row in val.named_results())
131-
132-
class WorkflowClient:
133-
def __init__(self, temporal_client: TemporalClient):
134-
self.temporal_client = temporal_client
135-
self.configs = self.load_consolidated_configs()
136-
print(f"WorkflowClient - configs: {self.configs}")
137-
138-
# Test workflow executor in rust if this changes significantly
139-
def execute(self, name: str, input_data: Any) -> Dict[str, Any]:
140-
try:
141-
workflow_id, run_id = asyncio.run(self._start_workflow_async(name, input_data))
142-
print(f"WorkflowClient - started workflow: {name}")
143-
return {
144-
"status": 200,
145-
"body": f"Workflow started: {name}. View it in the Temporal dashboard: http://localhost:8080/namespaces/default/workflows/{workflow_id}/{run_id}/history"
146-
}
147-
except Exception as e:
148-
print(f"WorkflowClient - error while starting workflow: {e}")
149-
return {
150-
"status": 400,
151-
"body": str(e)
152-
}
153-
154-
async def _start_workflow_async(self, name: str, input_data: Any):
155-
# Extract configuration based on workflow type
156-
config = self._get_workflow_config(name)
157-
158-
# Process input data and generate workflow ID (common logic)
159-
processed_input, workflow_id = self._process_input_data(name, input_data)
160-
161-
# Create retry policy and timeout (common logic)
162-
retry_policy = RetryPolicy(maximum_attempts=config['retry_count'])
163-
run_timeout = self.parse_timeout_to_timedelta(config['timeout_str'])
164-
165-
print(f"WorkflowClient - starting {'DMv2 ' if config['is_dmv2'] else ''}workflow: {name} with retry policy: {retry_policy} and timeout: {run_timeout}")
166-
167-
# Start workflow with appropriate args
168-
workflow_args = self._build_workflow_args(name, processed_input, config['is_dmv2'])
169-
170-
workflow_handle = await self.temporal_client.start_workflow(
171-
"ScriptWorkflow",
172-
args=workflow_args,
173-
id=workflow_id,
174-
task_queue="python-script-queue",
175-
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
176-
id_reuse_policy=WorkflowIDReusePolicy.ALLOW_DUPLICATE,
177-
retry_policy=retry_policy,
178-
run_timeout=run_timeout
179-
)
180-
181-
return workflow_id, workflow_handle.result_run_id
182-
183-
def _get_workflow_config(self, name: str) -> Dict[str, Any]:
184-
"""Extract workflow configuration from DMv2 or legacy config."""
185-
dmv2_workflow = get_workflow(name)
186-
if dmv2_workflow is not None:
187-
return {
188-
'retry_count': dmv2_workflow.config.retries or 3,
189-
'timeout_str': dmv2_workflow.config.timeout or "1h",
190-
'is_dmv2': True
191-
}
192-
else:
193-
config = self.configs.get(name, {})
194-
return {
195-
'retry_count': config.get('retries', 3),
196-
'timeout_str': config.get('timeout', "1h"),
197-
'is_dmv2': False
198-
}
199-
200-
def _process_input_data(self, name: str, input_data: Any) -> tuple[Any, str]:
201-
"""Process input data and generate workflow ID."""
202-
workflow_id = name
203-
if input_data:
204-
try:
205-
# Handle Pydantic model input for DMv2
206-
if isinstance(input_data, BaseModel):
207-
input_data = input_data.model_dump()
208-
elif isinstance(input_data, str):
209-
input_data = json.loads(input_data)
210-
211-
# Encode with custom encoder
212-
input_data = json.loads(
213-
json.dumps({"data": input_data}, cls=EnhancedJSONEncoder)
214-
)
215-
216-
params_str = json.dumps(input_data, sort_keys=True)
217-
params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16]
218-
workflow_id = f"{name}-{params_hash}"
219-
except Exception as e:
220-
raise ValueError(f"Invalid input data: {e}")
221-
222-
return input_data, workflow_id
223-
224-
def _build_workflow_args(self, name: str, input_data: Any, is_dmv2: bool) -> list:
225-
"""Build workflow arguments based on workflow type."""
226-
if is_dmv2:
227-
return [f"{name}", input_data]
228-
else:
229-
return [f"{os.getcwd()}/app/scripts/{name}", input_data]
230-
231-
def load_consolidated_configs(self):
232-
try:
233-
file_path = os.path.join(os.getcwd(), ".moose", "workflow_configs.json")
234-
with open(file_path, 'r') as file:
235-
data = json.load(file)
236-
config_map = {config['name']: config for config in data}
237-
return config_map
238-
except Exception as e:
239-
raise ValueError(f"Error loading file {file_path}: {e}")
240-
241-
def parse_timeout_to_timedelta(self, timeout_str: str) -> timedelta:
242-
if timeout_str.endswith('h'):
243-
return timedelta(hours=int(timeout_str[:-1]))
244-
elif timeout_str.endswith('m'):
245-
return timedelta(minutes=int(timeout_str[:-1]))
246-
elif timeout_str.endswith('s'):
247-
return timedelta(seconds=int(timeout_str[:-1]))
248-
else:
249-
raise ValueError(f"Unsupported timeout format: {timeout_str}")
250-
251-
class MooseClient:
252-
def __init__(self, ch_client: ClickhouseClient, temporal_client: Optional[TemporalClient] = None):
253-
self.query = QueryClient(ch_client)
254-
self.ch_client = ch_client # Store reference for cleanup
255-
self.temporal_client = temporal_client
256-
if temporal_client:
257-
self.workflow = WorkflowClient(temporal_client)
258-
else:
259-
self.workflow = None
260-
261-
async def cleanup(self):
262-
"""Cleanup resources before shutdown"""
263-
if self.ch_client:
264-
try:
265-
self.ch_client.close()
266-
except Exception as e:
267-
print(f"Error closing Clickhouse client: {e}")
268-
269-
if self.temporal_client:
270-
try:
271-
await self.temporal_client.close()
272-
except Exception as e:
273-
print(f"Error closing Temporal client: {e}")
274-
27577
def verify_jwt(token: str) -> Optional[Dict[str, Any]]:
27678
try:
27779
payload = jwt.decode(token, jwt_secret, algorithms=["RS256"], audience=jwt_audience, issuer=jwt_issuer)

packages/py-moose-lib/moose_lib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010

1111
from .dmv2 import *
1212

13-
from .clients.redis_client import MooseCache
13+
from .clients.redis_client import MooseCache

packages/py-moose-lib/moose_lib/commons.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
import dataclasses
12
import logging
3+
from datetime import datetime, timezone
4+
25
import requests
36
import json
47
from typing import Optional, Literal
58

9+
610
class CliLogData:
711
INFO = "Info"
812
SUCCESS = "Success"
913
ERROR = "Error"
1014
HIGHLIGHT = "Highlight"
1115

12-
def __init__(self, action: str, message: str, message_type: Optional[Literal[INFO, SUCCESS, ERROR, HIGHLIGHT]] = INFO):
16+
def __init__(self, action: str, message: str,
17+
message_type: Optional[Literal[INFO, SUCCESS, ERROR, HIGHLIGHT]] = INFO):
1318
self.message_type = message_type
1419
self.action = action
1520
self.message = message
@@ -31,11 +36,11 @@ def cli_log(log: CliLogData) -> None:
3136

3237
class Logger:
3338
default_action = "Custom"
34-
39+
3540
def __init__(self, action: Optional[str] = None, is_moose_task: bool = False):
3641
self.action = action or Logger.default_action
3742
self._is_moose_task = is_moose_task
38-
43+
3944
def _log(self, message: str, message_type: str) -> None:
4045
if self._is_moose_task:
4146
# We have a task decorator in the lib that initializes a logger
@@ -62,4 +67,32 @@ def error(self, message: str) -> None:
6267
self._log(message, CliLogData.ERROR)
6368

6469
def highlight(self, message: str) -> None:
65-
self._log(message, CliLogData.HIGHLIGHT)
70+
self._log(message, CliLogData.HIGHLIGHT)
71+
72+
73+
class EnhancedJSONEncoder(json.JSONEncoder):
74+
"""
75+
Custom JSON encoder that handles:
76+
- datetime objects (converts to ISO format with timezone)
77+
- dataclass instances (converts to dict)
78+
- Pydantic models (converts to dict)
79+
"""
80+
81+
def default(self, o):
82+
if isinstance(o, datetime):
83+
if o.tzinfo is None:
84+
o = o.replace(tzinfo=timezone.utc)
85+
return o.isoformat()
86+
if hasattr(o, "model_dump"): # Handle Pydantic v2 models
87+
# Convert to dict and handle datetime fields
88+
data = o.model_dump()
89+
# Handle any datetime fields that might be present
90+
for key, value in data.items():
91+
if isinstance(value, datetime):
92+
if value.tzinfo is None:
93+
value = value.replace(tzinfo=timezone.utc)
94+
data[key] = value.isoformat()
95+
return data
96+
if dataclasses.is_dataclass(o):
97+
return dataclasses.asdict(o)
98+
return super().default(o)

0 commit comments

Comments
 (0)