Skip to content

Commit 11eb30b

Browse files
committed
Refactor create messages and stats validation into class
1 parent a1af5f5 commit 11eb30b

File tree

7 files changed

+60
-246
lines changed

7 files changed

+60
-246
lines changed

listenbrainz_spark/path.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
LISTENBRAINZ_BASE_STATS_DIRECTORY = os.path.join('/', 'stats')
99
LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'sitewide')
1010
LISTENBRAINZ_USER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'user')
11-
12-
LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY = os.path.join('/', 'listener_stats_aggregates')
13-
LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY = os.path.join('/', 'listener_stats_bookkeeping')
11+
LISTENBRAINZ_LISTENER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'listener')
1412

1513
# MLHD+ dump files
1614
MLHD_PLUS_RAW_DATA_DIRECTORY = os.path.join("/", "mlhd-raw")

listenbrainz_spark/stats/incremental/listener/artist.py

+8-17
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,26 @@
55
from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME
66
from listenbrainz_spark.stats import run_query
77
from listenbrainz_spark.stats.incremental.listener.entity import EntityListener
8-
from listenbrainz_spark.stats.incremental.user.entity import UserEntity
98

109

1110
class ArtistEntityListener(EntityListener):
1211

13-
def __init__(self):
14-
super().__init__(entity="artists")
12+
def __init__(self, stats_range, database):
13+
super().__init__(entity="artists", stats_range=stats_range, database=database, message_type="entity_listener")
1514

1615
def get_cache_tables(self) -> List[str]:
1716
return [ARTIST_COUNTRY_CODE_DATAFRAME]
1817

1918
def get_partial_aggregate_schema(self):
2019
return StructType([
21-
StructField('artist_name', StringType(), nullable=False),
22-
StructField('artist_mbid', StringType(), nullable=True),
23-
StructField('user_id', IntegerType(), nullable=False),
24-
StructField('listen_count', IntegerType(), nullable=False),
20+
StructField("artist_name", StringType(), nullable=False),
21+
StructField("artist_mbid", StringType(), nullable=True),
22+
StructField("user_id", IntegerType(), nullable=False),
23+
StructField("listen_count", IntegerType(), nullable=False),
2524
])
2625

27-
def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
28-
query = f"""
29-
WITH incremental_artists AS (
30-
SELECT DISTINCT artist_mbid FROM {incremental_aggregate}
31-
)
32-
SELECT *
33-
FROM {existing_aggregate} ea
34-
WHERE EXISTS(SELECT 1 FROM incremental_artists iu WHERE iu.artist_mbid = ea.artist_mbid)
35-
"""
36-
return run_query(query)
26+
def get_entity_id(self):
27+
return "artist_mbid"
3728

3829
def aggregate(self, table, cache_tables):
3930
cache_table = cache_tables[0]
Original file line numberDiff line numberDiff line change
@@ -1,132 +1,32 @@
11
import abc
22
import logging
3-
from datetime import datetime
4-
from pathlib import Path
5-
from typing import List
6-
7-
from pyspark.errors import AnalysisException
8-
from pyspark.sql import DataFrame
9-
from pyspark.sql.types import StructType, StructField, TimestampType
10-
11-
import listenbrainz_spark
12-
from listenbrainz_spark import hdfs_connection
13-
from listenbrainz_spark.config import HDFS_CLUSTER_URI
14-
from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH, \
15-
LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY, LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY
16-
from listenbrainz_spark.stats import run_query
17-
from listenbrainz_spark.utils import read_files_from_HDFS, get_listens_from_dump
3+
from datetime import date
4+
from typing import Optional
185

6+
from listenbrainz_spark.path import LISTENBRAINZ_LISTENER_STATS_DIRECTORY
7+
from listenbrainz_spark.stats.incremental.user.entity import UserEntity
198

209
logger = logging.getLogger(__name__)
21-
BOOKKEEPING_SCHEMA = StructType([
22-
StructField('from_date', TimestampType(), nullable=False),
23-
StructField('to_date', TimestampType(), nullable=False),
24-
StructField('created', TimestampType(), nullable=False),
25-
])
26-
2710

28-
class EntityListener(abc.ABC):
29-
30-
def __init__(self, entity):
31-
self.entity = entity
32-
33-
def get_existing_aggregate_path(self, stats_range) -> str:
34-
return f"{LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY}/{self.entity}/{stats_range}"
3511

36-
def get_bookkeeping_path(self, stats_range) -> str:
37-
return f"{LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY}/{self.entity}/{stats_range}"
12+
class EntityListener(UserEntity, abc.ABC):
3813

39-
def get_partial_aggregate_schema(self) -> StructType:
40-
raise NotImplementedError()
14+
def __init__(self, entity: str, stats_range: str, database: Optional[str], message_type: Optional[str]):
15+
if not database:
16+
database = f"{self.entity}_listeners_{self.stats_range}_{date.today().strftime('%Y%m%d')}"
17+
super().__init__(entity, stats_range, database, message_type)
4118

