Skip to content

Commit

Permalink
Good and band nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
lavrov committed Sep 14, 2024
1 parent cd6d2a6 commit 5abd251
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 72 deletions.
1 change: 1 addition & 0 deletions cmd/src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ object Main
val table = Resource.eval(RoutingTable[IO](selfId)).await
val node = Node(selfId, Some(port), QueryHandler(selfId, table)).await
Resource.eval(RoutingTableBootstrap(table, node.client)).await
PingRoutine(table, node.client).runForever.background.await
}.useForever
}
}
Expand Down
36 changes: 17 additions & 19 deletions dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,24 @@ object Node {
}
client0 <- Client(selfId, messageSocket.writeMessage, responses.take, generateTransactionId)
_ <-
Resource
.make(
messageSocket.readMessage
.flatMap {
case (a, m: Message.QueryMessage) =>
logger.debug(s"Received $m") >>
queryHandler(a, m.query).flatMap { response =>
val responseMessage = Message.ResponseMessage(m.transactionId, response)
logger.debug(s"Responding with $responseMessage") >>
messageSocket.writeMessage(a, responseMessage)
}
case (a, m: Message.ResponseMessage) => responses.offer((a, m.asRight))
case (a, m: Message.ErrorMessage) => responses.offer((a, m.asLeft))
messageSocket.readMessage
.flatMap {
case (a, m: Message.QueryMessage) =>
logger.debug(s"Received $m") >>
queryHandler(a, m.query).flatMap { response =>
val responseMessage = Message.ResponseMessage(m.transactionId, response)
logger.debug(s"Responding with $responseMessage") >>
messageSocket.writeMessage(a, responseMessage)
}
.recoverWith { case e: Throwable =>
logger.trace(s"Failed to read message: $e")
}
.foreverM
.start
)(_.cancel)
case (a, m: Message.ResponseMessage) => responses.offer((a, m.asRight))
case (a, m: Message.ErrorMessage) => responses.offer((a, m.asLeft))
}
.recoverWith { case e: Throwable =>
logger.trace(s"Failed to read message: $e")
}
.foreverM
.background

yield new Node {
def client: Client[IO] = client0
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.github.torrentdam.bittorrent.dht

import cats.effect.IO
import cats.syntax.all.*
import cats.effect.cps.{*, given}
import org.legogroup.woof.{Logger, given}

import scala.concurrent.duration.DurationInt

class PingRoutine(table: RoutingTable[IO], client: Client[IO])(using logger: Logger[IO]):

def run: IO[Unit] = async[IO]:
val nodes = table.allNodes.await
logger.info(s"Pinging ${nodes.size} nodes").await
val queries = nodes.map { node =>
client.ping(node.address).timeout(5.seconds).attempt.map(_.bimap(_ => node.id, _ => node.id))
}
val results = queries.parSequence.await
val (bad, good) = results.partitionMap(identity)
logger.info(s"Got ${good.size} good nodes and ${bad.size} bad nodes").await
table.updateGoodness(good.toSet, bad.toSet)

def runForever: IO[Unit] =
run
.handleErrorWith: e =>
logger.error(s"PingRoutine failed: $e")
.productR(IO.sleep(10.minutes))
.foreverM
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import cats.implicits.*
import com.comcast.ip4s.*
import com.github.torrentdam.bittorrent.InfoHash
import com.github.torrentdam.bittorrent.PeerInfo

import scala.collection.immutable.ListMap
import scodec.bits.ByteVector

import scala.annotation.tailrec

trait RoutingTable[F[_]] {

def insert(node: NodeInfo): F[Unit]
Expand All @@ -21,13 +24,20 @@ trait RoutingTable[F[_]] {
def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit]

def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]]

def allNodes: F[LazyList[RoutingTable.Node]]

def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit]
}

