diff --git a/.gitignore b/.gitignore index 0fa30f9..41a3343 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,9 @@ poetry.toml # LSP config files pyrightconfig.json + +# Generated Protocol Buffer files +generated/ +*_pb2.py +*_pb2.pyi +*_pb2_grpc.py diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3d72f9d --- /dev/null +++ b/Makefile @@ -0,0 +1,113 @@ +# Makefile for DataCloud MCP Query Project +# This Makefile provides convenient commands for building and managing the project + +.PHONY: help install protos clean clean-protos test run dev-install all + +# Default target - show help +help: + @echo "DataCloud MCP Query - Makefile Commands" + @echo "========================================" + @echo "" + @echo "Available targets:" + @echo " make install - Install Python dependencies" + @echo " make dev-install - Install dependencies and compile protos" + @echo " make protos - Compile Protocol Buffer files" + @echo " make clean-protos - Remove generated Protocol Buffer files" + @echo " make clean - Clean all generated files and caches" + @echo " make test - Run tests (if available)" + @echo " make run - Run the MCP server" + @echo " make all - Install dependencies and compile protos" + @echo "" + +# Install Python dependencies +install: + @echo "Installing Python dependencies..." + pip install -r requirements.txt + +# Install dependencies and compile protos (for development) +dev-install: install protos + @echo "Development setup complete!" + +# Compile Protocol Buffer files +protos: + @echo "Compiling Protocol Buffer files..." + @python compile_protos.py + +# Clean generated Protocol Buffer files +clean-protos: + @echo "Removing generated Protocol Buffer files..." + rm -rf generated/ + @echo "Generated proto files removed." + +# Clean all generated files and caches +clean: clean-protos + @echo "Cleaning Python cache files..." + find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete 2>/dev/null || true + find . -type f -name "*.pyo" -delete 2>/dev/null || true + find . -type f -name "*.pyd" -delete 2>/dev/null || true + find . -type f -name ".coverage" -delete 2>/dev/null || true + find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true + find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true + find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true + @echo "Cleanup complete!" + +# Run tests (placeholder - update when tests are available) +test: + @echo "Running tests..." + @if [ -d "tests" ]; then \ + python -m pytest tests/; \ + else \ + echo "No tests directory found. Add tests to enable testing."; \ + fi + +# Run the MCP server +run: + @echo "Starting MCP server..." + @if [ -f "server.py" ]; then \ + python server.py; \ + else \ + echo "server.py not found!"; \ + exit 1; \ + fi + +# Build everything +all: dev-install + @echo "Project setup complete!" + +# Check if dependencies are installed +check-deps: + @echo "Checking dependencies..." + @python -c "import grpc_tools.protoc" 2>/dev/null || \ + (echo "Error: grpcio-tools not installed. Run 'make install' first." && exit 1) + @python -c "import google.protobuf" 2>/dev/null || \ + (echo "Error: protobuf not installed. Run 'make install' first." && exit 1) + @echo "All dependencies are installed." + +# Compile protos with dependency check +safe-protos: check-deps protos + +# Show current proto files +list-protos: + @echo "Proto files in the project:" + @find protos -name "*.proto" -type f | sort + +# Validate proto files (requires protoc) +validate-protos: + @echo "Validating proto files..." + @for proto in $$(find protos -name "*.proto" -type f); do \ + echo " Checking $$proto..."; \ + python -m grpc_tools.protoc --proto_path=protos $$proto --descriptor_set_out=/dev/null || exit 1; \ + done + @echo "All proto files are valid!" + +# Watch for changes and recompile (requires watchdog) +watch: + @echo "Watching for proto file changes..." + @echo "Note: This requires 'pip install watchdog'" + @which watchmedo >/dev/null 2>&1 || (echo "Error: watchdog not installed. Run 'pip install watchdog'" && exit 1) + watchmedo shell-command \ + --patterns="*.proto" \ + --recursive \ + --command='make protos' \ + protos diff --git a/PROTO_BUILD.md b/PROTO_BUILD.md new file mode 100644 index 0000000..6602a4e --- /dev/null +++ b/PROTO_BUILD.md @@ -0,0 +1,195 @@ +# Protocol Buffer Build System + +This project includes Protocol Buffer (protobuf) definitions for the DataCloud HyperService API. This document explains how to compile and use these proto files in Python. + +## Prerequisites + +The required dependencies are listed in `requirements.txt`: +- `protobuf>=4.25.0` - Protocol Buffer runtime +- `grpcio-tools>=1.60.0` - Protocol Buffer compiler for Python +- `grpcio>=1.60.0` - gRPC runtime + +Install all dependencies: +```bash +pip install -r requirements.txt +``` + +## Building Proto Files + +### Using the Makefile (Recommended) + +The easiest way to compile proto files is using the provided Makefile: + +```bash +# Compile proto files +make protos + +# Clean and recompile +make clean-protos && make protos + +# Install dependencies and compile protos +make dev-install + +# List all proto files +make list-protos + +# Validate proto files +make validate-protos +``` + +### Using the Python Script + +You can also use the Python script directly: + +```bash +python compile_protos.py +``` + +This script will: +1. Find all `.proto` files in the `protos/` directory +2. Compile them to Python code using `grpc_tools.protoc` +3. Generate type stubs (`.pyi` files) for better IDE support +4. Fix import paths to work with the project structure +5. Create `__init__.py` files for proper Python packaging + +## Generated Files + +After compilation, the generated files will be in the `generated/` directory: + +``` +generated/ +├── __init__.py +└── salesforce/ + ├── __init__.py + └── hyperdb/ + ├── __init__.py + ├── grpc/ + │ ├── __init__.py + │ └── v1/ + │ ├── __init__.py + │ ├── error_details_pb2.py + │ ├── error_details_pb2.pyi + │ ├── error_details_pb2_grpc.py + │ ├── hyper_service_pb2.py + │ ├── hyper_service_pb2.pyi + │ └── hyper_service_pb2_grpc.py + └── v1/ + ├── __init__.py + ├── query_status_pb2.py + ├── query_status_pb2.pyi + ├── sql_type_pb2.py + └── sql_type_pb2.pyi +``` + +## Using the Generated Code + +### Import the modules + +You can import the generated modules in your Python code: + +```python +# Import from the top-level generated package +from generated import ( + error_details_pb2, + hyper_service_pb2, + hyper_service_pb2_grpc, + query_status_pb2, + sql_type_pb2 +) + +# Or import from the specific package paths +from generated.salesforce.hyperdb.grpc.v1 import hyper_service_pb2 +from generated.salesforce.hyperdb.grpc.v1 import hyper_service_pb2_grpc +``` + +### Example Usage + +```python +# Create a QueryParam message +query_param = hyper_service_pb2.QueryParam( + query="SELECT * FROM Account LIMIT 10", + output_format=hyper_service_pb2.OutputFormat.JSON_ARRAY, + transfer_mode=hyper_service_pb2.QueryParam.TransferMode.SYNC +) + +# Create an ErrorInfo message +error_info = error_details_pb2.ErrorInfo( + primary_message="Query failed", + sqlstate="42000", + customer_detail="Invalid SQL syntax" +) + +# Use with gRPC client (example) +import grpc +channel = grpc.insecure_channel('localhost:50051') +stub = hyper_service_pb2_grpc.HyperServiceStub(channel) +# response = stub.ExecuteQuery(query_param) +``` + +## Proto File Structure + +The project includes the following proto files: + +- **`hyper_service.proto`**: Main service definition for HyperService with RPC methods: + - `ExecuteQuery`: Submit and execute a query + - `GetQueryInfo`: Get information about a query + - `GetQueryResult`: Retrieve query results + - `CancelQuery`: Cancel a running query + +- **`error_details.proto`**: Error detail messages for rich error handling + - `ErrorInfo`: Detailed error information + - `TextPosition`: Position information for errors in SQL text + +- **`query_status.proto`**: Query status related messages + +- **`sql_type.proto`**: SQL type definitions + +## Development Tips + +1. **Auto-rebuild on changes**: If you have `watchdog` installed, you can watch for proto file changes: + ```bash + pip install watchdog + make watch + ``` + +2. **Type hints**: The generated `.pyi` files provide type hints for better IDE support + +3. **Import fixes**: The build script automatically fixes import paths to use `generated.salesforce.*` instead of absolute `salesforce.*` imports + +## Troubleshooting + +### Import Errors + +If you encounter import errors like "No module named 'salesforce'", ensure: +1. The proto files have been compiled: `make protos` +2. The `generated` directory exists and contains the compiled files +3. You're importing from `generated.*` not directly from `salesforce.*` + +### Compilation Errors + +If compilation fails: +1. Check that all dependencies are installed: `pip install -r requirements.txt` +2. Ensure the proto files are valid: `make validate-protos` +3. Check the error output for specific issues with proto syntax + +### Clean Build + +If you're having issues, try a clean rebuild: +```bash +make clean +make dev-install +``` + +## Git Ignore + +The generated files are excluded from version control via `.gitignore`: +- `generated/` - The entire generated directory +- `*_pb2.py` - Generated Python protobuf files +- `*_pb2.pyi` - Generated type stub files +- `*_pb2_grpc.py` - Generated gRPC service files + +## Further Information + +- [Protocol Buffers Documentation](https://developers.google.com/protocol-buffers) +- [gRPC Python Documentation](https://grpc.io/docs/languages/python/) +- [Salesforce Data Cloud SQL Reference](https://developer.salesforce.com/docs/data/data-cloud-query-guide/references/dc-sql-reference/) diff --git a/README.md b/README.md index 453091b..fc8b75f 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,13 @@ This MCP server provides a seamless integration between Cursor and Salesforce Da ```bash pip install -r requirements.txt ``` -3. Connect to the MCP server in Cursor: +3. Build the Protocol Buffer files: + ```bash + make protos + # Or: python compile_protos.py + ``` + For more details on the proto build system, see [PROTO_BUILD.md](PROTO_BUILD.md). +4. Connect to the MCP server in Cursor: - Open Cursor IDE. - Go to **Cursor Settings** → **MCP**. - Click on **Add new global MCP server**. diff --git a/compile_protos.py b/compile_protos.py new file mode 100755 index 0000000..8303eb8 --- /dev/null +++ b/compile_protos.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +""" +Script to compile Protocol Buffer files for the DataCloud MCP Query project. +This script generates Python code from .proto files using the protoc compiler. +""" + +import os +import sys +import subprocess +from pathlib import Path + +# Define paths +PROJECT_ROOT = Path(__file__).parent +PROTO_DIR = PROJECT_ROOT / "protos" +OUTPUT_DIR = PROJECT_ROOT / "generated" + + +def ensure_output_dir(): + """Ensure the output directory exists.""" + OUTPUT_DIR.mkdir(exist_ok=True) + + +def find_proto_files(): + """Find all .proto files in the protos directory.""" + proto_files = [] + for root, dirs, files in os.walk(PROTO_DIR): + for file in files: + if file.endswith('.proto'): + proto_path = Path(root) / file + proto_files.append(str(proto_path.relative_to(PROJECT_ROOT))) + return proto_files + + +def compile_protos(): + """Compile all proto files.""" + ensure_output_dir() + + proto_files = find_proto_files() + + if not proto_files: + print("No .proto files found in the protos directory.") + return False + + print(f"Found {len(proto_files)} proto file(s) to compile:") + for proto_file in proto_files: + print(f" - {proto_file}") + + # Prepare the protoc command + # Use mypy_protobuf for better Python type hints + cmd = [ + sys.executable, "-m", "grpc_tools.protoc", + f"--proto_path={PROTO_DIR}", + f"--python_out={OUTPUT_DIR}", + f"--grpc_python_out={OUTPUT_DIR}", + f"--pyi_out={OUTPUT_DIR}", # Generate type stubs + ] + + # Add proto files - they're already relative to PROJECT_ROOT, so we just need the part after 'protos/' + for proto_file in proto_files: + # Remove the 'protos/' prefix from the path + relative_proto = Path(proto_file).relative_to('protos') + cmd.append(str(relative_proto)) + + print("\nCompiling proto files...") + print(f"Command: {' '.join(cmd)}") + + try: + result = subprocess.run( + cmd, capture_output=True, text=True, check=True) + print("✅ Proto files compiled successfully!") + + # Create __init__.py files for proper Python package structure + create_init_files() + + # Fix import paths in generated files + fix_import_paths() + + return True + except subprocess.CalledProcessError as e: + print(f"❌ Error compiling proto files:") + print(f"Return code: {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + return False + except Exception as e: + print(f"❌ Unexpected error: {e}") + return False + + +def fix_import_paths(): + """Fix absolute import paths in generated files to use relative imports.""" + print("\nFixing import paths in generated files...") + + import re + fixed_count = 0 + + # First, collect all top-level packages in the generated directory + # These are directories that exist directly under generated/ + top_level_packages = set() + if OUTPUT_DIR.exists(): + for item in OUTPUT_DIR.iterdir(): + if item.is_dir() and not item.name.startswith('__'): + top_level_packages.add(item.name) + + # Find all generated Python files + for root, dirs, files in os.walk(OUTPUT_DIR): + for file in files: + if file.endswith(('.py', '.pyi')) and file != '__init__.py': + file_path = Path(root) / file + + # Read the file + with open(file_path, 'r') as f: + content = f.read() + + original_content = content + + # Dynamically fix imports for any top-level package found in generated/ + for package in top_level_packages: + # Fix "from package." and "import package." patterns + # These are absolute imports that need to be prefixed with "generated." + from_pattern = f'from {package}.' + import_pattern = f'import {package}.' + + if from_pattern in content: + content = content.replace(from_pattern, f'from generated.{package}.') + if import_pattern in content: + content = content.replace(import_pattern, f'import generated.{package}.') + + # For files in the root of generated/, fix imports of other root-level modules + if Path(root) == OUTPUT_DIR: + # Match imports like "import module_pb2 as ..." or "import module_pb2_grpc as ..." + pattern = r'import\s+(\w+_pb2(?:_grpc)?)\s+as' + matches = re.findall(pattern, content) + for module_name in matches: + # Check if this module exists in the generated root + if (OUTPUT_DIR / f"{module_name}.py").exists(): + # Replace with relative import + old_import = f'import {module_name} as' + new_import = f'from generated import {module_name} as' + content = content.replace(old_import, new_import) + + # Also fix any cross-references between root-level modules + # Pattern: "from module_pb2 import" or standalone "import module_pb2" + root_modules = [f.stem for f in OUTPUT_DIR.glob('*_pb2.py')] + root_modules.extend([f.stem for f in OUTPUT_DIR.glob('*_pb2_grpc.py')]) + for module in root_modules: + # Fix "from module import Something" + from_module_pattern = f'from {module} import' + if from_module_pattern in content: + content = content.replace( + from_module_pattern, + f'from generated.{module} import') + # Fix standalone "import module" (not followed by 'as') + import_module_pattern = f'\nimport {module}\n' + if import_module_pattern in content: + content = content.replace( + import_module_pattern, + f'\nfrom generated import {module}\n') + + # Write back if changed + if content != original_content: + with open(file_path, 'w') as f: + f.write(content) + fixed_count += 1 + print( + f" Fixed imports in: {file_path.relative_to(PROJECT_ROOT)}") + + if fixed_count == 0: + print(" No imports needed fixing") + else: + print(f" Fixed imports in {fixed_count} file(s)") + if top_level_packages: + print(f" Auto-detected packages: {', '.join(sorted(top_level_packages))}") + + +def create_init_files(): + """Create __init__.py files in the generated directory structure.""" + print("\nCreating __init__.py files...") + + # Walk through the generated directory and create __init__.py files + for root, dirs, files in os.walk(OUTPUT_DIR): + # Skip if directory already has __init__.py + init_file = Path(root) / "__init__.py" + if not init_file.exists(): + init_file.touch() + print(f" Created: {init_file.relative_to(PROJECT_ROOT)}") + + # Dynamically generate the main __init__.py content based on generated files + generate_main_init() + + +def generate_main_init(): + """Dynamically generate the main __init__.py file based on generated proto files.""" + # Find all generated proto modules + proto_modules = {} # module_name -> (package_path, is_grpc) + + for root, dirs, files in os.walk(OUTPUT_DIR): + for file in files: + if file.endswith('_pb2.py'): + # Regular proto module + module_name = file[:-3] # Remove .py extension + if not module_name.endswith('_grpc'): + # Get the relative package path + rel_path = Path(root).relative_to(OUTPUT_DIR) + package_path = '.'.join( + rel_path.parts) if rel_path.parts else '' + proto_modules[module_name] = (package_path, False) + elif file.endswith('_pb2_grpc.py'): + # gRPC service module + module_name = file[:-3] # Remove .py extension + rel_path = Path(root).relative_to(OUTPUT_DIR) + package_path = '.'.join( + rel_path.parts) if rel_path.parts else '' + proto_modules[module_name] = (package_path, True) + + # Sort modules for consistent ordering + sorted_modules = sorted(proto_modules.items()) + + # Group modules by package for organized imports + packages = {} + for module_name, (package_path, is_grpc) in sorted_modules: + if package_path not in packages: + packages[package_path] = [] + packages[package_path].append(module_name) + + # Generate the init file content + init_content = '''""" +Generated Protocol Buffer Python code. +This module contains the compiled protobuf definitions for the DataCloud MCP Query project. + +Auto-generated based on the following proto files: +''' + + # Add list of proto files + proto_files = find_proto_files() + for proto_file in sorted(proto_files): + init_content += f" - {proto_file}\n" + + init_content += ''' +You can import the generated modules directly using their full paths, +or use the convenient re-exports from this module. +""" + +# Import and re-export for convenience +try: +''' + + # Generate import statements + all_modules = [] + for package_path in sorted(packages.keys()): + modules = packages[package_path] + if package_path: + # Modules in subdirectories + init_content += f" from .{package_path} import (\n" + for module in sorted(modules): + init_content += f" {module},\n" + all_modules.append(module) + init_content += " )\n" + else: + # Modules in the root directory + for module in sorted(modules): + init_content += f" from . import {module}\n" + all_modules.append(module) + + # Generate __all__ list + init_content += ''' + # Re-export for convenience at the top level + __all__ = [ +''' + for module in sorted(all_modules): + init_content += f" '{module}',\n" + init_content += ''' ] + + # Make them available at package level for easier imports + _exported_modules = { +''' + for module in sorted(all_modules): + init_content += f" '{module}': {module},\n" + init_content += ''' } + + locals().update(_exported_modules) + + print(f"Successfully loaded {len(__all__)} proto modules") + +except ImportError as e: + import traceback + print(f"Warning: Could not import generated protobuf modules: {e}") + print("This might happen if you're importing before compilation.") + print("Run 'make protos' or 'python compile_protos.py' to generate the modules.") + __all__ = [] +''' + + # Write the generated content + main_init = OUTPUT_DIR / "__init__.py" + with open(main_init, 'w') as f: + f.write(init_content) + + print( + f" Generated main __init__.py with {len(all_modules)} module imports") + + +def main(): + """Main entry point.""" + print("=" * 60) + print("Protocol Buffer Compilation Script") + print("=" * 60) + + # Check if grpc_tools is installed + try: + import grpc_tools.protoc + except ImportError: + print("❌ Error: grpc_tools is not installed.") + print("Please run: pip install grpcio-tools") + sys.exit(1) + + success = compile_protos() + + if success: + print("\n" + "=" * 60) + print("✅ All proto files compiled successfully!") + print( + f"Generated files are in: {OUTPUT_DIR.relative_to(PROJECT_ROOT)}/") + print("\nYou can now import the generated modules:") + print(" from generated import hyper_service_pb2, hyper_service_pb2_grpc") + print("=" * 60) + else: + print("\n❌ Proto compilation failed. Please check the errors above.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/hyper_grpc_client.py b/hyper_grpc_client.py new file mode 100644 index 0000000..69b6ccf --- /dev/null +++ b/hyper_grpc_client.py @@ -0,0 +1,627 @@ +import enum +import logging +import argparse +import os +from urllib.parse import urlparse +import datetime as dt +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Generator, Iterable, Iterator, List, Optional, Tuple + +import grpc + +from generated.salesforce.hyperdb.grpc.v1 import hyper_service_pb2 as hs_pb2 +from generated.salesforce.hyperdb.grpc.v1 import hyper_service_pb2_grpc as hs_grpc +from generated.salesforce.hyperdb.grpc.v1 import error_details_pb2 as ed_pb2 + + +logger = logging.getLogger(__name__) +class HyperGrpcError(Exception): + def __init__( + self, + *, + code: grpc.StatusCode, + details: Optional[str] = None, + sqlstate: Optional[str] = None, + primary_message: Optional[str] = None, + customer_detail: Optional[str] = None, + customer_hint: Optional[str] = None, + error_source: Optional[str] = None, + ) -> None: + self.code = code + self.details = details + self.sqlstate = sqlstate + self.primary_message = primary_message + self.customer_detail = customer_detail + self.customer_hint = customer_hint + self.error_source = error_source + super().__init__(self._format_message()) + + def _format_message(self) -> str: + parts = [f"code={self.code.name}"] + if self.sqlstate: + parts.append(f"sqlstate={self.sqlstate}") + if self.primary_message: + parts.append(f"message={self.primary_message}") + elif self.details: + parts.append(f"details={self.details}") + if self.customer_detail: + parts.append(f"detail={self.customer_detail}") + if self.customer_hint: + parts.append(f"hint={self.customer_hint}") + if self.error_source: + parts.append(f"source={self.error_source}") + return ", ".join(parts) + + +def _convert_grpc_error(exc: grpc.RpcError) -> HyperGrpcError: + """Convert grpc.RpcError with structured details to HyperGrpcError.""" + code = exc.code() if hasattr(exc, "code") else grpc.StatusCode.UNKNOWN + details_text = exc.details() if hasattr(exc, "details") else None + # Parse structured details from trailers (grpc-status-details-bin) + sqlstate = None + primary_message = None + customer_detail = None + customer_hint = None + error_source = None + try: + # Lazy import to avoid hard dependency at module import time + from google.rpc import status_pb2 as google_status_pb2 # type: ignore + md = dict(exc.trailing_metadata() or []) # type: ignore[arg-type] + raw = md.get("grpc-status-details-bin") + if isinstance(raw, str): + raw = raw.encode("latin1", errors="ignore") + if raw: + st = google_status_pb2.Status() + st.MergeFromString(raw) + for any_msg in st.details: + info = ed_pb2.ErrorInfo() + if any_msg.Is(info.DESCRIPTOR) and any_msg.Unpack(info): + sqlstate = info.sqlstate or None + primary_message = info.primary_message or None + customer_detail = info.customer_detail or None + customer_hint = info.customer_hint or None + error_source = info.error_source or None + break + except Exception: # Fall back to plain details + pass + return HyperGrpcError( + code=code, + details=details_text, + sqlstate=sqlstate, + primary_message=primary_message, + customer_detail=customer_detail, + customer_hint=customer_hint, + error_source=error_source, + ) + + + +@dataclass +class ResultChunk: + data: bytes + row_count: int + + +class HyperGrpcClient: + """ + Thin gRPC client for HyperService with convenience helpers. + + - Provides ExecuteQuery, GetQueryInfo, GetQueryResult, CancelQuery wrappers + - Use AdaptiveQueryResultIterator to stream results in ADAPTIVE mode + """ + + def __init__(self, channel: grpc.Channel, default_metadata: Optional[List[Tuple[str, str]]] = None): + self._channel = channel + self._stub = hs_grpc.HyperServiceStub(channel) + self._default_metadata = tuple(default_metadata or []) + + @classmethod + def secure_channel( + cls, + target: str, + credentials: Optional[grpc.ChannelCredentials] = None, + default_metadata: Optional[List[Tuple[str, str]]] = None, + ) -> "HyperGrpcClient": + creds = credentials or grpc.ssl_channel_credentials() + channel = grpc.secure_channel(target, creds) + return cls(channel, default_metadata) + + def execute_query( + self, + sql: str, + *, + metadata: Optional[List[Tuple[str, str]]] = None, + grpc_timeout: Optional[dt.timedelta] = None, + ) -> Iterator[hs_pb2.ExecuteQueryResponse]: + params = hs_pb2.QueryParam( + query=sql, + output_format=hs_pb2.OutputFormat.ARROW_IPC, + transfer_mode=hs_pb2.QueryParam.TransferMode.ADAPTIVE, + ) + + md = self._merge_metadata(metadata) + logger.debug("ExecuteQuery") + grpc_timeout = None if grpc_timeout is None else grpc_timeout.total_seconds() + return self._stub.ExecuteQuery(params, metadata=md, timeout=grpc_timeout) + + def get_query_info( + self, + query_id: str, + *, + streaming: bool = True, + metadata: Optional[List[Tuple[str, str]]] = None, + grpc_timeout: Optional[dt.timedelta] = None, + ) -> Iterator[hs_pb2.QueryInfo]: + request = hs_pb2.QueryInfoParam(query_id=query_id, streaming=streaming) + md = self._merge_metadata([("x-hyperdb-query-id", query_id)], metadata) + grpc_timeout = None if grpc_timeout is None else grpc_timeout.total_seconds() + return self._stub.GetQueryInfo(request, metadata=md, timeout=grpc_timeout) + + def get_query_result( + self, + query_id: str, + *, + chunk_id: Optional[int] = None, + omit_schema: bool = True, + metadata: Optional[List[Tuple[str, str]]] = None, + grpc_timeout: Optional[dt.timedelta] = None, + ) -> Iterator[hs_pb2.QueryResult]: + request = hs_pb2.QueryResultParam( + query_id=query_id, + output_format=hs_pb2.OutputFormat.ARROW_IPC, + omit_schema=omit_schema, + ) + if chunk_id is not None: + request.chunk_id = int(chunk_id) + md = self._merge_metadata([("x-hyperdb-query-id", query_id)], metadata) + grpc_timeout = None if grpc_timeout is None else grpc_timeout.total_seconds() + return self._stub.GetQueryResult(request, metadata=md, timeout=grpc_timeout) + + def cancel_query(self, query_id: str, metadata: Optional[List[Tuple[str, str]]] = None) -> None: + request = hs_pb2.CancelQueryParam(query_id=query_id) + md = self._merge_metadata([("x-hyperdb-query-id", query_id)], metadata) + # Fire-and-forget; ignore result + try: + self._stub.CancelQuery(request, metadata=md) + except grpc.RpcError as e: + # Cancellation may race with query completion; not fatal + logger.debug("CancelQuery failed: %s", e) + + def adaptive_iterator( + self, + sql: str, + *, + metadata: Optional[List[Tuple[str, str]]] = None, + grpc_timeout: Optional[dt.timedelta] = None, + ) -> "AdaptiveQueryResultIterator": + execute_stream = self.execute_query( + sql, + metadata=metadata, + grpc_timeout=grpc_timeout, + ) + return AdaptiveQueryResultIterator(client=self, execute_stream=execute_stream) + + def _merge_metadata( + self, + *metadata: Optional[Iterable[Tuple[str, str]]], + ) -> List[Tuple[str, str]]: + md: List[Tuple[str, str]] = list(self._default_metadata) + for m in metadata: + if m: + md.extend(list(m)) + return md + + +class _State(enum.Enum): + PROCESS_EXECUTE_QUERY_STREAM = enum.auto() + CHECK_FOR_MORE_DATA = enum.auto() + PROCESS_QUERY_RESULT_STREAM = enum.auto() + PROCESS_QUERY_INFO_STREAM = enum.auto() + COMPLETED = enum.auto() + + +@dataclass +class _Context: + query_id: Optional[str] = None + status: Optional[hs_pb2.QueryStatus] = None + high_water: int = 1 # next chunk id to request + + # active streams + execute_stream: Optional[Iterator[hs_pb2.ExecuteQueryResponse]] = None + info_stream: Optional[Iterator[hs_pb2.QueryInfo]] = None + result_stream: Optional[Iterator[hs_pb2.QueryResult]] = None + + # buffered results to yield next + result_queue: Deque[hs_pb2.QueryResult] = field(default_factory=deque) + + def has_more_chunks(self) -> bool: + return bool(self.status and (self.high_water < int(self.status.chunk_count))) + + def all_results_produced(self) -> bool: + if not self.status: + return False + return self.status.completion_status in ( + hs_pb2.QueryStatus.CompletionStatus.RESULTS_PRODUCED, + hs_pb2.QueryStatus.CompletionStatus.FINISHED, + ) + + +class AdaptiveQueryResultIterator: + """ + Implements the adaptive state machine inspired by the Java/C++ clients. + + Iterate over this object to receive `QueryResult` messages (ARROW_IPC or JSON_ARRAY parts). + """ + + def __init__( + self, + *, + client: HyperGrpcClient, + execute_stream: Iterator[hs_pb2.ExecuteQueryResponse], + ) -> None: + self._client = client + self._state: _State = _State.PROCESS_EXECUTE_QUERY_STREAM + self._ctx = _Context(execute_stream=execute_stream) + + @property + def query_id(self) -> Optional[str]: + return self._ctx.query_id + + @property + def query_status(self) -> Optional[hs_pb2.QueryStatus]: + return self._ctx.status + + def __iter__(self) -> Iterator[hs_pb2.QueryResult]: + while True: + # If we already have results buffered, yield them immediately + if self._ctx.result_queue: + yield self._ctx.result_queue.popleft() + continue + + if self._state == _State.PROCESS_EXECUTE_QUERY_STREAM: + try: + assert self._ctx.execute_stream is not None + response = next(self._ctx.execute_stream) + if response.HasField("query_info"): + self._update_query_context(response.query_info) + elif response.HasField("query_result"): + self._ctx.result_queue.append(response.query_result) + else: + if not response.optional: + raise RuntimeError( + "Received unexpected non-optional ExecuteQueryResponse" + ) + except StopIteration: + self._transition(_State.CHECK_FOR_MORE_DATA) + except grpc.RpcError as exc: + if exc.code() == grpc.StatusCode.CANCELLED: + logger.warning( + "ExecuteQuery stream cancelled; retrying via status") + self._ctx.execute_stream = None + self._transition(_State.CHECK_FOR_MORE_DATA) + else: + raise _convert_grpc_error(exc) + + elif self._state == _State.CHECK_FOR_MORE_DATA: + if self._ctx.has_more_chunks(): + chunk_id = self._ctx.high_water + self._ctx.high_water += 1 + assert self._ctx.query_id + self._ctx.result_stream = self._client.get_query_result( + self._ctx.query_id, + chunk_id=chunk_id, + omit_schema=True, + ) + self._transition(_State.PROCESS_QUERY_RESULT_STREAM) + elif not self._ctx.all_results_produced(): + assert self._ctx.query_id + self._ctx.info_stream = self._client.get_query_info( + self._ctx.query_id, streaming=True) + self._transition(_State.PROCESS_QUERY_INFO_STREAM) + else: + self._transition(_State.COMPLETED) + + elif self._state == _State.PROCESS_QUERY_RESULT_STREAM: + try: + assert self._ctx.result_stream is not None + result = next(self._ctx.result_stream) + self._ctx.result_queue.append(result) + except StopIteration: + self._transition(_State.CHECK_FOR_MORE_DATA) + except grpc.RpcError as exc: + if exc.code() == grpc.StatusCode.CANCELLED: + logger.warning( + "GetQueryResult stream cancelled; retrying") + self._ctx.result_stream = None + # Reset any partial results and retry the same chunk via CHECK_FOR_MORE_DATA + self._transition(_State.CHECK_FOR_MORE_DATA) + elif exc.code() == grpc.StatusCode.FAILED_PRECONDITION: + # Rely on GetQueryInfo to surface actual query error + self._drain_info_stream_if_open() + self._transition(_State.PROCESS_QUERY_INFO_STREAM) + else: + raise _convert_grpc_error(exc) + + elif self._state == _State.PROCESS_QUERY_INFO_STREAM: + try: + progressed = False + while not self._ctx.has_more_chunks(): + if self._ctx.info_stream is None: + break + try: + info = next(self._ctx.info_stream) + self._update_query_context(info) + progressed = True + except StopIteration: + self._ctx.info_stream = None + break + # Either chunks became available, or stream ended; re-check + self._transition(_State.CHECK_FOR_MORE_DATA) + # If we made progress by reading infos, loop continues + except grpc.RpcError as exc: + if exc.code() == grpc.StatusCode.CANCELLED: + logger.warning( + "GetQueryInfo stream cancelled; retrying") + self._ctx.info_stream = None + self._transition(_State.CHECK_FOR_MORE_DATA) + else: + raise _convert_grpc_error(exc) + + elif self._state == _State.COMPLETED: + return + + def stream_arrow_ipc(self) -> Iterator[bytes]: + for qr in self: + if qr.HasField("binary_part"): + yield qr.binary_part.data + elif qr.HasField("string_part"): + # For JSON_ARRAY format, users likely prefer textual chunks + yield qr.string_part.data.encode("utf-8") + + def _update_query_context(self, info: hs_pb2.QueryInfo) -> None: + if info.optional: + return + if info.HasField("query_status"): + self._ctx.status = info.query_status + if not self._ctx.query_id: + self._ctx.query_id = info.query_status.query_id + + def _transition(self, new_state: _State) -> None: + logger.debug("state transition: %s -> %s (qid=%s)", + self._state, new_state, self._ctx.query_id) + self._state = new_state + + def _drain_info_stream_if_open(self) -> None: + if self._ctx.info_stream is None: + return + try: + for _ in self._ctx.info_stream: + pass + except grpc.RpcError: + # Ignore during drain + pass + finally: + self._ctx.info_stream = None + + +__all__ = [ + "HyperGrpcClient", + "AdaptiveQueryResultIterator", + "ArrowIpcRowIterator", + "OutputFormat", + "ResultChunk", +] + +# Public alias so callers can reference OutputFormat directly from this module +OutputFormat = hs_pb2.OutputFormat + + +class ArrowIpcRowIterator: + """ + Helper that consumes an AdaptiveQueryResultIterator's Arrow IPC bytes and + yields Python dict rows. It buffers the full Arrow IPC stream in memory. + """ + + def __init__(self, adaptive_iter: AdaptiveQueryResultIterator) -> None: + self._adaptive_iter = adaptive_iter + + def iter_record_batches(self) -> Iterator[object]: + """ + Stream-decode Arrow IPC without buffering all chunks in memory. + + This wraps the chunk iterator in a file-like object that feeds bytes + on-demand to PyArrow's streaming reader. + """ + try: + import io + import pyarrow as pa + import pyarrow.ipc as pa_ipc + except ModuleNotFoundError as e: + raise RuntimeError( + "pyarrow is required for Arrow decoding. Install with `pip install pyarrow`." + ) from e + + class _ChunkInput(io.RawIOBase): + def __init__(self, chunk_iter: Iterator[bytes]) -> None: + self._iter = iter(chunk_iter) + self._buffer = bytearray() + self._eof = False + + def readable(self) -> bool: + return True + + def _fill_buffer(self) -> None: + if self._eof: + return + try: + next_chunk = next(self._iter) + if next_chunk: + self._buffer.extend(next_chunk) + else: + # Treat empty chunk as no-op; fetch next on subsequent reads + pass + except StopIteration: + self._eof = True + + def readinto(self, b: bytearray) -> int: + # Ensure at least some data is available or we reached EOF + while not self._buffer and not self._eof: + self._fill_buffer() + if not self._buffer and self._eof: + return 0 + n = min(len(b), len(self._buffer)) + b[:n] = self._buffer[:n] + del self._buffer[:n] + return n + + def read(self, size: int = -1) -> bytes: + if size is None or size < 0: + chunks: list[bytes] = [] + while True: + if self._buffer: + chunks.append(bytes(self._buffer)) + self._buffer.clear() + if self._eof: + break + self._fill_buffer() + if not self._buffer and self._eof: + break + return b"".join(chunks) + # Sized read + out = bytearray() + while len(out) < size: + if self._buffer: + take = min(size - len(out), len(self._buffer)) + out += self._buffer[:take] + del self._buffer[:take] + continue + if self._eof: + break + self._fill_buffer() + if not self._buffer and self._eof: + break + return bytes(out) + + stream = pa.input_stream(_ChunkInput(self._adaptive_iter.stream_arrow_ipc())) + reader = pa_ipc.open_stream(stream) + try: + while True: + batch = reader.read_next_batch() + if batch is None: + break + yield batch + except StopIteration: + return + + def iter_rows(self) -> Iterator[dict]: + for batch in self.iter_record_batches(): + num_cols = batch.num_columns + if num_cols == 0: + continue + # Preserve duplicate column names by disambiguating with numeric suffixes + original_names = [batch.schema.names[i] for i in range(num_cols)] + name_counts: dict[str, int] = {} + output_names: list[str] = [] + for name in original_names: + count = name_counts.get(name, 0) + 1 + name_counts[name] = count + if count == 1: + output_names.append(name) + else: + output_names.append(f"{name}_{count}") + + column_values_lists = [batch.column(i).to_pylist() for i in range(num_cols)] + num_rows = len(column_values_lists[0]) if column_values_lists else 0 + for row_index in range(num_rows): + yield {output_names[i]: column_values_lists[i][row_index] for i in range(num_cols)} + +if __name__ == "__main__": + # Lazy import to avoid hard dependency for library users + from oauth import OAuthConfig, OAuthSession + + parser = argparse.ArgumentParser(description="Execute a SQL query via Hyper gRPC Adaptive iterator using OAuth Data Cloud session") + parser.add_argument("--sql", required=False, default=os.getenv("HYPER_SQL", "SELECT 1"), help="SQL to execute") + # output format fixed to ARROW_IPC + # row/byte limits removed for simplicity + parser.add_argument("--timeout", type=float, default=None, help="RPC timeout in seconds for calls") + parser.add_argument("--print-rows", action="store_true", help="Decode Arrow IPC and print rows as Python dicts") + parser.add_argument("--max-rows", type=int, default=0, help="Optional cap when printing decoded rows (0 = no limit)") + parser.add_argument("--metadata", action="append", default=[], help="Additional metadata headers key=value (repeatable)") + parser.add_argument("-v", "--verbose", action="count", default=0, help="Increase verbosity (-v, -vv)") + args = parser.parse_args() + + log_level = logging.WARNING + if args.verbose == 1: + log_level = logging.INFO + elif args.verbose >= 2: + log_level = logging.DEBUG + logging.basicConfig(level=log_level, format='%(asctime)s %(levelname)s %(name)s: %(message)s') + + # Build default metadata from --metadata key=value + md: list[tuple[str, str]] = [] + for item in args.metadata: + if "=" in item: + k, v = item.split("=", 1) + md.append((k.strip(), v.strip())) + elif item: + logger.warning("Ignoring malformed metadata entry (expected key=value): %s", item) + + # Initialize OAuth Data Cloud session and derive secure gRPC target from instance URL + cfg = OAuthConfig.from_env() + base_session = OAuthSession(cfg) + dc_session = base_session.create_dc_session() + token = dc_session.get_token() + instance_url = dc_session.get_instance_url() + parsed = urlparse(instance_url if "://" in instance_url else f"https://{instance_url}") + host = parsed.netloc or parsed.path + if not host: + raise RuntimeError(f"Invalid instance URL: {instance_url}") + target = f"{host}:443" + + # Always secure; add Authorization metadata + md.append(("authorization", f"Bearer {token}")) + client = HyperGrpcClient.secure_channel(target, default_metadata=md) + + td_timeout = None if args.timeout is None else dt.timedelta(seconds=float(args.timeout)) + iterator = client.adaptive_iterator( + args.sql, + grpc_timeout=td_timeout, + ) + + total_parts = 0 + total_bytes = 0 + last_qid = None + + try: + if args.print_rows: + printed = 0 + row_iter = ArrowIpcRowIterator(iterator) + for row in row_iter.iter_rows(): + if iterator.query_id and iterator.query_id != last_qid: + print(f"query_id={iterator.query_id}") + last_qid = iterator.query_id + print(row) + printed += 1 + if args.max_rows and printed >= args.max_rows: + break + else: + for part in iterator: + if iterator.query_id and iterator.query_id != last_qid: + print(f"query_id={iterator.query_id}") + last_qid = iterator.query_id + if part.HasField("binary_part"): + data = part.binary_part.data + elif part.HasField("string_part"): + data = part.string_part.data.encode("utf-8") + else: + data = b"" + total_parts += 1 + total_bytes += len(data) + print(f"chunk {total_parts}: {len(data)} bytes, rows={part.result_part_row_count}") + except grpc.RpcError as e: + err = _convert_grpc_error(e) + print(f"gRPC error: {err}") + raise err + + print(f"done: parts={total_parts}, bytes={total_bytes}") diff --git a/oauth.py b/oauth.py index 9cb876a..8cb8082 100644 --- a/oauth.py +++ b/oauth.py @@ -186,3 +186,106 @@ def get_token(self) -> str: def get_instance_url(self) -> str: self.ensure_access() return self.instance_url + + def create_dc_session(self) -> "OAuthDCSession": + """Create an OAuthDCSession that manages a Data Cloud token via token exchange.""" + return OAuthDCSession(self) + + +class OAuthDCSession: + def __init__(self, base_session: OAuthSession): + self._base = base_session + self.dc_token: str | None = None + self.dc_exp: datetime | None = None + self.dc_instance_url: str | None = None + + def _exchange_for_dc_token(self) -> dict: + """ + Exchange the current Salesforce access token for a Data Cloud (A360) token. + + Uses the OAuth 2.0 Token Exchange spec: + POST {instance_url}/services/a360/token + """ + access_token = self._base.ensure_access() + assert self._base.instance_url, "instance_url must be set after OAuth access" + + token_exchange_url = f"{self._base.instance_url}/services/a360/token" + response = requests.post( + token_exchange_url, + data={ + "grant_type": "urn:salesforce:grant-type:external:cdp", + "subject_token": access_token, + "subject_token_type": "urn:ietf:params:oauth:token-type:access_token", + }, + headers={"Accept": "application/json"}, + ) + + logger.info( + f"A360 token exchange response: status={response.status_code}, elapsed={response.elapsed.total_seconds():.2f}s" + ) + + if response.status_code >= 400: + logger.error(f"A360 token exchange failed: {response.text}") + + response.raise_for_status() + payload = response.json() + # Prefer any instance url present in response, fallback to base + self.dc_instance_url = ( + payload.get("instance_url") + or payload.get("a360_instance_url") + or self._base.instance_url + ) + return payload + + def ensure_dc_access(self) -> str: + if self.dc_exp is not None and datetime.now() > self.dc_exp: + self.dc_exp = None + self.dc_token = None + + if self.dc_token is None: + dc_auth_info = self._exchange_for_dc_token() + self.dc_token = dc_auth_info.get("access_token") + expires_in = dc_auth_info.get("expires_in") + if isinstance(expires_in, (int, float)) and expires_in > 0: + self.dc_exp = datetime.now() + timedelta(seconds=max(60, int(expires_in) - 60)) + else: + self.dc_exp = datetime.now() + timedelta(minutes=50) + + if not self.dc_token: + raise RuntimeError("A360 token exchange did not return access_token") + + return self.dc_token + + def get_token(self) -> str: + return self.ensure_dc_access() + + def get_instance_url(self) -> str: + # Ensure base is valid and token exchange sets dc_instance_url if needed + self._base.ensure_access() + if not self.dc_instance_url: + self.dc_instance_url = self._base.instance_url + return self.dc_instance_url # type: ignore[return-value] + + def get_requests_session(self) -> requests.Session: + token = self.ensure_dc_access() + sess = requests.Session() + sess.headers.update({"Authorization": f"Bearer {token}"}) + return sess + + +if __name__ == "__main__": + # Basic CLI to initialize a Data Cloud OAuth session + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + logger.setLevel(logging.DEBUG) + + cfg = OAuthConfig.from_env() + base_session = OAuthSession(cfg) + dc_session = base_session.create_dc_session() + + # Trigger token acquisition (will run interactive OAuth flow if needed) + dc_session.get_token() + print(f"Initialized Data Cloud session. Instance URL: {dc_session.get_instance_url()}") diff --git a/protos/salesforce/hyperdb/grpc/v1/README.md b/protos/salesforce/hyperdb/grpc/v1/README.md new file mode 100644 index 0000000..3fbe3df --- /dev/null +++ b/protos/salesforce/hyperdb/grpc/v1/README.md @@ -0,0 +1,7 @@ +# Hyper Service definition + +This is the main definition of Hyper Service gRPC API. + +## Updating the definition + +When you update this definition, please also update https://git.soma.salesforce.com/Hyper/hyper-service-proto. diff --git a/protos/salesforce/hyperdb/grpc/v1/error_details.proto b/protos/salesforce/hyperdb/grpc/v1/error_details.proto new file mode 100644 index 0000000..b20f759 --- /dev/null +++ b/protos/salesforce/hyperdb/grpc/v1/error_details.proto @@ -0,0 +1,86 @@ +/* + This definition is kept in sync at + * https://github.com/hyper-db-emu/hyper-db/blob/main/protos/salesforce/hyperdb/grpc/v1/error_details.proto + * https://git.soma.salesforce.com/a360/cdp-protos/blob/master/proto/hyperdb-proto/salesforce/hyperdb/grpc/v1/error_details.proto + + The version in https://github.com/hyper-db-emu/hyper-db is the source of + truth. Always update that verison first and then copy over the changes into + the other version. + + Furthermore, this is also mirrored into + https://github.com/forcedotcom/datacloud-jdbc/blob/main/jdbc-proto/src/main/proto/error_details.proto + + The public version is updated on demand, as we pull in new versions of hyperd + via HyperAPI. (The JDBC driver relies on the hyperd packaged in Hyper API for + testing). +*/ + +/* + This file contains the richer error model message types for HyperService + (defined in hyper_service.proto). For more details on richer error model see: + https://grpc.io/docs/guides/error/#richer-error-model +*/ +syntax = "proto3"; + +package salesforce.hyperdb.grpc.v1; + +// Ensure that we have a sensible java package name and that +// messages are individual classes for better dev experience. +option java_multiple_files = true; +option java_package = "salesforce.cdp.hyperdb.v1"; + +// Positional information on the error, references the user-provided SQL text +message TextPosition { + // Start offset, measured in unicode code points + uint64 error_begin_character_offset = 2; + // End offset, measured in unicode code points + uint64 error_end_character_offset = 3; +} + +// Error details for HyperService rpcs (defined in hyper_service.proto) +// +// This message is used in two ways: +// 1. by Hyper to report error information to its clients +// 2. by upstream services to report errors to Hyper +// +// We distuish between customer-visible information and internal ("system") +// information. Hyper makes sure to not expose the system-internal +// information via any of its public endpoints. +// +// Note for Upstream services: +// - Use `ErrorInfo` to enable Hyper to provide detailed error messages to +// customer and improved debuggability +// - When providing detailed information, you MUST provide both customer_detail +// and system_detail +// - If Hyper is unable to validate correctness of `ErrorInfo`, it will fallback +// to standard GRPC Error Model and map the error message to system_detail +message ErrorInfo { + // The primary (terse) error message + // MUST NOT contain sensitive data as it will be logged and returned to user + // MANDATORY FIELD + string primary_message = 1; + // The SQL state error code + // For upstream services: + // - ALLOWED to be empty + // - if set restrict to Class 28 (Invalid Authorization Specification) or + // Class 42 (Syntax Error or Access Rule Violation) + string sqlstate = 2; + // A suggestion on what what to do about the problem + // Differs from customer_detail by offering advise rather than hard facts + // Can be returned to the customer but MUST NOT be logged + string customer_hint = 3; + // Error detail with customer data classification + // Can be returned to the customer but MUST NOT be logged + string customer_detail = 4; + // Error detail with system data classification + // Can be logged but MUST NOT be forwarded to untrusted clients or customers + string system_detail = 5; + // Position information pertaining to the error + // This will be IGNORED for Error Info coming from upstream services of Hyper + TextPosition position = 6; + // The cause of the error + // ALLOWED values are "User" and "System" (case sensitive) + // Use "User" if and only if the error can only be caused by the end-user or + // customer + string error_source = 7; +} diff --git a/protos/salesforce/hyperdb/grpc/v1/hyper_service.proto b/protos/salesforce/hyperdb/grpc/v1/hyper_service.proto new file mode 100644 index 0000000..76a69bc --- /dev/null +++ b/protos/salesforce/hyperdb/grpc/v1/hyper_service.proto @@ -0,0 +1,517 @@ +/* + This definition is kept in sync at + * https://github.com/hyper-db-emu/hyper-db/blob/main/protos/salesforce/hyperdb/grpc/v1/hyper_service.proto + * https://git.soma.salesforce.com/a360/cdp-protos/blob/master/proto/hyperdb-proto/salesforce/hyperdb/grpc/v1/hyper_service.proto + + The version in https://github.com/hyper-db-emu/hyper-db is the source of + truth. Always update that version first and then copy over the changes into + the other version. + + Furthermore, this is also mirrored into + https://github.com/forcedotcom/datacloud-jdbc/blob/main/jdbc-proto/src/main/proto/hyper_service.proto + + The public version is updated on demand, as we pull in new versions of hyperd + via HyperAPI. (The JDBC driver relies on the hyperd packaged in Hyper API for + testing). +*/ +syntax = "proto3"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +package salesforce.hyperdb.grpc.v1; + +// Ensure that we have a sensible java package name and that +// messages are individual classes for better dev experience. +option java_multiple_files = true; +option java_package = "salesforce.cdp.hyperdb.v1"; +option java_outer_classname = "HyperDatabaseServiceProto"; + +// All methods under `HyperService` use the richer error model +// (https://grpc.io/docs/guides/error/#richer-error-model). Error details +// can contain messages types defined in error_details.proto +service HyperService { + // Submit a query / statement for execution and retrieve its result. + // The result stream will contain the result schema and rows in the requested + // OutputFormat. + rpc ExecuteQuery(QueryParam) returns (stream ExecuteQueryResponse); + // Get query information for a previous `ExecuteQuery` call. See `QueryInfo`. + // By default, this call will only return one update before ending the + // stream. A client can opt into a streaming mode that will continuously push + // updates until the query is done or a timeout is reached. + rpc GetQueryInfo(QueryInfoParam) returns (stream QueryInfo); + // Retrieve the results of a previous `ExecuteQuery`. See + // `QueryParam::TransferMode` + rpc GetQueryResult(QueryResultParam) returns (stream QueryResult); + // Attempts to cancel a query started via `ExecuteQuery`. + // The call is successful regardless of whether the query was actually + // canceled. + rpc CancelQuery(CancelQueryParam) returns (google.protobuf.Empty); +} + +// ---------------------------------------------------------------------------- +// Parameters passed to ExecuteQuery +// ---------------------------------------------------------------------------- + +// QueryParam represents a query SQL text and the execution context, +// such as settings, attached databases and parameters. Additionally, the +// transfer mode and output format for result chunks can be configured. +message QueryParam { + // The requested result transfer mode. + // By default, we recommend using `ADAPTIVE` for most workloads. + // + // `ADAPTIVE` and `ASYNC` are only supported for the `ARROW_IPC` and + // `JSON_ARRAY` output formats. All other output formats use `SYNC` mode. + enum TransferMode { + // TRANSFER_MODE_UNSPECIFIED defaults to ADAPTIVE + TRANSFER_MODE_UNSPECIFIED = 0; + // Only returns the header including the schema and query id. Results need + // to be fetched via `GetQueryResult`. Only supported for JSON_ARRAY and + // ARROW_IPC output format. + ASYNC = 1; + // All results will be returned with the `ExecuteQuery` call + // Using this mode is discouraged, as there is a 100s timeout. + SYNC = 2; + // Returns up to one result chunk synchronously. Subsequent result chunks + // may be retrieved via `GetQueryResult`. If the client does not retrieve + // the chunk in time, the server will close the connection. This does not + // imply that the query failed. It is the client’s obligation to query the + // status or refetch remaining rows from the first chunk in case of + // timeouts. + ADAPTIVE = 3; + } + + // The parameter style + // + // Hyper supports three different syntaxes for query parameters in SQL. + // This field allows specifying which syntax is used by the query. + enum ParameterStyle { + // PARAMETER_STYLE_UNSPECIFIED defaults to QUESTION_MARK + PARAMETER_STYLE_UNSPECIFIED = 0; + // Use the question mark `?` + QUESTION_MARK = 3; + // Use a Dollar sign followed by a number, e.g., `$1` + DOLLAR_NUMBERED = 1; + // Use named parameter, e.g., `:param` + NAMED = 2; + } + + // The SQL query text. + // See + // https://developer.salesforce.com/docs/data/data-cloud-query-guide/references/dc-sql-reference/ + // for documentation of Hyper's SQL. + string query = 1; + // Specify the list of attached databases for this query (optional) + repeated AttachedDatabase databases = 2; + // Specify the output format for query result data chunks. ARROW_IPC is + // recommended. The default is "unspecified" which currently maps to + // JSON_LEGACY_DICT. + OutputFormat output_format = 3; + // Settings to allow adjusting the execution of a query. + // See + // https://tableau.github.io/hyper-db/docs/hyper-api/connection#connection-settings + map settings = 4; + // See `TransferMode` + TransferMode transfer_mode = 5; + // See `ParameterStyle` + ParameterStyle param_style = 6; + // The parameter values for the parameters in the SQL query. + // + // The parameters are specified as a JSON object or an Arrow IPC message. + // + // The JSON object is a map of parameter names to values. The Arrow IPC + // message is a serialized Arrow schema and a serialized Arrow record batch. + // + // The Arrow IPC message is preferred, as it is more efficient. + oneof parameters { + // The Arrow parameters + QueryParameterArrow arrow_parameters = 7; + // The JSON parameters + QueryParameterJson json_parameters = 8; + } + // Specifies limits on the row count and byte size from the result set + // returned in this call. Only applicable for transfer mode ADAPTIVE and + // output formats JSON_ARRAY and ARROW_IPC. + optional ResultRange result_range = 9; + // Acts like a SQL LIMIT clause. It limits the output rows and stops executing + // once those are produced. Set to 0 in order to only retrieve the schema + // without actually producing any rows. + optional uint64 query_row_limit = 10; +} + +// The query parameters of type Arrow, used in `parameters` field of the +// `QueryParam` message +message QueryParameterArrow { + bytes data = 127; +} + +// The query parameters of type JSON, used in `parameters` field of the +// `QueryParam` message +message QueryParameterJson { + string data = 127; +} + +// Defines limits on the rows retrieved from a query result set +// Only applicable for output formats JSON_ARRAY and ARROW_IPC. +// Used by both `ExecuteQuery` and `GetQueryResult`. +message ResultRange { + // The (zero-based) row offset where the range starts inside the query result. + // This parameter is only applicable to `GetQueryResult`. + // When calling `ExecuteQuery`, it must be left at its default, i.e. zero. + uint64 row_offset = 1; + // The maximum number of rows to include in the `ExecuteQuery` or + // `GetQueryResult` response stream. Must be greater than zero if specified. + // If specified, less rows may be returned e.g. due to timeouts. Returning + // less rows is not an error. Just fetch the next rows using a new + // `GetQueryResult` call in this case. + optional uint64 row_limit = 2; + // The targeted maximum total size of the rows in the `ExecuteQuery` or + // `GetQueryResult` response stream, measured in bytes. Must be specified with + // a value greater than zero. If greater than the setting + // `row_based_pagination_max_byte_limit` (default: 20 MB), an error occurs. + // Returning less bytes is not an error. Just fetch the next rows using a new + // `GetQueryResult` call in this case. + uint64 byte_limit = 3; +} + +// The output formats currently supported by HyperService +// +// Since Hyper's protocol went through multiple iterations, we have a few +// deprecated, non-recommend formats. +// +// Only `ARROW_IPC` and `JSON_ARRAY` should be used for new workloads. +// The other formats will likely be removed in the future. Many of the other +// formats only support the `SYNC` transfer mode and are not fully supported +// for all HyperService methods. +enum OutputFormat { + // Encode the result chunk in a text-based format intended for debugging gRPC + // on the command line. Currently, this format is the same as + // `JSON_LEGACY_DICT`, which encodes the result as a JSON array. However, this + // format might change in the future. `JSON_ARRAY` or `ARROW_IPC` is strictly + // preferable. Not supported by `GetQueryResult` + OUTPUT_FORMAT_UNSPECIFIED = 0; + + // Formerly Hyper-Binary. Reserved as long as we expect clients to send it. + reserved 1; + + // Encode the result chunk in a proprietary variant similar to the open-source + // "Arrow IPC" format. + // + // Do not use this format when onboarding any new workloads. Not supported by + // `GetQueryResult`. `ARROW_IPC` is strictly preferable. + // + // Each result chunk consists of a schema and a record batch message. This is + // the original format of the gRPC proxy. For the JDBC Tableau connector, this + // format is passed through directly to a public Data Cloud API endpoint. As + // such, we cannot just drop support. + ARROW_LEGACY = 2; + + // Encode the result chunk as a JSON array of objects using the Query Service + // V1 SQL API convention. Not supported by `GetQueryResult`. + // + // Do not use this format when onboarding any new workloads. Not supported by + // `GetQueryResult`. `ARROW_IPC` and `JSON_ARRAY` are strictly preferable. + JSON_LEGACY_DICT = 3; + + // Encode the result chunk as part of a single Arrow IPC stream that + // encompasses all result chunks of a query. The first returned message will + // be a `QueryResultHeader` describing the schema, or a successful command. + // Only the first result chunk will contain an ARROW schema message. The + // following result chunks contain one or more record batch messages. + // + // Do not use this format when onboarding any new workloads. Not supported by + // `GetQueryResult`. `ARROW_IPC` is strictly preferable. + ARROW_LEGACY_IPC = 4; + + // The first message in the response stream is the `QueryStatus` with the + // query id. The result is encoded in multiple `QueryResultPartString` + // messages. In concatenation, these form one single Arrow IPC stream, with + // one Arrow schema message and one or more Arrow RecordBatches. Unlike + // ARROW_LEGACY_IPC, does not return QueryResultHeader. + ARROW_IPC = 5; + + // The first message in the response stream is the QueryStatus with the query + // id. Each following `QueryResultPartString` message is a JSON object. The + // first result message contains a `columns` array describing the column names + // and types. E.g. + // `{"columns":[{"name":"IntCol","type":"numeric","precision":38,"scale":18,"nullable":false},{"name":"TextCol","type":"varchar","nullable": + // true}]}` The following messages contain the result rows encoded as an array + // of array of JSON types. Each tuple is encoded as one array. E.g. + // `{"data":[[42, "Foo"], [1.4, null]]}` + JSON_ARRAY = 6; +} + +message AttachedDatabase { + // Access path for the database + string path = 1; + // Alias for the database under which it should be available in SQL + string alias = 2; +} + +// ---------------------------------------------------------------------------- +// Parameters for GetQueryInfo +// ---------------------------------------------------------------------------- + +// The parameters of the `GetQueryInfo` call +message QueryInfoParam { + // The query id unambiguously identifies a query. + // !!! You also have to send the query id as header (== gRPC metadata) with + // key `x-hyperdb-query-id`. + string query_id = 1; + // Whether new updates will be streamed to the client until the query is done + // or the timeout of 100s is reached. By default, only the current info + // message is sent. + bool streaming = 2; + // Specifies the output format for the query schema. + // OUTPUT_FORMAT_UNSPECIFIED means we won't send a schema. + // Currently, only JSON_ARRAY and ARROW_IPC are supported. + OutputFormat schema_output_format = 3; +} + +// ---------------------------------------------------------------------------- +// Parameters for GetQueryResult +// ---------------------------------------------------------------------------- + +// The parameters of the `GetQueryResult` call to unambiguously identify the +// query and the requested data +message QueryResultParam { + // The query id unambiguously identifies a query. + // !!! You also have to send the query id as header (== gRPC metadata) with + // key `x-hyperdb-query-id`. + string query_id = 1; + // Specifies the output format for the query result data. + // Currently, only JSON_ARRAY and ARROW_IPC are supported. + OutputFormat output_format = 2; + // One can either request a specific chunk or a specific range of rows. + oneof requested_data { + // The id of the chunk to retrieve. + uint64 chunk_id = 3; + // Limits on the rows retrieved from the query result set. + ResultRange result_range = 5; + } + // By default the schema + data is sent (a complete Arrow IPC stream in case + // of ARROW_IPC), in case that is not needed, the initial schema can be + // omitted. + bool omit_schema = 4; +} + +// ---------------------------------------------------------------------------- +// Parameters for CancelQuery +// ---------------------------------------------------------------------------- + +// The parameters of the `CancelQuery` call +message CancelQueryParam { + // The query id unambiguously identifies a query. + // !!! You also have to send the query id as header (== gRPC metadata) with + // key `x-hyperdb-query-id`. + string query_id = 1; +} + +// ---------------------------------------------------------------------------- +// Metadata about a query. +// ---------------------------------------------------------------------------- + +// Information about a query, such as its status, schema, and result size. +message QueryInfo { + oneof content { + // The status of the query + QueryStatus query_status = 1; + // The schema of the query result for a binary format (if requested via + // `schema_output_format`) + QueryResultPartBinary binary_schema = 3; + // The schema of the query result for a text format (if requested via + // `schema_output_format`) + QueryResultPartString string_schema = 4; + } + // Whether this message is optional or required for client processing. Clients + // MUST ignore optional messages which they do not know. + bool optional = 2; +} + +// The query status of a previous `ExecuteQuery` call +message QueryStatus { + // The completion status of the query. Errors will be indicated via structured + // gRPC errors. + enum CompletionStatus { + // RUNNING had to be renamed to RUNNING_OR_UNSPECIFIED in order to satisfy + // the salesforce proto guidelines. This is a band-aid solution to not break + // existing clients: The behavior of clients receiving an unknown enum + // value e.g. RUNNING = 3, if we did that, is implementation defined. The + // Java protobuf library does not fallback to the default enum value but + // deserializes to a special "UNRECOGNIZED". We cannot rely on all our + // existing clients handling this edge case correctly and uniformly. The + // query is in progress. + RUNNING_OR_UNSPECIFIED = 0; + // The query completed + // All results are ready to be fetched by the client. + RESULTS_PRODUCED = 1; + // The query status and results have been persisted and + // are now guaranteed to be available until the expiration time. + FINISHED = 2; + } + // The query id unambiguously identifies a query. + string query_id = 1; + // See `CompletionStatus` + CompletionStatus completion_status = 2; + // The number of chunks that the query has produced. If `completion_status == + // RUNNING_OR_UNSPECIFIED` this value may not be final. The chunks reported + // here can be retrieved via `GetQueryResult`. + uint64 chunk_count = 3; + // The number of rows that the query has produced. If `completion_status == + // RUNNING_OR_UNSPECIFIED` this value may not be final. The rows reported here + // can be retrieved via `GetQueryResult`. + uint64 row_count = 4; + // A number between 0.0 and 1.0 that indicates how much progress the query has + // made. For `completion_status = RESULTS_PRODUCED` and `completion_status = + // FINISHED` this is always 1.0. + double progress = 5; + // A timestamp (seconds since Unix epoch) indicating when the results won’t be + // available anymore. If `completion_status != FINISHED` this value may not be + // final. + google.protobuf.Timestamp expiration_time = 6; + // The Query execution statistics which contains elapsedTime + QueryExecutionStatistics execution_stats = 7; +} + +// The query execution stats present in QueryStatus response +message QueryExecutionStatistics { + // Server side elapsed wall clock time in seconds + double wall_clock_time = 1; + // Total number of rows processed which includes native, byolFileFederation, + // byolLiveQuery + uint64 rows_processed = 2; +} + +// ---------------------------------------------------------------------------- +// Query results +// ---------------------------------------------------------------------------- + +// The result of a query execution +message ExecuteQueryResponse { + oneof result { + // DEPRECATED + // Header is only used for legacy formats, see QueryResultHeader message + // comments. In a ExecuteQueryResponse stream the result header will always + // come first and will be followed by the configured result chunk type. + QueryResultHeader header = 1; + // DEPRECATED + // New formats use query_result.binary_part instead. + // A result part in binary format + QueryResultPartBinary binary_part = 4; + // DEPRECATED + // New formats use query_result.binary_part instead. + // A result part in textual format + QueryResultPartString string_part = 5; + // Information on the query + QueryInfo query_info = 6; + // Query result data + QueryResult query_result = 7; + } + // Whether this message is optional or required for client processing. Clients + // can skip over optional messages if they have no logic to process them. + bool optional = 9; +} + +// Result data of a query +message QueryResult { + oneof result { + // A result part in binary format + QueryResultPartBinary binary_part = 1; + // A result part in textual format + QueryResultPartString string_part = 2; + } + + // The number of rows contained in the result. + // If `result` only contains the schema, this field is zero. + uint64 result_part_row_count = 3; +} + +// Describes the schema of the query result +// Is only included for the following formats [OUTPUT_FORMAT_UNSPECIFIED, +// ARROW_LEGACY, JSON_LEGACY_DICT, ARROW_LEGACY_IPC] +message QueryResultHeader { + oneof header { + // Returned for normal queries (i.e., SELECT) + QueryResultSchema schema = 1; + // Returned when the query was of statement type + QueryCommandOk command = 2; + } +} + +// Returned for statements, some statements additionally return the affected row +// count. The server will only send this message once the changes of the +// statement are committed successfully. +message QueryCommandOk { + oneof command_return { + google.protobuf.Empty empty = 2; + uint64 affected_rows = 1; + } +} + +// Schema of the query result +message QueryResultSchema { + repeated ColumnDescription columns = 1; +} + +// Describes a column +message ColumnDescription { + string name = 1; + SqlType type = 2; +} + +// Type of a result column, provides additional information through the modifier +// field +message SqlType { + enum TypeTag { + HYPER_UNSPECIFIED = 0; + HYPER_BOOL = 1; + HYPER_BIG_INT = 2; + HYPER_SMALL_INT = 3; + HYPER_INT = 4; + HYPER_NUMERIC = 5; + HYPER_DOUBLE = 6; + HYPER_OID = 7; + HYPER_BYTE_A = 8; + HYPER_TEXT = 9; + HYPER_VARCHAR = 10; + HYPER_CHAR = 11; + HYPER_JSON = 12; + HYPER_DATE = 13; + HYPER_INTERVAL = 14; + HYPER_TIME = 15; + HYPER_TIMESTAMP = 16; + HYPER_TIMESTAMP_TZ = 17; + HYPER_GEOGRAPHY = 18; + HYPER_FLOAT = 19; + HYPER_ARRAY_OF_FLOAT = 20; + } + + // The precision of a numeric column + message NumericModifier { + uint32 precision = 1; + uint32 scale = 2; + } + + // Matches hyperapi::SqlType enum + TypeTag tag = 1; + // Additional type information, e.g. about precision + oneof modifier { + google.protobuf.Empty empty = 2; + // Only available if tag is a text type + uint32 max_length = 3; + // Only available if tag is a numeric type + NumericModifier numeric_modifier = 4; + } +} + +// A result part which contains multiple rows encoded in the binary format +// requested via the `output_format` field of the `QueryParam` message +message QueryResultPartBinary { + bytes data = 127; +} + +// A result part which contains multiple rows encoded in the textual format +// requested via the `output_format` field of the `QueryParam` message +message QueryResultPartString { + string data = 127; +} diff --git a/protos/salesforce/hyperdb/v1/query_status.proto b/protos/salesforce/hyperdb/v1/query_status.proto new file mode 100644 index 0000000..61979f8 --- /dev/null +++ b/protos/salesforce/hyperdb/v1/query_status.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package salesforce.hyperdb.v1; + +import "salesforce/hyperdb/grpc/v1/hyper_service.proto"; +import "salesforce/hyperdb/grpc/v1/error_details.proto"; + +// The format of a query status file for Query API V3. +// This is used only internally. +message QueryStatusFile { + // The query status + salesforce.hyperdb.grpc.v1.QueryStatus query_status = 1; + + // The serialized schema of the spooled data (if it exists) + salesforce.hyperdb.grpc.v1.QueryResultPartString result_schema = 3; + + // The error information + salesforce.hyperdb.grpc.v1.ErrorInfo error_info = 4; + + // The identifying features of the original requester, formatted as JSON. + // In CDP, for example, this would include tenant id, data space, user id and + // possibly more. + string original_requester_details = 5; + + // The result id + string result_id = 7; + // A list that stores the end offsets of tuples for each chunk. + // Since chunks are stored consecutively, the range of tuples for a chunk at + // position `i` extends from `chunk_offsets[i - 1]` (inclusive) to + // `chunk_offsets[i]` (exclusive). The first chunk starts from zero, so + // `chunk_offsets[0]` represents the end offset of the first chunk. + repeated uint64 chunk_offsets = 6; +} diff --git a/protos/salesforce/hyperdb/v1/sql_type.proto b/protos/salesforce/hyperdb/v1/sql_type.proto new file mode 100644 index 0000000..79cadfa --- /dev/null +++ b/protos/salesforce/hyperdb/v1/sql_type.proto @@ -0,0 +1,81 @@ +/* + This definition is kept in sync at + * https://github.com/hyper-db-emu/hyper-db/blob/main/protos/salesforce/hyperdb/v1/sql_type.proto + * https://git.soma.salesforce.com/a360/cdp-protos/blob/master/proto/hyperdb-proto/salesforce/hyperdb/v1/sql_type.proto + When you update one definition, please also update the other! +*/ +syntax = "proto3"; + +package salesforce.hyperdb.v1; + +option java_multiple_files = true; +option java_package = "com.salesforce.hyperdb.v1"; +option java_outer_classname = "HyperSqlTypeProto"; + +message SQLType { + bool nullable = 1; + + oneof kind { + Boolean boolean = 2; + SmallInt smallint = 3; + Integer integer = 5; + BigInt bigint = 7; + Float float = 25; + Double double = 11; + + String string = 12; + Binary binary = 13; + Timestamp timestamp = 14; + Date date = 16; + Time time = 17; + + FixedChar fixedchar = 21; + VarChar varchar = 22; + Decimal decimal = 24; + } + + message Boolean {} + message SmallInt {} + message Integer {} + message BigInt {} + message Float {} + message Double {} + message String {} + message Binary {} + message Timestamp {} + message Date {} + message Time {} + + message FixedChar { + uint32 length = 1; + } + message VarChar { + uint32 length = 1; + } + + message Decimal { + uint32 scale = 1; + uint32 precision = 2; + } +} + +message QualifiedTableName { + string database_name = 1; + string schema_name = 2; + string table_name = 3; +} + +message ColumnDescriptor { + string name = 1; + SQLType type = 2; + map properties = 3; +} + +message TableDescriptor { + QualifiedTableName name = 1; + repeated ColumnDescriptor columns = 2; + map metadata_properties = 3; + bool descriptor_only = 4; + // All tables referenced by this view in random order + repeated TableDescriptor referenced_tables = 5; +} diff --git a/requirements.txt b/requirements.txt index 3237391..12c7ac0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,9 @@ pydantic requests rfc3986 -mcp \ No newline at end of file +mcp +protobuf>=4.25.0 +grpcio-tools>=1.60.0 +grpcio>=1.60.0 +pyarrow>=16.0.0 +googleapis-common-protos>=1.63.0 \ No newline at end of file