Skip to content

Commit

Permalink
Fix: Feedback addressed by returning the plots in dictionary and also…
Browse files Browse the repository at this point in the history
… adding functionality to save the plots. Done by Ismail (Husian)
  • Loading branch information
ismailbhinder committed Feb 2, 2025
1 parent b6396fe commit 88c4b89
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 2,000 deletions.
2,160 changes: 242 additions & 1,918 deletions docs/example.ipynb

Large diffs are not rendered by default.

105 changes: 68 additions & 37 deletions src/datpro/datpro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import altair as alt
from itertools import combinations
from typing import Dict, Union, Optional, List

def summarize_data(df: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -117,7 +118,7 @@ def detect_anomalies(df: pd.DataFrame, anomaly_type: Optional[str] = None) -> Di

return report

def plotify(df: pd.DataFrame, plot_types: Optional[List[str]] = None) -> None:
def plotify(df: pd.DataFrame, plot_types: Optional[List[str]] = None, save: bool = False, save_path: str = "plots", file_prefix: str = "plot") -> Dict[str, alt.Chart]:
"""
Visualize a DataFrame by generating specified plots based on column datatypes.
Expand All @@ -136,11 +137,20 @@ def plotify(df: pd.DataFrame, plot_types: Optional[List[str]] = None) -> None:
- 'box' : Plot box plots for numeric vs categorical columns.
- 'stacked_bar' : Plot stacked bar charts for pairwise categorical columns.
If None, all plot types are generated by default.
save : bool, optional
If True, saves the plots to the specified path. Default is False.
save_path : str, optional
The directory where plots should be saved. Default is 'plots'.
file_prefix : str, optional
The prefix for saved plot filenames. Default is 'plot'.
Returns
-------
None
Generates and displays specified plots based on the provided column types.
dict
A dictionary where keys are plot names and values are Altair Chart objects.
Raises
------
Expand All @@ -153,89 +163,110 @@ def plotify(df: pd.DataFrame, plot_types: Optional[List[str]] = None) -> None:
-----
- Numeric columns are those of types 'int64', 'float64'.
- Categorical columns are those of types 'object', 'category', and 'bool'.
Examples
--------
>>> import pandas as pd
>>> df = pd.DataFrame({'A': [1, 2, 3, 4], 'B': ['x', 'y', 'x', 'y']})
>>> charts = plotify(df, plot_types=['histogram', 'bar'])
>>> charts['histogram_A'].show()
"""
import os

# Validate input
if not isinstance(df, pd.DataFrame):
raise TypeError("Input must be a pandas DataFrame.")
if df.empty:
raise ValueError("Input DataFrame is empty.")

if save and not os.path.exists(save_path):
os.makedirs(save_path)

# Set default plot types if not specified
if plot_types is None:
plot_types = ['histogram', 'density', 'bar', 'scatter', 'correlation', 'box', 'stacked_bar']

# Analyze columns
numeric_cols = df.select_dtypes(include='number').columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category', 'bool']).columns.tolist()

plots = {}

# Individual column visualizations
if 'histogram' in plot_types or 'density' in plot_types:
for col in numeric_cols:
print(f"Visualizing numeric column: {col}")
if 'histogram' in plot_types:
hist_chart = alt.Chart(df).mark_bar().encode(
x=alt.X(col, bin=True, title=f"{col} (binned)"),
y=alt.Y('count()', title='Count')
).properties(title=f"Histogram of {col}")
hist_chart.display()
plots[f'histogram_{col}'] = hist_chart
if save:
hist_chart.save(f"{save_path}/{file_prefix}_histogram_{col}.html")
if 'density' in plot_types:
density_chart = alt.Chart(df).transform_density(
col, as_=[col, 'density']
).mark_area(opacity=0.5).encode(
x=alt.X(col, title=col),
y=alt.Y('density:Q', title='Density')
).properties(title=f"Density Plot of {col}")
density_chart.display()

plots[f'density_{col}'] = density_chart
if save:
density_chart.save(f"{save_path}/{file_prefix}_density_{col}.html")

if 'bar' in plot_types:
for col in categorical_cols:
print(f"Visualizing categorical column: {col}")
bar_chart = alt.Chart(df).mark_bar().encode(
x=alt.X(col, title=col),
y=alt.Y('count()', title='Count')
).properties(title=f"Bar Chart of {col}")
bar_chart.display()

# Pairwise relationships
plots[f'bar_{col}'] = bar_chart
if save:
bar_chart.save(f"{save_path}/{file_prefix}_bar_{col}.html")

if 'scatter' in plot_types:
for col1, col2 in combinations(numeric_cols, 2):
print(f"Visualizing numeric vs numeric: {col1} vs {col2}")
scatter_chart = alt.Chart(df).mark_circle(size=60).encode(
x=alt.X(col1, title=col1),
y=alt.Y(col2, title=col2),
tooltip=[col1, col2]
).properties(title=f"Scatter Plot: {col1} vs {col2}")
scatter_chart.display()

plots[f'scatter_{col1}_{col2}'] = scatter_chart
if save:
scatter_chart.save(f"{save_path}/{file_prefix}_scatter_{col1}_{col2}.html")

if 'correlation' in plot_types and len(numeric_cols) > 1:
print("Visualizing correlation heatmap")
corr_matrix = df[numeric_cols].corr().stack().reset_index()
corr_matrix.columns = ['Variable 1', 'Variable 2', 'Correlation']
heatmap = alt.Chart(corr_matrix).mark_rect().encode(
x=alt.X('Variable 1:N'),
y=alt.Y('Variable 2:N'),
color=alt.Color('Correlation:Q', scale=alt.Scale(scheme='viridis')),
tooltip=['Variable 1', 'Variable 2', 'Correlation']
color=alt.Color('Correlation:Q', scale=alt.Scale(scheme='viridis'))
).properties(title='Correlation Heatmap')
heatmap.display()

plots['correlation_heatmap'] = heatmap
if save:
heatmap.save(f"{save_path}/{file_prefix}_correlation_heatmap.html")

if 'box' in plot_types:
for num_col in numeric_cols:
for cat_col in categorical_cols:
print(f"Visualizing numeric vs categorical: {num_col} vs {cat_col}")
for numeric_col in numeric_cols:
for categorical_col in categorical_cols:
box_plot = alt.Chart(df).mark_boxplot().encode(
x=alt.X(cat_col, title=cat_col),
y=alt.Y(num_col, title=num_col),
color=alt.Color(cat_col, legend=None)
).properties(title=f"Box Plot: {num_col} vs {cat_col}")
box_plot.display()

x=alt.X(categorical_col, title=categorical_col),
y=alt.Y(numeric_col, title=numeric_col)
).properties(title=f"Box Plot of {numeric_col} by {categorical_col}")
plots[f'box_{numeric_col}_{categorical_col}'] = box_plot
if save:
box_plot.save(f"{save_path}/{file_prefix}_box_{numeric_col}_{categorical_col}.html")

