Skip to content

Commit

Permalink
Remove embedding code from public repo
Browse files Browse the repository at this point in the history
This code is pretty specific to our workflow,
and adds a bunch of overhead to AnalyzedUrl's.
We're going to bring this small part of the code into a private repo,
similar to our ML modeling before it.
  • Loading branch information
ericholscher committed Mar 14, 2024
1 parent e92f4ce commit 3202a68
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 140 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ celerybeat-schedule
celerybeat-schedule.db
celerybeat.pid

# VSCode
.vscode


##########################################################################
# Ad Server specific ignores
Expand Down
1 change: 0 additions & 1 deletion adserver/analyzer/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Backends for analyzing URLs for keywords and topics."""
from .eatopics import EthicalAdsTopicsBackend # noqa
from .naive import NaiveKeywordAnalyzerBackend # noqa
from .st import SentenceTransformerAnalyzerBackend # noqa
from .textacynlp import TextacyAnalyzerBackend # noqa
52 changes: 0 additions & 52 deletions adserver/analyzer/backends/st.py

This file was deleted.

20 changes: 20 additions & 0 deletions adserver/analyzer/migrations/0005_remove_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Generated by Django 4.2.11 on 2024-03-14 18:53
from django.db import migrations


class Migration(migrations.Migration):

dependencies = [
("adserver_analyzer", "0004_add_embeddings"),
]

operations = [
migrations.RemoveField(
model_name="analyzedurl",
name="embedding",
),
migrations.RemoveField(
model_name="historicalanalyzedurl",
name="embedding",
),
]
2 changes: 0 additions & 2 deletions adserver/analyzer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class AnalyzedUrl(TimeStampedModel):
),
)

embedding = VectorField(dimensions=384, default=None, null=True, blank=True)

history = HistoricalRecords()

def __str__(self):
Expand Down
17 changes: 15 additions & 2 deletions adserver/analyzer/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from .utils import normalize_url
from config.celery_app import app

if "ethicalads_ext" in settings.INSTALLED_APPS:
from ethicalads_ext.models import Embedding


log = logging.getLogger(__name__) # noqa

Expand Down Expand Up @@ -91,18 +94,28 @@ def analyze_url(url, publisher_slug, force=False):
publisher=publisher,
defaults={
"keywords": keywords,
"embedding": embedding,
"last_analyzed_date": timezone.now(),
},
)

if not created:
url_obj.keywords = keywords
url_obj.embedding = embedding
url_obj.last_analyzed_date = timezone.now()
url_obj.visits_since_last_analyzed = 0
url_obj.save()

if "ethicalads_ext" in settings.INSTALLED_APPS:
embedding_obj, embedding_created = Embedding.objects.get_or_create(
url=url_obj,
model="v1",
defaults={
"embedding": embedding,
},
)
if not embedding_created:
embedding_obj.embedding = embedding
embedding_obj.save()


@app.task
def daily_visited_urls_aggregation(day=None):
Expand Down
82 changes: 1 addition & 81 deletions adserver/analyzer/views.py
Original file line number Diff line number Diff line change
@@ -1,81 +1 @@
from urllib.parse import urlparse

from django.conf import settings
from pgvector.django import CosineDistance
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.renderers import StaticHTMLRenderer
from rest_framework.response import Response
from rest_framework.views import APIView

from adserver.analyzer.backends.st import SentenceTransformerAnalyzerBackend
from adserver.analyzer.models import AnalyzedUrl


if "adserver.analyzer" in settings.INSTALLED_APPS:

class EmbeddingViewSet(APIView):
"""
Returns a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL.
Example: http://localhost:5000/api/v1/similar/?url=https://www.gitbook.com/
.. http:get:: /api/v1/embedding/
Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL
:<json string url: **Required**. The URL to query for similar URLs and scores
:>json int count: The number of similar URLs returned
:>json array results: An array of similar URLs and scores
"""

permission_classes = [AllowAny]

def get(self, request):
"""Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL."""
url = request.query_params.get("url")

if not url:
return Response(
{"error": "url is required"}, status=status.HTTP_400_BAD_REQUEST
)

backend_instance = SentenceTransformerAnalyzerBackend(url)
response = backend_instance.fetch()
if not response:
return Response(
{"error": "Not able to fetch content from URL"},
status=status.HTTP_400_BAD_REQUEST,
)
processed_text = backend_instance.get_content(response)
analyzed_embedding = backend_instance.embedding(response)

unfiltered_urls = (
AnalyzedUrl.objects.filter(publisher__allow_paid_campaigns=True)
.exclude(embedding=None)
.annotate(distance=CosineDistance("embedding", analyzed_embedding))
.order_by("distance")[:25]
)

# Filter urls to ensure each domain is unique
unique_domains = set()
urls = []
for url in unfiltered_urls:
domain = urlparse(url.url).netloc
if domain not in unique_domains:
unique_domains.add(domain)
urls.append(url)

if not len(urls) > 3:
return Response(
{"error": "No similar URLs found"}, status=status.HTTP_404_NOT_FOUND
)

return Response(
{
"count": len(urls),
"text": processed_text[:500],
"results": [[url.url, url.distance] for url in urls],
}
)
# Left blank
4 changes: 2 additions & 2 deletions adserver/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
router.register(r"advertisers", AdvertiserViewSet, basename="advertisers")
router.register(r"publishers", PublisherViewSet, basename="publishers")

if "adserver.analyzer" in settings.INSTALLED_APPS:
from adserver.analyzer.views import EmbeddingViewSet
if "ethicalads_ext.embedding" in settings.INSTALLED_APPS:
from ethicalads_ext.embedding.views import EmbeddingViewSet

urlpatterns += [path(r"similar/", EmbeddingViewSet.as_view(), name="similar")]

Expand Down
12 changes: 12 additions & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
except ImproperlyConfigured:
log.info("Unable to read env file. Assuming environment is already set.")

# This is a bit of a hack to allow us to import the ethicalads_ext package
# which contains private extensions to the ad server.
try:
import ethicalads_ext # noqa

ext = True
except ImportError:
ext = False

# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.abspath(
Expand Down Expand Up @@ -69,6 +77,10 @@
"corsheaders",
]


if ext:
INSTALLED_APPS.append("ethicalads_ext.embedding")

MIDDLEWARE = [
"django.middleware.security.SecurityMiddleware",
"enforce_host.EnforceHostMiddleware",
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ services:
# Make it so we can edit the start script dynamically,
# for example to install dependencies
- ./docker-compose/django/start:/start
# Load the ethicalads_ext code from the host, so we don't have to rebuild
- ${PWD}/${EA_EXT_PATH:-../ethicalads-ext/ethicalads_ext}:/app/ethicalads_ext
env_file:
- ./.envs/local/django
- ./.envs/local/postgres
Expand Down

0 comments on commit 3202a68

Please sign in to comment.