Skip to content

Commit

Permalink
Improve node bootstrap
Browse files Browse the repository at this point in the history
  • Loading branch information
lavrov committed Sep 16, 2024
1 parent b0f0d85 commit c450e4d
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 42 deletions.
3 changes: 1 addition & 2 deletions cmd/src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ object Main
val messageSocket = MessageSocket(none).await
val client = Client(selfId, messageSocket, QueryHandler.noop).await
async[IO]:
val pong = client.ping(nodeIpAddress).await
val response = client.getPeers(NodeInfo(pong.id, nodeIpAddress), infoHash).await
val response = client.getPeers(nodeIpAddress, infoHash).await
IO.println(response).await
ExitCode.Success
}.useEval
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ trait Client {

def id: NodeId

def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]]
def getPeers(address: SocketAddress[IpAddress], infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]]

def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes]
def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes]

def ping(address: SocketAddress[IpAddress]): IO[Response.Ping]

def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]]
def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]]
}

object Client {
Expand Down Expand Up @@ -71,17 +71,17 @@ object Client {
def id: NodeId = selfId

def getPeers(
nodeInfo: NodeInfo,
address: SocketAddress[IpAddress],
infoHash: InfoHash
): IO[Either[Response.Nodes, Response.Peers]] =
requestResponse.sendQuery(nodeInfo.address, Query.GetPeers(selfId, infoHash)).flatMap {
requestResponse.sendQuery(address, Query.GetPeers(selfId, infoHash)).flatMap {
case nodes: Response.Nodes => nodes.asLeft.pure
case peers: Response.Peers => peers.asRight.pure
case _ => IO.raiseError(InvalidResponse())
}

def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes] =
requestResponse.sendQuery(nodeInfo.address, Query.FindNode(selfId, target)).flatMap {
def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes] =
requestResponse.sendQuery(address, Query.FindNode(selfId, target)).flatMap {
case nodes: Response.Nodes => nodes.pure
case _ => IO.raiseError(InvalidResponse())
}
Expand All @@ -91,8 +91,8 @@ object Client {
case ping: Response.Ping => ping.pure
case _ => IO.raiseError(InvalidResponse())
}
def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
requestResponse.sendQuery(nodeInfo.address, Query.SampleInfoHashes(selfId, target)).flatMap {
def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
requestResponse.sendQuery(address, Query.SampleInfoHashes(selfId, target)).flatMap {
case response: Response.SampleInfoHashes => response.asRight[Response.Nodes].pure
case response: Response.Nodes => response.asLeft[Response.SampleInfoHashes].pure
case _ => IO.raiseError(InvalidResponse())
Expand Down
27 changes: 18 additions & 9 deletions dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,34 @@ object Node {

def id: NodeId = client.id

def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] =
client.getPeers(nodeInfo, infoHash) <* routingTable.insert(nodeInfo)
def getPeers(address: SocketAddress[IpAddress], infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] =
client.getPeers(address, infoHash).flatTap { response =>
routingTable.insert(
NodeInfo(
response match
case Left(response) => response.id
case Right(response) => response.id,
address
)
)
}

def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes] =
client.findNodes(nodeInfo, target).flatTap { response =>
routingTable.insert(NodeInfo(response.id, nodeInfo.address))
def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes] =
client.findNodes(address, target).flatTap { response =>
routingTable.insert(NodeInfo(response.id, address))
}

def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] =
client.ping(address).flatTap { response =>
routingTable.insert(NodeInfo(response.id, address))
}

def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
client.sampleInfoHashes(nodeInfo, target).flatTap { response =>
def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
client.sampleInfoHashes(address, target).flatTap { response =>
routingTable.insert(
response match
case Left(response) => NodeInfo(response.id, nodeInfo.address)
case Right(response) => NodeInfo(response.id, nodeInfo.address)
case Left(response) => NodeInfo(response.id, address)
case Right(response) => NodeInfo(response.id, address)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ object NodeId {
for bytes <- Random[F].nextBytes(20)
yield NodeId(ByteVector.view(bytes))
}

given Show[NodeId] = nodeId => s"NodeId(${nodeId.bytes.toHex})"

val MaxValue: BigInt = BigInt(1, Array.fill(20)(0xff.toByte))
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@ import cats.effect.IO
import cats.effect.Resource
import cats.instances.all.*
import cats.syntax.all.*
import cats.effect.cps.{given, *}
import cats.Show.Shown
import com.github.torrentdam.bittorrent.InfoHash
import com.github.torrentdam.bittorrent.PeerInfo
import fs2.Stream
import org.legogroup.woof.given
import org.legogroup.woof.Logger

import scala.concurrent.duration.DurationInt
import Logger.withLogContext
import com.comcast.ip4s.{IpAddress, SocketAddress}

trait PeerDiscovery {

def discover(infoHash: InfoHash): Stream[IO, PeerInfo]

def findNodes(NodeId: NodeId): Stream[IO, NodeInfo]
}

object PeerDiscovery {
Expand All @@ -36,7 +41,7 @@ object PeerDiscovery {
_ <- logger.info("Start discovery")
initialNodes <- routingTable.findNodes(NodeId(infoHash.bytes))
initialNodes <- initialNodes.take(100).toList.pure[IO]
_ <- logger.info(s"Got ${initialNodes.size} from routing table")
_ <- logger.info(s"Received ${initialNodes.size} from own routing table")
state <- DiscoveryState(initialNodes, infoHash)
} yield {
start(
Expand All @@ -52,21 +57,80 @@ object PeerDiscovery {
case _ => IO.unit
}
}

def findNodes(nodeId: NodeId): Stream[IO, NodeInfo] =
Stream
.eval(
for
_ <- logger.info(s"Start finding nodes for $nodeId")
initialNodes <- routingTable.findNodes(nodeId)
initialNodes <- initialNodes
.take(10)
.sortBy(nodeInfo => NodeId.distance(nodeInfo.id, dhtClient.id))
.toList.pure[IO]
yield
FindNodesState(nodeId, initialNodes)
)
.flatMap { state =>
Stream
.unfoldEval(state)(_.next)
.flatMap(Stream.emits)
}

case class FindNodesState(
targetId: NodeId,
nodesToQuery: List[NodeInfo],
seenNodes: Set[NodeInfo] = Set.empty,
respondedCount: Int = 0
):
def next: IO[Option[(List[NodeInfo], FindNodesState)]] = async[IO]:
if nodesToQuery.isEmpty then
none
else
val responses = nodesToQuery
.parTraverse(nodeInfo =>
dhtClient
.findNodes(nodeInfo.address, targetId)
.map(_.nodes.some)
.timeout(5.seconds)
.orElse(none.pure[IO])
.tupleLeft(nodeInfo)
)
.await
val respondedNodes = responses.collect { case (nodeInfo, Some(_)) => nodeInfo }
val foundNodes = responses.collect { case (_, Some(nodes)) => nodes }.flatten
val threshold =
if respondedCount > 10
then NodeId.distance(nodesToQuery.head.id, targetId)
else NodeId.MaxValue
val closeNodes = foundNodes
.filterNot(seenNodes)
.distinct
.filter(nodeInfo => NodeId.distance(nodeInfo.id, targetId) < threshold)
.sortBy(nodeInfo => NodeId.distance(nodeInfo.id, targetId))
.take(10)
(
respondedNodes,
copy(
nodesToQuery = closeNodes,
seenNodes = seenNodes ++ foundNodes,
respondedCount = respondedCount + respondedNodes.size)
).some
}

private[dht] def start(
infoHash: InfoHash,
getPeers: (NodeInfo, InfoHash) => IO[Either[Response.Nodes, Response.Peers]],
state: DiscoveryState,
parallelism: Int = 10
infoHash: InfoHash,
getPeers: (SocketAddress[IpAddress], InfoHash) => IO[Either[Response.Nodes, Response.Peers]],
state: DiscoveryState,
parallelism: Int = 10
)(using
logger: Logger[IO]
): Stream[IO, PeerInfo] = {

Stream
.repeatEval(state.next)
.parEvalMapUnordered(parallelism) { nodeInfo =>
getPeers(nodeInfo, infoHash).timeout(5.seconds).attempt <* logger.trace(s"Get peers $nodeInfo")
getPeers(nodeInfo.address, infoHash).timeout(5.seconds).attempt <* logger.trace(s"Get peers $nodeInfo")
}
.flatMap {
case Right(response) =>
Expand Down Expand Up @@ -100,10 +164,8 @@ object PeerDiscovery {

def addNodes(nodes: List[NodeInfo]): IO[Unit] = {
ref.modify { state =>
val (seenNodes, newNodes) = {
val newNodes = nodes.filterNot(state.seenNodes)
(state.seenNodes ++ newNodes, newNodes)
}
val newNodes = nodes.filterNot(state.seenNodes)
val seenNodes = state.seenNodes ++ newNodes
val nodesToTry = (newNodes ++ state.nodesToTry).sortBy(n => NodeId.distance(n.id, infoHash))
val waiters = state.waiters.drop(nodesToTry.size)
val newState =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ class PingRoutine(table: RoutingTable[IO], client: Client)(using logger: Logger[
table.updateGoodness(good.toSet, bad.toSet).await

def runForever: IO[Unit] =
run
IO
.sleep(10.minutes)
.productR(run)
.foreverM
.handleErrorWith: e =>
logger.error(s"PingRoutine failed: $e")
.productR(IO.sleep(10.minutes))
.foreverM
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import cats.implicits.*
import cats.MonadError
import cats.effect.IO
import cats.effect.implicits.*
import cats.effect.cps.{given, *}
import com.comcast.ip4s.*
import com.github.torrentdam.bittorrent.InfoHash
import org.legogroup.woof.given
Expand All @@ -24,15 +25,14 @@ object RoutingTableBootstrap {
dns: Dns[IO],
logger: Logger[IO]
): IO[Unit] =
for {
for
_ <- logger.info("Bootstrapping")
count <- resolveNodes(client, bootstrapNodeAddress).compile.count
_ <- logger.info(s"Pinged $count bootstrap nodes")
_ <- logger.info("Discover self to fill up routing table")
_ <- discovery.discover(InfoHash(client.id.bytes)).take(10).compile.drain.timeout(30.seconds).attempt
count <- resolveNodes(client, bootstrapNodeAddress).compile.count.iterateUntil(_ > 0)
_ <- logger.info(s"Communicated with $count bootstrap nodes")
_ <- selfDiscovery(table, client, discovery)
nodeCount <- table.allNodes.map(_.size)
_ <- logger.info(s"Bootstrapping finished with $nodeCount nodes")
} yield {}
yield {}

private def resolveNodes(
client: Client,
Expand Down Expand Up @@ -63,12 +63,25 @@ object RoutingTableBootstrap {
logger.info(s"Failed to reach $resolvedAddress $e").as(none)
.collect {
case Some(node) => node
}
}
Stream
.emits(bootstrapNodeAddress)
.covary[IO]
.flatMap(tryThis)

private def selfDiscovery(
table: RoutingTable[IO],
client: Client,
discovery: PeerDiscovery
)(using Logger[IO]) =
def attempt(number: Int): IO[Unit] = async[IO]:
Logger[IO].info(s"Discover self to fill up routing table (attempt $number)").await
val count = discovery.findNodes(client.id).take(30).interruptAfter(30.seconds).compile.count.await
Logger[IO].info(s"Communicated with $count nodes during self discovery").await
val nodeCount = table.allNodes.await.size
if nodeCount < 20 then attempt(number + 1).await else IO.unit
attempt(1)

val PublicBootstrapNodes: List[SocketAddress[Host]] = List(
SocketAddress(host"router.bittorrent.com", port"6881"),
SocketAddress(host"router.utorrent.com", port"6881"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class PeerDiscoverySpec extends munit.CatsEffectSuite {
given logger: Logger[IO] = NoOpLogger()

def getPeers(
nodeInfo: NodeInfo,
address: SocketAddress[IpAddress],
infoHash: InfoHash
): IO[Either[Response.Nodes, Response.Peers]] = IO {
nodeInfo.address.port.value match {
address.port.value match {
case 1 =>
Left(
Response.Nodes(
Expand Down

0 comments on commit c450e4d

Please sign in to comment.