|
| 1 | +from collections import OrderedDict |
| 2 | + |
| 3 | +import dask.dataframe as dd |
| 4 | +import pytest |
| 5 | + |
| 6 | +SCHEMA = OrderedDict( |
| 7 | + [ |
| 8 | + ("GlobalEventID", "Int64"), |
| 9 | + ("Day", "Int64"), |
| 10 | + ("MonthYear", "Int64"), |
| 11 | + ("Year", "Int64"), |
| 12 | + ("FractionDate", "float64"), |
| 13 | + ("Actor1Code", "string[pyarrow]"), |
| 14 | + ("Actor1Name", "string[pyarrow]"), |
| 15 | + ("Actor1CountryCode", "string[pyarrow]"), |
| 16 | + ("Actor1KnownGroupCode", "string[pyarrow]"), |
| 17 | + ("Actor1EthnicCode", "string[pyarrow]"), |
| 18 | + ("Actor1Religion1Code", "string[pyarrow]"), |
| 19 | + ("Actor1Religion2Code", "string[pyarrow]"), |
| 20 | + ("Actor1Type1Code", "string[pyarrow]"), |
| 21 | + ("Actor1Type2Code", "string[pyarrow]"), |
| 22 | + ("Actor1Type3Code", "string[pyarrow]"), |
| 23 | + ("Actor2Code", "string[pyarrow]"), |
| 24 | + ("Actor2Name", "string[pyarrow]"), |
| 25 | + ("Actor2CountryCode", "string[pyarrow]"), |
| 26 | + ("Actor2KnownGroupCode", "string[pyarrow]"), |
| 27 | + ("Actor2EthnicCode", "string[pyarrow]"), |
| 28 | + ("Actor2Religion1Code", "string[pyarrow]"), |
| 29 | + ("Actor2Religion2Code", "string[pyarrow]"), |
| 30 | + ("Actor2Type1Code", "string[pyarrow]"), |
| 31 | + ("Actor2Type2Code", "string[pyarrow]"), |
| 32 | + ("Actor2Type3Code", "string[pyarrow]"), |
| 33 | + ("IsRootEvent", "Int64"), |
| 34 | + ("EventCode", "string[pyarrow]"), |
| 35 | + ("EventBaseCode", "string[pyarrow]"), |
| 36 | + ("EventRootCode", "string[pyarrow]"), |
| 37 | + ("QuadClass", "Int64"), |
| 38 | + ("GoldsteinScale", "float64"), |
| 39 | + ("NumMentions", "Int64"), |
| 40 | + ("NumSources", "Int64"), |
| 41 | + ("NumArticles", "Int64"), |
| 42 | + ("AvgTone", "float64"), |
| 43 | + ("Actor1Geo_Type", "Int64"), |
| 44 | + ("Actor1Geo_Fullname", "string[pyarrow]"), |
| 45 | + ("Actor1Geo_CountryCode", "string[pyarrow]"), |
| 46 | + ("Actor1Geo_ADM1Code", "string[pyarrow]"), |
| 47 | + ("Actor1Geo_Lat", "float64"), |
| 48 | + ("Actor1Geo_Long", "float64"), |
| 49 | + ("Actor1Geo_FeatureID", "string[pyarrow]"), |
| 50 | + ("Actor2Geo_Type", "Int64"), |
| 51 | + ("Actor2Geo_Fullname", "string[pyarrow]"), |
| 52 | + ("Actor2Geo_CountryCode", "string[pyarrow]"), |
| 53 | + ("Actor2Geo_ADM1Code", "string[pyarrow]"), |
| 54 | + ("Actor2Geo_Lat", "float64"), |
| 55 | + ("Actor2Geo_Long", "float64"), |
| 56 | + ("Actor2Geo_FeatureID", "string[pyarrow]"), |
| 57 | + ("ActionGeo_Type", "Int64"), |
| 58 | + ("ActionGeo_Fullname", "string[pyarrow]"), |
| 59 | + ("ActionGeo_CountryCode", "string[pyarrow]"), |
| 60 | + ("ActionGeo_ADM1Code", "string[pyarrow]"), |
| 61 | + ("ActionGeo_Lat", "float64"), |
| 62 | + ("ActionGeo_Long", "float64"), |
| 63 | + ("ActionGeo_FeatureID", "string[pyarrow]"), |
| 64 | + ("DATEADDED", "Int64"), |
| 65 | + ("SOURCEURL", "string[pyarrow]"), |
| 66 | + ] |
| 67 | +) |
| 68 | + |
| 69 | + |
| 70 | +@pytest.mark.client("from_csv_to_parquet") |
| 71 | +def test_from_csv_to_parquet(client, s3_factory, s3_url): |
| 72 | + s3 = s3_factory(anon=True) |
| 73 | + files = s3.ls("s3://gdelt-open-data/events/")[:1000] |
| 74 | + files = [f"s3://{f}" for f in files] |
| 75 | + |
| 76 | + df = dd.read_csv( |
| 77 | + files, |
| 78 | + sep="\t", |
| 79 | + names=SCHEMA.keys(), |
| 80 | + # 'dtype' and 'converters' cannot overlap |
| 81 | + dtype={col: dtype for col, dtype in SCHEMA.items() if dtype != "float64"}, |
| 82 | + storage_options=s3.storage_options, |
| 83 | + on_bad_lines="skip", |
| 84 | + # Some bad files have '#' in float values |
| 85 | + converters={ |
| 86 | + col: lambda v: float(v.replace("#", "") or "NaN") |
| 87 | + for col, dtype in SCHEMA.items() |
| 88 | + if dtype == "float64" |
| 89 | + }, |
| 90 | + ) |
| 91 | + |
| 92 | + # Now we can safely convert the float columns |
| 93 | + df = df.astype({col: dtype for col, dtype in SCHEMA.items() if dtype == "float64"}) |
| 94 | + |
| 95 | + df = df.map_partitions( |
| 96 | + lambda xdf: xdf.drop_duplicates(subset=["SOURCEURL"], keep="first") |
| 97 | + ) |
| 98 | + df["national_paper"] = df.SOURCEURL.str.contains( |
| 99 | + "washingtonpost|nytimes", regex=True |
| 100 | + ) |
| 101 | + df = df[df["national_paper"]] |
| 102 | + df.to_parquet(f"{s3_url}/from-csv-to-parquet/", write_index=False) |
0 commit comments