diff --git a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java index f9938071..bc91daf6 100644 --- a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java @@ -203,7 +203,7 @@ public boolean isCacheable(LeafReaderContext ctx) { // XXX: this is not thread safe and may run into extremely weird issues // if the searcher uses the parallel collector // Hopefully elastic never runs - MutableSupplier vectorSupplier = new Suppliers.FeatureVectorSupplier(); + MutableSupplier vectorSupplier = new Suppliers.MutableSupplier<>(); FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier); LtrRewriteContext context = new LtrRewriteContext(ranker, vectorSupplier); for (Query q : queries) { diff --git a/src/main/java/com/o19s/es/ltr/utils/Suppliers.java b/src/main/java/com/o19s/es/ltr/utils/Suppliers.java index 94bb7faf..d54fbcea 100644 --- a/src/main/java/com/o19s/es/ltr/utils/Suppliers.java +++ b/src/main/java/com/o19s/es/ltr/utils/Suppliers.java @@ -16,10 +16,8 @@ package com.o19s.es.ltr.utils; -import com.o19s.es.ltr.ranker.LtrRanker; -import org.opensearch.core.Assertions; - import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; public final class Suppliers { @@ -66,33 +64,15 @@ public E get() { * A mutable supplier */ public static class MutableSupplier implements Supplier { - T obj; + private final AtomicReference ref = new AtomicReference<>(); @Override public T get() { - return obj; + return ref.get(); } public void set(T obj) { - this.obj = obj; - } - } - - /** - * Simple wrapper to make sure we run on the same thread - */ - public static class FeatureVectorSupplier extends MutableSupplier { - private final long threadId = Assertions.ENABLED ? Thread.currentThread().getId() : 0; - - public LtrRanker.FeatureVector get() { - assert threadId == Thread.currentThread().getId(); - return super.get(); - } - - @Override - public void set(LtrRanker.FeatureVector obj) { - assert threadId == Thread.currentThread().getId(); - super.set(obj); + this.ref.set(obj); } } }