Skip to content

Commit

Permalink
Port over ES-LTR change to fix threading issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sstults committed Sep 16, 2024
1 parent 61d9a71 commit e0ad7d4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/o19s/es/ltr/query/RankerQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public boolean isCacheable(LeafReaderContext ctx) {
// XXX: this is not thread safe and may run into extremely weird issues
// if the searcher uses the parallel collector
// Hopefully elastic never runs
MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.FeatureVectorSupplier();
MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.MutableSupplier<>();
FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier);
LtrRewriteContext context = new LtrRewriteContext(ranker, vectorSupplier);
for (Query q : queries) {
Expand Down
28 changes: 4 additions & 24 deletions src/main/java/com/o19s/es/ltr/utils/Suppliers.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

package com.o19s.es.ltr.utils;

import com.o19s.es.ltr.ranker.LtrRanker;
import org.opensearch.core.Assertions;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

public final class Suppliers {
Expand Down Expand Up @@ -66,33 +64,15 @@ public E get() {
* A mutable supplier
*/
public static class MutableSupplier<T> implements Supplier<T> {
T obj;
private final AtomicReference<T> ref = new AtomicReference<>();

@Override
public T get() {
return obj;
return ref.get();
}

public void set(T obj) {
this.obj = obj;
}
}

/**
* Simple wrapper to make sure we run on the same thread
*/
public static class FeatureVectorSupplier extends MutableSupplier<LtrRanker.FeatureVector> {
private final long threadId = Assertions.ENABLED ? Thread.currentThread().getId() : 0;

public LtrRanker.FeatureVector get() {
assert threadId == Thread.currentThread().getId();
return super.get();
}

@Override
public void set(LtrRanker.FeatureVector obj) {
assert threadId == Thread.currentThread().getId();
super.set(obj);
this.ref.set(obj);
}
}
}

0 comments on commit e0ad7d4

Please sign in to comment.