From 0744abe87bacd1ff79672106b9bcaf93e6e4b816 Mon Sep 17 00:00:00 2001
From: Ross Masters <ross@rossmasters.com>
Date: Sat, 7 Sep 2024 01:56:14 +0100
Subject: [PATCH] fix: Chart cache-warmup task fails on Superset 4.0 (#28706)

---
 superset/tasks/cache.py                     |  6 +-
 superset/tasks/utils.py                     | 48 +++++++++++++++-
 tests/integration_tests/tasks/test_cache.py | 15 +++--
 tests/integration_tests/tasks/test_utils.py | 64 +++++++++++++++++++++
 4 files changed, 126 insertions(+), 7 deletions(-)
 create mode 100644 tests/integration_tests/tasks/test_utils.py

diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index d1051c8fcb89f..1509b0d0624aa 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -29,6 +29,7 @@
 from superset.models.dashboard import Dashboard
 from superset.models.slice import Slice
 from superset.tags.models import Tag, TaggedObject
+from superset.tasks.utils import fetch_csrf_token
 from superset.utils import json
 from superset.utils.date_parser import parse_human_datetime
 from superset.utils.machine_auth import MachineAuthProvider
@@ -219,7 +220,10 @@ def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
     """
     result = {}
     try:
-        url = get_url_path("Superset.warm_up_cache")
+        # Fetch CSRF token for API request
+        headers.update(fetch_csrf_token(headers))
+
+        url = get_url_path("ChartRestApi.warm_up_cache")
         logger.info("Fetching %s with payload %s", url, data)
         req = request.Request(
             url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
diff --git a/superset/tasks/utils.py b/superset/tasks/utils.py
index 5012330bbd43e..6fc799c4abc2c 100644
--- a/superset/tasks/utils.py
+++ b/superset/tasks/utils.py
@@ -17,12 +17,18 @@
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+import logging
+from http.client import HTTPResponse
+from typing import Optional, TYPE_CHECKING
+from urllib import request
 
+from celery.utils.log import get_task_logger
 from flask import current_app, g
 
 from superset.tasks.exceptions import ExecutorNotFoundError
 from superset.tasks.types import ExecutorType
+from superset.utils import json
+from superset.utils.urls import get_url_path
 
 if TYPE_CHECKING:
     from superset.models.dashboard import Dashboard
@@ -30,6 +36,10 @@
     from superset.reports.models import ReportSchedule
 
 
+logger = get_task_logger(__name__)
+logger.setLevel(logging.INFO)
+
+
 # pylint: disable=too-many-branches
 def get_executor(
     executor_types: list[ExecutorType],
@@ -92,3 +102,39 @@ def get_current_user() -> str | None:
         return user.username
 
     return None
+
+
+def fetch_csrf_token(
+    headers: dict[str, str], session_cookie_name: str = "session"
+) -> dict[str, str]:
+    """
+    Fetches a CSRF token for API requests
+
+    :param headers: A map of headers to use in the request, including the session cookie
+    :returns: A map of headers, including the session cookie and csrf token
+    """
+    url = get_url_path("SecurityRestApi.csrf_token")
+    logger.info("Fetching %s", url)
+    req = request.Request(url, headers=headers, method="GET")
+    response: HTTPResponse
+    with request.urlopen(req, timeout=600) as response:
+        body = response.read().decode("utf-8")
+        session_cookie: Optional[str] = None
+        cookie_headers = response.headers.get_all("set-cookie")
+        if cookie_headers:
+            for cookie in cookie_headers:
+                cookie = cookie.split(";", 1)[0]
+                name, value = cookie.split("=", 1)
+                if name == session_cookie_name:
+                    session_cookie = value
+                    break
+
+        if response.status == 200:
+            data = json.loads(body)
+            res = {"X-CSRF-Token": data["result"]}
+            if session_cookie is not None:
+                res["Cookie"] = session_cookie
+            return res
+
+    logger.error("Error fetching CSRF token, status code: %s", response.status)
+    return {}
diff --git a/tests/integration_tests/tasks/test_cache.py b/tests/integration_tests/tasks/test_cache.py
index 943b444f76936..6e8d3ffe03b4d 100644
--- a/tests/integration_tests/tasks/test_cache.py
+++ b/tests/integration_tests/tasks/test_cache.py
@@ -29,9 +29,10 @@
     ],
     ids=["Without trailing slash", "With trailing slash"],
 )
+@mock.patch("superset.tasks.cache.fetch_csrf_token")
 @mock.patch("superset.tasks.cache.request.Request")
 @mock.patch("superset.tasks.cache.request.urlopen")
-def test_fetch_url(mock_urlopen, mock_request_cls, base_url):
+def test_fetch_url(mock_urlopen, mock_request_cls, mock_fetch_csrf_token, base_url):
     from superset.tasks.cache import fetch_url
 
     mock_request = mock.MagicMock()
@@ -40,18 +41,22 @@ def test_fetch_url(mock_urlopen, mock_request_cls, base_url):
     mock_urlopen.return_value = mock.MagicMock()
     mock_urlopen.return_value.code = 200
 
+    initial_headers = {"Cookie": "cookie", "key": "value"}
+    csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"}
+    mock_fetch_csrf_token.return_value = csrf_headers
+
     app.config["WEBDRIVER_BASEURL"] = base_url
-    headers = {"key": "value"}
     data = "data"
     data_encoded = b"data"
 
-    result = fetch_url(data, headers)
+    result = fetch_url(data, initial_headers)
 
     assert data == result["success"]
+    mock_fetch_csrf_token.assert_called_once_with(initial_headers)
     mock_request_cls.assert_called_once_with(
-        "http://base-url/superset/warm_up_cache/",
+        "http://base-url/api/v1/chart/warm_up_cache",
         data=data_encoded,
-        headers=headers,
+        headers=csrf_headers,
         method="PUT",
     )
     # assert the same Request object is used
diff --git a/tests/integration_tests/tasks/test_utils.py b/tests/integration_tests/tasks/test_utils.py
new file mode 100644
index 0000000000000..b1213b78c85a0
--- /dev/null
+++ b/tests/integration_tests/tasks/test_utils.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import mock
+
+import pytest
+
+from tests.integration_tests.test_app import app
+
+
+@pytest.mark.parametrize(
+    "base_url",
+    [
+        "http://base-url",
+        "http://base-url/",
+    ],
+    ids=["Without trailing slash", "With trailing slash"],
+)
+@mock.patch("superset.tasks.cache.request.Request")
+@mock.patch("superset.tasks.cache.request.urlopen")
+def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context):
+    from superset.tasks.utils import fetch_csrf_token
+
+    mock_request = mock.MagicMock()
+    mock_request_cls.return_value = mock_request
+
+    mock_response = mock.MagicMock()
+    mock_urlopen.return_value.__enter__.return_value = mock_response
+
+    mock_response.status = 200
+    mock_response.read.return_value = b'{"result": "csrf_token"}'
+    mock_response.headers.get_all.return_value = [
+        "session=new_session_cookie",
+        "async-token=websocket_cookie",
+    ]
+
+    app.config["WEBDRIVER_BASEURL"] = base_url
+    headers = {"Cookie": "original_session_cookie"}
+
+    result_headers = fetch_csrf_token(headers)
+
+    mock_request_cls.assert_called_with(
+        "http://base-url/api/v1/security/csrf_token/",
+        headers=headers,
+        method="GET",
+    )
+
+    assert result_headers["X-CSRF-Token"] == "csrf_token"
+    assert result_headers["Cookie"] == "new_session_cookie"
+    # assert the same Request object is used
+    mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)