|
1 | 1 | import argparse |
2 | 2 | import asyncio |
3 | 3 | import dataclasses |
4 | | -import hashlib |
5 | 4 | import json |
6 | 5 | import os |
7 | 6 | import subprocess |
|
15 | 14 | from http.server import HTTPServer, BaseHTTPRequestHandler |
16 | 15 |
|
17 | 16 | from importlib import import_module |
18 | | -from string import Formatter |
19 | | -from typing import Optional, Dict, Any, Type |
| 17 | +from typing import Optional, Dict, Any |
20 | 18 | from urllib.parse import urlparse, parse_qs |
| 19 | +from moose_lib import MooseClient |
21 | 20 | from moose_lib.query_param import map_params_to_class, convert_consumption_api_param, convert_pydantic_definition |
22 | 21 | from moose_lib.internal import load_models |
23 | 22 | from moose_lib.dmv2 import get_consumption_api, get_workflow |
24 | 23 | from pydantic import BaseModel, ValidationError |
25 | 24 |
|
26 | 25 | import jwt |
27 | 26 | from clickhouse_connect import get_client |
28 | | -from clickhouse_connect.driver.client import Client as ClickhouseClient |
29 | 27 |
|
30 | | -from temporalio.client import Client as TemporalClient, TLSConfig |
31 | | -from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy, WorkflowIDReusePolicy |
| 28 | +from moose_lib.commons import EnhancedJSONEncoder |
| 29 | + |
32 | 30 | from consumption_wrapper.utils import create_temporal_connection |
33 | 31 |
|
34 | 32 | parser = argparse.ArgumentParser(description='Run Consumption Server') |
|
76 | 74 | sys.path.append(consumption_dir_path) |
77 | 75 |
|
78 | 76 |
|
79 | | - |
80 | | -# TODO: move this to python moose lib |
81 | | -class EnhancedJSONEncoder(json.JSONEncoder): |
82 | | - def default(self, o): |
83 | | - if isinstance(o, datetime): |
84 | | - if o.tzinfo is None: |
85 | | - o = o.replace(tzinfo=timezone.utc) |
86 | | - return o.isoformat() |
87 | | - if isinstance(o, date): |
88 | | - return o.isoformat() |
89 | | - if dataclasses.is_dataclass(o): |
90 | | - return dataclasses.asdict(o) |
91 | | - return super().default(o) |
92 | | - |
93 | | - |
94 | | -class QueryClient: |
95 | | - def __init__(self, ch_client: ClickhouseClient): |
96 | | - self.ch_client = ch_client |
97 | | - |
98 | | - def __call__(self, input, variables): |
99 | | - return self.execute(input, variables) |
100 | | - |
101 | | - def execute(self, input, variables, row_type: Type[BaseModel] = None): |
102 | | - params = {} |
103 | | - values = {} |
104 | | - |
105 | | - for i, (_, variable_name, _, _) in enumerate(Formatter().parse(input)): |
106 | | - if variable_name: |
107 | | - value = variables[variable_name] |
108 | | - if isinstance(value, list) and len(value) == 1: |
109 | | - # handling passing the value of the query string dict directly to variables |
110 | | - value = value[0] |
111 | | - |
112 | | - t = 'String' if isinstance(value, str) else \ |
113 | | - 'Int64' if isinstance(value, int) else \ |
114 | | - 'Float64' if isinstance(value, float) else "String" # unknown type |
115 | | - |
116 | | - params[variable_name] = f'{{p{i}: {t}}}' |
117 | | - values[f'p{i}'] = value |
118 | | - clickhouse_query = input.format_map(params) |
119 | | - |
120 | | - # We are not using the result of the ping |
121 | | - # but this ensures that if the clickhouse cloud service is idle, we |
122 | | - # wake it up, before we send the query. |
123 | | - self.ch_client.ping() |
124 | | - |
125 | | - val = self.ch_client.query(clickhouse_query, values) |
126 | | - |
127 | | - if row_type is None: |
128 | | - return list(val.named_results()) |
129 | | - else: |
130 | | - return list(row_type(**row) for row in val.named_results()) |
131 | | - |
132 | | -class WorkflowClient: |
133 | | - def __init__(self, temporal_client: TemporalClient): |
134 | | - self.temporal_client = temporal_client |
135 | | - self.configs = self.load_consolidated_configs() |
136 | | - print(f"WorkflowClient - configs: {self.configs}") |
137 | | - |
138 | | - # Test workflow executor in rust if this changes significantly |
139 | | - def execute(self, name: str, input_data: Any) -> Dict[str, Any]: |
140 | | - try: |
141 | | - workflow_id, run_id = asyncio.run(self._start_workflow_async(name, input_data)) |
142 | | - print(f"WorkflowClient - started workflow: {name}") |
143 | | - return { |
144 | | - "status": 200, |
145 | | - "body": f"Workflow started: {name}. View it in the Temporal dashboard: http://localhost:8080/namespaces/default/workflows/{workflow_id}/{run_id}/history" |
146 | | - } |
147 | | - except Exception as e: |
148 | | - print(f"WorkflowClient - error while starting workflow: {e}") |
149 | | - return { |
150 | | - "status": 400, |
151 | | - "body": str(e) |
152 | | - } |
153 | | - |
154 | | - async def _start_workflow_async(self, name: str, input_data: Any): |
155 | | - # Extract configuration based on workflow type |
156 | | - config = self._get_workflow_config(name) |
157 | | - |
158 | | - # Process input data and generate workflow ID (common logic) |
159 | | - processed_input, workflow_id = self._process_input_data(name, input_data) |
160 | | - |
161 | | - # Create retry policy and timeout (common logic) |
162 | | - retry_policy = RetryPolicy(maximum_attempts=config['retry_count']) |
163 | | - run_timeout = self.parse_timeout_to_timedelta(config['timeout_str']) |
164 | | - |
165 | | - print(f"WorkflowClient - starting {'DMv2 ' if config['is_dmv2'] else ''}workflow: {name} with retry policy: {retry_policy} and timeout: {run_timeout}") |
166 | | - |
167 | | - # Start workflow with appropriate args |
168 | | - workflow_args = self._build_workflow_args(name, processed_input, config['is_dmv2']) |
169 | | - |
170 | | - workflow_handle = await self.temporal_client.start_workflow( |
171 | | - "ScriptWorkflow", |
172 | | - args=workflow_args, |
173 | | - id=workflow_id, |
174 | | - task_queue="python-script-queue", |
175 | | - id_conflict_policy=WorkflowIDConflictPolicy.FAIL, |
176 | | - id_reuse_policy=WorkflowIDReusePolicy.ALLOW_DUPLICATE, |
177 | | - retry_policy=retry_policy, |
178 | | - run_timeout=run_timeout |
179 | | - ) |
180 | | - |
181 | | - return workflow_id, workflow_handle.result_run_id |
182 | | - |
183 | | - def _get_workflow_config(self, name: str) -> Dict[str, Any]: |
184 | | - """Extract workflow configuration from DMv2 or legacy config.""" |
185 | | - dmv2_workflow = get_workflow(name) |
186 | | - if dmv2_workflow is not None: |
187 | | - return { |
188 | | - 'retry_count': dmv2_workflow.config.retries or 3, |
189 | | - 'timeout_str': dmv2_workflow.config.timeout or "1h", |
190 | | - 'is_dmv2': True |
191 | | - } |
192 | | - else: |
193 | | - config = self.configs.get(name, {}) |
194 | | - return { |
195 | | - 'retry_count': config.get('retries', 3), |
196 | | - 'timeout_str': config.get('timeout', "1h"), |
197 | | - 'is_dmv2': False |
198 | | - } |
199 | | - |
200 | | - def _process_input_data(self, name: str, input_data: Any) -> tuple[Any, str]: |
201 | | - """Process input data and generate workflow ID.""" |
202 | | - workflow_id = name |
203 | | - if input_data: |
204 | | - try: |
205 | | - # Handle Pydantic model input for DMv2 |
206 | | - if isinstance(input_data, BaseModel): |
207 | | - input_data = input_data.model_dump() |
208 | | - elif isinstance(input_data, str): |
209 | | - input_data = json.loads(input_data) |
210 | | - |
211 | | - # Encode with custom encoder |
212 | | - input_data = json.loads( |
213 | | - json.dumps({"data": input_data}, cls=EnhancedJSONEncoder) |
214 | | - ) |
215 | | - |
216 | | - params_str = json.dumps(input_data, sort_keys=True) |
217 | | - params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16] |
218 | | - workflow_id = f"{name}-{params_hash}" |
219 | | - except Exception as e: |
220 | | - raise ValueError(f"Invalid input data: {e}") |
221 | | - |
222 | | - return input_data, workflow_id |
223 | | - |
224 | | - def _build_workflow_args(self, name: str, input_data: Any, is_dmv2: bool) -> list: |
225 | | - """Build workflow arguments based on workflow type.""" |
226 | | - if is_dmv2: |
227 | | - return [f"{name}", input_data] |
228 | | - else: |
229 | | - return [f"{os.getcwd()}/app/scripts/{name}", input_data] |
230 | | - |
231 | | - def load_consolidated_configs(self): |
232 | | - try: |
233 | | - file_path = os.path.join(os.getcwd(), ".moose", "workflow_configs.json") |
234 | | - with open(file_path, 'r') as file: |
235 | | - data = json.load(file) |
236 | | - config_map = {config['name']: config for config in data} |
237 | | - return config_map |
238 | | - except Exception as e: |
239 | | - raise ValueError(f"Error loading file {file_path}: {e}") |
240 | | - |
241 | | - def parse_timeout_to_timedelta(self, timeout_str: str) -> timedelta: |
242 | | - if timeout_str.endswith('h'): |
243 | | - return timedelta(hours=int(timeout_str[:-1])) |
244 | | - elif timeout_str.endswith('m'): |
245 | | - return timedelta(minutes=int(timeout_str[:-1])) |
246 | | - elif timeout_str.endswith('s'): |
247 | | - return timedelta(seconds=int(timeout_str[:-1])) |
248 | | - else: |
249 | | - raise ValueError(f"Unsupported timeout format: {timeout_str}") |
250 | | - |
251 | | -class MooseClient: |
252 | | - def __init__(self, ch_client: ClickhouseClient, temporal_client: Optional[TemporalClient] = None): |
253 | | - self.query = QueryClient(ch_client) |
254 | | - self.ch_client = ch_client # Store reference for cleanup |
255 | | - self.temporal_client = temporal_client |
256 | | - if temporal_client: |
257 | | - self.workflow = WorkflowClient(temporal_client) |
258 | | - else: |
259 | | - self.workflow = None |
260 | | - |
261 | | - async def cleanup(self): |
262 | | - """Cleanup resources before shutdown""" |
263 | | - if self.ch_client: |
264 | | - try: |
265 | | - self.ch_client.close() |
266 | | - except Exception as e: |
267 | | - print(f"Error closing Clickhouse client: {e}") |
268 | | - |
269 | | - if self.temporal_client: |
270 | | - try: |
271 | | - await self.temporal_client.close() |
272 | | - except Exception as e: |
273 | | - print(f"Error closing Temporal client: {e}") |
274 | | - |
275 | 77 | def verify_jwt(token: str) -> Optional[Dict[str, Any]]: |
276 | 78 | try: |
277 | 79 | payload = jwt.decode(token, jwt_secret, algorithms=["RS256"], audience=jwt_audience, issuer=jwt_issuer) |
|
0 commit comments