42-
def aggregate(self, table, cache_tables) -> DataFrame:
43-
raise NotImplementedError()
19+
def get_table_prefix(self) -> str:
20+
return f"{self.entity}_listener_{self.stats_range}"
4421

45-
def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
46-
raise NotImplementedError()
22+
def get_base_path(self) -> str:
23+
return LISTENBRAINZ_LISTENER_STATS_DIRECTORY
4724

48-
def combine_aggregates(self, existing_aggregate, incremental_aggregate) -> DataFrame:
25+
def get_entity_id(self):
4926
raise NotImplementedError()
5027

51-
def get_top_n(self, final_aggregate, N) -> DataFrame:
52-
raise NotImplementedError()
53-
54-
def get_cache_tables(self) -> List[str]:
55-
raise NotImplementedError()
56-
57-
def generate_stats(self, stats_range: str, from_date: datetime,
58-
to_date: datetime, top_entity_limit: int):
59-
cache_tables = []
60-
for idx, df_path in enumerate(self.get_cache_tables()):
61-
df_name = f"entity_data_cache_{idx}"
62-
cache_tables.append(df_name)
63-
read_files_from_HDFS(df_path).createOrReplaceTempView(df_name)
64-
65-
metadata_path = self.get_bookkeeping_path(stats_range)
66-
try:
67-
metadata = listenbrainz_spark \
68-
.session \
69-
.read \
70-
.schema(BOOKKEEPING_SCHEMA) \
71-
.json(f"{HDFS_CLUSTER_URI}{metadata_path}") \
72-
.collect()[0]
73-
existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"]
74-
existing_aggregate_usable = existing_from_date.date() == from_date.date()
75-
except AnalysisException:
76-
existing_aggregate_usable = False
77-
logger.info("Existing partial aggregate not found!")
78-
79-
prefix = f"entity_listener_{self.entity}_{stats_range}"
80-
existing_aggregate_path = self.get_existing_aggregate_path(stats_range)
81-
82-
only_inc_entities = True
83-
84-
if not hdfs_connection.client.status(existing_aggregate_path, strict=False) or not existing_aggregate_usable:
85-
table = f"{prefix}_full_listens"
86-
get_listens_from_dump(from_date, to_date, include_incremental=False).createOrReplaceTempView(table)
87-
88-
logger.info("Creating partial aggregate from full dump listens")
89-
hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent)
90-
full_df = self.aggregate(table, cache_tables)
91-
full_df.write.mode("overwrite").parquet(existing_aggregate_path)
92-
93-
hdfs_connection.client.makedirs(Path(metadata_path).parent)
94-
metadata_df = listenbrainz_spark.session.createDataFrame(
95-
[(from_date, to_date, datetime.now())],
96-
schema=BOOKKEEPING_SCHEMA
97-
)
98-
metadata_df.write.mode("overwrite").json(metadata_path)
99-
only_inc_entities = False
100-
101-
full_df = read_files_from_HDFS(existing_aggregate_path)
102-
103-
if hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False):
104-
table = f"{prefix}_incremental_listens"
105-
read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) \
106-
.createOrReplaceTempView(table)
107-
inc_df = self.aggregate(table, cache_tables)
108-
else:
109-
inc_df = listenbrainz_spark.session.createDataFrame([], schema=self.get_partial_aggregate_schema())
110-
only_inc_entities = False
111-
112-
full_table = f"{prefix}_existing_aggregate"
113-
full_df.createOrReplaceTempView(full_table)
114-
115-
inc_table = f"{prefix}_incremental_aggregate"
116-
inc_df.createOrReplaceTempView(inc_table)
117-
118-
if only_inc_entities:
119-
existing_table = f"{prefix}_filtered_aggregate"
120-
filtered_aggregate_df = self.filter_existing_aggregate(full_table, inc_table)
121-
filtered_aggregate_df.createOrReplaceTempView(existing_table)
122-
else:
123-
existing_table = full_table
124-
125-
combined_df = self.combine_aggregates(existing_table, inc_table)
126-
127-
combined_table = f"{prefix}_combined_aggregate"
128-
combined_df.createOrReplaceTempView(combined_table)
129-
results_df = self.get_top_n(combined_table, top_entity_limit)
28+
def items_per_message(self):
29+
return 10000
13030

131-
return only_inc_entities, results_df.toLocalIterator()
132-
31+
def parse_one_user_stats(self, entry: dict):
32+
raise entry

listenbrainz_spark/stats/incremental/listener/release_group.py

+16-22
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,37 @@
22

33
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType
44

5-
from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME, RELEASE_METADATA_CACHE_DATAFRAME, \
5+
from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME, \
66
RELEASE_GROUP_METADATA_CACHE_DATAFRAME
77
from listenbrainz_spark.stats import run_query
88
from listenbrainz_spark.stats.incremental.listener.entity import EntityListener
9-
from listenbrainz_spark.stats.incremental.user.entity import UserEntity
109