if 'stacked_bar' in plot_types:
for cat_col1, cat_col2 in combinations(categorical_cols, 2):
print(f"Visualizing categorical vs categorical: {cat_col1} vs {cat_col2}")
stacked_bar = alt.Chart(df).mark_bar().encode(
x=alt.X(cat_col1, title=cat_col1),
for col1, col2 in combinations(categorical_cols, 2):
stacked_bar_chart = alt.Chart(df).mark_bar().encode(
x=alt.X(col1, title=col1),
y=alt.Y('count()', title='Count'),
color=alt.Color(cat_col2, title=cat_col2)
).properties(title=f"Stacked Bar Chart: {cat_col1} vs {cat_col2}")
stacked_bar.display()
color=alt.Color(col2, title=col2)
).properties(title=f"Stacked Bar Chart of {col1} vs {col2}")
plots[f'stacked_bar_{col1}_{col2}'] = stacked_bar_chart
if save:
stacked_bar_chart.save(f"{save_path}/{file_prefix}_stacked_bar_{col1}_{col2}.html")

return plots
98 changes: 53 additions & 45 deletions tests/test_plotify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,35 @@
import pytest
import pandas as pd

@pytest.fixture
def valid_df():
return pd.DataFrame({
'age': [25, 30, 35, 40, 45],
'income': [50000, 60000, 70000, 80000, 90000],
'gender': ['M', 'F', 'M', 'F', 'M'],
'city': ['New York', 'Los Angeles', 'Chicago', 'Houston', 'Phoenix']
})

@pytest.fixture
def empty_df():
return pd.DataFrame()

@pytest.fixture
def numeric_df():
return pd.DataFrame({
'A': [1, 2, 3, 4, 5],
'B': [10, 20, 30, 40, 50]
})

