Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ New Features
* util-jvm: Add gc pause stats for all collector pools, including G1. ``PHAB_ID=D1176049``
* util-stats: Expose dimensional metrics APIs and allow metrics with an indeterminate
identity to be exported through the Prometheus exporter. ``PHAB_ID=D1218090``

* util-core: Add AsyncStream methods: `groupBy`, `distinct`, `distinctBy`, `contains`, `exists`, `find`, `collect`,
`collectFirst`, `to`, `toList`, `toSet`, `toArray`, `toMap`

24.5.0
------
Expand Down
206 changes: 205 additions & 1 deletion util-core/src/main/scala/com/twitter/concurrent/AsyncStream.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package com.twitter.concurrent

import com.twitter.conversions.SeqUtil
import com.twitter.util.{Future, Return, Throw, Promise}
import com.twitter.util.{Future, Promise, Return, Throw}

import scala.annotation.varargs
import scala.collection.{Factory, mutable}
import scala.reflect.ClassTag

/**
* A representation of a lazy (and possibly infinite) sequence of asynchronous
Expand Down Expand Up @@ -434,6 +437,207 @@ sealed abstract class AsyncStream[+A] {
case Embed(fas) => Embed(fas.map(_.flatten))
}

/**
* Groups elements of the stream by a key function and collects them into collections.
*
* @param f the function to compute a key for each element
* @param factory the factory to create collections for each group
* @tparam K the type of keys
* @tparam C the type of collection to store elements for each key
* @return a Future containing a Map where keys are the result of applying `f`
* and values are collections of elements that had the same key
*
* @example {{{
* val stream = AsyncStream(1, 2, 3, 4, 5, 6)
* stream.groupBy(_ % 2) // Future(Map(0 -> List(2, 4, 6), 1 -> List(1, 3, 5)))
* }}}
*/
def groupBy[K, C](f: A => K)(implicit factory: Factory[A, C]): Future[Map[K, C]] = {
val m = mutable.Map.empty[K, mutable.Builder[A, C]]
foreach(t => m.getOrElseUpdate(f(t), factory.newBuilder) += t).map { _ =>
val builder = Map.newBuilder[K, C]
if (m.knownSize >= 0) builder.sizeHint(m.knownSize)
m.foreach { case (k, v) => builder.addOne(k, v.result()) }
builder.result()
}
}

/**
* Returns a stream with duplicate elements removed, comparing by equality.
*
* @return a new AsyncStream containing only distinct elements
*
* @example {{{
* val stream = AsyncStream(1, 2, 2, 3, 1, 4)
* stream.distinct // AsyncStream(1, 2, 3, 4)
* }}}
*/
def distinct: AsyncStream[A] = distinctBy(identity)

/**
* Returns a stream with duplicate elements removed, where duplicates are determined
* by applying a function to each element.
*
* @param f the function to compute a value for comparison
* @tparam B the type of values used for comparison
* @return a new AsyncStream containing only elements that are distinct according to `f`
*
* @example {{{
* case class Person(name: String, age: Int)
* val stream = AsyncStream(Person("Alice", 25), Person("Bob", 30), Person("Alice", 35))
* stream.distinctBy(_.name) // AsyncStream(Person("Alice", 25), Person("Bob", 30))
* }}}
*/
def distinctBy[B](f: A => B): AsyncStream[A] = {
val traversedValues = mutable.HashSet.empty[B]
filter(a => traversedValues.add(f(a)))
}

/**
* Tests whether the stream contains a specific element.
*
* @param elem the element to test for membership
* @tparam A1 the type of the element, which must be a supertype of A
* @return a Future containing true if the element is found, false otherwise
*
* @example {{{
* val stream = AsyncStream(1, 2, 3, 4, 5)
* stream.contains(3) // Future(true)
* stream.contains(6) // Future(false)
* }}}
*/
def contains[A1 >: A](elem: A1): Future[Boolean] = exists(_ == elem)

