diff --git a/src/share/CMakeLists.txt b/src/share/CMakeLists.txt index 23f4d6ff2..876a67ac4 100644 --- a/src/share/CMakeLists.txt +++ b/src/share/CMakeLists.txt @@ -733,6 +733,9 @@ ob_set_subtarget(ob_share hybrid_search hybrid_search/ob_query_request.cpp hybrid_search/ob_query_translator.cpp hybrid_search/ob_hybrid_search_executor.cpp + hybrid_search/ob_hybrid_search_fusion_engine.cpp + hybrid_search/ob_rrf_fusion.cpp + hybrid_search/ob_weighted_fusion.cpp ) ob_set_subtarget(ob_share domain_id diff --git a/src/share/hybrid_search/HYBRID_SEARCH_GUIDE.sql b/src/share/hybrid_search/HYBRID_SEARCH_GUIDE.sql new file mode 100644 index 000000000..836db52b8 --- /dev/null +++ b/src/share/hybrid_search/HYBRID_SEARCH_GUIDE.sql @@ -0,0 +1,310 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Hybrid Search Implementation Guide and SQL Examples + * ================================ + * + * This document provides detailed instructions and SQL examples on how to use hybrid search features. + */ + +-- ======================================================== +-- Part 1: Table Structure Design +-- ======================================================== + +-- Create a table with vector and full-text indexes +CREATE TABLE documents ( + id INT PRIMARY KEY, + title VARCHAR(255), + content TEXT, + embedding VECTOR(384), -- 384-dimensional vector + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + -- Full-text index configuration + FULLTEXT INDEX idx_content(content) WITH PARSER jieba, + + -- Vector index configuration + -- DISTANCE=l2: Uses L2 Euclidean distance + -- TYPE=hnsw: Uses HNSW (Hierarchical Navigable Small Worlds) algorithm + -- LIB=vsag: Uses VSAG vector search library + VECTOR INDEX idx_embedding(embedding) WITH(DISTANCE=l2, TYPE=hnsw, LIB=vsag) +) ORGANIZATION = HEAP; + +-- ======================================================== +-- Part 2: Data Insertion Example +-- ======================================================== + +-- Insert sample data +INSERT INTO documents (id, title, content, embedding) VALUES +(1, 'Artificial Intelligence Overview', + 'Machine learning is a branch of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. It focuses on developing computer programs that can access data and use it to learn for themselves.', + VECTOR('[0.1, 0.2, 0.3, ..., 0.384]')), + +(2, 'Deep Learning Fundamentals', + 'Deep learning is a subset of machine learning that uses artificial neural networks with multiple layers. It has revolutionized computer vision, natural language processing, and many other AI applications.', + VECTOR('[0.15, 0.25, 0.35, ..., 0.385]')), + +(3, 'Vector Database Technology', + 'Vector databases are specialized databases designed for efficient storage, retrieval, and similarity search of vector embeddings. They support various distance metrics including L2, cosine similarity, and inner product.', + VECTOR('[0.2, 0.3, 0.4, ..., 0.386]')), + +(4, 'Natural Language Processing', + 'Natural language processing (NLP) is a subfield of linguistics, computer science, and artificial intelligence concerned with the interactions between computers and human language. It is used to apply machine learning algorithms to text and speech.', + VECTOR('[0.12, 0.22, 0.32, ..., 0.387]')), + +(5, 'Computer Vision Applications', + 'Computer vision is an interdisciplinary scientific field that deals with how digital images and videos can be used to extract high-level understanding from digital images and videos. It seeks to automate tasks that the human visual system can do.', + VECTOR('[0.18, 0.28, 0.38, ..., 0.388]')); + +-- ======================================================== +-- Part 3: SQL Examples for RRF Fusion Method +-- ======================================================== + +-- Scheme 1.1: Basic RRF Fusion Query +-- Use Case: Automatic normalization needed, robust to outliers +-- Parameter Explanation: +-- rank_constant: 60 (larger values are more favorable to low-ranked documents) +-- rank_window_size: 100 (fuse from 100 results) + +EXPLAIN SELECT + doc_id, + fts_score, + vector_score, + fts_rank, + vector_rank, + final_score +FROM ( + WITH fts_results AS ( + SELECT + id AS doc_id, + MATCH(content) AGAINST('artificial intelligence machine learning' IN NATURAL LANGUAGE MODE) AS fts_score, + ROW_NUMBER() OVER (ORDER BY MATCH(content) AGAINST('artificial intelligence machine learning' IN NATURAL LANGUAGE MODE) DESC) AS fts_rank + FROM documents + WHERE MATCH(content) AGAINST('artificial intelligence machine learning' IN NATURAL LANGUAGE MODE) + LIMIT 100 + ), + vector_results AS ( + SELECT + id AS doc_id, + 1.0 / (1.0 + l2_distance(embedding, '[0.15, 0.25, ...]')) AS vector_score, + ROW_NUMBER() OVER (ORDER BY l2_distance(embedding, '[0.15, 0.25, ...]') ASC) AS vector_rank + FROM documents + LIMIT 100 + ), + rrf_scores AS ( + SELECT + COALESCE(f.doc_id, v.doc_id) AS doc_id, + COALESCE(f.fts_score, 0) AS fts_score, + COALESCE(v.vector_score, 0) AS vector_score, + COALESCE(f.fts_rank, -1) AS fts_rank, + COALESCE(v.vector_rank, -1) AS vector_rank, + -- RRF formula: score = 1 / (rank + rank_constant) + COALESCE(1.0 / (f.fts_rank + 60), 0) + + COALESCE(1.0 / (v.vector_rank + 60), 0) AS final_score + FROM fts_results f + FULL OUTER JOIN vector_results v ON f.doc_id = v.doc_id + ) + SELECT * FROM rrf_scores +) results +ORDER BY final_score DESC +LIMIT 10; + +-- ======================================================== +-- Part 4: SQL Examples for Weighted Fusion Method +-- ======================================================== + +-- Scheme 2.1: Balanced Fusion (50% Full-text + 50% Vector) +-- Use Case: Keyword matching and semantic similarity are equally important + +WITH fts_results AS ( + SELECT + id, + title, + MATCH(content) AGAINST('artificial intelligence' IN NATURAL LANGUAGE MODE) AS fts_score + FROM documents + WHERE MATCH(content) AGAINST('artificial intelligence' IN NATURAL LANGUAGE MODE) + LIMIT 100 +), +vector_results AS ( + SELECT + id, + 1.0 / (1.0 + l2_distance(embedding, '[0.15, 0.25, ...]')) AS vector_score + FROM documents + ORDER BY l2_distance(embedding, '[0.15, 0.25, ...]') + LIMIT 100 +), +score_stats AS ( + SELECT + MAX(f.fts_score) AS max_fts, + MIN(f.fts_score) AS min_fts, + MAX(v.vector_score) AS max_vector, + MIN(v.vector_score) AS min_vector + FROM fts_results f, vector_results v +), +normalized_scores AS ( + SELECT + COALESCE(f.id, v.id) AS id, + COALESCE(f.title, 'N/A') AS title, + COALESCE(f.fts_score, 0) AS fts_score, + COALESCE(v.vector_score, 0) AS vector_score, + -- Min-Max normalization + COALESCE((f.fts_score - s.min_fts) / (s.max_fts - s.min_fts), 0) AS norm_fts, + COALESCE((v.vector_score - s.min_vector) / (s.max_vector - s.min_vector), 0) AS norm_vector, + s.max_fts, + s.min_fts, + s.max_vector, + s.min_vector + FROM fts_results f + FULL OUTER JOIN vector_results v ON f.id = v.id + CROSS JOIN score_stats s +) +SELECT + id, + title, + norm_fts, + norm_vector, + -- Weighted sum: 0.5 * normalized_fts + 0.5 * normalized_vector + (0.5 * norm_fts + 0.5 * norm_vector) AS final_score +FROM normalized_scores +WHERE norm_fts IS NOT NULL OR norm_vector IS NOT NULL +ORDER BY final_score DESC +LIMIT 10; + +-- Scheme 2.2: Keyword Priority Fusion (70% Full-text + 30% Vector) +-- Use Case: Users' search keywords are usually accurate, minimal semantic understanding needed + +WITH fts_results AS ( + SELECT + id, + MATCH(content) AGAINST('machine learning' IN NATURAL LANGUAGE MODE) AS fts_score + FROM documents + WHERE MATCH(content) AGAINST('machine learning' IN NATURAL LANGUAGE MODE) +), +vector_results AS ( + SELECT + id, + 1.0 / (1.0 + l2_distance(embedding, '[0.15, 0.25, ...]')) AS vector_score + FROM documents +), +min_max_norm AS ( + SELECT + COALESCE(f.id, v.id) AS id, + COALESCE(f.fts_score, 0) AS fts_score, + COALESCE(v.vector_score, 0) AS vector_score, + -- Min-Max normalization + CASE WHEN (MAX(f.fts_score) OVER () - MIN(f.fts_score) OVER ()) > 0 + THEN (COALESCE(f.fts_score, 0) - MIN(f.fts_score) OVER ()) / + (MAX(f.fts_score) OVER () - MIN(f.fts_score) OVER ()) + ELSE 0 END AS norm_fts, + CASE WHEN (MAX(v.vector_score) OVER () - MIN(v.vector_score) OVER ()) > 0 + THEN (COALESCE(v.vector_score, 0) - MIN(v.vector_score) OVER ()) / + (MAX(v.vector_score) OVER () - MIN(v.vector_score) OVER ()) + ELSE 0 END AS norm_vector + FROM fts_results f + FULL OUTER JOIN vector_results v ON f.id = v.id +) +SELECT + id, + -- Weighted sum: 0.7 * normalized_fts + 0.3 * normalized_vector + (0.7 * norm_fts + 0.3 * norm_vector) AS final_score +FROM min_max_norm +ORDER BY final_score DESC +LIMIT 10; + +-- Scheme 2.3: Semantic Priority Fusion (30% Full-text + 70% Vector) +-- Use Case: Complex user search intent, need to understand semantics through vector search + +WITH fts_results AS ( + SELECT + id, + MATCH(content) AGAINST('neural network deep learning' IN NATURAL LANGUAGE MODE) AS fts_score + FROM documents + WHERE MATCH(content) AGAINST('neural network deep learning' IN NATURAL LANGUAGE MODE) +), +vector_results AS ( + SELECT + id, + 1.0 / (1.0 + l2_distance(embedding, '[0.15, 0.25, ...]')) AS vector_score + FROM documents +), +weighted_hybrid AS ( + SELECT + COALESCE(f.id, v.id) AS id, + COALESCE(f.fts_score, 0) AS fts_score, + COALESCE(v.vector_score, 0) AS vector_score, + -- Z-Score normalization (using Sigmoid function) + 1.0 / (1.0 + EXP(-(COALESCE(f.fts_score, 0) - AVG(f.fts_score) OVER ()) / + STDDEV(f.fts_score) OVER ())) AS norm_fts, + 1.0 / (1.0 + EXP(-(COALESCE(v.vector_score, 0) - AVG(v.vector_score) OVER ()) / + STDDEV(v.vector_score) OVER ())) AS norm_vector + FROM fts_results f + FULL OUTER JOIN vector_results v ON f.id = v.id +) +SELECT + id, + -- Weighted sum: 0.3 * normalized_fts + 0.7 * normalized_vector + (0.3 * norm_fts + 0.7 * norm_vector) AS final_score +FROM weighted_hybrid +ORDER BY final_score DESC +LIMIT 10; + +-- ======================================================== +-- Part 5: Advanced Normalization Strategy Examples +-- ======================================================== + +-- Scheme 3.1: Min-Max Normalization Example +-- Characteristic: Maps all scores to [0, 1] range + +WITH score_stats AS ( + SELECT + MAX(MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE)) AS max_fts, + MIN(MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE)) AS min_fts, + MAX(l2_distance(embedding, '[0.15, 0.25, ...]')) AS max_vec, + MIN(l2_distance(embedding, '[0.15, 0.25, ...]')) AS min_vec + FROM documents +) +SELECT + id, + -- Full-text search score normalization + (MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE) - s.min_fts) / + (s.max_fts - s.min_fts) * 0.5 + + -- Vector search score normalization (distance to similarity) + (1.0 - (l2_distance(embedding, '[...]') - s.min_vec) / + (s.max_vec - s.min_vec)) * 0.5 AS final_score +FROM documents, score_stats s +WHERE MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE) +ORDER BY final_score DESC +LIMIT 10; + +-- Scheme 3.2: Z-Score Normalization Example +-- Characteristic: Standardizes score distribution, sensitive to outliers + +WITH score_stats AS ( + SELECT + AVG(MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE)) AS avg_fts, + STDDEV(MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE)) AS std_fts, + AVG(l2_distance(embedding, '[...]')) AS avg_vec, + STDDEV(l2_distance(embedding, '[...]')) AS std_vec + FROM documents +) +SELECT + id, + -- Standardized scores (Z-Score) + ((MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE) - s.avg_fts) / s.std_fts) * 0.5 + + ((s.avg_vec - l2_distance(embedding, '[...]')) / s.std_vec) * 0.5 AS final_score +FROM documents, score_stats s +WHERE MATCH(content) AGAINST('query' IN NATURAL LANGUAGE MODE) +ORDER BY final_score DESC +LIMIT 10; diff --git a/src/share/hybrid_search/hybrid_search_demo.cpp b/src/share/hybrid_search/hybrid_search_demo.cpp new file mode 100644 index 000000000..9c67fc488 --- /dev/null +++ b/src/share/hybrid_search/hybrid_search_demo.cpp @@ -0,0 +1,369 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Hybrid Search Demo Examples and Test Cases + * + * This file demonstrates how to use the hybrid search API to implement fusion + * of vector and full-text search. Includes demo code and unit tests for multiple scenarios. + */ + +#include "ob_rrf_fusion.h" +#include "ob_weighted_fusion.h" +#include "lib/allocator/ob_malloc.h" +#include +#include + +namespace oceanbase +{ +namespace common +{ + +/* + * ============================================= + * Demo 1: Using RRF Fusion Method + * ============================================= + * Scenario: Balanced hybrid search with automatic normalization + * Use Case: Applications that need to balance keyword matching and semantic similarity + */ +void demo_rrf_fusion() +{ + std::cout << "\n========== RRF Fusion Demo ==========" << std::endl; + + // 1. Prepare memory allocator + ObMallocAllocator allocator; + + // 2. Create RRF fusion engine + ObRRFFusion rrf_fusion; + + // 3. Configure parameters + // rank_constant 60 means ranking differences have relatively small impact + // rank_window_size 100 means fusing from 100 search results + ObRRFConfig rrf_config(60, 100); + + if (rrf_fusion.init(rrf_config, allocator) != OB_SUCCESS) { + std::cout << "Failed to initialize RRF fusion" << std::endl; + return; + } + + // 4. Prepare full-text search results (BM25 scores, typically between 0 and tens) + common::ObSEArray fts_results; + ObHybridSearchResult fts_result1; + fts_result1.doc_id_ = 1; + fts_result1.fts_score_ = 15.5; + fts_results.push_back(fts_result1); + + ObHybridSearchResult fts_result2; + fts_result2.doc_id_ = 2; + fts_result2.fts_score_ = 12.3; + fts_results.push_back(fts_result2); + + ObHybridSearchResult fts_result3; + fts_result3.doc_id_ = 3; + fts_result3.fts_score_ = 8.7; + fts_results.push_back(fts_result3); + + ObHybridSearchResult fts_result4; + fts_result4.doc_id_ = 4; + fts_result4.fts_score_ = 5.2; + fts_results.push_back(fts_result4); + + std::cout << "FTS Results:" << std::endl; + for (const auto &r : fts_results) { + std::cout << " Doc ID: " << r.doc_id_ << ", Score: " << std::fixed << std::setprecision(2) << r.fts_score_ << std::endl; + } + + // 5. Prepare vector search results (vector similarity, typically between 0-1) + common::ObSEArray vector_results; + ObHybridSearchResult vec_result1; + vec_result1.doc_id_ = 2; + vec_result1.vector_score_ = 0.95; + vector_results.push_back(vec_result1); + + ObHybridSearchResult vec_result2; + vec_result2.doc_id_ = 1; + vec_result2.vector_score_ = 0.88; + vector_results.push_back(vec_result2); + + ObHybridSearchResult vec_result3; + vec_result3.doc_id_ = 5; + vec_result3.vector_score_ = 0.82; + vector_results.push_back(vec_result3); + + ObHybridSearchResult vec_result4; + vec_result4.doc_id_ = 3; + vec_result4.vector_score_ = 0.75; + vector_results.push_back(vec_result4); + + std::cout << "\nVector Results:" << std::endl; + for (const auto &r : vector_results) { + std::cout << " Doc ID: " << r.doc_id_ << ", Score: " << std::fixed << std::setprecision(2) << r.vector_score_ << std::endl; + } + + // 6. Add search results and perform fusion + if (rrf_fusion.add_fts_results(fts_results) != OB_SUCCESS || + rrf_fusion.add_vector_results(vector_results) != OB_SUCCESS || + rrf_fusion.fuse() != OB_SUCCESS) { + std::cout << "Failed to perform RRF fusion" << std::endl; + return; + } + + // 7. Get fusion results (top 5) + common::ObSEArray fused_results; + if (rrf_fusion.get_results(fused_results, 5) != OB_SUCCESS) { + std::cout << "Failed to get results" << std::endl; + return; + } + + std::cout << "\nFused Results (Top 5):" << std::endl; + std::cout << std::left << std::setw(10) << "Doc ID" + << std::setw(15) << "FTS Score" + << std::setw(15) << "Vector Score" + << std::setw(15) << "Final Score" + << std::setw(10) << "Source" << std::endl; + std::cout << std::string(65, '-') << std::endl; + + for (const auto &result : fused_results) { + std::string source = "None"; + if (result.source_flag_ == 1) source = "FTS Only"; + else if (result.source_flag_ == 2) source = "Vec Only"; + else if (result.source_flag_ == 3) source = "Both"; + + std::cout << std::left + << std::setw(10) << result.doc_id_ + << std::setw(15) << std::fixed << std::setprecision(4) << result.fts_score_ + << std::setw(15) << std::fixed << std::setprecision(4) << result.vector_score_ + << std::setw(15) << std::fixed << std::setprecision(4) << result.final_score_ + << std::setw(10) << source << std::endl; + } +} + +/* + * ============================================= + * Demo 2: Using Weighted Fusion Method - Balanced Approach + * ============================================= + * Scenario: Equal weights for full-text and vector search (50% each) + * Use Case: Both keyword matching and semantic similarity are equally important + */ +void demo_weighted_fusion_balanced() +{ + std::cout << "\n========== Weighted Fusion Demo (Balanced 50:50) ==========" << std::endl; + + // 1. Prepare memory allocator + ObMallocAllocator allocator; + + // 2. Create weighted fusion engine + ObWeightedFusion weighted_fusion; + + // 3. Configure parameters - Balanced approach + ObWeightedFusionConfig fusion_config(0.5, 0.5, true); // 50% FTS, 50% Vector, enable normalization + + // 4. Normalization configuration - Use Min-Max normalization + ObNormalizationConfig norm_config; + norm_config.norm_type_ = ObNormalizationConfig::NormalizationType::MIN_MAX; + + if (weighted_fusion.init(fusion_config, norm_config, allocator) != OB_SUCCESS) { + std::cout << "Failed to initialize weighted fusion" << std::endl; + return; + } + + // 5. Prepare test data (same as RRF demo) + common::ObSEArray fts_results; + ObHybridSearchResult fts_r1 = {1, 15.5, 0.0, -1, -1, 0.0, 0}; + ObHybridSearchResult fts_r2 = {2, 12.3, 0.0, -1, -1, 0.0, 0}; + ObHybridSearchResult fts_r3 = {3, 8.7, 0.0, -1, -1, 0.0, 0}; + fts_results.push_back(fts_r1); + fts_results.push_back(fts_r2); + fts_results.push_back(fts_r3); + + common::ObSEArray vector_results; + ObHybridSearchResult vec_r1 = {2, 0.0, 0.95, -1, -1, 0.0, 0}; + ObHybridSearchResult vec_r2 = {1, 0.0, 0.88, -1, -1, 0.0, 0}; + ObHybridSearchResult vec_r3 = {5, 0.0, 0.82, -1, -1, 0.0, 0}; + vector_results.push_back(vec_r1); + vector_results.push_back(vec_r2); + vector_results.push_back(vec_r3); + + // 6. Perform fusion + if (weighted_fusion.add_fts_results(fts_results) != OB_SUCCESS || + weighted_fusion.add_vector_results(vector_results) != OB_SUCCESS || + weighted_fusion.fuse() != OB_SUCCESS) { + std::cout << "Failed to perform weighted fusion" << std::endl; + return; + } + + // 7. Get results + common::ObSEArray fused_results; + if (weighted_fusion.get_results(fused_results, 5) != OB_SUCCESS) { + std::cout << "Failed to get results" << std::endl; + return; + } + + std::cout << "Fused Results (Balanced 50:50):" << std::endl; + std::cout << std::left << std::setw(10) << "Doc ID" + << std::setw(15) << "Norm FTS" + << std::setw(15) << "Norm Vector" + << std::setw(15) << "Final Score" << std::endl; + std::cout << std::string(55, '-') << std::endl; + + for (const auto &result : fused_results) { + std::cout << std::left + << std::setw(10) << result.doc_id_ + << std::setw(15) << std::fixed << std::setprecision(4) << result.fts_score_ + << std::setw(15) << std::fixed << std::setprecision(4) << result.vector_score_ + << std::setw(15) << std::fixed << std::setprecision(4) << result.final_score_ << std::endl; + } +} + +/* + * ============================================= + * Demo 3: Using Weighted Fusion Method - Exact Match Priority + * ============================================= + * Scenario: Prioritize keyword exact matching (70% FTS, 30% Vector) + * Use Case: Users' search keywords are usually accurate, minimal semantic understanding needed + */ +void demo_weighted_fusion_keyword_priority() +{ + std::cout << "\n========== Weighted Fusion Demo (Keyword Priority 70:30) ==========" << std::endl; + + ObMallocAllocator allocator; + ObWeightedFusion weighted_fusion; + + // 70% full-text search, 30% vector search + ObWeightedFusionConfig fusion_config(0.7, 0.3, true); + + // Use Z-Score normalization + ObNormalizationConfig norm_config; + norm_config.norm_type_ = ObNormalizationConfig::NormalizationType::Z_SCORE; + + if (weighted_fusion.init(fusion_config, norm_config, allocator) != OB_SUCCESS) { + std::cout << "Failed to initialize weighted fusion" << std::endl; + return; + } + + // Prepare data + common::ObSEArray fts_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result = {static_cast(i+1), 10.0 - i*3, 0.0, -1, -1, 0.0, 0}; + fts_results.push_back(result); + } + + common::ObSEArray vector_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result = {static_cast((i+1)%3+1), 0.0, 0.9 - i*0.05, -1, -1, 0.0, 0}; + vector_results.push_back(result); + } + + if (weighted_fusion.add_fts_results(fts_results) != OB_SUCCESS || + weighted_fusion.add_vector_results(vector_results) != OB_SUCCESS || + weighted_fusion.fuse() != OB_SUCCESS) { + std::cout << "Failed to perform fusion" << std::endl; + return; + } + + common::ObSEArray fused_results; + if (weighted_fusion.get_results(fused_results, 5) != OB_SUCCESS) { + std::cout << "Failed to get results" << std::endl; + return; + } + + std::cout << "Fused Results (Keyword Priority 70:30):" << std::endl; + for (int i = 0; i < fused_results.count(); ++i) { + const auto &result = fused_results.at(i); + std::cout << " Rank " << (i+1) << ": Doc ID " << result.doc_id_ + << ", Score: " << std::fixed << std::setprecision(4) << result.final_score_ << std::endl; + } +} + +/* + * ============================================= + * Demo 4: Using Weighted Fusion Method - Semantic Similarity Priority + * ============================================= + * Scenario: Prioritize semantic similarity (30% FTS, 70% Vector) + * Use Case: Complex user search intent, need to understand semantics through vector search + */ +void demo_weighted_fusion_semantic_priority() +{ + std::cout << "\n========== Weighted Fusion Demo (Semantic Priority 30:70) ==========" << std::endl; + + ObMallocAllocator allocator; + ObWeightedFusion weighted_fusion; + + // 30% full-text search, 70% vector search + ObWeightedFusionConfig fusion_config(0.3, 0.7, true); + + // Use Min-Max normalization + ObNormalizationConfig norm_config; + norm_config.norm_type_ = ObNormalizationConfig::NormalizationType::MIN_MAX; + + if (weighted_fusion.init(fusion_config, norm_config, allocator) != OB_SUCCESS) { + std::cout << "Failed to initialize weighted fusion" << std::endl; + return; + } + + // Prepare data + common::ObSEArray fts_results; + ObHybridSearchResult fts1 = {1, 8.0, 0.0, -1, -1, 0.0, 0}; + ObHybridSearchResult fts2 = {2, 6.5, 0.0, -1, -1, 0.0, 0}; + fts_results.push_back(fts1); + fts_results.push_back(fts2); + + common::ObSEArray vector_results; + ObHybridSearchResult vec1 = {2, 0.0, 0.92, -1, -1, 0.0, 0}; + ObHybridSearchResult vec2 = {1, 0.0, 0.85, -1, -1, 0.0, 0}; + vector_results.push_back(vec1); + vector_results.push_back(vec2); + + if (weighted_fusion.add_fts_results(fts_results) != OB_SUCCESS || + weighted_fusion.add_vector_results(vector_results) != OB_SUCCESS || + weighted_fusion.fuse() != OB_SUCCESS) { + std::cout << "Failed to perform fusion" << std::endl; + return; + } + + common::ObSEArray fused_results; + if (weighted_fusion.get_results(fused_results) != OB_SUCCESS) { + std::cout << "Failed to get results" << std::endl; + return; + } + + std::cout << "Fused Results (Semantic Priority 30:70):" << std::endl; + for (int i = 0; i < fused_results.count(); ++i) { + const auto &result = fused_results.at(i); + std::cout << " Rank " << (i+1) << ": Doc ID " << result.doc_id_ + << ", Score: " << std::fixed << std::setprecision(4) << result.final_score_ << std::endl; + } +} + +} // namespace common +} // namespace oceanbase + +int main() +{ + std::cout << "========================================" << std::endl; + std::cout << "Hybrid Search Fusion Demonstrations" << std::endl; + std::cout << "========================================" << std::endl; + + // Run all demos + oceanbase::common::demo_rrf_fusion(); + oceanbase::common::demo_weighted_fusion_balanced(); + oceanbase::common::demo_weighted_fusion_keyword_priority(); + oceanbase::common::demo_weighted_fusion_semantic_priority(); + + std::cout << "\n========================================" << std::endl; + std::cout << "All demonstrations completed successfully!" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +} diff --git a/src/share/hybrid_search/ob_hybrid_search_common.h b/src/share/hybrid_search/ob_hybrid_search_common.h new file mode 100644 index 000000000..f8948c897 --- /dev/null +++ b/src/share/hybrid_search/ob_hybrid_search_common.h @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OB_HYBRID_SEARCH_COMMON_H +#define OB_HYBRID_SEARCH_COMMON_H + +#include "lib/ob_define.h" +#include "lib/oblog/ob_log.h" +#include "lib/container/ob_se_array.h" + +namespace oceanbase +{ +namespace common +{ + +// Hybrid search fusion method types +enum class ObHybridSearchFusionType +{ + UNKNOWN = 0, + RRF = 1, // Reciprocal Rank Fusion + WEIGHT_SUM = 2, // Weighted Sum Fusion + MIN_MAX_NORM = 3, // Min-Max Normalization Fusion + Z_SCORE_NORM = 4 // Z-Score Normalization Fusion +}; + +// Configuration parameters for RRF method +struct ObRRFConfig +{ + // Rank constant for balancing documents with low and high ranks + // Formula: score = 1 / (rank + rank_constant) + // Larger values are more favorable for low-ranked documents + int64_t rank_constant_ = 60; + + // Window size for each sub-query, recommended as 10-20 times the number of final results + int64_t rank_window_size_ = 100; + + ObRRFConfig() = default; + ObRRFConfig(int64_t rank_const, int64_t window_size) + : rank_constant_(rank_const), rank_window_size_(window_size) {} + + TO_STRING_KV(K_(rank_constant), K_(rank_window_size)); +}; + +// Configuration parameters for weighted fusion +struct ObWeightedFusionConfig +{ + // Weight for full-text search, range [0, 1] + double fts_weight_ = 0.5; + + // Weight for vector search, range [0, 1] + double vector_weight_ = 0.5; + + // Normalization strategy: whether to normalize scores + bool enable_normalization_ = true; + + ObWeightedFusionConfig() = default; + ObWeightedFusionConfig(double fts_w, double vec_w, bool normalize) + : fts_weight_(fts_w), vector_weight_(vec_w), enable_normalization_(normalize) {} +}; + +// Normalization strategy configuration +struct ObNormalizationConfig +{ + // Normalization method type + enum class NormalizationType + { + NONE = 0, // No normalization + MIN_MAX = 1, // Min-Max normalization: (x - min) / (max - min) + Z_SCORE = 2, // Z-Score normalization: (x - mean) / stddev + SIGMOID = 3 // Sigmoid normalization: 1 / (1 + exp(-x)) + }; + + NormalizationType norm_type_ = NormalizationType::MIN_MAX; + + // Min and max values for Min-Max normalization + double min_value_ = 0.0; + double max_value_ = 1.0; + + // Mean value and standard deviation for Z-Score normalization + double mean_value_ = 0.0; + double stddev_value_ = 1.0; + + // Note: Normalization statistics (min, max, mean, stddev) are computed automatically from the data + // during the fusion process. The above fields are provided for configuration purposes but are + // not directly used in the ObWeightedFusion implementation. + + ObNormalizationConfig() = default; +}; + +// Single search result item +struct ObHybridSearchResult +{ + // Document ID + uint64_t doc_id_ = 0; + + // Full-text search score (BM25) + double fts_score_ = 0.0; + + // Vector search score (distance or similarity) + double vector_score_ = 0.0; + + // Full-text search rank + int64_t fts_rank_ = -1; + + // Vector search rank + int64_t vector_rank_ = -1; + + // Final score after fusion + double final_score_ = 0.0; + + // Source flag: 1 for FTS only, 2 for vector only, 3 for both + int32_t source_flag_ = 0; + + bool operator<(const ObHybridSearchResult &other) const + { + // Sort by final score in descending order + if (final_score_ != other.final_score_) { + return final_score_ > other.final_score_; + } + return doc_id_ < other.doc_id_; + } + + TO_STRING_KV(K_(doc_id), K_(fts_score), K_(vector_score), + K_(fts_rank), K_(vector_rank), K_(final_score), K_(source_flag)); +}; + +// Vector distance measurement type +enum class ObVectorDistanceType +{ + L2_DISTANCE = 0, // Euclidean distance (L2) + COSINE_DISTANCE = 1, // Cosine distance + INNER_PRODUCT = 2 // Inner product +}; + +// Helper class for vector similarity conversion +class ObVectorMetricConverter +{ +public: + // Convert vector distance to similarity (between 0 and 1) + static double distance_to_similarity(double distance, ObVectorDistanceType type) + { + if (distance < 0) { + distance = 0; + } + + switch (type) { + case ObVectorDistanceType::L2_DISTANCE: + // L2 distance to similarity: similarity = 1 / (1 + distance) + return 1.0 / (1.0 + distance); + + case ObVectorDistanceType::COSINE_DISTANCE: + // Cosine distance to similarity: similarity = (1 - distance) / 2 + // Assumes cosine_distance range is [0, 2] + return (1.0 - distance) / 2.0; + + case ObVectorDistanceType::INNER_PRODUCT: + // Inner product is usually already similarity, but needs mapping to [0, 1] range + // Assumes already normalized + return distance > 1.0 ? 1.0 : (distance < 0.0 ? 0.0 : distance); + + default: + return 0.0; + } + } +}; + +} // namespace common +} // namespace oceanbase + +#endif // OB_HYBRID_SEARCH_COMMON_H diff --git a/src/share/hybrid_search/ob_hybrid_search_executor.cpp b/src/share/hybrid_search/ob_hybrid_search_executor.cpp index 96eb61189..3b59af6d5 100644 --- a/src/share/hybrid_search/ob_hybrid_search_executor.cpp +++ b/src/share/hybrid_search/ob_hybrid_search_executor.cpp @@ -15,7 +15,10 @@ */ #include "ob_hybrid_search_executor.h" +#include "ob_hybrid_search_fusion_engine.h" #include "storage/vector_index/cmd/ob_vector_refresh_index_executor.h" +#include "lib/json_type/ob_json_base.h" +#include "lib/json_type/ob_json_tree.h" #define USING_LOG_PREFIX SHARE @@ -84,16 +87,15 @@ int ObHybridSearchExecutor::execute_search(ObObj &query_res) { LOG_WARN("execute query failed", K(ret), K(query_sql), K(tenant_id_)); } else if (OB_NOT_NULL(result.get_result())) { if (OB_SUCCESS == (ret = result.get_result()->next())) { - ObObj tmp_res; - if (OB_FAIL(result.get_result()->get_obj("hits", tmp_res))) { - if (OB_ERR_NULL_VALUE == ret || OB_ERR_COLUMN_NOT_FOUND == ret) { - query_res.set_null(); - ret = OB_SUCCESS; - } else { - LOG_WARN("fail to extract result. ", K(ret)); - } - } else if (OB_FAIL(common::deep_copy_obj(ctx_->get_allocator(), tmp_res, query_res))) { - LOG_WARN("deep copy query result failed", K(ret)); + // Step 1: Parse FTS and Vector results from SQL result + common::ObSEArray fts_results; + common::ObSEArray vector_results; + + if (OB_FAIL(parse_hybrid_search_result(result.get_result(), fts_results, vector_results))) { + LOG_WARN("fail to parse hybrid search result", K(ret)); + } else if (OB_FAIL(apply_fusion_and_convert_to_json(fts_results, vector_results, + search_arg_.search_params_, query_res))) { + LOG_WARN("fail to apply fusion and convert to json", K(ret)); } } else if (OB_ITER_END == ret) { LOG_INFO("no result return!", K(ret), K(tenant_id_)); @@ -117,6 +119,440 @@ int ObHybridSearchExecutor::execute_get_sql(ObString &sql_result) { return ret; } +int ObHybridSearchExecutor::parse_hybrid_search_result( + const common::sqlclient::ObMySQLResult *result, + common::ObIArray &fts_results, + common::ObIArray &vector_results) { + int ret = OB_SUCCESS; + + if (OB_ISNULL(result)) { + ret = OB_INVALID_ARGUMENT; + LOG_WARN("result is null", K(ret)); + } else { + // Extract results from the SQL query result + // The SQL generates FULL OUTER JOIN of FTS and Vector search results + // Each row contains: doc_id, fts_rank, vector_rank, fts_score, vector_score, fts_matched, vector_matched + + while (OB_SUCC(ret)) { + uint64_t doc_id = 0; + double fts_score = 0.0; + double vector_score = 0.0; + int64_t fts_rank = -1; + int64_t vector_rank = -1; + + // Try to extract FTS-only result (FTS matched, Vector is NULL) + if (OB_FAIL(result->get_uint("doc_id", doc_id))) { + if (OB_ERR_NULL_VALUE == ret || OB_ERR_COLUMN_NOT_FOUND == ret) { + ret = OB_SUCCESS; + break; // No more rows + } + LOG_WARN("fail to extract doc_id", K(ret)); + break; + } else { + // Extract FTS rank and score + ObHybridSearchResult hybrid_result; + hybrid_result.doc_id_ = doc_id; + + // Try to get FTS rank (will be NULL if no FTS match) + if (OB_FAIL(result->get_int("fts_rank", fts_rank))) { + if (OB_ERR_NULL_VALUE == ret) { + fts_rank = -1; + ret = OB_SUCCESS; // FTS not matched + } else if (OB_ERR_COLUMN_NOT_FOUND != ret) { + LOG_WARN("fail to extract fts_rank", K(ret)); + break; + } + } + + // Try to get Vector rank (will be NULL if no Vector match) + if (OB_SUCC(ret)) { + if (OB_FAIL(result->get_int("vector_rank", vector_rank))) { + if (OB_ERR_NULL_VALUE == ret) { + vector_rank = -1; + ret = OB_SUCCESS; // Vector not matched + } else if (OB_ERR_COLUMN_NOT_FOUND != ret) { + LOG_WARN("fail to extract vector_rank", K(ret)); + break; + } + } + } + + // Try to get FTS score + if (OB_SUCC(ret)) { + if (OB_FAIL(result->get_double("fts_score", fts_score))) { + if (OB_ERR_NULL_VALUE == ret) { + fts_score = 0.0; + ret = OB_SUCCESS; + } else if (OB_ERR_COLUMN_NOT_FOUND != ret) { + LOG_WARN("fail to extract fts_score", K(ret)); + break; + } + } + } + + // Try to get Vector score + if (OB_SUCC(ret)) { + if (OB_FAIL(result->get_double("vector_score", vector_score))) { + if (OB_ERR_NULL_VALUE == ret) { + vector_score = 0.0; + ret = OB_SUCCESS; + } else if (OB_ERR_COLUMN_NOT_FOUND != ret) { + LOG_WARN("fail to extract vector_score", K(ret)); + break; + } + } + } + + // Populate result based on which search matched + if (OB_SUCC(ret)) { + if (fts_rank >= 0 && vector_rank >= 0) { + // Both FTS and Vector matched + hybrid_result.fts_rank_ = fts_rank; + hybrid_result.vector_rank_ = vector_rank; + hybrid_result.fts_score_ = fts_score; + hybrid_result.vector_score_ = vector_score; + hybrid_result.source_flag_ = 3; // Both sources + if (OB_FAIL(fts_results.push_back(hybrid_result))) { + LOG_WARN("fail to push fts result", K(ret)); + break; + } + if (OB_FAIL(vector_results.push_back(hybrid_result))) { + LOG_WARN("fail to push vector result", K(ret)); + break; + } + } else if (fts_rank >= 0) { + // FTS only + hybrid_result.fts_rank_ = fts_rank; + hybrid_result.vector_rank_ = -1; + hybrid_result.fts_score_ = fts_score; + hybrid_result.vector_score_ = 0.0; + hybrid_result.source_flag_ = 1; // FTS only + if (OB_FAIL(fts_results.push_back(hybrid_result))) { + LOG_WARN("fail to push fts result", K(ret)); + break; + } + } else if (vector_rank >= 0) { + // Vector only + hybrid_result.fts_rank_ = -1; + hybrid_result.vector_rank_ = vector_rank; + hybrid_result.fts_score_ = 0.0; + hybrid_result.vector_score_ = vector_score; + hybrid_result.source_flag_ = 2; // Vector only + if (OB_FAIL(vector_results.push_back(hybrid_result))) { + LOG_WARN("fail to push vector result", K(ret)); + break; + } + } + } + } + + // Get next row + ret = const_cast(result)->next(); + if (OB_ITER_END == ret) { + ret = OB_SUCCESS; + break; + } else if (OB_FAIL(ret)) { + LOG_WARN("fail to get next result", K(ret)); + break; + } + } + } + + return ret; +} + +int ObHybridSearchExecutor::apply_fusion_and_convert_to_json( + const common::ObIArray &fts_results, + const common::ObIArray &vector_results, + const ObString &search_params_str, + ObObj &query_res) { + int ret = OB_SUCCESS; + + if (OB_ISNULL(ctx_)) { + ret = OB_NOT_INIT; + LOG_WARN("exec context not initialized", K(ret)); + } else { + // Step 1: Initialize fusion engine + ObHybridSearchFusionEngine fusion_engine; + common::ObIAllocator &allocator = ctx_->get_allocator(); + + // Step 2: Determine fusion strategy and parameters from search_params_str + // Parse fusion configuration from search parameters + // Expected format: {"rank": {"rrf": {"rank_constant": 60, "rank_window_size": 1000}}} + ObRRFConfig rrf_config(60, 1000); // Default: rank_constant=60, rank_window_size=1000 + + // Extract RRF parameters from search_params_str if provided via rank.rrf path + if (OB_FAIL(parse_rrf_config_from_params(search_params_str, rrf_config))) { + // If parsing fails, use default configuration + LOG_INFO("fail to parse rrf config from params, using default config", K(ret)); + ret = OB_SUCCESS; // Reset error to continue with default config + } + + if (OB_FAIL(fusion_engine.init(ObHybridSearchFusionEngine::FusionStrategy::RRF, + &rrf_config, allocator))) { + LOG_WARN("fail to init fusion engine", K(ret)); + } else { + // Step 3: Feed FTS and Vector results to fusion engine + if (OB_FAIL(fusion_engine.feed_fts_results(fts_results))) { + LOG_WARN("fail to feed fts results to fusion engine", K(ret)); + } else if (OB_FAIL(fusion_engine.feed_vector_results(vector_results))) { + LOG_WARN("fail to feed vector results to fusion engine", K(ret)); + } else if (OB_FAIL(fusion_engine.execute_fusion())) { + // Step 4: Execute fusion algorithm (RRF in this case) + LOG_WARN("fail to execute fusion", K(ret)); + } else { + // Step 5: Get fused results + common::ObSEArray fused_results; + if (OB_FAIL(fusion_engine.get_fused_results(fused_results))) { + LOG_WARN("fail to get fused results", K(ret)); + } else { + // Step 6: Convert fused results to JSON format + common::ObJsonObject response_json(&allocator); + common::ObJsonArray hits_array(&allocator); + + // Add results array + for (int64_t i = 0; OB_SUCC(ret) && i < fused_results.count(); ++i) { + const ObHybridSearchResult &result = fused_results.at(i); + common::ObJsonObject result_obj(&allocator); + + // Add doc_id + common::ObJsonInt *doc_id_node = OB_NEWx(common::ObJsonInt, (&allocator), + static_cast(result.doc_id_)); + if (OB_ISNULL(doc_id_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate doc_id node", K(ret)); + break; + } else if (OB_FAIL(result_obj.add(common::ObString::make_string("doc_id"), doc_id_node))) { + LOG_WARN("fail to add doc_id", K(ret)); + break; + } + + // Add final score + common::ObJsonDouble *score_node = OB_NEWx(common::ObJsonDouble, (&allocator), + result.final_score_); + if (OB_ISNULL(score_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate score node", K(ret)); + break; + } else if (OB_FAIL(result_obj.add(common::ObString::make_string("score"), score_node))) { + LOG_WARN("fail to add score", K(ret)); + break; + } + + // Add FTS score if available + if (result.fts_rank_ >= 0) { + common::ObJsonDouble *fts_score_node = OB_NEWx(common::ObJsonDouble, (&allocator), + result.fts_score_); + if (OB_ISNULL(fts_score_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate fts_score node", K(ret)); + break; + } else if (OB_FAIL(result_obj.add(common::ObString::make_string("fts_score"), fts_score_node))) { + LOG_WARN("fail to add fts_score", K(ret)); + break; + } + + common::ObJsonInt *fts_rank_node = OB_NEWx(common::ObJsonInt, (&allocator), + result.fts_rank_); + if (OB_ISNULL(fts_rank_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate fts_rank node", K(ret)); + break; + } else if (OB_FAIL(result_obj.add(common::ObString::make_string("fts_rank"), fts_rank_node))) { + LOG_WARN("fail to add fts_rank", K(ret)); + break; + } + } + + // Add Vector score if available + if (result.vector_rank_ >= 0) { + common::ObJsonDouble *vector_score_node = OB_NEWx(common::ObJsonDouble, (&allocator), + result.vector_score_); + if (OB_ISNULL(vector_score_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate vector_score node", K(ret)); + break; + } else if (OB_FAIL(result_obj.add(common::ObString::make_string("vector_score"), vector_score_node))) { + LOG_WARN("fail to add vector_score", K(ret)); + break; + } + + common::ObJsonInt *vector_rank_node = OB_NEWx(common::ObJsonInt, (&allocator), + result.vector_rank_); + if (OB_ISNULL(vector_rank_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate vector_rank node", K(ret)); + break; + } else if (OB_FAIL(result_obj.add(common::ObString::make_string("vector_rank"), vector_rank_node))) { + LOG_WARN("fail to add vector_rank", K(ret)); + break; + } + } + + // Add source flag + common::ObJsonInt *source_node = OB_NEWx(common::ObJsonInt, (&allocator), + static_cast(result.source_flag_)); + if (OB_ISNULL(source_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate source node", K(ret)); + break; + } else if (OB_FAIL(result_obj.add(common::ObString::make_string("source"), source_node))) { + LOG_WARN("fail to add source flag", K(ret)); + break; + } + + // Add result object to hits array + if (OB_FAIL(hits_array.append(&result_obj))) { + LOG_WARN("fail to append result to hits array", K(ret)); + break; + } + } + + if (OB_SUCC(ret)) { + // Add hits array to response + if (OB_FAIL(response_json.add(common::ObString::make_string("hits"), + &hits_array))) { + LOG_WARN("fail to add hits array to response", K(ret)); + } else { + // Add metadata + common::ObJsonInt *total_node = OB_NEWx(common::ObJsonInt, (&allocator), + static_cast(fused_results.count())); + if (OB_ISNULL(total_node)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("fail to allocate total node", K(ret)); + } else if (OB_FAIL(response_json.add(common::ObString::make_string("total"), total_node))) { + LOG_WARN("fail to add total count", K(ret)); + } else { + // Convert JSON object to string and set as result + common::ObStringBuffer json_str(&allocator); + if (OB_FAIL(response_json.print(json_str, 0))) { + LOG_WARN("fail to print json object", K(ret)); + } else { + // Create ObObj from JSON string + ObString json_result(json_str.length(), json_str.ptr()); + if (OB_FAIL(common::deep_copy_obj(allocator, + ObObj(json_result), query_res))) { + LOG_WARN("fail to deep copy json result", K(ret)); + } + } + } + } + } + } + } + } + } + + return ret; +} + +int ObHybridSearchExecutor::parse_rrf_config_from_params( + const ObString &search_params_str, + ObRRFConfig &rrf_config) { + int ret = OB_SUCCESS; + + if (OB_ISNULL(search_params_str.ptr()) || search_params_str.length() == 0) { + // Empty params string, use default configuration + return OB_SUCCESS; + } + + // Parse rank.rrf from JSON search parameters + // Expected format according to ObESQueryParser: + // { + // "rank": { + // "rrf": { + // "rank_constant": 60, + // "rank_window_size": 1000 + // } + // } + // } + ObIJsonBase *json_base = nullptr; + if (OB_FAIL(common::ObJsonBaseFactory::get_json_base(&ctx_->get_allocator(), + search_params_str, + common::ObJsonInType::JSON_TREE, + common::ObJsonInType::JSON_TREE, + json_base))) { + LOG_WARN("fail to parse search params as json", K(ret), K(search_params_str)); + // Not a valid JSON, return success to use default config + ret = OB_SUCCESS; + } else if (OB_ISNULL(json_base)) { + // JSON parse result is null, use default config + ret = OB_SUCCESS; + } else { + common::ObJsonObject *json_obj = static_cast(json_base); + common::ObIJsonBase *rank_base = nullptr; + + // Try to get rank object (first level) + if (OB_FAIL(json_obj->get_object_value(common::ObString::make_string("rank"), + rank_base))) { + if (OB_ERR_JSON_PATH_EXPRESSION_SYNTAX_ERROR == ret) { + // rank not found, use default configuration + LOG_INFO("rank not found in search params, using default config"); + ret = OB_SUCCESS; + } else { + LOG_WARN("fail to get rank from search params", K(ret)); + } + } else if (OB_ISNULL(rank_base)) { + // rank is null, use default config + ret = OB_SUCCESS; + } else { + common::ObJsonObject *rank_obj = static_cast(rank_base); + common::ObIJsonBase *rrf_base = nullptr; + + // Try to get rrf object (second level) + if (OB_FAIL(rank_obj->get_object_value(common::ObString::make_string("rrf"), + rrf_base))) { + if (OB_ERR_JSON_PATH_EXPRESSION_SYNTAX_ERROR == ret) { + // rrf not found, use default configuration + LOG_INFO("rrf not found in rank config, using default config"); + ret = OB_SUCCESS; + } else { + LOG_WARN("fail to get rrf from rank config", K(ret)); + } + } else if (OB_ISNULL(rrf_base)) { + // rrf is null, use default config + ret = OB_SUCCESS; + } else { + common::ObJsonObject *rrf_obj = static_cast(rrf_base); + + // Extract rank_constant + common::ObIJsonBase *rank_const_base = nullptr; + if (OB_FAIL(rrf_obj->get_object_value( + common::ObString::make_string("rank_constant"), rank_const_base))) { + if (OB_ERR_JSON_PATH_EXPRESSION_SYNTAX_ERROR != ret) { + LOG_WARN("fail to get rank_constant from rrf config", K(ret)); + } + ret = OB_SUCCESS; // Not critical, use default + } else if (OB_NOT_NULL(rank_const_base)) { + common::ObJsonNumber *rank_const_num = + static_cast(rank_const_base); + rrf_config.rank_constant_ = static_cast(rank_const_num->get_double()); + LOG_INFO("parsed rank_constant from rank.rrf config", + K(rrf_config.rank_constant_)); + } + + // Extract rank_window_size + common::ObIJsonBase *window_size_base = nullptr; + if (OB_FAIL(rrf_obj->get_object_value( + common::ObString::make_string("rank_window_size"), window_size_base))) { + if (OB_ERR_JSON_PATH_EXPRESSION_SYNTAX_ERROR != ret) { + LOG_WARN("fail to get rank_window_size from rrf config", K(ret)); + } + ret = OB_SUCCESS; // Not critical, use default + } else if (OB_NOT_NULL(window_size_base)) { + common::ObJsonNumber *window_size_num = + static_cast(window_size_base); + rrf_config.rank_window_size_ = static_cast(window_size_num->get_double()); + LOG_INFO("parsed rank_window_size from rank.rrf config", + K(rrf_config.rank_window_size_)); + } + } + } + } + + return ret; +} + int ObHybridSearchExecutor::do_get_sql(const ObString &search_params_str, ObString &sql_result, bool need_wrap_result /*= false*/) { int ret = OB_SUCCESS; @@ -184,7 +620,7 @@ int ObHybridSearchExecutor::parse_search_params( LOG_WARN("No database selected", KR(ret)); } else { ObESQueryParser parser(allocator_, need_wrap_result, &table_name, &database_name); - if (OB_FAIL(construct_column_index_info(allocator_, database_name, table_name, parser.get_index_name_map(), parser.get_user_column_names()))) { + if (OB_FAIL(construct_column_index_info(allocator_, parser))) { LOG_WARN("fail to construnct column index info", KR(ret), K(search_params_str)); } else if (OB_FAIL(parser.parse(search_params_str, query_req))) { LOG_WARN("fail to parse search params", KR(ret), K(search_params_str)); @@ -193,14 +629,17 @@ int ObHybridSearchExecutor::parse_search_params( return ret; } -int ObHybridSearchExecutor::construct_column_index_info(ObIAllocator &alloc, const ObString &database_name, const ObString &table_name, - ColumnIndexNameMap &column_index_info, ObIArray &col_names) +int ObHybridSearchExecutor::construct_column_index_info(ObIAllocator &alloc, ObESQueryParser &parser) { int ret = OB_SUCCESS; share::schema::ObSchemaGetterGuard *schema_guard = NULL; const ObTableSchema *data_table_schema = NULL; ObSEArray simple_index_infos; ObCStringHelper helper; + const ObString &database_name = parser.get_database_name(); + const ObString &table_name = parser.get_table_name(); + ColumnIndexNameMap &column_index_info = parser.get_index_name_map(); + ObIArray &col_names = parser.get_user_column_names(); if (OB_ISNULL(schema_guard = ctx_->get_virtual_table_ctx().schema_guard_)) { ret = OB_ERR_UNEXPECTED; @@ -219,6 +658,8 @@ int ObHybridSearchExecutor::construct_column_index_info(ObIAllocator &alloc, con LOG_WARN("fail to get simple index infos failed", K(ret)); } else if (OB_FAIL(get_basic_column_names(data_table_schema, col_names))) { LOG_WARN("fail to get all column names", K(ret)); + } else if (OB_FAIL(get_partition_info(data_table_schema, parser))) { + LOG_WARN("fail to get partition column names and init alias exprs", K(ret)); } else { for (int64_t i = 0; OB_SUCC(ret) && i < simple_index_infos.count(); ++i) { const ObTableSchema *index_table_schema = nullptr; @@ -349,5 +790,67 @@ int ObHybridSearchExecutor::get_basic_column_names(const ObTableSchema *table_sc return ret; } +int ObHybridSearchExecutor::extract_partition_column_ids(const ObPartitionKeyInfo &part_key_info, + hash::ObPlacementHashSet &column_id_set, + ObIArray &column_ids) +{ + int ret = OB_SUCCESS; + uint64_t column_id = OB_INVALID_ID; + for (int64_t i = 0; OB_SUCC(ret) && i < part_key_info.get_size(); i++) { + if (OB_FAIL(part_key_info.get_column_id(i, column_id))) { + LOG_WARN("failed to get column id from partition key info", K(ret), K(i)); + } else { + int hash_ret = column_id_set.exist_refactored(column_id); + if (OB_HASH_EXIST == hash_ret) { + } else if (OB_HASH_NOT_EXIST == hash_ret) { + if (OB_FAIL(column_id_set.set_refactored(column_id))) { + LOG_WARN("failed to set column id in hash set", K(ret), K(column_id)); + } else if (OB_FAIL(column_ids.push_back(column_id))) { + LOG_WARN("failed to push back column id", K(ret), K(column_id)); + } + } else { + ret = hash_ret; + LOG_WARN("failed to check column id existence", K(ret), K(column_id)); + } + } + } + return ret; +} + +int ObHybridSearchExecutor::get_partition_info(const ObTableSchema *table_schema, ObESQueryParser &parser) +{ + int ret = OB_SUCCESS; + if (OB_ISNULL(table_schema)) { + ret = OB_INVALID_ARGUMENT; + LOG_WARN("invalid argument", K(ret), KP(table_schema)); + } else if (table_schema->get_part_level() != PARTITION_LEVEL_ZERO) { + hash::ObPlacementHashSet column_id_set; // to deduplicate column ids + ObSEArray column_ids; + ObSEArray column_names; + const ObPartitionKeyInfo &part_key_info = table_schema->get_partition_key_info(); + const ObPartitionKeyInfo &subpart_key_info = table_schema->get_subpartition_key_info(); + if (OB_FAIL(extract_partition_column_ids(part_key_info, column_id_set, column_ids))) { + LOG_WARN("failed to extract column ids from partition key info", K(ret)); + } else if (table_schema->get_part_level() == PARTITION_LEVEL_TWO && OB_FAIL(extract_partition_column_ids(subpart_key_info, column_id_set, column_ids))) { + LOG_WARN("failed to extract column ids from subpartition key info", K(ret)); + } else if (column_ids.count() > 0) { + lib::ob_sort(column_ids.begin(), column_ids.end()); + } + for (int64_t i = 0; OB_SUCC(ret) && i < column_ids.count(); i++) { + const ObColumnSchemaV2 *column_schema = table_schema->get_column_schema(column_ids.at(i)); + if (OB_ISNULL(column_schema)) { + ret = OB_ERR_UNEXPECTED; + LOG_WARN("unexpected column schema", K(ret), K(column_ids.at(i))); + } else if (OB_FAIL(column_names.push_back(column_schema->get_column_name_str()))) { + LOG_WARN("failed to push back column name", K(ret)); + } + } + if (OB_SUCC(ret) && OB_FAIL(parser.construct_partition_cols(column_names))) { + LOG_WARN("failed to construct partition column and alias exprs", K(ret)); + } + } + return ret; +} + } // namespace share } // namespace oceanbase diff --git a/src/share/hybrid_search/ob_hybrid_search_executor.h b/src/share/hybrid_search/ob_hybrid_search_executor.h index a80014182..413496a8c 100644 --- a/src/share/hybrid_search/ob_hybrid_search_executor.h +++ b/src/share/hybrid_search/ob_hybrid_search_executor.h @@ -17,6 +17,7 @@ #pragma once #include "ob_query_parse.h" +#include "ob_hybrid_search_fusion_engine.h" #include "pl/ob_pl.h" #include "share/schema/ob_schema_struct.h" #include "sql/engine/ob_exec_context.h" @@ -76,9 +77,23 @@ class ObHybridSearchExecutor { /// int do_get_sql_with_retry(ObString &sql_result); int generate_sql_from_params(const ObString &search_params_str, ObString &sql_result); - int construct_column_index_info(ObIAllocator &alloc, const ObString &database_name, const ObString &table_name, - ColumnIndexNameMap &column_index_info, ObIArray &col_names); + int construct_column_index_info(ObIAllocator &alloc, ObESQueryParser &parser); int get_basic_column_names(const ObTableSchema *table_schema, ObIArray &col_names); + int extract_partition_column_ids(const ObPartitionKeyInfo &part_key_info, + hash::ObPlacementHashSet &column_id_set, + ObIArray &column_ids); + int get_partition_info(const ObTableSchema *table_schema, ObESQueryParser &parser); + + // Hybrid search fusion engine integration + int parse_hybrid_search_result(const common::sqlclient::ObMySQLResult *result, + common::ObIArray &fts_results, + common::ObIArray &vector_results); + int apply_fusion_and_convert_to_json(const common::ObIArray &fts_results, + const common::ObIArray &vector_results, + const ObString &search_params_str, + ObObj &query_res); + int parse_rrf_config_from_params(const ObString &search_params_str, + ObRRFConfig &rrf_config); private: sql::ObExecContext *ctx_; diff --git a/src/share/hybrid_search/ob_hybrid_search_fusion_engine.cpp b/src/share/hybrid_search/ob_hybrid_search_fusion_engine.cpp new file mode 100644 index 000000000..ca76ab902 --- /dev/null +++ b/src/share/hybrid_search/ob_hybrid_search_fusion_engine.cpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define USING_LOG_PREFIX OLOG + +#include "ob_hybrid_search_fusion_engine.h" +#include "ob_rrf_fusion.h" +#include "ob_weighted_fusion.h" +#include "lib/oblog/ob_log.h" + +namespace oceanbase { +namespace share { +using namespace oceanbase::common; + +// ==================== RRF Strategy Adapter ==================== +class ObRRFFusionStrategy : public IObHybridSearchFusionStrategy { +public: + virtual ~ObRRFFusionStrategy() = default; + + int init(const void *config, ObIAllocator &allocator) override { + if (OB_ISNULL(config)) { + return OB_INVALID_ARGUMENT; + } + const ObRRFConfig *rrf_config = static_cast(config); + return rrf_fusion_.init(*rrf_config, allocator); + } + + int feed_fts_results(const ObIArray &results) override { + return rrf_fusion_.add_fts_results(results); + } + + int feed_vector_results(const ObIArray &results) override { + return rrf_fusion_.add_vector_results(results); + } + + int execute_fusion() override { + return rrf_fusion_.fuse(); + } + + int get_fused_results(ObIArray &results, int64_t limit = 0) const override { + return rrf_fusion_.get_results(results, limit); + } + + void reset() override { + rrf_fusion_.reset(); + } + +private: + ObRRFFusion rrf_fusion_; +}; + +// ==================== Weighted Fusion Strategy Adapter ==================== +class ObWeightedFusionStrategy : public IObHybridSearchFusionStrategy { +public: + ObWeightedFusionStrategy(ObNormalizationConfig::NormalizationType norm_type) + : norm_type_(norm_type) {} + + virtual ~ObWeightedFusionStrategy() = default; + + int init(const void *config, ObIAllocator &allocator) override { + if (OB_ISNULL(config)) { + return OB_INVALID_ARGUMENT; + } + const ObWeightedFusionConfig *fusion_config = static_cast(config); + ObNormalizationConfig norm_config; + norm_config.norm_type_ = norm_type_; + return weighted_fusion_.init(*fusion_config, norm_config, allocator); + } + + int feed_fts_results(const ObIArray &results) override { + return weighted_fusion_.add_fts_results(results); + } + + int feed_vector_results(const ObIArray &results) override { + return weighted_fusion_.add_vector_results(results); + } + + int execute_fusion() override { + return weighted_fusion_.fuse(); + } + + int get_fused_results(ObIArray &results, int64_t limit = 0) const override { + return weighted_fusion_.get_results(results, limit); + } + + void reset() override { + weighted_fusion_.reset(); + } + +private: + ObWeightedFusion weighted_fusion_; + ObNormalizationConfig::NormalizationType norm_type_; +}; + +// ==================== Fusion Engine Main Class ==================== +ObHybridSearchFusionEngine::ObHybridSearchFusionEngine() + : strategy_buffer_(), + strategy_(nullptr), + is_initialized_(false), + allocator_(nullptr) +{ +} + +ObHybridSearchFusionEngine::~ObHybridSearchFusionEngine() +{ + if (OB_NOT_NULL(strategy_)) { + strategy_->~IObHybridSearchFusionStrategy(); + strategy_ = nullptr; + } +} + +int ObHybridSearchFusionEngine::create_strategy(FusionStrategy strategy, ObIAllocator &allocator) +{ + int ret = OB_SUCCESS; + + switch (strategy) { + case FusionStrategy::RRF: { + void *buf = strategy_buffer_.get_data(); + if (OB_ISNULL(buf)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + } else { + strategy_ = new (buf) ObRRFFusionStrategy(); + } + break; + } + + case FusionStrategy::WEIGHTED_SUM_MIN_MAX: { + void *buf = strategy_buffer_.get_data(); + if (OB_ISNULL(buf)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + } else { + strategy_ = new (buf) ObWeightedFusionStrategy( + ObNormalizationConfig::NormalizationType::MIN_MAX); + } + break; + } + + case FusionStrategy::WEIGHTED_SUM_Z_SCORE: { + void *buf = strategy_buffer_.get_data(); + if (OB_ISNULL(buf)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + } else { + strategy_ = new (buf) ObWeightedFusionStrategy( + ObNormalizationConfig::NormalizationType::Z_SCORE); + } + break; + } + + case FusionStrategy::WEIGHTED_SUM_SIGMOID: { + void *buf = strategy_buffer_.get_data(); + if (OB_ISNULL(buf)) { + ret = OB_ALLOCATE_MEMORY_FAILED; + } else { + strategy_ = new (buf) ObWeightedFusionStrategy( + ObNormalizationConfig::NormalizationType::SIGMOID); + } + break; + } + + case FusionStrategy::WEIGHTED_SUM: + case FusionStrategy::UNKNOWN: + default: + ret = OB_NOT_SUPPORTED; + OB_LOG(WARN, "unsupported fusion strategy", K(ret), K(strategy)); + } + + return ret; +} + +int ObHybridSearchFusionEngine::init(FusionStrategy strategy, const void *config, + ObIAllocator &allocator) +{ + int ret = OB_SUCCESS; + + if (is_initialized_) { + ret = OB_INIT_TWICE; + OB_LOG(WARN, "fusion engine already initialized", K(ret)); + return ret; + } + + if (OB_FAIL(create_strategy(strategy, allocator))) { + OB_LOG(WARN, "failed to create fusion strategy", K(ret)); + return ret; + } + + if (OB_ISNULL(strategy_)) { + ret = OB_ERR_UNEXPECTED; + OB_LOG(WARN, "strategy is null after creation", K(ret)); + return ret; + } + + if (OB_FAIL(strategy_->init(config, allocator))) { + OB_LOG(WARN, "failed to initialize fusion strategy", K(ret)); + return ret; + } + + allocator_ = &allocator; + is_initialized_ = true; + + return ret; +} + +int ObHybridSearchFusionEngine::feed_fts_results(const ObIArray &results) +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "fusion engine not initialized", K(ret)); + return ret; + } + + if (OB_ISNULL(strategy_)) { + ret = OB_ERR_UNEXPECTED; + OB_LOG(WARN, "strategy is null", K(ret)); + return ret; + } + + return strategy_->feed_fts_results(results); +} + +int ObHybridSearchFusionEngine::feed_vector_results(const ObIArray &results) +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "fusion engine not initialized", K(ret)); + return ret; + } + + if (OB_ISNULL(strategy_)) { + ret = OB_ERR_UNEXPECTED; + OB_LOG(WARN, "strategy is null", K(ret)); + return ret; + } + + return strategy_->feed_vector_results(results); +} + +int ObHybridSearchFusionEngine::execute_fusion() +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "fusion engine not initialized", K(ret)); + return ret; + } + + if (OB_ISNULL(strategy_)) { + ret = OB_ERR_UNEXPECTED; + OB_LOG(WARN, "strategy is null", K(ret)); + return ret; + } + + return strategy_->execute_fusion(); +} + +int ObHybridSearchFusionEngine::get_fused_results(ObIArray &results, + int64_t limit) const +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "fusion engine not initialized", K(ret)); + return ret; + } + + if (OB_ISNULL(strategy_)) { + ret = OB_ERR_UNEXPECTED; + OB_LOG(WARN, "strategy is null", K(ret)); + return ret; + } + + return strategy_->get_fused_results(results, limit); +} + +int64_t ObHybridSearchFusionEngine::get_fused_result_count() const +{ + int ret = OB_SUCCESS; + if (!is_initialized_ || OB_ISNULL(strategy_)) { + return 0; + } + + common::ObSEArray temp_results; + if (OB_SUCC(strategy_->get_fused_results(temp_results))) { + return temp_results.count(); + } + + return 0; +} + +void ObHybridSearchFusionEngine::reset() +{ + if (OB_NOT_NULL(strategy_)) { + strategy_->reset(); + } + + is_initialized_ = false; + allocator_ = nullptr; +} + +} // namespace share +} // namespace oceanbase diff --git a/src/share/hybrid_search/ob_hybrid_search_fusion_engine.h b/src/share/hybrid_search/ob_hybrid_search_fusion_engine.h new file mode 100644 index 000000000..ac70a1e59 --- /dev/null +++ b/src/share/hybrid_search/ob_hybrid_search_fusion_engine.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OB_HYBRID_SEARCH_FUSION_ENGINE_H +#define OB_HYBRID_SEARCH_FUSION_ENGINE_H + +#include "ob_hybrid_search_common.h" +#include "lib/container/ob_se_array.h" +#include "lib/allocator/ob_allocator.h" + +namespace oceanbase { +namespace share { + +/* + * Hybrid Search Fusion Engine + * + * Responsibilities: + * 1. Receive raw search results from FTS and KNN + * 2. Select appropriate fusion strategy based on configuration + * 3. Execute fusion computation and return sorted final results + * + * Usage flow: + * 1. ObHybridSearchFusionEngine engine; + * 2. engine.init(strategy, config, allocator); + * 3. engine.feed_fts_results(fts_results); + * 4. engine.feed_vector_results(vector_results); + * 5. engine.execute_fusion(); + * 6. engine.get_fused_results(output_results); + */ + +class IObHybridSearchFusionStrategy { +public: + virtual ~IObHybridSearchFusionStrategy() = default; + + /* + * Initialize fusion strategy + * + * @param config Configuration object, specific type is determined by strategy + * @param allocator Memory allocator + * @return OB_SUCCESS or error code + */ + virtual int init(const void *config, common::ObIAllocator &allocator) = 0; + + /* + * Feed full-text search results + */ + virtual int feed_fts_results(const common::ObIArray &results) = 0; + + /* + * Feed vector search results + */ + virtual int feed_vector_results(const common::ObIArray &results) = 0; + + /* + * Execute fusion computation + */ + virtual int execute_fusion() = 0; + + /* + * Get fused results + * + * @param results Output parameter containing sorted fused results + * @param limit Limit the number of returned results, 0 means return all results + * @return OB_SUCCESS or error code + */ + virtual int get_fused_results(common::ObIArray &results, + int64_t limit = 0) const = 0; + + /* + * Reset state + */ + virtual void reset() = 0; +}; + +class ObHybridSearchFusionEngine { +public: + /* + * Fusion strategy type + */ + enum class FusionStrategy { + UNKNOWN = 0, + RRF = 1, // Reciprocal Rank Fusion + WEIGHTED_SUM = 2, // Weighted sum (dynamically select normalization strategy) + WEIGHTED_SUM_MIN_MAX = 3, // Min-Max normalization + weighted sum + WEIGHTED_SUM_Z_SCORE = 4, // Z-Score normalization + weighted sum + WEIGHTED_SUM_SIGMOID = 5 // Sigmoid normalization + weighted sum + }; + + ObHybridSearchFusionEngine(); + ~ObHybridSearchFusionEngine(); + + /* + * Initialize fusion engine + * + * @param strategy Fusion strategy + * @param config Configuration object pointer, type depends on strategy + * @param allocator Memory allocator + * @return OB_SUCCESS or error code + * + * Configuration object types: + * - RRF: const ObRRFConfig* + * - WEIGHTED_SUM*: const ObWeightedFusionConfig* + */ + int init(FusionStrategy strategy, const void *config, common::ObIAllocator &allocator); + + /* + * Feed full-text search result list + */ + int feed_fts_results(const common::ObIArray &results); + + /* + * Feed vector search result list + */ + int feed_vector_results(const common::ObIArray &results); + + /* + * Execute fusion computation + */ + int execute_fusion(); + + /* + * Get fused results + * + * @param results Output array to receive fused results + * @param limit Optional, limit the number of returned results, 0 means return all + * @return OB_SUCCESS or error code + */ + int get_fused_results(common::ObIArray &results, + int64_t limit = 0) const; + + /* + * Get count of fused results + */ + int64_t get_fused_result_count() const; + + /* + * Reset engine state, prepare for next fusion + */ + void reset(); + +private: + /* + * Create corresponding strategy object based on strategy type + */ + int create_strategy(FusionStrategy strategy, common::ObIAllocator &allocator); + +private: + // Fusion strategy implementation + common::ObSEArray strategy_buffer_; // Buffer to store strategy object + IObHybridSearchFusionStrategy *strategy_; // Strategy interface pointer + + // Initialization flag + bool is_initialized_; + + // Memory allocator (non-owner) + common::ObIAllocator *allocator_; +}; + +} // namespace share +} // namespace oceanbase + +#endif // OB_HYBRID_SEARCH_FUSION_ENGINE_H diff --git a/src/share/hybrid_search/ob_query_parse.cpp b/src/share/hybrid_search/ob_query_parse.cpp index de04d651f..f44138f1b 100644 --- a/src/share/hybrid_search/ob_query_parse.cpp +++ b/src/share/hybrid_search/ob_query_parse.cpp @@ -37,6 +37,7 @@ const ObString ObESQueryParser::FTS_ALIAS("_fts"); const ObString ObESQueryParser::VS_ALIAS("_vs"); const ObString ObESQueryParser::MSM_KEY("minimum_should_match"); const ObString ObESQueryParser::FTS_SUB_SCORE_PREFIX("_fts_sub_score_"); +const ObString ObESQueryParser::PART_COL_ALIAS_PREFIX("_part_col_"); const ObString ObESQueryParser::HIDDEN_COLUMN_VISIBLE_HINT("opt_param('hidden_column_visible', 'true')"); int ObESQueryParser::parse(const common::ObString &req_str, ObQueryReqFromJson *&query_req) @@ -438,6 +439,13 @@ int ObESQueryParser::knn_fusion(const ObIArray &knn_queries LOG_WARN("fail to create query request", K(ret)); } else if (OB_FAIL(base_table_req->select_items_.push_back(rowkey_expr))) { LOG_WARN("fail to create query request", K(ret)); + } + for (int64_t j = 0; OB_SUCC(ret) && j < part_cols_.count(); j++) { + if (OB_FAIL(base_table_req->select_items_.push_back(part_cols_.at(j)))) { + LOG_WARN("failed to add partition expr to subquery select items", K(ret)); + } + } + if (OB_FAIL(ret)) { } else if (OB_FAIL(construct_sub_query_table(empty_str, knn_queries.at(i), sub_query))) { LOG_WARN("fail to create sub query table", K(ret)); } else if (OB_FAIL(multi_set_table->sub_queries_.push_back(sub_query))) { @@ -459,6 +467,13 @@ int ObESQueryParser::knn_fusion(const ObIArray &knn_queries LOG_WARN("fail to append from item", K(ret)); } else if (OB_FAIL(res->score_items_.push_back(sum_expr))) { LOG_WARN("fail to append select item", K(ret)); + } + for (int64_t i = 0; OB_SUCC(ret) && i < part_aliases_.count(); i++) { + if (OB_FAIL(res->group_items_.push_back(part_aliases_.at(i)))) { + LOG_WARN("failed to add partition column to group items", K(ret), K(i)); + } + } + if (OB_FAIL(ret)) { } else if (OB_FAIL(res->group_items_.push_back(rowkey_expr))) { LOG_WARN("fail to push query order item", K(ret)); } else if (OB_FAIL(construct_order_by_item(sum_expr, false, order_info))) { @@ -591,7 +606,7 @@ int ObESQueryParser::set_output_columns(ObQueryReqFromJson &query_res, bool is_h } } for (uint64_t i = 0; OB_SUCC(ret) && i < out_cols_->count(); i++) { - if (is_hybrid && !is_inner_column(out_cols_->at(i))) { + if (is_hybrid && (!is_inner_column(out_cols_->at(i)) || out_cols_->at(i) == ROWKEY_NAME)) { ObReqColumnExpr *fts_col = NULL; ObReqColumnExpr *vs_col = NULL; ObReqExpr *if_null = NULL; @@ -661,6 +676,49 @@ int ObESQueryParser::construct_join_condition(const ObString &l_table, const ObS return ret; } +int ObESQueryParser::construct_join_multi_condition(const ObString &l_table, const ObString &r_table, + const ObString &rowkey, ObItemType condition, ObReqOpExpr *&join_condition) +{ + int ret = OB_SUCCESS; + ObReqOpExpr *rowkey_condition = nullptr; + ObSEArray conditions; + for (int64_t i = 0; OB_SUCC(ret) && i < part_aliases_.count(); i++) { + const ObString &key_name = part_aliases_.at(i)->expr_name; + ObReqOpExpr *key_condition = nullptr; + if (OB_FAIL(construct_join_condition(l_table, r_table, key_name, key_name, condition, key_condition))) { + LOG_WARN("fail to create partition key join condition", K(ret), K(i)); + } else if (OB_FAIL(conditions.push_back(key_condition))) { + LOG_WARN("fail to push back key condition", K(ret), K(i)); + } + } + if (OB_FAIL(ret)) { + } else if (OB_FAIL(construct_join_condition(l_table, r_table, rowkey, rowkey, condition, rowkey_condition))) { + LOG_WARN("fail to create rowkey join condition", K(ret)); + } else if (OB_FAIL(conditions.push_back(rowkey_condition))) { + LOG_WARN("fail to push back rowkey condition", K(ret)); + } else if (OB_FAIL(ObReqOpExpr::construct_op_expr(join_condition, alloc_, T_OP_AND, conditions))) { + LOG_WARN("fail to construct AND op expr", K(ret)); + } + return ret; +} + +int ObESQueryParser::add_partition_keys_to_select(ObQueryReqFromJson *fts_base, ObQueryReqFromJson *knn_base) +{ + int ret = OB_SUCCESS; + if (part_cols_.empty()) { + // no partition keys, do nothing + } else { + for (int64_t i = 0; OB_SUCC(ret) && i < part_cols_.count(); i++) { + if (OB_NOT_NULL(fts_base) && OB_FAIL(fts_base->select_items_.push_back(part_cols_.at(i)))) { + LOG_WARN("failed to add partition expr to fts select items", K(ret)); + } else if (OB_NOT_NULL(knn_base) && OB_FAIL(knn_base->select_items_.push_back(part_cols_.at(i)))) { + LOG_WARN("failed to add partition expr to knn select items", K(ret)); + } + } + } + return ret; +} + int ObESQueryParser::set_default_score(ObQueryReqFromJson *query_req, double default_score) { int ret = OB_SUCCESS; @@ -745,6 +803,11 @@ int ObESQueryParser::construct_hybrid_query(ObQueryReqFromJson *fts, ObQueryReqF LOG_WARN("fail to create query request", K(ret)); } else if (OB_FAIL(knn_table_type != MULTI_SET && base_table_knn_req->select_items_.push_back(knn_rowkey))) { LOG_WARN("fail to create query request", K(ret)); + } + // add partition keys to base table SELECT items + if (OB_FAIL(ret)) { + } else if (OB_FAIL(add_partition_keys_to_select(base_table_fts_req, knn_table_type != MULTI_SET ? base_table_knn_req : nullptr))) { + LOG_WARN("fail to add partition keys to base tables", K(ret)); } else { fts->output_all_columns_ = false; fts->score_alias_ = FTS_SCORE_NAME; @@ -771,6 +834,15 @@ int ObESQueryParser::construct_hybrid_query(ObQueryReqFromJson *fts, ObQueryReqF LOG_WARN("fail to add score col", K(ret)); } } + // add partition key column references to outer queries if there are any sub queries + for (int64_t i = 0; OB_SUCC(ret) && i < part_aliases_.count(); i++) { + if (base_table_fts_req != fts && OB_FAIL(fts->select_items_.push_back(part_aliases_.at(i)))) { + LOG_WARN("fail to add partition key to fts outer query select items", K(ret), K(i)); + } else if (base_table_knn_req != knn && !knn->output_all_columns_ && + OB_FAIL(knn->select_items_.push_back(part_aliases_.at(i)))) { + LOG_WARN("fail to add partition key to knn outer query select items", K(ret), K(i)); + } + } if (OB_FAIL(ret)) { } else if (base_table_fts_req != fts && OB_FAIL(fts->select_items_.push_back(fts_rowkey))) { LOG_WARN("fail to create query request", K(ret)); @@ -780,29 +852,19 @@ int ObESQueryParser::construct_hybrid_query(ObQueryReqFromJson *fts, ObQueryReqF } if (OB_FAIL(ret)) { - } else if (fusion_config_.method == ObFusionMethod::RRF) { - ObString empty_str = ""; - ObString fts_rank_alias = FTS_RANK_NAME; - ObString vs_rank_alias = VS_RANK_NAME; - if (OB_FAIL(construct_rank_query(empty_str, fts_score, fts_rank_alias, fts))) { - LOG_WARN("fail to construct keyword rank query", K(ret)); - } else if (OB_FAIL(construct_rank_query(empty_str, knn_score, vs_rank_alias, knn))) { - LOG_WARN("fail to construct keyword rank query", K(ret)); - } else if (OB_FAIL(construct_rank_score(fts_alias, fts_rank_alias, fts_score))) { - LOG_WARN("fail to construct keyword rank score expr", K(ret)); - } else if (OB_FAIL(construct_rank_score(knn_alias, vs_rank_alias, knn_score))) { - LOG_WARN("fail to construct knn rank score expr", K(ret)); - } } else if (FALSE_IT(static_cast(fts_score)->table_name = fts_alias)) { } else if (FALSE_IT(static_cast(knn_score)->table_name = knn_alias)) { } + // NOTE: RRF fusion logic has been moved to ObHybridSearchFusionEngine (execution layer) + // SQL generation now produces unified FULL OUTER JOIN without fusion calculations + // Fusion strategy selection and execution is delegated to the fusion engine if (OB_FAIL(ret)) { } else if (OB_FAIL(construct_sub_query_table(fts_alias, fts, fts_table))) { LOG_WARN("fail to create sub query table", K(ret)); } else if (OB_FAIL(construct_sub_query_table(knn_alias, knn, knn_table))) { LOG_WARN("fail to create sub query table", K(ret)); - } else if (OB_FAIL(construct_join_condition(fts_alias, knn_alias, rowkey, rowkey, T_OP_EQ, join_condition))) { - LOG_WARN("fail to create op expr", K(ret)); + } else if (OB_FAIL(construct_join_multi_condition(fts_alias, knn_alias, rowkey, T_OP_EQ, join_condition))) { + LOG_WARN("fail to create join condition", K(ret)); } else if (FALSE_IT(join_table->init(fts_table, knn_table, join_condition, ObReqJoinType::FULL_OUTER_JOIN))) { } else if (OB_FAIL(hybrid->from_items_.push_back(join_table))) { LOG_WARN("fail to append from item", K(ret)); @@ -950,20 +1012,12 @@ int ObESQueryParser::construct_es_expr_field(ObReqColumnExpr *raw_field, ObReqEx ret = OB_INVALID_ARGUMENT; LOG_WARN("raw_field is null", K(ret)); } else { - char *buf = static_cast(alloc_.alloc(OB_MAX_COLUMN_NAME_LENGTH)); - int64_t pos = 0; - if (OB_ISNULL(buf)) { - ret = OB_ALLOCATE_MEMORY_FAILED; - LOG_WARN("fail to allocate memory for field param buffer", K(ret)); - } else if (OB_FAIL(databuff_printf(buf, OB_MAX_COLUMN_NAME_LENGTH, pos, "%.*s", raw_field->expr_name.length(), raw_field->expr_name.ptr()))) { - LOG_WARN("fail to write field name", K(ret)); - } else if (OB_FAIL(databuff_printf(buf, OB_MAX_COLUMN_NAME_LENGTH, pos, "^%.15g", (raw_field->weight_ == -1.0) ? 1.0 : raw_field->weight_))) { - LOG_WARN("fail to write field weight", K(ret)); + ObReqColumnExpr *col_field = nullptr; + double weight = (raw_field->weight_ == -1.0) ? 1.0 : raw_field->weight_; + if (OB_FAIL(ObReqColumnExpr::construct_column_expr(col_field, alloc_, raw_field->expr_name, weight, true))) { + LOG_WARN("fail to create column expr for ES field", K(ret)); } else { - ObString field_param_str(pos, buf); - if (OB_FAIL(ObReqExpr::construct_expr(field, alloc_, field_param_str))) { - LOG_WARN("fail to create field param expr", K(ret)); - } + field = col_field; } } return ret; @@ -2251,28 +2305,28 @@ int ObESQueryParser::parse_query_string(ObIJsonBase &req_node, ObEsQueryInfo &qu } if (OB_SUCC(ret)) { - if (OB_SUCC(parse_query_string_fields(req_node, query_info))) { + if (OB_SUCC(parse_query_string_operator(req_node, query_info))) { parsed_keys++; + } else if (ret == OB_SEARCH_NOT_FOUND) { + ret = OB_SUCCESS; } else { - LOG_WARN("fail to parse query_string fields", K(ret)); + LOG_WARN("fail to parse query_string operator", K(ret)); } } if (OB_SUCC(ret)) { - if (OB_SUCC(parse_query_string_query(req_node, query_info))) { + if (OB_SUCC(parse_query_string_fields(req_node, query_info))) { parsed_keys++; } else { - LOG_WARN("fail to parse query_string query", K(ret)); + LOG_WARN("fail to parse query_string fields", K(ret)); } } if (OB_SUCC(ret)) { - if (OB_SUCC(parse_query_string_operator(req_node, query_info))) { + if (OB_SUCC(parse_query_string_query(req_node, query_info))) { parsed_keys++; - } else if (ret == OB_SEARCH_NOT_FOUND) { - ret = OB_SUCCESS; } else { - LOG_WARN("fail to parse query_string operator", K(ret)); + LOG_WARN("fail to parse query_string query", K(ret)); } } @@ -2507,7 +2561,8 @@ int ObESQueryParser::parse_keyword_query_string(ObEsQueryInfo &query_info, } if (OB_FAIL(ret)) { - } else if (query_info.score_type_ != SCORE_TYPE_CROSS_FIELDS) { + } else if (query_info.score_type_ == SCORE_TYPE_PHRASE || + (query_info.score_type_ != SCORE_TYPE_CROSS_FIELDS && query_info.opr_ != T_OP_AND)) { common::ObSEArray current_phrase_keywords; for (int64_t i = 0; OB_SUCC(ret) && i < raw_keywords.count(); i++) { ObReqConstExpr *current_keyword = raw_keywords.at(i); @@ -3882,6 +3937,38 @@ int ObESQueryParser::concat_const_exprs(const common::ObIArray return ret; } +int ObESQueryParser::construct_partition_cols(const ObIArray &column_names) +{ + int ret = OB_SUCCESS; + for (int64_t i = 0; OB_SUCC(ret) && i < column_names.count(); i++) { + const ObString &col_name = column_names.at(i); + ObReqColumnExpr *part_col = nullptr; + ObReqColumnExpr *part_col_alias = nullptr; + char *alias_buf = nullptr; + const int64_t alias_buf_len = OB_MAX_COLUMN_NAME_LENGTH; + int64_t alias_pos = 0; + ObString alias_str; + if (OB_ISNULL(alias_buf = static_cast(alloc_.alloc(alias_buf_len)))) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("failed to allocate memory for partition expr alias", K(ret)); + } else if (OB_FAIL(databuff_printf(alias_buf, alias_buf_len, alias_pos, "%.*s%ld", + PART_COL_ALIAS_PREFIX.length(), PART_COL_ALIAS_PREFIX.ptr(), i))) { + LOG_WARN("failed to generate partition expr alias", K(ret)); + } else if (OB_FALSE_IT(alias_str.assign_ptr(alias_buf, alias_pos))) { + } else if (OB_FAIL(ObReqColumnExpr::construct_column_expr(part_col_alias, alloc_, alias_str))) { + LOG_WARN("failed to construct partition column alias expr", K(ret), K(alias_str)); + } else if (OB_FAIL(part_aliases_.push_back(part_col_alias))) { + LOG_WARN("failed to push back partition key column alias expr", K(ret), K(part_col_alias)); + } else if (OB_FAIL(ObReqColumnExpr::construct_column_expr(part_col, alloc_, col_name))) { + LOG_WARN("failed to construct partition column expr", K(ret), K(col_name)); + } else if (OB_FALSE_IT(part_col->set_alias(alias_str))) { + } else if (OB_FAIL(part_cols_.push_back(part_col))) { + LOG_WARN("failed to push back partition column expr", K(ret)); + } + } + return ret; +} + void ObEsQueryInfo::set_msm_apply_type() { uint64_t msm_val = msm_info_.get_msm_val(); diff --git a/src/share/hybrid_search/ob_query_parse.h b/src/share/hybrid_search/ob_query_parse.h index 48b3ffd28..6e3baa694 100644 --- a/src/share/hybrid_search/ob_query_parse.h +++ b/src/share/hybrid_search/ob_query_parse.h @@ -56,6 +56,9 @@ enum ObFusionMethod { WEIGHT_SUM = 0, RRF + // NOTE: Fusion strategies are now centralized in ObHybridSearchFusionEngine + // This enum is retained for backward compatibility during migration + // Future versions will deprecate SQL-layer fusion calculations }; enum ObMsmApplyType : int8_t { @@ -251,17 +254,23 @@ class ObESQueryParser public : ObESQueryParser(ObIAllocator &alloc, common::ObString *table_name) : alloc_(alloc), source_cols_(), need_json_wrap_(false), table_name_(*table_name), database_name_(), index_name_map_(), - user_cols_(), out_cols_(nullptr), enable_es_mode_(false), fusion_config_(), default_size_(nullptr) {} + user_cols_(), part_cols_(), part_aliases_(), out_cols_(nullptr), enable_es_mode_(false), fusion_config_(), default_size_(nullptr) {} ObESQueryParser(ObIAllocator &alloc, bool need_json_wrap, const common::ObString *table_name, const common::ObString *database_name = nullptr, bool enable_es_mode = false) : alloc_(alloc), source_cols_(), need_json_wrap_(need_json_wrap), table_name_(*table_name), database_name_(*database_name), index_name_map_(), - user_cols_(), out_cols_(nullptr), enable_es_mode_(enable_es_mode), fusion_config_(), default_size_(nullptr) {} + user_cols_(), part_cols_(), part_aliases_(), out_cols_(nullptr), enable_es_mode_(enable_es_mode), fusion_config_(), default_size_(nullptr) {} virtual ~ObESQueryParser() {} int parse(const common::ObString &req_str, ObQueryReqFromJson *&query_req); inline ColumnIndexNameMap &get_index_name_map() { return index_name_map_; } inline ObIArray &get_user_column_names() { return user_cols_; } + inline ObIArray &get_partition_column_exprs() { return part_cols_; } + inline ObIArray &get_partition_aliases() { return part_aliases_; } + inline ObIAllocator &get_allocator() { return alloc_; } + inline const ObString &get_table_name() const { return table_name_; } + inline const ObString &get_database_name() const { return database_name_; } + int construct_partition_cols(const ObIArray &column_names); private : int parse_query(ObIJsonBase &req_node, ObQueryReqFromJson *&query_req); int parse_knn(ObIJsonBase &req_node, ObQueryReqFromJson *&query_req); @@ -327,6 +336,8 @@ private : int construct_join_condition(const ObString &l_table, const ObString &r_table, const ObString &l_expr_name, const ObString &r_expr_name, ObItemType condition, ObReqOpExpr *&join_condition); + int construct_join_multi_condition(const ObString &l_table, const ObString &r_table, const ObString &rowkey, ObItemType condition, ObReqOpExpr *&join_condition); + int add_partition_keys_to_select(ObQueryReqFromJson *fts_base, ObQueryReqFromJson *knn_base); int construct_weighted_expr(ObReqExpr *base_expr, double weight, ObReqExpr *&weighted_expr); int construct_ip_expr(ObReqColumnExpr *vec_field, ObReqConstExpr *query_vec, ObReqCaseWhenExpr *&case_when/* score */, ObReqOpExpr *&minus_expr/* distance */, ObReqExpr *&order_by_vec); @@ -377,6 +388,8 @@ private : common::ObString database_name_; ColumnIndexNameMap index_name_map_; common::ObSEArray user_cols_; + common::ObSEArray part_cols_; + common::ObSEArray part_aliases_; ObIArray *out_cols_; // if enable es mode bool enable_es_mode_ = false; @@ -395,6 +408,7 @@ private : static const ObString VS_ALIAS; static const ObString MSM_KEY; static const ObString FTS_SUB_SCORE_PREFIX; + static const ObString PART_COL_ALIAS_PREFIX; static const ObString HIDDEN_COLUMN_VISIBLE_HINT; }; diff --git a/src/share/hybrid_search/ob_query_translator.cpp b/src/share/hybrid_search/ob_query_translator.cpp index b9a11184c..b0a3bbe9b 100644 --- a/src/share/hybrid_search/ob_query_translator.cpp +++ b/src/share/hybrid_search/ob_query_translator.cpp @@ -132,10 +132,10 @@ int ObQueryTranslator::translate_select() query_req->score_alias_.empty()) { // do nothing } else if (query_req->score_alias_.empty()) { - DATA_PRINTF(" as _score"); + DATA_PRINTF(" as `_score`"); } else { DATA_PRINTF(" as "); - PRINT_IDENT(query_req->score_alias_); + PRINT_IDENT_WITH_QUOT(query_req->score_alias_); } } } @@ -150,7 +150,7 @@ int ObQueryTranslator::translate_order(const OrderInfo *order_info) LOG_WARN("fail to translate expr", K(ret)); } } else { - PRINT_IDENT(order_info->order_item->alias_name); + PRINT_IDENT_WITH_QUOT(order_info->order_item->alias_name); } if (OB_FAIL(ret)) { } else if (order_info->ascent == false) { @@ -199,10 +199,10 @@ int ObRequestTranslator::translate_table(const ObReqTable *table) int ret = OB_SUCCESS; if (table->table_type_ == ReqTableType::BASE_TABLE) { if (!table->database_name_.empty()) { - PRINT_IDENT(table->database_name_); + PRINT_IDENT_WITH_QUOT(table->database_name_); DATA_PRINTF("."); } - PRINT_IDENT(table->table_name_); + PRINT_IDENT_WITH_QUOT(table->table_name_); } else if (table->table_type_ == ReqTableType::SUB_QUERY) { ObQueryReqFromJson *ref_query = static_cast(table->ref_query_); int64_t res_len = 0; @@ -211,8 +211,11 @@ int ObRequestTranslator::translate_table(const ObReqTable *table) LOG_WARN("failed to translate ref query", K(ret), K(*pos_), K(buf_len_)); } else { (*pos_) += res_len; - DATA_PRINTF(") "); - PRINT_IDENT(table->alias_name_); + DATA_PRINTF(")"); + if (!table->alias_name_.empty()) { + DATA_PRINTF(" "); + PRINT_IDENT_WITH_QUOT(table->alias_name_); + } } } else if (table->table_type_ == ReqTableType::MULTI_SET) { const ObMultiSetTable *mul_tab = static_cast(table); diff --git a/src/share/hybrid_search/ob_request_base.cpp b/src/share/hybrid_search/ob_request_base.cpp index 9ee387543..eaf3635a3 100644 --- a/src/share/hybrid_search/ob_request_base.cpp +++ b/src/share/hybrid_search/ob_request_base.cpp @@ -289,7 +289,7 @@ int ObReqExpr::translate_alias(ObObjPrintParams &print_params_, char *buf_, int6 int ret = OB_SUCCESS; if (!alias_name.empty()) { DATA_PRINTF(" as "); - PRINT_IDENT(alias_name); + PRINT_IDENT_WITH_QUOT(alias_name); } return ret; } @@ -322,10 +322,14 @@ int ObReqColumnExpr::translate_expr(ObObjPrintParams &print_params_, char *buf_, { int ret = OB_SUCCESS; if (!table_name.empty()) { - PRINT_IDENT(table_name); + PRINT_IDENT_WITH_QUOT(table_name); DATA_PRINTF("."); } - PRINT_IDENT(expr_name); + PRINT_IDENT_WITH_QUOT(expr_name); + // For ES mode MATCH field, output weight in format: `column`^weight + if (OB_SUCC(ret) && print_weight_ && weight_ >= 0) { + DATA_PRINTF("^%.15g", weight_); + } if (OB_SUCC(ret) && need_alias && translate_alias(print_params_, buf_, buf_len_, pos_)) { LOG_WARN("fail to translate expr alias", K(ret)); } @@ -443,7 +447,7 @@ int OrderInfo::translate(ObObjPrintParams &print_params_, char *buf_, int64_t bu LOG_WARN("fail to translate expr", K(ret)); } } else { - PRINT_IDENT(order_item->alias_name); + PRINT_IDENT_WITH_QUOT(order_item->alias_name); } if (OB_FAIL(ret)) { } else if (ascent == false) { @@ -731,20 +735,20 @@ int ObReqExpr::construct_expr(ObReqExpr *&expr, ObIAllocator &alloc, const ObStr return ret; } -int ObReqColumnExpr::construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name, double weight) +int ObReqColumnExpr::construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name, double weight, bool print_weight) { int ret = OB_SUCCESS; - if (OB_ISNULL(expr = OB_NEWx(ObReqColumnExpr, &alloc, expr_name, ObString(), weight))) { + if (OB_ISNULL(expr = OB_NEWx(ObReqColumnExpr, &alloc, expr_name, ObString(), weight, print_weight))) { ret = OB_ALLOCATE_MEMORY_FAILED; LOG_WARN("fail to create column expr", K(ret)); } return ret; } -int ObReqColumnExpr::construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name, const ObString &table_name, double weight) +int ObReqColumnExpr::construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name, const ObString &table_name, double weight, bool print_weight) { int ret = OB_SUCCESS; - if (OB_ISNULL(expr = OB_NEWx(ObReqColumnExpr, &alloc, expr_name, table_name, weight))) { + if (OB_ISNULL(expr = OB_NEWx(ObReqColumnExpr, &alloc, expr_name, table_name, weight, print_weight))) { ret = OB_ALLOCATE_MEMORY_FAILED; LOG_WARN("fail to create column expr", K(ret)); } diff --git a/src/share/hybrid_search/ob_request_base.h b/src/share/hybrid_search/ob_request_base.h index 9fdc6417b..926dad20e 100644 --- a/src/share/hybrid_search/ob_request_base.h +++ b/src/share/hybrid_search/ob_request_base.h @@ -155,16 +155,17 @@ class ObReqColumnExpr : public ObReqExpr { public: ObReqColumnExpr() = delete; - static int construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name = ObString(), double weight = -1.0); - static int construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name, const ObString &table_name, double weight = -1.0); + static int construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name = ObString(), double weight = -1.0, bool print_weight = false); + static int construct_column_expr(ObReqColumnExpr *&expr, ObIAllocator &alloc, const ObString &expr_name, const ObString &table_name, double weight = -1.0, bool print_weight = false); virtual ~ObReqColumnExpr() {} virtual int translate_expr(ObObjPrintParams &print_params_, char *buf_, int64_t buf_len_, int64_t *pos_, ObReqScope scope = FIELD_LIST_SCOPE, bool need_alias = true); virtual int get_expr_type() { return ObReqExprType::COLUMN_EXPR; } common::ObString table_name; double weight_; + bool print_weight_; private: - ObReqColumnExpr(const ObString &expr_name, const ObString &table_name = ObString(), double weight = -1.0) - : ObReqExpr(expr_name), table_name(table_name), weight_(weight) {} + ObReqColumnExpr(const ObString &expr_name, const ObString &table_name = ObString(), double weight = -1.0, bool print_weight = false) + : ObReqExpr(expr_name), table_name(table_name), weight_(weight), print_weight_(print_weight) {} }; class ObReqConstExpr : public ObReqExpr diff --git a/src/share/hybrid_search/ob_rrf_fusion.cpp b/src/share/hybrid_search/ob_rrf_fusion.cpp new file mode 100644 index 000000000..cb30c6532 --- /dev/null +++ b/src/share/hybrid_search/ob_rrf_fusion.cpp @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define USING_LOG_PREFIX OLOG + +#include "ob_rrf_fusion.h" +#include "lib/ob_errno.h" + +namespace oceanbase +{ +namespace common +{ +using namespace oceanbase::common; + +ObRRFFusion::ObRRFFusion() + : is_initialized_(false), allocator_(nullptr) +{ +} + +ObRRFFusion::~ObRRFFusion() +{ + reset(); +} + +int ObRRFFusion::init(const ObRRFConfig &config, ObIAllocator &allocator) +{ + int ret = OB_SUCCESS; + + if (is_initialized_) { + ret = OB_INIT_TWICE; + OB_LOG(WARN, "RRF fusion is already initialized", K(ret)); + } else if (OB_FAIL(validate_config(config))) { + OB_LOG(WARN, "failed to validate rrf config", K(ret)); + } else { + config_ = config; + allocator_ = &allocator; + + // Initialize result mapping table + if (OB_FAIL(result_map_.create(10240, common::ObMemAttr(common::OB_SERVER_TENANT_ID)))) { + OB_LOG(WARN, "failed to create result map", K(ret)); + } else { + is_initialized_ = true; + OB_LOG(DEBUG, "RRF fusion initialized successfully"); + } + } + + return ret; +} + +int ObRRFFusion::add_fts_results(const common::ObIArray &fts_results) +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "RRF fusion is not initialized", K(ret)); + } else if (OB_FAIL(fts_results_.assign(fts_results))) { + OB_LOG(WARN, "failed to assign fts results", K(ret)); + } else { + OB_LOG(DEBUG, "add fts results successfully", "count", fts_results_.count()); + } + + return ret; +} + +int ObRRFFusion::add_vector_results(const common::ObIArray &vector_results) +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "RRF fusion is not initialized", K(ret)); + } else if (OB_FAIL(vector_results_.assign(vector_results))) { + OB_LOG(WARN, "failed to assign vector results", K(ret)); + } else { + OB_LOG(DEBUG, "add vector results successfully", "count", vector_results_.count()); + } + + return ret; +} + +double ObRRFFusion::calculate_rrf_score(int64_t rank) const +{ + // RRF formula: score = 1 / (rank + rank_constant) + // rank starts counting from 1 + if (rank <= 0) { + return 0.0; + } + return 1.0 / (rank + config_.rank_constant_); +} + +int ObRRFFusion::fuse() +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "RRF fusion is not initialized", K(ret)); + return ret; + } + + // Clear fused results and mapping table + fused_results_.reuse(); + result_map_.clear(); + + // Process full-text search results + for (int64_t i = 0; OB_SUCC(ret) && i < fts_results_.count(); ++i) { + const ObHybridSearchResult &result = fts_results_.at(i); + int64_t rank = i + 1; // Rank starts from 1 + + ObHybridSearchResult merged_result = result; + merged_result.fts_rank_ = rank; + merged_result.fts_score_ = calculate_rrf_score(rank); + merged_result.source_flag_ |= 1; // Mark as from full-text search + + if (OB_FAIL(result_map_.set_refactored(result.doc_id_, merged_result))) { + OB_LOG(WARN, "failed to insert fts result into map", K(ret), K(result)); + } + } + + // Process vector search results + for (int64_t i = 0; OB_SUCC(ret) && i < vector_results_.count(); ++i) { + const ObHybridSearchResult &result = vector_results_.at(i); + int64_t rank = i + 1; // Rank starts from 1 + + ObHybridSearchResult merged_result = result; + merged_result.vector_rank_ = rank; + merged_result.vector_score_ = calculate_rrf_score(rank); + merged_result.source_flag_ |= 2; // Mark as from vector search + + ObHybridSearchResult existing; + if (OB_FAIL(result_map_.get_refactored(result.doc_id_, existing))) { + if (OB_HASH_NOT_EXIST == ret) { + // New document, insert directly + ret = result_map_.set_refactored(result.doc_id_, merged_result); + if (OB_FAIL(ret)) { + OB_LOG(WARN, "failed to insert vector result into map", K(ret), K(result)); + } + } else { + OB_LOG(WARN, "failed to get result from map", K(ret)); + } + } else { + // Document already exists, update score and rank + existing.vector_rank_ = rank; + existing.vector_score_ = calculate_rrf_score(rank); + existing.source_flag_ |= 2; // Add vector search flag + + if (OB_FAIL(result_map_.set_refactored(result.doc_id_, existing))) { + OB_LOG(WARN, "failed to update result in map", K(ret)); + } + } + } + + // Extract results from mapping table and calculate final score + for (ResultMap::iterator iter = result_map_.begin(); OB_SUCC(ret) && iter != result_map_.end(); ++iter) { + ObHybridSearchResult result = iter->second; + + // Calculate final score (sum of two scores) + result.final_score_ = result.fts_score_ + result.vector_score_; + + if (OB_FAIL(fused_results_.push_back(result))) { + OB_LOG(WARN, "failed to push back fused result", K(ret)); + } + } + + // Sort by final score in descending order + if (OB_SUCC(ret)) { + std::sort(fused_results_.begin(), fused_results_.end(), + [](const ObHybridSearchResult &a, const ObHybridSearchResult &b) { + if (a.final_score_ != b.final_score_) { + return a.final_score_ > b.final_score_; + } + return a.doc_id_ < b.doc_id_; + }); + + OB_LOG(DEBUG, "RRF fusion completed successfully", + "fused_count", fused_results_.count(), + "fts_count", fts_results_.count(), + "vector_count", vector_results_.count()); + } + + return ret; +} + +int ObRRFFusion::get_results(common::ObIArray &results, int64_t limit) const +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "RRF fusion is not initialized", K(ret)); + return ret; + } + + int64_t count = fused_results_.count(); + if (limit > 0 && limit < count) { + count = limit; + } + + for (int64_t i = 0; OB_SUCC(ret) && i < count; ++i) { + if (OB_FAIL(results.push_back(fused_results_.at(i)))) { + OB_LOG(WARN, "failed to push back result", K(ret)); + } + } + + return ret; +} + +const ObHybridSearchResult *ObRRFFusion::get_result_at(int64_t index) const +{ + if (index < 0 || index >= fused_results_.count()) { + return nullptr; + } + return &fused_results_.at(index); +} + +void ObRRFFusion::reset() +{ + fts_results_.reuse(); + vector_results_.reuse(); + fused_results_.reuse(); + if (nullptr != allocator_) { + result_map_.clear(); + result_map_.destroy(); + } + is_initialized_ = false; + allocator_ = nullptr; +} + +int ObRRFFusion::validate_config(const ObRRFConfig &config) const +{ + int ret = OB_SUCCESS; + + if (config.rank_constant_ < 0) { + ret = OB_INVALID_ARGUMENT; + OB_LOG(WARN, "rank constant must be non-negative", K(ret), K(config.rank_constant_)); + } else if (config.rank_window_size_ <= 0) { + ret = OB_INVALID_ARGUMENT; + OB_LOG(WARN, "rank window size must be positive", K(ret), K(config.rank_window_size_)); + } + + return ret; +} + +} // namespace common +} // namespace oceanbase diff --git a/src/share/hybrid_search/ob_rrf_fusion.h b/src/share/hybrid_search/ob_rrf_fusion.h new file mode 100644 index 000000000..12c93a5a7 --- /dev/null +++ b/src/share/hybrid_search/ob_rrf_fusion.h @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OB_RRF_FUSION_H +#define OB_RRF_FUSION_H + +#include "ob_hybrid_search_common.h" +#include "lib/container/ob_se_array.h" +#include "lib/hash/ob_hashmap.h" + +namespace oceanbase +{ +namespace common +{ + +/* + * RRF (Reciprocal Rank Fusion) Fusion Implementation + * + * Basic Principle: + * RRF is a parameter-free fusion algorithm that converts multiple ranked lists into scores, + * and combines these scores to generate a hybrid ranking. + * + * Formula: + * score = 1/(rank + rank_constant) for each search engine + * final_score = score_from_fts + score_from_vector + * + * Advantages: + * 1. Automatic normalization: naturally solves normalization problems between different scoring systems + * 2. Strong robustness: insensitive to outliers + * 3. Simple parameters: only requires configuring rank_constant + * 4. Excellent performance: no extra normalization computation needed + * + * Application Scenarios: + * - Search applications that need to balance keyword matching and semantic similarity + * - Applications robust to anomalous score values + * - Medium-scale datasets (typically rank_window_size = 100-1000) + */ +class ObRRFFusion +{ +public: + typedef common::hash::ObHashMap ResultMap; + + ObRRFFusion(); + virtual ~ObRRFFusion(); + + /* + * Initialize RRF fusion engine + * + * @param config RRF configuration parameters + * @param allocator Memory allocator + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int init(const ObRRFConfig &config, ObIAllocator &allocator); + + /* + * Add full-text search results + * + * @param fts_results Full-text search result list, sorted by relevance in descending order + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int add_fts_results(const common::ObIArray &fts_results); + + /* + * Add vector search results + * + * @param vector_results Vector search result list, sorted by similarity in descending order + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int add_vector_results(const common::ObIArray &vector_results); + + /* + * Execute RRF fusion calculation + * + * This method will: + * 1. Assign ranks to each result in both result lists + * 2. Calculate normalized scores using RRF formula + * 3. Merge results from both lists + * 4. Sort by final score + * + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int fuse(); + + /* + * Get fusion results + * + * @param results Output parameter containing the fused result list + * @param limit Maximum number of results to return, 0 means return all results + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int get_results(common::ObIArray &results, int64_t limit = 0) const; + + /* + * Reset fusion engine state, prepare for next fusion + */ + void reset(); + + /* + * Get count of full-text search results + */ + int64_t get_fts_result_count() const { return fts_results_.count(); } + + /* + * Get count of vector search results + */ + int64_t get_vector_result_count() const { return vector_results_.count(); } + + /* + * Get count of fused results + */ + int64_t get_fused_result_count() const { return fused_results_.count(); } + + /* + * Get single fused result + * + * @param index Result index + * @return Fused result, returns empty result if index is out of bounds + */ + const ObHybridSearchResult *get_result_at(int64_t index) const; + +private: + // Calculate RRF score + double calculate_rrf_score(int64_t rank) const; + + // Validate configuration parameters + int validate_config(const ObRRFConfig &config) const; + +private: + // RRF configuration parameters + ObRRFConfig config_; + + // Full-text search results + common::ObSEArray fts_results_; + + // Vector search results + common::ObSEArray vector_results_; + + // Fused results + common::ObSEArray fused_results_; + + // Record initialization state + bool is_initialized_; + + // Memory allocator (non-owner) + ObIAllocator *allocator_; + + // Result mapping table for deduplication + ResultMap result_map_; +}; + +} // namespace common +} // namespace oceanbase + +#endif // OB_RRF_FUSION_H diff --git a/src/share/hybrid_search/ob_weighted_fusion.cpp b/src/share/hybrid_search/ob_weighted_fusion.cpp new file mode 100644 index 000000000..e68854b44 --- /dev/null +++ b/src/share/hybrid_search/ob_weighted_fusion.cpp @@ -0,0 +1,385 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define USING_LOG_PREFIX OLOG + +#include "ob_weighted_fusion.h" +#include "lib/ob_errno.h" + +namespace oceanbase +{ +namespace common +{ +using namespace oceanbase::common; + +ObWeightedFusion::ObWeightedFusion() + : is_initialized_(false), allocator_(nullptr), stats_calculated_(false) +{ +} + +ObWeightedFusion::~ObWeightedFusion() +{ + reset(); +} + +int ObWeightedFusion::init(const ObWeightedFusionConfig &config, + const ObNormalizationConfig &norm_config, + ObIAllocator &allocator) +{ + int ret = OB_SUCCESS; + + if (is_initialized_) { + ret = OB_INIT_TWICE; + OB_LOG(WARN, "weighted fusion is already initialized", K(ret)); + } else if (OB_FAIL(validate_config(config))) { + OB_LOG(WARN, "failed to validate config", K(ret)); + } else { + fusion_config_ = config; + norm_config_ = norm_config; + allocator_ = &allocator; + + // Initialize result mapping table + if (OB_FAIL(result_map_.create(10240, common::ObMemAttr(common::OB_SERVER_TENANT_ID)))) { + OB_LOG(WARN, "failed to create result map", K(ret)); + } else { + is_initialized_ = true; + OB_LOG(DEBUG, "weighted fusion initialized successfully", + K(config.fts_weight_), K(config.vector_weight_)); + } + } + + return ret; +} + +int ObWeightedFusion::add_fts_results(const common::ObIArray &fts_results) +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "weighted fusion is not initialized", K(ret)); + } else if (OB_FAIL(fts_results_.assign(fts_results))) { + OB_LOG(WARN, "failed to assign fts results", K(ret)); + } else { + OB_LOG(DEBUG, "add fts results successfully", "count", fts_results_.count()); + } + + return ret; +} + +int ObWeightedFusion::add_vector_results(const common::ObIArray &vector_results) +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "weighted fusion is not initialized", K(ret)); + } else if (OB_FAIL(vector_results_.assign(vector_results))) { + OB_LOG(WARN, "failed to assign vector results", K(ret)); + } else { + OB_LOG(DEBUG, "add vector results successfully", "count", vector_results_.count()); + } + + return ret; +} + +int ObWeightedFusion::calculate_statistics() +{ + int ret = OB_SUCCESS; + + // Calculate statistics for full-text search scores + if (fts_results_.count() > 0) { + double sum = 0.0; + fts_stats_.min_score_ = fts_results_.at(0).fts_score_; + fts_stats_.max_score_ = fts_results_.at(0).fts_score_; + fts_stats_.count_ = fts_results_.count(); + + for (int64_t i = 0; i < fts_results_.count(); ++i) { + double score = fts_results_.at(i).fts_score_; + sum += score; + if (score < fts_stats_.min_score_) { + fts_stats_.min_score_ = score; + } + if (score > fts_stats_.max_score_) { + fts_stats_.max_score_ = score; + } + } + + fts_stats_.mean_score_ = sum / fts_stats_.count_; + + // Calculate standard deviation + double variance = 0.0; + for (int64_t i = 0; i < fts_results_.count(); ++i) { + double diff = fts_results_.at(i).fts_score_ - fts_stats_.mean_score_; + variance += diff * diff; + } + fts_stats_.stddev_ = std::sqrt(variance / fts_stats_.count_); + } + + // Calculate statistics for vector search scores + if (vector_results_.count() > 0) { + double sum = 0.0; + vector_stats_.min_score_ = vector_results_.at(0).vector_score_; + vector_stats_.max_score_ = vector_results_.at(0).vector_score_; + vector_stats_.count_ = vector_results_.count(); + + for (int64_t i = 0; i < vector_results_.count(); ++i) { + double score = vector_results_.at(i).vector_score_; + sum += score; + if (score < vector_stats_.min_score_) { + vector_stats_.min_score_ = score; + } + if (score > vector_stats_.max_score_) { + vector_stats_.max_score_ = score; + } + } + + vector_stats_.mean_score_ = sum / vector_stats_.count_; + + // Calculate standard deviation + double variance = 0.0; + for (int64_t i = 0; i < vector_results_.count(); ++i) { + double diff = vector_results_.at(i).vector_score_ - vector_stats_.mean_score_; + variance += diff * diff; + } + vector_stats_.stddev_ = std::sqrt(variance / vector_stats_.count_); + } + + stats_calculated_ = true; + return ret; +} + +double ObWeightedFusion::min_max_normalize(double score, double min_val, double max_val) +{ + if (max_val - min_val < 1e-10) { + return 0.5; // Avoid division by zero + } + return (score - min_val) / (max_val - min_val); +} + +double ObWeightedFusion::z_score_normalize(double score, double mean, double stddev) +{ + if (stddev < 1e-10) { + return 0.0; // Avoid division by zero + } + // Use Sigmoid function to map Z-Score to [0, 1] + double z = (score - mean) / stddev; + return 1.0 / (1.0 + std::exp(-z)); +} + +double ObWeightedFusion::sigmoid_normalize(double score) +{ + return 1.0 / (1.0 + std::exp(-score)); +} + +double ObWeightedFusion::apply_normalization(double score, bool is_fts) +{ + if (!fusion_config_.enable_normalization_) { + return score; + } + + switch (norm_config_.norm_type_) { + case ObNormalizationConfig::NormalizationType::NONE: + return score; + + case ObNormalizationConfig::NormalizationType::MIN_MAX: { + if (is_fts) { + return min_max_normalize(score, fts_stats_.min_score_, fts_stats_.max_score_); + } else { + return min_max_normalize(score, vector_stats_.min_score_, vector_stats_.max_score_); + } + } + + case ObNormalizationConfig::NormalizationType::Z_SCORE: { + if (is_fts) { + return z_score_normalize(score, fts_stats_.mean_score_, fts_stats_.stddev_); + } else { + return z_score_normalize(score, vector_stats_.mean_score_, vector_stats_.stddev_); + } + } + + case ObNormalizationConfig::NormalizationType::SIGMOID: + return sigmoid_normalize(score); + + default: + return score; + } +} + +int ObWeightedFusion::fuse() +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "weighted fusion is not initialized", K(ret)); + return ret; + } + + // 清空融合结果和映射表 + fused_results_.reuse(); + result_map_.clear(); + + // 计算统计信息 + if (OB_FAIL(calculate_statistics())) { + OB_LOG(WARN, "failed to calculate statistics", K(ret)); + return ret; + } + + // 处理全文搜索结果 + for (int64_t i = 0; OB_SUCC(ret) && i < fts_results_.count(); ++i) { + const ObHybridSearchResult &result = fts_results_.at(i); + + ObHybridSearchResult merged_result = result; + merged_result.source_flag_ |= 1; // 标记来自全文搜索 + + if (OB_FAIL(result_map_.set_refactored(result.doc_id_, merged_result))) { + OB_LOG(WARN, "failed to insert fts result into map", K(ret), K(result)); + } + } + + // 处理向量搜索结果 + for (int64_t i = 0; OB_SUCC(ret) && i < vector_results_.count(); ++i) { + const ObHybridSearchResult &result = vector_results_.at(i); + + ObHybridSearchResult merged_result = result; + merged_result.source_flag_ |= 2; // 标记来自向量搜索 + + ObHybridSearchResult existing; + if (OB_FAIL(result_map_.get_refactored(result.doc_id_, existing))) { + if (OB_HASH_NOT_EXIST == ret) { + // 新文档,直接插入 + ret = result_map_.set_refactored(result.doc_id_, merged_result); + if (OB_FAIL(ret)) { + OB_LOG(WARN, "failed to insert vector result into map", K(ret), K(result)); + } + } else { + OB_LOG(WARN, "failed to get result from map", K(ret)); + } + } else { + // 文档已存在,更新向量分数 + existing.vector_score_ = result.vector_score_; + existing.source_flag_ |= 2; // 添加向量搜索标记 + + if (OB_FAIL(result_map_.set_refactored(result.doc_id_, existing))) { + OB_LOG(WARN, "failed to update result in map", K(ret)); + } + } + } + + // 从映射表中提取结果并计算最终得分 + for (ResultMap::iterator iter = result_map_.begin(); OB_SUCC(ret) && iter != result_map_.end(); ++iter) { + ObHybridSearchResult result = iter->second; + + // 规范化分数 + double normalized_fts = apply_normalization(result.fts_score_, true); + double normalized_vector = apply_normalization(result.vector_score_, false); + + // 计算加权和 + result.final_score_ = fusion_config_.fts_weight_ * normalized_fts + + fusion_config_.vector_weight_ * normalized_vector; + + if (OB_FAIL(fused_results_.push_back(result))) { + OB_LOG(WARN, "failed to push back fused result", K(ret)); + } + } + + // 按最终得分降序排序 + if (OB_SUCC(ret)) { + std::sort(fused_results_.begin(), fused_results_.end(), + [](const ObHybridSearchResult &a, const ObHybridSearchResult &b) { + if (a.final_score_ != b.final_score_) { + return a.final_score_ > b.final_score_; + } + return a.doc_id_ < b.doc_id_; + }); + + OB_LOG(DEBUG, "weighted fusion completed successfully", + "fused_count", fused_results_.count(), + "fts_count", fts_results_.count(), + "vector_count", vector_results_.count()); + } + + return ret; +} + +int ObWeightedFusion::get_results(common::ObIArray &results, int64_t limit) const +{ + int ret = OB_SUCCESS; + + if (!is_initialized_) { + ret = OB_NOT_INIT; + OB_LOG(WARN, "weighted fusion is not initialized", K(ret)); + return ret; + } + + int64_t count = fused_results_.count(); + if (limit > 0 && limit < count) { + count = limit; + } + + for (int64_t i = 0; OB_SUCC(ret) && i < count; ++i) { + if (OB_FAIL(results.push_back(fused_results_.at(i)))) { + OB_LOG(WARN, "failed to push back result", K(ret)); + } + } + + return ret; +} + +const ObHybridSearchResult *ObWeightedFusion::get_result_at(int64_t index) const +{ + if (index < 0 || index >= fused_results_.count()) { + return nullptr; + } + return &fused_results_.at(index); +} + +void ObWeightedFusion::reset() +{ + fts_results_.reuse(); + vector_results_.reuse(); + fused_results_.reuse(); + if (nullptr != allocator_) { + result_map_.clear(); + result_map_.destroy(); + } + is_initialized_ = false; + allocator_ = nullptr; + stats_calculated_ = false; +} + +int ObWeightedFusion::validate_config(const ObWeightedFusionConfig &config) const +{ + int ret = OB_SUCCESS; + + // Check weights + if (config.fts_weight_ < 0.0 || config.fts_weight_ > 1.0) { + ret = OB_INVALID_ARGUMENT; + OB_LOG(WARN, "fts weight must be in [0, 1]", K(ret), K(config.fts_weight_)); + } else if (config.vector_weight_ < 0.0 || config.vector_weight_ > 1.0) { + ret = OB_INVALID_ARGUMENT; + OB_LOG(WARN, "vector weight must be in [0, 1]", K(ret), K(config.vector_weight_)); + } else if (config.fts_weight_ + config.vector_weight_ < 1e-10) { + ret = OB_INVALID_ARGUMENT; + OB_LOG(WARN, "sum of weights should be positive", K(ret)); + } + + return ret; +} + +} // namespace common +} // namespace oceanbase diff --git a/src/share/hybrid_search/ob_weighted_fusion.h b/src/share/hybrid_search/ob_weighted_fusion.h new file mode 100644 index 000000000..8b04811f9 --- /dev/null +++ b/src/share/hybrid_search/ob_weighted_fusion.h @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OB_WEIGHTED_FUSION_H +#define OB_WEIGHTED_FUSION_H + +#include "ob_hybrid_search_common.h" +#include "lib/container/ob_se_array.h" +#include "lib/hash/ob_hashmap.h" +#include + +namespace oceanbase +{ +namespace common +{ + +/* + * Weighted Fusion Method Implementation + * + * Basic Principle: + * By assigning weights to full-text and vector search, + * then normalizing and computing weighted sum of scores for each document. + * + * Formula: + * final_score = fts_weight * normalized_fts_score + vector_weight * normalized_vector_score + * + * Advantages: + * 1. Fine-grained control: can precisely control the impact ratio of FTS and vector search + * 2. Flexible adaptation: supports multiple normalization strategies + * 3. Business-oriented: weights can be adjusted dynamically based on business scenarios + * + * Application Scenarios: + * - Applications requiring fine-grained control over FTS and vector search ratio + * - Scenarios with clear business preferences (e.g., prioritizing keyword matching or semantic similarity) + * - Applications that can dynamically adjust weights based on query types + */ +class ObWeightedFusion +{ +public: + typedef common::hash::ObHashMap ResultMap; + + ObWeightedFusion(); + virtual ~ObWeightedFusion(); + + /* + * Initialize weighted fusion engine + * + * @param config Weighted fusion configuration parameters + * @param norm_config Normalization configuration parameters + * @param allocator Memory allocator + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int init(const ObWeightedFusionConfig &config, + const ObNormalizationConfig &norm_config, + ObIAllocator &allocator); + + /* + * Add full-text search results + * + * @param fts_results Full-text search result list + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int add_fts_results(const common::ObIArray &fts_results); + + /* + * Add vector search results + * + * @param vector_results Vector search result list + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int add_vector_results(const common::ObIArray &vector_results); + + /* + * Execute weighted fusion calculation + * + * This method will: + * 1. Collect score statistics from both result lists + * 2. Normalize scores according to normalization strategy + * 3. Apply weights to compute weighted sum + * 4. Sort by final score + * + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int fuse(); + + /* + * Get fusion results + * + * @param results Output parameter containing the fused result list + * @param limit Maximum number of results to return, 0 means return all results + * @return Returns OB_SUCCESS on success, corresponding error code on failure + */ + int get_results(common::ObIArray &results, int64_t limit = 0) const; + + /* + * Reset fusion engine state, prepare for next fusion + */ + void reset(); + + /* + * Get count of fused results + */ + int64_t get_fused_result_count() const { return fused_results_.count(); } + + /* + * Get single fused result + */ + const ObHybridSearchResult *get_result_at(int64_t index) const; + +private: + // Calculate statistics + int calculate_statistics(); + + // Normalize single score + double normalize_score(double score, bool is_fts); + + // Apply normalization strategy + double apply_normalization(double score, bool is_fts); + + // Min-Max normalization + double min_max_normalize(double score, double min_val, double max_val); + + // Z-Score normalization + double z_score_normalize(double score, double mean, double stddev); + + // Sigmoid normalization + double sigmoid_normalize(double score); + + // Validate configuration parameters + int validate_config(const ObWeightedFusionConfig &config) const; + +private: + // Weighted fusion configuration + ObWeightedFusionConfig fusion_config_; + + // Normalization configuration + ObNormalizationConfig norm_config_; + + // Full-text search results + common::ObSEArray fts_results_; + + // Vector search results + common::ObSEArray vector_results_; + + // Fused results + common::ObSEArray fused_results_; + + // Statistics for full-text search scores + struct FTSStats + { + double min_score_ = 0.0; + double max_score_ = 0.0; + double mean_score_ = 0.0; + double stddev_ = 0.0; + int64_t count_ = 0; + } fts_stats_; + + // Statistics for vector search scores + struct VectorStats + { + double min_score_ = 0.0; + double max_score_ = 0.0; + double mean_score_ = 0.0; + double stddev_ = 0.0; + int64_t count_ = 0; + } vector_stats_; + + // Whether initialization is complete + bool is_initialized_; + + // Memory allocator (non-owner) + ObIAllocator *allocator_; + + // Result mapping table for deduplication + ResultMap result_map_; + + // Whether statistics have been calculated + bool stats_calculated_; +}; + +} // namespace common +} // namespace oceanbase + +#endif // OB_WEIGHTED_FUSION_H diff --git a/src/share/hybrid_search/test_hybrid_search.cpp b/src/share/hybrid_search/test_hybrid_search.cpp new file mode 100644 index 000000000..b127cee0b --- /dev/null +++ b/src/share/hybrid_search/test_hybrid_search.cpp @@ -0,0 +1,422 @@ +/* + * Copyright (c) 2025 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Hybrid Search Unit Tests + * Includes comprehensive test cases for RRF fusion and weighted fusion + */ + +#include +#include "ob_rrf_fusion.h" +#include "ob_weighted_fusion.h" +#include "lib/allocator/ob_malloc.h" + +namespace oceanbase +{ +namespace common +{ + +/* + * ============================================= + * RRF 融合单元测试 + * ============================================= + */ + +class RRFFusionTest : public ::testing::Test +{ +protected: + void SetUp() override + { + allocator_ = new ObMallocAllocator(); + } + + void TearDown() override + { + delete allocator_; + } + + ObIAllocator *allocator_; +}; + +// 测试 1: 基础初始化 +TEST_F(RRFFusionTest, BasicInitialization) +{ + ObRRFFusion rrf; + ObRRFConfig config(60, 100); + + int ret = rrf.init(config, *allocator_); + ASSERT_EQ(OB_SUCCESS, ret); +} + +// 测试 2: 重复初始化应该失败 +TEST_F(RRFFusionTest, DuplicateInitialization) +{ + ObRRFFusion rrf; + ObRRFConfig config(60, 100); + + int ret1 = rrf.init(config, *allocator_); + ASSERT_EQ(OB_SUCCESS, ret1); + + int ret2 = rrf.init(config, *allocator_); + ASSERT_EQ(OB_INIT_TWICE, ret2); +} + +// 测试 3: 无效配置应该失败 +TEST_F(RRFFusionTest, InvalidConfig) +{ + ObRRFFusion rrf; + ObRRFConfig config(-1, 100); // 无效的 rank_constant + + int ret = rrf.init(config, *allocator_); + ASSERT_NE(OB_SUCCESS, ret); +} + +// 测试 4: 基本融合功能 +TEST_F(RRFFusionTest, BasicFusion) +{ + ObRRFFusion rrf; + ObRRFConfig config(60, 100); + + ASSERT_EQ(OB_SUCCESS, rrf.init(config, *allocator_)); + + // 准备全文搜索结果 + common::ObSEArray fts_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result; + result.doc_id_ = i + 1; + result.fts_score_ = 10.0 - i * 2.0; + fts_results.push_back(result); + } + + // 准备向量搜索结果 + common::ObSEArray vector_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result; + result.doc_id_ = (i % 3) + 1; + result.vector_score_ = 0.9 - i * 0.1; + vector_results.push_back(result); + } + + ASSERT_EQ(OB_SUCCESS, rrf.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, rrf.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, rrf.fuse()); + + // 验证融合结果 + ASSERT_GT(rrf.get_fused_result_count(), 0); + + // 获取前 2 个结果 + common::ObSEArray results; + ASSERT_EQ(OB_SUCCESS, rrf.get_results(results, 2)); + ASSERT_EQ(2, results.count()); + + // 验证结果按得分排序 + if (results.count() > 1) { + EXPECT_GE(results.at(0).final_score_, results.at(1).final_score_); + } +} + +// 测试 5: 全文搜索结果为空 +TEST_F(RRFFusionTest, EmptyFTSResults) +{ + ObRRFFusion rrf; + ObRRFConfig config(60, 100); + + ASSERT_EQ(OB_SUCCESS, rrf.init(config, *allocator_)); + + common::ObSEArray fts_results; + common::ObSEArray vector_results; + + ObHybridSearchResult result; + result.doc_id_ = 1; + result.vector_score_ = 0.9; + vector_results.push_back(result); + + ASSERT_EQ(OB_SUCCESS, rrf.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, rrf.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, rrf.fuse()); + + ASSERT_GT(rrf.get_fused_result_count(), 0); +} + +// 测试 6: 向量搜索结果为空 +TEST_F(RRFFusionTest, EmptyVectorResults) +{ + ObRRFFusion rrf; + ObRRFConfig config(60, 100); + + ASSERT_EQ(OB_SUCCESS, rrf.init(config, *allocator_)); + + common::ObSEArray fts_results; + common::ObSEArray vector_results; + + ObHybridSearchResult result; + result.doc_id_ = 1; + result.fts_score_ = 10.0; + fts_results.push_back(result); + + ASSERT_EQ(OB_SUCCESS, rrf.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, rrf.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, rrf.fuse()); + + ASSERT_GT(rrf.get_fused_result_count(), 0); +} + +// 测试 7: 重置功能 +TEST_F(RRFFusionTest, Reset) +{ + ObRRFFusion rrf; + ObRRFConfig config(60, 100); + + ASSERT_EQ(OB_SUCCESS, rrf.init(config, *allocator_)); + + common::ObSEArray results; + ObHybridSearchResult result; + result.doc_id_ = 1; + result.fts_score_ = 10.0; + results.push_back(result); + + ASSERT_EQ(OB_SUCCESS, rrf.add_fts_results(results)); + + rrf.reset(); + ASSERT_EQ(0, rrf.get_fts_result_count()); +} + +/* + * ============================================= + * 加权融合单元测试 + * ============================================= + */ + +class WeightedFusionTest : public ::testing::Test +{ +protected: + void SetUp() override + { + allocator_ = new ObMallocAllocator(); + } + + void TearDown() override + { + delete allocator_; + } + + ObIAllocator *allocator_; +}; + +// 测试 1: 基础初始化 +TEST_F(WeightedFusionTest, BasicInitialization) +{ + ObWeightedFusion fusion; + ObWeightedFusionConfig fusion_config(0.5, 0.5, true); + ObNormalizationConfig norm_config; + + int ret = fusion.init(fusion_config, norm_config, *allocator_); + ASSERT_EQ(OB_SUCCESS, ret); +} + +// 测试 2: 无效权重配置 +TEST_F(WeightedFusionTest, InvalidWeightConfig) +{ + ObWeightedFusion fusion; + ObWeightedFusionConfig fusion_config(-0.5, 1.5, true); // 无效权重 + ObNormalizationConfig norm_config; + + int ret = fusion.init(fusion_config, norm_config, *allocator_); + ASSERT_NE(OB_SUCCESS, ret); +} + +// 测试 3: Min-Max 规范化融合 +TEST_F(WeightedFusionTest, MinMaxNormalization) +{ + ObWeightedFusion fusion; + ObWeightedFusionConfig fusion_config(0.5, 0.5, true); + ObNormalizationConfig norm_config; + norm_config.norm_type_ = ObNormalizationConfig::NormalizationType::MIN_MAX; + + ASSERT_EQ(OB_SUCCESS, fusion.init(fusion_config, norm_config, *allocator_)); + + // 准备数据 + common::ObSEArray fts_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result; + result.doc_id_ = i + 1; + result.fts_score_ = 10.0 - i * 3.0; + fts_results.push_back(result); + } + + common::ObSEArray vector_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result; + result.doc_id_ = (i % 3) + 1; + result.vector_score_ = 0.9 - i * 0.2; + vector_results.push_back(result); + } + + ASSERT_EQ(OB_SUCCESS, fusion.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, fusion.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, fusion.fuse()); + + ASSERT_GT(fusion.get_fused_result_count(), 0); +} + +// 测试 4: Z-Score 规范化融合 +TEST_F(WeightedFusionTest, ZScoreNormalization) +{ + ObWeightedFusion fusion; + ObWeightedFusionConfig fusion_config(0.7, 0.3, true); + ObNormalizationConfig norm_config; + norm_config.norm_type_ = ObNormalizationConfig::NormalizationType::Z_SCORE; + + ASSERT_EQ(OB_SUCCESS, fusion.init(fusion_config, norm_config, *allocator_)); + + // 准备数据 + common::ObSEArray fts_results; + for (int i = 0; i < 5; ++i) { + ObHybridSearchResult result; + result.doc_id_ = i + 1; + result.fts_score_ = 10.0 - i; + fts_results.push_back(result); + } + + common::ObSEArray vector_results; + for (int i = 0; i < 5; ++i) { + ObHybridSearchResult result; + result.doc_id_ = i + 1; + result.vector_score_ = 0.8 - i * 0.1; + vector_results.push_back(result); + } + + ASSERT_EQ(OB_SUCCESS, fusion.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, fusion.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, fusion.fuse()); + + // 验证融合结果 + common::ObSEArray results; + ASSERT_EQ(OB_SUCCESS, fusion.get_results(results)); + + for (int i = 0; i < results.count(); ++i) { + const auto &result = results.at(i); + // 最终得分应该在合理范围内 + EXPECT_GE(result.final_score_, 0.0); + EXPECT_LE(result.final_score_, 2.0); + } +} + +// 测试 5: 关键词优先权重配置 +TEST_F(WeightedFusionTest, KeywordPriorityWeights) +{ + ObWeightedFusion fusion; + ObWeightedFusionConfig fusion_config(0.7, 0.3, true); + ObNormalizationConfig norm_config; + + ASSERT_EQ(OB_SUCCESS, fusion.init(fusion_config, norm_config, *allocator_)); + + // 准备数据 + common::ObSEArray fts_results; + ObHybridSearchResult fts_r1; + fts_r1.doc_id_ = 1; + fts_r1.fts_score_ = 15.0; + fts_results.push_back(fts_r1); + + common::ObSEArray vector_results; + ObHybridSearchResult vec_r1; + vec_r1.doc_id_ = 2; + vec_r1.vector_score_ = 0.95; + vector_results.push_back(vec_r1); + + ASSERT_EQ(OB_SUCCESS, fusion.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, fusion.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, fusion.fuse()); + + // 文档 1 应该获得更高的分数(因为全文权重较高) + const auto *result1 = fusion.get_result_at(0); + ASSERT_NE(nullptr, result1); +} + +// 测试 6: Sigmoid 规范化 +TEST_F(WeightedFusionTest, SigmoidNormalization) +{ + ObWeightedFusion fusion; + ObWeightedFusionConfig fusion_config(0.5, 0.5, true); + ObNormalizationConfig norm_config; + norm_config.norm_type_ = ObNormalizationConfig::NormalizationType::SIGMOID; + + ASSERT_EQ(OB_SUCCESS, fusion.init(fusion_config, norm_config, *allocator_)); + + // 准备数据 + common::ObSEArray fts_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result; + result.doc_id_ = i + 1; + result.fts_score_ = 5.0 - i * 1.5; + fts_results.push_back(result); + } + + common::ObSEArray vector_results; + for (int i = 0; i < 3; ++i) { + ObHybridSearchResult result; + result.doc_id_ = i + 1; + result.vector_score_ = 0.7 - i * 0.15; + vector_results.push_back(result); + } + + ASSERT_EQ(OB_SUCCESS, fusion.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, fusion.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, fusion.fuse()); + + ASSERT_GT(fusion.get_fused_result_count(), 0); +} + +// 测试 7: 无规范化 +TEST_F(WeightedFusionTest, NoNormalization) +{ + ObWeightedFusion fusion; + ObWeightedFusionConfig fusion_config(0.5, 0.5, false); + ObNormalizationConfig norm_config; + norm_config.norm_type_ = ObNormalizationConfig::NormalizationType::NONE; + + ASSERT_EQ(OB_SUCCESS, fusion.init(fusion_config, norm_config, *allocator_)); + + // 准备数据 + common::ObSEArray fts_results; + ObHybridSearchResult fts_r; + fts_r.doc_id_ = 1; + fts_r.fts_score_ = 10.0; + fts_results.push_back(fts_r); + + common::ObSEArray vector_results; + ObHybridSearchResult vec_r; + vec_r.doc_id_ = 1; + vec_r.vector_score_ = 0.5; + vector_results.push_back(vec_r); + + ASSERT_EQ(OB_SUCCESS, fusion.add_fts_results(fts_results)); + ASSERT_EQ(OB_SUCCESS, fusion.add_vector_results(vector_results)); + ASSERT_EQ(OB_SUCCESS, fusion.fuse()); + + const auto *result = fusion.get_result_at(0); + ASSERT_NE(nullptr, result); + // final_score = 0.5 * 10.0 + 0.5 * 0.5 = 5.25 + EXPECT_DOUBLE_EQ(5.25, result->final_score_); +} + +} // namespace common +} // namespace oceanbase + +// 运行所有测试 +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}