@pytest.fixture
def valid_df():
"""
Fixture to create a sample DataFrame with both numeric and categorical columns.
Returns
-------
pandas.DataFrame
Sample DataFrame with numeric and categorical columns for testing.
"""
data = {
'age': [25, 30, 35, 40, 45],
Expand All @@ -17,95 +42,78 @@ def valid_df():

def test_plotify_valid_df(valid_df):
"""
Test to verify that the plotify function generates all plots when no specific plot types are provided.
Args:
valid_df (pd.DataFrame): Sample DataFrame containing numeric and categorical columns.
Test that plotify generates all plots when no specific plot types are provided.
"""
plotify(valid_df, plot_types=None) # Test with all plot types
result = plotify(valid_df, plot_types=None)
assert isinstance(result, dict)
assert len(result) > 0

def test_plotify_empty_df():
"""
Test to verify that the plotify function raises a ValueError when an empty DataFrame is passed.
This tests the validation logic for empty DataFrames.
Test that plotify raises a ValueError when an empty DataFrame is provided.
"""
empty_df = pd.DataFrame()
with pytest.raises(ValueError):
plotify(empty_df)

def test_plotify_invalid_input():
"""
Test to verify that the plotify function raises a TypeError when the input is not a pandas DataFrame.
This tests the input type validation logic.
Test that plotify raises a TypeError when input is not a pandas DataFrame.
"""
invalid_input = [1, 2, 3, 4]
with pytest.raises(TypeError):
plotify(invalid_input)

def test_plotify_specific_plots(valid_df):
"""
Test to verify that the plotify function generates only specified plot types.
Args:
valid_df (pd.DataFrame): Sample DataFrame containing numeric and categorical columns.
Test that plotify generates only the specified plot types.
"""
plotify(valid_df, plot_types=['histogram', 'scatter']) # Test with specific plot types
result = plotify(valid_df, plot_types=['histogram', 'scatter'])
assert all(plot in result for plot in ['histogram_age', 'scatter_age_income'])

def test_plotify_scatter_plot(valid_df):
"""
Test to verify that the plotify function generates scatter plots for numeric columns.
Args:
valid_df (pd.DataFrame): Sample DataFrame containing numeric columns for pairwise scatter plotting.
Test that plotify generates scatter plots for numeric columns.
"""
plotify(valid_df, plot_types=['scatter'])
result = plotify(valid_df, plot_types=['scatter'])
assert 'scatter_age_income' in result

def test_plotify_correlation_heatmap(valid_df):
"""
Test to verify that the plotify function generates a correlation heatmap when there are multiple numeric columns.
Args:
valid_df (pd.DataFrame): Sample DataFrame containing multiple numeric columns.
Test that plotify generates a correlation heatmap for numeric columns.
"""
plotify(valid_df, plot_types=['correlation'])
result = plotify(valid_df, plot_types=['correlation'])
assert 'correlation_heatmap' in result

def test_plotify_box_plot(valid_df):
"""
Test to verify that the plotify function generates box plots for numeric vs categorical columns.
Args:
valid_df (pd.DataFrame): Sample DataFrame containing both numeric and categorical columns.
Test that plotify generates box plots for numeric vs categorical columns.
"""
plotify(valid_df, plot_types=['box'])
result = plotify(valid_df, plot_types=['box'])
assert 'box_income_gender' in result

def test_plotify_stacked_bar(valid_df):
"""
Test to verify that the plotify function generates stacked bar charts for pairwise categorical columns.
Args:
valid_df (pd.DataFrame): Sample DataFrame containing categorical columns for pairwise stacked bar plotting.
Test that plotify generates stacked bar charts for categorical columns.
"""
plotify(valid_df, plot_types=['stacked_bar'])
result = plotify(valid_df, plot_types=['stacked_bar'])
assert 'stacked_bar_gender_city' in result

def test_plotify_empty_plot_types(valid_df):
"""
Test to verify that the plotify function generates all plots when the plot_types argument is an empty list.
Args:
valid_df (pd.DataFrame): Sample DataFrame containing both numeric and categorical columns.
Test that plotify generates no plots when plot_types is an empty list.
"""
plotify(valid_df, plot_types=[]) # Expect all plots to be generated as the list is empty
result = plotify(valid_df, plot_types=[])
assert isinstance(result, dict)
assert len(result) == 0

def test_plotify_missing_columns():
"""
Test to verify that the plotify function handles cases where only numeric columns are present.
This tests the scenario where only numeric columns are available for generating plots.
Test that plotify handles cases where only numeric columns are present.
"""
df = pd.DataFrame({
'age': [25, 30, 35, 40, 45],
'income': [50000, 60000, 70000, 80000, 90000]
})
plotify(df, plot_types=['scatter', 'box']) # Test scatter plot and box plot
result = plotify(df, plot_types=['scatter'])
assert 'scatter_age_income' in result

0 comments on commit 88c4b89

Please sign in to comment.