diff --git a/app/services/oral_history/chunk_fetcher.rb b/app/services/oral_history/chunk_fetcher.rb new file mode 100644 index 000000000..25ace856d --- /dev/null +++ b/app/services/oral_history/chunk_fetcher.rb @@ -0,0 +1,113 @@ +module OralHistory + # Fetch OralHistoryChunks for RAG querying, from pg db, using `neighbor` gem for vector + # search cosine similarity. + # + # Can also use fancy SQL to limit to only so many per document, and add + # other constraints. + class ChunkFetcher + attr_reader :top_k, :question_embedding, :max_per_interview, :oversample_factor + attr_reader :exclude_oral_history_chunk_ids, :exclude_oral_history_content_ids + + + # @param top_k [Integer] how many chunks do you want back + # + # @param max_per_interview [Integer] if set, only include top per document_limit + # per oral history. + # + # @param oversample_factor [Integer] When doing max_per_interview, we need to kind of originally + # fetch more than that, so we can then apply the max_per_interview limit and still have enough. + # It all happens inside a SQL subquery, but we can't actually rank _everything_. + # + # @param exclude_chunks [Array] Array of OralHistoryChunk, or OralHistoryChunk#id, exclude these + # + # @param exclude_interviews [Array] Interviews to exclude. can be as Work, OralHistoryContent, + # or OralHistoryContent#id + def initialize(top_k:, question_embedding:, max_per_interview: nil, oversample_factor: 3, exclude_chunks: nil, exclude_interviews: nil) + @top_k = top_k + @question_embedding = question_embedding + @max_per_interview = max_per_interview + @oversample_factor = oversample_factor + + if exclude_chunks + @exclude_oral_history_chunk_ids = exclude_chunks.collect {|i| i.kind_of?(OralHistoryChunk) ? i.id : i } + end + + if exclude_interviews + @exclude_oral_history_content_ids = exclude_interviews.collect do |i| + if i.kind_of?(Work) + i.oral_history_content.id + elsif i.kind_of?(OralHistoryContent) + i.id + else + i + end + end + end + end + + # @return [Array] Where each one also has a `neighbor_distance` attribute + # with cosine distance, added by neighbor gem. Set returned has strict_loading to ensure + # pre-loading to avoid n+1 fetch problem. + def fetch_chunks + relation = if max_per_interview + wrap_relation_for_max_per_interview(base_relation: base_relation, max_per_interview: max_per_interview, inner_limit: oversample_factor * top_k) + else + base_relation + end + + relation.limit(top_k).strict_loading + end + + # Without limit count, we'll add that later. + def base_relation + # Preload work, so we can get title or other metadata we might want. + relation = OralHistoryChunk.neighbors_for_embedding(question_embedding).includes(oral_history_content: :work) + + # exclude specific chunks? + if exclude_oral_history_chunk_ids.present? + relation = relation.where.not(id: exclude_oral_history_chunk_ids) + end + + # exclude interviews? + if exclude_oral_history_content_ids.present? + relation = relation.where.not(oral_history_content_id: exclude_oral_history_content_ids) + end + + relation + end + + # We need to take base_scope and use it as a Postgres CTE (Common Table Expression) + # to select from, but adding on a ROW_NUMBER window function, that let's us limit + # to top max_per_interview + # + # Kinda tricky. Got from google and talking to LLMs. + # + # @return [ActiveRecord::Relation] that's been wrapped with a CTE to enforce max_per_interview limits. + def wrap_relation_for_max_per_interview(base_relation:, max_per_interview:, inner_limit:) + base_relation = base_relation.dup # cause we're gonna mutate it, avoid confusion. + + # add a 'select' using semi-private select_values API + # Not sure what neighbor's type erialization is doing we couldn't get right ourselves, but it works. + base_relation.select_values += [ + ActiveRecord::Base.sanitize_sql([ + "ROW_NUMBER() OVER (PARTITION BY oral_history_content_id ORDER BY oral_history_chunks.embedding <=> ?) as doc_rank", + Neighbor::Type::Vector.new.serialize(question_embedding) + ]) + ] + + # In the inner CTE, have to fetch oversampled, so we can wind up with + # hopefully enough in outer. Leaving inner unlimited would be peformance, + # cause of how indexing works it doesn't need to calculate them all. + base_relation.limit(inner_limit) + + # copy the order from inner scope, where neighbor gem set it to be vector distance asc + # We leave the real limit for the caller to set + OralHistoryChunk.with(ranked_chunks: base_relation). + select("*"). + from("ranked_chunks"). + where("doc_rank <= ?", max_per_interview). + order("neighbor_distance"). + includes(oral_history_content: :work) + end + end +end diff --git a/app/services/oral_history/claude_interactor.rb b/app/services/oral_history/claude_interactor.rb index c957c5b8a..c532220ff 100644 --- a/app/services/oral_history/claude_interactor.rb +++ b/app/services/oral_history/claude_interactor.rb @@ -11,8 +11,6 @@ class ClaudeInteractor # claude sonnet 4.5 MODEL_ID = "global.anthropic.claude-sonnet-4-5-20250929-v1:0" - INITIAL_CHUNK_COUNT = 8 - ANSWER_UNAVAILABLE_TEXT = "I am unable to answer this question with the methods and sources available." # should e threadsafe, and better to re-use for re-used connections maybe @@ -63,7 +61,7 @@ def extract_answer(response) # # can raise a Aws::Errors::ServiceError def get_response(conversation_record:nil) - chunks = get_chunks(k: INITIAL_CHUNK_COUNT) + chunks = get_chunks conversation_record&.record_chunks_used(chunks) conversation_record&.request_sent_at = Time.current @@ -106,11 +104,18 @@ def render_user_prompt(chunks) ) end - # @param k [Integer] how many chunks to get - def get_chunks(k: INITIAL_CHUNK_COUNT) - # TODO: the SQL log for the neighbor query is too huge!! - # Preload work, so we can get title or other metadata we might want. - OralHistoryChunk.neighbors_for_embedding(question_embedding).limit(k).includes(oral_history_content: :work).strict_loading + + def get_chunks + # fetch first 8 closest-vector chunks + chunks = OralHistory::ChunkFetcher.new(question_embedding: question_embedding, top_k: 8).fetch_chunks + + # now fetch another 8, but only 1-per-interview, not including any interviews from above + chunks += OralHistory::ChunkFetcher.new(question_embedding: question_embedding, + top_k: 8, + max_per_interview: 1, + exclude_interviews: chunks.collect(&:oral_history_content_id).uniq).fetch_chunks + + chunks end def format_chunks(chunks) diff --git a/spec/services/oral_history/chunk_fetcher_spec.rb b/spec/services/oral_history/chunk_fetcher_spec.rb new file mode 100644 index 000000000..2ebe36367 --- /dev/null +++ b/spec/services/oral_history/chunk_fetcher_spec.rb @@ -0,0 +1,114 @@ +require 'rails_helper' +require 'matrix' + +describe OralHistory::ChunkFetcher do + def cosine_similarity(a, b) + a = Vector[*a] + b = Vector[*b] + dot = a.inner_product(b) + dot / (a.norm * b.norm) + end + + + # These are all arbitrary, we just took some random sample vectors, having them + # not be all zero gives us more realistic data, give us a prefix, we'll pad + # with zeroes to correct length. + def fake_vector(*prefix) + prefix + ([0.0] * (OralHistoryChunk::FAKE_EMBEDDING.length - prefix.length)) + end + + let(:fake_question_embedding) { fake_vector(0.03263719,-0.021255592,-0.018256947,0.012259656,0.008308401)} + + + let(:work1) { create(:oral_history_work) } + + + let!(:chunk1) { create(:oral_history_chunk, + oral_history_content: work1.oral_history_content, + embedding: fake_vector(0.01759516,-0.0453438, -0.029577527, -0.032289326, 0.012045433), + speakers: ["SMITH"])} + + let!(:chunk2) { create(:oral_history_chunk, + oral_history_content: work1.oral_history_content, + embedding: fake_vector(0.059072047,-0.021131188,-0.013840758,-0.0077753244,-0.02725617), + speakers: ["SMITH", "JONES"], text: "Chunk 2")} + + + let(:work2) { create(:oral_history_work) } + + + let!(:chunk3) { create(:oral_history_chunk, + oral_history_content: work2.oral_history_content, + embedding: fake_vector(0.015151533,-0.01646033,-0.021422518,-0.024602171,0.009659404), + speakers: ["SMITH", "JONES"], text: "Chunk 3")} + + let!(:chunk4) { create(:oral_history_chunk, + oral_history_content: work2.oral_history_content, + embedding: fake_vector(0.013049184,-0.019433592,-0.024848722,-0.010990473,0.024592385), + speakers: ["SMITH", "JONES"], text: "Chunk 3")} + + + let(:all_chunks) { [chunk1, chunk2, chunk3, chunk4] } + + it "fetches" do + results = described_class.new(question_embedding: fake_question_embedding, top_k: 2).fetch_chunks + + expect(results.length).to be 2 + expect(results).to all be_kind_of(OralHistoryChunk) + expect(results).to all satisfy { |r| r.neighbor_distance.present? } + + # make sure we can follow associations without triggering strict loading error + results.collect(&:oral_history_content).flatten.collect(&:work) + end + + describe "max_per_interview" do + it "fetches with limit" do + # We ask for 3, but can only get 2 because of per-doc limit + results = described_class.new(question_embedding: fake_question_embedding, top_k: 3, max_per_interview: 1).fetch_chunks + + expect(results.length).to eq 2 + + # Two oral_history_content_id's, each with only 1 chunk + groups = results.group_by {|c| c.oral_history_content_id } + expect(groups.count).to eq 2 + expect(groups.values).to all satisfy { |v| v.length== 1} + + # included chunks have closer vector distance to question than excluded + excluded = [chunk1, chunk2, chunk3, chunk4].find_all { |c| ! c.id.in?(results.collect(&:id)) } + excluded_similarity = excluded.collect {|c| cosine_similarity(c.embedding, fake_question_embedding) } + + included_similarity = results.collect {|c| cosine_similarity(c.embedding, fake_question_embedding) } + + # everything included has more similarity than anything excluded! + expect(included_similarity.min).to be >= excluded_similarity.max + + # make sure we can follow associations without triggering strict loading error + results.collect(&:oral_history_content).flatten.collect(&:work) + end + end + + describe "exclude_chunks" do + it "can exclude chunks by id" do + exclude_chunk_ids = [chunk1, chunk3].collect(&:id) + results = described_class.new(question_embedding: fake_question_embedding, top_k: 100, exclude_chunks: exclude_chunk_ids).fetch_chunks + + expected_ids = all_chunks.collect(&:id) - exclude_chunk_ids + expect(results.collect(&:id)).to match_array expected_ids + end + + it "can exclude chunks by chunk" do + results = described_class.new(question_embedding: fake_question_embedding, top_k: 100, exclude_chunks: [chunk2, chunk4]).fetch_chunks + + expected_ids = all_chunks.collect(&:id) - [chunk2, chunk4].collect(&:id) + expect(results.collect(&:id)).to match_array expected_ids + end + end + + describe "exclude_interviews" do + it "can exclude interviews by OralHistoryContent model" do + results = described_class.new(question_embedding: fake_question_embedding, top_k: 100, exclude_interviews: [work1.oral_history_content]).fetch_chunks + + expect(results.collect(&:oral_history_content_id)).not_to include(work1.oral_history_content.id) + end + end +end