1010import logging
1111import multiprocessing
1212from collections .abc import Iterable
13+ from contextlib import ContextDecorator
1314from typing import Any
1415
1516import django_tenants .migration_executors
1617from django .conf import settings
17- from django .db import connection
18+ from django .db import connection , connections
1819from django .db .migrations .executor import MigrationExecutor
1920from django .db .migrations .recorder import MigrationRecorder
2021from django_tenants .migration_executors .base import run_migrations
2122from django_tenants .signals import schema_migrated
22- from django_tenants .utils import schema_context
23+ from django_tenants .utils import get_tenant_database_alias
2324
2425logger = logging .getLogger ("django_tenants_smart_executor" )
2526
2627
28+ class schema_context_without_public (ContextDecorator ): # noqa: N801
29+ """
30+ Like schema_context, but without public schema.
31+ """
32+
33+ def __init__ (self , * args , ** kwargs ):
34+ self .schema_name = args [0 ]
35+ self .database = kwargs .get ("database" , get_tenant_database_alias ())
36+ super ().__init__ ()
37+
38+ def __enter__ (self ):
39+ self .connection = connections [self .database ]
40+ self .previous_tenant = connection .tenant
41+ self .connection .set_schema (self .schema_name , include_public = False )
42+
43+ def __exit__ (self , * exc ):
44+ if self .previous_tenant is None :
45+ self .connection .set_schema_to_public ()
46+ else :
47+ self .connection .set_tenant (self .previous_tenant )
48+
49+
2750def needs_migrations (nodes : set [tuple [str , str ]], schema_name : str , options : dict ) -> bool :
2851 """
2952 Returns whether we need to run migrations for a given schema.
@@ -35,8 +58,12 @@ def needs_migrations(nodes: set[tuple[str, str]], schema_name: str, options: dic
3558
3659 migrated_already : set [tuple [str , str ]]
3760
38- with schema_context (schema_name ):
39- migrated_already = set (MigrationRecorder (connection = connection ).applied_migrations ().keys ())
61+ # need to exclude public schema so if there's no migration table it doesn't pick up the one in public
62+ with schema_context_without_public (schema_name ):
63+ migration_recorder = MigrationRecorder (connection = connection )
64+ if not migration_recorder .has_table ():
65+ return True
66+ migrated_already = set (migration_recorder .applied_migrations ().keys ())
4067
4168 for node in nodes :
4269 if node not in migrated_already :
0 commit comments