object RoutingTable {

enum TreeNode:
case Split(center: BigInt, lower: TreeNode, higher: TreeNode)
case Bucket(from: BigInt, until: BigInt, nodes: ListMap[NodeId, SocketAddress[IpAddress]])
case Bucket(from: BigInt, until: BigInt, nodes: ListMap[NodeId, Node])

case class Node(id: NodeId, address: SocketAddress[IpAddress], isGood: Boolean):
def toNodeInfo: NodeInfo = NodeInfo(id, address)

object TreeNode {

Expand All @@ -43,40 +53,39 @@ object RoutingTable {

import TreeNode.*

extension (bucket: TreeNode) {

extension (bucket: TreeNode)
def insert(node: NodeInfo, selfId: NodeId): TreeNode =
bucket match {
bucket match
case b @ Split(center, lower, higher) =>
if (node.id.int < center)
b.copy(lower = lower.insert(node, selfId))
else
b.copy(higher = higher.insert(node, selfId))
case Bucket(from, until, nodes) =>
case b @ Bucket(from, until, nodes) =>
if nodes.size == MaxNodes
then
val tree =
if selfId.int >= from && selfId.int < until
then
// split the bucket because it contains the self node
val center = (from + until) / 2
val splitBucket: TreeNode =
Split(
center,
lower = Bucket(from, center, ListMap.empty),
higher = Bucket(center, until, ListMap.empty)
)
nodes.view.map(NodeInfo.apply.tupled).foldLeft(splitBucket)(_.insert(_, selfId))
else
// drop one node from the bucket
Bucket(from, until, nodes.init)
tree.insert(node, selfId)
if selfId.int >= from && selfId.int < until
then
// split the bucket because it contains the self node
val center = (from + until) / 2
val splitNode =
Split(
center,
lower = Bucket(from, center, nodes.view.filterKeys(_.int < center).to(ListMap)),
higher = Bucket(center, until, nodes.view.filterKeys(_.int >= center).to(ListMap))
)
splitNode.insert(node, selfId)
else
// drop one node from the bucket
val badNode = nodes.values.find(!_.isGood)
badNode match
case Some(badNode) => Bucket(from, until, nodes.removed(badNode.id)).insert(node, selfId)
case None => b
else
Bucket(from, until, nodes.updated(node.id, node.address))
}
Bucket(from, until, nodes.updated(node.id, Node(node.id, node.address, isGood = true)))

def remove(nodeId: NodeId): TreeNode =
bucket match {
bucket match
case b @ Split(center, lower, higher) =>
if (nodeId.int < center)
(lower.remove(nodeId), higher) match {
Expand All @@ -94,29 +103,34 @@ object RoutingTable {
}
case b @ Bucket(_, _, nodes) =>
b.copy(nodes = nodes - nodeId)
}

@tailrec
def findBucket(nodeId: NodeId): Bucket =
bucket match {
bucket match
case Split(center, lower, higher) =>
if (nodeId.int < center)
lower.findBucket(nodeId)
else
higher.findBucket(nodeId)
case b: Bucket => b
}

def findNodes(nodeId: NodeId): LazyList[NodeInfo] =
bucket match {
def findNodes(nodeId: NodeId): LazyList[Node] =
bucket match
case Split(center, lower, higher) =>
if (nodeId.int < center)
lower.findNodes(nodeId) ++ higher.findNodes(nodeId)
else
higher.findNodes(nodeId) ++ lower.findNodes(nodeId)
case b: Bucket => b.nodes.to(LazyList).map(NodeInfo.apply.tupled)
}

}
case b: Bucket => b.nodes.values.to(LazyList)

def update(fn: Node => Node): TreeNode =
bucket match
case b @ Split(_, lower, higher) =>
b.copy(lower = lower.update(fn), higher = higher.update(fn))
case b @ Bucket(from, until, nodes) =>
b.copy(nodes = nodes.view.mapValues(fn).to(ListMap))

end extension

def apply[F[_]: Concurrent](selfId: NodeId): F[RoutingTable[F]] =
for {
Expand All @@ -128,10 +142,10 @@ object RoutingTable {
treeNodeRef.update(_.insert(node, selfId))

def findNodes(nodeId: NodeId): F[LazyList[NodeInfo]] =
treeNodeRef.get.map(_.findNodes(nodeId))
treeNodeRef.get.map(_.findNodes(nodeId).filter(_.isGood).map(_.toNodeInfo))

def findBucket(nodeId: NodeId): F[List[NodeInfo]] =
treeNodeRef.get.map(_.findBucket(nodeId).nodes.map(NodeInfo.apply.tupled).toList)
treeNodeRef.get.map(_.findBucket(nodeId).nodes.values.filter(_.isGood).map(_.toNodeInfo).toList)

def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit] =
peers.update { map =>
Expand All @@ -143,6 +157,17 @@ object RoutingTable {

def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]] =
peers.get.map(_.get(infoHash))

def allNodes: F[LazyList[Node]] =
treeNodeRef.get.map(_.findNodes(selfId))

def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit] =
treeNodeRef.update(
_.update(node =>
if good.contains(node.id) then node.copy(isGood = true)
else if bad.contains(node.id) then node.copy(isGood = false)
else node
)
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,30 @@ object RoutingTableBootstrap {
dns: Dns[F],
logger: Logger[F]
): Stream[F, NodeInfo] =
def tryThis(hostname: SocketAddress[Host]): F[Option[NodeInfo]] =
logger.info(s"Trying to reach $hostname") >>
hostname
.resolve[F]
.flatMap: seedAddress =>
def tryThis(hostname: SocketAddress[Host]): Stream[F, NodeInfo] =
Stream.eval(logger.info(s"Trying to reach $hostname")) >>
Stream
.evals(
hostname.host.resolveAll[F]
.recoverWith: e =>
logger.info(s"Failed to resolve $hostname $e").as(List.empty)
)
.evalMap: ipAddress =>
val resolvedAddress = SocketAddress(ipAddress, hostname.port)
client
.ping(seedAddress)
.ping(resolvedAddress)
.timeout(5.seconds)
.map(pong => NodeInfo(pong.id, seedAddress))
.map(_.some)
.recoverWith:
e =>
val msg = e.getMessage
logger.info(s"Failed to reach $hostname $msg $e").as(none)
Stream.emits(bootstrapNodeAddress)
.map(pong => NodeInfo(pong.id, resolvedAddress))
.map(_.some)
.recoverWith: e =>
logger.info(s"Failed to reach $resolvedAddress $e").as(none)
.collect {
case Some(node) => node
}
Stream
.emits(bootstrapNodeAddress)
.covary[F]
.evalMap(tryThis)
.collect {
case Some(node) => node
}
.flatMap(tryThis)

val PublicBootstrapNodes: List[SocketAddress[Host]] = List(
SocketAddress(host"router.bittorrent.com", port"6881"),
Expand Down

0 comments on commit 5abd251

Please sign in to comment.