diff --git a/listenbrainz_spark/hdfs/upload.py b/listenbrainz_spark/hdfs/upload.py index 2d9171a54e..1fd8416526 100644 --- a/listenbrainz_spark/hdfs/upload.py +++ b/listenbrainz_spark/hdfs/upload.py @@ -6,7 +6,7 @@ import logging from typing import List -from listenbrainz_spark import schema, path, utils +from listenbrainz_spark import schema, path, utils, hdfs_connection from listenbrainz_spark.hdfs.utils import create_dir from listenbrainz_spark.hdfs.utils import delete_dir from listenbrainz_spark.hdfs.utils import path_exists @@ -180,3 +180,6 @@ def process_full_listens_dump(self): .partitionBy("year", "month") \ .mode("overwrite") \ .parquet(path.LISTENBRAINZ_INTERMEDIATE_STATS_DIRECTORY) + + if path_exists(path.LISTENBRAINZ_BASE_STATS_DIRECTORY): + hdfs_connection.client.delete(path.LISTENBRAINZ_BASE_STATS_DIRECTORY, recursive=True, skip_trash=True) diff --git a/listenbrainz_spark/path.py b/listenbrainz_spark/path.py index 03b65258c3..84282b805a 100644 --- a/listenbrainz_spark/path.py +++ b/listenbrainz_spark/path.py @@ -5,6 +5,9 @@ LISTENBRAINZ_INTERMEDIATE_STATS_DIRECTORY = os.path.join('/', 'data', 'stats-new') +LISTENBRAINZ_BASE_STATS_DIRECTORY = os.path.join('/', 'stats') +LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'sitewide') + # MLHD+ dump files MLHD_PLUS_RAW_DATA_DIRECTORY = os.path.join("/", "mlhd-raw") MLHD_PLUS_DATA_DIRECTORY = os.path.join("/", "mlhd") # processed MLHD+ dump data diff --git a/listenbrainz_spark/schema.py b/listenbrainz_spark/schema.py index e17151e06b..ada774875d 100644 --- a/listenbrainz_spark/schema.py +++ b/listenbrainz_spark/schema.py @@ -3,6 +3,15 @@ from pyspark.sql.types import StructField, StructType, ArrayType, StringType, TimestampType, FloatType, \ IntegerType, LongType +# Keeping track of the from_date and the to_date used to create the partial aggressive from full dump listens. +# Assuming dumps are imported twice a month, the aggregates for weekly stats need to be refreshed (generated from +# different range of listens in the full dump) sooner. The existing_aggrrgate_usable method reads this from/to date +# from bookkeeping path and compares it with current day's request to determine if the aggregate needs to be recreated. +BOOKKEEPING_SCHEMA = StructType([ + StructField('from_date', TimestampType(), nullable=False), + StructField('to_date', TimestampType(), nullable=False), + StructField('created', TimestampType(), nullable=False), +]) mlhd_schema = StructType([ StructField('user_id', StringType(), nullable=False), diff --git a/listenbrainz_spark/stats/incremental/__init__.py b/listenbrainz_spark/stats/incremental/__init__.py new file mode 100644 index 0000000000..21d4651988 --- /dev/null +++ b/listenbrainz_spark/stats/incremental/__init__.py @@ -0,0 +1,197 @@ +import abc +from datetime import datetime +from pathlib import Path +from typing import List + +from pyspark.errors import AnalysisException +from pyspark.sql import DataFrame +from pyspark.sql.types import StructType + +import listenbrainz_spark +from listenbrainz_spark import hdfs_connection +from listenbrainz_spark.config import HDFS_CLUSTER_URI +from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH +from listenbrainz_spark.schema import BOOKKEEPING_SCHEMA +from listenbrainz_spark.stats import get_dates_for_stats_range +from listenbrainz_spark.utils import read_files_from_HDFS, logger, get_listens_from_dump + + +class IncrementalStats(abc.ABC): + """ + Provides a framework for generating incremental statistics for a given entity (e.g., users, tracks) + over a specified date range. + + In the ListenBrainz Spark cluster, full dump listens (which remain constant for ~15 days) and incremental listens + (ingested daily) are the two main sources of data. Incremental listens are cleared whenever a new full dump is + imported. Aggregating full dump listens daily for various statistics is inefficient since this data does not + change. + + To optimize this process: + + 1. A partial aggregate is generated from the full dump listens the first time a stat is requested. This partial + aggregate is stored in HDFS for future use, eliminating the need for redundant full dump aggregation. + 2. Incremental listens are aggregated daily. Although all incremental listens since the full dump’s import are + used (not just today’s), this introduces some redundant computation. + 3. The incremental aggregate is combined with the existing partial aggregate, forming a combined aggregate from + which final statistics are generated. + + For non-sitewide statistics, further optimization is possible: + + If an entity’s listens (e.g., for a user) are not present in the incremental listens, its statistics do not + need to be recalculated. Similarly, entity-level listener stats can skip recomputation when relevant data + is absent in incremental listens. + """ + + def __init__(self, entity: str, stats_range: str): + """ + Args: + entity: The entity for which statistics are generated. + stats_range: The statistics range to calculate the stats for. + """ + self.entity = entity + self.stats_range = stats_range + self.from_date, self.to_date = get_dates_for_stats_range(stats_range) + self._cache_tables = [] + + @abc.abstractmethod + def get_base_path(self) -> str: + """ Returns the base HDFS path for storing partial data and metadata for this category of statistics. """ + raise NotImplementedError() + + def get_existing_aggregate_path(self) -> str: + """ Returns the HDFS path for existing aggregate data. """ + return f"{self.get_base_path()}/aggregates/{self.entity}/{self.stats_range}" + + def get_bookkeeping_path(self) -> str: + """ Returns the HDFS path for bookkeeping metadata. """ + return f"{self.get_base_path()}/bookkeeping/{self.entity}/{self.stats_range}" + + @abc.abstractmethod + def get_partial_aggregate_schema(self) -> StructType: + """ Returns the spark schema of the partial aggregates created during generation of this stat. """ + raise NotImplementedError() + + @abc.abstractmethod + def aggregate(self, table: str, cache_tables: List[str]) -> DataFrame: + """ + Create partial aggregates from the given listens. + + Args: + table: The listen table to aggregation. + cache_tables: List of metadata cache tables. + + Returns: + DataFrame: The aggregated DataFrame. + """ + raise NotImplementedError() + + @abc.abstractmethod + def combine_aggregates(self, existing_aggregate: str, incremental_aggregate: str) -> DataFrame: + """ + Combines existing aggregate and incremental aggregate to get the final aggregate to obtain stats from. + + Args: + existing_aggregate: The table name for existing aggregate. + incremental_aggregate: The table name for incremental aggregate. + + Returns: + DataFrame: The combined DataFrame. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_top_n(self, final_aggregate: str, N: int) -> DataFrame: + """ + Obtain the top N entities for the given statistic from the final aggregate. + + Args: + final_aggregate: The table name for the final aggregate. + N: The number of top entities to retrieve. + + Returns: + DataFrame: The DataFrame containing the top N entities. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_cache_tables(self) -> List[str]: + """ Returns the list of HDFS paths for the metadata cache tables required by the statistic. """ + raise NotImplementedError() + + def setup_cache_tables(self): + """ Set up metadata cache tables by reading data from HDFS and creating temporary views. """ + cache_tables = [] + for idx, df_path in enumerate(self.get_cache_tables()): + df_name = f"entity_data_cache_{idx}" + cache_tables.append(df_name) + read_files_from_HDFS(df_path).createOrReplaceTempView(df_name) + self._cache_tables = cache_tables + + @abc.abstractmethod + def get_table_prefix(self) -> str: + """ Get the prefix for table names based on the stat type, entity and stats range. """ + raise NotImplementedError() + + def partial_aggregate_usable(self) -> bool: + """ Checks whether a partial aggregate exists and is fresh to generate the required stats. """ + metadata_path = self.get_bookkeeping_path() + existing_aggregate_path = self.get_existing_aggregate_path() + + try: + metadata = listenbrainz_spark \ + .session \ + .read \ + .schema(BOOKKEEPING_SCHEMA) \ + .json(f"{HDFS_CLUSTER_URI}{metadata_path}") \ + .collect()[0] + existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"] + existing_aggregate_fresh = existing_from_date.date() == self.from_date.date() + except AnalysisException: + existing_aggregate_fresh = False + + existing_aggregate_exists = hdfs_connection.client.status(existing_aggregate_path, strict=False) + + return existing_aggregate_fresh and existing_aggregate_exists + + def create_partial_aggregate(self) -> DataFrame: + """ + Create a new partial aggregate from full dump listens. + + Returns: + DataFrame: The generated partial aggregate DataFrame. + """ + metadata_path = self.get_bookkeeping_path() + existing_aggregate_path = self.get_existing_aggregate_path() + + table = f"{self.get_table_prefix()}_full_listens" + get_listens_from_dump(self.from_date, self.to_date, include_incremental=False).createOrReplaceTempView(table) + + logger.info("Creating partial aggregate from full dump listens") + hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent) + full_df = self.aggregate(table, self._cache_tables) + full_df.write.mode("overwrite").parquet(existing_aggregate_path) + + hdfs_connection.client.makedirs(Path(metadata_path).parent) + metadata_df = listenbrainz_spark.session.createDataFrame( + [(self.from_date, self.to_date, datetime.now())], + schema=BOOKKEEPING_SCHEMA + ) + metadata_df.write.mode("overwrite").json(metadata_path) + logger.info("Finished creating partial aggregate from full dump listens") + + return full_df + + def incremental_dump_exists(self) -> bool: + return hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False) + + def create_incremental_aggregate(self) -> DataFrame: + """ + Create an incremental aggregate from incremental listens. + + Returns: + DataFrame: The generated incremental aggregate DataFrame. + """ + table = f"{self.get_table_prefix()}_incremental_listens" + read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) \ + .createOrReplaceTempView(table) + return self.aggregate(table, self._cache_tables) diff --git a/listenbrainz_spark/stats/incremental/sitewide/__init__.py b/listenbrainz_spark/stats/incremental/sitewide/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/listenbrainz_spark/stats/incremental/sitewide/artist.py b/listenbrainz_spark/stats/incremental/sitewide/artist.py new file mode 100644 index 0000000000..a6f31d4b36 --- /dev/null +++ b/listenbrainz_spark/stats/incremental/sitewide/artist.py @@ -0,0 +1,111 @@ +from typing import List + +from pyspark.sql.types import StructType, StructField, StringType, IntegerType + +from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME +from listenbrainz_spark.stats import run_query +from listenbrainz_spark.stats.incremental.sitewide.entity import SitewideEntity + + +class AritstSitewideEntity(SitewideEntity): + + def __init__(self, stats_range): + super().__init__(entity="artists", stats_range=stats_range) + + def get_cache_tables(self) -> List[str]: + return [ARTIST_COUNTRY_CODE_DATAFRAME] + + def get_partial_aggregate_schema(self): + return StructType([ + StructField("artist_name", StringType(), nullable=False), + StructField("artist_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), + ]) + + def aggregate(self, table, cache_tables): + user_listen_count_limit = self.get_listen_count_limit() + cache_table = cache_tables[0] + result = run_query(f""" + WITH exploded_listens AS ( + SELECT user_id + , artist_name AS artist_credit_name + , explode_outer(artist_credit_mbids) AS artist_mbid + FROM {table} + ), listens_with_mb_data as ( + SELECT user_id + , COALESCE(at.artist_name, el.artist_credit_name) AS artist_name + , el.artist_mbid + FROM exploded_listens el + LEFT JOIN {cache_table} at + ON el.artist_mbid = at.artist_mbid + ), user_counts as ( + SELECT user_id + , first(artist_name) AS any_artist_name + , artist_mbid + , LEAST(count(*), {user_listen_count_limit}) as listen_count + FROM listens_with_mb_data + GROUP BY user_id + , lower(artist_name) + , artist_mbid + ) + SELECT first(any_artist_name) AS artist_name + , artist_mbid + , SUM(listen_count) as listen_count + FROM user_counts + GROUP BY lower(any_artist_name) + , artist_mbid + """) + return result + + def combine_aggregates(self, existing_aggregate, incremental_aggregate): + query = f""" + WITH intermediate_table AS ( + SELECT artist_name + , artist_mbid + , listen_count + FROM {existing_aggregate} + UNION ALL + SELECT artist_name + , artist_mbid + , listen_count + FROM {incremental_aggregate} + ) + SELECT first(artist_name) AS artist_name + , artist_mbid + , sum(listen_count) as listen_count + FROM intermediate_table + GROUP BY lower(artist_name) + , artist_mbid + """ + return run_query(query) + + def get_top_n(self, final_aggregate, N): + query = f""" + WITH entity_count AS ( + SELECT count(*) AS total_count + FROM {final_aggregate} + ), ordered_stats AS ( + SELECT * + FROM {final_aggregate} + ORDER BY listen_count DESC + LIMIT {N} + ), grouped_stats AS ( + SELECT sort_array( + collect_list( + struct( + listen_count + , artist_name + , artist_mbid + ) + ) + , false + ) AS stats + FROM ordered_stats + ) + SELECT total_count + , stats + FROM grouped_stats + JOIN entity_count + ON TRUE + """ + return run_query(query) diff --git a/listenbrainz_spark/stats/incremental/sitewide/entity.py b/listenbrainz_spark/stats/incremental/sitewide/entity.py new file mode 100644 index 0000000000..6f03398527 --- /dev/null +++ b/listenbrainz_spark/stats/incremental/sitewide/entity.py @@ -0,0 +1,114 @@ +import logging +from abc import ABC +from typing import Iterator, Dict + +from pydantic import ValidationError +from pyspark.sql import DataFrame + +from data.model.user_artist_stat import ArtistRecord +from data.model.user_recording_stat import RecordingRecord +from data.model.user_release_group_stat import ReleaseGroupRecord +from data.model.user_release_stat import ReleaseRecord +from listenbrainz_spark.path import LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY +from listenbrainz_spark.stats import SITEWIDE_STATS_ENTITY_LIMIT +from listenbrainz_spark.stats.incremental import IncrementalStats +from listenbrainz_spark.utils import read_files_from_HDFS + +logger = logging.getLogger(__name__) + +entity_model_map = { + "artists": ArtistRecord, + "releases": ReleaseRecord, + "recordings": RecordingRecord, + "release_groups": ReleaseGroupRecord, +} + + +class SitewideEntity(IncrementalStats, ABC): + + def get_base_path(self) -> str: + return LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY + + def get_table_prefix(self) -> str: + return f"sitewide_{self.entity}_{self.stats_range}" + + def get_listen_count_limit(self) -> int: + """ Return the per user per entity listen count above which it should + be capped. The rationale is to avoid a single user's listens from + over-influencing the sitewide stats. + + For instance: if the limit for yearly recordings count is 500 and a user + listens to a particular recording for 10000 times, it will be counted as + 500 for calculating the stat. + """ + return 500 + + def generate_stats(self, top_entity_limit: int) -> DataFrame: + """ + Generate statistics of the given type, entity and stats range. + + Args: + top_entity_limit (int): The maximum number of top entities to retrieve. + + Returns a DataFrame of the results. + """ + self.setup_cache_tables() + prefix = self.get_table_prefix() + + if not self.partial_aggregate_usable(): + self.create_partial_aggregate() + partial_df = read_files_from_HDFS(self.get_existing_aggregate_path()) + partial_table = f"{prefix}_existing_aggregate" + partial_df.createOrReplaceTempView(partial_table) + + if self.incremental_dump_exists(): + inc_df = self.create_incremental_aggregate() + inc_table = f"{prefix}_incremental_aggregate" + inc_df.createOrReplaceTempView(inc_table) + final_df = self.combine_aggregates(partial_table, inc_table) + else: + final_df = partial_df + + final_table = f"{prefix}_final_aggregate" + final_df.createOrReplaceTempView(final_table) + + return self.get_top_n(final_table, top_entity_limit) + + def create_messages(self, results: DataFrame) -> Iterator[Dict]: + """ + Create messages to send the data to the webserver via RabbitMQ + + Args: + results: Data to sent to the webserver + + Returns: + messages: A list of messages to be sent via RabbitMQ + """ + message = { + "type": "sitewide_entity", + "stats_range": self.stats_range, + "from_ts": int(self.from_date.timestamp()), + "to_ts": int(self.to_date.timestamp()), + "entity": self.entity, + } + entry = results.collect()[0].asDict(recursive=True) + stats = entry["stats"] + count = entry["total_count"] + + entity_list = [] + for item in stats: + try: + entity_model_map[self.entity](**item) + entity_list.append(item) + except ValidationError: + logger.error("Invalid entry in entity stats", exc_info=True) + count -= 1 + + message["count"] = count + message["data"] = entity_list + + yield message + + def main(self): + results = self.generate_stats(SITEWIDE_STATS_ENTITY_LIMIT) + return self.create_messages(results) diff --git a/listenbrainz_spark/stats/incremental/sitewide/listening_activity.py b/listenbrainz_spark/stats/incremental/sitewide/listening_activity.py new file mode 100644 index 0000000000..d8fb946cfe --- /dev/null +++ b/listenbrainz_spark/stats/incremental/sitewide/listening_activity.py @@ -0,0 +1,81 @@ +from typing import List, Iterator, Dict + +from pyspark.sql import DataFrame +from pyspark.sql.types import StructType, StructField, StringType, IntegerType + +from listenbrainz_spark.stats import run_query +from listenbrainz_spark.stats.common.listening_activity import setup_time_range +from listenbrainz_spark.stats.incremental.sitewide.entity import SitewideEntity + + +class ListeningActivitySitewideEntity(SitewideEntity): + + def __init__(self, stats_range): + super().__init__(entity="listening_activity", stats_range=stats_range) + self.from_date, self.to_date, _, __, self.spark_date_format = setup_time_range(stats_range) + + def get_cache_tables(self) -> List[str]: + return [] + + def get_partial_aggregate_schema(self): + return StructType([ + StructField("time_range", StringType(), nullable=False), + StructField("listen_count", IntegerType(), nullable=False), + ]) + + def aggregate(self, table, cache_tables): + result = run_query(f""" + SELECT date_format(listened_at, '{self.spark_date_format}') AS time_range + , count(listened_at) AS listen_count + FROM {table} + GROUP BY time_range + """) + return result + + def combine_aggregates(self, existing_aggregate, incremental_aggregate): + query = f""" + WITH intermediate_table AS ( + SELECT time_range + , listen_count + FROM {existing_aggregate} + UNION ALL + SELECT time_range + , listen_count + FROM {incremental_aggregate} + ) + SELECT time_range + , sum(listen_count) as listen_count + FROM intermediate_table + GROUP BY time_range + """ + return run_query(query) + + def get_top_n(self, final_aggregate, N): + query = f""" + SELECT sort_array( + collect_list( + struct( + to_unix_timestamp(start) AS from_ts + , to_unix_timestamp(end) AS to_ts + , time_range + , COALESCE(listen_count, 0) AS listen_count + ) + ) + ) AS listening_activity + FROM time_range + LEFT JOIN {final_aggregate} + USING (time_range) + """ + return run_query(query) + + def create_messages(self, results: DataFrame) -> Iterator[Dict]: + message = { + "type": "sitewide_listening_activity", + "stats_range": self.stats_range, + "from_ts": int(self.from_date.timestamp()), + "to_ts": int(self.to_date.timestamp()) + } + data = results.collect()[0] + _dict = data.asDict(recursive=True) + message["data"] = _dict["listening_activity"] + yield message diff --git a/listenbrainz_spark/stats/incremental/sitewide/recording.py b/listenbrainz_spark/stats/incremental/sitewide/recording.py new file mode 100644 index 0000000000..f6f869d319 --- /dev/null +++ b/listenbrainz_spark/stats/incremental/sitewide/recording.py @@ -0,0 +1,161 @@ +from typing import List + +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType + +from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME +from listenbrainz_spark.stats import run_query +from listenbrainz_spark.stats.incremental.sitewide.entity import SitewideEntity + + +class RecordingSitewideEntity(SitewideEntity): + + def __init__(self, stats_range): + super().__init__(entity="recordings", stats_range=stats_range) + + def get_cache_tables(self) -> List[str]: + return [RELEASE_METADATA_CACHE_DATAFRAME] + + def get_partial_aggregate_schema(self): + return StructType([ + StructField("recording_name", StringType(), nullable=False), + StructField("recording_mbid", StringType(), nullable=True), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_credit_mbids", ArrayType(StringType()), nullable=True), + StructField("release_name", StringType(), nullable=True), + StructField("release_mbid", StringType(), nullable=True), + StructField("caa_id", IntegerType(), nullable=True), + StructField("caa_release_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), + ]) + + def aggregate(self, table, cache_tables): + user_listen_count_limit = self.get_listen_count_limit() + cache_table = cache_tables[0] + result = run_query(f""" + WITH user_counts as ( + SELECT user_id + , first(l.recording_name) AS recording_name + , nullif(l.recording_mbid, '') AS recording_mbid + , first(l.artist_name) AS artist_name + , l.artist_credit_mbids + , nullif(first(l.release_name), '') as release_name + , l.release_mbid + , rel.caa_id + , rel.caa_release_mbid + , LEAST(count(*), {user_listen_count_limit}) as listen_count + FROM {table} l + LEFT JOIN {cache_table} rel + ON rel.release_mbid = l.release_mbid + GROUP BY l.user_id + , lower(l.recording_name) + , l.recording_mbid + , lower(l.artist_name) + , l.artist_credit_mbids + , lower(l.release_name) + , l.release_mbid + , rel.caa_id + , rel.caa_release_mbid + ) + SELECT first(recording_name) AS recording_name + , recording_mbid + , first(artist_name) AS artist_name + , artist_credit_mbids + , nullif(first(release_name), '') as release_name + , release_mbid + , caa_id + , caa_release_mbid + , SUM(listen_count) as listen_count + FROM user_counts uc + GROUP BY lower(uc.recording_name) + , recording_mbid + , lower(uc.artist_name) + , artist_credit_mbids + , lower(release_name) + , release_mbid + , caa_id + , caa_release_mbid + """) + return result + + def combine_aggregates(self, existing_aggregate, incremental_aggregate): + query = f""" + WITH intermediate_table AS ( + SELECT recording_name + , recording_mbid + , artist_name + , artist_credit_mbids + , release_name + , release_mbid + , caa_id + , caa_release_mbid + , listen_count + FROM {existing_aggregate} + UNION ALL + SELECT recording_name + , recording_mbid + , artist_name + , artist_credit_mbids + , release_name + , release_mbid + , caa_id + , caa_release_mbid + , listen_count + FROM {incremental_aggregate} + ) + SELECT first(recording_name) AS recording_name + , recording_mbid + , first(artist_name) AS artist_name + , artist_credit_mbids + , first(release_name) AS release_name + , release_mbid + , caa_id + , caa_release_mbid + , sum(listen_count) as listen_count + FROM intermediate_table + GROUP BY lower(recording_name) + , recording_mbid + , lower(artist_name) + , artist_credit_mbids + , lower(release_name) + , release_mbid + , caa_id + , caa_release_mbid + """ + return run_query(query) + + def get_top_n(self, final_aggregate, N): + query = f""" + WITH entity_count AS ( + SELECT count(*) AS total_count + FROM {final_aggregate} + ), ordered_stats AS ( + SELECT * + FROM {final_aggregate} + ORDER BY listen_count DESC + LIMIT {N} + ), grouped_stats AS ( + SELECT sort_array( + collect_list( + struct( + listen_count + , recording_name AS track_name + , recording_mbid + , artist_name + , coalesce(artist_credit_mbids, array()) AS artist_mbids + , release_name + , release_mbid + , caa_id + , caa_release_mbid + ) + ) + , false + ) AS stats + FROM ordered_stats + ) + SELECT total_count + , stats + FROM grouped_stats + JOIN entity_count + ON TRUE + """ + return run_query(query) diff --git a/listenbrainz_spark/stats/incremental/sitewide/release.py b/listenbrainz_spark/stats/incremental/sitewide/release.py new file mode 100644 index 0000000000..32c8c622fa --- /dev/null +++ b/listenbrainz_spark/stats/incremental/sitewide/release.py @@ -0,0 +1,152 @@ +from typing import List + +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType + +from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME +from listenbrainz_spark.stats import run_query +from listenbrainz_spark.stats.incremental.sitewide.entity import SitewideEntity + + +class ReleaseSitewideEntity(SitewideEntity): + + def __init__(self): + super().__init__(entity="releases") + + def get_cache_tables(self) -> List[str]: + return [RELEASE_METADATA_CACHE_DATAFRAME] + + def get_partial_aggregate_schema(self): + return StructType([ + StructField("release_name", StringType(), nullable=False), + StructField("release_mbid", StringType(), nullable=False), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False), + StructField("caa_id", IntegerType(), nullable=True), + StructField("caa_release_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), + ]) + + def aggregate(self, table, cache_tables): + user_listen_count_limit = self.get_listen_count_limit() + cache_table = cache_tables[0] + result = run_query(f""" + WITH gather_release_data AS ( + SELECT user_id + , l.release_mbid + , COALESCE(rel.release_name, l.release_name) AS release_name + , COALESCE(rel.album_artist_name, l.artist_name) AS release_artist_name + , COALESCE(rel.artist_credit_mbids, l.artist_credit_mbids) AS artist_credit_mbids + , rel.caa_id + , rel.caa_release_mbid + FROM {table} l + LEFT JOIN {cache_table} rel + ON rel.release_mbid = l.release_mbid + ), user_counts AS ( + SELECT user_id + , first(release_name) AS any_release_name + , release_mbid + , first(release_artist_name) AS any_artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , LEAST(count(*), {user_listen_count_limit}) as listen_count + FROM gather_release_data + WHERE release_name != '' + AND release_name IS NOT NULL + GROUP BY user_id + , lower(release_name) + , release_mbid + , lower(release_artist_name) + , artist_credit_mbids + , caa_id + , caa_release_mbid + ) + SELECT first(any_release_name) AS release_name + , release_mbid + , first(any_artist_name) AS artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , SUM(listen_count) as listen_count + FROM user_counts + GROUP BY lower(any_release_name) + , release_mbid + , lower(any_artist_name) + , artist_credit_mbids + , caa_id + , caa_release_mbid + """) + return result + + def combine_aggregates(self, existing_aggregate, incremental_aggregate): + query = f""" + WITH intermediate_table AS ( + SELECT release_name + , release_mbid + , artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , listen_count + FROM {existing_aggregate} + UNION ALL + SELECT release_name + , release_mbid + , artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , listen_count + FROM {incremental_aggregate} + ) + SELECT first(release_name) AS release_name + , release_mbid + , first(artist_name) AS artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , sum(listen_count) as listen_count + FROM intermediate_table + GROUP BY lower(release_name) + , release_mbid + , lower(artist_name) + , artist_credit_mbids + , caa_id + , caa_release_mbid + """ + return run_query(query) + + def get_top_n(self, final_aggregate, N): + query = f""" + WITH entity_count AS ( + SELECT count(*) AS total_count + FROM {final_aggregate} + ), ordered_stats AS ( + SELECT * + FROM {final_aggregate} + ORDER BY listen_count DESC + LIMIT {N} + ), grouped_stats AS ( + SELECT sort_array( + collect_list( + struct( + listen_count + , release_name + , release_mbid + , artist_name + , coalesce(artist_credit_mbids, array()) AS artist_mbids + , caa_id + , caa_release_mbid + ) + ) + , false + ) AS stats + FROM ordered_stats + ) + SELECT total_count + , stats + FROM grouped_stats + JOIN entity_count + ON TRUE + """ + return run_query(query) diff --git a/listenbrainz_spark/stats/incremental/sitewide/release_group.py b/listenbrainz_spark/stats/incremental/sitewide/release_group.py new file mode 100644 index 0000000000..1f7192e5a7 --- /dev/null +++ b/listenbrainz_spark/stats/incremental/sitewide/release_group.py @@ -0,0 +1,157 @@ +from typing import List + +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType + +from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME +from listenbrainz_spark.stats import run_query +from listenbrainz_spark.stats.incremental.sitewide.entity import SitewideEntity + + +class ReleaseGroupSitewideEntity(SitewideEntity): + + def __init__(self, stats_range): + super().__init__(entity="release_groups", stats_range=stats_range) + + def get_cache_tables(self) -> List[str]: + return [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME] + + def get_partial_aggregate_schema(self): + return StructType([ + StructField("release_group_name", StringType(), nullable=False), + StructField("release_group_mbid", StringType(), nullable=False), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False), + StructField("caa_id", IntegerType(), nullable=True), + StructField("caa_release_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), + ]) + + def aggregate(self, table, cache_tables): + user_listen_count_limit = self.get_listen_count_limit() + rel_cache_table = cache_tables[0] + rg_cache_table = cache_tables[1] + result = run_query(f""" + WITH gather_release_group_data AS ( + SELECT l.user_id + , rg.release_group_mbid + -- this is intentional as we don't have a release group name field in listen submission json + -- and for the purposes of this stat, they'd usually be the same. + , COALESCE(rg.title, l.release_name) AS release_group_name + , COALESCE(rg.artist_credit_name, l.artist_name) AS release_group_artist_name + , COALESCE(rg.artist_credit_mbids, l.artist_credit_mbids) AS artist_credit_mbids + , rg.caa_id + , rg.caa_release_mbid + FROM {table} l + LEFT JOIN {rel_cache_table} rel + ON rel.release_mbid = l.release_mbid + LEFT JOIN {rg_cache_table} rg + ON rg.release_group_mbid = rel.release_group_mbid + ), user_counts as ( + SELECT user_id + , first(release_group_name) AS any_release_group_name + , release_group_mbid + , first(release_group_artist_name) AS any_artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , LEAST(count(*), {user_listen_count_limit}) as listen_count + FROM gather_release_group_data + WHERE release_group_name != '' + AND release_group_name IS NOT NULL + GROUP BY user_id + , lower(release_group_name) + , release_group_mbid + , lower(release_group_artist_name) + , artist_credit_mbids + , caa_id + , caa_release_mbid + ) + SELECT first(any_release_group_name) AS release_group_name + , release_group_mbid + , first(any_artist_name) AS artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , SUM(listen_count) as listen_count + FROM user_counts + GROUP BY lower(any_release_group_name) + , release_group_mbid + , lower(any_artist_name) + , artist_credit_mbids + , caa_id + , caa_release_mbid + """) + return result + + def combine_aggregates(self, existing_aggregate, incremental_aggregate): + query = f""" + WITH intermediate_table AS ( + SELECT release_group_name + , release_group_mbid + , artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , listen_count + FROM {existing_aggregate} + UNION ALL + SELECT release_group_name + , release_group_mbid + , artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , listen_count + FROM {incremental_aggregate} + ) + SELECT first(release_group_name) AS release_group_name + , release_group_mbid + , first(artist_name) AS artist_name + , artist_credit_mbids + , caa_id + , caa_release_mbid + , sum(listen_count) as listen_count + FROM intermediate_table + GROUP BY lower(release_group_name) + , release_group_mbid + , lower(artist_name) + , artist_credit_mbids + , caa_id + , caa_release_mbid + """ + return run_query(query) + + def get_top_n(self, final_aggregate, N): + query = f""" + WITH entity_count AS ( + SELECT count(*) AS total_count + FROM {final_aggregate} + ), ordered_stats AS ( + SELECT * + FROM {final_aggregate} + ORDER BY listen_count DESC + LIMIT {N} + ), grouped_stats AS ( + SELECT sort_array( + collect_list( + struct( + listen_count + , release_group_name + , release_group_mbid + , artist_name + , coalesce(artist_credit_mbids, array()) AS artist_mbids + , caa_id + , caa_release_mbid + ) + ) + , false + ) AS stats + FROM ordered_stats + ) + SELECT total_count + , stats + FROM grouped_stats + JOIN entity_count + ON TRUE + """ + return run_query(query) diff --git a/listenbrainz_spark/stats/sitewide/entity.py b/listenbrainz_spark/stats/sitewide/entity.py index 1035a0f304..4a9599c89b 100644 --- a/listenbrainz_spark/stats/sitewide/entity.py +++ b/listenbrainz_spark/stats/sitewide/entity.py @@ -1,4 +1,3 @@ -import json import logging from datetime import datetime from typing import List, Optional @@ -10,7 +9,12 @@ from data.model.user_release_stat import ReleaseRecord from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME, RELEASE_METADATA_CACHE_DATAFRAME, \ RELEASE_GROUP_METADATA_CACHE_DATAFRAME -from listenbrainz_spark.stats import get_dates_for_stats_range +from listenbrainz_spark.stats import get_dates_for_stats_range, SITEWIDE_STATS_ENTITY_LIMIT +from listenbrainz_spark.stats.incremental.sitewide.artist import AritstSitewideEntity +from listenbrainz_spark.stats.incremental.sitewide.entity import SitewideEntity +from listenbrainz_spark.stats.incremental.sitewide.recording import RecordingSitewideEntity +from listenbrainz_spark.stats.incremental.sitewide.release import ReleaseSitewideEntity +from listenbrainz_spark.stats.incremental.sitewide.release_group import ReleaseGroupSitewideEntity from listenbrainz_spark.stats.sitewide.artist import get_artists from listenbrainz_spark.stats.sitewide.recording import get_recordings from listenbrainz_spark.stats.sitewide.release import get_releases @@ -22,103 +26,17 @@ logger = logging.getLogger(__name__) -entity_handler_map = { - "artists": get_artists, - "releases": get_releases, - "recordings": get_recordings, - "release_groups": get_release_groups, +incremental_sitewide_map = { + "artists": AritstSitewideEntity, + "releases": ReleaseSitewideEntity, + "recordings": RecordingSitewideEntity, + "release_groups": ReleaseGroupSitewideEntity, } -entity_model_map = { - "artists": ArtistRecord, - "releases": ReleaseRecord, - "recordings": RecordingRecord, - "release_groups": ReleaseGroupRecord, -} - -entity_cache_map = { - "artists": [ARTIST_COUNTRY_CODE_DATAFRAME], - "releases": [RELEASE_METADATA_CACHE_DATAFRAME], - "recordings": [RELEASE_METADATA_CACHE_DATAFRAME], - "release_groups": [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME] -} - - -def get_listen_count_limit(stats_range: str) -> int: - """ Return the per user per entity listen count above which it should - be capped. The rationale is to avoid a single user's listens from - over-influencing the sitewide stats. - - For instance: if the limit for yearly recordings count is 500 and a user - listens to a particular recording for 10000 times, it will be counted as - 500 for calculating the stat. - """ - return 500 - def get_entity_stats(entity: str, stats_range: str) -> Optional[List[SitewideEntityStatMessage]]: """ Returns top entity stats for given time period """ logger.debug(f"Calculating sitewide_{entity}_{stats_range}...") - - from_date, to_date = get_dates_for_stats_range(stats_range) - listens_df = get_listens_from_dump(from_date, to_date) - table_name = f"sitewide_{entity}_{stats_range}" - listens_df.createOrReplaceTempView(table_name) - - listen_count_limit = get_listen_count_limit(stats_range) - - cache_dfs = [] - for idx, df_path in enumerate(entity_cache_map.get(entity)): - df_name = f"entity_data_cache_{idx}" - cache_dfs.append(df_name) - read_files_from_HDFS(df_path).createOrReplaceTempView(df_name) - - handler = entity_handler_map[entity] - data = handler(table_name, cache_dfs, listen_count_limit) - - messages = create_messages(data=data, entity=entity, stats_range=stats_range, - from_date=from_date, to_date=to_date) - - logger.debug("Done!") - - return messages - - -def create_messages(data, entity: str, stats_range: str, from_date: datetime, to_date: datetime): - """ - Create messages to send the data to the webserver via RabbitMQ - - Args: - data: Data to sent to the webserver - entity: The entity for which statistics are calculated, i.e 'artists', - 'releases' or 'recordings' - stats_range: The range for which the statistics have been calculated - from_date: The start time of the stats - to_date: The end time of the stats - - Returns: - messages: A list of messages to be sent via RabbitMQ - """ - message = { - "type": "sitewide_entity", - "stats_range": stats_range, - "from_ts": int(from_date.timestamp()), - "to_ts": int(to_date.timestamp()), - "entity": entity, - } - entry = next(data).asDict(recursive=True) - stats = entry["stats"] - count = entry["total_count"] - - entity_list = [] - for item in stats: - try: - entity_model_map[entity](**item) - entity_list.append(item) - except ValidationError: - logger.error("Invalid entry in entity stats", exc_info=True) - count -= 1 - message["count"] = count - message["data"] = entity_list - - return [message] + entity_cls = incremental_sitewide_map[entity] + entity_obj: SitewideEntity = entity_cls(stats_range) + return entity_obj.main() diff --git a/listenbrainz_spark/stats/sitewide/listening_activity.py b/listenbrainz_spark/stats/sitewide/listening_activity.py index 9f436ad55f..c1ff79bf66 100644 --- a/listenbrainz_spark/stats/sitewide/listening_activity.py +++ b/listenbrainz_spark/stats/sitewide/listening_activity.py @@ -1,71 +1,12 @@ -import json import logging from datetime import datetime from typing import Iterator, Optional, Dict -from pydantic import ValidationError - -from data.model.common_stat_spark import StatMessage -from data.model.user_listening_activity import ListeningActivityRecord -from listenbrainz_spark.stats import run_query -from listenbrainz_spark.stats.common.listening_activity import setup_time_range -from listenbrainz_spark.utils import get_listens_from_dump -from pyspark.sql.types import (StringType, StructField, StructType, - TimestampType) - -time_range_schema = StructType([ - StructField("time_range", StringType()), - StructField("start", TimestampType()), - StructField("end", TimestampType()) -]) - +from listenbrainz_spark.stats.incremental.sitewide.listening_activity import ListeningActivitySitewideEntity logger = logging.getLogger(__name__) -def calculate_listening_activity(spark_date_format): - """ Calculate number of listens for each user in time ranges given in the "time_range" table. - The time ranges are as follows: - 1) week - each day with weekday name of the past 2 weeks. - 2) month - each day the past 2 months. - 3) year - each month of the past 2 years. - 4) all_time - each year starting from LAST_FM_FOUNDING_YEAR (2002) - - Args: - spark_date_format: the date format - """ - # calculates the number of listens in each time range for each user, count(listened_at) so that - # group without listens are counted as 0, count(*) gives 1. - # this query is much different that the user listening activity stats query because an earlier - # version of this query which was similar to that caused OutOfMemory on yearly and all time - # ranges. It turns converting each listened_at to the needed date format and grouping by it is - # much cheaper than joining with a separate time range table. We still join the grouped data with - # a separate time range table to fill any gaps i.e. time ranges with no listens get a value of 0 - # instead of being completely omitted from the final result. - result = run_query(f""" - WITH bucket_listen_counts AS ( - SELECT date_format(listened_at, '{spark_date_format}') AS time_range - , count(listened_at) AS listen_count - FROM listens - GROUP BY time_range - ) - SELECT sort_array( - collect_list( - struct( - to_unix_timestamp(start) AS from_ts - , to_unix_timestamp(end) AS to_ts - , time_range - , COALESCE(listen_count, 0) AS listen_count - ) - ) - ) AS listening_activity - FROM time_range - LEFT JOIN bucket_listen_counts - USING (time_range) - """) - return result.toLocalIterator() - - def get_listening_activity(stats_range: str) -> Iterator[Optional[Dict]]: """ Compute the number of listens for a time range compared to the previous range @@ -76,34 +17,5 @@ def get_listening_activity(stats_range: str) -> Iterator[Optional[Dict]]: details). These values are used on the listening activity reports. """ logger.debug(f"Calculating listening_activity_{stats_range}") - from_date, to_date, _, _, spark_date_format = setup_time_range(stats_range) - get_listens_from_dump(from_date, to_date).createOrReplaceTempView("listens") - data = calculate_listening_activity(spark_date_format) - messages = create_messages(data=data, stats_range=stats_range, from_date=from_date, to_date=to_date) - logger.debug("Done!") - return messages - - -def create_messages(data, stats_range: str, from_date: datetime, to_date: datetime): - """ - Create messages to send the data to webserver via RabbitMQ - - Args: - data: Data to send to webserver - stats_range: The range for which the statistics have been calculated - from_date: The start time of the stats - to_date: The end time of the stats - Returns: - messages: A list of messages to be sent via RabbitMQ - """ - message = { - "type": "sitewide_listening_activity", - "stats_range": stats_range, - "from_ts": int(from_date.timestamp()), - "to_ts": int(to_date.timestamp()) - } - - _dict = next(data).asDict(recursive=True) - message["data"] = _dict["listening_activity"] - - return [message] + entity_obj = ListeningActivitySitewideEntity(stats_range) + return entity_obj.main() diff --git a/listenbrainz_spark/stats/sitewide/tests/__init__.py b/listenbrainz_spark/stats/sitewide/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/listenbrainz_spark/stats/sitewide/tests/test_sitewide_entity.py b/listenbrainz_spark/stats/sitewide/tests/test_sitewide_entity.py deleted file mode 100644 index 1005e3a5bd..0000000000 --- a/listenbrainz_spark/stats/sitewide/tests/test_sitewide_entity.py +++ /dev/null @@ -1,69 +0,0 @@ -import json -from datetime import datetime -from unittest.mock import MagicMock, patch - -from listenbrainz_spark.stats.sitewide import entity -from listenbrainz_spark.constants import LAST_FM_FOUNDING_YEAR -from listenbrainz_spark.stats.user.tests import StatsTestCase - - -class SitewideEntityTestCase(StatsTestCase): - - @classmethod - def setUpClass(cls): - super(SitewideEntityTestCase, cls).setUpClass() - entity.entity_handler_map['test'] = MagicMock(return_value="sample_test_data") - entity.entity_cache_map['test'] = [] - - @patch('listenbrainz_spark.stats.sitewide.entity.get_listens_from_dump') - @patch('listenbrainz_spark.stats.sitewide.entity.create_messages') - def test_get_entity_week(self, mock_create_messages, mock_get_listens): - entity.get_entity_stats('test', 'week') - from_date = datetime(2021, 8, 2) - to_date = datetime(2021, 8, 9) - mock_get_listens.assert_called_with(from_date, to_date) - mock_create_messages.assert_called_with(data='sample_test_data', entity='test', stats_range='week', - from_date=from_date, to_date=to_date) - - @patch('listenbrainz_spark.stats.sitewide.entity.get_listens_from_dump') - @patch('listenbrainz_spark.stats.sitewide.entity.create_messages') - def test_get_entity_month(self, mock_create_messages, mock_get_listens): - entity.get_entity_stats('test', 'month') - from_date = datetime(2021, 7, 1) - to_date = datetime(2021, 8, 1) - mock_get_listens.assert_called_with(from_date, to_date) - mock_create_messages.assert_called_with(data='sample_test_data', entity='test', stats_range='month', - from_date=from_date, to_date=to_date) - - @patch('listenbrainz_spark.stats.sitewide.entity.get_listens_from_dump') - @patch('listenbrainz_spark.stats.sitewide.entity.create_messages') - def test_get_entity_year(self, mock_create_messages, mock_get_listens): - entity.get_entity_stats('test', 'year') - from_date = datetime(2020, 1, 1) - to_date = datetime(2021, 1, 1) - mock_get_listens.assert_called_with(from_date, to_date) - mock_create_messages.assert_called_with(data='sample_test_data', entity='test', stats_range='year', - from_date=from_date, to_date=to_date) - - @patch('listenbrainz_spark.stats.sitewide.entity.get_listens_from_dump') - @patch('listenbrainz_spark.stats.sitewide.entity.create_messages') - def test_get_entity_all_time(self, mock_create_messages, mock_get_listens): - entity.get_entity_stats('test', 'all_time') - from_date = datetime(LAST_FM_FOUNDING_YEAR, 1, 1) - to_date = datetime(2021, 8, 9, 12, 22, 43) - mock_get_listens.assert_called_with(from_date, to_date) - mock_create_messages.assert_called_with(data='sample_test_data', entity='test', stats_range='all_time', - from_date=from_date, to_date=to_date) - - def test_skip_incorrect_artists_stats(self): - """ Test to check if entries with incorrect data is skipped for top sitewide artists """ - with open(self.path_to_data_file('sitewide_top_artists_incorrect.json')) as f: - data = json.load(f) - - mock_result = MagicMock() - mock_result.asDict.return_value = data - - message = entity.create_messages(iter([mock_result]), 'artists', 'all_time', datetime.now(), datetime.now()) - - # Only the first entry in file is valid, all others must be skipped - self.assertListEqual(data['stats'][:1], message[0]['data']) diff --git a/listenbrainz_spark/utils/__init__.py b/listenbrainz_spark/utils/__init__.py index 5cd18d8536..dee4bf5f8b 100644 --- a/listenbrainz_spark/utils/__init__.py +++ b/listenbrainz_spark/utils/__init__.py @@ -146,12 +146,13 @@ def get_listen_files_list() -> List[str]: return file_names -def get_listens_from_dump(start: datetime, end: datetime) -> DataFrame: +def get_listens_from_dump(start: datetime, end: datetime, include_incremental=True) -> DataFrame: """ Load listens with listened_at between from_ts and to_ts from HDFS in a spark dataframe. Args: start: minimum time to include a listen in the dataframe end: maximum time to include a listen in the dataframe + include_incremental: if True, also include listens from incremental dumps Returns: dataframe of listens with listened_at between start and end @@ -162,17 +163,22 @@ def get_listens_from_dump(start: datetime, end: datetime) -> DataFrame: full_df = get_intermediate_stats_df(start, end) df = df.union(full_df) - if hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False): + if include_incremental and hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False): inc_df = read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) df = df.union(inc_df) - df = df.where(f"listened_at >= to_timestamp('{start}')") - df = df.where(f"listened_at <= to_timestamp('{end}')") + if start: + df = df.where(f"listened_at >= to_timestamp('{start}')") + if end: + df = df.where(f"listened_at <= to_timestamp('{end}')") return df def get_intermediate_stats_df(start: datetime, end: datetime): + if start is None and end is None: + return read_files_from_HDFS(LISTENBRAINZ_INTERMEDIATE_STATS_DIRECTORY) + filters = [] current = start