Skip to content

Commit

Permalink
Read Token.user_id instead of Token.userid
Browse files Browse the repository at this point in the history
Remove all **reads** of the `Token.userid` column, replacing them with
reads of the new `Token.user_id` column instead.

This prepares the way for actually removing the `Token.userid` column.

For now the code does still need to **write** `Token.userid` because
it's not nullable.
  • Loading branch information
seanh committed Feb 14, 2024
1 parent 338c4a2 commit 5970e3d
Show file tree
Hide file tree
Showing 29 changed files with 68 additions and 73 deletions.
2 changes: 2 additions & 0 deletions h/models/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sqlalchemy
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped

from h.db import Base, mixins

Expand Down Expand Up @@ -55,6 +56,7 @@ class Token(Base, mixins.Timestamps):
index=True,
nullable=True,
)
user: Mapped["User"] = sqlalchemy.orm.relationship(back_populates="tokens")

#: The authclient which created the token.
#: A NULL value means it is a developer token.
Expand Down
2 changes: 2 additions & 0 deletions h/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def activate(self):
#: upgrading their passwords and setting this column to None.
salt = sa.Column(sa.UnicodeText(), nullable=True)

tokens = sa.orm.relationship("Token", back_populates="user")

@sa.orm.validates("email")
def validate_email(self, _key, email):
if email is None:
Expand Down
2 changes: 1 addition & 1 deletion h/services/auth_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LongLivedToken:

def __init__(self, token):
self.expires = token.expires
self.userid = token.userid
self.userid = token.user.userid

# Associates the userid with a given transaction/web request.
newrelic.agent.add_custom_attribute("userid", self.userid)
Expand Down
6 changes: 4 additions & 2 deletions h/services/developer_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def create(self, userid):
"""
user = self.user_svc.fetch(userid)
token = models.Token(
userid=user.userid, user_id=user.id, value=self._generate_token()
user=user, userid=user.userid, value=self._generate_token()
)
self.session.add(token)
return token
Expand All @@ -68,9 +68,11 @@ def _fetch(self, userid):
if userid is None:
return None

user = self.user_svc.fetch(userid)

return (
self.session.query(models.Token)
.filter_by(userid=userid, authclient=None)
.filter_by(user=user, authclient=None)
.order_by(models.Token.created.desc())
.one_or_none()
)
Expand Down
7 changes: 3 additions & 4 deletions h/services/oauth/_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ class OAuthValidator( # pylint: disable=too-many-public-methods, abstract-metho
This implements the ``oauthlib.oauth2.RequestValidator`` interface.
"""

def __init__(self, session, user_svc):
def __init__(self, session):
self.session = session
self.user_svc = user_svc

self._cached_find_authz_code = lru_cache_in_transaction(self.session)(
self._find_authz_code
Expand Down Expand Up @@ -223,8 +222,8 @@ def save_bearer_token(self, token, request, *args, **kwargs):
] # We don't want to render this in the response.

oauth_token = models.Token(
user=request.user,
userid=request.user.userid,
user_id=request.user.id,
value=token["access_token"],
refresh_token=token["refresh_token"],
expires=expires,
Expand Down Expand Up @@ -321,7 +320,7 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs
):
return False

request.user = self.user_svc.fetch(token.userid)
request.user = token.user
return True