/**
* Tests whether any element of the stream satisfies a predicate.
*
* @param p the predicate to test
* @return a Future containing true if any element satisfies the predicate, false otherwise
*
* @example {{{
* val stream = AsyncStream(1, 2, 3, 4, 5)
* stream.exists(_ > 3) // Future(true)
* stream.exists(_ > 10) // Future(false)
* }}}
*/
def exists(p: A => Boolean): Future[Boolean] = filter(p).isEmpty.map(!_)

/**
* Finds the first element that satisfies a predicate.
*
* @param p the predicate to test
* @return a Future containing Some(element) if found, None otherwise
*
* @example {{{
* val stream = AsyncStream(1, 2, 3, 4, 5)
* stream.find(_ > 3) // Future(Some(4))
* stream.find(_ > 10) // Future(None)
* }}}
*/
def find(p: A => Boolean): Future[Option[A]] = filter(p).head

/**
* Builds a new stream by applying a partial function to all elements where it is defined.
*
* @param pf the partial function to apply
* @tparam B the element type of the returned stream
* @return a new AsyncStream containing elements transformed by the partial function
*
* @example {{{
* val stream = AsyncStream(1, 2, 3, 4, 5)
* stream.collect { case x if x % 2 == 0 => x * 2 } // AsyncStream(4, 8)
* }}}
*/
def collect[B](pf: PartialFunction[A, B]): AsyncStream[B] =
flatMap(a => if (pf.isDefinedAt(a)) AsyncStream.of(pf(a)) else AsyncStream.empty)

/**
* Finds the first element where the partial function is defined and applies it.
*
* @param pf the partial function to apply
* @tparam B the result type of the partial function
* @return a Future containing Some(result) if the partial function was applicable
* to any element, None otherwise
*
* @example {{{
* val stream = AsyncStream("1", "hello", "2", "world")
* stream.collectFirst { case s if s.forall(_.isDigit) => s.toInt } // Future(Some(1))
* }}}
*/
def collectFirst[B](pf: PartialFunction[A, B]): Future[Option[B]] = collect(pf).head

/**
* Converts the stream to a specific collection type.
*
* @param factory the factory to create the target collection
* @tparam C the type of the collection to create
* @return a Future containing all elements of the stream in the specified collection type
*
* @example {{{
* val stream = AsyncStream(1, 2, 3, 4, 5)
* stream.to(Vector) // Future(Vector(1, 2, 3, 4, 5))
* stream.to(Set) // Future(Set(1, 2, 3, 4, 5))
* }}}
*/
def to[C](implicit factory: Factory[A, C]): Future[C] = {
val builder = factory.newBuilder
foreach(builder.addOne).map(_ => builder.result())
}

/**
* Converts the stream to a List.
*
* @return a Future containing all elements of the stream in a List
*
* @example {{{
* val stream = AsyncStream(1, 2, 3)
* stream.toList // Future(List(1, 2, 3))
* }}}
*/
def toList: Future[List[A]] = to(List)

/**
* Converts the stream to a Set.
*
* @tparam A1 the element type of the resulting Set, which must be a supertype of A
* @return a Future containing all distinct elements of the stream in a Set
*
* @example {{{
* val stream = AsyncStream(1, 2, 2, 3, 1)
* stream.toSet // Future(Set(1, 2, 3))
* }}}
*/
def toSet[A1 >: A]: Future[Set[A1]] = to(Set)

/**
* Converts the stream to an Array.
*
* @param tag the ClassTag for the array element type
* @tparam A1 the element type of the resulting Array, which must be a supertype of A
* @return a Future containing all elements of the stream in an Array
*
* @example {{{
* val stream = AsyncStream(1, 2, 3)
* stream.toArray // Future(Array(1, 2, 3))
* }}}
*/
def toArray[A1 >: A](implicit tag: ClassTag[A1]): Future[Array[A1]] = to(Array)

/**
* Converts the stream to a Map. The stream elements must be key-value pairs.
*
* @param ev evidence that A is a subtype of (K, V)
* @tparam K the key type
* @tparam V the value type
* @return a Future containing all key-value pairs from the stream in a Map
*
* @example {{{
* val stream = AsyncStream(("a", 1), ("b", 2), ("c", 3))
* stream.toMap // Future(Map("a" -> 1, "b" -> 2, "c" -> 3))
* }}}
*/
def toMap[K, V](implicit ev: A <:< (K, V)): Future[Map[K, V]] = map(ev).to(Map)

