Skip to content

Commit

Permalink
atproto: support bsky.app URLs as targets
Browse files Browse the repository at this point in the history
  • Loading branch information
snarfed committed Sep 13, 2023
1 parent bd09af9 commit 2bb5db8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
25 changes: 20 additions & 5 deletions atproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def ap_address(self):
return f'@{self.readable_id}@{self.ABBREV}{common.SUPERDOMAIN}'

@classmethod
# TODO: add bsky.app URLs, translating to/from at:// URIs. (to arroba?)
def owns_id(cls, id):
return (id.startswith('at://')
or id.startswith('did:plc:')
or id.startswith('did:web:'))
or id.startswith('did:web:')
or id.startswith('https://bsky.app/'))

@classmethod
def target_for(cls, obj, shared=False):
Expand All @@ -105,11 +105,26 @@ def target_for(cls, obj, shared=False):
Returns:
str
"""
if obj.key.id().startswith('did:'):
id = obj.key.id()
if id.startswith('did:'):
return None

if obj.key.id().startswith('at://'):
repo, collection, rkey = parse_at_uri(obj.key.id())
logger.info(f'Finding ATProto PDS for {id}')
if id.startswith('https://bsky.app/'):
return cls.target_for(Object(id=bluesky.web_url_to_at_uri(id)))

if id.startswith('at://'):
repo, collection, rkey = parse_at_uri(id)

if not repo.startswith('did:'):
# repo is a handle; resolve it
repo_did = did.resolve_handle(repo, get_fn=util.requests_get)
if repo_did:
return cls.target_for(Object(id=id.replace(
f'at://{repo}', f'at://{repo_did}')))
else:
return None

did_obj = ATProto.load(repo)
if did_obj:
return cls._pds_for(did_obj)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ colorama==0.4.6
cryptography==41.0.3
dag-cbor==0.3.2
Deprecated==1.2.14
dnspython==2.4.2
domain2idna==1.12.0
ecdsa==0.18.0
extras==1.0.0
Expand Down
37 changes: 33 additions & 4 deletions tests/test_atproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from arroba.did import encode_did_key
from arroba.repo import Repo
import arroba.util
import dns.resolver
from dns.resolver import NXDOMAIN
from flask import g
from granary.tests.test_bluesky import (
ACTOR_AS,
Expand Down Expand Up @@ -80,12 +82,14 @@ def test_owns_id(self):
self.assertTrue(ATProto.owns_id('at://did:plc:foo/bar/123'))
self.assertTrue(ATProto.owns_id('did:plc:foo'))
self.assertTrue(ATProto.owns_id('did:web:bar.com'))
self.assertTrue(ATProto.owns_id(
'https://bsky.app/profile/snarfed.org/post/3k62u4ht77f2z'))

def test_target_for_did_doc(self):
self.assertIsNone(ATProto.target_for(Object(id='did:plc:foo')))

def test_target_for_stored_did(self):
did_obj = self.store_object(id='did:plc:foo', raw=DID_DOC)
self.store_object(id='did:plc:foo', raw=DID_DOC)
got = ATProto.target_for(Object(id='at://did:plc:foo/co.ll/123'))
self.assertEqual('https://some.pds', got)

Expand All @@ -95,22 +99,47 @@ def test_target_for_fetch_did(self, mock_get):
self.assertEqual('https://some.pds', got)

def test_target_for_user_with_stored_did(self):
did_obj = self.store_object(id='did:plc:foo', raw=DID_DOC)
user = self.make_user('fake:user', cls=Fake, atproto_did='did:plc:foo')
self.store_object(id='did:plc:foo', raw=DID_DOC)
self.make_user('fake:user', cls=Fake, atproto_did='did:plc:foo')
got = ATProto.target_for(Object(id='fake:post', our_as1={
**POST_AS,
'actor': 'fake:user',
}))
self.assertEqual('https://some.pds', got)

def test_target_for_user_no_stored_did(self):
user = self.make_user('fake:user', cls=Fake)
self.make_user('fake:user', cls=Fake)
got = ATProto.target_for(Object(id='fake:post', our_as1={
**POST_AS,
'actor': 'fake:user',
}))
self.assertEqual('http://localhost/', got)

def test_target_for_bsky_app_url_did_stored(self):
self.store_object(id='did:plc:foo', raw=DID_DOC)
self.make_user('fake:user', cls=Fake, atproto_did='did:plc:foo')

got = ATProto.target_for(Object(
id='https://bsky.app/profile/did:plc:foo/post/123'))
self.assertEqual('https://some.pds', got)

@patch('dns.resolver.resolve', side_effect=dns.resolver.NXDOMAIN())
@patch('requests.get', side_effect=[
# resolving handle, HTTPS method
requests_response('did:plc:foo', content_type='text/plain'),
# fetching DID doc
requests_response(DID_DOC),
])
def test_target_for_bsky_app_url_resolve_handle(self, mock_get, _):
got = ATProto.target_for(Object(
id='https://bsky.app/profile/baz.com/post/123'))
self.assertEqual('https://some.pds', got)

mock_get.assert_has_calls((
self.req('https://baz.com/.well-known/atproto-did'),
self.req('https://plc.local/did:plc:foo'),
))

@patch('requests.get', return_value=requests_response({'foo': 'bar'}))
def test_fetch_did_plc(self, mock_get):
obj = Object(id='did:plc:123')
Expand Down

0 comments on commit 2bb5db8

Please sign in to comment.