From 1ccd814102cf0d021a97b63009371cb44ffd14df Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Fri, 22 Aug 2025 08:08:47 -0700 Subject: [PATCH 01/12] added postgres storage support --- graphrag/config/enums.py | 2 + graphrag/config/models/storage_config.py | 26 + .../index/operations/finalize_entities.py | 3 + graphrag/storage/factory.py | 2 + graphrag/storage/postgres_pipeline_storage.py | 1277 +++++++++++++++++ 5 files changed, 1310 insertions(+) create mode 100644 graphrag/storage/postgres_pipeline_storage.py diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 34b7765a67..9eb8a99a35 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -53,6 +53,8 @@ class StorageType(str, Enum): """The blob output type.""" cosmosdb = "cosmosdb" """The cosmosdb output type""" + postgres = "postgres" + """The postgres output type.""" def __repr__(self): """Get a string representation.""" diff --git a/graphrag/config/models/storage_config.py b/graphrag/config/models/storage_config.py index abd0936c7b..e89524479e 100644 --- a/graphrag/config/models/storage_config.py +++ b/graphrag/config/models/storage_config.py @@ -50,3 +50,29 @@ def validate_base_dir(cls, value, info): description="The cosmosdb account url to use.", default=graphrag_config_defaults.storage.cosmosdb_account_url, ) + + ### PostgreSQL + host: str = Field( + description="PostgreSQL server host (for postgres type).", + default="localhost" + ) + port: int = Field( + description="PostgreSQL server port (for postgres type).", + default=5432 + ) + database: str = Field( + description="PostgreSQL database name (for postgres type).", + default="graphrag" + ) + username: str | None = Field( + description="PostgreSQL username for authentication (for postgres type).", + default=None + ) + password: str | None = Field( + description="PostgreSQL password for authentication (for postgres type).", + default=None + ) + collection_prefix: str = Field( + description="Prefix for PostgreSQL collection names (for postgres type).", + default="graphrag_" + ) \ No newline at end of file diff --git a/graphrag/index/operations/finalize_entities.py b/graphrag/index/operations/finalize_entities.py index cd1dbb83eb..7c14bdd9c2 100644 --- a/graphrag/index/operations/finalize_entities.py +++ b/graphrag/index/operations/finalize_entities.py @@ -22,6 +22,9 @@ def finalize_entities( layout_enabled: bool = False, ) -> pd.DataFrame: """All the steps to transform final entities.""" + # Remove the default column degree, x and y for Postgres storage compatibility. And below entities.merge method + # will add them back with calculated values. + entities = entities.drop(columns=["degree", "x", "y"], errors="ignore") graph = create_graph(relationships, edge_attr=["weight"]) graph_embeddings = None if embed_config is not None and embed_config.enabled: diff --git a/graphrag/storage/factory.py b/graphrag/storage/factory.py index 81e7ba17b4..35d6255ac6 100644 --- a/graphrag/storage/factory.py +++ b/graphrag/storage/factory.py @@ -12,6 +12,7 @@ from graphrag.storage.blob_pipeline_storage import create_blob_storage from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage from graphrag.storage.file_pipeline_storage import create_file_storage +from graphrag.storage.postgres_pipeline_storage import PostgresPipelineStorage from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage if TYPE_CHECKING: @@ -99,3 +100,4 @@ def is_supported_storage_type(cls, storage_type: str) -> bool: StorageFactory.register(StorageType.cosmosdb.value, create_cosmosdb_storage) StorageFactory.register(StorageType.file.value, create_file_storage) StorageFactory.register(StorageType.memory.value, lambda **_: MemoryPipelineStorage()) +StorageFactory.register(StorageType.postgres.value, PostgresPipelineStorage) diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py new file mode 100644 index 0000000000..64957f6a68 --- /dev/null +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -0,0 +1,1277 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""PostgreSQL Storage implementation of PipelineStorage.""" + +import json +import logging +import re +from collections.abc import Iterator +from datetime import datetime, timezone +from io import BytesIO +from typing import Any + +import numpy as np +import pandas as pd +import asyncpg +from asyncpg import Connection, Pool + +from graphrag.storage.pipeline_storage import ( + PipelineStorage, + get_timestamp_formatted_with_local_tz, +) + +log = logging.getLogger(__name__) + +class PostgresPipelineStorage(PipelineStorage): + """The PostgreSQL Storage Implementation.""" + + _pool: Pool | None + _connection_string: str + _database: str + _collection_prefix: str + _encoding: str + _no_id_prefixes: list[str] + + def __init__( + self, + host: str = "localhost", + port: int = 5432, + database: str = "graphrag", + username: str = "postgres", + password: str | None = None, + collection_prefix: str = "lgr_", + encoding: str = "utf-8", + connection_string: str | None = None, + **kwargs: Any, + ): + """Initialize the PostgreSQL Storage.""" + self._host = host + self._port = port + self._database = database + self._username = username + self._password = password + self._collection_prefix = collection_prefix + self._encoding = encoding + self._no_id_prefixes = [] + self._pool = None + + # Build connection string from components or use provided one + if connection_string: + self._connection_string = connection_string + else: + if password: + self._connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}" + else: + self._connection_string = f"postgresql://{username}@{host}:{port}/{database}" + + log.info( + "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s", + self._host, + self._port, + self._database, + self._collection_prefix, + ) + + async def _get_connection(self) -> Connection: + """Get a database connection from the pool.""" + if self._pool is None: + try: + self._pool = await asyncpg.create_pool( + self._connection_string, + min_size=1, + max_size=10, + command_timeout=60 + ) + log.info("Created PostgreSQL connection pool") + except Exception as e: + log.error("Failed to create PostgreSQL connection pool: %s", e) + raise + + return await self._pool.acquire() + + async def _release_connection(self, conn: Connection) -> None: + """Release a connection back to the pool.""" + if self._pool: + await self._pool.release(conn) + + def _sanitize_table_name(self, name: str) -> str: + """Sanitize a name to be a valid PostgreSQL table name.""" + return name + import re + + # Replace common problematic characters + sanitized = name.replace("-", "_").replace(":", "_").replace(" ", "_") + + # Remove any characters that aren't alphanumeric, underscore, or dollar sign + sanitized = re.sub(r'[^a-zA-Z0-9_$]', '_', sanitized) + + # Remove consecutive underscores + sanitized = re.sub(r'_+', '_', sanitized) + + # Remove leading/trailing underscores + sanitized = sanitized.strip('_') + + # Ensure it starts with a letter or underscore (not a digit) + if sanitized and sanitized[0].isdigit(): + sanitized = f"tbl_{sanitized}" + + # Ensure it's not empty + if not sanitized: + sanitized = "tbl_unnamed" + + # PostgreSQL has a limit of 63 characters for identifiers + if len(sanitized) > 59: # Leave room for prefix + sanitized = sanitized[:59] + log.info(f"Sanitied name {name} to {sanitized}") + return sanitized + + def _get_table_name(self, key: str) -> str: + """Get the table name for a given key.""" + # Extract the base name without file extension + base_name = key.split(".")[0] + + # Sanitize for PostgreSQL compatibility + sanitized_name = self._sanitize_table_name(base_name) + + return f"{self._collection_prefix}{sanitized_name}" + + def _get_prefix(self, key: str) -> str: + """Get the prefix of the filename key.""" + return key.split(".")[0] + + def _get_entities_table_schema(self, table_name: str) -> str: + """Get the SQL schema for entities table.""" + # Sanitize table name for index names + table_name = self._sanitize_table_name(table_name) + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + title TEXT, + type TEXT, + description TEXT, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + frequency INTEGER DEFAULT 0, + degree INTEGER DEFAULT 0, + x DOUBLE PRECISION DEFAULT 0.0, + y DOUBLE PRECISION DEFAULT 0.0, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Performance indexes + CREATE INDEX idx_{table_name}_type ON {table_name}(type); + CREATE INDEX idx_{table_name}_frequency ON {table_name}(frequency); + CREATE INDEX idx_{table_name}_title ON {table_name}(title); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + """ + + def _get_relationships_table_schema(self, table_name: str) -> str: + """Get the SQL schema for relationships table.""" + # Sanitize table name for index names + table_name = self._sanitize_table_name(table_name) + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + source TEXT NOT NULL, + target TEXT NOT NULL, + description TEXT DEFAULT '', + weight DOUBLE PRECISION DEFAULT 0.0, + combined_degree INTEGER DEFAULT 0, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Graph query indexes + CREATE INDEX idx_{table_name}_source ON {table_name}(source); + CREATE INDEX idx_{table_name}_target ON {table_name}(target); + CREATE INDEX idx_{table_name}_weight ON {table_name}(weight); + CREATE INDEX idx_{table_name}_source_target ON {table_name}(source, target); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + """ + + def _get_communities_table_schema(self, table_name: str) -> str: + """Get the SQL schema for communities table.""" + # Sanitize table name for index names + table_name = self._sanitize_table_name(table_name) + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + community INTEGER, + level INTEGER DEFAULT 0, + parent INTEGER, + children JSONB DEFAULT '[]'::jsonb, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + entity_ids JSONB DEFAULT '[]'::jsonb, + relationship_ids JSONB DEFAULT '[]'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Community hierarchy indexes + CREATE INDEX idx_{table_name}_community ON {table_name}(community); + CREATE INDEX idx_{table_name}_level ON {table_name}(level); + CREATE INDEX idx_{table_name}_parent ON {table_name}(parent); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); + CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); + """ + + def _get_text_units_table_schema(self, table_name: str) -> str: + """Get the SQL schema for text_units table.""" + # Sanitize table name for index names + table_name = self._sanitize_table_name(table_name) + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + text TEXT, + n_tokens INTEGER DEFAULT 0, + document_ids JSONB DEFAULT '[]'::jsonb, + entity_ids JSONB DEFAULT '[]'::jsonb, + relationship_ids JSONB DEFAULT '[]'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Text search and relationship indexes + CREATE INDEX idx_{table_name}_n_tokens ON {table_name}(n_tokens); + CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); + CREATE INDEX idx_{table_name}_document_ids_gin ON {table_name} USING GIN(document_ids); + CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); + CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); + """ + + def _get_documents_table_schema(self, table_name: str) -> str: + """Get the SQL schema for documents table.""" + # Sanitize table name for index names + table_name = self._sanitize_table_name(table_name) + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + title TEXT, + text TEXT, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + creation_date TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + metadata JSONB DEFAULT '{{}}'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Document search indexes + CREATE INDEX idx_{table_name}_title ON {table_name}(title); + CREATE INDEX idx_{table_name}_creation_date ON {table_name}(creation_date); + CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); + """ + + def _get_generic_table_schema(self, table_name: str) -> str: + """Get the SQL schema for generic data (fallback).""" + # Sanitize table name for index names + table_name = self._sanitize_table_name(table_name) + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + data JSONB NOT NULL, + metadata JSONB DEFAULT '{{}}'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Generic indexes + CREATE INDEX idx_{table_name}_data_gin ON {table_name} USING GIN(data); + CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); + """ + + def _get_table_schema_sql(self, table_name: str) -> str: + """Get the appropriate schema SQL for the table type.""" + # Sanitize table name for schema generation + table_name = self._sanitize_table_name(table_name) + + if 'entities' in table_name: + return self._get_entities_table_schema(table_name) + elif 'relationships' in table_name: + return self._get_relationships_table_schema(table_name) + elif 'communities' in table_name: + return self._get_communities_table_schema(table_name) + elif 'text_units' in table_name: + return self._get_text_units_table_schema(table_name) + elif 'documents' in table_name: + return self._get_documents_table_schema(table_name) + else: + return self._get_generic_table_schema(table_name) + + async def _ensure_table_exists_with_schema(self, table_name: str) -> None: + # Ensure table name is properly sanitized for SQL operations + table_name = self._sanitize_table_name(table_name) + + conn = await self._get_connection() + try: + table_exists = await conn.fetchval( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + table_name + ) + if not table_exists: + # Create table with appropriate typed schema (pass original table_name for type detection) + schema_sql = self._get_table_schema_sql(table_name) + await conn.execute(schema_sql) + log.info(f"Created table {table_name} with specific schema") + + finally: + await self._release_connection(conn) + + def _process_id_field(self, df: pd.DataFrame, table_name: str) -> list[str]: + """Process ID values - store clean IDs with prefix following CosmosDB pattern in GraphRAG.""" + prefix = self._get_prefix(table_name.replace(self._collection_prefix, "")) + id_values = [] + + if "id" not in df.columns: + # No ID column - create prefixed sequential IDs and track this prefix + for index in range(len(df)): + id_values.append(f"{prefix}:{index}") + if prefix not in self._no_id_prefixes: + self._no_id_prefixes.append(prefix) + log.info(f"No ID column found for {prefix}, generated prefixed sequential IDs") + else: + # Has ID column - process each row with prefix + for index, val in enumerate(df["id"]): + if self._is_scalar_na(val) or val == '' or val == 'nan' or (isinstance(val, list) and (len(val) == 0 or self._is_scalar_na(val[0]) or str(val[0]).strip() == '')): + # Missing ID - create prefixed sequential ID and track this prefix + id_values.append(f"{prefix}:{index}") + if prefix not in self._no_id_prefixes: + self._no_id_prefixes.append(prefix) + else: + # Valid ID - use as is without prefix + if isinstance(val, list): + id_values.append(str(val[0])) + else: + id_values.append(str(val)) + + return id_values + + def _is_scalar_na(self, value: Any) -> bool: + """Safely check if a value is NA/null, avoiding issues with arrays.""" + try: + # Don't check pd.isna on complex objects or large arrays + if isinstance(value, (list, dict)): + return False + if hasattr(value, '__len__') and len(str(value)) > 100: + return False + return pd.isna(value) + except (ValueError, TypeError): + # If pd.isna fails, assume it's not NA + return False + + def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[dict]: + """Prepare DataFrame data for PostgreSQL insertion with typed columns.""" + log.info(f"Preparing data for table {table_name}, DataFrame shape: {df.shape}") + log.info(f"DataFrame columns: {df.columns.tolist()}") + + # Add human_readable_id if missing + if 'human_readable_id' not in df.columns: + df = df.copy() + df['human_readable_id'] = range(len(df)) + log.info(f"Generated sequential human_readable_id for {len(df)} records") + + # Process IDs - for typed tables, we can use simpler ID handling + ids = self._process_id_field(df, table_name) + + # Determine if this is a typed table or generic table + is_typed_table = any(table_type in table_name for table_type in + ['entities', 'relationships', 'communities', 'text_units', 'documents']) + + if is_typed_table: + return self._prepare_data_for_typed_table(df, table_name, ids) + else: + return self._prepare_data_for_generic_table(df, table_name, ids) + + def _prepare_data_for_typed_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: + """Prepare data for typed PostgreSQL tables with specific columns.""" + records = [] + + for i in range(len(df)): + record = {'id': ids[i]} + row = df.iloc[i] + + # Map DataFrame columns to table columns based on table type + if 'entities' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'title': str(row.get('title', '')), + 'type': str(row.get('type', '')), + 'description': str(row.get('description', '')), + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), + 'frequency': int(row.get('frequency', 0)) if pd.notna(row.get('frequency', 0)) else 0, + 'degree': int(row.get('degree', 0)) if pd.notna(row.get('degree', 0)) else 0, + 'x': float(row.get('x', 0.0)) if pd.notna(row.get('x', 0.0)) else 0.0, + 'y': float(row.get('y', 0.0)) if pd.notna(row.get('y', 0.0)) else 0.0 + }) + elif 'relationships' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'source': str(row.get('source', '')), + 'target': str(row.get('target', '')), + 'description': str(row.get('description', '')), + 'weight': float(row.get('weight', 0.0)) if pd.notna(row.get('weight', 0.0)) else 0.0, + 'combined_degree': int(row.get('combined_degree', 0)) if pd.notna(row.get('combined_degree', 0)) else 0, + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])) + }) + elif 'communities' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'community': int(row.get('community', 0)) if pd.notna(row.get('community')) and str(row.get('community', '')).strip() != '' else 0, + 'level': int(row.get('level', 0)) if pd.notna(row.get('level', 0)) else 0, + 'parent': int(row.get('parent', 0)) if pd.notna(row.get('parent')) and str(row.get('parent', '')).strip() != '' else None, + 'children': self._ensure_json_list(row.get('children', [])), + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), + 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), + 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) + }) + elif 'text_units' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'text': str(row.get('text', '')), + 'n_tokens': int(row.get('n_tokens', 0)) if pd.notna(row.get('n_tokens', 0)) else 0, + 'document_ids': self._ensure_json_list(row.get('document_ids', [])), + 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), + 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) + }) + elif 'documents' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'title': str(row.get('title', '')), + 'text': str(row.get('text', '')), + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), + 'creation_date': self._ensure_datetime(row.get('creation_date')), + 'metadata': self._ensure_json_dict(row.get('metadata', {})) + }) + + records.append(record) + + log.info(f"Prepared {len(records)} records for typed table {table_name}") + if records: + log.info(f"Sample typed record: {list(records[0].keys())}") + + return records + + def _prepare_data_for_generic_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: + """Prepare data for generic PostgreSQL tables (fallback to JSONB storage).""" + records = [] + for i in range(len(df)): + # Create record with ID and all data in JSONB field + record_data = df.iloc[i].to_dict() + + # Convert numpy types to native Python types for JSON serialization + for key, value in record_data.items(): + if isinstance(value, (list, dict)): + record_data[key] = value + elif hasattr(value, 'tolist'): + # Handle numpy arrays and other numpy types + record_data[key] = value.tolist() + elif hasattr(value, 'item') and hasattr(value, 'size') and value.size == 1: + record_data[key] = value.item() + elif isinstance(value, (pd.Timestamp, pd.DatetimeTZDtype)): + record_data[key] = value.isoformat() if pd.notna(value) else None + elif self._is_scalar_na(value): + record_data[key] = None + elif isinstance(value, (list, np.ndarray)) and len(value) == 0: + record_data[key] = [] + else: + record_data[key] = value + + record = { + 'id': ids[i], + 'data': record_data, + 'metadata': {} + } + records.append(record) + + log.info(f"Prepared {len(records)} records for generic table {table_name}") + return records + + def _ensure_json_list(self, value: Any) -> list: + """Ensure a value is a proper list for JSONB storage.""" + if isinstance(value, list): + # Convert any numpy arrays in the list to regular Python lists + return [item.tolist() if hasattr(item, 'tolist') else item for item in value] + elif hasattr(value, 'tolist'): + # Handle numpy arrays directly + converted = value.tolist() + return converted if isinstance(converted, list) else [converted] + elif isinstance(value, str) and value: + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, list) else [] + except (json.JSONDecodeError, TypeError): + return [] + elif value is None or pd.isna(value): + return [] + else: + return [value] if value else [] + + def _ensure_json_dict(self, value: Any) -> dict: + """Ensure a value is a proper dict for JSONB storage.""" + if isinstance(value, dict): + # Convert any numpy arrays in the dict to regular Python objects + result = {} + for k, v in value.items(): + if hasattr(v, 'tolist'): + result[k] = v.tolist() + elif hasattr(v, 'item') and hasattr(v, 'size') and v.size == 1: + result[k] = v.item() + else: + result[k] = v + return result + elif isinstance(value, str) and value: + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, dict) else {} + except (json.JSONDecodeError, TypeError): + return {} + elif value is None or pd.isna(value): + return {} + else: + return {'value': str(value)} if value else {} + + def _ensure_timezone_aware_datetimes(self, records: list[dict]) -> list[dict]: + """Ensure all datetime fields in records are timezone-aware for PostgreSQL.""" + datetime_fields = ['creation_date', 'created_at', 'updated_at'] + + for record in records: + for field in datetime_fields: + if field in record: + value = record[field] + if value is not None: + record[field] = self._ensure_datetime(value) + + return records + + def _ensure_datetime(self, value: Any) -> datetime: + """Ensure a value is a proper timezone-aware datetime object for PostgreSQL storage.""" + from dateutil import parser + + if isinstance(value, datetime): + # If it's already a datetime, ensure it has timezone info + if value.tzinfo is None: + # If it's timezone-naive, localize to UTC + return value.replace(tzinfo=timezone.utc) + else: + # Already timezone-aware + return value + elif isinstance(value, pd.Timestamp): + # Convert pandas Timestamp to datetime + dt = value.to_pydatetime() + # Ensure timezone awareness + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + else: + return dt + elif isinstance(value, str) and value: + try: + # Try to parse the string as a datetime + parsed_dt = parser.parse(value) + # Ensure timezone awareness + if parsed_dt.tzinfo is None: + return parsed_dt.replace(tzinfo=timezone.utc) + else: + return parsed_dt + except (ValueError, TypeError): + # If parsing fails, return current time + return datetime.now(timezone.utc) + elif value is None or pd.isna(value): + return datetime.now(timezone.utc) + else: + # For any other type, return current time + return datetime.now(timezone.utc) + + async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict], batch_size: int = 1000) -> None: + """Perform high-performance batch upsert of records using executemany.""" + total_records = len(records) + log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {batch_size}") + + # Ensure all datetime fields are timezone-aware + records = self._ensure_timezone_aware_datetimes(records) + + processed_count = 0 + + # Determine if this is a typed table or generic table + is_typed_table = any(table_type in table_name for table_type in + ['entities', 'relationships', 'communities', 'text_units', 'documents']) + + # Process records in batches for optimal performance + for i in range(0, total_records, batch_size): + batch = records[i:i + batch_size] + batch_end = min(i + batch_size, total_records) + + try: + if is_typed_table: + await self._batch_upsert_typed_records(conn, table_name, batch) + else: + await self._batch_upsert_generic_records(conn, table_name, batch) + + except Exception as e: + log.warning(f"Batch method failed for batch {i}-{batch_end}, falling back to individual inserts: {e}") + + # Fallback to individual inserts within the batch + try: + async with conn.transaction(): + if is_typed_table: + for record in batch: + await self._insert_typed_record(conn, table_name, record) + else: + upsert_sql = f""" + INSERT INTO {table_name} (id, data, updated_at) + VALUES ($1, $2, NOW()) + ON CONFLICT (id) + DO UPDATE SET + data = EXCLUDED.data, + updated_at = NOW() + """ + for record in batch: + await conn.execute(upsert_sql, record['id'], json.dumps(record['data'])) + except Exception as inner_e: + log.error(f"Both batch and individual insert methods failed for batch {i}-{batch_end}: {inner_e}") + raise + + processed_count += len(batch) + + # Log progress every batch for visibility + if i % batch_size == 0 or batch_end == total_records: + log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") + + async def _batch_upsert_typed_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: + """Batch upsert for typed tables with specific columns.""" + async with conn.transaction(): + # Ensure table name is properly sanitized for SQL + table_name = self._sanitize_table_name(table_name) + + if 'entities' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + title = EXCLUDED.title, + type = EXCLUDED.type, + description = EXCLUDED.description, + text_unit_ids = EXCLUDED.text_unit_ids, + frequency = EXCLUDED.frequency, + degree = EXCLUDED.degree, + x = EXCLUDED.x, + y = EXCLUDED.y, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['title'], r['type'], r['description'], + json.dumps(r['text_unit_ids']), r['frequency'], r['degree'], r['x'], r['y']) + for r in batch + ] + elif 'relationships' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + source = EXCLUDED.source, + target = EXCLUDED.target, + description = EXCLUDED.description, + weight = EXCLUDED.weight, + combined_degree = EXCLUDED.combined_degree, + text_unit_ids = EXCLUDED.text_unit_ids, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['source'], r['target'], r['description'], + r['weight'], r['combined_degree'], json.dumps(r['text_unit_ids'])) + for r in batch + ] + elif 'communities' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + community = EXCLUDED.community, + level = EXCLUDED.level, + parent = EXCLUDED.parent, + children = EXCLUDED.children, + text_unit_ids = EXCLUDED.text_unit_ids, + entity_ids = EXCLUDED.entity_ids, + relationship_ids = EXCLUDED.relationship_ids, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['community'], r['level'], r['parent'], + json.dumps(r['children']), json.dumps(r['text_unit_ids']), + json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) + for r in batch + ] + elif 'text_units' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + text = EXCLUDED.text, + n_tokens = EXCLUDED.n_tokens, + document_ids = EXCLUDED.document_ids, + entity_ids = EXCLUDED.entity_ids, + relationship_ids = EXCLUDED.relationship_ids, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['text'], r['n_tokens'], + json.dumps(r['document_ids']), json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) + for r in batch + ] + elif 'documents' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, title, text, text_unit_ids, creation_date, metadata, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + title = EXCLUDED.title, + text = EXCLUDED.text, + text_unit_ids = EXCLUDED.text_unit_ids, + creation_date = EXCLUDED.creation_date, + metadata = EXCLUDED.metadata, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['title'], r['text'], + json.dumps(r['text_unit_ids']), + self._ensure_datetime(r['creation_date']), + json.dumps(r['metadata'])) + for r in batch + ] + else: + raise ValueError(f"Unknown typed table: {table_name}") + + await conn.executemany(upsert_sql, batch_data) + + async def _batch_upsert_generic_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: + """Batch upsert for generic tables using JSONB.""" + async with conn.transaction(): + # Ensure table name is properly sanitized for SQL + table_name = self._sanitize_table_name(table_name) + upsert_sql = f""" + INSERT INTO {table_name} (id, data, metadata, updated_at) + VALUES ($1, $2, $3, NOW()) + ON CONFLICT (id) + DO UPDATE SET + data = EXCLUDED.data, + metadata = EXCLUDED.metadata, + updated_at = NOW() + """ + batch_data = [ + (record['id'], json.dumps(record['data']), json.dumps(record['metadata'])) + for record in batch + ] + await conn.executemany(upsert_sql, batch_data) + + async def _insert_typed_record(self, conn: Connection, table_name: str, record: dict) -> None: + """Insert a single typed record (fallback method).""" + # This is a simplified fallback - implement based on table type if needed + # For now, just use the batch method with a single record + await self._batch_upsert_typed_records(conn, table_name, [record]) + + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find data in PostgreSQL tables using a file pattern regex.""" + # This is a synchronous method, but we need async operations + # For now, implement a basic version - in practice, this would need refactoring + log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) + + # Note: This is simplified - full implementation would need async/await support + # in the find method signature or use asyncio.run() + return iter([]) + + async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: + """Retrieve data from PostgreSQL table.""" + try: + table_name = self._get_table_name(key) + log.info(f"Retrieving data from table: {table_name}") + + conn = await self._get_connection() + try: + # Check if table exists + table_exists = await conn.fetchval( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + table_name + ) + + if not table_exists: + log.warning(f"Table {table_name} does not exist") + return None + + # Determine if this is a typed table or generic table + is_typed_table = any(table_type in table_name for table_type in + ['entities', 'relationships', 'communities', 'text_units', 'documents']) + + if is_typed_table: + # For typed tables, select all columns except created_at/updated_at + if 'documents' in table_name: + query = "SELECT id, human_readable_id, title, text, text_unit_ids, creation_date, metadata FROM {} ORDER BY created_at".format(table_name) + elif 'entities' in table_name: + query = "SELECT id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y FROM {} ORDER BY created_at".format(table_name) + elif 'relationships' in table_name: + query = "SELECT id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids FROM {} ORDER BY created_at".format(table_name) + elif 'communities' in table_name: + query = "SELECT id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) + elif 'text_units' in table_name: + query = "SELECT id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) + else: + # Fallback for unknown typed table + query = "SELECT * FROM {} ORDER BY created_at".format(table_name) + else: + # For generic tables, use the data column + query = "SELECT id, data FROM {} ORDER BY created_at".format(table_name) + + rows = await conn.fetch(query) + + if not rows: + log.info(f"No data found in table {table_name}") + return None + + log.info(f"Retrieved {len(rows)} records from table {table_name}") + + # Check if this should be treated as raw data instead of tabular data + if (not key.endswith('.parquet') or + 'state' in key.lower() or + key.endswith('.json') or + 'context' in table_name.lower()): + # Handle state.json or context.json as raw data + # For non-tabular data, return the raw content from the first record + if rows: + if is_typed_table: + # For typed tables, convert row to dict and return as JSON + row_dict = dict(rows[0]) + json_str = json.dumps(row_dict) + return json_str.encode(encoding or self._encoding) if as_bytes else json_str + elif 'data' in rows[0]: + raw_content = rows[0]['data'] + if isinstance(raw_content, dict): + json_str = json.dumps(raw_content) + return json_str.encode(encoding or self._encoding) if as_bytes else json_str + return b"" if as_bytes else "" + + # Convert to DataFrame + records = [] + for row in rows: + if is_typed_table: + # For typed tables, the row is already the data we need + record_data = dict(row) + + # Convert JSONB fields back to proper Python objects + for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: + if field in record_data: + value = record_data[field] + + if value is None: + record_data[field] = {} if field == 'metadata' else [] + elif isinstance(value, str): + # Handle JSONB strings - they should always be valid JSON + try: + parsed = json.loads(value) + # Validate the parsed type + if field == 'metadata': + record_data[field] = parsed if isinstance(parsed, dict) else {} + else: + record_data[field] = parsed if isinstance(parsed, list) else [] + except (json.JSONDecodeError, TypeError): + log.warning(f"Failed to parse JSONB field {field}: {value}") + # Fallback for non-JSON strings + if field == 'metadata': + record_data[field] = {} + else: + record_data[field] = [] + elif isinstance(value, (list, dict)): + # Already correct type (shouldn't happen with JSONB, but handle it) + record_data[field] = value + else: + # Convert other types + if field == 'metadata': + record_data[field] = {'value': str(value)} if value else {} + else: + record_data[field] = [value] if value else [] + else: + # Handle generic table data (JSONB data column) + if isinstance(row['data'], dict): + record_data = dict(row['data']) + else: + # If it's a string, parse it as JSON + record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] + + # Clean up the record data - convert None to proper values and handle NaN + cleaned_data = {} + for key_name, value in record_data.items(): + if self._is_scalar_na(value) or value is None: + cleaned_data[key_name] = None + elif isinstance(value, str) and key_name == 'text_unit_ids' and not is_typed_table: + # Try to parse text_unit_ids back from JSON string if needed (only for generic tables) + try: + parsed_value = json.loads(value) + cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [value] + except (json.JSONDecodeError, TypeError): + # If it's not JSON, treat as a single item list or keep as string + cleaned_data[key_name] = [value] if value else [] + elif key_name in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids'] and not is_typed_table: + # Always ensure these columns are lists (only for generic tables - typed tables already handled this) + if isinstance(value, list): + cleaned_data[key_name] = value + elif isinstance(value, str): + try: + parsed_value = json.loads(value) + cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [] + except (json.JSONDecodeError, TypeError): + cleaned_data[key_name] = [] + elif value is None: + cleaned_data[key_name] = [] + else: + # fallback: wrap single value in a list + cleaned_data[key_name] = [value] + elif isinstance(value, (list, np.ndarray)) and len(value) == 0: + # Handle empty arrays/lists + cleaned_data[key_name] = [] + else: + cleaned_data[key_name] = value + + # Always include the ID column for GraphRAG compatibility + # Use the storage ID as is since we simplified ID handling + storage_id = row['id'] + cleaned_data['id'] = storage_id + + records.append(cleaned_data) + + df = pd.DataFrame(records) + + # Additional cleanup for NaN values in the DataFrame + df = df.where(pd.notna(df), None) + log.info(f"Created DataFrame with shape: {df.shape}") + log.info(f"Table {table_name} DataFrame columns: {df.columns.tolist()}") + + if len(df) > 0: + log.debug(f"Table {table_name} Sample record: {df.iloc[0].to_dict()}") + # Debug: Check if children column exists and its type + if 'children' in df.columns: + sample_children = df.iloc[0]['children'] + log.debug(f"Table {table_name} Sample children value: {sample_children}, type: {type(sample_children)}") + + # Handle bytes conversion for GraphRAG compatibility + if as_bytes or kwargs.get("as_bytes"): + log.info(f"Converting DataFrame to parquet bytes for key: {key}") + + # Apply column filtering similar to Milvus implementation + df_clean = df.copy() + + # Define expected columns for each data type + if 'documents' in table_name: + expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] + elif 'entities' in table_name: + expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'] + elif 'relationships' in table_name: + expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] + if 'combined_degree' in df_clean.columns: + expected_columns.append('combined_degree') + elif 'text_units' in table_name: + expected_columns = ['id', 'human_readable_id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] + elif 'communities' in table_name: + expected_columns = ['id', 'human_readable_id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] + else: + expected_columns = list(df_clean.columns) + + # Filter columns + available_columns = [col for col in expected_columns if col in df_clean.columns] + if available_columns != expected_columns: + missing = set(expected_columns) - set(available_columns) + extra = set(df_clean.columns) - set(expected_columns) + log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") + + df_clean = df_clean[available_columns] + log.info(f"Table {table_name} final filtered columns: {df_clean.columns.tolist()}") + + # Convert to parquet bytes + try: + # Handle list columns for PyArrow compatibility + df_for_parquet = df_clean.copy() + + # For PyArrow/parquet compatibility, we need to handle list columns carefully + # Instead of converting to JSON strings, let's try a different approach + list_columns = [] + for col in df_for_parquet.columns: + if col in df_for_parquet.columns and len(df_for_parquet) > 0: + # Check if this column contains lists + first_non_null = None + for val in df_for_parquet[col]: + if isinstance(val, list): + first_non_null = val + break + elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): + first_non_null = val + break + + if isinstance(first_non_null, list) or col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: + list_columns.append(col) + # Ensure all values in this column are proper lists + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: x if isinstance(x, list) else ([] if x is None or pd.isna(x) else [str(x)]) + ) + + if list_columns: + log.info(f"Ensured list columns are proper lists for parquet: {list_columns}") + + # Try to convert to parquet without JSON string conversion + buffer = BytesIO() + df_for_parquet.to_parquet(buffer, engine='pyarrow') + buffer.seek(0) + parquet_bytes = buffer.getvalue() + log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") + return parquet_bytes + except Exception as e: + log.warning(f"Direct parquet conversion failed: {e}, trying with JSON string conversion") + + # Fallback: convert lists to JSON strings + try: + df_for_parquet = df_clean.copy() + + # Convert list columns to JSON strings for parquet compatibility + list_columns = [] + for col in df_for_parquet.columns: + if col in df_for_parquet.columns and len(df_for_parquet) > 0: + # Check if this column contains lists + first_non_null = None + for val in df_for_parquet[col]: + if isinstance(val, list): + first_non_null = val + break + elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): + first_non_null = val + break + if isinstance(first_non_null, list): + list_columns.append(col) + # Convert lists to JSON strings + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + ) + elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: + # These columns should always be lists, even if empty + list_columns.append(col) + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + ) + + if list_columns: + log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") + + buffer = BytesIO() + df_for_parquet.to_parquet(buffer, engine='pyarrow') + buffer.seek(0) + parquet_bytes = buffer.getvalue() + log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data (with JSON conversion)") + return parquet_bytes + except Exception as e2: + log.exception(f"Failed to convert DataFrame to parquet bytes: {e2}") + return b"" + + return df + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception(f"Error retrieving data from table {table_name}: {e}") + return None + + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Insert data into PostgreSQL table with drop/recreate to avoid duplicates.""" + try: + table_name = self._get_table_name(key) + log.info(f"Setting data for key: {key}, table: {table_name}") + + # Use new table creation approach with duplicate prevention + await self._ensure_table_exists_with_schema(table_name) + + conn = await self._get_connection() + try: + if isinstance(value, bytes): + # Parse parquet data + df = pd.read_parquet(BytesIO(value)) + log.info(f"Parsed parquet data, DataFrame shape: {df.shape}") + # output sample record for debugging + log.debug(f"Table {table_name} Sample record (first row): {df.iloc[0].to_dict()}") + log.info(f"Parsed DataFrame columns: {df.columns.tolist()}") + + # Prepare data for PostgreSQL (typed or generic) + records = self._prepare_data_for_postgres(df, table_name) + + if records: + # Use batch insert for much better performance + await self._batch_upsert_records(conn, table_name, records) + + log.info(f"Successfully inserted {len(records)} records to {table_name}") + + # Log ID handling info + if any(record['id'].split(':')[0] in self._no_id_prefixes for record in records if 'id' in record): + log.info(f"Some records used auto-generated IDs in table {table_name}") + + else: + # Handle non-parquet data (e.g., JSON, stats) - always use generic table + log.info(f"Handling non-parquet data for key: {key}") + + record_data = json.loads(value) if isinstance(value, str) else value + + # Use generic table insertion for non-parquet data + records = [{ + 'id': key, + 'data': record_data, + 'metadata': {'type': 'non_parquet', 'created': datetime.now(timezone.utc).isoformat()} + }] + + await self._batch_upsert_generic_records(conn, table_name, records) + log.info("Non-parquet data insertion successful") + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error setting data in PostgreSQL table %s: %s", table_name, e) + raise + + async def has(self, key: str) -> bool: + """Check if data exists for the given key.""" + try: + table_name = self._get_table_name(key) + log.info(f"Checking existence for key: {key}, table: {table_name}") + conn = await self._get_connection() + try: + # Check if table exists + table_exists = await conn.fetchval( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + table_name + ) + log.debug(f"Table {table_name} exists: {table_exists}") + if not table_exists: + return False + + if key.endswith('.parquet'): + # For parquet files, check if table has any records + total_count = await conn.fetchval( + f"SELECT COUNT(*) FROM {table_name}" + ) + if total_count > 0: + return True + else: + raise ValueError(f"No records found in table {table_name} for parquet key {key}") + else: + # Check for exact key match + exists = await conn.fetchval( + f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", + key + ) + return exists + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error checking existence for key %s: %s", key, e) + return False + + async def delete(self, key: str) -> None: + """Delete data for the given key.""" + try: + table_name = self._get_table_name(key) + conn = await self._get_connection() + try: + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + log.info(f"Deleted record for key {key}") + finally: + await self._release_connection(conn) + except Exception as e: + log.exception("Error deleting key %s: %s", key, e) + + async def clear(self) -> None: + """Clear all tables with the configured prefix.""" + try: + conn = await self._get_connection() + try: + # Get all tables with our prefix + tables = await conn.fetch( + "SELECT table_name FROM information_schema.tables WHERE table_name LIKE $1", + f"{self._collection_prefix}%" + ) + + for table_row in tables: + table_name = table_row['table_name'] + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + log.info(f"Dropped table: {table_name}") + + log.info(f"Cleared all tables with prefix: {self._collection_prefix}") + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error clearing tables: %s", e) + + def keys(self) -> list[str]: + """Return the keys in the storage.""" + # This would need to be async to properly implement + # For now, return empty list + log.warning("keys() method not fully implemented for async storage") + return [] + + def child(self, name: str | None) -> PipelineStorage: + """Create a child storage instance.""" + return self + + async def get_creation_date(self, key: str) -> str: + """Get the creation date for data.""" + try: + table_name = self._get_table_name(key) + conn = await self._get_connection() + try: + if key.endswith('.parquet'): + prefix = self._get_prefix(key) + created_at = await conn.fetchval( + f"SELECT MIN(created_at) FROM {table_name} WHERE id LIKE $1", + f"{prefix}:%" + ) + else: + created_at = await conn.fetchval( + f"SELECT created_at FROM {table_name} WHERE id = $1", + key + ) + + if created_at: + return get_timestamp_formatted_with_local_tz(created_at) + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error getting creation date for %s: %s", key, e) + + return "" + + async def close(self) -> None: + """Close the connection pool.""" + if self._pool: + await self._pool.close() + log.info("Closed PostgreSQL connection pool") From d6778b6bda6d585ef8d9bc61cf9d30a59a6cbb9c Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Fri, 22 Aug 2025 09:26:03 -0700 Subject: [PATCH 02/12] added timeout for postgres storage --- graphrag/config/models/storage_config.py | 16 +++++++++ graphrag/storage/postgres_pipeline_storage.py | 36 +++++++++++++------ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/graphrag/config/models/storage_config.py b/graphrag/config/models/storage_config.py index e89524479e..90b73963a2 100644 --- a/graphrag/config/models/storage_config.py +++ b/graphrag/config/models/storage_config.py @@ -75,4 +75,20 @@ def validate_base_dir(cls, value, info): collection_prefix: str = Field( description="Prefix for PostgreSQL collection names (for postgres type).", default="graphrag_" + ) + batch_size: int = Field( + description="Batch size for database operations (for postgres type).", + default=50 + ) + command_timeout: int = Field( + description="Command timeout for database operations (for postgres type).", + default=600 + ) + server_timeout: int = Field( + description="Server timeout for database connections (for postgres type).", + default=120 + ) + connection_timeout: int = Field( + description="Connection timeout for establishing database connections (for postgres type).", + default=60 ) \ No newline at end of file diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index 64957f6a68..df337d5164 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -43,6 +43,10 @@ def __init__( collection_prefix: str = "lgr_", encoding: str = "utf-8", connection_string: str | None = None, + command_timeout: int = 600, # 10 minutes for SQL commands + server_timeout: int = 120, # 2 minutes for server connection + connection_timeout: int = 60, # 1 minute to establish connection + batch_size: int = 50, # Smaller batch size to reduce timeout risk **kwargs: Any, ): """Initialize the PostgreSQL Storage.""" @@ -53,6 +57,10 @@ def __init__( self._password = password self._collection_prefix = collection_prefix self._encoding = encoding + self._command_timeout = command_timeout + self._server_timeout = server_timeout + self._connection_timeout = connection_timeout + self._batch_size = batch_size self._no_id_prefixes = [] self._pool = None @@ -66,11 +74,13 @@ def __init__( self._connection_string = f"postgresql://{username}@{host}:{port}/{database}" log.info( - "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s", + "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s, command_timeout: %s, batch_size: %s", self._host, self._port, self._database, self._collection_prefix, + self._command_timeout, + self._batch_size, ) async def _get_connection(self) -> Connection: @@ -81,9 +91,15 @@ async def _get_connection(self) -> Connection: self._connection_string, min_size=1, max_size=10, - command_timeout=60 + command_timeout=self._command_timeout, + server_settings={ + 'application_name': 'graphrag_postgres_storage' + }, + # Use connection_timeout for initial connection establishment + timeout=self._connection_timeout ) - log.info("Created PostgreSQL connection pool") + log.info("Created PostgreSQL connection pool with command_timeout: %s, connection_timeout: %s", + self._command_timeout, self._connection_timeout) except Exception as e: log.error("Failed to create PostgreSQL connection pool: %s", e) raise @@ -590,10 +606,10 @@ def _ensure_datetime(self, value: Any) -> datetime: # For any other type, return current time return datetime.now(timezone.utc) - async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict], batch_size: int = 1000) -> None: + async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict]) -> None: """Perform high-performance batch upsert of records using executemany.""" total_records = len(records) - log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {batch_size}") + log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {self._batch_size}") # Ensure all datetime fields are timezone-aware records = self._ensure_timezone_aware_datetimes(records) @@ -605,10 +621,10 @@ async def _batch_upsert_records(self, conn: Connection, table_name: str, records ['entities', 'relationships', 'communities', 'text_units', 'documents']) # Process records in batches for optimal performance - for i in range(0, total_records, batch_size): - batch = records[i:i + batch_size] - batch_end = min(i + batch_size, total_records) - + for i in range(0, total_records, self._batch_size): + batch = records[i:i + self._batch_size] + batch_end = min(i + self._batch_size, total_records) + try: if is_typed_table: await self._batch_upsert_typed_records(conn, table_name, batch) @@ -642,7 +658,7 @@ async def _batch_upsert_records(self, conn: Connection, table_name: str, records processed_count += len(batch) # Log progress every batch for visibility - if i % batch_size == 0 or batch_end == total_records: + if i % self._batch_size == 0 or batch_end == total_records: log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") async def _batch_upsert_typed_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: From 8e2bdff046894b06470644345b4f0c7be00a2ef5 Mon Sep 17 00:00:00 2001 From: Danny Zheng Date: Fri, 25 Jul 2025 14:32:54 -0700 Subject: [PATCH 03/12] skip id generation for entities, relationships, and communities if id exists --- graphrag/index/operations/finalize_entities.py | 9 ++++++--- graphrag/index/operations/finalize_relationships.py | 9 ++++++--- graphrag/index/workflows/create_communities.py | 6 +++++- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/graphrag/index/operations/finalize_entities.py b/graphrag/index/operations/finalize_entities.py index 7c14bdd9c2..55fac07d2a 100644 --- a/graphrag/index/operations/finalize_entities.py +++ b/graphrag/index/operations/finalize_entities.py @@ -48,9 +48,12 @@ def finalize_entities( final_entities["degree"] = final_entities["degree"].fillna(0).astype(int) final_entities.reset_index(inplace=True) final_entities["human_readable_id"] = final_entities.index - final_entities["id"] = final_entities["human_readable_id"].apply( - lambda _x: str(uuid4()) - ) + + # Generate id if id is empty + if "id" not in final_entities.columns or final_entities["id"].isna().all(): + final_entities["id"] = final_entities["human_readable_id"].apply( + lambda _x: str(uuid4()) + ) return final_entities.loc[ :, ENTITIES_FINAL_COLUMNS, diff --git a/graphrag/index/operations/finalize_relationships.py b/graphrag/index/operations/finalize_relationships.py index 21ba413667..163a629fe9 100644 --- a/graphrag/index/operations/finalize_relationships.py +++ b/graphrag/index/operations/finalize_relationships.py @@ -34,9 +34,12 @@ def finalize_relationships( final_relationships.reset_index(inplace=True) final_relationships["human_readable_id"] = final_relationships.index - final_relationships["id"] = final_relationships["human_readable_id"].apply( - lambda _x: str(uuid4()) - ) + + # Generate id if there is no id + if "id" not in final_relationships.columns or final_relationships["id"].isna().all(): + final_relationships["id"] = final_relationships["human_readable_id"].apply( + lambda _x: str(uuid4()) + ) return final_relationships.loc[ :, diff --git a/graphrag/index/workflows/create_communities.py b/graphrag/index/workflows/create_communities.py index c06d5f4b28..ed0542821b 100644 --- a/graphrag/index/workflows/create_communities.py +++ b/graphrag/index/workflows/create_communities.py @@ -125,7 +125,11 @@ def create_communities( # join it all up and add some new fields final_communities = all_grouped.merge(entity_ids, on="community", how="inner") - final_communities["id"] = [str(uuid4()) for _ in range(len(final_communities))] + + # Generate id if there is no id + if "id" not in final_communities.columns or final_communities["id"].isna().all(): + final_communities["id"] = [str(uuid4()) for _ in range(len(final_communities))] + final_communities["human_readable_id"] = final_communities["community"] final_communities["title"] = "Community " + final_communities["community"].astype( str From fc716ba0131f90e25280270eeb83315b0b54195c Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Fri, 22 Aug 2025 16:17:06 -0700 Subject: [PATCH 04/12] removed unnessary sanitize func --- graphrag/storage/postgres_pipeline_storage.py | 60 +------------------ 1 file changed, 2 insertions(+), 58 deletions(-) diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index df337d5164..0a265d8316 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -111,46 +111,12 @@ async def _release_connection(self, conn: Connection) -> None: if self._pool: await self._pool.release(conn) - def _sanitize_table_name(self, name: str) -> str: - """Sanitize a name to be a valid PostgreSQL table name.""" - return name - import re - - # Replace common problematic characters - sanitized = name.replace("-", "_").replace(":", "_").replace(" ", "_") - - # Remove any characters that aren't alphanumeric, underscore, or dollar sign - sanitized = re.sub(r'[^a-zA-Z0-9_$]', '_', sanitized) - - # Remove consecutive underscores - sanitized = re.sub(r'_+', '_', sanitized) - - # Remove leading/trailing underscores - sanitized = sanitized.strip('_') - - # Ensure it starts with a letter or underscore (not a digit) - if sanitized and sanitized[0].isdigit(): - sanitized = f"tbl_{sanitized}" - - # Ensure it's not empty - if not sanitized: - sanitized = "tbl_unnamed" - - # PostgreSQL has a limit of 63 characters for identifiers - if len(sanitized) > 59: # Leave room for prefix - sanitized = sanitized[:59] - log.info(f"Sanitied name {name} to {sanitized}") - return sanitized - def _get_table_name(self, key: str) -> str: """Get the table name for a given key.""" # Extract the base name without file extension base_name = key.split(".")[0] - # Sanitize for PostgreSQL compatibility - sanitized_name = self._sanitize_table_name(base_name) - - return f"{self._collection_prefix}{sanitized_name}" + return f"{self._collection_prefix}{base_name}" def _get_prefix(self, key: str) -> str: """Get the prefix of the filename key.""" @@ -158,8 +124,6 @@ def _get_prefix(self, key: str) -> str: def _get_entities_table_schema(self, table_name: str) -> str: """Get the SQL schema for entities table.""" - # Sanitize table name for index names - table_name = self._sanitize_table_name(table_name) return f""" CREATE TABLE {table_name} ( id TEXT PRIMARY KEY, @@ -185,8 +149,6 @@ def _get_entities_table_schema(self, table_name: str) -> str: def _get_relationships_table_schema(self, table_name: str) -> str: """Get the SQL schema for relationships table.""" - # Sanitize table name for index names - table_name = self._sanitize_table_name(table_name) return f""" CREATE TABLE {table_name} ( id TEXT PRIMARY KEY, @@ -211,8 +173,6 @@ def _get_relationships_table_schema(self, table_name: str) -> str: def _get_communities_table_schema(self, table_name: str) -> str: """Get the SQL schema for communities table.""" - # Sanitize table name for index names - table_name = self._sanitize_table_name(table_name) return f""" CREATE TABLE {table_name} ( id TEXT PRIMARY KEY, @@ -239,8 +199,6 @@ def _get_communities_table_schema(self, table_name: str) -> str: def _get_text_units_table_schema(self, table_name: str) -> str: """Get the SQL schema for text_units table.""" - # Sanitize table name for index names - table_name = self._sanitize_table_name(table_name) return f""" CREATE TABLE {table_name} ( id TEXT PRIMARY KEY, @@ -264,8 +222,6 @@ def _get_text_units_table_schema(self, table_name: str) -> str: def _get_documents_table_schema(self, table_name: str) -> str: """Get the SQL schema for documents table.""" - # Sanitize table name for index names - table_name = self._sanitize_table_name(table_name) return f""" CREATE TABLE {table_name} ( id TEXT PRIMARY KEY, @@ -289,8 +245,6 @@ def _get_documents_table_schema(self, table_name: str) -> str: def _get_generic_table_schema(self, table_name: str) -> str: """Get the SQL schema for generic data (fallback).""" - # Sanitize table name for index names - table_name = self._sanitize_table_name(table_name) return f""" CREATE TABLE {table_name} ( id TEXT PRIMARY KEY, @@ -307,8 +261,6 @@ def _get_generic_table_schema(self, table_name: str) -> str: def _get_table_schema_sql(self, table_name: str) -> str: """Get the appropriate schema SQL for the table type.""" - # Sanitize table name for schema generation - table_name = self._sanitize_table_name(table_name) if 'entities' in table_name: return self._get_entities_table_schema(table_name) @@ -324,9 +276,6 @@ def _get_table_schema_sql(self, table_name: str) -> str: return self._get_generic_table_schema(table_name) async def _ensure_table_exists_with_schema(self, table_name: str) -> None: - # Ensure table name is properly sanitized for SQL operations - table_name = self._sanitize_table_name(table_name) - conn = await self._get_connection() try: table_exists = await conn.fetchval( @@ -664,8 +613,6 @@ async def _batch_upsert_records(self, conn: Connection, table_name: str, records async def _batch_upsert_typed_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: """Batch upsert for typed tables with specific columns.""" async with conn.transaction(): - # Ensure table name is properly sanitized for SQL - table_name = self._sanitize_table_name(table_name) if 'entities' in table_name: upsert_sql = f""" @@ -779,8 +726,6 @@ async def _batch_upsert_typed_records(self, conn: Connection, table_name: str, b async def _batch_upsert_generic_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: """Batch upsert for generic tables using JSONB.""" async with conn.transaction(): - # Ensure table name is properly sanitized for SQL - table_name = self._sanitize_table_name(table_name) upsert_sql = f""" INSERT INTO {table_name} (id, data, metadata, updated_at) VALUES ($1, $2, $3, NOW()) @@ -972,11 +917,10 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None # Use the storage ID as is since we simplified ID handling storage_id = row['id'] cleaned_data['id'] = storage_id - records.append(cleaned_data) df = pd.DataFrame(records) - + # Additional cleanup for NaN values in the DataFrame df = df.where(pd.notna(df), None) log.info(f"Created DataFrame with shape: {df.shape}") From 93cb87ef31c74ef9acfb9c44900a2b7bbb6f850a Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Fri, 22 Aug 2025 16:32:03 -0700 Subject: [PATCH 05/12] simplified postgres schema --- graphrag/storage/postgres_pipeline_storage.py | 1098 +++------------ .../storage/postgres_pipeline_storage2.py | 1237 +++++++++++++++++ 2 files changed, 1399 insertions(+), 936 deletions(-) create mode 100644 graphrag/storage/postgres_pipeline_storage2.py diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index 0a265d8316..2bb8971558 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -1,20 +1,18 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""PostgreSQL Storage implementation of PipelineStorage.""" +"""PostgreSQL Storage implementation of PipelineStorage""" import json import logging import re from collections.abc import Iterator -from datetime import datetime, timezone from io import BytesIO from typing import Any -import numpy as np import pandas as pd import asyncpg -from asyncpg import Connection, Pool +from asyncpg import Connection from graphrag.storage.pipeline_storage import ( PipelineStorage, @@ -24,14 +22,7 @@ log = logging.getLogger(__name__) class PostgresPipelineStorage(PipelineStorage): - """The PostgreSQL Storage Implementation.""" - - _pool: Pool | None - _connection_string: str - _database: str - _collection_prefix: str - _encoding: str - _no_id_prefixes: list[str] + """Simplified PostgreSQL Storage Implementation.""" def __init__( self, @@ -43,10 +34,9 @@ def __init__( collection_prefix: str = "lgr_", encoding: str = "utf-8", connection_string: str | None = None, - command_timeout: int = 600, # 10 minutes for SQL commands - server_timeout: int = 120, # 2 minutes for server connection - connection_timeout: int = 60, # 1 minute to establish connection - batch_size: int = 50, # Smaller batch size to reduce timeout risk + command_timeout: int = 600, + connection_timeout: int = 60, + batch_size: int = 50, **kwargs: Any, ): """Initialize the PostgreSQL Storage.""" @@ -58,13 +48,10 @@ def __init__( self._collection_prefix = collection_prefix self._encoding = encoding self._command_timeout = command_timeout - self._server_timeout = server_timeout self._connection_timeout = connection_timeout self._batch_size = batch_size - self._no_id_prefixes = [] self._pool = None - # Build connection string from components or use provided one if connection_string: self._connection_string = connection_string else: @@ -74,13 +61,8 @@ def __init__( self._connection_string = f"postgresql://{username}@{host}:{port}/{database}" log.info( - "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s, command_timeout: %s, batch_size: %s", - self._host, - self._port, - self._database, - self._collection_prefix, - self._command_timeout, - self._batch_size, + "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s", + self._host, self._port, self._database, self._collection_prefix ) async def _get_connection(self) -> Connection: @@ -92,14 +74,10 @@ async def _get_connection(self) -> Connection: min_size=1, max_size=10, command_timeout=self._command_timeout, - server_settings={ - 'application_name': 'graphrag_postgres_storage' - }, - # Use connection_timeout for initial connection establishment + server_settings={'application_name': 'graphrag_postgres_storage'}, timeout=self._connection_timeout ) - log.info("Created PostgreSQL connection pool with command_timeout: %s, connection_timeout: %s", - self._command_timeout, self._connection_timeout) + log.info("Created PostgreSQL connection pool") except Exception as e: log.error("Failed to create PostgreSQL connection pool: %s", e) raise @@ -113,169 +91,26 @@ async def _release_connection(self, conn: Connection) -> None: def _get_table_name(self, key: str) -> str: """Get the table name for a given key.""" - # Extract the base name without file extension base_name = key.split(".")[0] - return f"{self._collection_prefix}{base_name}" - def _get_prefix(self, key: str) -> str: - """Get the prefix of the filename key.""" - return key.split(".")[0] - - def _get_entities_table_schema(self, table_name: str) -> str: - """Get the SQL schema for entities table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - title TEXT, - type TEXT, - description TEXT, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - frequency INTEGER DEFAULT 0, - degree INTEGER DEFAULT 0, - x DOUBLE PRECISION DEFAULT 0.0, - y DOUBLE PRECISION DEFAULT 0.0, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Performance indexes - CREATE INDEX idx_{table_name}_type ON {table_name}(type); - CREATE INDEX idx_{table_name}_frequency ON {table_name}(frequency); - CREATE INDEX idx_{table_name}_title ON {table_name}(title); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - """ - - def _get_relationships_table_schema(self, table_name: str) -> str: - """Get the SQL schema for relationships table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - source TEXT NOT NULL, - target TEXT NOT NULL, - description TEXT DEFAULT '', - weight DOUBLE PRECISION DEFAULT 0.0, - combined_degree INTEGER DEFAULT 0, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Graph query indexes - CREATE INDEX idx_{table_name}_source ON {table_name}(source); - CREATE INDEX idx_{table_name}_target ON {table_name}(target); - CREATE INDEX idx_{table_name}_weight ON {table_name}(weight); - CREATE INDEX idx_{table_name}_source_target ON {table_name}(source, target); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - """ - - def _get_communities_table_schema(self, table_name: str) -> str: - """Get the SQL schema for communities table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - community INTEGER, - level INTEGER DEFAULT 0, - parent INTEGER, - children JSONB DEFAULT '[]'::jsonb, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - entity_ids JSONB DEFAULT '[]'::jsonb, - relationship_ids JSONB DEFAULT '[]'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Community hierarchy indexes - CREATE INDEX idx_{table_name}_community ON {table_name}(community); - CREATE INDEX idx_{table_name}_level ON {table_name}(level); - CREATE INDEX idx_{table_name}_parent ON {table_name}(parent); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); - CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); - """ - - def _get_text_units_table_schema(self, table_name: str) -> str: - """Get the SQL schema for text_units table.""" + def _get_universal_table_schema(self, table_name: str) -> str: + """Universal schema that works for all GraphRAG data types.""" return f""" CREATE TABLE {table_name} ( id TEXT PRIMARY KEY, human_readable_id BIGINT, - text TEXT, - n_tokens INTEGER DEFAULT 0, - document_ids JSONB DEFAULT '[]'::jsonb, - entity_ids JSONB DEFAULT '[]'::jsonb, - relationship_ids JSONB DEFAULT '[]'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Text search and relationship indexes - CREATE INDEX idx_{table_name}_n_tokens ON {table_name}(n_tokens); - CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); - CREATE INDEX idx_{table_name}_document_ids_gin ON {table_name} USING GIN(document_ids); - CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); - CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); - """ - - def _get_documents_table_schema(self, table_name: str) -> str: - """Get the SQL schema for documents table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - title TEXT, - text TEXT, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - creation_date TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - metadata JSONB DEFAULT '{{}}'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Document search indexes - CREATE INDEX idx_{table_name}_title ON {table_name}(title); - CREATE INDEX idx_{table_name}_creation_date ON {table_name}(creation_date); - CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); - """ - - def _get_generic_table_schema(self, table_name: str) -> str: - """Get the SQL schema for generic data (fallback).""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, data JSONB NOT NULL, - metadata JSONB DEFAULT '{{}}'::jsonb, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ); - -- Generic indexes - CREATE INDEX idx_{table_name}_data_gin ON {table_name} USING GIN(data); - CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); + CREATE INDEX IF NOT EXISTS idx_{table_name}_data_gin ON {table_name} USING GIN(data); + CREATE INDEX IF NOT EXISTS idx_{table_name}_created_at ON {table_name}(created_at); """ - def _get_table_schema_sql(self, table_name: str) -> str: - """Get the appropriate schema SQL for the table type.""" - - if 'entities' in table_name: - return self._get_entities_table_schema(table_name) - elif 'relationships' in table_name: - return self._get_relationships_table_schema(table_name) - elif 'communities' in table_name: - return self._get_communities_table_schema(table_name) - elif 'text_units' in table_name: - return self._get_text_units_table_schema(table_name) - elif 'documents' in table_name: - return self._get_documents_table_schema(table_name) - else: - return self._get_generic_table_schema(table_name) - - async def _ensure_table_exists_with_schema(self, table_name: str) -> None: + async def _ensure_table_exists(self, table_name: str) -> None: + """Ensure table exists with universal schema.""" conn = await self._get_connection() try: table_exists = await conn.fetchval( @@ -283,485 +118,105 @@ async def _ensure_table_exists_with_schema(self, table_name: str) -> None: table_name ) if not table_exists: - # Create table with appropriate typed schema (pass original table_name for type detection) - schema_sql = self._get_table_schema_sql(table_name) + schema_sql = self._get_universal_table_schema(table_name) await conn.execute(schema_sql) - log.info(f"Created table {table_name} with specific schema") - + log.info(f"Created table {table_name}") finally: await self._release_connection(conn) - def _process_id_field(self, df: pd.DataFrame, table_name: str) -> list[str]: - """Process ID values - store clean IDs with prefix following CosmosDB pattern in GraphRAG.""" - prefix = self._get_prefix(table_name.replace(self._collection_prefix, "")) - id_values = [] - - if "id" not in df.columns: - # No ID column - create prefixed sequential IDs and track this prefix - for index in range(len(df)): - id_values.append(f"{prefix}:{index}") - if prefix not in self._no_id_prefixes: - self._no_id_prefixes.append(prefix) - log.info(f"No ID column found for {prefix}, generated prefixed sequential IDs") - else: - # Has ID column - process each row with prefix - for index, val in enumerate(df["id"]): - if self._is_scalar_na(val) or val == '' or val == 'nan' or (isinstance(val, list) and (len(val) == 0 or self._is_scalar_na(val[0]) or str(val[0]).strip() == '')): - # Missing ID - create prefixed sequential ID and track this prefix - id_values.append(f"{prefix}:{index}") - if prefix not in self._no_id_prefixes: - self._no_id_prefixes.append(prefix) - else: - # Valid ID - use as is without prefix - if isinstance(val, list): - id_values.append(str(val[0])) - else: - id_values.append(str(val)) - - return id_values - - def _is_scalar_na(self, value: Any) -> bool: - """Safely check if a value is NA/null, avoiding issues with arrays.""" - try: - # Don't check pd.isna on complex objects or large arrays - if isinstance(value, (list, dict)): - return False - if hasattr(value, '__len__') and len(str(value)) > 100: - return False - return pd.isna(value) - except (ValueError, TypeError): - # If pd.isna fails, assume it's not NA - return False - - def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[dict]: - """Prepare DataFrame data for PostgreSQL insertion with typed columns.""" - log.info(f"Preparing data for table {table_name}, DataFrame shape: {df.shape}") - log.info(f"DataFrame columns: {df.columns.tolist()}") - - # Add human_readable_id if missing - if 'human_readable_id' not in df.columns: - df = df.copy() - df['human_readable_id'] = range(len(df)) - log.info(f"Generated sequential human_readable_id for {len(df)} records") - - # Process IDs - for typed tables, we can use simpler ID handling - ids = self._process_id_field(df, table_name) - - # Determine if this is a typed table or generic table - is_typed_table = any(table_type in table_name for table_type in - ['entities', 'relationships', 'communities', 'text_units', 'documents']) - - if is_typed_table: - return self._prepare_data_for_typed_table(df, table_name, ids) - else: - return self._prepare_data_for_generic_table(df, table_name, ids) - - def _prepare_data_for_typed_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: - """Prepare data for typed PostgreSQL tables with specific columns.""" + def _prepare_dataframe_for_storage(self, df: pd.DataFrame) -> list[dict]: + """Convert DataFrame to records for storage.""" records = [] - for i in range(len(df)): - record = {'id': ids[i]} - row = df.iloc[i] - - # Map DataFrame columns to table columns based on table type - if 'entities' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'title': str(row.get('title', '')), - 'type': str(row.get('type', '')), - 'description': str(row.get('description', '')), - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), - 'frequency': int(row.get('frequency', 0)) if pd.notna(row.get('frequency', 0)) else 0, - 'degree': int(row.get('degree', 0)) if pd.notna(row.get('degree', 0)) else 0, - 'x': float(row.get('x', 0.0)) if pd.notna(row.get('x', 0.0)) else 0.0, - 'y': float(row.get('y', 0.0)) if pd.notna(row.get('y', 0.0)) else 0.0 - }) - elif 'relationships' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'source': str(row.get('source', '')), - 'target': str(row.get('target', '')), - 'description': str(row.get('description', '')), - 'weight': float(row.get('weight', 0.0)) if pd.notna(row.get('weight', 0.0)) else 0.0, - 'combined_degree': int(row.get('combined_degree', 0)) if pd.notna(row.get('combined_degree', 0)) else 0, - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])) - }) - elif 'communities' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'community': int(row.get('community', 0)) if pd.notna(row.get('community')) and str(row.get('community', '')).strip() != '' else 0, - 'level': int(row.get('level', 0)) if pd.notna(row.get('level', 0)) else 0, - 'parent': int(row.get('parent', 0)) if pd.notna(row.get('parent')) and str(row.get('parent', '')).strip() != '' else None, - 'children': self._ensure_json_list(row.get('children', [])), - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), - 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), - 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) - }) - elif 'text_units' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'text': str(row.get('text', '')), - 'n_tokens': int(row.get('n_tokens', 0)) if pd.notna(row.get('n_tokens', 0)) else 0, - 'document_ids': self._ensure_json_list(row.get('document_ids', [])), - 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), - 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) - }) - elif 'documents' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'title': str(row.get('title', '')), - 'text': str(row.get('text', '')), - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), - 'creation_date': self._ensure_datetime(row.get('creation_date')), - 'metadata': self._ensure_json_dict(row.get('metadata', {})) - }) - - records.append(record) - - log.info(f"Prepared {len(records)} records for typed table {table_name}") - if records: - log.info(f"Sample typed record: {list(records[0].keys())}") - - return records - - def _prepare_data_for_generic_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: - """Prepare data for generic PostgreSQL tables (fallback to JSONB storage).""" - records = [] - for i in range(len(df)): - # Create record with ID and all data in JSONB field - record_data = df.iloc[i].to_dict() + for i, row in df.iterrows(): + record_data = row.to_dict() - # Convert numpy types to native Python types for JSON serialization + # Convert pandas/numpy types to JSON-serializable types for key, value in record_data.items(): - if isinstance(value, (list, dict)): + if pd.isna(value): + record_data[key] = None + elif isinstance(value, (list, dict)): record_data[key] = value elif hasattr(value, 'tolist'): - # Handle numpy arrays and other numpy types record_data[key] = value.tolist() elif hasattr(value, 'item') and hasattr(value, 'size') and value.size == 1: record_data[key] = value.item() elif isinstance(value, (pd.Timestamp, pd.DatetimeTZDtype)): record_data[key] = value.isoformat() if pd.notna(value) else None - elif self._is_scalar_na(value): - record_data[key] = None - elif isinstance(value, (list, np.ndarray)) and len(value) == 0: - record_data[key] = [] else: record_data[key] = value - - record = { - 'id': ids[i], - 'data': record_data, - 'metadata': {} - } - records.append(record) - - log.info(f"Prepared {len(records)} records for generic table {table_name}") - return records - def _ensure_json_list(self, value: Any) -> list: - """Ensure a value is a proper list for JSONB storage.""" - if isinstance(value, list): - # Convert any numpy arrays in the list to regular Python lists - return [item.tolist() if hasattr(item, 'tolist') else item for item in value] - elif hasattr(value, 'tolist'): - # Handle numpy arrays directly - converted = value.tolist() - return converted if isinstance(converted, list) else [converted] - elif isinstance(value, str) and value: - try: - parsed = json.loads(value) - return parsed if isinstance(parsed, list) else [] - except (json.JSONDecodeError, TypeError): - return [] - elif value is None or pd.isna(value): - return [] - else: - return [value] if value else [] - - def _ensure_json_dict(self, value: Any) -> dict: - """Ensure a value is a proper dict for JSONB storage.""" - if isinstance(value, dict): - # Convert any numpy arrays in the dict to regular Python objects - result = {} - for k, v in value.items(): - if hasattr(v, 'tolist'): - result[k] = v.tolist() - elif hasattr(v, 'item') and hasattr(v, 'size') and v.size == 1: - result[k] = v.item() - else: - result[k] = v - return result - elif isinstance(value, str) and value: - try: - parsed = json.loads(value) - return parsed if isinstance(parsed, dict) else {} - except (json.JSONDecodeError, TypeError): - return {} - elif value is None or pd.isna(value): - return {} - else: - return {'value': str(value)} if value else {} - - def _ensure_timezone_aware_datetimes(self, records: list[dict]) -> list[dict]: - """Ensure all datetime fields in records are timezone-aware for PostgreSQL.""" - datetime_fields = ['creation_date', 'created_at', 'updated_at'] - - for record in records: - for field in datetime_fields: - if field in record: - value = record[field] - if value is not None: - record[field] = self._ensure_datetime(value) + # Extract ID and human_readable_id + record_id = record_data.pop('id', f"record_{i}") + human_readable_id = record_data.pop('human_readable_id', i) + + records.append({ + 'id': str(record_id), + 'human_readable_id': int(human_readable_id) if pd.notna(human_readable_id) else i, + 'data': record_data + }) return records - def _ensure_datetime(self, value: Any) -> datetime: - """Ensure a value is a proper timezone-aware datetime object for PostgreSQL storage.""" - from dateutil import parser - - if isinstance(value, datetime): - # If it's already a datetime, ensure it has timezone info - if value.tzinfo is None: - # If it's timezone-naive, localize to UTC - return value.replace(tzinfo=timezone.utc) - else: - # Already timezone-aware - return value - elif isinstance(value, pd.Timestamp): - # Convert pandas Timestamp to datetime - dt = value.to_pydatetime() - # Ensure timezone awareness - if dt.tzinfo is None: - return dt.replace(tzinfo=timezone.utc) - else: - return dt - elif isinstance(value, str) and value: - try: - # Try to parse the string as a datetime - parsed_dt = parser.parse(value) - # Ensure timezone awareness - if parsed_dt.tzinfo is None: - return parsed_dt.replace(tzinfo=timezone.utc) - else: - return parsed_dt - except (ValueError, TypeError): - # If parsing fails, return current time - return datetime.now(timezone.utc) - elif value is None or pd.isna(value): - return datetime.now(timezone.utc) - else: - # For any other type, return current time - return datetime.now(timezone.utc) - - async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict]) -> None: - """Perform high-performance batch upsert of records using executemany.""" + async def _batch_upsert(self, conn: Connection, table_name: str, records: list[dict]) -> None: + """Perform batch upsert of records.""" total_records = len(records) - log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {self._batch_size}") - - # Ensure all datetime fields are timezone-aware - records = self._ensure_timezone_aware_datetimes(records) - - processed_count = 0 + log.info(f"Starting batch upsert of {total_records} records to {table_name}") - # Determine if this is a typed table or generic table - is_typed_table = any(table_type in table_name for table_type in - ['entities', 'relationships', 'communities', 'text_units', 'documents']) + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, data, updated_at) + VALUES ($1, $2, $3, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + data = EXCLUDED.data, + updated_at = NOW() + """ - # Process records in batches for optimal performance + # Process in batches for i in range(0, total_records, self._batch_size): batch = records[i:i + self._batch_size] - batch_end = min(i + self._batch_size, total_records) - - try: - if is_typed_table: - await self._batch_upsert_typed_records(conn, table_name, batch) - else: - await self._batch_upsert_generic_records(conn, table_name, batch) - - except Exception as e: - log.warning(f"Batch method failed for batch {i}-{batch_end}, falling back to individual inserts: {e}") - - # Fallback to individual inserts within the batch - try: - async with conn.transaction(): - if is_typed_table: - for record in batch: - await self._insert_typed_record(conn, table_name, record) - else: - upsert_sql = f""" - INSERT INTO {table_name} (id, data, updated_at) - VALUES ($1, $2, NOW()) - ON CONFLICT (id) - DO UPDATE SET - data = EXCLUDED.data, - updated_at = NOW() - """ - for record in batch: - await conn.execute(upsert_sql, record['id'], json.dumps(record['data'])) - except Exception as inner_e: - log.error(f"Both batch and individual insert methods failed for batch {i}-{batch_end}: {inner_e}") - raise - - processed_count += len(batch) - - # Log progress every batch for visibility - if i % self._batch_size == 0 or batch_end == total_records: - log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") - - async def _batch_upsert_typed_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: - """Batch upsert for typed tables with specific columns.""" - async with conn.transaction(): - - if 'entities' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - title = EXCLUDED.title, - type = EXCLUDED.type, - description = EXCLUDED.description, - text_unit_ids = EXCLUDED.text_unit_ids, - frequency = EXCLUDED.frequency, - degree = EXCLUDED.degree, - x = EXCLUDED.x, - y = EXCLUDED.y, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['title'], r['type'], r['description'], - json.dumps(r['text_unit_ids']), r['frequency'], r['degree'], r['x'], r['y']) - for r in batch - ] - elif 'relationships' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - source = EXCLUDED.source, - target = EXCLUDED.target, - description = EXCLUDED.description, - weight = EXCLUDED.weight, - combined_degree = EXCLUDED.combined_degree, - text_unit_ids = EXCLUDED.text_unit_ids, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['source'], r['target'], r['description'], - r['weight'], r['combined_degree'], json.dumps(r['text_unit_ids'])) - for r in batch - ] - elif 'communities' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - community = EXCLUDED.community, - level = EXCLUDED.level, - parent = EXCLUDED.parent, - children = EXCLUDED.children, - text_unit_ids = EXCLUDED.text_unit_ids, - entity_ids = EXCLUDED.entity_ids, - relationship_ids = EXCLUDED.relationship_ids, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['community'], r['level'], r['parent'], - json.dumps(r['children']), json.dumps(r['text_unit_ids']), - json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) - for r in batch - ] - elif 'text_units' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - text = EXCLUDED.text, - n_tokens = EXCLUDED.n_tokens, - document_ids = EXCLUDED.document_ids, - entity_ids = EXCLUDED.entity_ids, - relationship_ids = EXCLUDED.relationship_ids, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['text'], r['n_tokens'], - json.dumps(r['document_ids']), json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) - for r in batch - ] - elif 'documents' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, title, text, text_unit_ids, creation_date, metadata, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - title = EXCLUDED.title, - text = EXCLUDED.text, - text_unit_ids = EXCLUDED.text_unit_ids, - creation_date = EXCLUDED.creation_date, - metadata = EXCLUDED.metadata, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['title'], r['text'], - json.dumps(r['text_unit_ids']), - self._ensure_datetime(r['creation_date']), - json.dumps(r['metadata'])) - for r in batch - ] - else: - raise ValueError(f"Unknown typed table: {table_name}") - - await conn.executemany(upsert_sql, batch_data) - - async def _batch_upsert_generic_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: - """Batch upsert for generic tables using JSONB.""" - async with conn.transaction(): - upsert_sql = f""" - INSERT INTO {table_name} (id, data, metadata, updated_at) - VALUES ($1, $2, $3, NOW()) - ON CONFLICT (id) - DO UPDATE SET - data = EXCLUDED.data, - metadata = EXCLUDED.metadata, - updated_at = NOW() - """ batch_data = [ - (record['id'], json.dumps(record['data']), json.dumps(record['metadata'])) + (record['id'], record['human_readable_id'], json.dumps(record['data'])) for record in batch ] - await conn.executemany(upsert_sql, batch_data) - - async def _insert_typed_record(self, conn: Connection, table_name: str, record: dict) -> None: - """Insert a single typed record (fallback method).""" - # This is a simplified fallback - implement based on table type if needed - # For now, just use the batch method with a single record - await self._batch_upsert_typed_records(conn, table_name, [record]) + + try: + async with conn.transaction(): + await conn.executemany(upsert_sql, batch_data) + + log.info(f"Batch upsert progress: {min(i + self._batch_size, total_records)}/{total_records}") + + except Exception as e: + log.error(f"Batch upsert failed for batch {i}-{min(i + self._batch_size, total_records)}: {e}") + # Fallback to individual inserts + async with conn.transaction(): + for record in batch: + await conn.execute(upsert_sql, record['id'], record['human_readable_id'], json.dumps(record['data'])) + + def _parse_jsonb_field(self, value: Any, default_type: str = "list") -> Any: + """Parse JSONB field back to Python object.""" + if value is None: + return {} if default_type == "dict" else [] + if isinstance(value, (list, dict)): + return value + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return {} if default_type == "dict" else [] + return {} if default_type == "dict" else [] - def find( - self, - file_pattern: re.Pattern[str], - base_dir: str | None = None, - file_filter: dict[str, Any] | None = None, - max_count=-1, - ) -> Iterator[tuple[str, dict[str, Any]]]: - """Find data in PostgreSQL tables using a file pattern regex.""" - # This is a synchronous method, but we need async operations - # For now, implement a basic version - in practice, this would need refactoring - log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) - - # Note: This is simplified - full implementation would need async/await support - # in the find method signature or use asyncio.run() - return iter([]) + def _convert_dataframe_to_parquet_bytes(self, df: pd.DataFrame) -> bytes: + """Convert DataFrame to parquet bytes.""" + try: + buffer = BytesIO() + df.to_parquet(buffer, engine='pyarrow', index=False) + buffer.seek(0) + return buffer.getvalue() + except Exception as e: + log.error(f"Failed to convert DataFrame to parquet bytes: {e}") + return b"" async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: """Retrieve data from PostgreSQL table.""" @@ -780,31 +235,11 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None if not table_exists: log.warning(f"Table {table_name} does not exist") return None - - # Determine if this is a typed table or generic table - is_typed_table = any(table_type in table_name for table_type in - ['entities', 'relationships', 'communities', 'text_units', 'documents']) - - if is_typed_table: - # For typed tables, select all columns except created_at/updated_at - if 'documents' in table_name: - query = "SELECT id, human_readable_id, title, text, text_unit_ids, creation_date, metadata FROM {} ORDER BY created_at".format(table_name) - elif 'entities' in table_name: - query = "SELECT id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y FROM {} ORDER BY created_at".format(table_name) - elif 'relationships' in table_name: - query = "SELECT id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids FROM {} ORDER BY created_at".format(table_name) - elif 'communities' in table_name: - query = "SELECT id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) - elif 'text_units' in table_name: - query = "SELECT id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) - else: - # Fallback for unknown typed table - query = "SELECT * FROM {} ORDER BY created_at".format(table_name) - else: - # For generic tables, use the data column - query = "SELECT id, data FROM {} ORDER BY created_at".format(table_name) - - rows = await conn.fetch(query) + + # Get all records + rows = await conn.fetch( + f"SELECT id, human_readable_id, data FROM {table_name} ORDER BY created_at" + ) if not rows: log.info(f"No data found in table {table_name}") @@ -812,242 +247,48 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None log.info(f"Retrieved {len(rows)} records from table {table_name}") - # Check if this should be treated as raw data instead of tabular data - if (not key.endswith('.parquet') or - 'state' in key.lower() or - key.endswith('.json') or - 'context' in table_name.lower()): - # Handle state.json or context.json as raw data - # For non-tabular data, return the raw content from the first record + # Handle non-parquet data (JSON/state files) + if not key.endswith('.parquet') or 'state' in key.lower() or 'context' in key.lower(): if rows: - if is_typed_table: - # For typed tables, convert row to dict and return as JSON - row_dict = dict(rows[0]) - json_str = json.dumps(row_dict) - return json_str.encode(encoding or self._encoding) if as_bytes else json_str - elif 'data' in rows[0]: - raw_content = rows[0]['data'] - if isinstance(raw_content, dict): - json_str = json.dumps(raw_content) - return json_str.encode(encoding or self._encoding) if as_bytes else json_str + first_record_data = rows[0]['data'] + if isinstance(first_record_data, dict): + json_str = json.dumps(first_record_data) + else: + json_str = json.dumps(first_record_data) + return json_str.encode(encoding or self._encoding) if as_bytes else json_str return b"" if as_bytes else "" # Convert to DataFrame records = [] for row in rows: - if is_typed_table: - # For typed tables, the row is already the data we need - record_data = dict(row) - - # Convert JSONB fields back to proper Python objects - for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: - if field in record_data: - value = record_data[field] - - if value is None: - record_data[field] = {} if field == 'metadata' else [] - elif isinstance(value, str): - # Handle JSONB strings - they should always be valid JSON - try: - parsed = json.loads(value) - # Validate the parsed type - if field == 'metadata': - record_data[field] = parsed if isinstance(parsed, dict) else {} - else: - record_data[field] = parsed if isinstance(parsed, list) else [] - except (json.JSONDecodeError, TypeError): - log.warning(f"Failed to parse JSONB field {field}: {value}") - # Fallback for non-JSON strings - if field == 'metadata': - record_data[field] = {} - else: - record_data[field] = [] - elif isinstance(value, (list, dict)): - # Already correct type (shouldn't happen with JSONB, but handle it) - record_data[field] = value - else: - # Convert other types - if field == 'metadata': - record_data[field] = {'value': str(value)} if value else {} - else: - record_data[field] = [value] if value else [] - else: - # Handle generic table data (JSONB data column) - if isinstance(row['data'], dict): - record_data = dict(row['data']) - else: - # If it's a string, parse it as JSON - record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] + record_data = dict(row['data']) if isinstance(row['data'], dict) else json.loads(row['data']) - # Clean up the record data - convert None to proper values and handle NaN - cleaned_data = {} - for key_name, value in record_data.items(): - if self._is_scalar_na(value) or value is None: - cleaned_data[key_name] = None - elif isinstance(value, str) and key_name == 'text_unit_ids' and not is_typed_table: - # Try to parse text_unit_ids back from JSON string if needed (only for generic tables) - try: - parsed_value = json.loads(value) - cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [value] - except (json.JSONDecodeError, TypeError): - # If it's not JSON, treat as a single item list or keep as string - cleaned_data[key_name] = [value] if value else [] - elif key_name in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids'] and not is_typed_table: - # Always ensure these columns are lists (only for generic tables - typed tables already handled this) - if isinstance(value, list): - cleaned_data[key_name] = value - elif isinstance(value, str): - try: - parsed_value = json.loads(value) - cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [] - except (json.JSONDecodeError, TypeError): - cleaned_data[key_name] = [] - elif value is None: - cleaned_data[key_name] = [] - else: - # fallback: wrap single value in a list - cleaned_data[key_name] = [value] - elif isinstance(value, (list, np.ndarray)) and len(value) == 0: - # Handle empty arrays/lists - cleaned_data[key_name] = [] - else: - cleaned_data[key_name] = value + # Add back the ID and human_readable_id + record_data['id'] = row['id'] + record_data['human_readable_id'] = row['human_readable_id'] + + # Parse JSONB list fields back to proper Python lists + for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids']: + if field in record_data: + record_data[field] = self._parse_jsonb_field(record_data[field], "list") + + # Parse metadata as dict + if 'metadata' in record_data: + record_data['metadata'] = self._parse_jsonb_field(record_data['metadata'], "dict") - # Always include the ID column for GraphRAG compatibility - # Use the storage ID as is since we simplified ID handling - storage_id = row['id'] - cleaned_data['id'] = storage_id - records.append(cleaned_data) + records.append(record_data) df = pd.DataFrame(records) - - # Additional cleanup for NaN values in the DataFrame + + # Clean up NaN values df = df.where(pd.notna(df), None) + log.info(f"Created DataFrame with shape: {df.shape}") - log.info(f"Table {table_name} DataFrame columns: {df.columns.tolist()}") + log.info(f"DataFrame columns: {df.columns.tolist()}") - if len(df) > 0: - log.debug(f"Table {table_name} Sample record: {df.iloc[0].to_dict()}") - # Debug: Check if children column exists and its type - if 'children' in df.columns: - sample_children = df.iloc[0]['children'] - log.debug(f"Table {table_name} Sample children value: {sample_children}, type: {type(sample_children)}") - - # Handle bytes conversion for GraphRAG compatibility + # Convert to bytes if requested if as_bytes or kwargs.get("as_bytes"): - log.info(f"Converting DataFrame to parquet bytes for key: {key}") - - # Apply column filtering similar to Milvus implementation - df_clean = df.copy() - - # Define expected columns for each data type - if 'documents' in table_name: - expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] - elif 'entities' in table_name: - expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'] - elif 'relationships' in table_name: - expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] - if 'combined_degree' in df_clean.columns: - expected_columns.append('combined_degree') - elif 'text_units' in table_name: - expected_columns = ['id', 'human_readable_id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] - elif 'communities' in table_name: - expected_columns = ['id', 'human_readable_id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] - else: - expected_columns = list(df_clean.columns) - - # Filter columns - available_columns = [col for col in expected_columns if col in df_clean.columns] - if available_columns != expected_columns: - missing = set(expected_columns) - set(available_columns) - extra = set(df_clean.columns) - set(expected_columns) - log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") - - df_clean = df_clean[available_columns] - log.info(f"Table {table_name} final filtered columns: {df_clean.columns.tolist()}") - - # Convert to parquet bytes - try: - # Handle list columns for PyArrow compatibility - df_for_parquet = df_clean.copy() - - # For PyArrow/parquet compatibility, we need to handle list columns carefully - # Instead of converting to JSON strings, let's try a different approach - list_columns = [] - for col in df_for_parquet.columns: - if col in df_for_parquet.columns and len(df_for_parquet) > 0: - # Check if this column contains lists - first_non_null = None - for val in df_for_parquet[col]: - if isinstance(val, list): - first_non_null = val - break - elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): - first_non_null = val - break - - if isinstance(first_non_null, list) or col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: - list_columns.append(col) - # Ensure all values in this column are proper lists - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: x if isinstance(x, list) else ([] if x is None or pd.isna(x) else [str(x)]) - ) - - if list_columns: - log.info(f"Ensured list columns are proper lists for parquet: {list_columns}") - - # Try to convert to parquet without JSON string conversion - buffer = BytesIO() - df_for_parquet.to_parquet(buffer, engine='pyarrow') - buffer.seek(0) - parquet_bytes = buffer.getvalue() - log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") - return parquet_bytes - except Exception as e: - log.warning(f"Direct parquet conversion failed: {e}, trying with JSON string conversion") - - # Fallback: convert lists to JSON strings - try: - df_for_parquet = df_clean.copy() - - # Convert list columns to JSON strings for parquet compatibility - list_columns = [] - for col in df_for_parquet.columns: - if col in df_for_parquet.columns and len(df_for_parquet) > 0: - # Check if this column contains lists - first_non_null = None - for val in df_for_parquet[col]: - if isinstance(val, list): - first_non_null = val - break - elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): - first_non_null = val - break - if isinstance(first_non_null, list): - list_columns.append(col) - # Convert lists to JSON strings - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - ) - elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: - # These columns should always be lists, even if empty - list_columns.append(col) - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - ) - - if list_columns: - log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") - - buffer = BytesIO() - df_for_parquet.to_parquet(buffer, engine='pyarrow') - buffer.seek(0) - parquet_bytes = buffer.getvalue() - log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data (with JSON conversion)") - return parquet_bytes - except Exception as e2: - log.exception(f"Failed to convert DataFrame to parquet bytes: {e2}") - return b"" + return self._convert_dataframe_to_parquet_bytes(df) return df @@ -1059,90 +300,70 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None return None async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Insert data into PostgreSQL table with drop/recreate to avoid duplicates.""" + """Store data in PostgreSQL table.""" try: table_name = self._get_table_name(key) log.info(f"Setting data for key: {key}, table: {table_name}") - # Use new table creation approach with duplicate prevention - await self._ensure_table_exists_with_schema(table_name) + await self._ensure_table_exists(table_name) conn = await self._get_connection() try: if isinstance(value, bytes): - # Parse parquet data + # Handle parquet data df = pd.read_parquet(BytesIO(value)) log.info(f"Parsed parquet data, DataFrame shape: {df.shape}") - # output sample record for debugging - log.debug(f"Table {table_name} Sample record (first row): {df.iloc[0].to_dict()}") - log.info(f"Parsed DataFrame columns: {df.columns.tolist()}") - - # Prepare data for PostgreSQL (typed or generic) - records = self._prepare_data_for_postgres(df, table_name) + records = self._prepare_dataframe_for_storage(df) if records: - # Use batch insert for much better performance - await self._batch_upsert_records(conn, table_name, records) - - log.info(f"Successfully inserted {len(records)} records to {table_name}") - - # Log ID handling info - if any(record['id'].split(':')[0] in self._no_id_prefixes for record in records if 'id' in record): - log.info(f"Some records used auto-generated IDs in table {table_name}") - + await self._batch_upsert(conn, table_name, records) + log.info(f"Successfully stored {len(records)} records to {table_name}") else: - # Handle non-parquet data (e.g., JSON, stats) - always use generic table - log.info(f"Handling non-parquet data for key: {key}") - + # Handle non-parquet data (JSON, etc.) record_data = json.loads(value) if isinstance(value, str) else value - - # Use generic table insertion for non-parquet data - records = [{ + record = { 'id': key, - 'data': record_data, - 'metadata': {'type': 'non_parquet', 'created': datetime.now(timezone.utc).isoformat()} - }] + 'human_readable_id': 0, + 'data': record_data + } - await self._batch_upsert_generic_records(conn, table_name, records) - log.info("Non-parquet data insertion successful") + await conn.execute( + f"""INSERT INTO {table_name} (id, human_readable_id, data, updated_at) + VALUES ($1, $2, $3, NOW()) + ON CONFLICT (id) DO UPDATE SET + data = EXCLUDED.data, updated_at = NOW()""", + record['id'], record['human_readable_id'], json.dumps(record['data']) + ) + log.info(f"Successfully stored non-parquet data for key: {key}") finally: await self._release_connection(conn) except Exception as e: - log.exception("Error setting data in PostgreSQL table %s: %s", table_name, e) + log.exception(f"Error setting data for key {key}: {e}") raise async def has(self, key: str) -> bool: """Check if data exists for the given key.""" try: table_name = self._get_table_name(key) - log.info(f"Checking existence for key: {key}, table: {table_name}") conn = await self._get_connection() try: - # Check if table exists table_exists = await conn.fetchval( "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", table_name ) - log.debug(f"Table {table_name} exists: {table_exists}") if not table_exists: return False if key.endswith('.parquet'): - # For parquet files, check if table has any records - total_count = await conn.fetchval( - f"SELECT COUNT(*) FROM {table_name}" - ) - if total_count > 0: - return True - else: - raise ValueError(f"No records found in table {table_name} for parquet key {key}") + # For parquet files, check if table has records + count = await conn.fetchval(f"SELECT COUNT(*) FROM {table_name}") + return count > 0 else: - # Check for exact key match + # For specific keys, check exact match exists = await conn.fetchval( - f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", - key + f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", key ) return exists @@ -1150,7 +371,7 @@ async def has(self, key: str) -> bool: await self._release_connection(conn) except Exception as e: - log.exception("Error checking existence for key %s: %s", key, e) + log.exception(f"Error checking existence for key {key}: {e}") return False async def delete(self, key: str) -> None: @@ -1160,18 +381,17 @@ async def delete(self, key: str) -> None: conn = await self._get_connection() try: await conn.execute(f"DROP TABLE IF EXISTS {table_name}") - log.info(f"Deleted record for key {key}") + log.info(f"Deleted table for key {key}") finally: await self._release_connection(conn) except Exception as e: - log.exception("Error deleting key %s: %s", key, e) + log.exception(f"Error deleting key {key}: {e}") async def clear(self) -> None: """Clear all tables with the configured prefix.""" try: conn = await self._get_connection() try: - # Get all tables with our prefix tables = await conn.fetch( "SELECT table_name FROM information_schema.tables WHERE table_name LIKE $1", f"{self._collection_prefix}%" @@ -1188,12 +408,10 @@ async def clear(self) -> None: await self._release_connection(conn) except Exception as e: - log.exception("Error clearing tables: %s", e) + log.exception(f"Error clearing tables: {e}") def keys(self) -> list[str]: """Return the keys in the storage.""" - # This would need to be async to properly implement - # For now, return empty list log.warning("keys() method not fully implemented for async storage") return [] @@ -1201,6 +419,17 @@ def child(self, name: str | None) -> PipelineStorage: """Create a child storage instance.""" return self + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find data in PostgreSQL tables using a file pattern regex.""" + log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) + return iter([]) + async def get_creation_date(self, key: str) -> str: """Get the creation date for data.""" try: @@ -1208,15 +437,12 @@ async def get_creation_date(self, key: str) -> str: conn = await self._get_connection() try: if key.endswith('.parquet'): - prefix = self._get_prefix(key) created_at = await conn.fetchval( - f"SELECT MIN(created_at) FROM {table_name} WHERE id LIKE $1", - f"{prefix}:%" + f"SELECT MIN(created_at) FROM {table_name}" ) else: created_at = await conn.fetchval( - f"SELECT created_at FROM {table_name} WHERE id = $1", - key + f"SELECT created_at FROM {table_name} WHERE id = $1", key ) if created_at: @@ -1226,7 +452,7 @@ async def get_creation_date(self, key: str) -> str: await self._release_connection(conn) except Exception as e: - log.exception("Error getting creation date for %s: %s", key, e) + log.exception(f"Error getting creation date for {key}: {e}") return "" @@ -1234,4 +460,4 @@ async def close(self) -> None: """Close the connection pool.""" if self._pool: await self._pool.close() - log.info("Closed PostgreSQL connection pool") + log.info("Closed PostgreSQL connection pool") \ No newline at end of file diff --git a/graphrag/storage/postgres_pipeline_storage2.py b/graphrag/storage/postgres_pipeline_storage2.py new file mode 100644 index 0000000000..0a265d8316 --- /dev/null +++ b/graphrag/storage/postgres_pipeline_storage2.py @@ -0,0 +1,1237 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""PostgreSQL Storage implementation of PipelineStorage.""" + +import json +import logging +import re +from collections.abc import Iterator +from datetime import datetime, timezone +from io import BytesIO +from typing import Any + +import numpy as np +import pandas as pd +import asyncpg +from asyncpg import Connection, Pool + +from graphrag.storage.pipeline_storage import ( + PipelineStorage, + get_timestamp_formatted_with_local_tz, +) + +log = logging.getLogger(__name__) + +class PostgresPipelineStorage(PipelineStorage): + """The PostgreSQL Storage Implementation.""" + + _pool: Pool | None + _connection_string: str + _database: str + _collection_prefix: str + _encoding: str + _no_id_prefixes: list[str] + + def __init__( + self, + host: str = "localhost", + port: int = 5432, + database: str = "graphrag", + username: str = "postgres", + password: str | None = None, + collection_prefix: str = "lgr_", + encoding: str = "utf-8", + connection_string: str | None = None, + command_timeout: int = 600, # 10 minutes for SQL commands + server_timeout: int = 120, # 2 minutes for server connection + connection_timeout: int = 60, # 1 minute to establish connection + batch_size: int = 50, # Smaller batch size to reduce timeout risk + **kwargs: Any, + ): + """Initialize the PostgreSQL Storage.""" + self._host = host + self._port = port + self._database = database + self._username = username + self._password = password + self._collection_prefix = collection_prefix + self._encoding = encoding + self._command_timeout = command_timeout + self._server_timeout = server_timeout + self._connection_timeout = connection_timeout + self._batch_size = batch_size + self._no_id_prefixes = [] + self._pool = None + + # Build connection string from components or use provided one + if connection_string: + self._connection_string = connection_string + else: + if password: + self._connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}" + else: + self._connection_string = f"postgresql://{username}@{host}:{port}/{database}" + + log.info( + "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s, command_timeout: %s, batch_size: %s", + self._host, + self._port, + self._database, + self._collection_prefix, + self._command_timeout, + self._batch_size, + ) + + async def _get_connection(self) -> Connection: + """Get a database connection from the pool.""" + if self._pool is None: + try: + self._pool = await asyncpg.create_pool( + self._connection_string, + min_size=1, + max_size=10, + command_timeout=self._command_timeout, + server_settings={ + 'application_name': 'graphrag_postgres_storage' + }, + # Use connection_timeout for initial connection establishment + timeout=self._connection_timeout + ) + log.info("Created PostgreSQL connection pool with command_timeout: %s, connection_timeout: %s", + self._command_timeout, self._connection_timeout) + except Exception as e: + log.error("Failed to create PostgreSQL connection pool: %s", e) + raise + + return await self._pool.acquire() + + async def _release_connection(self, conn: Connection) -> None: + """Release a connection back to the pool.""" + if self._pool: + await self._pool.release(conn) + + def _get_table_name(self, key: str) -> str: + """Get the table name for a given key.""" + # Extract the base name without file extension + base_name = key.split(".")[0] + + return f"{self._collection_prefix}{base_name}" + + def _get_prefix(self, key: str) -> str: + """Get the prefix of the filename key.""" + return key.split(".")[0] + + def _get_entities_table_schema(self, table_name: str) -> str: + """Get the SQL schema for entities table.""" + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + title TEXT, + type TEXT, + description TEXT, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + frequency INTEGER DEFAULT 0, + degree INTEGER DEFAULT 0, + x DOUBLE PRECISION DEFAULT 0.0, + y DOUBLE PRECISION DEFAULT 0.0, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Performance indexes + CREATE INDEX idx_{table_name}_type ON {table_name}(type); + CREATE INDEX idx_{table_name}_frequency ON {table_name}(frequency); + CREATE INDEX idx_{table_name}_title ON {table_name}(title); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + """ + + def _get_relationships_table_schema(self, table_name: str) -> str: + """Get the SQL schema for relationships table.""" + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + source TEXT NOT NULL, + target TEXT NOT NULL, + description TEXT DEFAULT '', + weight DOUBLE PRECISION DEFAULT 0.0, + combined_degree INTEGER DEFAULT 0, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Graph query indexes + CREATE INDEX idx_{table_name}_source ON {table_name}(source); + CREATE INDEX idx_{table_name}_target ON {table_name}(target); + CREATE INDEX idx_{table_name}_weight ON {table_name}(weight); + CREATE INDEX idx_{table_name}_source_target ON {table_name}(source, target); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + """ + + def _get_communities_table_schema(self, table_name: str) -> str: + """Get the SQL schema for communities table.""" + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + community INTEGER, + level INTEGER DEFAULT 0, + parent INTEGER, + children JSONB DEFAULT '[]'::jsonb, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + entity_ids JSONB DEFAULT '[]'::jsonb, + relationship_ids JSONB DEFAULT '[]'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Community hierarchy indexes + CREATE INDEX idx_{table_name}_community ON {table_name}(community); + CREATE INDEX idx_{table_name}_level ON {table_name}(level); + CREATE INDEX idx_{table_name}_parent ON {table_name}(parent); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); + CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); + """ + + def _get_text_units_table_schema(self, table_name: str) -> str: + """Get the SQL schema for text_units table.""" + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + text TEXT, + n_tokens INTEGER DEFAULT 0, + document_ids JSONB DEFAULT '[]'::jsonb, + entity_ids JSONB DEFAULT '[]'::jsonb, + relationship_ids JSONB DEFAULT '[]'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Text search and relationship indexes + CREATE INDEX idx_{table_name}_n_tokens ON {table_name}(n_tokens); + CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); + CREATE INDEX idx_{table_name}_document_ids_gin ON {table_name} USING GIN(document_ids); + CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); + CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); + """ + + def _get_documents_table_schema(self, table_name: str) -> str: + """Get the SQL schema for documents table.""" + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + human_readable_id BIGINT, + title TEXT, + text TEXT, + text_unit_ids JSONB DEFAULT '[]'::jsonb, + creation_date TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + metadata JSONB DEFAULT '{{}}'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Document search indexes + CREATE INDEX idx_{table_name}_title ON {table_name}(title); + CREATE INDEX idx_{table_name}_creation_date ON {table_name}(creation_date); + CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); + CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); + CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); + """ + + def _get_generic_table_schema(self, table_name: str) -> str: + """Get the SQL schema for generic data (fallback).""" + return f""" + CREATE TABLE {table_name} ( + id TEXT PRIMARY KEY, + data JSONB NOT NULL, + metadata JSONB DEFAULT '{{}}'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Generic indexes + CREATE INDEX idx_{table_name}_data_gin ON {table_name} USING GIN(data); + CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); + """ + + def _get_table_schema_sql(self, table_name: str) -> str: + """Get the appropriate schema SQL for the table type.""" + + if 'entities' in table_name: + return self._get_entities_table_schema(table_name) + elif 'relationships' in table_name: + return self._get_relationships_table_schema(table_name) + elif 'communities' in table_name: + return self._get_communities_table_schema(table_name) + elif 'text_units' in table_name: + return self._get_text_units_table_schema(table_name) + elif 'documents' in table_name: + return self._get_documents_table_schema(table_name) + else: + return self._get_generic_table_schema(table_name) + + async def _ensure_table_exists_with_schema(self, table_name: str) -> None: + conn = await self._get_connection() + try: + table_exists = await conn.fetchval( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + table_name + ) + if not table_exists: + # Create table with appropriate typed schema (pass original table_name for type detection) + schema_sql = self._get_table_schema_sql(table_name) + await conn.execute(schema_sql) + log.info(f"Created table {table_name} with specific schema") + + finally: + await self._release_connection(conn) + + def _process_id_field(self, df: pd.DataFrame, table_name: str) -> list[str]: + """Process ID values - store clean IDs with prefix following CosmosDB pattern in GraphRAG.""" + prefix = self._get_prefix(table_name.replace(self._collection_prefix, "")) + id_values = [] + + if "id" not in df.columns: + # No ID column - create prefixed sequential IDs and track this prefix + for index in range(len(df)): + id_values.append(f"{prefix}:{index}") + if prefix not in self._no_id_prefixes: + self._no_id_prefixes.append(prefix) + log.info(f"No ID column found for {prefix}, generated prefixed sequential IDs") + else: + # Has ID column - process each row with prefix + for index, val in enumerate(df["id"]): + if self._is_scalar_na(val) or val == '' or val == 'nan' or (isinstance(val, list) and (len(val) == 0 or self._is_scalar_na(val[0]) or str(val[0]).strip() == '')): + # Missing ID - create prefixed sequential ID and track this prefix + id_values.append(f"{prefix}:{index}") + if prefix not in self._no_id_prefixes: + self._no_id_prefixes.append(prefix) + else: + # Valid ID - use as is without prefix + if isinstance(val, list): + id_values.append(str(val[0])) + else: + id_values.append(str(val)) + + return id_values + + def _is_scalar_na(self, value: Any) -> bool: + """Safely check if a value is NA/null, avoiding issues with arrays.""" + try: + # Don't check pd.isna on complex objects or large arrays + if isinstance(value, (list, dict)): + return False + if hasattr(value, '__len__') and len(str(value)) > 100: + return False + return pd.isna(value) + except (ValueError, TypeError): + # If pd.isna fails, assume it's not NA + return False + + def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[dict]: + """Prepare DataFrame data for PostgreSQL insertion with typed columns.""" + log.info(f"Preparing data for table {table_name}, DataFrame shape: {df.shape}") + log.info(f"DataFrame columns: {df.columns.tolist()}") + + # Add human_readable_id if missing + if 'human_readable_id' not in df.columns: + df = df.copy() + df['human_readable_id'] = range(len(df)) + log.info(f"Generated sequential human_readable_id for {len(df)} records") + + # Process IDs - for typed tables, we can use simpler ID handling + ids = self._process_id_field(df, table_name) + + # Determine if this is a typed table or generic table + is_typed_table = any(table_type in table_name for table_type in + ['entities', 'relationships', 'communities', 'text_units', 'documents']) + + if is_typed_table: + return self._prepare_data_for_typed_table(df, table_name, ids) + else: + return self._prepare_data_for_generic_table(df, table_name, ids) + + def _prepare_data_for_typed_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: + """Prepare data for typed PostgreSQL tables with specific columns.""" + records = [] + + for i in range(len(df)): + record = {'id': ids[i]} + row = df.iloc[i] + + # Map DataFrame columns to table columns based on table type + if 'entities' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'title': str(row.get('title', '')), + 'type': str(row.get('type', '')), + 'description': str(row.get('description', '')), + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), + 'frequency': int(row.get('frequency', 0)) if pd.notna(row.get('frequency', 0)) else 0, + 'degree': int(row.get('degree', 0)) if pd.notna(row.get('degree', 0)) else 0, + 'x': float(row.get('x', 0.0)) if pd.notna(row.get('x', 0.0)) else 0.0, + 'y': float(row.get('y', 0.0)) if pd.notna(row.get('y', 0.0)) else 0.0 + }) + elif 'relationships' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'source': str(row.get('source', '')), + 'target': str(row.get('target', '')), + 'description': str(row.get('description', '')), + 'weight': float(row.get('weight', 0.0)) if pd.notna(row.get('weight', 0.0)) else 0.0, + 'combined_degree': int(row.get('combined_degree', 0)) if pd.notna(row.get('combined_degree', 0)) else 0, + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])) + }) + elif 'communities' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'community': int(row.get('community', 0)) if pd.notna(row.get('community')) and str(row.get('community', '')).strip() != '' else 0, + 'level': int(row.get('level', 0)) if pd.notna(row.get('level', 0)) else 0, + 'parent': int(row.get('parent', 0)) if pd.notna(row.get('parent')) and str(row.get('parent', '')).strip() != '' else None, + 'children': self._ensure_json_list(row.get('children', [])), + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), + 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), + 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) + }) + elif 'text_units' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'text': str(row.get('text', '')), + 'n_tokens': int(row.get('n_tokens', 0)) if pd.notna(row.get('n_tokens', 0)) else 0, + 'document_ids': self._ensure_json_list(row.get('document_ids', [])), + 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), + 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) + }) + elif 'documents' in table_name: + record.update({ + 'human_readable_id': int(row.get('human_readable_id', i)), + 'title': str(row.get('title', '')), + 'text': str(row.get('text', '')), + 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), + 'creation_date': self._ensure_datetime(row.get('creation_date')), + 'metadata': self._ensure_json_dict(row.get('metadata', {})) + }) + + records.append(record) + + log.info(f"Prepared {len(records)} records for typed table {table_name}") + if records: + log.info(f"Sample typed record: {list(records[0].keys())}") + + return records + + def _prepare_data_for_generic_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: + """Prepare data for generic PostgreSQL tables (fallback to JSONB storage).""" + records = [] + for i in range(len(df)): + # Create record with ID and all data in JSONB field + record_data = df.iloc[i].to_dict() + + # Convert numpy types to native Python types for JSON serialization + for key, value in record_data.items(): + if isinstance(value, (list, dict)): + record_data[key] = value + elif hasattr(value, 'tolist'): + # Handle numpy arrays and other numpy types + record_data[key] = value.tolist() + elif hasattr(value, 'item') and hasattr(value, 'size') and value.size == 1: + record_data[key] = value.item() + elif isinstance(value, (pd.Timestamp, pd.DatetimeTZDtype)): + record_data[key] = value.isoformat() if pd.notna(value) else None + elif self._is_scalar_na(value): + record_data[key] = None + elif isinstance(value, (list, np.ndarray)) and len(value) == 0: + record_data[key] = [] + else: + record_data[key] = value + + record = { + 'id': ids[i], + 'data': record_data, + 'metadata': {} + } + records.append(record) + + log.info(f"Prepared {len(records)} records for generic table {table_name}") + return records + + def _ensure_json_list(self, value: Any) -> list: + """Ensure a value is a proper list for JSONB storage.""" + if isinstance(value, list): + # Convert any numpy arrays in the list to regular Python lists + return [item.tolist() if hasattr(item, 'tolist') else item for item in value] + elif hasattr(value, 'tolist'): + # Handle numpy arrays directly + converted = value.tolist() + return converted if isinstance(converted, list) else [converted] + elif isinstance(value, str) and value: + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, list) else [] + except (json.JSONDecodeError, TypeError): + return [] + elif value is None or pd.isna(value): + return [] + else: + return [value] if value else [] + + def _ensure_json_dict(self, value: Any) -> dict: + """Ensure a value is a proper dict for JSONB storage.""" + if isinstance(value, dict): + # Convert any numpy arrays in the dict to regular Python objects + result = {} + for k, v in value.items(): + if hasattr(v, 'tolist'): + result[k] = v.tolist() + elif hasattr(v, 'item') and hasattr(v, 'size') and v.size == 1: + result[k] = v.item() + else: + result[k] = v + return result + elif isinstance(value, str) and value: + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, dict) else {} + except (json.JSONDecodeError, TypeError): + return {} + elif value is None or pd.isna(value): + return {} + else: + return {'value': str(value)} if value else {} + + def _ensure_timezone_aware_datetimes(self, records: list[dict]) -> list[dict]: + """Ensure all datetime fields in records are timezone-aware for PostgreSQL.""" + datetime_fields = ['creation_date', 'created_at', 'updated_at'] + + for record in records: + for field in datetime_fields: + if field in record: + value = record[field] + if value is not None: + record[field] = self._ensure_datetime(value) + + return records + + def _ensure_datetime(self, value: Any) -> datetime: + """Ensure a value is a proper timezone-aware datetime object for PostgreSQL storage.""" + from dateutil import parser + + if isinstance(value, datetime): + # If it's already a datetime, ensure it has timezone info + if value.tzinfo is None: + # If it's timezone-naive, localize to UTC + return value.replace(tzinfo=timezone.utc) + else: + # Already timezone-aware + return value + elif isinstance(value, pd.Timestamp): + # Convert pandas Timestamp to datetime + dt = value.to_pydatetime() + # Ensure timezone awareness + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + else: + return dt + elif isinstance(value, str) and value: + try: + # Try to parse the string as a datetime + parsed_dt = parser.parse(value) + # Ensure timezone awareness + if parsed_dt.tzinfo is None: + return parsed_dt.replace(tzinfo=timezone.utc) + else: + return parsed_dt + except (ValueError, TypeError): + # If parsing fails, return current time + return datetime.now(timezone.utc) + elif value is None or pd.isna(value): + return datetime.now(timezone.utc) + else: + # For any other type, return current time + return datetime.now(timezone.utc) + + async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict]) -> None: + """Perform high-performance batch upsert of records using executemany.""" + total_records = len(records) + log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {self._batch_size}") + + # Ensure all datetime fields are timezone-aware + records = self._ensure_timezone_aware_datetimes(records) + + processed_count = 0 + + # Determine if this is a typed table or generic table + is_typed_table = any(table_type in table_name for table_type in + ['entities', 'relationships', 'communities', 'text_units', 'documents']) + + # Process records in batches for optimal performance + for i in range(0, total_records, self._batch_size): + batch = records[i:i + self._batch_size] + batch_end = min(i + self._batch_size, total_records) + + try: + if is_typed_table: + await self._batch_upsert_typed_records(conn, table_name, batch) + else: + await self._batch_upsert_generic_records(conn, table_name, batch) + + except Exception as e: + log.warning(f"Batch method failed for batch {i}-{batch_end}, falling back to individual inserts: {e}") + + # Fallback to individual inserts within the batch + try: + async with conn.transaction(): + if is_typed_table: + for record in batch: + await self._insert_typed_record(conn, table_name, record) + else: + upsert_sql = f""" + INSERT INTO {table_name} (id, data, updated_at) + VALUES ($1, $2, NOW()) + ON CONFLICT (id) + DO UPDATE SET + data = EXCLUDED.data, + updated_at = NOW() + """ + for record in batch: + await conn.execute(upsert_sql, record['id'], json.dumps(record['data'])) + except Exception as inner_e: + log.error(f"Both batch and individual insert methods failed for batch {i}-{batch_end}: {inner_e}") + raise + + processed_count += len(batch) + + # Log progress every batch for visibility + if i % self._batch_size == 0 or batch_end == total_records: + log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") + + async def _batch_upsert_typed_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: + """Batch upsert for typed tables with specific columns.""" + async with conn.transaction(): + + if 'entities' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + title = EXCLUDED.title, + type = EXCLUDED.type, + description = EXCLUDED.description, + text_unit_ids = EXCLUDED.text_unit_ids, + frequency = EXCLUDED.frequency, + degree = EXCLUDED.degree, + x = EXCLUDED.x, + y = EXCLUDED.y, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['title'], r['type'], r['description'], + json.dumps(r['text_unit_ids']), r['frequency'], r['degree'], r['x'], r['y']) + for r in batch + ] + elif 'relationships' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + source = EXCLUDED.source, + target = EXCLUDED.target, + description = EXCLUDED.description, + weight = EXCLUDED.weight, + combined_degree = EXCLUDED.combined_degree, + text_unit_ids = EXCLUDED.text_unit_ids, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['source'], r['target'], r['description'], + r['weight'], r['combined_degree'], json.dumps(r['text_unit_ids'])) + for r in batch + ] + elif 'communities' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + community = EXCLUDED.community, + level = EXCLUDED.level, + parent = EXCLUDED.parent, + children = EXCLUDED.children, + text_unit_ids = EXCLUDED.text_unit_ids, + entity_ids = EXCLUDED.entity_ids, + relationship_ids = EXCLUDED.relationship_ids, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['community'], r['level'], r['parent'], + json.dumps(r['children']), json.dumps(r['text_unit_ids']), + json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) + for r in batch + ] + elif 'text_units' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + text = EXCLUDED.text, + n_tokens = EXCLUDED.n_tokens, + document_ids = EXCLUDED.document_ids, + entity_ids = EXCLUDED.entity_ids, + relationship_ids = EXCLUDED.relationship_ids, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['text'], r['n_tokens'], + json.dumps(r['document_ids']), json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) + for r in batch + ] + elif 'documents' in table_name: + upsert_sql = f""" + INSERT INTO {table_name} (id, human_readable_id, title, text, text_unit_ids, creation_date, metadata, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + ON CONFLICT (id) + DO UPDATE SET + human_readable_id = EXCLUDED.human_readable_id, + title = EXCLUDED.title, + text = EXCLUDED.text, + text_unit_ids = EXCLUDED.text_unit_ids, + creation_date = EXCLUDED.creation_date, + metadata = EXCLUDED.metadata, + updated_at = NOW() + """ + batch_data = [ + (r['id'], r['human_readable_id'], r['title'], r['text'], + json.dumps(r['text_unit_ids']), + self._ensure_datetime(r['creation_date']), + json.dumps(r['metadata'])) + for r in batch + ] + else: + raise ValueError(f"Unknown typed table: {table_name}") + + await conn.executemany(upsert_sql, batch_data) + + async def _batch_upsert_generic_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: + """Batch upsert for generic tables using JSONB.""" + async with conn.transaction(): + upsert_sql = f""" + INSERT INTO {table_name} (id, data, metadata, updated_at) + VALUES ($1, $2, $3, NOW()) + ON CONFLICT (id) + DO UPDATE SET + data = EXCLUDED.data, + metadata = EXCLUDED.metadata, + updated_at = NOW() + """ + batch_data = [ + (record['id'], json.dumps(record['data']), json.dumps(record['metadata'])) + for record in batch + ] + await conn.executemany(upsert_sql, batch_data) + + async def _insert_typed_record(self, conn: Connection, table_name: str, record: dict) -> None: + """Insert a single typed record (fallback method).""" + # This is a simplified fallback - implement based on table type if needed + # For now, just use the batch method with a single record + await self._batch_upsert_typed_records(conn, table_name, [record]) + + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find data in PostgreSQL tables using a file pattern regex.""" + # This is a synchronous method, but we need async operations + # For now, implement a basic version - in practice, this would need refactoring + log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) + + # Note: This is simplified - full implementation would need async/await support + # in the find method signature or use asyncio.run() + return iter([]) + + async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: + """Retrieve data from PostgreSQL table.""" + try: + table_name = self._get_table_name(key) + log.info(f"Retrieving data from table: {table_name}") + + conn = await self._get_connection() + try: + # Check if table exists + table_exists = await conn.fetchval( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + table_name + ) + + if not table_exists: + log.warning(f"Table {table_name} does not exist") + return None + + # Determine if this is a typed table or generic table + is_typed_table = any(table_type in table_name for table_type in + ['entities', 'relationships', 'communities', 'text_units', 'documents']) + + if is_typed_table: + # For typed tables, select all columns except created_at/updated_at + if 'documents' in table_name: + query = "SELECT id, human_readable_id, title, text, text_unit_ids, creation_date, metadata FROM {} ORDER BY created_at".format(table_name) + elif 'entities' in table_name: + query = "SELECT id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y FROM {} ORDER BY created_at".format(table_name) + elif 'relationships' in table_name: + query = "SELECT id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids FROM {} ORDER BY created_at".format(table_name) + elif 'communities' in table_name: + query = "SELECT id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) + elif 'text_units' in table_name: + query = "SELECT id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) + else: + # Fallback for unknown typed table + query = "SELECT * FROM {} ORDER BY created_at".format(table_name) + else: + # For generic tables, use the data column + query = "SELECT id, data FROM {} ORDER BY created_at".format(table_name) + + rows = await conn.fetch(query) + + if not rows: + log.info(f"No data found in table {table_name}") + return None + + log.info(f"Retrieved {len(rows)} records from table {table_name}") + + # Check if this should be treated as raw data instead of tabular data + if (not key.endswith('.parquet') or + 'state' in key.lower() or + key.endswith('.json') or + 'context' in table_name.lower()): + # Handle state.json or context.json as raw data + # For non-tabular data, return the raw content from the first record + if rows: + if is_typed_table: + # For typed tables, convert row to dict and return as JSON + row_dict = dict(rows[0]) + json_str = json.dumps(row_dict) + return json_str.encode(encoding or self._encoding) if as_bytes else json_str + elif 'data' in rows[0]: + raw_content = rows[0]['data'] + if isinstance(raw_content, dict): + json_str = json.dumps(raw_content) + return json_str.encode(encoding or self._encoding) if as_bytes else json_str + return b"" if as_bytes else "" + + # Convert to DataFrame + records = [] + for row in rows: + if is_typed_table: + # For typed tables, the row is already the data we need + record_data = dict(row) + + # Convert JSONB fields back to proper Python objects + for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: + if field in record_data: + value = record_data[field] + + if value is None: + record_data[field] = {} if field == 'metadata' else [] + elif isinstance(value, str): + # Handle JSONB strings - they should always be valid JSON + try: + parsed = json.loads(value) + # Validate the parsed type + if field == 'metadata': + record_data[field] = parsed if isinstance(parsed, dict) else {} + else: + record_data[field] = parsed if isinstance(parsed, list) else [] + except (json.JSONDecodeError, TypeError): + log.warning(f"Failed to parse JSONB field {field}: {value}") + # Fallback for non-JSON strings + if field == 'metadata': + record_data[field] = {} + else: + record_data[field] = [] + elif isinstance(value, (list, dict)): + # Already correct type (shouldn't happen with JSONB, but handle it) + record_data[field] = value + else: + # Convert other types + if field == 'metadata': + record_data[field] = {'value': str(value)} if value else {} + else: + record_data[field] = [value] if value else [] + else: + # Handle generic table data (JSONB data column) + if isinstance(row['data'], dict): + record_data = dict(row['data']) + else: + # If it's a string, parse it as JSON + record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] + + # Clean up the record data - convert None to proper values and handle NaN + cleaned_data = {} + for key_name, value in record_data.items(): + if self._is_scalar_na(value) or value is None: + cleaned_data[key_name] = None + elif isinstance(value, str) and key_name == 'text_unit_ids' and not is_typed_table: + # Try to parse text_unit_ids back from JSON string if needed (only for generic tables) + try: + parsed_value = json.loads(value) + cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [value] + except (json.JSONDecodeError, TypeError): + # If it's not JSON, treat as a single item list or keep as string + cleaned_data[key_name] = [value] if value else [] + elif key_name in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids'] and not is_typed_table: + # Always ensure these columns are lists (only for generic tables - typed tables already handled this) + if isinstance(value, list): + cleaned_data[key_name] = value + elif isinstance(value, str): + try: + parsed_value = json.loads(value) + cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [] + except (json.JSONDecodeError, TypeError): + cleaned_data[key_name] = [] + elif value is None: + cleaned_data[key_name] = [] + else: + # fallback: wrap single value in a list + cleaned_data[key_name] = [value] + elif isinstance(value, (list, np.ndarray)) and len(value) == 0: + # Handle empty arrays/lists + cleaned_data[key_name] = [] + else: + cleaned_data[key_name] = value + + # Always include the ID column for GraphRAG compatibility + # Use the storage ID as is since we simplified ID handling + storage_id = row['id'] + cleaned_data['id'] = storage_id + records.append(cleaned_data) + + df = pd.DataFrame(records) + + # Additional cleanup for NaN values in the DataFrame + df = df.where(pd.notna(df), None) + log.info(f"Created DataFrame with shape: {df.shape}") + log.info(f"Table {table_name} DataFrame columns: {df.columns.tolist()}") + + if len(df) > 0: + log.debug(f"Table {table_name} Sample record: {df.iloc[0].to_dict()}") + # Debug: Check if children column exists and its type + if 'children' in df.columns: + sample_children = df.iloc[0]['children'] + log.debug(f"Table {table_name} Sample children value: {sample_children}, type: {type(sample_children)}") + + # Handle bytes conversion for GraphRAG compatibility + if as_bytes or kwargs.get("as_bytes"): + log.info(f"Converting DataFrame to parquet bytes for key: {key}") + + # Apply column filtering similar to Milvus implementation + df_clean = df.copy() + + # Define expected columns for each data type + if 'documents' in table_name: + expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] + elif 'entities' in table_name: + expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'] + elif 'relationships' in table_name: + expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] + if 'combined_degree' in df_clean.columns: + expected_columns.append('combined_degree') + elif 'text_units' in table_name: + expected_columns = ['id', 'human_readable_id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] + elif 'communities' in table_name: + expected_columns = ['id', 'human_readable_id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] + else: + expected_columns = list(df_clean.columns) + + # Filter columns + available_columns = [col for col in expected_columns if col in df_clean.columns] + if available_columns != expected_columns: + missing = set(expected_columns) - set(available_columns) + extra = set(df_clean.columns) - set(expected_columns) + log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") + + df_clean = df_clean[available_columns] + log.info(f"Table {table_name} final filtered columns: {df_clean.columns.tolist()}") + + # Convert to parquet bytes + try: + # Handle list columns for PyArrow compatibility + df_for_parquet = df_clean.copy() + + # For PyArrow/parquet compatibility, we need to handle list columns carefully + # Instead of converting to JSON strings, let's try a different approach + list_columns = [] + for col in df_for_parquet.columns: + if col in df_for_parquet.columns and len(df_for_parquet) > 0: + # Check if this column contains lists + first_non_null = None + for val in df_for_parquet[col]: + if isinstance(val, list): + first_non_null = val + break + elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): + first_non_null = val + break + + if isinstance(first_non_null, list) or col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: + list_columns.append(col) + # Ensure all values in this column are proper lists + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: x if isinstance(x, list) else ([] if x is None or pd.isna(x) else [str(x)]) + ) + + if list_columns: + log.info(f"Ensured list columns are proper lists for parquet: {list_columns}") + + # Try to convert to parquet without JSON string conversion + buffer = BytesIO() + df_for_parquet.to_parquet(buffer, engine='pyarrow') + buffer.seek(0) + parquet_bytes = buffer.getvalue() + log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") + return parquet_bytes + except Exception as e: + log.warning(f"Direct parquet conversion failed: {e}, trying with JSON string conversion") + + # Fallback: convert lists to JSON strings + try: + df_for_parquet = df_clean.copy() + + # Convert list columns to JSON strings for parquet compatibility + list_columns = [] + for col in df_for_parquet.columns: + if col in df_for_parquet.columns and len(df_for_parquet) > 0: + # Check if this column contains lists + first_non_null = None + for val in df_for_parquet[col]: + if isinstance(val, list): + first_non_null = val + break + elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): + first_non_null = val + break + if isinstance(first_non_null, list): + list_columns.append(col) + # Convert lists to JSON strings + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + ) + elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: + # These columns should always be lists, even if empty + list_columns.append(col) + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + ) + + if list_columns: + log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") + + buffer = BytesIO() + df_for_parquet.to_parquet(buffer, engine='pyarrow') + buffer.seek(0) + parquet_bytes = buffer.getvalue() + log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data (with JSON conversion)") + return parquet_bytes + except Exception as e2: + log.exception(f"Failed to convert DataFrame to parquet bytes: {e2}") + return b"" + + return df + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception(f"Error retrieving data from table {table_name}: {e}") + return None + + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Insert data into PostgreSQL table with drop/recreate to avoid duplicates.""" + try: + table_name = self._get_table_name(key) + log.info(f"Setting data for key: {key}, table: {table_name}") + + # Use new table creation approach with duplicate prevention + await self._ensure_table_exists_with_schema(table_name) + + conn = await self._get_connection() + try: + if isinstance(value, bytes): + # Parse parquet data + df = pd.read_parquet(BytesIO(value)) + log.info(f"Parsed parquet data, DataFrame shape: {df.shape}") + # output sample record for debugging + log.debug(f"Table {table_name} Sample record (first row): {df.iloc[0].to_dict()}") + log.info(f"Parsed DataFrame columns: {df.columns.tolist()}") + + # Prepare data for PostgreSQL (typed or generic) + records = self._prepare_data_for_postgres(df, table_name) + + if records: + # Use batch insert for much better performance + await self._batch_upsert_records(conn, table_name, records) + + log.info(f"Successfully inserted {len(records)} records to {table_name}") + + # Log ID handling info + if any(record['id'].split(':')[0] in self._no_id_prefixes for record in records if 'id' in record): + log.info(f"Some records used auto-generated IDs in table {table_name}") + + else: + # Handle non-parquet data (e.g., JSON, stats) - always use generic table + log.info(f"Handling non-parquet data for key: {key}") + + record_data = json.loads(value) if isinstance(value, str) else value + + # Use generic table insertion for non-parquet data + records = [{ + 'id': key, + 'data': record_data, + 'metadata': {'type': 'non_parquet', 'created': datetime.now(timezone.utc).isoformat()} + }] + + await self._batch_upsert_generic_records(conn, table_name, records) + log.info("Non-parquet data insertion successful") + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error setting data in PostgreSQL table %s: %s", table_name, e) + raise + + async def has(self, key: str) -> bool: + """Check if data exists for the given key.""" + try: + table_name = self._get_table_name(key) + log.info(f"Checking existence for key: {key}, table: {table_name}") + conn = await self._get_connection() + try: + # Check if table exists + table_exists = await conn.fetchval( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + table_name + ) + log.debug(f"Table {table_name} exists: {table_exists}") + if not table_exists: + return False + + if key.endswith('.parquet'): + # For parquet files, check if table has any records + total_count = await conn.fetchval( + f"SELECT COUNT(*) FROM {table_name}" + ) + if total_count > 0: + return True + else: + raise ValueError(f"No records found in table {table_name} for parquet key {key}") + else: + # Check for exact key match + exists = await conn.fetchval( + f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", + key + ) + return exists + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error checking existence for key %s: %s", key, e) + return False + + async def delete(self, key: str) -> None: + """Delete data for the given key.""" + try: + table_name = self._get_table_name(key) + conn = await self._get_connection() + try: + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + log.info(f"Deleted record for key {key}") + finally: + await self._release_connection(conn) + except Exception as e: + log.exception("Error deleting key %s: %s", key, e) + + async def clear(self) -> None: + """Clear all tables with the configured prefix.""" + try: + conn = await self._get_connection() + try: + # Get all tables with our prefix + tables = await conn.fetch( + "SELECT table_name FROM information_schema.tables WHERE table_name LIKE $1", + f"{self._collection_prefix}%" + ) + + for table_row in tables: + table_name = table_row['table_name'] + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + log.info(f"Dropped table: {table_name}") + + log.info(f"Cleared all tables with prefix: {self._collection_prefix}") + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error clearing tables: %s", e) + + def keys(self) -> list[str]: + """Return the keys in the storage.""" + # This would need to be async to properly implement + # For now, return empty list + log.warning("keys() method not fully implemented for async storage") + return [] + + def child(self, name: str | None) -> PipelineStorage: + """Create a child storage instance.""" + return self + + async def get_creation_date(self, key: str) -> str: + """Get the creation date for data.""" + try: + table_name = self._get_table_name(key) + conn = await self._get_connection() + try: + if key.endswith('.parquet'): + prefix = self._get_prefix(key) + created_at = await conn.fetchval( + f"SELECT MIN(created_at) FROM {table_name} WHERE id LIKE $1", + f"{prefix}:%" + ) + else: + created_at = await conn.fetchval( + f"SELECT created_at FROM {table_name} WHERE id = $1", + key + ) + + if created_at: + return get_timestamp_formatted_with_local_tz(created_at) + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception("Error getting creation date for %s: %s", key, e) + + return "" + + async def close(self) -> None: + """Close the connection pool.""" + if self._pool: + await self._pool.close() + log.info("Closed PostgreSQL connection pool") From 6e156b354afb183a7f3c772021183d73572729b1 Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Fri, 22 Aug 2025 22:09:58 -0700 Subject: [PATCH 06/12] Save dataframe to json data field --- .../index/operations/finalize_entities.py | 6 +- graphrag/storage/postgres_pipeline_storage.py | 647 ++++++++++++++---- ...ge2.py => postgres_pipeline_storage_bk.py} | 553 ++++++++------- 3 files changed, 817 insertions(+), 389 deletions(-) rename graphrag/storage/{postgres_pipeline_storage2.py => postgres_pipeline_storage_bk.py} (69%) diff --git a/graphrag/index/operations/finalize_entities.py b/graphrag/index/operations/finalize_entities.py index 55fac07d2a..3dba9d1a5c 100644 --- a/graphrag/index/operations/finalize_entities.py +++ b/graphrag/index/operations/finalize_entities.py @@ -22,9 +22,9 @@ def finalize_entities( layout_enabled: bool = False, ) -> pd.DataFrame: """All the steps to transform final entities.""" - # Remove the default column degree, x and y for Postgres storage compatibility. And below entities.merge method - # will add them back with calculated values. - entities = entities.drop(columns=["degree", "x", "y"], errors="ignore") + # # Remove the default column degree, x and y for Postgres storage compatibility. And below entities.merge method + # # will add them back with calculated values. + # entities = entities.drop(columns=["degree", "x", "y"], errors="ignore") graph = create_graph(relationships, edge_attr=["weight"]) graph_embeddings = None if embed_config is not None and embed_config.enabled: diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index 2bb8971558..e28a460d1c 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""PostgreSQL Storage implementation of PipelineStorage""" +"""PostgreSQL Storage implementation of PipelineStorage.""" import json import logging @@ -10,10 +10,12 @@ from io import BytesIO from typing import Any +import numpy as np import pandas as pd import asyncpg -from asyncpg import Connection +from asyncpg import Connection, Pool +from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import ( PipelineStorage, get_timestamp_formatted_with_local_tz, @@ -22,7 +24,13 @@ log = logging.getLogger(__name__) class PostgresPipelineStorage(PipelineStorage): - """Simplified PostgreSQL Storage Implementation.""" + """The PostgreSQL Storage Implementation.""" + + _pool: Pool | None + _connection_string: str + _database: str + _collection_prefix: str + _encoding: str def __init__( self, @@ -34,9 +42,10 @@ def __init__( collection_prefix: str = "lgr_", encoding: str = "utf-8", connection_string: str | None = None, - command_timeout: int = 600, - connection_timeout: int = 60, - batch_size: int = 50, + command_timeout: int = 600, # 10 minutes for SQL commands + server_timeout: int = 120, # 2 minutes for server connection + connection_timeout: int = 60, # 1 minute to establish connection + batch_size: int = 50, # Smaller batch size to reduce timeout risk **kwargs: Any, ): """Initialize the PostgreSQL Storage.""" @@ -48,10 +57,12 @@ def __init__( self._collection_prefix = collection_prefix self._encoding = encoding self._command_timeout = command_timeout + self._server_timeout = server_timeout self._connection_timeout = connection_timeout self._batch_size = batch_size self._pool = None + # Build connection string from components or use provided one if connection_string: self._connection_string = connection_string else: @@ -61,8 +72,13 @@ def __init__( self._connection_string = f"postgresql://{username}@{host}:{port}/{database}" log.info( - "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s", - self._host, self._port, self._database, self._collection_prefix + "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s, command_timeout: %s, batch_size: %s", + self._host, + self._port, + self._database, + self._collection_prefix, + self._command_timeout, + self._batch_size, ) async def _get_connection(self) -> Connection: @@ -74,10 +90,14 @@ async def _get_connection(self) -> Connection: min_size=1, max_size=10, command_timeout=self._command_timeout, - server_settings={'application_name': 'graphrag_postgres_storage'}, + server_settings={ + 'application_name': 'graphrag_postgres_storage' + }, + # Use connection_timeout for initial connection establishment timeout=self._connection_timeout ) - log.info("Created PostgreSQL connection pool") + log.info("Created PostgreSQL connection pool with command_timeout: %s, connection_timeout: %s", + self._command_timeout, self._connection_timeout) except Exception as e: log.error("Failed to create PostgreSQL connection pool: %s", e) raise @@ -91,109 +111,205 @@ async def _release_connection(self, conn: Connection) -> None: def _get_table_name(self, key: str) -> str: """Get the table name for a given key.""" + # Extract the base name without file extension base_name = key.split(".")[0] return f"{self._collection_prefix}{base_name}" - def _get_universal_table_schema(self, table_name: str) -> str: - """Universal schema that works for all GraphRAG data types.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - data JSONB NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_{table_name}_data_gin ON {table_name} USING GIN(data); - CREATE INDEX IF NOT EXISTS idx_{table_name}_created_at ON {table_name}(created_at); - """ + def _get_prefix(self, key: str) -> str: + """Get the prefix of the filename key.""" + return key.split(".")[0] async def _ensure_table_exists(self, table_name: str) -> None: - """Ensure table exists with universal schema.""" + """Ensure a table exists, create if it doesn't.""" conn = await self._get_connection() try: - table_exists = await conn.fetchval( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", - table_name - ) - if not table_exists: - schema_sql = self._get_universal_table_schema(table_name) - await conn.execute(schema_sql) - log.info(f"Created table {table_name}") + # Create table with flexible schema similar to CosmosDB approach + create_sql = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + id TEXT PRIMARY KEY, + data JSONB NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ); + + -- Create indexes for better performance + CREATE INDEX IF NOT EXISTS idx_{table_name}_created_at ON {table_name}(created_at); + CREATE INDEX IF NOT EXISTS idx_{table_name}_data_gin ON {table_name} USING GIN(data); + """ + + await conn.execute(create_sql) + log.debug("Ensured table exists: %s", table_name) finally: await self._release_connection(conn) - def _prepare_dataframe_for_storage(self, df: pd.DataFrame) -> list[dict]: - """Convert DataFrame to records for storage.""" - records = [] + # def _process_id_field(self, df: pd.DataFrame, table_name: str) -> list[str]: + # """Process ID values - store clean IDs with prefix following CosmosDB pattern.""" + # prefix = self._get_prefix(table_name.replace(self._collection_prefix, "")) + # id_values = [] + + # if "id" not in df.columns: + # # No ID column - create prefixed sequential IDs and track this prefix + # for index in range(len(df)): + # id_values.append(f"{prefix}:{index}") + # if prefix not in self._no_id_prefixes: + # self._no_id_prefixes.append(prefix) + # log.info(f"No ID column found for {prefix}, generated prefixed sequential IDs") + # else: + # # Has ID column - process each row with prefix + # for index, val in enumerate(df["id"]): + # if self._is_scalar_na(val) or val == '' or val == 'nan' or (isinstance(val, list) and (len(val) == 0 or self._is_scalar_na(val[0]) or str(val[0]).strip() == '')): + # # Missing ID - create prefixed sequential ID and track this prefix + # id_values.append(f"{prefix}:{index}") + # if prefix not in self._no_id_prefixes: + # self._no_id_prefixes.append(prefix) + # else: + # # Valid ID - use with prefix (following CosmosDB pattern) + # if isinstance(val, list): + # id_values.append(f"{prefix}:{val[0]}") + # else: + # id_values.append(f"{prefix}:{val}") + + # return id_values + + def _is_scalar_na(self, value: Any) -> bool: + """Safely check if a value is NA/null, avoiding issues with arrays.""" + try: + # Don't check pd.isna on complex objects or large arrays + if isinstance(value, (list, dict)): + return False + if hasattr(value, '__len__') and len(str(value)) > 100: + return False + return pd.isna(value) + except (ValueError, TypeError): + # If pd.isna fails, assume it's not NA + return False + + def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[dict]: + """Prepare DataFrame data for PostgreSQL insertion following CosmosDB pattern.""" + log.info(f"Preparing data for table {table_name}, DataFrame shape: {df.shape}") + log.info(f"DataFrame columns: {df.columns.tolist()}") + + # Add human_readable_id if missing + if 'human_readable_id' not in df.columns: + df = df.copy() + df['human_readable_id'] = range(len(df)) + log.info(f"Generated sequential human_readable_id for {len(df)} records") - for i, row in df.iterrows(): - record_data = row.to_dict() + # Process IDs with prefix + ids = df['id'].astype(str).tolist() if 'id' in df.columns else [f"{self._get_prefix(table_name.replace(self._collection_prefix, ''))}:{i}" for i in range(len(df))] + + records = [] + for i in range(len(df)): + # Create record with ID and all data in JSONB field + record_data = df.iloc[i].to_dict() - # Convert pandas/numpy types to JSON-serializable types + # Convert numpy types to native Python types for JSON serialization for key, value in record_data.items(): - if pd.isna(value): - record_data[key] = None - elif isinstance(value, (list, dict)): + # Handle different value types carefully + if isinstance(value, (list, dict)): + # Keep lists and dicts as-is (like text_unit_ids) record_data[key] = value - elif hasattr(value, 'tolist'): - record_data[key] = value.tolist() elif hasattr(value, 'item') and hasattr(value, 'size') and value.size == 1: + # Only use .item() for numpy scalars (arrays of size 1) record_data[key] = value.item() + elif hasattr(value, 'tolist'): + # Convert numpy arrays to Python lists + record_data[key] = value.tolist() elif isinstance(value, (pd.Timestamp, pd.DatetimeTZDtype)): record_data[key] = value.isoformat() if pd.notna(value) else None + elif self._is_scalar_na(value): + # Only check pd.isna for scalar-like values + record_data[key] = None + elif isinstance(value, (list, np.ndarray)) and len(value) == 0: + # Handle empty arrays/lists + record_data[key] = [] else: record_data[key] = value - - # Extract ID and human_readable_id - record_id = record_data.pop('id', f"record_{i}") - human_readable_id = record_data.pop('human_readable_id', i) - records.append({ - 'id': str(record_id), - 'human_readable_id': int(human_readable_id) if pd.notna(human_readable_id) else i, + record = { + 'id': ids[i], 'data': record_data - }) + } + records.append(record) + log.info(f"Prepared {len(records)} records for PostgreSQL") return records - async def _batch_upsert(self, conn: Connection, table_name: str, records: list[dict]) -> None: - """Perform batch upsert of records.""" + async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict], batch_size: int = 1000) -> None: + """Perform high-performance batch upsert of records using executemany.""" total_records = len(records) - log.info(f"Starting batch upsert of {total_records} records to {table_name}") + log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {batch_size}") - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, data, updated_at) - VALUES ($1, $2, $3, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - data = EXCLUDED.data, - updated_at = NOW() - """ + processed_count = 0 - # Process in batches - for i in range(0, total_records, self._batch_size): - batch = records[i:i + self._batch_size] - batch_data = [ - (record['id'], record['human_readable_id'], json.dumps(record['data'])) - for record in batch - ] + # Process records in batches for optimal performance + for i in range(0, total_records, batch_size): + batch = records[i:i + batch_size] + batch_end = min(i + batch_size, total_records) + + # Prepare batch data + ids = [record['id'] for record in batch] + data_json_list = [json.dumps(record['data']) for record in batch] try: async with conn.transaction(): + # Use executemany for reliable batch processing + upsert_sql = f""" + INSERT INTO {table_name} (id, data, updated_at) + VALUES ($1, $2, NOW()) + ON CONFLICT (id) + DO UPDATE SET + data = EXCLUDED.data, + updated_at = NOW() + """ + + # Prepare data for executemany + batch_data = [(record_id, data_json) for record_id, data_json in zip(ids, data_json_list)] await conn.executemany(upsert_sql, batch_data) - - log.info(f"Batch upsert progress: {min(i + self._batch_size, total_records)}/{total_records}") - + except Exception as e: - log.error(f"Batch upsert failed for batch {i}-{min(i + self._batch_size, total_records)}: {e}") - # Fallback to individual inserts - async with conn.transaction(): - for record in batch: - await conn.execute(upsert_sql, record['id'], record['human_readable_id'], json.dumps(record['data'])) + log.warning(f"Batch method failed for batch {i}-{batch_end}, falling back to individual inserts: {e}") + + # Fallback to individual inserts within the batch + try: + async with conn.transaction(): + upsert_sql = f""" + INSERT INTO {table_name} (id, data, updated_at) + VALUES ($1, $2, NOW()) + ON CONFLICT (id) + DO UPDATE SET + data = EXCLUDED.data, + updated_at = NOW() + """ + + for record_id, data_json in zip(ids, data_json_list): + await conn.execute(upsert_sql, record_id, data_json) + except Exception as inner_e: + log.error(f"Both batch and individual insert methods failed for batch {i}-{batch_end}: {inner_e}") + raise + + processed_count += len(batch) + + # Log progress every batch for visibility + if i % batch_size == 0 or batch_end == total_records: + log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find data in PostgreSQL tables using a file pattern regex.""" + # This is a synchronous method, but we need async operations + # For now, implement a basic version - in practice, this would need refactoring + log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) + + # Note: This is simplified - full implementation would need async/await support + # in the find method signature or use asyncio.run() + return iter([]) + def _parse_jsonb_field(self, value: Any, default_type: str = "list") -> Any: """Parse JSONB field back to Python object.""" if value is None: @@ -237,9 +353,7 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None return None # Get all records - rows = await conn.fetch( - f"SELECT id, human_readable_id, data FROM {table_name} ORDER BY created_at" - ) + rows = await conn.fetch(f"SELECT * FROM {table_name} ORDER BY created_at") if not rows: log.info(f"No data found in table {table_name}") @@ -248,14 +362,15 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None log.info(f"Retrieved {len(rows)} records from table {table_name}") # Handle non-parquet data (JSON/state files) - if not key.endswith('.parquet') or 'state' in key.lower() or 'context' in key.lower(): - if rows: - first_record_data = rows[0]['data'] - if isinstance(first_record_data, dict): - json_str = json.dumps(first_record_data) - else: - json_str = json.dumps(first_record_data) - return json_str.encode(encoding or self._encoding) if as_bytes else json_str + if (not key.endswith('.parquet') or + 'state' in key.lower() or + 'context' in table_name.lower()): + # For non-tabular data, return the raw content from the first record + if rows and 'data' in rows[0]: + raw_content = rows[0]['data'] + if isinstance(raw_content, dict): + json_str = json.dumps(raw_content) + return json_str.encode(encoding or self._encoding) if as_bytes else json_str return b"" if as_bytes else "" # Convert to DataFrame @@ -263,10 +378,6 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None for row in rows: record_data = dict(row['data']) if isinstance(row['data'], dict) else json.loads(row['data']) - # Add back the ID and human_readable_id - record_data['id'] = row['id'] - record_data['human_readable_id'] = row['human_readable_id'] - # Parse JSONB list fields back to proper Python lists for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids']: if field in record_data: @@ -282,9 +393,9 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None # Clean up NaN values df = df.where(pd.notna(df), None) - - log.info(f"Created DataFrame with shape: {df.shape}") - log.info(f"DataFrame columns: {df.columns.tolist()}") + + log.info(f"Get table {table_name} DataFrame with shape: {df.shape}") + log.info(f"Get table {table_name} DataFrame columns: {df.columns.tolist()}") # Convert to bytes if requested if as_bytes or kwargs.get("as_bytes"): @@ -298,9 +409,219 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None except Exception as e: log.exception(f"Error retrieving data from table {table_name}: {e}") return None + async def get1(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: + """Retrieve data from PostgreSQL table.""" + try: + table_name = self._get_table_name(key) + log.info(f"Retrieving data from table: {table_name}") + + conn = await self._get_connection() + try: + # Check if table exists + table_exists = await conn.fetchval( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + table_name + ) + + if not table_exists: + log.warning(f"Table {table_name} does not exist") + return None + + # Query all records for this prefix + rows = await conn.fetch(f"SELECT * FROM {table_name} ORDER BY created_at") + + if not rows: + log.info(f"No data found in table {table_name}") + return None + + log.info(f"Retrieved {len(rows)} records from table {table_name}") + + # Check if this should be treated as raw data instead of tabular data + if (not key.endswith('.parquet') or + 'state' in key.lower() or + key.endswith('.json') or + key.endswith('.txt') or + key.endswith('.yaml') or + key.endswith('.yml') or + 'context' in table_name.lower()): + # For non-tabular data, return the raw content from the first record + if rows and 'data' in rows[0]: + raw_content = rows[0]['data'] + if isinstance(raw_content, dict): + json_str = json.dumps(raw_content) + return json_str.encode(encoding or self._encoding) if as_bytes else json_str + return b"" if as_bytes else "" + + # Convert to DataFrame + records = [] + for row in rows: + # Handle JSONB data properly - row['data'] should already be a dict from asyncpg + if isinstance(row['data'], dict): + record_data = dict(row['data']) + else: + # If it's a string, parse it as JSON + record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] + + # Clean up the record data - convert None to proper values and handle NaN + cleaned_data = {} + for key, value in record_data.items(): + if self._is_scalar_na(value) or value is None: + cleaned_data[key] = None + elif isinstance(value, str) and key == 'text_unit_ids': + # Try to parse text_unit_ids back from JSON string if needed + try: + parsed_value = json.loads(value) + cleaned_data[key] = parsed_value if isinstance(parsed_value, list) else [value] + except (json.JSONDecodeError, TypeError): + # If it's not JSON, treat as a single item list or keep as string + cleaned_data[key] = [value] if value else [] + elif key in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids']: + # Always ensure these columns are lists + if isinstance(value, list): + cleaned_data[key] = value + elif isinstance(value, str): + try: + parsed_value = json.loads(value) + cleaned_data[key] = parsed_value if isinstance(parsed_value, list) else [] + except (json.JSONDecodeError, TypeError): + cleaned_data[key] = [] + elif value is None: + cleaned_data[key] = [] + else: + # fallback: wrap single value in a list + cleaned_data[key] = [value] + elif isinstance(value, (list, np.ndarray)) and len(value) == 0: + # Handle empty arrays/lists + cleaned_data[key] = [] + else: + cleaned_data[key] = value + + # # Always include the ID column for GraphRAG compatibility + # # Extract the actual ID from the prefixed storage ID + # storage_id = row['id'] + # # if ':' in storage_id: + # # actual_id = storage_id.split(':', 1)[1] + # # # Only use the actual ID if it's not a sequential index + # # if not actual_id.isdigit() or prefix not in self._no_id_prefixes: + # # cleaned_data['id'] = actual_id + # # else: + # # # For auto-generated sequential IDs, use the storage ID as the ID + # # cleaned_data['id'] = storage_id + # # else: + # # # If no prefix found, use the storage ID as is + # cleaned_data['id'] = storage_id + records.append(cleaned_data) + + df = pd.DataFrame(records) + + # Additional cleanup for NaN values in the DataFrame + df = df.where(pd.notna(df), None) + log.info(f"Get DataFrame with shape: {df.shape}") + log.info(f"DataFrame columns: {df.columns.tolist()}") + + # if len(df) > 0: + # log.info(f"Sample record: {df.iloc[0].to_dict()}") + # # Debug: Check if children column exists and its type + # if 'children' in df.columns: + # sample_children = df.iloc[0]['children'] + # log.info(f"Sample children value: {sample_children}, type: {type(sample_children)}") + + # Handle bytes conversion for GraphRAG compatibility + if as_bytes or kwargs.get("as_bytes"): + log.info(f"Converting DataFrame to parquet bytes for key: {key}") + + # Apply column filtering similar to Milvus implementation + df_clean = df.copy() + + # Define expected columns for each data type + if 'documents' in table_name: + expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] + # Include text_unit_ids if it has meaningful data + if 'text_unit_ids' in df_clean.columns and any( + len(tuid) > 0 for tuid in df_clean['text_unit_ids'] if isinstance(tuid, list) + ): + expected_columns.insert(4, 'text_unit_ids') + log.info("Including text_unit_ids (appears to be final documents)") + elif 'entities' in table_name: + # Exclude degree column for GraphRAG compatibility + expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency'] + log.info("Excluding degree column from entities for finalize_entities compatibility") + elif 'relationships' in table_name: + expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] + if 'combined_degree' in df_clean.columns: + expected_columns.append('combined_degree') + elif 'text_units' in table_name: + expected_columns = ['id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] + elif 'communities' in table_name: + expected_columns = ['id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] + else: + expected_columns = list(df_clean.columns) + + # Filter columns + available_columns = [col for col in expected_columns if col in df_clean.columns] + if available_columns != expected_columns: + missing = set(expected_columns) - set(available_columns) + extra = set(df_clean.columns) - set(expected_columns) + log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") + + df_clean = df_clean[available_columns] + log.info(f"Final filtered columns: {df_clean.columns.tolist()}") + + # Convert to parquet bytes + try: + # Handle list columns that PyArrow can't serialize directly + df_for_parquet = df_clean.copy() + + # Convert list columns to JSON strings for parquet compatibility + list_columns = [] + for col in df_for_parquet.columns: + if col in df_for_parquet.columns and len(df_for_parquet) > 0: + # Check if this column contains lists + first_non_null = None + for val in df_for_parquet[col]: + if isinstance(val, list): + first_non_null = val + break + elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): + first_non_null = val + break + if isinstance(first_non_null, list): + list_columns.append(col) + # Convert lists to JSON strings + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + ) + elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids']: + # These columns should always be lists, even if empty + list_columns.append(col) + df_for_parquet[col] = df_for_parquet[col].apply( + lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + ) + + if list_columns: + log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") + + buffer = BytesIO() + df_for_parquet.to_parquet(buffer, engine='pyarrow') + buffer.seek(0) + parquet_bytes = buffer.getvalue() + log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") + return parquet_bytes + except Exception as e: + log.exception(f"Failed to convert DataFrame to parquet bytes: {e}") + return b"" + + return df + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception(f"Error retrieving data from table {table_name}: {e}") + return None async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Store data in PostgreSQL table.""" + """Insert data into PostgreSQL table with upsert capability.""" try: table_name = self._get_table_name(key) log.info(f"Setting data for key: {key}, table: {table_name}") @@ -310,60 +631,87 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: conn = await self._get_connection() try: if isinstance(value, bytes): - # Handle parquet data + # Parse parquet data df = pd.read_parquet(BytesIO(value)) log.info(f"Parsed parquet data, DataFrame shape: {df.shape}") + log.info(f"Parsed DataFrame head: {df.head()}") + + # Prepare data for PostgreSQL + records = self._prepare_data_for_postgres(df, table_name) - records = self._prepare_dataframe_for_storage(df) if records: - await self._batch_upsert(conn, table_name, records) - log.info(f"Successfully stored {len(records)} records to {table_name}") + # Use batch insert for much better performance + await self._batch_upsert_records(conn, table_name, records) + + log.info(f"Successfully upserted {len(records)} records to {table_name}") + + # # Log duplicate handling info + # if any(record['id'].split(':')[0] in self._no_id_prefixes for record in records): + # log.info("Some records used auto-generated IDs") + else: - # Handle non-parquet data (JSON, etc.) + # Handle non-parquet data (e.g., JSON, stats) + log.info(f"Handling non-parquet data for key: {key}") + record_data = json.loads(value) if isinstance(value, str) else value record = { 'id': key, - 'human_readable_id': 0, 'data': record_data } + upsert_sql = f""" + INSERT INTO {table_name} (id, data, updated_at) + VALUES ($1, $2, NOW()) + ON CONFLICT (id) + DO UPDATE SET + data = EXCLUDED.data, + updated_at = NOW() + """ + await conn.execute( - f"""INSERT INTO {table_name} (id, human_readable_id, data, updated_at) - VALUES ($1, $2, $3, NOW()) - ON CONFLICT (id) DO UPDATE SET - data = EXCLUDED.data, updated_at = NOW()""", - record['id'], record['human_readable_id'], json.dumps(record['data']) + upsert_sql, + record['id'], + json.dumps(record['data']) ) - log.info(f"Successfully stored non-parquet data for key: {key}") + + log.info("Non-parquet data upsert successful") finally: await self._release_connection(conn) except Exception as e: - log.exception(f"Error setting data for key {key}: {e}") - raise + log.exception("Error setting data in PostgreSQL table %s: %s", table_name, e) async def has(self, key: str) -> bool: """Check if data exists for the given key.""" try: table_name = self._get_table_name(key) + log.info(f"Checking existence for key: {key}, table: {table_name}") conn = await self._get_connection() try: + # Check if table exists table_exists = await conn.fetchval( "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", table_name ) + log.debug(f"Table {table_name} exists: {table_exists}") if not table_exists: return False if key.endswith('.parquet'): - # For parquet files, check if table has records - count = await conn.fetchval(f"SELECT COUNT(*) FROM {table_name}") - return count > 0 + # For parquet files, check if table has any records + total_count = await conn.fetchval( + f"SELECT COUNT(*) FROM {table_name}" + ) + if total_count > 0: + return True + else: + raise ValueError(f"No records found in table {table_name} for parquet key {key}") else: - # For specific keys, check exact match + # Check for exact key match exists = await conn.fetchval( - f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", key + f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", + key ) return exists @@ -371,7 +719,7 @@ async def has(self, key: str) -> bool: await self._release_connection(conn) except Exception as e: - log.exception(f"Error checking existence for key {key}: {e}") + log.exception("Error checking existence for key %s: %s", key, e) return False async def delete(self, key: str) -> None: @@ -380,18 +728,34 @@ async def delete(self, key: str) -> None: table_name = self._get_table_name(key) conn = await self._get_connection() try: - await conn.execute(f"DROP TABLE IF EXISTS {table_name}") - log.info(f"Deleted table for key {key}") + if key.endswith('.parquet'): + # Delete all records with this prefix + prefix = self._get_prefix(key) + result = await conn.execute( + f"DELETE FROM {table_name} WHERE id LIKE $1", + f"{prefix}:%" + ) + log.info(f"Deleted records for prefix {prefix}: {result}") + else: + # Delete exact key match + result = await conn.execute( + f"DELETE FROM {table_name} WHERE id = $1", + key + ) + log.info(f"Deleted record for key {key}: {result}") + finally: await self._release_connection(conn) + except Exception as e: - log.exception(f"Error deleting key {key}: {e}") + log.exception("Error deleting key %s: %s", key, e) async def clear(self) -> None: """Clear all tables with the configured prefix.""" try: conn = await self._get_connection() try: + # Get all tables with our prefix tables = await conn.fetch( "SELECT table_name FROM information_schema.tables WHERE table_name LIKE $1", f"{self._collection_prefix}%" @@ -408,27 +772,31 @@ async def clear(self) -> None: await self._release_connection(conn) except Exception as e: - log.exception(f"Error clearing tables: {e}") + log.exception("Error clearing tables: %s", e) def keys(self) -> list[str]: """Return the keys in the storage.""" + # This would need to be async to properly implement + # For now, return empty list log.warning("keys() method not fully implemented for async storage") return [] def child(self, name: str | None) -> PipelineStorage: """Create a child storage instance.""" - return self - - def find( - self, - file_pattern: re.Pattern[str], - base_dir: str | None = None, - file_filter: dict[str, Any] | None = None, - max_count=-1, - ) -> Iterator[tuple[str, dict[str, Any]]]: - """Find data in PostgreSQL tables using a file pattern regex.""" - log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) - return iter([]) + if name is None: + return self + + # Create child with modified table prefix + child_prefix = f"{self._collection_prefix}{name}_" + return PostgresPipelineStorage( + host=self._host, + port=self._port, + database=self._database, + username=self._username, + password=self._password, + collection_prefix=child_prefix, + encoding=self._encoding, + ) async def get_creation_date(self, key: str) -> str: """Get the creation date for data.""" @@ -437,12 +805,15 @@ async def get_creation_date(self, key: str) -> str: conn = await self._get_connection() try: if key.endswith('.parquet'): + prefix = self._get_prefix(key) created_at = await conn.fetchval( - f"SELECT MIN(created_at) FROM {table_name}" + f"SELECT MIN(created_at) FROM {table_name} WHERE id LIKE $1", + f"{prefix}:%" ) else: created_at = await conn.fetchval( - f"SELECT created_at FROM {table_name} WHERE id = $1", key + f"SELECT created_at FROM {table_name} WHERE id = $1", + key ) if created_at: @@ -452,7 +823,7 @@ async def get_creation_date(self, key: str) -> str: await self._release_connection(conn) except Exception as e: - log.exception(f"Error getting creation date for {key}: {e}") + log.exception("Error getting creation date for %s: %s", key, e) return "" @@ -460,4 +831,4 @@ async def close(self) -> None: """Close the connection pool.""" if self._pool: await self._pool.close() - log.info("Closed PostgreSQL connection pool") \ No newline at end of file + log.info("Closed PostgreSQL connection pool") diff --git a/graphrag/storage/postgres_pipeline_storage2.py b/graphrag/storage/postgres_pipeline_storage_bk.py similarity index 69% rename from graphrag/storage/postgres_pipeline_storage2.py rename to graphrag/storage/postgres_pipeline_storage_bk.py index 0a265d8316..319886db66 100644 --- a/graphrag/storage/postgres_pipeline_storage2.py +++ b/graphrag/storage/postgres_pipeline_storage_bk.py @@ -764,7 +764,7 @@ def find( return iter([]) async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: - """Retrieve data from PostgreSQL table.""" + """Retrieve data from PostgreSQL table - simplified approach.""" try: table_name = self._get_table_name(key) log.info(f"Retrieving data from table: {table_name}") @@ -778,285 +778,342 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None ) if not table_exists: - log.warning(f"Table {table_name} does not exist") return None - # Determine if this is a typed table or generic table - is_typed_table = any(table_type in table_name for table_type in - ['entities', 'relationships', 'communities', 'text_units', 'documents']) - - if is_typed_table: - # For typed tables, select all columns except created_at/updated_at - if 'documents' in table_name: - query = "SELECT id, human_readable_id, title, text, text_unit_ids, creation_date, metadata FROM {} ORDER BY created_at".format(table_name) - elif 'entities' in table_name: - query = "SELECT id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y FROM {} ORDER BY created_at".format(table_name) - elif 'relationships' in table_name: - query = "SELECT id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids FROM {} ORDER BY created_at".format(table_name) - elif 'communities' in table_name: - query = "SELECT id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) - elif 'text_units' in table_name: - query = "SELECT id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) - else: - # Fallback for unknown typed table - query = "SELECT * FROM {} ORDER BY created_at".format(table_name) - else: - # For generic tables, use the data column - query = "SELECT id, data FROM {} ORDER BY created_at".format(table_name) - - rows = await conn.fetch(query) + # Simple approach: get all data and convert directly to DataFrame + rows = await conn.fetch(f"SELECT * FROM {table_name} ORDER BY created_at") if not rows: - log.info(f"No data found in table {table_name}") return None - log.info(f"Retrieved {len(rows)} records from table {table_name}") + # Convert to DataFrame with minimal transformation + records = [dict(row) for row in rows] + df = pd.DataFrame(records) + + # Only handle JSONB fields - convert back from JSON strings to lists/dicts + for col in df.columns: + if col in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: + df[col] = df[col].apply(self._parse_jsonb_field) + + # Handle bytes conversion for GraphRAG compatibility + if as_bytes or kwargs.get("as_bytes"): + return self._dataframe_to_parquet_bytes(df) + + return df + + finally: + await self._release_connection(conn) + + except Exception as e: + log.exception(f"Error retrieving data from table {table_name}: {e}") + return None - # Check if this should be treated as raw data instead of tabular data - if (not key.endswith('.parquet') or - 'state' in key.lower() or - key.endswith('.json') or - 'context' in table_name.lower()): - # Handle state.json or context.json as raw data - # For non-tabular data, return the raw content from the first record - if rows: - if is_typed_table: - # For typed tables, convert row to dict and return as JSON - row_dict = dict(rows[0]) - json_str = json.dumps(row_dict) - return json_str.encode(encoding or self._encoding) if as_bytes else json_str - elif 'data' in rows[0]: - raw_content = rows[0]['data'] - if isinstance(raw_content, dict): - json_str = json.dumps(raw_content) - return json_str.encode(encoding or self._encoding) if as_bytes else json_str - return b"" if as_bytes else "" + def _parse_jsonb_field(self, value): + """Simple JSONB field parser.""" + if value is None: + return [] + if isinstance(value, (list, dict)): + return value + if isinstance(value, str): + try: + return json.loads(value) + except: + return [] + return [] + # async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: + # """Retrieve data from PostgreSQL table.""" + # try: + # table_name = self._get_table_name(key) + # log.info(f"Retrieving data from table: {table_name}") + + # conn = await self._get_connection() + # try: + # # Check if table exists + # table_exists = await conn.fetchval( + # "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", + # table_name + # ) + + # if not table_exists: + # log.warning(f"Table {table_name} does not exist") + # return None + + # # Determine if this is a typed table or generic table + # is_typed_table = any(table_type in table_name for table_type in + # ['entities', 'relationships', 'communities', 'text_units', 'documents']) + + # if is_typed_table: + # # For typed tables, select all columns except created_at/updated_at + # if 'documents' in table_name: + # query = "SELECT id, human_readable_id, title, text, text_unit_ids, creation_date, metadata FROM {} ORDER BY created_at".format(table_name) + # elif 'entities' in table_name: + # query = "SELECT id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y FROM {} ORDER BY created_at".format(table_name) + # elif 'relationships' in table_name: + # query = "SELECT id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids FROM {} ORDER BY created_at".format(table_name) + # elif 'communities' in table_name: + # query = "SELECT id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) + # elif 'text_units' in table_name: + # query = "SELECT id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) + # else: + # # Fallback for unknown typed table + # query = "SELECT * FROM {} ORDER BY created_at".format(table_name) + # else: + # # For generic tables, use the data column + # query = "SELECT id, data FROM {} ORDER BY created_at".format(table_name) + + # rows = await conn.fetch(query) + + # if not rows: + # log.info(f"No data found in table {table_name}") + # return None + + # log.info(f"Retrieved {len(rows)} records from table {table_name}") + + # # Check if this should be treated as raw data instead of tabular data + # if (not key.endswith('.parquet') or + # 'state' in key.lower() or + # key.endswith('.json') or + # 'context' in table_name.lower()): + # # Handle state.json or context.json as raw data + # # For non-tabular data, return the raw content from the first record + # if rows: + # if is_typed_table: + # # For typed tables, convert row to dict and return as JSON + # row_dict = dict(rows[0]) + # json_str = json.dumps(row_dict) + # return json_str.encode(encoding or self._encoding) if as_bytes else json_str + # elif 'data' in rows[0]: + # raw_content = rows[0]['data'] + # if isinstance(raw_content, dict): + # json_str = json.dumps(raw_content) + # return json_str.encode(encoding or self._encoding) if as_bytes else json_str + # return b"" if as_bytes else "" - # Convert to DataFrame - records = [] - for row in rows: - if is_typed_table: - # For typed tables, the row is already the data we need - record_data = dict(row) + # # Convert to DataFrame + # records = [] + # for row in rows: + # if is_typed_table: + # # For typed tables, the row is already the data we need + # record_data = dict(row) - # Convert JSONB fields back to proper Python objects - for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: - if field in record_data: - value = record_data[field] + # # Convert JSONB fields back to proper Python objects + # for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: + # if field in record_data: + # value = record_data[field] - if value is None: - record_data[field] = {} if field == 'metadata' else [] - elif isinstance(value, str): - # Handle JSONB strings - they should always be valid JSON - try: - parsed = json.loads(value) - # Validate the parsed type - if field == 'metadata': - record_data[field] = parsed if isinstance(parsed, dict) else {} - else: - record_data[field] = parsed if isinstance(parsed, list) else [] - except (json.JSONDecodeError, TypeError): - log.warning(f"Failed to parse JSONB field {field}: {value}") - # Fallback for non-JSON strings - if field == 'metadata': - record_data[field] = {} - else: - record_data[field] = [] - elif isinstance(value, (list, dict)): - # Already correct type (shouldn't happen with JSONB, but handle it) - record_data[field] = value - else: - # Convert other types - if field == 'metadata': - record_data[field] = {'value': str(value)} if value else {} - else: - record_data[field] = [value] if value else [] - else: - # Handle generic table data (JSONB data column) - if isinstance(row['data'], dict): - record_data = dict(row['data']) - else: - # If it's a string, parse it as JSON - record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] + # if value is None: + # record_data[field] = {} if field == 'metadata' else [] + # elif isinstance(value, str): + # # Handle JSONB strings - they should always be valid JSON + # try: + # parsed = json.loads(value) + # # Validate the parsed type + # if field == 'metadata': + # record_data[field] = parsed if isinstance(parsed, dict) else {} + # else: + # record_data[field] = parsed if isinstance(parsed, list) else [] + # except (json.JSONDecodeError, TypeError): + # log.warning(f"Failed to parse JSONB field {field}: {value}") + # # Fallback for non-JSON strings + # if field == 'metadata': + # record_data[field] = {} + # else: + # record_data[field] = [] + # elif isinstance(value, (list, dict)): + # # Already correct type (shouldn't happen with JSONB, but handle it) + # record_data[field] = value + # else: + # # Convert other types + # if field == 'metadata': + # record_data[field] = {'value': str(value)} if value else {} + # else: + # record_data[field] = [value] if value else [] + # else: + # # Handle generic table data (JSONB data column) + # if isinstance(row['data'], dict): + # record_data = dict(row['data']) + # else: + # # If it's a string, parse it as JSON + # record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] - # Clean up the record data - convert None to proper values and handle NaN - cleaned_data = {} - for key_name, value in record_data.items(): - if self._is_scalar_na(value) or value is None: - cleaned_data[key_name] = None - elif isinstance(value, str) and key_name == 'text_unit_ids' and not is_typed_table: - # Try to parse text_unit_ids back from JSON string if needed (only for generic tables) - try: - parsed_value = json.loads(value) - cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [value] - except (json.JSONDecodeError, TypeError): - # If it's not JSON, treat as a single item list or keep as string - cleaned_data[key_name] = [value] if value else [] - elif key_name in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids'] and not is_typed_table: - # Always ensure these columns are lists (only for generic tables - typed tables already handled this) - if isinstance(value, list): - cleaned_data[key_name] = value - elif isinstance(value, str): - try: - parsed_value = json.loads(value) - cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [] - except (json.JSONDecodeError, TypeError): - cleaned_data[key_name] = [] - elif value is None: - cleaned_data[key_name] = [] - else: - # fallback: wrap single value in a list - cleaned_data[key_name] = [value] - elif isinstance(value, (list, np.ndarray)) and len(value) == 0: - # Handle empty arrays/lists - cleaned_data[key_name] = [] - else: - cleaned_data[key_name] = value + # # Clean up the record data - convert None to proper values and handle NaN + # cleaned_data = {} + # for key_name, value in record_data.items(): + # if self._is_scalar_na(value) or value is None: + # cleaned_data[key_name] = None + # elif isinstance(value, str) and key_name == 'text_unit_ids' and not is_typed_table: + # # Try to parse text_unit_ids back from JSON string if needed (only for generic tables) + # try: + # parsed_value = json.loads(value) + # cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [value] + # except (json.JSONDecodeError, TypeError): + # # If it's not JSON, treat as a single item list or keep as string + # cleaned_data[key_name] = [value] if value else [] + # elif key_name in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids'] and not is_typed_table: + # # Always ensure these columns are lists (only for generic tables - typed tables already handled this) + # if isinstance(value, list): + # cleaned_data[key_name] = value + # elif isinstance(value, str): + # try: + # parsed_value = json.loads(value) + # cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [] + # except (json.JSONDecodeError, TypeError): + # cleaned_data[key_name] = [] + # elif value is None: + # cleaned_data[key_name] = [] + # else: + # # fallback: wrap single value in a list + # cleaned_data[key_name] = [value] + # elif isinstance(value, (list, np.ndarray)) and len(value) == 0: + # # Handle empty arrays/lists + # cleaned_data[key_name] = [] + # else: + # cleaned_data[key_name] = value - # Always include the ID column for GraphRAG compatibility - # Use the storage ID as is since we simplified ID handling - storage_id = row['id'] - cleaned_data['id'] = storage_id - records.append(cleaned_data) + # # Always include the ID column for GraphRAG compatibility + # # Use the storage ID as is since we simplified ID handling + # storage_id = row['id'] + # cleaned_data['id'] = storage_id + # records.append(cleaned_data) - df = pd.DataFrame(records) + # df = pd.DataFrame(records) - # Additional cleanup for NaN values in the DataFrame - df = df.where(pd.notna(df), None) - log.info(f"Created DataFrame with shape: {df.shape}") - log.info(f"Table {table_name} DataFrame columns: {df.columns.tolist()}") + # # Additional cleanup for NaN values in the DataFrame + # df = df.where(pd.notna(df), None) + # log.info(f"Created DataFrame with shape: {df.shape}") + # log.info(f"Table {table_name} DataFrame columns: {df.columns.tolist()}") - if len(df) > 0: - log.debug(f"Table {table_name} Sample record: {df.iloc[0].to_dict()}") - # Debug: Check if children column exists and its type - if 'children' in df.columns: - sample_children = df.iloc[0]['children'] - log.debug(f"Table {table_name} Sample children value: {sample_children}, type: {type(sample_children)}") + # if len(df) > 0: + # log.debug(f"Table {table_name} Sample record: {df.iloc[0].to_dict()}") + # # Debug: Check if children column exists and its type + # if 'children' in df.columns: + # sample_children = df.iloc[0]['children'] + # log.debug(f"Table {table_name} Sample children value: {sample_children}, type: {type(sample_children)}") - # Handle bytes conversion for GraphRAG compatibility - if as_bytes or kwargs.get("as_bytes"): - log.info(f"Converting DataFrame to parquet bytes for key: {key}") + # # Handle bytes conversion for GraphRAG compatibility + # if as_bytes or kwargs.get("as_bytes"): + # log.info(f"Converting DataFrame to parquet bytes for key: {key}") - # Apply column filtering similar to Milvus implementation - df_clean = df.copy() + # # Apply column filtering similar to Milvus implementation + # df_clean = df.copy() - # Define expected columns for each data type - if 'documents' in table_name: - expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] - elif 'entities' in table_name: - expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'] - elif 'relationships' in table_name: - expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] - if 'combined_degree' in df_clean.columns: - expected_columns.append('combined_degree') - elif 'text_units' in table_name: - expected_columns = ['id', 'human_readable_id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] - elif 'communities' in table_name: - expected_columns = ['id', 'human_readable_id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] - else: - expected_columns = list(df_clean.columns) + # # Define expected columns for each data type + # if 'documents' in table_name: + # expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] + # elif 'entities' in table_name: + # expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'] + # elif 'relationships' in table_name: + # expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] + # if 'combined_degree' in df_clean.columns: + # expected_columns.append('combined_degree') + # elif 'text_units' in table_name: + # expected_columns = ['id', 'human_readable_id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] + # elif 'communities' in table_name: + # expected_columns = ['id', 'human_readable_id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] + # else: + # expected_columns = list(df_clean.columns) - # Filter columns - available_columns = [col for col in expected_columns if col in df_clean.columns] - if available_columns != expected_columns: - missing = set(expected_columns) - set(available_columns) - extra = set(df_clean.columns) - set(expected_columns) - log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") + # # Filter columns + # available_columns = [col for col in expected_columns if col in df_clean.columns] + # if available_columns != expected_columns: + # missing = set(expected_columns) - set(available_columns) + # extra = set(df_clean.columns) - set(expected_columns) + # log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") - df_clean = df_clean[available_columns] - log.info(f"Table {table_name} final filtered columns: {df_clean.columns.tolist()}") + # df_clean = df_clean[available_columns] + # log.info(f"Table {table_name} final filtered columns: {df_clean.columns.tolist()}") - # Convert to parquet bytes - try: - # Handle list columns for PyArrow compatibility - df_for_parquet = df_clean.copy() + # # Convert to parquet bytes + # try: + # # Handle list columns for PyArrow compatibility + # df_for_parquet = df_clean.copy() - # For PyArrow/parquet compatibility, we need to handle list columns carefully - # Instead of converting to JSON strings, let's try a different approach - list_columns = [] - for col in df_for_parquet.columns: - if col in df_for_parquet.columns and len(df_for_parquet) > 0: - # Check if this column contains lists - first_non_null = None - for val in df_for_parquet[col]: - if isinstance(val, list): - first_non_null = val - break - elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): - first_non_null = val - break + # # For PyArrow/parquet compatibility, we need to handle list columns carefully + # # Instead of converting to JSON strings, let's try a different approach + # list_columns = [] + # for col in df_for_parquet.columns: + # if col in df_for_parquet.columns and len(df_for_parquet) > 0: + # # Check if this column contains lists + # first_non_null = None + # for val in df_for_parquet[col]: + # if isinstance(val, list): + # first_non_null = val + # break + # elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): + # first_non_null = val + # break - if isinstance(first_non_null, list) or col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: - list_columns.append(col) - # Ensure all values in this column are proper lists - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: x if isinstance(x, list) else ([] if x is None or pd.isna(x) else [str(x)]) - ) + # if isinstance(first_non_null, list) or col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: + # list_columns.append(col) + # # Ensure all values in this column are proper lists + # df_for_parquet[col] = df_for_parquet[col].apply( + # lambda x: x if isinstance(x, list) else ([] if x is None or pd.isna(x) else [str(x)]) + # ) - if list_columns: - log.info(f"Ensured list columns are proper lists for parquet: {list_columns}") + # if list_columns: + # log.info(f"Ensured list columns are proper lists for parquet: {list_columns}") - # Try to convert to parquet without JSON string conversion - buffer = BytesIO() - df_for_parquet.to_parquet(buffer, engine='pyarrow') - buffer.seek(0) - parquet_bytes = buffer.getvalue() - log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") - return parquet_bytes - except Exception as e: - log.warning(f"Direct parquet conversion failed: {e}, trying with JSON string conversion") + # # Try to convert to parquet without JSON string conversion + # buffer = BytesIO() + # df_for_parquet.to_parquet(buffer, engine='pyarrow') + # buffer.seek(0) + # parquet_bytes = buffer.getvalue() + # log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") + # return parquet_bytes + # except Exception as e: + # log.warning(f"Direct parquet conversion failed: {e}, trying with JSON string conversion") - # Fallback: convert lists to JSON strings - try: - df_for_parquet = df_clean.copy() + # # Fallback: convert lists to JSON strings + # try: + # df_for_parquet = df_clean.copy() - # Convert list columns to JSON strings for parquet compatibility - list_columns = [] - for col in df_for_parquet.columns: - if col in df_for_parquet.columns and len(df_for_parquet) > 0: - # Check if this column contains lists - first_non_null = None - for val in df_for_parquet[col]: - if isinstance(val, list): - first_non_null = val - break - elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): - first_non_null = val - break - if isinstance(first_non_null, list): - list_columns.append(col) - # Convert lists to JSON strings - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - ) - elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: - # These columns should always be lists, even if empty - list_columns.append(col) - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - ) + # # Convert list columns to JSON strings for parquet compatibility + # list_columns = [] + # for col in df_for_parquet.columns: + # if col in df_for_parquet.columns and len(df_for_parquet) > 0: + # # Check if this column contains lists + # first_non_null = None + # for val in df_for_parquet[col]: + # if isinstance(val, list): + # first_non_null = val + # break + # elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): + # first_non_null = val + # break + # if isinstance(first_non_null, list): + # list_columns.append(col) + # # Convert lists to JSON strings + # df_for_parquet[col] = df_for_parquet[col].apply( + # lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + # ) + # elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: + # # These columns should always be lists, even if empty + # list_columns.append(col) + # df_for_parquet[col] = df_for_parquet[col].apply( + # lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) + # ) - if list_columns: - log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") + # if list_columns: + # log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") - buffer = BytesIO() - df_for_parquet.to_parquet(buffer, engine='pyarrow') - buffer.seek(0) - parquet_bytes = buffer.getvalue() - log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data (with JSON conversion)") - return parquet_bytes - except Exception as e2: - log.exception(f"Failed to convert DataFrame to parquet bytes: {e2}") - return b"" + # buffer = BytesIO() + # df_for_parquet.to_parquet(buffer, engine='pyarrow') + # buffer.seek(0) + # parquet_bytes = buffer.getvalue() + # log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data (with JSON conversion)") + # return parquet_bytes + # except Exception as e2: + # log.exception(f"Failed to convert DataFrame to parquet bytes: {e2}") + # return b"" - return df + # return df - finally: - await self._release_connection(conn) + # finally: + # await self._release_connection(conn) - except Exception as e: - log.exception(f"Error retrieving data from table {table_name}: {e}") - return None + # except Exception as e: + # log.exception(f"Error retrieving data from table {table_name}: {e}") + # return None async def set(self, key: str, value: Any, encoding: str | None = None) -> None: """Insert data into PostgreSQL table with drop/recreate to avoid duplicates.""" From e79a6cf93cfb24c82914fabf60e522af78943074 Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Fri, 22 Aug 2025 23:44:12 -0700 Subject: [PATCH 07/12] cleanup --- .../index/operations/finalize_entities.py | 2 +- graphrag/storage/postgres_pipeline_storage.py | 281 +----------------- 2 files changed, 13 insertions(+), 270 deletions(-) diff --git a/graphrag/index/operations/finalize_entities.py b/graphrag/index/operations/finalize_entities.py index 3dba9d1a5c..460b8c1c63 100644 --- a/graphrag/index/operations/finalize_entities.py +++ b/graphrag/index/operations/finalize_entities.py @@ -24,7 +24,7 @@ def finalize_entities( """All the steps to transform final entities.""" # # Remove the default column degree, x and y for Postgres storage compatibility. And below entities.merge method # # will add them back with calculated values. - # entities = entities.drop(columns=["degree", "x", "y"], errors="ignore") + entities = entities.drop(columns=["degree", "x", "y"], errors="ignore") graph = create_graph(relationships, edge_attr=["weight"]) graph_embeddings = None if embed_config is not None and embed_config.enabled: diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index e28a460d1c..7ecd311714 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -9,11 +9,11 @@ from collections.abc import Iterator from io import BytesIO from typing import Any - import numpy as np import pandas as pd import asyncpg from asyncpg import Connection, Pool +import asyncio from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import ( @@ -142,35 +142,6 @@ async def _ensure_table_exists(self, table_name: str) -> None: finally: await self._release_connection(conn) - # def _process_id_field(self, df: pd.DataFrame, table_name: str) -> list[str]: - # """Process ID values - store clean IDs with prefix following CosmosDB pattern.""" - # prefix = self._get_prefix(table_name.replace(self._collection_prefix, "")) - # id_values = [] - - # if "id" not in df.columns: - # # No ID column - create prefixed sequential IDs and track this prefix - # for index in range(len(df)): - # id_values.append(f"{prefix}:{index}") - # if prefix not in self._no_id_prefixes: - # self._no_id_prefixes.append(prefix) - # log.info(f"No ID column found for {prefix}, generated prefixed sequential IDs") - # else: - # # Has ID column - process each row with prefix - # for index, val in enumerate(df["id"]): - # if self._is_scalar_na(val) or val == '' or val == 'nan' or (isinstance(val, list) and (len(val) == 0 or self._is_scalar_na(val[0]) or str(val[0]).strip() == '')): - # # Missing ID - create prefixed sequential ID and track this prefix - # id_values.append(f"{prefix}:{index}") - # if prefix not in self._no_id_prefixes: - # self._no_id_prefixes.append(prefix) - # else: - # # Valid ID - use with prefix (following CosmosDB pattern) - # if isinstance(val, list): - # id_values.append(f"{prefix}:{val[0]}") - # else: - # id_values.append(f"{prefix}:{val}") - - # return id_values - def _is_scalar_na(self, value: Any) -> bool: """Safely check if a value is NA/null, avoiding issues with arrays.""" try: @@ -235,9 +206,10 @@ def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[ log.info(f"Prepared {len(records)} records for PostgreSQL") return records - async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict], batch_size: int = 1000) -> None: + async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict]) -> None: """Perform high-performance batch upsert of records using executemany.""" total_records = len(records) + batch_size = self._batch_size log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {batch_size}") processed_count = 0 @@ -293,23 +265,16 @@ async def _batch_upsert_records(self, conn: Connection, table_name: str, records # Log progress every batch for visibility if i % batch_size == 0 or batch_end == total_records: log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") - + def find( - self, - file_pattern: re.Pattern[str], - base_dir: str | None = None, - file_filter: dict[str, Any] | None = None, - max_count=-1, + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, ) -> Iterator[tuple[str, dict[str, Any]]]: - """Find data in PostgreSQL tables using a file pattern regex.""" - # This is a synchronous method, but we need async operations - # For now, implement a basic version - in practice, this would need refactoring - log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) - - # Note: This is simplified - full implementation would need async/await support - # in the find method signature or use asyncio.run() return iter([]) - + def _parse_jsonb_field(self, value: Any, default_type: str = "list") -> Any: """Parse JSONB field back to Python object.""" if value is None: @@ -409,216 +374,6 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None except Exception as e: log.exception(f"Error retrieving data from table {table_name}: {e}") return None - async def get1(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: - """Retrieve data from PostgreSQL table.""" - try: - table_name = self._get_table_name(key) - log.info(f"Retrieving data from table: {table_name}") - - conn = await self._get_connection() - try: - # Check if table exists - table_exists = await conn.fetchval( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", - table_name - ) - - if not table_exists: - log.warning(f"Table {table_name} does not exist") - return None - - # Query all records for this prefix - rows = await conn.fetch(f"SELECT * FROM {table_name} ORDER BY created_at") - - if not rows: - log.info(f"No data found in table {table_name}") - return None - - log.info(f"Retrieved {len(rows)} records from table {table_name}") - - # Check if this should be treated as raw data instead of tabular data - if (not key.endswith('.parquet') or - 'state' in key.lower() or - key.endswith('.json') or - key.endswith('.txt') or - key.endswith('.yaml') or - key.endswith('.yml') or - 'context' in table_name.lower()): - # For non-tabular data, return the raw content from the first record - if rows and 'data' in rows[0]: - raw_content = rows[0]['data'] - if isinstance(raw_content, dict): - json_str = json.dumps(raw_content) - return json_str.encode(encoding or self._encoding) if as_bytes else json_str - return b"" if as_bytes else "" - - # Convert to DataFrame - records = [] - for row in rows: - # Handle JSONB data properly - row['data'] should already be a dict from asyncpg - if isinstance(row['data'], dict): - record_data = dict(row['data']) - else: - # If it's a string, parse it as JSON - record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] - - # Clean up the record data - convert None to proper values and handle NaN - cleaned_data = {} - for key, value in record_data.items(): - if self._is_scalar_na(value) or value is None: - cleaned_data[key] = None - elif isinstance(value, str) and key == 'text_unit_ids': - # Try to parse text_unit_ids back from JSON string if needed - try: - parsed_value = json.loads(value) - cleaned_data[key] = parsed_value if isinstance(parsed_value, list) else [value] - except (json.JSONDecodeError, TypeError): - # If it's not JSON, treat as a single item list or keep as string - cleaned_data[key] = [value] if value else [] - elif key in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids']: - # Always ensure these columns are lists - if isinstance(value, list): - cleaned_data[key] = value - elif isinstance(value, str): - try: - parsed_value = json.loads(value) - cleaned_data[key] = parsed_value if isinstance(parsed_value, list) else [] - except (json.JSONDecodeError, TypeError): - cleaned_data[key] = [] - elif value is None: - cleaned_data[key] = [] - else: - # fallback: wrap single value in a list - cleaned_data[key] = [value] - elif isinstance(value, (list, np.ndarray)) and len(value) == 0: - # Handle empty arrays/lists - cleaned_data[key] = [] - else: - cleaned_data[key] = value - - # # Always include the ID column for GraphRAG compatibility - # # Extract the actual ID from the prefixed storage ID - # storage_id = row['id'] - # # if ':' in storage_id: - # # actual_id = storage_id.split(':', 1)[1] - # # # Only use the actual ID if it's not a sequential index - # # if not actual_id.isdigit() or prefix not in self._no_id_prefixes: - # # cleaned_data['id'] = actual_id - # # else: - # # # For auto-generated sequential IDs, use the storage ID as the ID - # # cleaned_data['id'] = storage_id - # # else: - # # # If no prefix found, use the storage ID as is - # cleaned_data['id'] = storage_id - records.append(cleaned_data) - - df = pd.DataFrame(records) - - # Additional cleanup for NaN values in the DataFrame - df = df.where(pd.notna(df), None) - log.info(f"Get DataFrame with shape: {df.shape}") - log.info(f"DataFrame columns: {df.columns.tolist()}") - - # if len(df) > 0: - # log.info(f"Sample record: {df.iloc[0].to_dict()}") - # # Debug: Check if children column exists and its type - # if 'children' in df.columns: - # sample_children = df.iloc[0]['children'] - # log.info(f"Sample children value: {sample_children}, type: {type(sample_children)}") - - # Handle bytes conversion for GraphRAG compatibility - if as_bytes or kwargs.get("as_bytes"): - log.info(f"Converting DataFrame to parquet bytes for key: {key}") - - # Apply column filtering similar to Milvus implementation - df_clean = df.copy() - - # Define expected columns for each data type - if 'documents' in table_name: - expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] - # Include text_unit_ids if it has meaningful data - if 'text_unit_ids' in df_clean.columns and any( - len(tuid) > 0 for tuid in df_clean['text_unit_ids'] if isinstance(tuid, list) - ): - expected_columns.insert(4, 'text_unit_ids') - log.info("Including text_unit_ids (appears to be final documents)") - elif 'entities' in table_name: - # Exclude degree column for GraphRAG compatibility - expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency'] - log.info("Excluding degree column from entities for finalize_entities compatibility") - elif 'relationships' in table_name: - expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] - if 'combined_degree' in df_clean.columns: - expected_columns.append('combined_degree') - elif 'text_units' in table_name: - expected_columns = ['id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] - elif 'communities' in table_name: - expected_columns = ['id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] - else: - expected_columns = list(df_clean.columns) - - # Filter columns - available_columns = [col for col in expected_columns if col in df_clean.columns] - if available_columns != expected_columns: - missing = set(expected_columns) - set(available_columns) - extra = set(df_clean.columns) - set(expected_columns) - log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") - - df_clean = df_clean[available_columns] - log.info(f"Final filtered columns: {df_clean.columns.tolist()}") - - # Convert to parquet bytes - try: - # Handle list columns that PyArrow can't serialize directly - df_for_parquet = df_clean.copy() - - # Convert list columns to JSON strings for parquet compatibility - list_columns = [] - for col in df_for_parquet.columns: - if col in df_for_parquet.columns and len(df_for_parquet) > 0: - # Check if this column contains lists - first_non_null = None - for val in df_for_parquet[col]: - if isinstance(val, list): - first_non_null = val - break - elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): - first_non_null = val - break - if isinstance(first_non_null, list): - list_columns.append(col) - # Convert lists to JSON strings - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - ) - elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids']: - # These columns should always be lists, even if empty - list_columns.append(col) - df_for_parquet[col] = df_for_parquet[col].apply( - lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - ) - - if list_columns: - log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") - - buffer = BytesIO() - df_for_parquet.to_parquet(buffer, engine='pyarrow') - buffer.seek(0) - parquet_bytes = buffer.getvalue() - log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") - return parquet_bytes - except Exception as e: - log.exception(f"Failed to convert DataFrame to parquet bytes: {e}") - return b"" - - return df - - finally: - await self._release_connection(conn) - - except Exception as e: - log.exception(f"Error retrieving data from table {table_name}: {e}") - return None async def set(self, key: str, value: Any, encoding: str | None = None) -> None: """Insert data into PostgreSQL table with upsert capability.""" @@ -783,20 +538,8 @@ def keys(self) -> list[str]: def child(self, name: str | None) -> PipelineStorage: """Create a child storage instance.""" - if name is None: - return self - - # Create child with modified table prefix - child_prefix = f"{self._collection_prefix}{name}_" - return PostgresPipelineStorage( - host=self._host, - port=self._port, - database=self._database, - username=self._username, - password=self._password, - collection_prefix=child_prefix, - encoding=self._encoding, - ) + return self + async def get_creation_date(self, key: str) -> str: """Get the creation date for data.""" From 2afbdd1832b8187e19abd3dd36f13c3f6949d022 Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Sat, 23 Aug 2025 00:27:28 -0700 Subject: [PATCH 08/12] working version but need to clean up --- graphrag/storage/postgres_pipeline_storage.py | 153 ++++++++++++++---- 1 file changed, 123 insertions(+), 30 deletions(-) diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index 7ecd311714..fb8761e6bb 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -145,16 +145,27 @@ async def _ensure_table_exists(self, table_name: str) -> None: def _is_scalar_na(self, value: Any) -> bool: """Safely check if a value is NA/null, avoiding issues with arrays.""" try: - # Don't check pd.isna on complex objects or large arrays - if isinstance(value, (list, dict)): - return False - if hasattr(value, '__len__') and len(str(value)) > 100: - return False + # Handle arrays/lists - check if it's an array-like object + if hasattr(value, '__len__') and hasattr(value, '__getitem__'): + # For arrays, check if all elements are NA + if isinstance(value, (list, tuple)): + return all(pd.isna(item) if not hasattr(item, '__len__') or len(str(item)) < 100 else False for item in value) + elif hasattr(value, 'size'): + # NumPy array - be careful with large arrays + if value.size > 100: + return False + try: + return pd.isna(value).all() if value.size > 1 else pd.isna(value.item()) + except (ValueError, TypeError): + return False + else: + return False + + # For scalar values, use pandas isna return pd.isna(value) except (ValueError, TypeError): # If pd.isna fails, assume it's not NA return False - def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[dict]: """Prepare DataFrame data for PostgreSQL insertion following CosmosDB pattern.""" log.info(f"Preparing data for table {table_name}, DataFrame shape: {df.shape}") @@ -176,9 +187,23 @@ def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[ # Convert numpy types to native Python types for JSON serialization for key, value in record_data.items(): - # Handle different value types carefully - if isinstance(value, (list, dict)): - # Keep lists and dicts as-is (like text_unit_ids) + if key in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids']: + # Clean list fields during storage preparation + if isinstance(value, list): + record_data[key] = self._ensure_string_list(value) + elif self._is_scalar_na(value) or value is None: + record_data[key] = [] + elif hasattr(value, '__len__') and len(value) == 0: + # Handle empty arrays/lists + record_data[key] = [] + elif hasattr(value, '__len__') and len(value) > 0: + # Handle non-empty arrays/lists + record_data[key] = self._ensure_string_list(value.tolist() if hasattr(value, 'tolist') else list(value)) + else: + # Handle single values or other scalar types + record_data[key] = [str(value)] if str(value).strip() else [] + elif isinstance(value, (list, dict)): + # Keep other lists and dicts as-is record_data[key] = value elif hasattr(value, 'item') and hasattr(value, 'size') and value.size == 1: # Only use .item() for numpy scalars (arrays of size 1) @@ -186,8 +211,18 @@ def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[ elif hasattr(value, 'tolist'): # Convert numpy arrays to Python lists record_data[key] = value.tolist() - elif isinstance(value, (pd.Timestamp, pd.DatetimeTZDtype)): - record_data[key] = value.isoformat() if pd.notna(value) else None + elif isinstance(value, pd.Timestamp): + # Handle pandas Timestamp objects + try: + record_data[key] = value.isoformat() if not self._is_scalar_na(value) else None + except AttributeError: + record_data[key] = str(value) if not self._is_scalar_na(value) else None + elif hasattr(value, 'isoformat') and callable(getattr(value, 'isoformat', None)): + # Handle other datetime-like objects + try: + record_data[key] = value.isoformat() if not self._is_scalar_na(value) else None + except (AttributeError, TypeError): + record_data[key] = str(value) if not self._is_scalar_na(value) else None elif self._is_scalar_na(value): # Only check pd.isna for scalar-like values record_data[key] = None @@ -205,7 +240,6 @@ def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[ log.info(f"Prepared {len(records)} records for PostgreSQL") return records - async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict]) -> None: """Perform high-performance batch upsert of records using executemany.""" total_records = len(records) @@ -276,16 +310,22 @@ def find( return iter([]) def _parse_jsonb_field(self, value: Any, default_type: str = "list") -> Any: - """Parse JSONB field back to Python object.""" + """Parse JSONB field back to Python object with better type consistency.""" if value is None: return {} if default_type == "dict" else [] if isinstance(value, (list, dict)): return value if isinstance(value, str): try: - return json.loads(value) + parsed = json.loads(value) + # Ensure we return the correct type + if default_type == "dict": + return parsed if isinstance(parsed, dict) else {} + else: + return parsed if isinstance(parsed, list) else [] except (json.JSONDecodeError, TypeError): return {} if default_type == "dict" else [] + # For any other type (including float/NaN), return empty default return {} if default_type == "dict" else [] def _convert_dataframe_to_parquet_bytes(self, df: pd.DataFrame) -> bytes: @@ -299,10 +339,30 @@ def _convert_dataframe_to_parquet_bytes(self, df: pd.DataFrame) -> bytes: log.error(f"Failed to convert DataFrame to parquet bytes: {e}") return b"" + def _ensure_string_list(self, value: Any) -> list[str]: + """Ensure a value is a list of strings, filtering out invalid items.""" + if not isinstance(value, list): + return [] + + result = [] + for item in value: + # Skip None values + if item is None: + continue + # Skip NaN values (both float NaN and string 'nan') + if isinstance(item, float) and (pd.isna(item) or item != item): # NaN check + continue + if isinstance(item, str) and item.lower() in ['nan', 'none', '']: + continue + # Convert to string and add + result.append(str(item)) + + return result + async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: """Retrieve data from PostgreSQL table.""" + table_name = self._get_table_name(key) try: - table_name = self._get_table_name(key) log.info(f"Retrieving data from table: {table_name}") conn = await self._get_connection() @@ -343,25 +403,61 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None for row in rows: record_data = dict(row['data']) if isinstance(row['data'], dict) else json.loads(row['data']) - # Parse JSONB list fields back to proper Python lists - for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids']: - if field in record_data: - record_data[field] = self._parse_jsonb_field(record_data[field], "list") - - # Parse metadata as dict - if 'metadata' in record_data: - record_data['metadata'] = self._parse_jsonb_field(record_data['metadata'], "dict") + # Clean up the record data with better type consistency + cleaned_data = {} + for field_name, value in record_data.items(): + if field_name in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids']: + # These should always be lists of strings + parsed_list = self._parse_jsonb_field(value, "list") + # Use the robust string list converter + cleaned_data[field_name] = self._ensure_string_list(parsed_list) + elif field_name == 'metadata': + # Metadata should be a dict + cleaned_data[field_name] = self._parse_jsonb_field(value, "dict") + elif self._is_scalar_na(value) or value is None: + cleaned_data[field_name] = None + elif isinstance(value, float) and pd.isna(value): + cleaned_data[field_name] = None + else: + cleaned_data[field_name] = value - records.append(record_data) + records.append(cleaned_data) df = pd.DataFrame(records) + # Additional cleanup - ensure list columns are properly typed using the robust method + for col in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids']: + if col in df.columns: + df[col] = df[col].apply(self._ensure_string_list) + # Clean up NaN values df = df.where(pd.notna(df), None) log.info(f"Get table {table_name} DataFrame with shape: {df.shape}") log.info(f"Get table {table_name} DataFrame columns: {df.columns.tolist()}") + # Debug: Log sample data to verify type consistency + if 'text_unit_ids' in df.columns and len(df) > 0: + sample_ids = df['text_unit_ids'].iloc[0] + log.info(f"Sample text_unit_ids: {sample_ids}, type: {type(sample_ids)}") + if isinstance(sample_ids, list) and len(sample_ids) > 0: + log.info(f"First item: {sample_ids[0]}, type: {type(sample_ids[0])}") + + # Check all values in the column for mixed types + all_types = set() + for idx, row_ids in enumerate(df['text_unit_ids']): + if isinstance(row_ids, list): + for item in row_ids: + all_types.add(type(item).__name__) + else: + all_types.add(type(row_ids).__name__) + + # Log first few rows for debugging + if idx < 3: + log.debug(f"Row {idx} text_unit_ids: {row_ids}, types: {[type(x).__name__ for x in row_ids] if isinstance(row_ids, list) else type(row_ids).__name__}") + + log.info(f"All types found in text_unit_ids column: {all_types}") + # Convert to bytes if requested if as_bytes or kwargs.get("as_bytes"): return self._convert_dataframe_to_parquet_bytes(df) @@ -372,7 +468,7 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None await self._release_connection(conn) except Exception as e: - log.exception(f"Error retrieving data from table {table_name}: {e}") + log.exception(f"Error retrieving data from table {table_name}: %s", e) return None async def set(self, key: str, value: Any, encoding: str | None = None) -> None: @@ -435,7 +531,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: await self._release_connection(conn) except Exception as e: - log.exception("Error setting data in PostgreSQL table %s: %s", table_name, e) + log.exception("Error setting data for key %s: %s", key, e) async def has(self, key: str) -> bool: """Check if data exists for the given key.""" @@ -458,10 +554,7 @@ async def has(self, key: str) -> bool: total_count = await conn.fetchval( f"SELECT COUNT(*) FROM {table_name}" ) - if total_count > 0: - return True - else: - raise ValueError(f"No records found in table {table_name} for parquet key {key}") + return total_count > 0 else: # Check for exact key match exists = await conn.fetchval( From f99ff174da1a5557b384e66b06b0bd50e02d824e Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Sun, 24 Aug 2025 22:21:53 -0700 Subject: [PATCH 09/12] removed debug code --- graphrag/storage/postgres_pipeline_storage.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index fb8761e6bb..3139c7fced 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -13,9 +13,7 @@ import pandas as pd import asyncpg from asyncpg import Connection, Pool -import asyncio -from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import ( PipelineStorage, get_timestamp_formatted_with_local_tz, @@ -436,28 +434,6 @@ async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None log.info(f"Get table {table_name} DataFrame with shape: {df.shape}") log.info(f"Get table {table_name} DataFrame columns: {df.columns.tolist()}") - # Debug: Log sample data to verify type consistency - if 'text_unit_ids' in df.columns and len(df) > 0: - sample_ids = df['text_unit_ids'].iloc[0] - log.info(f"Sample text_unit_ids: {sample_ids}, type: {type(sample_ids)}") - if isinstance(sample_ids, list) and len(sample_ids) > 0: - log.info(f"First item: {sample_ids[0]}, type: {type(sample_ids[0])}") - - # Check all values in the column for mixed types - all_types = set() - for idx, row_ids in enumerate(df['text_unit_ids']): - if isinstance(row_ids, list): - for item in row_ids: - all_types.add(type(item).__name__) - else: - all_types.add(type(row_ids).__name__) - - # Log first few rows for debugging - if idx < 3: - log.debug(f"Row {idx} text_unit_ids: {row_ids}, types: {[type(x).__name__ for x in row_ids] if isinstance(row_ids, list) else type(row_ids).__name__}") - - log.info(f"All types found in text_unit_ids column: {all_types}") - # Convert to bytes if requested if as_bytes or kwargs.get("as_bytes"): return self._convert_dataframe_to_parquet_bytes(df) From ff3d124cb0a211a61493f9d5d32150c63f2ccfd4 Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Tue, 2 Sep 2025 20:40:03 -0700 Subject: [PATCH 10/12] removed unused code --- .../storage/postgres_pipeline_storage_bk.py | 1294 ----------------- 1 file changed, 1294 deletions(-) delete mode 100644 graphrag/storage/postgres_pipeline_storage_bk.py diff --git a/graphrag/storage/postgres_pipeline_storage_bk.py b/graphrag/storage/postgres_pipeline_storage_bk.py deleted file mode 100644 index 319886db66..0000000000 --- a/graphrag/storage/postgres_pipeline_storage_bk.py +++ /dev/null @@ -1,1294 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""PostgreSQL Storage implementation of PipelineStorage.""" - -import json -import logging -import re -from collections.abc import Iterator -from datetime import datetime, timezone -from io import BytesIO -from typing import Any - -import numpy as np -import pandas as pd -import asyncpg -from asyncpg import Connection, Pool - -from graphrag.storage.pipeline_storage import ( - PipelineStorage, - get_timestamp_formatted_with_local_tz, -) - -log = logging.getLogger(__name__) - -class PostgresPipelineStorage(PipelineStorage): - """The PostgreSQL Storage Implementation.""" - - _pool: Pool | None - _connection_string: str - _database: str - _collection_prefix: str - _encoding: str - _no_id_prefixes: list[str] - - def __init__( - self, - host: str = "localhost", - port: int = 5432, - database: str = "graphrag", - username: str = "postgres", - password: str | None = None, - collection_prefix: str = "lgr_", - encoding: str = "utf-8", - connection_string: str | None = None, - command_timeout: int = 600, # 10 minutes for SQL commands - server_timeout: int = 120, # 2 minutes for server connection - connection_timeout: int = 60, # 1 minute to establish connection - batch_size: int = 50, # Smaller batch size to reduce timeout risk - **kwargs: Any, - ): - """Initialize the PostgreSQL Storage.""" - self._host = host - self._port = port - self._database = database - self._username = username - self._password = password - self._collection_prefix = collection_prefix - self._encoding = encoding - self._command_timeout = command_timeout - self._server_timeout = server_timeout - self._connection_timeout = connection_timeout - self._batch_size = batch_size - self._no_id_prefixes = [] - self._pool = None - - # Build connection string from components or use provided one - if connection_string: - self._connection_string = connection_string - else: - if password: - self._connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}" - else: - self._connection_string = f"postgresql://{username}@{host}:{port}/{database}" - - log.info( - "Initializing PostgreSQL storage with host: %s:%s, database: %s, collection_prefix: %s, command_timeout: %s, batch_size: %s", - self._host, - self._port, - self._database, - self._collection_prefix, - self._command_timeout, - self._batch_size, - ) - - async def _get_connection(self) -> Connection: - """Get a database connection from the pool.""" - if self._pool is None: - try: - self._pool = await asyncpg.create_pool( - self._connection_string, - min_size=1, - max_size=10, - command_timeout=self._command_timeout, - server_settings={ - 'application_name': 'graphrag_postgres_storage' - }, - # Use connection_timeout for initial connection establishment - timeout=self._connection_timeout - ) - log.info("Created PostgreSQL connection pool with command_timeout: %s, connection_timeout: %s", - self._command_timeout, self._connection_timeout) - except Exception as e: - log.error("Failed to create PostgreSQL connection pool: %s", e) - raise - - return await self._pool.acquire() - - async def _release_connection(self, conn: Connection) -> None: - """Release a connection back to the pool.""" - if self._pool: - await self._pool.release(conn) - - def _get_table_name(self, key: str) -> str: - """Get the table name for a given key.""" - # Extract the base name without file extension - base_name = key.split(".")[0] - - return f"{self._collection_prefix}{base_name}" - - def _get_prefix(self, key: str) -> str: - """Get the prefix of the filename key.""" - return key.split(".")[0] - - def _get_entities_table_schema(self, table_name: str) -> str: - """Get the SQL schema for entities table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - title TEXT, - type TEXT, - description TEXT, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - frequency INTEGER DEFAULT 0, - degree INTEGER DEFAULT 0, - x DOUBLE PRECISION DEFAULT 0.0, - y DOUBLE PRECISION DEFAULT 0.0, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Performance indexes - CREATE INDEX idx_{table_name}_type ON {table_name}(type); - CREATE INDEX idx_{table_name}_frequency ON {table_name}(frequency); - CREATE INDEX idx_{table_name}_title ON {table_name}(title); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - """ - - def _get_relationships_table_schema(self, table_name: str) -> str: - """Get the SQL schema for relationships table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - source TEXT NOT NULL, - target TEXT NOT NULL, - description TEXT DEFAULT '', - weight DOUBLE PRECISION DEFAULT 0.0, - combined_degree INTEGER DEFAULT 0, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Graph query indexes - CREATE INDEX idx_{table_name}_source ON {table_name}(source); - CREATE INDEX idx_{table_name}_target ON {table_name}(target); - CREATE INDEX idx_{table_name}_weight ON {table_name}(weight); - CREATE INDEX idx_{table_name}_source_target ON {table_name}(source, target); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - """ - - def _get_communities_table_schema(self, table_name: str) -> str: - """Get the SQL schema for communities table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - community INTEGER, - level INTEGER DEFAULT 0, - parent INTEGER, - children JSONB DEFAULT '[]'::jsonb, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - entity_ids JSONB DEFAULT '[]'::jsonb, - relationship_ids JSONB DEFAULT '[]'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Community hierarchy indexes - CREATE INDEX idx_{table_name}_community ON {table_name}(community); - CREATE INDEX idx_{table_name}_level ON {table_name}(level); - CREATE INDEX idx_{table_name}_parent ON {table_name}(parent); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); - CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); - """ - - def _get_text_units_table_schema(self, table_name: str) -> str: - """Get the SQL schema for text_units table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - text TEXT, - n_tokens INTEGER DEFAULT 0, - document_ids JSONB DEFAULT '[]'::jsonb, - entity_ids JSONB DEFAULT '[]'::jsonb, - relationship_ids JSONB DEFAULT '[]'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Text search and relationship indexes - CREATE INDEX idx_{table_name}_n_tokens ON {table_name}(n_tokens); - CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); - CREATE INDEX idx_{table_name}_document_ids_gin ON {table_name} USING GIN(document_ids); - CREATE INDEX idx_{table_name}_entity_ids_gin ON {table_name} USING GIN(entity_ids); - CREATE INDEX idx_{table_name}_relationship_ids_gin ON {table_name} USING GIN(relationship_ids); - """ - - def _get_documents_table_schema(self, table_name: str) -> str: - """Get the SQL schema for documents table.""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - human_readable_id BIGINT, - title TEXT, - text TEXT, - text_unit_ids JSONB DEFAULT '[]'::jsonb, - creation_date TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - metadata JSONB DEFAULT '{{}}'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Document search indexes - CREATE INDEX idx_{table_name}_title ON {table_name}(title); - CREATE INDEX idx_{table_name}_creation_date ON {table_name}(creation_date); - CREATE INDEX idx_{table_name}_text_gin ON {table_name} USING GIN(to_tsvector('english', text)); - CREATE INDEX idx_{table_name}_text_unit_ids_gin ON {table_name} USING GIN(text_unit_ids); - CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); - """ - - def _get_generic_table_schema(self, table_name: str) -> str: - """Get the SQL schema for generic data (fallback).""" - return f""" - CREATE TABLE {table_name} ( - id TEXT PRIMARY KEY, - data JSONB NOT NULL, - metadata JSONB DEFAULT '{{}}'::jsonb, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() - ); - - -- Generic indexes - CREATE INDEX idx_{table_name}_data_gin ON {table_name} USING GIN(data); - CREATE INDEX idx_{table_name}_metadata_gin ON {table_name} USING GIN(metadata); - """ - - def _get_table_schema_sql(self, table_name: str) -> str: - """Get the appropriate schema SQL for the table type.""" - - if 'entities' in table_name: - return self._get_entities_table_schema(table_name) - elif 'relationships' in table_name: - return self._get_relationships_table_schema(table_name) - elif 'communities' in table_name: - return self._get_communities_table_schema(table_name) - elif 'text_units' in table_name: - return self._get_text_units_table_schema(table_name) - elif 'documents' in table_name: - return self._get_documents_table_schema(table_name) - else: - return self._get_generic_table_schema(table_name) - - async def _ensure_table_exists_with_schema(self, table_name: str) -> None: - conn = await self._get_connection() - try: - table_exists = await conn.fetchval( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", - table_name - ) - if not table_exists: - # Create table with appropriate typed schema (pass original table_name for type detection) - schema_sql = self._get_table_schema_sql(table_name) - await conn.execute(schema_sql) - log.info(f"Created table {table_name} with specific schema") - - finally: - await self._release_connection(conn) - - def _process_id_field(self, df: pd.DataFrame, table_name: str) -> list[str]: - """Process ID values - store clean IDs with prefix following CosmosDB pattern in GraphRAG.""" - prefix = self._get_prefix(table_name.replace(self._collection_prefix, "")) - id_values = [] - - if "id" not in df.columns: - # No ID column - create prefixed sequential IDs and track this prefix - for index in range(len(df)): - id_values.append(f"{prefix}:{index}") - if prefix not in self._no_id_prefixes: - self._no_id_prefixes.append(prefix) - log.info(f"No ID column found for {prefix}, generated prefixed sequential IDs") - else: - # Has ID column - process each row with prefix - for index, val in enumerate(df["id"]): - if self._is_scalar_na(val) or val == '' or val == 'nan' or (isinstance(val, list) and (len(val) == 0 or self._is_scalar_na(val[0]) or str(val[0]).strip() == '')): - # Missing ID - create prefixed sequential ID and track this prefix - id_values.append(f"{prefix}:{index}") - if prefix not in self._no_id_prefixes: - self._no_id_prefixes.append(prefix) - else: - # Valid ID - use as is without prefix - if isinstance(val, list): - id_values.append(str(val[0])) - else: - id_values.append(str(val)) - - return id_values - - def _is_scalar_na(self, value: Any) -> bool: - """Safely check if a value is NA/null, avoiding issues with arrays.""" - try: - # Don't check pd.isna on complex objects or large arrays - if isinstance(value, (list, dict)): - return False - if hasattr(value, '__len__') and len(str(value)) > 100: - return False - return pd.isna(value) - except (ValueError, TypeError): - # If pd.isna fails, assume it's not NA - return False - - def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[dict]: - """Prepare DataFrame data for PostgreSQL insertion with typed columns.""" - log.info(f"Preparing data for table {table_name}, DataFrame shape: {df.shape}") - log.info(f"DataFrame columns: {df.columns.tolist()}") - - # Add human_readable_id if missing - if 'human_readable_id' not in df.columns: - df = df.copy() - df['human_readable_id'] = range(len(df)) - log.info(f"Generated sequential human_readable_id for {len(df)} records") - - # Process IDs - for typed tables, we can use simpler ID handling - ids = self._process_id_field(df, table_name) - - # Determine if this is a typed table or generic table - is_typed_table = any(table_type in table_name for table_type in - ['entities', 'relationships', 'communities', 'text_units', 'documents']) - - if is_typed_table: - return self._prepare_data_for_typed_table(df, table_name, ids) - else: - return self._prepare_data_for_generic_table(df, table_name, ids) - - def _prepare_data_for_typed_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: - """Prepare data for typed PostgreSQL tables with specific columns.""" - records = [] - - for i in range(len(df)): - record = {'id': ids[i]} - row = df.iloc[i] - - # Map DataFrame columns to table columns based on table type - if 'entities' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'title': str(row.get('title', '')), - 'type': str(row.get('type', '')), - 'description': str(row.get('description', '')), - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), - 'frequency': int(row.get('frequency', 0)) if pd.notna(row.get('frequency', 0)) else 0, - 'degree': int(row.get('degree', 0)) if pd.notna(row.get('degree', 0)) else 0, - 'x': float(row.get('x', 0.0)) if pd.notna(row.get('x', 0.0)) else 0.0, - 'y': float(row.get('y', 0.0)) if pd.notna(row.get('y', 0.0)) else 0.0 - }) - elif 'relationships' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'source': str(row.get('source', '')), - 'target': str(row.get('target', '')), - 'description': str(row.get('description', '')), - 'weight': float(row.get('weight', 0.0)) if pd.notna(row.get('weight', 0.0)) else 0.0, - 'combined_degree': int(row.get('combined_degree', 0)) if pd.notna(row.get('combined_degree', 0)) else 0, - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])) - }) - elif 'communities' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'community': int(row.get('community', 0)) if pd.notna(row.get('community')) and str(row.get('community', '')).strip() != '' else 0, - 'level': int(row.get('level', 0)) if pd.notna(row.get('level', 0)) else 0, - 'parent': int(row.get('parent', 0)) if pd.notna(row.get('parent')) and str(row.get('parent', '')).strip() != '' else None, - 'children': self._ensure_json_list(row.get('children', [])), - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), - 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), - 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) - }) - elif 'text_units' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'text': str(row.get('text', '')), - 'n_tokens': int(row.get('n_tokens', 0)) if pd.notna(row.get('n_tokens', 0)) else 0, - 'document_ids': self._ensure_json_list(row.get('document_ids', [])), - 'entity_ids': self._ensure_json_list(row.get('entity_ids', [])), - 'relationship_ids': self._ensure_json_list(row.get('relationship_ids', [])) - }) - elif 'documents' in table_name: - record.update({ - 'human_readable_id': int(row.get('human_readable_id', i)), - 'title': str(row.get('title', '')), - 'text': str(row.get('text', '')), - 'text_unit_ids': self._ensure_json_list(row.get('text_unit_ids', [])), - 'creation_date': self._ensure_datetime(row.get('creation_date')), - 'metadata': self._ensure_json_dict(row.get('metadata', {})) - }) - - records.append(record) - - log.info(f"Prepared {len(records)} records for typed table {table_name}") - if records: - log.info(f"Sample typed record: {list(records[0].keys())}") - - return records - - def _prepare_data_for_generic_table(self, df: pd.DataFrame, table_name: str, ids: list[str]) -> list[dict]: - """Prepare data for generic PostgreSQL tables (fallback to JSONB storage).""" - records = [] - for i in range(len(df)): - # Create record with ID and all data in JSONB field - record_data = df.iloc[i].to_dict() - - # Convert numpy types to native Python types for JSON serialization - for key, value in record_data.items(): - if isinstance(value, (list, dict)): - record_data[key] = value - elif hasattr(value, 'tolist'): - # Handle numpy arrays and other numpy types - record_data[key] = value.tolist() - elif hasattr(value, 'item') and hasattr(value, 'size') and value.size == 1: - record_data[key] = value.item() - elif isinstance(value, (pd.Timestamp, pd.DatetimeTZDtype)): - record_data[key] = value.isoformat() if pd.notna(value) else None - elif self._is_scalar_na(value): - record_data[key] = None - elif isinstance(value, (list, np.ndarray)) and len(value) == 0: - record_data[key] = [] - else: - record_data[key] = value - - record = { - 'id': ids[i], - 'data': record_data, - 'metadata': {} - } - records.append(record) - - log.info(f"Prepared {len(records)} records for generic table {table_name}") - return records - - def _ensure_json_list(self, value: Any) -> list: - """Ensure a value is a proper list for JSONB storage.""" - if isinstance(value, list): - # Convert any numpy arrays in the list to regular Python lists - return [item.tolist() if hasattr(item, 'tolist') else item for item in value] - elif hasattr(value, 'tolist'): - # Handle numpy arrays directly - converted = value.tolist() - return converted if isinstance(converted, list) else [converted] - elif isinstance(value, str) and value: - try: - parsed = json.loads(value) - return parsed if isinstance(parsed, list) else [] - except (json.JSONDecodeError, TypeError): - return [] - elif value is None or pd.isna(value): - return [] - else: - return [value] if value else [] - - def _ensure_json_dict(self, value: Any) -> dict: - """Ensure a value is a proper dict for JSONB storage.""" - if isinstance(value, dict): - # Convert any numpy arrays in the dict to regular Python objects - result = {} - for k, v in value.items(): - if hasattr(v, 'tolist'): - result[k] = v.tolist() - elif hasattr(v, 'item') and hasattr(v, 'size') and v.size == 1: - result[k] = v.item() - else: - result[k] = v - return result - elif isinstance(value, str) and value: - try: - parsed = json.loads(value) - return parsed if isinstance(parsed, dict) else {} - except (json.JSONDecodeError, TypeError): - return {} - elif value is None or pd.isna(value): - return {} - else: - return {'value': str(value)} if value else {} - - def _ensure_timezone_aware_datetimes(self, records: list[dict]) -> list[dict]: - """Ensure all datetime fields in records are timezone-aware for PostgreSQL.""" - datetime_fields = ['creation_date', 'created_at', 'updated_at'] - - for record in records: - for field in datetime_fields: - if field in record: - value = record[field] - if value is not None: - record[field] = self._ensure_datetime(value) - - return records - - def _ensure_datetime(self, value: Any) -> datetime: - """Ensure a value is a proper timezone-aware datetime object for PostgreSQL storage.""" - from dateutil import parser - - if isinstance(value, datetime): - # If it's already a datetime, ensure it has timezone info - if value.tzinfo is None: - # If it's timezone-naive, localize to UTC - return value.replace(tzinfo=timezone.utc) - else: - # Already timezone-aware - return value - elif isinstance(value, pd.Timestamp): - # Convert pandas Timestamp to datetime - dt = value.to_pydatetime() - # Ensure timezone awareness - if dt.tzinfo is None: - return dt.replace(tzinfo=timezone.utc) - else: - return dt - elif isinstance(value, str) and value: - try: - # Try to parse the string as a datetime - parsed_dt = parser.parse(value) - # Ensure timezone awareness - if parsed_dt.tzinfo is None: - return parsed_dt.replace(tzinfo=timezone.utc) - else: - return parsed_dt - except (ValueError, TypeError): - # If parsing fails, return current time - return datetime.now(timezone.utc) - elif value is None or pd.isna(value): - return datetime.now(timezone.utc) - else: - # For any other type, return current time - return datetime.now(timezone.utc) - - async def _batch_upsert_records(self, conn: Connection, table_name: str, records: list[dict]) -> None: - """Perform high-performance batch upsert of records using executemany.""" - total_records = len(records) - log.info(f"Starting batch upsert of {total_records} records to {table_name} with batch size {self._batch_size}") - - # Ensure all datetime fields are timezone-aware - records = self._ensure_timezone_aware_datetimes(records) - - processed_count = 0 - - # Determine if this is a typed table or generic table - is_typed_table = any(table_type in table_name for table_type in - ['entities', 'relationships', 'communities', 'text_units', 'documents']) - - # Process records in batches for optimal performance - for i in range(0, total_records, self._batch_size): - batch = records[i:i + self._batch_size] - batch_end = min(i + self._batch_size, total_records) - - try: - if is_typed_table: - await self._batch_upsert_typed_records(conn, table_name, batch) - else: - await self._batch_upsert_generic_records(conn, table_name, batch) - - except Exception as e: - log.warning(f"Batch method failed for batch {i}-{batch_end}, falling back to individual inserts: {e}") - - # Fallback to individual inserts within the batch - try: - async with conn.transaction(): - if is_typed_table: - for record in batch: - await self._insert_typed_record(conn, table_name, record) - else: - upsert_sql = f""" - INSERT INTO {table_name} (id, data, updated_at) - VALUES ($1, $2, NOW()) - ON CONFLICT (id) - DO UPDATE SET - data = EXCLUDED.data, - updated_at = NOW() - """ - for record in batch: - await conn.execute(upsert_sql, record['id'], json.dumps(record['data'])) - except Exception as inner_e: - log.error(f"Both batch and individual insert methods failed for batch {i}-{batch_end}: {inner_e}") - raise - - processed_count += len(batch) - - # Log progress every batch for visibility - if i % self._batch_size == 0 or batch_end == total_records: - log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") - - async def _batch_upsert_typed_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: - """Batch upsert for typed tables with specific columns.""" - async with conn.transaction(): - - if 'entities' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - title = EXCLUDED.title, - type = EXCLUDED.type, - description = EXCLUDED.description, - text_unit_ids = EXCLUDED.text_unit_ids, - frequency = EXCLUDED.frequency, - degree = EXCLUDED.degree, - x = EXCLUDED.x, - y = EXCLUDED.y, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['title'], r['type'], r['description'], - json.dumps(r['text_unit_ids']), r['frequency'], r['degree'], r['x'], r['y']) - for r in batch - ] - elif 'relationships' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - source = EXCLUDED.source, - target = EXCLUDED.target, - description = EXCLUDED.description, - weight = EXCLUDED.weight, - combined_degree = EXCLUDED.combined_degree, - text_unit_ids = EXCLUDED.text_unit_ids, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['source'], r['target'], r['description'], - r['weight'], r['combined_degree'], json.dumps(r['text_unit_ids'])) - for r in batch - ] - elif 'communities' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - community = EXCLUDED.community, - level = EXCLUDED.level, - parent = EXCLUDED.parent, - children = EXCLUDED.children, - text_unit_ids = EXCLUDED.text_unit_ids, - entity_ids = EXCLUDED.entity_ids, - relationship_ids = EXCLUDED.relationship_ids, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['community'], r['level'], r['parent'], - json.dumps(r['children']), json.dumps(r['text_unit_ids']), - json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) - for r in batch - ] - elif 'text_units' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - text = EXCLUDED.text, - n_tokens = EXCLUDED.n_tokens, - document_ids = EXCLUDED.document_ids, - entity_ids = EXCLUDED.entity_ids, - relationship_ids = EXCLUDED.relationship_ids, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['text'], r['n_tokens'], - json.dumps(r['document_ids']), json.dumps(r['entity_ids']), json.dumps(r['relationship_ids'])) - for r in batch - ] - elif 'documents' in table_name: - upsert_sql = f""" - INSERT INTO {table_name} (id, human_readable_id, title, text, text_unit_ids, creation_date, metadata, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) - ON CONFLICT (id) - DO UPDATE SET - human_readable_id = EXCLUDED.human_readable_id, - title = EXCLUDED.title, - text = EXCLUDED.text, - text_unit_ids = EXCLUDED.text_unit_ids, - creation_date = EXCLUDED.creation_date, - metadata = EXCLUDED.metadata, - updated_at = NOW() - """ - batch_data = [ - (r['id'], r['human_readable_id'], r['title'], r['text'], - json.dumps(r['text_unit_ids']), - self._ensure_datetime(r['creation_date']), - json.dumps(r['metadata'])) - for r in batch - ] - else: - raise ValueError(f"Unknown typed table: {table_name}") - - await conn.executemany(upsert_sql, batch_data) - - async def _batch_upsert_generic_records(self, conn: Connection, table_name: str, batch: list[dict]) -> None: - """Batch upsert for generic tables using JSONB.""" - async with conn.transaction(): - upsert_sql = f""" - INSERT INTO {table_name} (id, data, metadata, updated_at) - VALUES ($1, $2, $3, NOW()) - ON CONFLICT (id) - DO UPDATE SET - data = EXCLUDED.data, - metadata = EXCLUDED.metadata, - updated_at = NOW() - """ - batch_data = [ - (record['id'], json.dumps(record['data']), json.dumps(record['metadata'])) - for record in batch - ] - await conn.executemany(upsert_sql, batch_data) - - async def _insert_typed_record(self, conn: Connection, table_name: str, record: dict) -> None: - """Insert a single typed record (fallback method).""" - # This is a simplified fallback - implement based on table type if needed - # For now, just use the batch method with a single record - await self._batch_upsert_typed_records(conn, table_name, [record]) - - def find( - self, - file_pattern: re.Pattern[str], - base_dir: str | None = None, - file_filter: dict[str, Any] | None = None, - max_count=-1, - ) -> Iterator[tuple[str, dict[str, Any]]]: - """Find data in PostgreSQL tables using a file pattern regex.""" - # This is a synchronous method, but we need async operations - # For now, implement a basic version - in practice, this would need refactoring - log.info("Searching PostgreSQL tables for pattern %s", file_pattern.pattern) - - # Note: This is simplified - full implementation would need async/await support - # in the find method signature or use asyncio.run() - return iter([]) - - async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: - """Retrieve data from PostgreSQL table - simplified approach.""" - try: - table_name = self._get_table_name(key) - log.info(f"Retrieving data from table: {table_name}") - - conn = await self._get_connection() - try: - # Check if table exists - table_exists = await conn.fetchval( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", - table_name - ) - - if not table_exists: - return None - - # Simple approach: get all data and convert directly to DataFrame - rows = await conn.fetch(f"SELECT * FROM {table_name} ORDER BY created_at") - - if not rows: - return None - - # Convert to DataFrame with minimal transformation - records = [dict(row) for row in rows] - df = pd.DataFrame(records) - - # Only handle JSONB fields - convert back from JSON strings to lists/dicts - for col in df.columns: - if col in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: - df[col] = df[col].apply(self._parse_jsonb_field) - - # Handle bytes conversion for GraphRAG compatibility - if as_bytes or kwargs.get("as_bytes"): - return self._dataframe_to_parquet_bytes(df) - - return df - - finally: - await self._release_connection(conn) - - except Exception as e: - log.exception(f"Error retrieving data from table {table_name}: {e}") - return None - - def _parse_jsonb_field(self, value): - """Simple JSONB field parser.""" - if value is None: - return [] - if isinstance(value, (list, dict)): - return value - if isinstance(value, str): - try: - return json.loads(value) - except: - return [] - return [] - # async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None, **kwargs) -> Any: - # """Retrieve data from PostgreSQL table.""" - # try: - # table_name = self._get_table_name(key) - # log.info(f"Retrieving data from table: {table_name}") - - # conn = await self._get_connection() - # try: - # # Check if table exists - # table_exists = await conn.fetchval( - # "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", - # table_name - # ) - - # if not table_exists: - # log.warning(f"Table {table_name} does not exist") - # return None - - # # Determine if this is a typed table or generic table - # is_typed_table = any(table_type in table_name for table_type in - # ['entities', 'relationships', 'communities', 'text_units', 'documents']) - - # if is_typed_table: - # # For typed tables, select all columns except created_at/updated_at - # if 'documents' in table_name: - # query = "SELECT id, human_readable_id, title, text, text_unit_ids, creation_date, metadata FROM {} ORDER BY created_at".format(table_name) - # elif 'entities' in table_name: - # query = "SELECT id, human_readable_id, title, type, description, text_unit_ids, frequency, degree, x, y FROM {} ORDER BY created_at".format(table_name) - # elif 'relationships' in table_name: - # query = "SELECT id, human_readable_id, source, target, description, weight, combined_degree, text_unit_ids FROM {} ORDER BY created_at".format(table_name) - # elif 'communities' in table_name: - # query = "SELECT id, human_readable_id, community, level, parent, children, text_unit_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) - # elif 'text_units' in table_name: - # query = "SELECT id, human_readable_id, text, n_tokens, document_ids, entity_ids, relationship_ids FROM {} ORDER BY created_at".format(table_name) - # else: - # # Fallback for unknown typed table - # query = "SELECT * FROM {} ORDER BY created_at".format(table_name) - # else: - # # For generic tables, use the data column - # query = "SELECT id, data FROM {} ORDER BY created_at".format(table_name) - - # rows = await conn.fetch(query) - - # if not rows: - # log.info(f"No data found in table {table_name}") - # return None - - # log.info(f"Retrieved {len(rows)} records from table {table_name}") - - # # Check if this should be treated as raw data instead of tabular data - # if (not key.endswith('.parquet') or - # 'state' in key.lower() or - # key.endswith('.json') or - # 'context' in table_name.lower()): - # # Handle state.json or context.json as raw data - # # For non-tabular data, return the raw content from the first record - # if rows: - # if is_typed_table: - # # For typed tables, convert row to dict and return as JSON - # row_dict = dict(rows[0]) - # json_str = json.dumps(row_dict) - # return json_str.encode(encoding or self._encoding) if as_bytes else json_str - # elif 'data' in rows[0]: - # raw_content = rows[0]['data'] - # if isinstance(raw_content, dict): - # json_str = json.dumps(raw_content) - # return json_str.encode(encoding or self._encoding) if as_bytes else json_str - # return b"" if as_bytes else "" - - # # Convert to DataFrame - # records = [] - # for row in rows: - # if is_typed_table: - # # For typed tables, the row is already the data we need - # record_data = dict(row) - - # # Convert JSONB fields back to proper Python objects - # for field in ['text_unit_ids', 'children', 'entity_ids', 'relationship_ids', 'document_ids', 'metadata']: - # if field in record_data: - # value = record_data[field] - - # if value is None: - # record_data[field] = {} if field == 'metadata' else [] - # elif isinstance(value, str): - # # Handle JSONB strings - they should always be valid JSON - # try: - # parsed = json.loads(value) - # # Validate the parsed type - # if field == 'metadata': - # record_data[field] = parsed if isinstance(parsed, dict) else {} - # else: - # record_data[field] = parsed if isinstance(parsed, list) else [] - # except (json.JSONDecodeError, TypeError): - # log.warning(f"Failed to parse JSONB field {field}: {value}") - # # Fallback for non-JSON strings - # if field == 'metadata': - # record_data[field] = {} - # else: - # record_data[field] = [] - # elif isinstance(value, (list, dict)): - # # Already correct type (shouldn't happen with JSONB, but handle it) - # record_data[field] = value - # else: - # # Convert other types - # if field == 'metadata': - # record_data[field] = {'value': str(value)} if value else {} - # else: - # record_data[field] = [value] if value else [] - # else: - # # Handle generic table data (JSONB data column) - # if isinstance(row['data'], dict): - # record_data = dict(row['data']) - # else: - # # If it's a string, parse it as JSON - # record_data = json.loads(row['data']) if isinstance(row['data'], str) else row['data'] - - # # Clean up the record data - convert None to proper values and handle NaN - # cleaned_data = {} - # for key_name, value in record_data.items(): - # if self._is_scalar_na(value) or value is None: - # cleaned_data[key_name] = None - # elif isinstance(value, str) and key_name == 'text_unit_ids' and not is_typed_table: - # # Try to parse text_unit_ids back from JSON string if needed (only for generic tables) - # try: - # parsed_value = json.loads(value) - # cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [value] - # except (json.JSONDecodeError, TypeError): - # # If it's not JSON, treat as a single item list or keep as string - # cleaned_data[key_name] = [value] if value else [] - # elif key_name in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids'] and not is_typed_table: - # # Always ensure these columns are lists (only for generic tables - typed tables already handled this) - # if isinstance(value, list): - # cleaned_data[key_name] = value - # elif isinstance(value, str): - # try: - # parsed_value = json.loads(value) - # cleaned_data[key_name] = parsed_value if isinstance(parsed_value, list) else [] - # except (json.JSONDecodeError, TypeError): - # cleaned_data[key_name] = [] - # elif value is None: - # cleaned_data[key_name] = [] - # else: - # # fallback: wrap single value in a list - # cleaned_data[key_name] = [value] - # elif isinstance(value, (list, np.ndarray)) and len(value) == 0: - # # Handle empty arrays/lists - # cleaned_data[key_name] = [] - # else: - # cleaned_data[key_name] = value - - # # Always include the ID column for GraphRAG compatibility - # # Use the storage ID as is since we simplified ID handling - # storage_id = row['id'] - # cleaned_data['id'] = storage_id - # records.append(cleaned_data) - - # df = pd.DataFrame(records) - - # # Additional cleanup for NaN values in the DataFrame - # df = df.where(pd.notna(df), None) - # log.info(f"Created DataFrame with shape: {df.shape}") - # log.info(f"Table {table_name} DataFrame columns: {df.columns.tolist()}") - - # if len(df) > 0: - # log.debug(f"Table {table_name} Sample record: {df.iloc[0].to_dict()}") - # # Debug: Check if children column exists and its type - # if 'children' in df.columns: - # sample_children = df.iloc[0]['children'] - # log.debug(f"Table {table_name} Sample children value: {sample_children}, type: {type(sample_children)}") - - # # Handle bytes conversion for GraphRAG compatibility - # if as_bytes or kwargs.get("as_bytes"): - # log.info(f"Converting DataFrame to parquet bytes for key: {key}") - - # # Apply column filtering similar to Milvus implementation - # df_clean = df.copy() - - # # Define expected columns for each data type - # if 'documents' in table_name: - # expected_columns = ['id', 'human_readable_id', 'title', 'text', 'creation_date', 'metadata'] - # elif 'entities' in table_name: - # expected_columns = ['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'] - # elif 'relationships' in table_name: - # expected_columns = ['id', 'human_readable_id', 'source', 'target', 'description', 'weight', 'text_unit_ids'] - # if 'combined_degree' in df_clean.columns: - # expected_columns.append('combined_degree') - # elif 'text_units' in table_name: - # expected_columns = ['id', 'human_readable_id', 'text', 'n_tokens', 'document_ids', 'entity_ids', 'relationship_ids'] - # elif 'communities' in table_name: - # expected_columns = ['id', 'human_readable_id', 'community', 'level', 'parent', 'children', 'text_unit_ids', 'entity_ids', 'relationship_ids'] - # else: - # expected_columns = list(df_clean.columns) - - # # Filter columns - # available_columns = [col for col in expected_columns if col in df_clean.columns] - # if available_columns != expected_columns: - # missing = set(expected_columns) - set(available_columns) - # extra = set(df_clean.columns) - set(expected_columns) - # log.warning(f"Column mismatch - Expected: {expected_columns}, Available: {available_columns}, Missing: {missing}, Extra: {extra}") - - # df_clean = df_clean[available_columns] - # log.info(f"Table {table_name} final filtered columns: {df_clean.columns.tolist()}") - - # # Convert to parquet bytes - # try: - # # Handle list columns for PyArrow compatibility - # df_for_parquet = df_clean.copy() - - # # For PyArrow/parquet compatibility, we need to handle list columns carefully - # # Instead of converting to JSON strings, let's try a different approach - # list_columns = [] - # for col in df_for_parquet.columns: - # if col in df_for_parquet.columns and len(df_for_parquet) > 0: - # # Check if this column contains lists - # first_non_null = None - # for val in df_for_parquet[col]: - # if isinstance(val, list): - # first_non_null = val - # break - # elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): - # first_non_null = val - # break - - # if isinstance(first_non_null, list) or col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: - # list_columns.append(col) - # # Ensure all values in this column are proper lists - # df_for_parquet[col] = df_for_parquet[col].apply( - # lambda x: x if isinstance(x, list) else ([] if x is None or pd.isna(x) else [str(x)]) - # ) - - # if list_columns: - # log.info(f"Ensured list columns are proper lists for parquet: {list_columns}") - - # # Try to convert to parquet without JSON string conversion - # buffer = BytesIO() - # df_for_parquet.to_parquet(buffer, engine='pyarrow') - # buffer.seek(0) - # parquet_bytes = buffer.getvalue() - # log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data") - # return parquet_bytes - # except Exception as e: - # log.warning(f"Direct parquet conversion failed: {e}, trying with JSON string conversion") - - # # Fallback: convert lists to JSON strings - # try: - # df_for_parquet = df_clean.copy() - - # # Convert list columns to JSON strings for parquet compatibility - # list_columns = [] - # for col in df_for_parquet.columns: - # if col in df_for_parquet.columns and len(df_for_parquet) > 0: - # # Check if this column contains lists - # first_non_null = None - # for val in df_for_parquet[col]: - # if isinstance(val, list): - # first_non_null = val - # break - # elif val is not None and not isinstance(val, (list, np.ndarray)) and pd.notna(val): - # first_non_null = val - # break - # if isinstance(first_non_null, list): - # list_columns.append(col) - # # Convert lists to JSON strings - # df_for_parquet[col] = df_for_parquet[col].apply( - # lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - # ) - # elif col in ['children', 'entity_ids', 'relationship_ids', 'text_unit_ids', 'document_ids']: - # # These columns should always be lists, even if empty - # list_columns.append(col) - # df_for_parquet[col] = df_for_parquet[col].apply( - # lambda x: json.dumps(x) if isinstance(x, list) else (json.dumps([]) if x is None else str(x)) - # ) - - # if list_columns: - # log.info(f"Converted list columns to JSON strings for parquet: {list_columns}") - - # buffer = BytesIO() - # df_for_parquet.to_parquet(buffer, engine='pyarrow') - # buffer.seek(0) - # parquet_bytes = buffer.getvalue() - # log.info(f"Successfully converted DataFrame to {len(parquet_bytes)} bytes of parquet data (with JSON conversion)") - # return parquet_bytes - # except Exception as e2: - # log.exception(f"Failed to convert DataFrame to parquet bytes: {e2}") - # return b"" - - # return df - - # finally: - # await self._release_connection(conn) - - # except Exception as e: - # log.exception(f"Error retrieving data from table {table_name}: {e}") - # return None - - async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Insert data into PostgreSQL table with drop/recreate to avoid duplicates.""" - try: - table_name = self._get_table_name(key) - log.info(f"Setting data for key: {key}, table: {table_name}") - - # Use new table creation approach with duplicate prevention - await self._ensure_table_exists_with_schema(table_name) - - conn = await self._get_connection() - try: - if isinstance(value, bytes): - # Parse parquet data - df = pd.read_parquet(BytesIO(value)) - log.info(f"Parsed parquet data, DataFrame shape: {df.shape}") - # output sample record for debugging - log.debug(f"Table {table_name} Sample record (first row): {df.iloc[0].to_dict()}") - log.info(f"Parsed DataFrame columns: {df.columns.tolist()}") - - # Prepare data for PostgreSQL (typed or generic) - records = self._prepare_data_for_postgres(df, table_name) - - if records: - # Use batch insert for much better performance - await self._batch_upsert_records(conn, table_name, records) - - log.info(f"Successfully inserted {len(records)} records to {table_name}") - - # Log ID handling info - if any(record['id'].split(':')[0] in self._no_id_prefixes for record in records if 'id' in record): - log.info(f"Some records used auto-generated IDs in table {table_name}") - - else: - # Handle non-parquet data (e.g., JSON, stats) - always use generic table - log.info(f"Handling non-parquet data for key: {key}") - - record_data = json.loads(value) if isinstance(value, str) else value - - # Use generic table insertion for non-parquet data - records = [{ - 'id': key, - 'data': record_data, - 'metadata': {'type': 'non_parquet', 'created': datetime.now(timezone.utc).isoformat()} - }] - - await self._batch_upsert_generic_records(conn, table_name, records) - log.info("Non-parquet data insertion successful") - - finally: - await self._release_connection(conn) - - except Exception as e: - log.exception("Error setting data in PostgreSQL table %s: %s", table_name, e) - raise - - async def has(self, key: str) -> bool: - """Check if data exists for the given key.""" - try: - table_name = self._get_table_name(key) - log.info(f"Checking existence for key: {key}, table: {table_name}") - conn = await self._get_connection() - try: - # Check if table exists - table_exists = await conn.fetchval( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", - table_name - ) - log.debug(f"Table {table_name} exists: {table_exists}") - if not table_exists: - return False - - if key.endswith('.parquet'): - # For parquet files, check if table has any records - total_count = await conn.fetchval( - f"SELECT COUNT(*) FROM {table_name}" - ) - if total_count > 0: - return True - else: - raise ValueError(f"No records found in table {table_name} for parquet key {key}") - else: - # Check for exact key match - exists = await conn.fetchval( - f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", - key - ) - return exists - - finally: - await self._release_connection(conn) - - except Exception as e: - log.exception("Error checking existence for key %s: %s", key, e) - return False - - async def delete(self, key: str) -> None: - """Delete data for the given key.""" - try: - table_name = self._get_table_name(key) - conn = await self._get_connection() - try: - await conn.execute(f"DROP TABLE IF EXISTS {table_name}") - log.info(f"Deleted record for key {key}") - finally: - await self._release_connection(conn) - except Exception as e: - log.exception("Error deleting key %s: %s", key, e) - - async def clear(self) -> None: - """Clear all tables with the configured prefix.""" - try: - conn = await self._get_connection() - try: - # Get all tables with our prefix - tables = await conn.fetch( - "SELECT table_name FROM information_schema.tables WHERE table_name LIKE $1", - f"{self._collection_prefix}%" - ) - - for table_row in tables: - table_name = table_row['table_name'] - await conn.execute(f"DROP TABLE IF EXISTS {table_name}") - log.info(f"Dropped table: {table_name}") - - log.info(f"Cleared all tables with prefix: {self._collection_prefix}") - - finally: - await self._release_connection(conn) - - except Exception as e: - log.exception("Error clearing tables: %s", e) - - def keys(self) -> list[str]: - """Return the keys in the storage.""" - # This would need to be async to properly implement - # For now, return empty list - log.warning("keys() method not fully implemented for async storage") - return [] - - def child(self, name: str | None) -> PipelineStorage: - """Create a child storage instance.""" - return self - - async def get_creation_date(self, key: str) -> str: - """Get the creation date for data.""" - try: - table_name = self._get_table_name(key) - conn = await self._get_connection() - try: - if key.endswith('.parquet'): - prefix = self._get_prefix(key) - created_at = await conn.fetchval( - f"SELECT MIN(created_at) FROM {table_name} WHERE id LIKE $1", - f"{prefix}:%" - ) - else: - created_at = await conn.fetchval( - f"SELECT created_at FROM {table_name} WHERE id = $1", - key - ) - - if created_at: - return get_timestamp_formatted_with_local_tz(created_at) - - finally: - await self._release_connection(conn) - - except Exception as e: - log.exception("Error getting creation date for %s: %s", key, e) - - return "" - - async def close(self) -> None: - """Close the connection pool.""" - if self._pool: - await self._pool.close() - log.info("Closed PostgreSQL connection pool") From f8036746eec9a1e48e4ff7d32c161da366cc16ce Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Thu, 4 Sep 2025 17:37:34 -0700 Subject: [PATCH 11/12] updated delete method --- graphrag/storage/postgres_pipeline_storage.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index 3139c7fced..6ec8aaec93 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -552,22 +552,11 @@ async def delete(self, key: str) -> None: table_name = self._get_table_name(key) conn = await self._get_connection() try: - if key.endswith('.parquet'): - # Delete all records with this prefix - prefix = self._get_prefix(key) - result = await conn.execute( - f"DELETE FROM {table_name} WHERE id LIKE $1", - f"{prefix}:%" - ) - log.info(f"Deleted records for prefix {prefix}: {result}") - else: - # Delete exact key match - result = await conn.execute( - f"DELETE FROM {table_name} WHERE id = $1", - key - ) - log.info(f"Deleted record for key {key}: {result}") - + await conn.execute( + f"TRUNCATE TABLE {table_name}" + ) + log.info(f"Deleted records for key: {key}") + finally: await self._release_connection(conn) From 6650e15c07c88eef64936eb075754cbd7a4ee82c Mon Sep 17 00:00:00 2001 From: "dannyzheng@microsoft.com" Date: Fri, 26 Sep 2025 10:37:21 -0700 Subject: [PATCH 12/12] solved conflicts --- graphrag/storage/factory.py | 54 +++++++------------ graphrag/storage/postgres_pipeline_storage.py | 43 ++++----------- 2 files changed, 29 insertions(+), 68 deletions(-) diff --git a/graphrag/storage/factory.py b/graphrag/storage/factory.py index 35d6255ac6..dcbb32afde 100644 --- a/graphrag/storage/factory.py +++ b/graphrag/storage/factory.py @@ -5,14 +5,13 @@ from __future__ import annotations -from contextlib import suppress from typing import TYPE_CHECKING, ClassVar from graphrag.config.enums import StorageType -from graphrag.storage.blob_pipeline_storage import create_blob_storage -from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage -from graphrag.storage.file_pipeline_storage import create_file_storage from graphrag.storage.postgres_pipeline_storage import PostgresPipelineStorage +from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage +from graphrag.storage.file_pipeline_storage import FilePipelineStorage from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage if TYPE_CHECKING: @@ -30,8 +29,7 @@ class StorageFactory: for individual enforcement of required/optional arguments. """ - _storage_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {} - storage_types: ClassVar[dict[str, type]] = {} # For backward compatibility + _registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {} @classmethod def register( @@ -41,23 +39,13 @@ def register( Args: storage_type: The type identifier for the storage. - creator: A callable that creates an instance of the storage. - """ - cls._storage_registry[storage_type] = creator + creator: A class or callable that creates an instance of PipelineStorage. - # For backward compatibility with code that may access storage_types directly - if ( - callable(creator) - and hasattr(creator, "__annotations__") - and "return" in creator.__annotations__ - ): - with suppress(TypeError, KeyError): - cls.storage_types[storage_type] = creator.__annotations__["return"] + """ + cls._registry[storage_type] = creator @classmethod - def create_storage( - cls, storage_type: StorageType | str, kwargs: dict - ) -> PipelineStorage: + def create_storage(cls, storage_type: str, kwargs: dict) -> PipelineStorage: """Create a storage object from the provided type. Args: @@ -72,32 +60,26 @@ def create_storage( ------ ValueError: If the storage type is not registered. """ - storage_type_str = ( - storage_type.value - if isinstance(storage_type, StorageType) - else storage_type - ) - - if storage_type_str not in cls._storage_registry: + if storage_type not in cls._registry: msg = f"Unknown storage type: {storage_type}" raise ValueError(msg) - return cls._storage_registry[storage_type_str](**kwargs) + return cls._registry[storage_type](**kwargs) @classmethod def get_storage_types(cls) -> list[str]: """Get the registered storage implementations.""" - return list(cls._storage_registry.keys()) + return list(cls._registry.keys()) @classmethod - def is_supported_storage_type(cls, storage_type: str) -> bool: + def is_supported_type(cls, storage_type: str) -> bool: """Check if the given storage type is supported.""" - return storage_type in cls._storage_registry + return storage_type in cls._registry -# --- Register default implementations --- -StorageFactory.register(StorageType.blob.value, create_blob_storage) -StorageFactory.register(StorageType.cosmosdb.value, create_cosmosdb_storage) -StorageFactory.register(StorageType.file.value, create_file_storage) -StorageFactory.register(StorageType.memory.value, lambda **_: MemoryPipelineStorage()) +# --- register built-in storage implementations --- +StorageFactory.register(StorageType.blob.value, BlobPipelineStorage) +StorageFactory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage) +StorageFactory.register(StorageType.file.value, FilePipelineStorage) +StorageFactory.register(StorageType.memory.value, MemoryPipelineStorage) StorageFactory.register(StorageType.postgres.value, PostgresPipelineStorage) diff --git a/graphrag/storage/postgres_pipeline_storage.py b/graphrag/storage/postgres_pipeline_storage.py index 6ec8aaec93..b2ee7a961d 100644 --- a/graphrag/storage/postgres_pipeline_storage.py +++ b/graphrag/storage/postgres_pipeline_storage.py @@ -175,8 +175,8 @@ def _prepare_data_for_postgres(self, df: pd.DataFrame, table_name: str) -> list[ df['human_readable_id'] = range(len(df)) log.info(f"Generated sequential human_readable_id for {len(df)} records") - # Process IDs with prefix - ids = df['id'].astype(str).tolist() if 'id' in df.columns else [f"{self._get_prefix(table_name.replace(self._collection_prefix, ''))}:{i}" for i in range(len(df))] + # Process IDs + ids = df['id'].astype(str).tolist() records = [] for i in range(len(df)): @@ -295,8 +295,8 @@ async def _batch_upsert_records(self, conn: Connection, table_name: str, records processed_count += len(batch) # Log progress every batch for visibility - if i % batch_size == 0 or batch_end == total_records: - log.info(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") + # if i % batch_size == 0 or batch_end == total_records: + # log.debug(f"Batch upsert progress: {processed_count}/{total_records} records ({processed_count/total_records*100:.1f}%)") def find( self, @@ -460,7 +460,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: if isinstance(value, bytes): # Parse parquet data df = pd.read_parquet(BytesIO(value)) - log.info(f"Parsed parquet data, DataFrame shape: {df.shape}") + log.info(f"Parsed parquet data on set method, DataFrame shape: {df.shape}") log.info(f"Parsed DataFrame head: {df.head()}") # Prepare data for PostgreSQL @@ -472,10 +472,6 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: log.info(f"Successfully upserted {len(records)} records to {table_name}") - # # Log duplicate handling info - # if any(record['id'].split(':')[0] in self._no_id_prefixes for record in records): - # log.info("Some records used auto-generated IDs") - else: # Handle non-parquet data (e.g., JSON, stats) log.info(f"Handling non-parquet data for key: {key}") @@ -521,27 +517,10 @@ async def has(self, key: str) -> bool: "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)", table_name ) - log.debug(f"Table {table_name} exists: {table_exists}") - if not table_exists: - return False - - if key.endswith('.parquet'): - # For parquet files, check if table has any records - total_count = await conn.fetchval( - f"SELECT COUNT(*) FROM {table_name}" - ) - return total_count > 0 - else: - # Check for exact key match - exists = await conn.fetchval( - f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE id = $1)", - key - ) - return exists - + log.info(f"Table {table_name} exists: {table_exists}") + return table_exists finally: await self._release_connection(conn) - except Exception as e: log.exception("Error checking existence for key %s: %s", key, e) return False @@ -552,11 +531,11 @@ async def delete(self, key: str) -> None: table_name = self._get_table_name(key) conn = await self._get_connection() try: - await conn.execute( - f"TRUNCATE TABLE {table_name}" + result = await conn.execute( + f"TRUNCATE TABLE {table_name}" ) - log.info(f"Deleted records for key: {key}") - + log.info(f"Deleted record for key {key}: {result}") + finally: await self._release_connection(conn)