def validate_response_type(
Expand Down
2 changes: 1 addition & 1 deletion h/services/oauth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def factory(_context, request):
user_svc = request.find_service(name="user")

return OAuthProviderService(
oauth_validator=OAuthValidator(session=request.db, user_svc=user_svc),
oauth_validator=OAuthValidator(session=request.db),
user_svc=user_svc,
domain=request.domain,
)
2 changes: 1 addition & 1 deletion h/services/user_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def delete_user(self, user: User):
in the group that have been made by other users, the user is unassigned
as creator but the group persists.
"""
self._db.execute(sa.delete(Token).where(Token.userid == user.userid))
self._db.execute(sa.delete(Token).where(Token.user == user))

# Delete all annotations
self._annotation_delete_service.delete_annotations(
Expand Down
9 changes: 0 additions & 9 deletions h/services/user_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def rename(self, user, new_username):
# https://michael.merickel.org/projects/pyramid_auth_demo/auth_vs_auth.html
self._purge_auth_tickets(user)

# For OAuth tokens, only the token's value is stored by clients, so we
# can just update the userid.
self._update_tokens(old_userid, new_userid)

self._change_annotations(old_userid, new_userid)
tasks.job_queue.add_annotations_from_user.delay(
"sync_annotation",
Expand All @@ -68,11 +64,6 @@ def _purge_auth_tickets(self, user):
models.AuthTicket.user_id == user.id
).delete()

def _update_tokens(self, old_userid, new_userid):
self.session.query(models.Token).filter(
models.Token.userid == old_userid
).update({"userid": new_userid}, synchronize_session="fetch")

def _change_annotations(self, old_userid, new_userid):
annotations = self._fetch_annotations(old_userid)

Expand Down
2 changes: 1 addition & 1 deletion h/views/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def api_token_error(context, request):

def _present_debug_token(token):
data = {
"userid": token.userid,
"userid": token.user.userid,
"expires_at": utc_iso8601(token.expires),
"issued_at": utc_iso8601(token.created),
"expired": token.expired,
Expand Down
17 changes: 6 additions & 11 deletions tests/common/factories/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
from h.services.oauth import ACCESS_TOKEN_PREFIX, REFRESH_TOKEN_PREFIX

from .auth_client import AuthClient
from .base import FAKER, ModelFactory
from .base import ModelFactory
from .user import User


class DeveloperToken(ModelFactory):
class Meta:
model = models.Token
sqlalchemy_session_persistence = "flush"

userid = factory.LazyAttribute(
lambda _: (
"acct:" + FAKER.user_name() + "@example.com" # pylint:disable=no-member
)
)
user = factory.SubFactory(User)
userid = factory.LazyAttribute(lambda developer_token: developer_token.user.userid)
value = factory.LazyAttribute(
lambda _: (DEVELOPER_TOKEN_PREFIX + security.token_urlsafe())
)
Expand All @@ -30,11 +28,8 @@ class Meta:
model = models.Token
sqlalchemy_session_persistence = "flush"

userid = factory.LazyAttribute(
lambda _: (
"acct:" + FAKER.user_name() + "@example.com" # pylint:disable=no-member
)
)
user = factory.SubFactory(User)
userid = factory.LazyAttribute(lambda developer_token: developer_token.user.userid)
value = factory.LazyAttribute(
lambda _: (ACCESS_TOKEN_PREFIX + security.token_urlsafe())
)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/annotations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def user_annotation(db_session, user, factories):

@pytest.fixture
def user_with_token(user, db_session, factories):
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()
return (user, token)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _make_headers(authority):

@pytest.fixture
def token_auth_header(db_session, factories, user):
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def append_header(headers=None):
@pytest.fixture
def user_with_token(db_session, factories):
user = factories.User()
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.commit()
return (user, token)

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/flags_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def user(db_session, factories):

@pytest.fixture
def user_with_token(user, db_session, factories):
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()
return (user, token)
2 changes: 1 addition & 1 deletion tests/functional/api/groups/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def auth_client_header(auth_client):
@pytest.fixture
def user_with_token(db_session, factories):
user = factories.User()
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()
return (user, token)
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/api/groups/members_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def group_member(group, db_session, factories):

@pytest.fixture
def group_member_with_token(group_member, db_session, factories):
token = factories.DeveloperToken(userid=group_member.userid)
token = factories.DeveloperToken(user=group_member)
db_session.add(token)
db_session.commit()
return (group_member, token)
Expand All @@ -267,7 +267,7 @@ def group_member_with_token(group_member, db_session, factories):
@pytest.fixture
def user_with_token(db_session, factories):
user = factories.User()
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()
return (user, token)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/groups/read_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def auth_client_header(auth_client):
@pytest.fixture
def user_with_token(db_session, factories):
user = factories.User()
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()
return (user, token)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/groups/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def first_party_group(db_session, factories, first_party_user):

@pytest.fixture
def user_with_token(db_session, factories, first_party_user):
token = factories.DeveloperToken(userid=first_party_user.userid)
token = factories.DeveloperToken(user=first_party_user)
db_session.add(token)
db_session.commit()
return (first_party_user, token)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/groups/upsert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def first_party_group(db_session, factories, first_party_user):

@pytest.fixture
def user_with_token(db_session, factories, first_party_user):
token = factories.DeveloperToken(userid=first_party_user.userid)
token = factories.DeveloperToken(user=first_party_user)
db_session.add(token)
db_session.commit()
return (first_party_user, token)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/api/moderation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def private_group_annotation(group, db_session, factories, other_user):

@pytest.fixture
def user_with_token(user, db_session, factories):
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()
return (user, token)
4 changes: 2 additions & 2 deletions tests/functional/api/profile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def user(groups, db_session, factories):

@pytest.fixture
def user_with_token(user, db_session, factories):
token = factories.DeveloperToken(userid=user.userid)
token = factories.DeveloperToken(user=user)
db_session.add(token)
db_session.commit()
return (user, token)
Expand Down Expand Up @@ -184,6 +184,6 @@ def open_group(auth_client, db_session, factories):

@pytest.fixture
def third_party_user_with_token(third_party_user, db_session, factories):
token = factories.DeveloperToken(userid=third_party_user.userid)
token = factories.DeveloperToken(user=third_party_user)
db_session.commit()
return (third_party_user, token)
2 changes: 1 addition & 1 deletion tests/functional/moderation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ def moderator(db_session, factories):

@pytest.fixture
def moderator_with_token(moderator, db_session, factories):
token = factories.DeveloperToken(userid=moderator.userid)
token = factories.DeveloperToken(user=moderator)
db_session.commit()
return (moderator, token)
6 changes: 2 additions & 4 deletions tests/unit/h/services/auth_token_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_validate_returns_database_token(self, svc, factories):
result = svc.validate(token_model.value)

assert result.expires == token_model.expires
assert result.userid == token_model.userid
assert result.userid == token_model.user.userid

def test_validate_caches_database_token(self, svc, factories, db_session):
token_model = factories.DeveloperToken(expires=self.time(1))
Expand Down Expand Up @@ -103,9 +103,7 @@ class TestLongLivedToken:
),
)
def test_it(self, expires, is_valid, factories):
token = LongLivedToken(
factories.OAuth2Token(userid="acct:[email protected]", expires=expires)
)
token = LongLivedToken(factories.OAuth2Token(expires=expires))

assert token.is_valid() == is_valid

Expand Down
24 changes: 17 additions & 7 deletions tests/unit/h/services/developer_token_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,24 @@


class TestDeveloperTokenService:
def test_fetch_returns_developer_token_for_userid(self, svc, developer_token, user):
def test_fetch_returns_developer_token_for_userid(
self, svc, developer_token, user, user_service
):
user_service.fetch.return_value = user

assert svc.fetch(user.userid) == developer_token
user_service.fetch.assert_called_once_with(user.userid)

def test_fetch_returns_none_for_missing_userid(self, svc):
assert svc.fetch(None) is None

def test_fetch_returns_none_for_missing_developer_token(self, svc, user):
def test_fetch_returns_none_for_missing_developer_token(
self, svc, user, user_service
):
user_service.fetch.return_value = user

assert svc.fetch(user.userid) is None
user_service.fetch.assert_called_once_with(user.userid)

def test_create_creates_new_developer_token_for_userid(
self, svc, db_session, user, user_service
Expand All @@ -31,7 +41,7 @@ def test_create_creates_new_developer_token_for_userid(
user_service.fetch.assert_called_once_with(user.userid)
assert db_session.query(models.Token).all() == [
Any.instance_of(models.Token).with_attrs(
{"userid": user.userid, "user_id": user.id}
{"userid": user.userid, "user": user}
)
]

Expand All @@ -44,19 +54,19 @@ def test_create_returns_new_developer_token_for_userid(

token = svc.create(user.userid)

assert token.userid == user.userid
assert token.user == user
assert token.value == "6879-secure-token"
assert token.expires is None
assert token.authclient is None
assert token.refresh_token is None

def test_regenerate_sets_a_new_token_value(self, svc, developer_token):
old_userid = developer_token.userid
old_user = developer_token.user
old_value = developer_token.value

svc.regenerate(developer_token)

assert old_userid == developer_token.userid
assert old_user == developer_token.user
assert old_value != developer_token.value

@pytest.fixture
Expand All @@ -65,7 +75,7 @@ def svc(self, pyramid_request):

@pytest.fixture
def developer_token(self, factories, user):
return factories.DeveloperToken(userid=user.userid, user_id=user.id)
return factories.DeveloperToken(user=user)

@pytest.fixture
def user(self, factories):
Expand Down
Loading

0 comments on commit 5970e3d

Please sign in to comment.