Skip to content

Commit aefe978

Browse files
milesgrangerjrbourbeaudouglasdavisncclementi
authored
Workflow: CSV to parquet (#841)
Co-authored-by: James Bourbeau <[email protected]> Co-authored-by: Doug Davis <[email protected]> Co-authored-by: Naty Clementi <[email protected]>
1 parent 4877e22 commit aefe978

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

cluster_kwargs.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,10 @@ test_work_stealing_on_straggling_worker:
7575
test_repeated_merge_spill:
7676
n_workers: 20
7777
worker_vm_types: [m6i.large]
78+
79+
# For tests/workflows/test_from_csv_to_parquet.py
80+
from_csv_to_parquet:
81+
n_workers: 10
82+
worker_vm_types: [m6i.xlarge] # 4 CPU, 16 GiB (preferred default instance)
83+
backend_options:
84+
region: "us-east-1" # Same region as dataset
+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from collections import OrderedDict
2+
3+
import dask.dataframe as dd
4+
import pytest
5+
6+
SCHEMA = OrderedDict(
7+
[
8+
("GlobalEventID", "Int64"),
9+
("Day", "Int64"),
10+
("MonthYear", "Int64"),
11+
("Year", "Int64"),
12+
("FractionDate", "float64"),
13+
("Actor1Code", "string[pyarrow]"),
14+
("Actor1Name", "string[pyarrow]"),
15+
("Actor1CountryCode", "string[pyarrow]"),
16+
("Actor1KnownGroupCode", "string[pyarrow]"),
17+
("Actor1EthnicCode", "string[pyarrow]"),
18+
("Actor1Religion1Code", "string[pyarrow]"),
19+
("Actor1Religion2Code", "string[pyarrow]"),
20+
("Actor1Type1Code", "string[pyarrow]"),
21+
("Actor1Type2Code", "string[pyarrow]"),
22+
("Actor1Type3Code", "string[pyarrow]"),
23+
("Actor2Code", "string[pyarrow]"),
24+
("Actor2Name", "string[pyarrow]"),
25+
("Actor2CountryCode", "string[pyarrow]"),
26+
("Actor2KnownGroupCode", "string[pyarrow]"),
27+
("Actor2EthnicCode", "string[pyarrow]"),
28+
("Actor2Religion1Code", "string[pyarrow]"),
29+
("Actor2Religion2Code", "string[pyarrow]"),
30+
("Actor2Type1Code", "string[pyarrow]"),
31+
("Actor2Type2Code", "string[pyarrow]"),
32+
("Actor2Type3Code", "string[pyarrow]"),
33+
("IsRootEvent", "Int64"),
34+
("EventCode", "string[pyarrow]"),
35+
("EventBaseCode", "string[pyarrow]"),
36+
("EventRootCode", "string[pyarrow]"),
37+
("QuadClass", "Int64"),
38+
("GoldsteinScale", "float64"),
39+
("NumMentions", "Int64"),
40+
("NumSources", "Int64"),
41+
("NumArticles", "Int64"),
42+
("AvgTone", "float64"),
43+
("Actor1Geo_Type", "Int64"),
44+
("Actor1Geo_Fullname", "string[pyarrow]"),
45+
("Actor1Geo_CountryCode", "string[pyarrow]"),
46+
("Actor1Geo_ADM1Code", "string[pyarrow]"),
47+
("Actor1Geo_Lat", "float64"),
48+
("Actor1Geo_Long", "float64"),
49+
("Actor1Geo_FeatureID", "string[pyarrow]"),
50+
("Actor2Geo_Type", "Int64"),
51+
("Actor2Geo_Fullname", "string[pyarrow]"),
52+
("Actor2Geo_CountryCode", "string[pyarrow]"),
53+
("Actor2Geo_ADM1Code", "string[pyarrow]"),
54+
("Actor2Geo_Lat", "float64"),
55+
("Actor2Geo_Long", "float64"),
56+
("Actor2Geo_FeatureID", "string[pyarrow]"),
57+
("ActionGeo_Type", "Int64"),
58+
("ActionGeo_Fullname", "string[pyarrow]"),
59+
("ActionGeo_CountryCode", "string[pyarrow]"),
60+
("ActionGeo_ADM1Code", "string[pyarrow]"),
61+
("ActionGeo_Lat", "float64"),
62+
("ActionGeo_Long", "float64"),
63+
("ActionGeo_FeatureID", "string[pyarrow]"),
64+
("DATEADDED", "Int64"),
65+
("SOURCEURL", "string[pyarrow]"),
66+
]
67+
)
68+
69+
70+
@pytest.mark.client("from_csv_to_parquet")
71+
def test_from_csv_to_parquet(client, s3_factory, s3_url):
72+
s3 = s3_factory(anon=True)
73+
files = s3.ls("s3://gdelt-open-data/events/")[:1000]
74+
files = [f"s3://{f}" for f in files]
75+
76+
df = dd.read_csv(
77+
files,
78+
sep="\t",
79+
names=SCHEMA.keys(),
80+
# 'dtype' and 'converters' cannot overlap
81+
dtype={col: dtype for col, dtype in SCHEMA.items() if dtype != "float64"},
82+
storage_options=s3.storage_options,
83+
on_bad_lines="skip",
84+
# Some bad files have '#' in float values
85+
converters={
86+
col: lambda v: float(v.replace("#", "") or "NaN")
87+
for col, dtype in SCHEMA.items()
88+
if dtype == "float64"
89+
},
90+
)
91+
92+
# Now we can safely convert the float columns
93+
df = df.astype({col: dtype for col, dtype in SCHEMA.items() if dtype == "float64"})
94+
95+
df = df.map_partitions(
96+
lambda xdf: xdf.drop_duplicates(subset=["SOURCEURL"], keep="first")
97+
)
98+
df["national_paper"] = df.SOURCEURL.str.contains(
99+
"washingtonpost|nytimes", regex=True
100+
)
101+
df = df[df["national_paper"]]
102+
df.to_parquet(f"{s3_url}/from-csv-to-parquet/", write_index=False)

0 commit comments

Comments
 (0)