Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions app/services/oral_history/chunk_fetcher.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module OralHistory
# Fetch OralHistoryChunks for RAG querying, from pg db, using `neighbor` gem for vector
# search cosine similarity.
#
# Can also use fancy SQL to limit to only so many per document, and add
# other constraints.
class ChunkFetcher
attr_reader :top_k, :question_embedding, :max_per_interview, :oversample_factor
attr_reader :exclude_oral_history_chunk_ids, :exclude_oral_history_content_ids


# @param top_k [Integer] how many chunks do you want back
#
# @param max_per_interview [Integer] if set, only include top per document_limit
# per oral history.
#
# @param oversample_factor [Integer] When doing max_per_interview, we need to kind of originally
# fetch more than that, so we can then apply the max_per_interview limit and still have enough.
# It all happens inside a SQL subquery, but we can't actually rank _everything_.
#
# @param exclude_chunks [Array<OralHistoryChunk,Integer>] Array of OralHistoryChunk, or OralHistoryChunk#id, exclude these
#
# @param exclude_interviews [Array<Work,OralHistoryContent,Integer>] Interviews to exclude. can be as Work, OralHistoryContent,
# or OralHistoryContent#id
def initialize(top_k:, question_embedding:, max_per_interview: nil, oversample_factor: 3, exclude_chunks: nil, exclude_interviews: nil)
@top_k = top_k
@question_embedding = question_embedding
@max_per_interview = max_per_interview
@oversample_factor = oversample_factor

if exclude_chunks
@exclude_oral_history_chunk_ids = exclude_chunks.collect {|i| i.kind_of?(OralHistoryChunk) ? i.id : i }
end

if exclude_interviews
@exclude_oral_history_content_ids = exclude_interviews.collect do |i|
if i.kind_of?(Work)
i.oral_history_content.id
elsif i.kind_of?(OralHistoryContent)
i.id
else
i
end
end
end
end

# @return [Array<OralHistoryChunk>] Where each one also has a `neighbor_distance` attribute
# with cosine distance, added by neighbor gem. Set returned has strict_loading to ensure
# pre-loading to avoid n+1 fetch problem.
def fetch_chunks
relation = if max_per_interview
wrap_relation_for_max_per_interview(base_relation: base_relation, max_per_interview: max_per_interview, inner_limit: oversample_factor * top_k)
else
base_relation
end

relation.limit(top_k).strict_loading
end

# Without limit count, we'll add that later.
def base_relation
# Preload work, so we can get title or other metadata we might want.
relation = OralHistoryChunk.neighbors_for_embedding(question_embedding).includes(oral_history_content: :work)

# exclude specific chunks?
if exclude_oral_history_chunk_ids.present?
relation = relation.where.not(id: exclude_oral_history_chunk_ids)
end

# exclude interviews?
if exclude_oral_history_content_ids.present?
relation = relation.where.not(oral_history_content_id: exclude_oral_history_content_ids)
end

relation
end

# We need to take base_scope and use it as a Postgres CTE (Common Table Expression)
# to select from, but adding on a ROW_NUMBER window function, that let's us limit
# to top max_per_interview
#
# Kinda tricky. Got from google and talking to LLMs.
#
# @return [ActiveRecord::Relation] that's been wrapped with a CTE to enforce max_per_interview limits.
def wrap_relation_for_max_per_interview(base_relation:, max_per_interview:, inner_limit:)
base_relation = base_relation.dup # cause we're gonna mutate it, avoid confusion.

# add a 'select' using semi-private select_values API
# Not sure what neighbor's type erialization is doing we couldn't get right ourselves, but it works.
base_relation.select_values += [
ActiveRecord::Base.sanitize_sql([
"ROW_NUMBER() OVER (PARTITION BY oral_history_content_id ORDER BY oral_history_chunks.embedding <=> ?) as doc_rank",
Neighbor::Type::Vector.new.serialize(question_embedding)
])
]

# In the inner CTE, have to fetch oversampled, so we can wind up with
# hopefully enough in outer. Leaving inner unlimited would be peformance,
# cause of how indexing works it doesn't need to calculate them all.
base_relation.limit(inner_limit)

# copy the order from inner scope, where neighbor gem set it to be vector distance asc
# We leave the real limit for the caller to set
OralHistoryChunk.with(ranked_chunks: base_relation).
select("*").
from("ranked_chunks").
where("doc_rank <= ?", max_per_interview).
order("neighbor_distance").
includes(oral_history_content: :work)
end
end
end
21 changes: 13 additions & 8 deletions app/services/oral_history/claude_interactor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ class ClaudeInteractor
# claude sonnet 4.5
MODEL_ID = "global.anthropic.claude-sonnet-4-5-20250929-v1:0"

INITIAL_CHUNK_COUNT = 8

ANSWER_UNAVAILABLE_TEXT = "I am unable to answer this question with the methods and sources available."

# should e threadsafe, and better to re-use for re-used connections maybe
Expand Down Expand Up @@ -63,7 +61,7 @@ def extract_answer(response)
#
# can raise a Aws::Errors::ServiceError
def get_response(conversation_record:nil)
chunks = get_chunks(k: INITIAL_CHUNK_COUNT)
chunks = get_chunks

conversation_record&.record_chunks_used(chunks)
conversation_record&.request_sent_at = Time.current
Expand Down Expand Up @@ -106,11 +104,18 @@ def render_user_prompt(chunks)
)
end