1110

1211
class ReleaseGroupEntityListener(EntityListener):
1312

14-
def __init__(self):
15-
super().__init__(entity="release_groups")
13+
def __init__(self, stats_range, database):
14+
super().__init__(
15+
entity="release_groups", stats_range=stats_range,
16+
database=database, message_type="entity_listener"
17+
)
1618

1719
def get_cache_tables(self) -> List[str]:
1820
return [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME]
1921

2022
def get_partial_aggregate_schema(self):
2123
return StructType([
22-
StructField('release_group_mbid', StringType(), nullable=False),
23-
StructField('release_group_name', StringType(), nullable=False),
24-
StructField('release_group_artist_name', StringType(), nullable=False),
25-
StructField('artist_credit_mbids', ArrayType(StringType()), nullable=False),
26-
StructField('caa_id', IntegerType(), nullable=True),
27-
StructField('caa_release_mbid', StringType(), nullable=True),
28-
StructField('user_id', IntegerType(), nullable=False),
29-
StructField('listen_count', IntegerType(), nullable=False),
24+
StructField("release_group_mbid", StringType(), nullable=False),
25+
StructField("release_group_name", StringType(), nullable=False),
26+
StructField("release_group_artist_name", StringType(), nullable=False),
27+
StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False),
28+
StructField("caa_id", IntegerType(), nullable=True),
29+
StructField("caa_release_mbid", StringType(), nullable=True),
30+
StructField("user_id", IntegerType(), nullable=False),
31+
StructField("listen_count", IntegerType(), nullable=False),
3032
])
3133

32-
def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
33-
query = f"""
34-
WITH incremental_release_groups AS (
35-
SELECT DISTINCT release_group_mbid FROM {incremental_aggregate}
36-
)
37-
SELECT *
38-
FROM {existing_aggregate} ea
39-
WHERE EXISTS(SELECT 1 FROM incremental_release_groups iu WHERE iu.release_group_mbid = ea.release_group_mbid)
40-
"""
41-
return run_query(query)
34+
def get_entity_id(self):
35+
return "release_group_mbid"
4236

4337
def aggregate(self, table, cache_tables):
4438
rel_cache_table = cache_tables[0]

listenbrainz_spark/stats/incremental/user/entity.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import abc
2-
import json
32
import logging
43
from datetime import date
54
from typing import Optional, Iterator, Dict, Tuple
@@ -15,10 +14,8 @@
1514
from listenbrainz_spark.path import LISTENBRAINZ_USER_STATS_DIRECTORY
1615
from listenbrainz_spark.stats import run_query
1716
from listenbrainz_spark.stats.incremental import IncrementalStats
18-
from listenbrainz_spark.stats.user import USERS_PER_MESSAGE
1917
from listenbrainz_spark.utils import read_files_from_HDFS
2018

21-
2219
logger = logging.getLogger(__name__)
2320

2421
entity_model_map = {
@@ -31,7 +28,7 @@
3128

3229
class UserEntity(IncrementalStats, abc.ABC):
3330

34-
def __init__(self, entity: str, stats_range: str, database: Optional[str], message_type: Optional[str]):
31+
def __init__(self, entity: str, stats_range: str, database: Optional[str], message_type: Optional[str]):
3532
super().__init__(entity, stats_range)
3633
if database:
3734
self.database = database
@@ -45,14 +42,22 @@ def get_base_path(self) -> str:
4542
def get_table_prefix(self) -> str:
4643
return f"user_{self.entity}_{self.stats_range}"
4744

45+
def get_entity_id(self):
46+
return "user_id"
47+
48+
def items_per_message(self):
49+
""" Get the number of items to chunk per message """
50+
return 25
51+
4852
def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
53+
entity_id = self.get_entity_id()
4954
query = f"""
5055
WITH incremental_users AS (
51-
SELECT DISTINCT user_id FROM {incremental_aggregate}
56+
SELECT DISTINCT {entity_id} FROM {incremental_aggregate}
5257
)
5358
SELECT *
5459
FROM {existing_aggregate} ea
55-
WHERE EXISTS(SELECT 1 FROM incremental_users iu WHERE iu.user_id = ea.user_id)
60+
WHERE EXISTS(SELECT 1 FROM incremental_users iu WHERE iu.{entity_id} = ea.{entity_id})
5661
"""
5762
return run_query(query)
5863

@@ -130,7 +135,7 @@ def create_messages(self, only_inc_users, results: DataFrame) -> Iterator[Dict]:
130135
to_ts = int(self.to_date.timestamp())
131136

132137
data = results.toLocalIterator()
133-
for entries in chunked(data, USERS_PER_MESSAGE):
138+
for entries in chunked(data, self.items_per_message()):
134139
multiple_user_stats = []
135140
for entry in entries:
136141
row = entry.asDict(recursive=True)

0 commit comments

Comments
 (0)