/**
* A Future of the stream realized as a list. This future completes when all
* elements of the stream are resolved.
Expand Down
114 changes: 114 additions & 0 deletions util-core/src/test/scala/com/twitter/concurrent/AsyncStreamTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,120 @@ class AsyncStreamTest extends AnyFunSuite with ScalaCheckDrivenPropertyChecks {

assert(Await.result(stream.toSeq()) == Seq(n))
}

test(s"$impl: groupBy") {
forAll { xs: Seq[(Int, String)] =>
val stream = fromSeq(xs)
val expected = xs.groupBy(_._1)
assert(Await.result(stream.groupBy(_._1)(Seq)) == expected)
}
}

test(s"$impl: distinct") {
forAll { xs: Seq[Int] =>
val stream = fromSeq(xs)
val expected = xs.distinct
assert(Await.result(stream.distinct.toSeq()) == expected)
}
}

test(s"$impl: distinctBy") {
forAll { xs: Seq[Int] =>
val stream = fromSeq(xs)
val expected = xs.distinctBy(_ % 3)
assert(Await.result(stream.distinctBy(_ % 3).toSeq()) == expected)
}
}

test(s"$impl: contains") {
forAll { xs: Seq[Int] =>
val stream = fromSeq(xs)
val expected = xs.contains(3)
assert(Await.result(stream.contains(3)) == expected)
}
}

test(s"$impl: exists") {
forAll { xs: Seq[Int] =>
val stream = fromSeq(xs)
val expected = xs.exists(_ % 3 == 0)
assert(Await.result(stream.exists(_ % 3 == 0)) == expected)
}
}

test(s"$impl: find") {
forAll { xs: Seq[Int] =>
val stream = fromSeq(xs)
val expected = xs.find(_ % 3 == 0)
assert(Await.result(stream.find(_ % 3 == 0)) == expected)
}
}

test(s"$impl: collect") {
forAll { xs: Seq[Int] =>
val stream = fromSeq(xs)
val pf: PartialFunction[Int, String] = {
case n if n % 3 == 0 => String.valueOf(n + 2)
case n if n % 2 == 0 => String.valueOf(n + 17)
}
val expected = xs.collect(pf)
assert(Await.result(stream.collect(pf).toSeq()) == expected)
}
}

test(s"$impl: collectFirst") {
forAll { xs: Seq[Int] =>
val stream = fromSeq(xs)
val pf: PartialFunction[Int, String] = {
case n if n % 3 == 0 => String.valueOf(n + 2)
case n if n % 2 == 0 => String.valueOf(n + 17)
}
val expected = xs.collectFirst(pf)
assert(Await.result(stream.collectFirst(pf)) == expected)
}
}

test(s"$impl: to[String]") {
forAll { xs: Seq[Char] =>
val stream = fromSeq(xs)
assert(Await.result(stream.to[String]) == xs.mkString)
}
}

test(s"$impl: to[List[Int]]") {
forAll { xs: List[Int] =>
val stream = fromSeq(xs)
assert(Await.result(stream.to[List[Int]]) == xs)
}
}

test(s"$impl: toList") {
forAll { xs: List[Int] =>
val stream = fromSeq(xs)
assert(Await.result(stream.toList) == xs)
}
}

test(s"$impl: toSet") {
forAll { xs: Set[Int] =>
val stream = fromSeq(xs.toSeq)
assert(Await.result(stream.toSet) == xs)
}
}

test(s"$impl: toArray") {
forAll { xs: Array[Int] =>
val stream = fromSeq(xs.toSeq)
assert(java.util.Arrays.equals(Await.result(stream.toArray), xs))
}
}

test(s"$impl: toMap") {
forAll { map: Map[String, Int] =>
val stream = fromSeq(map.toSeq)
assert(Await.result(stream.toMap) == map)
}
}
}

}
Expand Down