Skip to content

Commit

Permalink
Refresh stale buckets
Browse files Browse the repository at this point in the history
  • Loading branch information
lavrov committed Sep 17, 2024
1 parent c450e4d commit 14cdaf5
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 61 deletions.
2 changes: 1 addition & 1 deletion cmd/src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ object Main
val nodeAddress = SocketAddress.fromString(nodeAddressParam).liftTo[ResourceIO](new Exception("Invalid address")).await
val nodeIpAddress = nodeAddress.resolve[IO].toResource.await
given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await
val selfId = Resource.eval(NodeId.generate[IO]).await
val selfId = Resource.eval(NodeId.random[IO]).await
val infoHash = infoHashFromString(infoHashParam).toResource.await
val messageSocket = MessageSocket(none).await
val client = Client(selfId, messageSocket, QueryHandler.noop).await
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object Node {
logger: Logger[IO]
): Resource[IO, Node] =
for
selfId <- Resource.eval(NodeId.generate[IO])
selfId <- Resource.eval(NodeId.random[IO])
messageSocket <- MessageSocket(port)
routingTable <- RoutingTable[IO](selfId).toResource
queryingNodes <- Queue.unbounded[IO, NodeInfo].toResource
Expand All @@ -41,7 +41,7 @@ object Node {
bootstrapNodes = bootstrapNodeAddress.map(List(_)).getOrElse(RoutingTableBootstrap.PublicBootstrapNodes)
discovery = PeerDiscovery(routingTable, insertingClient)
_ <- RoutingTableBootstrap(routingTable, insertingClient, discovery, bootstrapNodes).toResource
_ <- PingRoutine(routingTable, client).runForever.background
_ <- RoutingTableRefresh(routingTable, client, discovery).runEvery(15.minutes).background
_ <- pingCandidates(queryingNodes, client, routingTable).background
yield new Node(selfId, insertingClient, routingTable, discovery)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@ object NodeId {

def distance(a: NodeId, b: InfoHash): BigInt = distance(a.bytes, b.bytes)

def generate[F[_]](using Random[F], Monad[F]): F[NodeId] = {
def random[F[_]](using Random[F], Monad[F]): F[NodeId] = {
for bytes <- Random[F].nextBytes(20)
yield NodeId(ByteVector.view(bytes))
}

def fromInt(int: BigInt): NodeId = NodeId(ByteVector.view(int.toByteArray).padTo(20))

def randomInRange[F[_]](from: BigInt, until: BigInt)(using Random[F], Monad[F]): F[NodeId] =
val range = until - from
for
bytes <- Random[F].nextBytes(range.bitLength / 8)
integer = BigInt(1, bytes) + from
yield NodeId(ByteVector.view(integer.toByteArray).padLeft(20))

given Show[NodeId] = nodeId => s"NodeId(${nodeId.bytes.toHex})"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ object PeerDiscovery {
.eval {
for {
_ <- logger.info("Start discovery")
initialNodes <- routingTable.findNodes(NodeId(infoHash.bytes))
initialNodes <- initialNodes.take(100).toList.pure[IO]
initialNodes <- routingTable.goodNodes(NodeId(infoHash.bytes))
initialNodes <- initialNodes.take(16).toList.pure[IO]
_ <- logger.info(s"Received ${initialNodes.size} from own routing table")
state <- DiscoveryState(initialNodes, infoHash)
} yield {
Expand All @@ -63,11 +63,12 @@ object PeerDiscovery {
.eval(
for
_ <- logger.info(s"Start finding nodes for $nodeId")
initialNodes <- routingTable.findNodes(nodeId)
initialNodes <- routingTable.goodNodes(nodeId)
initialNodes <- initialNodes
.take(10)
.take(16)
.toList
.sortBy(nodeInfo => NodeId.distance(nodeInfo.id, dhtClient.id))
.toList.pure[IO]
.pure[IO]
yield
FindNodesState(nodeId, initialNodes)
)
Expand All @@ -80,7 +81,7 @@ object PeerDiscovery {
case class FindNodesState(
targetId: NodeId,
nodesToQuery: List[NodeInfo],
seenNodes: Set[NodeInfo] = Set.empty,
usedNodes: Set[NodeInfo] = Set.empty,
respondedCount: Int = 0
):
def next: IO[Option[(List[NodeInfo], FindNodesState)]] = async[IO]:
Expand All @@ -104,7 +105,7 @@ object PeerDiscovery {
then NodeId.distance(nodesToQuery.head.id, targetId)
else NodeId.MaxValue
val closeNodes = foundNodes
.filterNot(seenNodes)
.filterNot(usedNodes)
.distinct
.filter(nodeInfo => NodeId.distance(nodeInfo.id, targetId) < threshold)
.sortBy(nodeInfo => NodeId.distance(nodeInfo.id, targetId))
Expand All @@ -113,7 +114,7 @@ object PeerDiscovery {
respondedNodes,
copy(
nodesToQuery = closeNodes,
seenNodes = seenNodes ++ foundNodes,
usedNodes = usedNodes ++ respondedNodes,
respondedCount = respondedCount + respondedNodes.size)
).some
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ object QueryHandler {
case Query.Ping(_) =>
Response.Ping(selfId).some.pure[F]
case Query.FindNode(_, target) =>
routingTable.findBucket(target).map { nodes =>
Response.Nodes(selfId, nodes).some
routingTable.goodNodes(target).map { nodes =>
Response.Nodes(selfId, nodes.take(8).toList).some
}
case Query.GetPeers(_, infoHash) =>
routingTable.findPeers(infoHash).flatMap {
case Some(peers) =>
Response.Peers(selfId, peers.toList).some.pure[F]
case None =>
routingTable
.findBucket(NodeId(infoHash.bytes))
.goodNodes(NodeId(infoHash.bytes))
.map { nodes =>
Response.Nodes(selfId, nodes).some
Response.Nodes(selfId, nodes.take(8).toList).some
}
}
case Query.AnnouncePeer(_, infoHash, port) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ trait RoutingTable[F[_]] {

def remove(nodeId: NodeId): F[Unit]

def findNodes(nodeId: NodeId): F[LazyList[NodeInfo]]

def findBucket(nodeId: NodeId): F[List[NodeInfo]]
def goodNodes(nodeId: NodeId): F[Iterable[NodeInfo]]

def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit]

def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]]

def allNodes: F[LazyList[RoutingTable.Node]]
def allNodes: F[Iterable[RoutingTable.Node]]

def buckets: F[Iterable[RoutingTable.TreeNode.Bucket]]

def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit]

Expand All @@ -38,7 +38,7 @@ object RoutingTable {

enum TreeNode:
case Split(center: BigInt, lower: TreeNode, higher: TreeNode)
case Bucket(from: BigInt, until: BigInt, nodes: ListMap[NodeId, Node])
case Bucket(from: BigInt, until: BigInt, nodes: Map[NodeId, Node])

case class Node(id: NodeId, address: SocketAddress[IpAddress], isGood: Boolean, badCount: Int = 0):
def toNodeInfo: NodeInfo = NodeInfo(id, address)
Expand All @@ -49,7 +49,7 @@ object RoutingTable {
TreeNode.Bucket(
from = BigInt(0),
until = BigInt(1, ByteVector.fill(20)(-1: Byte).toArray),
ListMap.empty
Map.empty
)
}

Expand Down Expand Up @@ -118,7 +118,7 @@ object RoutingTable {
higher.findBucket(nodeId)
case b: Bucket => b

def findNodes(nodeId: NodeId): LazyList[Node] =
def findNodes(nodeId: NodeId): Iterable[Node] =
bucket match
case Split(center, lower, higher) =>
if (nodeId.int < center)
Expand All @@ -127,6 +127,11 @@ object RoutingTable {
higher.findNodes(nodeId) ++ lower.findNodes(nodeId)
case b: Bucket => b.nodes.values.to(LazyList)

def buckets: Iterable[Bucket] =
bucket match
case b: Bucket => Iterable(b)
case Split(_, lower, higher) => lower.buckets ++ higher.buckets

def update(fn: Node => Node): TreeNode =
bucket match
case b @ Split(_, lower, higher) =>
Expand All @@ -148,12 +153,9 @@ object RoutingTable {
def remove(nodeId: NodeId): F[Unit] =
treeNodeRef.update(_.remove(nodeId))

def findNodes(nodeId: NodeId): F[LazyList[NodeInfo]] =
def goodNodes(nodeId: NodeId): F[Iterable[NodeInfo]] =
treeNodeRef.get.map(_.findNodes(nodeId).filter(_.isGood).map(_.toNodeInfo))

def findBucket(nodeId: NodeId): F[List[NodeInfo]] =
treeNodeRef.get.map(_.findBucket(nodeId).nodes.values.filter(_.isGood).map(_.toNodeInfo).toList)

def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit] =
peers.update { map =>
map.updatedWith(infoHash) {
Expand All @@ -165,9 +167,12 @@ object RoutingTable {
def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]] =
peers.get.map(_.get(infoHash))

def allNodes: F[LazyList[Node]] =
def allNodes: F[Iterable[Node]] =
treeNodeRef.get.map(_.findNodes(selfId))

def buckets: F[Iterable[TreeNode.Bucket]] =
treeNodeRef.get.map(_.buckets)

def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit] =
treeNodeRef.update(
_.update(node =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.github.torrentdam.bittorrent.dht

import cats.effect.IO
import cats.syntax.all.*
import cats.effect.cps.{*, given}
import cats.effect.std.Random
import org.legogroup.woof.{Logger, given}

import scala.concurrent.duration.{DurationInt, FiniteDuration}

class RoutingTableRefresh(table: RoutingTable[IO], client: Client, discovery: PeerDiscovery)(using logger: Logger[IO], random: Random[IO]):

def runOnce: IO[Unit] = async[IO]:
val buckets = table.buckets.await
val (fresh, stale) = buckets.toList.partition(_.nodes.values.exists(_.isGood))
if stale.nonEmpty then
refreshBuckets(stale).await
val nodes = fresh.flatMap(_.nodes.values)
pingNodes(nodes).await

def runEvery(period: FiniteDuration): IO[Unit] =
IO
.sleep(period)
.productR(runOnce)
.foreverM
.handleErrorWith: e =>
logger.error(s"PingRoutine failed: $e")
.foreverM

private def pingNodes(nodes: List[RoutingTable.Node]) = async[IO]:
logger.info(s"Pinging ${nodes.size} nodes").await
val results = nodes
.parTraverse { node =>
client.ping(node.address).timeout(5.seconds).attempt.map(_.bimap(_ => node.id, _ => node.id))
}
.await
val (bad, good) = results.partitionMap(identity)
logger.info(s"Got ${good.size} good nodes and ${bad.size} bad nodes").await
table.updateGoodness(good.toSet, bad.toSet).await

private def refreshBuckets(buckets: List[RoutingTable.TreeNode.Bucket]) = async[IO]:
logger.info(s"Found ${buckets.size} stale buckets").await
buckets
.parTraverse: bucket =>
val randomId = NodeId.randomInRange(bucket.from, bucket.until).await
discovery.findNodes(randomId).take(32).compile.drain
.await


0 comments on commit 14cdaf5

Please sign in to comment.