-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from dsi-clinic/networkx_record_linkage
update on function to add nodes and their attributes to graph
- Loading branch information
Showing
4 changed files
with
233 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Output README | ||
--- | ||
'deduplicated_UUIDs.csv' : Following record linkage work in the record_linkage pipeline, this file stores all the original uuids, and indicates the uuids to which the deduplicated uuids have been matched to. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
import networkx as nx | ||
import pandas as pd | ||
import plotly.graph_objects as go | ||
|
||
|
||
def name_identifier(uuid: str, dfs: list[pd.DataFrame]) -> str: | ||
"""Returns the name of the entity given the entity's uuid | ||
Args: | ||
uuid: the uuid of the entity | ||
List of dfs: dataframes that have a uuid column, and an 'name' or | ||
'full_name' column | ||
Return: | ||
The entity's name | ||
""" | ||
for df in dfs: | ||
if "name" in df.columns: | ||
name_in_org = df.loc[df["id"] == uuid] | ||
if len(name_in_org) > 0: | ||
return name_in_org.iloc[0]["name"] | ||
|
||
if "full_name" in df.columns: | ||
name_in_ind = df.loc[df["id"] == uuid] | ||
if len(name_in_ind) > 0: | ||
return name_in_ind.iloc[0]["full_name"] | ||
return None | ||
|
||
|
||
def combine_datasets_for_network_graph(dfs: list[pd.DataFrame]) -> pd.DataFrame: | ||
"""Combines the 3 dataframes into a single dataframe to create a graph | ||
Given 3 dataframes, the func adds a 'recipient_name' column in the | ||
transactions df, merges the dfs together to record transaction info between | ||
entities, then concatenates the dfs into a final df of the merged | ||
transactions and entity dfs. | ||
Args: | ||
list of dataframes in the order: [inds_df, orgs_df, transactions_df] | ||
Transactions dataframe with column: 'recipient_id' | ||
Individuals dataframe with column: 'full_name' | ||
Organizations dataframe with column: 'name' | ||
Returns | ||
A merged dataframe with aggregate contribution amounts between entitites | ||
""" | ||
|
||
inds_df, orgs_df, transactions_df = dfs | ||
|
||
# first update the transactions df to have a recipient name tied to id | ||
transactions_df["recipient_name"] = transactions_df["recipient_id"].apply( | ||
name_identifier, args=([orgs_df, inds_df],) | ||
) | ||
|
||
# next, merge the inds_df and orgs_df ids with the transactions_df donor_id | ||
inds_trans_df = pd.merge( | ||
inds_df, transactions_df, how="left", left_on="id", right_on="donor_id" | ||
) | ||
inds_trans_df = inds_trans_df.dropna(subset=["amount"]) | ||
orgs_trans_df = pd.merge( | ||
orgs_df, transactions_df, how="left", left_on="id", right_on="donor_id" | ||
) | ||
orgs_trans_df = orgs_trans_df.dropna(subset=["amount"]) | ||
orgs_trans_df = orgs_trans_df.rename(columns={"name": "full_name"}) | ||
|
||
# concatenated the merged dfs | ||
merged_df = pd.concat([orgs_trans_df, inds_trans_df]) | ||
|
||
# lastly, create the final dataframe with aggregated attributes | ||
attribute_cols = merged_df.columns.difference( | ||
["donor_id", "recipient_id", "full_name", "recipient_name"] | ||
) | ||
agg_functions = { | ||
col: "sum" if col == "amount" else "first" for col in attribute_cols | ||
} | ||
aggreg_df = ( | ||
merged_df.groupby( | ||
["donor_id", "recipient_id", "full_name", "recipient_name"] | ||
) | ||
.agg(agg_functions) | ||
.reset_index() | ||
) | ||
aggreg_df = aggreg_df.drop(["id"], axis=1) | ||
return aggreg_df | ||
|
||
|
||
def create_network_graph(df: pd.DataFrame) -> nx.MultiDiGraph: | ||
"""Takes in a dataframe and generates a MultiDiGraph where the nodes are | ||
entity names, and the rest of the dataframe columns make the node attributes | ||
Args: | ||
df: a pandas dataframe with merged information from the inds, orgs, & | ||
transactions dataframes | ||
Returns: | ||
A Networkx MultiDiGraph with nodes and edges | ||
""" | ||
G = nx.MultiDiGraph() | ||
edge_columns = [ | ||
"office_sought", | ||
"purpose", | ||
"transaction_type", | ||
"year", | ||
"transaction_id", | ||
"donor_office", | ||
"amount", | ||
] | ||
|
||
for _, row in df.iterrows(): | ||
# add node attributes based on the columns relevant to the entity | ||
G.add_node( | ||
row["full_name"], | ||
**row[df.columns.difference(edge_columns)].dropna().to_dict(), | ||
) | ||
# add the recipient as a node | ||
G.add_node(row["recipient_name"], classification="neutral") | ||
|
||
# add the edge attributes between two nodes | ||
edge_attributes = row[edge_columns].dropna().to_dict() | ||
G.add_edge(row["full_name"], row["recipient_name"], **edge_attributes) | ||
|
||
return G | ||
|
||
|
||
def plot_network_graph(G: nx.MultiDiGraph): | ||
"""Given a networkX Graph, creates a plotly visualization of the nodes and | ||
edges | ||
Args: | ||
A networkX MultiDiGraph with edges including the attribute 'amount' | ||
Returns: None. Creates a plotly graph | ||
""" | ||
edge_trace = go.Scatter( | ||
x=(), | ||
y=(), | ||
line=dict(color="#888", width=1.5), | ||
hoverinfo="text", | ||
mode="lines+markers", | ||
) | ||
hovertext = [] | ||
pos = nx.spring_layout(G) | ||
|
||
for edge in G.edges(data=True): | ||
source = edge[0] | ||
target = edge[1] | ||
hovertext.append(f"Amount: {edge[2]['amount']:.2f}") | ||
# Adding coordinates of source and target nodes to edge_trace | ||
edge_trace["x"] += ( | ||
pos[source][0], | ||
pos[target][0], | ||
None, | ||
) # None creates a gap between line segments | ||
edge_trace["y"] += (pos[source][1], pos[target][1], None) | ||
|
||
edge_trace["hovertext"] = hovertext | ||
|
||
# Define arrow symbol for edges | ||
edge_trace["marker"] = dict( | ||
symbol="arrow", color="#888", size=10, angleref="previous" | ||
) | ||
|
||
node_trace = go.Scatter( | ||
x=[], | ||
y=[], | ||
text=[], | ||
mode="markers", | ||
hoverinfo="text", | ||
marker=dict(showscale=True, colorscale="YlGnBu", size=10), | ||
) | ||
node_trace["marker"]["color"] = [] | ||
|
||
for node in G.nodes(): | ||
node_info = f"Name: {node}<br>" | ||
for key, value in G.nodes[node].items(): | ||
node_info += f"{key}: {value}<br>" | ||
node_trace["text"] += tuple([node_info]) | ||
# Get the classification value for the node | ||
classification = G.nodes[node].get("classification", "neutral") | ||
# Assign a color based on the classification value | ||
if classification == "c": | ||
color = "blue" | ||
elif classification == "f": | ||
color = "red" | ||
else: | ||
color = "green" # Default color for unknown classification | ||
node_trace["marker"]["color"] += tuple([color]) | ||
|
||
# Add node positions to the trace | ||
node_trace["x"] += tuple([pos[node][0]]) | ||
node_trace["y"] += tuple([pos[node][1]]) | ||
|
||
# Define layout settings | ||
layout = go.Layout( | ||
title="Network Graph Indicating Campaign Contributions from 2018-2022", | ||
titlefont=dict(size=16), | ||
showlegend=True, | ||
hovermode="closest", | ||
margin=dict(b=20, l=5, r=5, t=40), | ||
xaxis=dict(showgrid=True, zeroline=True, showticklabels=False), | ||
yaxis=dict(showgrid=True, zeroline=True, showticklabels=False), | ||
) | ||
|
||
fig = go.Figure(data=[edge_trace, node_trace], layout=layout) | ||
fig.show() | ||
|
||
|
||
def construct_network_graph( | ||
start_year: int, end_year: int, dfs: list[pd.DataFrame] | ||
): | ||
"""Runs the network construction pipeline starting from 3 dataframes | ||
Args: | ||
start_year & end_year: the range of the desired data | ||
dfs: dataframes in the order: inds_df, orgs_df, transactions_df | ||
Returns: | ||
""" | ||
inds_df, orgs_df, transactions_df = dfs | ||
transactions_df = transactions_df.loc[ | ||
(transactions_df.year >= start_year) | ||
& (transactions_df.year <= end_year) | ||
] | ||
|
||
aggreg_df = combine_datasets_for_network_graph( | ||
[inds_df, orgs_df, transactions_df] | ||
) | ||
G = create_network_graph(aggreg_df) | ||
plot_network_graph(G) | ||
nx.write_adjlist(G, "Network Graph Node Data") |