# @param k [Integer] how many chunks to get
def get_chunks(k: INITIAL_CHUNK_COUNT)
# TODO: the SQL log for the neighbor query is too huge!!
# Preload work, so we can get title or other metadata we might want.
OralHistoryChunk.neighbors_for_embedding(question_embedding).limit(k).includes(oral_history_content: :work).strict_loading

def get_chunks
# fetch first 8 closest-vector chunks
chunks = OralHistory::ChunkFetcher.new(question_embedding: question_embedding, top_k: 8).fetch_chunks

# now fetch another 8, but only 1-per-interview, not including any interviews from above
chunks += OralHistory::ChunkFetcher.new(question_embedding: question_embedding,
top_k: 8,
max_per_interview: 1,
exclude_interviews: chunks.collect(&:oral_history_content_id).uniq).fetch_chunks

chunks
end

def format_chunks(chunks)
Expand Down
114 changes: 114 additions & 0 deletions spec/services/oral_history/chunk_fetcher_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
require 'rails_helper'
require 'matrix'

describe OralHistory::ChunkFetcher do
def cosine_similarity(a, b)
a = Vector[*a]
b = Vector[*b]
dot = a.inner_product(b)
dot / (a.norm * b.norm)
end


# These are all arbitrary, we just took some random sample vectors, having them
# not be all zero gives us more realistic data, give us a prefix, we'll pad
# with zeroes to correct length.
def fake_vector(*prefix)
prefix + ([0.0] * (OralHistoryChunk::FAKE_EMBEDDING.length - prefix.length))
end

let(:fake_question_embedding) { fake_vector(0.03263719,-0.021255592,-0.018256947,0.012259656,0.008308401)}


let(:work1) { create(:oral_history_work) }


let!(:chunk1) { create(:oral_history_chunk,
oral_history_content: work1.oral_history_content,
embedding: fake_vector(0.01759516,-0.0453438, -0.029577527, -0.032289326, 0.012045433),
speakers: ["SMITH"])}

let!(:chunk2) { create(:oral_history_chunk,
oral_history_content: work1.oral_history_content,
embedding: fake_vector(0.059072047,-0.021131188,-0.013840758,-0.0077753244,-0.02725617),
speakers: ["SMITH", "JONES"], text: "Chunk 2")}


let(:work2) { create(:oral_history_work) }


let!(:chunk3) { create(:oral_history_chunk,
oral_history_content: work2.oral_history_content,
embedding: fake_vector(0.015151533,-0.01646033,-0.021422518,-0.024602171,0.009659404),
speakers: ["SMITH", "JONES"], text: "Chunk 3")}

let!(:chunk4) { create(:oral_history_chunk,
oral_history_content: work2.oral_history_content,
embedding: fake_vector(0.013049184,-0.019433592,-0.024848722,-0.010990473,0.024592385),
speakers: ["SMITH", "JONES"], text: "Chunk 3")}


let(:all_chunks) { [chunk1, chunk2, chunk3, chunk4] }

it "fetches" do
results = described_class.new(question_embedding: fake_question_embedding, top_k: 2).fetch_chunks

expect(results.length).to be 2
expect(results).to all be_kind_of(OralHistoryChunk)
expect(results).to all satisfy { |r| r.neighbor_distance.present? }

# make sure we can follow associations without triggering strict loading error
results.collect(&:oral_history_content).flatten.collect(&:work)
end

describe "max_per_interview" do
it "fetches with limit" do
# We ask for 3, but can only get 2 because of per-doc limit
results = described_class.new(question_embedding: fake_question_embedding, top_k: 3, max_per_interview: 1).fetch_chunks

expect(results.length).to eq 2

# Two oral_history_content_id's, each with only 1 chunk
groups = results.group_by {|c| c.oral_history_content_id }
expect(groups.count).to eq 2
expect(groups.values).to all satisfy { |v| v.length== 1}

# included chunks have closer vector distance to question than excluded
excluded = [chunk1, chunk2, chunk3, chunk4].find_all { |c| ! c.id.in?(results.collect(&:id)) }
excluded_similarity = excluded.collect {|c| cosine_similarity(c.embedding, fake_question_embedding) }

included_similarity = results.collect {|c| cosine_similarity(c.embedding, fake_question_embedding) }

# everything included has more similarity than anything excluded!
expect(included_similarity.min).to be >= excluded_similarity.max

# make sure we can follow associations without triggering strict loading error
results.collect(&:oral_history_content).flatten.collect(&:work)
end
end

describe "exclude_chunks" do
it "can exclude chunks by id" do
exclude_chunk_ids = [chunk1, chunk3].collect(&:id)
results = described_class.new(question_embedding: fake_question_embedding, top_k: 100, exclude_chunks: exclude_chunk_ids).fetch_chunks

expected_ids = all_chunks.collect(&:id) - exclude_chunk_ids
expect(results.collect(&:id)).to match_array expected_ids
end

it "can exclude chunks by chunk" do
results = described_class.new(question_embedding: fake_question_embedding, top_k: 100, exclude_chunks: [chunk2, chunk4]).fetch_chunks

expected_ids = all_chunks.collect(&:id) - [chunk2, chunk4].collect(&:id)
expect(results.collect(&:id)).to match_array expected_ids
end
end

describe "exclude_interviews" do
it "can exclude interviews by OralHistoryContent model" do
results = described_class.new(question_embedding: fake_question_embedding, top_k: 100, exclude_interviews: [work1.oral_history_content]).fetch_chunks

expect(results.collect(&:oral_history_content_id)).not_to include(work1.oral_history_content.id)
end
end
end