Skip to content

Commit

Permalink
Merge pull request #73 from visdesignlab/quadratic-trend-fit
Browse files Browse the repository at this point in the history
Add polynomial fit and residual comparison across types
  • Loading branch information
JakeWags authored Sep 11, 2024
2 parents e65bb91 + 5542540 commit 90e749b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
4 changes: 3 additions & 1 deletion src/alttxt/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,6 @@ class IntersectionTrend(Listable):
to the MultiNet implementation's export format.
"""
DRASTIC = "drastically"
MODERATE = "moderately"
RAPID = "rapidly"
QUICK = "quickly"
STEADY = "steadily"
75 changes: 43 additions & 32 deletions src/alttxt/tokenmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,50 +748,61 @@ def calculate_set_divergence(self):
return IndividualSetSize.DIVERGINGABIT.value
else:
return IndividualSetSize.IDENTICAL.value


def calculate_change_trend(self):
"""
Performs non-linear regression using an exponential decay model
Calculates an array of optimal values for parameters a (amplitude), b (decay rate), and c (asymptote)
Returns the intersection trend type based on the decay rate (b)
Analyzes the trend of changes in intersection sizes and classifies the trend.
This method calculates the trend of changes in intersection sizes using three types of fits:
linear, exponential, and quadratic polynomial. It then compares the residuals of these fits to determine
the best fitting model and classifies the trend based on the parameters of the best fit.
Returns:
IntersectionTrend: An enumeration value representing the classified trend:
- IntersectionTrend.DRASTIC: If the exponential fit is the best and the decay rate (beta) is greater than 0.8.
- IntersectionTrend.RAPID: If the exponential fit is the best but the decay rate (beta) is less than or equal to 0.8.
- IntersectionTrend.QUICK: If the quadratic polynomial fit is the best.
- IntersectionTrend.STEADY: If the linear fit is the best.
"""
intersection_sizes = [self.data.subsets[i].size for i in range(len(self.data.subsets))]

x = np.arange(len(intersection_sizes))
y = np.array(intersection_sizes)

try:
popt, _ = curve_fit(lambda x, a, b, c: a*np.exp(-b*x)+c, x, y,
p0=[max(y), 0.1, min(y)],
bounds=([0, 0, 0], [np.inf, np.inf, np.inf]),
maxfev=5000)
a, b, c = popt

if b > 0 and a > 0:
x_fit = np.linspace(0, len(intersection_sizes)-1, 100)
y_fit = a * np.exp(-b * x_fit) + c

if b > 0.5:
return IntersectionTrend.DRASTIC.value
else:
return IntersectionTrend.MODERATE.value

except:
pass # If exponential fit fails, we'll fall back to linear regression

slope, _, r_value, p_value, _ = stats.linregress(x, y)

if p_value < 0.05:
relative_change = abs(slope * (len(x) - 1) / y[0])
if relative_change > 0.5:
# linear fit
slope, intercept, r_value, p_value, _ = stats.linregress(x, y)
linear_fit = slope * x + intercept
linear_residuals = np.sum((y - linear_fit) ** 2)

# polynomial (quadratic) fit
fit = np.polyfit(x, y, 2, full=True)
quadratic_residuals = fit[1][0]

# exponential fit
popt, _ = curve_fit(lambda x, a, beta, c: a*np.exp(-beta*x)+c, x, y,
p0=[max(y), 0.1, min(y)],
bounds=([0, 0, 0], [np.inf, np.inf, np.inf]),
maxfev=5000)
a, beta, c = popt

if beta > 0 and a > 0:
x_fit = np.linspace(0, len(x)-1, 100)
y_fit = a * np.exp(-beta * x_fit) + c
y_fit_interpolated = np.interp(x, x_fit, y_fit) # Interpolate y_fit to match x
exponential_residuals = np.sum((y - y_fit_interpolated) ** 2)
else:
exponential_residuals = np.inf

if exponential_residuals < linear_residuals and exponential_residuals < quadratic_residuals:
if beta > 0.8:
return IntersectionTrend.DRASTIC.value
else:
return IntersectionTrend.MODERATE.value
return IntersectionTrend.RAPID.value
elif quadratic_residuals < linear_residuals and quadratic_residuals < exponential_residuals:
return IntersectionTrend.QUICK.value
else:
return IntersectionTrend.SLIGHT.value


return IntersectionTrend.STEADY.value

def calculate_largest_factor(self):
sorted_sizes = self.sort_subsets_by_key(SubsetField.SIZE, True)
if len(sorted_sizes) >= 2:
Expand Down

0 comments on commit 90e749b

Please sign in to comment.