Issue
I want to find those values of X which divides the bimodal distribution into 3 groups. For example, in my code below, based on the bimodal plot, an approximate values such values are x less than 5, (x greater than 5 and less than 90) and (x greater than 90)
But I am not getting these values. Here is my code
import numpy as np
from scipy.stats import gaussian_kde
from scipy.signal import find_peaks
# Generate bimodal data
dist1 = np.random.normal(loc=0, scale=1, size=100)
dist2 = np.random.normal(loc=90, scale=1, size=100)
bimodal = np.concatenate((dist1, dist2))
# Fit KDE
kde = gaussian_kde(bimodal)
xgrid = np.linspace(0,100)
pdf = kde.evaluate(xgrid)
plt.hist(bimodal, bins=100, density=True)
plt.title("Bimodal distribution")
# Find peaks
peaks, _ = find_peaks(pdf)
peak1, peak2 = xgrid[peaks]
# Find valley
pdf_min = np.min(pdf[(xgrid > peak1) & (xgrid < peak2)])
valley = xgrid[(pdf == pdf_min) & (xgrid > peak1) & (xgrid < peak2)]
# Create group labels
groups = np.ones(len(bimodal), dtype=int)
groups[bimodal < valley] = 0
groups[(bimodal >= valley) & (bimodal <= peak1)] = 1
groups[bimodal > peak1] = 2
# Plot
plt.hist(bimodal)
plt.vlines([valley, peak1], 0, 100, colors='r')
plt.title("Bimodal distribution clustered into 3 groups")
plt.show()
print(groups)
Solution
This is basically a Gaussian mixture model. There are complicated ways to fit this; I show a pretty simple one. If your data resemble the example parameters you wrote then it will work well.
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
def generate_opaque(rand: np.random.Generator) -> np.ndarray:
# return rand.normal(loc=(0, 90), scale=1, size=(100, 2)).ravel()
return np.concatenate((
rand.normal(loc=5, scale=2, size=450),
rand.normal(loc=45, scale=10, size=150),
))
def bigaussian(
x: np.ndarray,
loc_a: float, scale_a: float,
loc_b: float, scale_b: float,
bal: float,
) -> np.ndarray:
norm_a = scipy.stats.norm(loc_a, scale_a)
norm_b = scipy.stats.norm(loc_b, scale_b)
cdf = bal*norm_a.cdf(x) + (1-bal)*norm_b.cdf(x)
return cdf
def main() -> None:
rand = np.random.default_rng(seed=0)
data = np.sort(generate_opaque(rand))
# Empirical cumulative PDF: this is the reference against which fitting is compared
ecdf = scipy.stats.ecdf(data).cdf.probabilities
# Mean of whole dataset. So long as modes are balanced, this will
# be a roughly sensible estimate of midpoint between the modes.
mean_est = data.mean()
# Rough left-side-dominated and right-side-dominated random variable section
lhs = data[data < mean_est]
rhs = data[data > mean_est]
# Rough normal fits for both modes. These are only sensible estimates
# if inter-mode distance is high and scales are low.
loc_est_a, scale_est_a = scipy.stats.norm.fit(lhs)
loc_est_b, scale_est_b = scipy.stats.norm.fit(rhs)
# Estimated proportion of left normal to entire random variable
bal_est = lhs.size / data.size
popt, _ = scipy.optimize.curve_fit(
f=bigaussian,
xdata=data, ydata=ecdf,
p0=(loc_est_a, scale_est_a, loc_est_b, scale_est_b, bal_est),
bounds=(
(data.min(), 0, data.min(), 0, 0),
(data.max(), np.inf, data.max(), np.inf, 1),
),
)
(loc_a, scale_a, loc_b, scale_b, bal) = popt
ecdf_fit = bigaussian(data, *popt)
# One workable definition of separation:
# empirical value of x s.t. CDF(x) ~ 0.5; aka. mid-point of the density based on modal balance
i_mid = ecdf_fit.searchsorted(v=bal)
x_mid = (data[i_mid] + data[i_mid-1])/2
print(f'lhs est x ~ {loc_est_a:5.2f} ±{scale_est_a:.2f}, {bal_est:.1%}')
print(f'lhs refined x ~ {loc_a:5.2f} ±{scale_a:.2f}, {bal:.1%}')
print(f'rhs est x ~ {loc_est_b:5.2f} ±{scale_est_b:.2f}, {1-bal_est:.1%}')
print(f'rhs refined x ~ {loc_b:5.2f} ±{scale_b:.2f}, {1-bal:.1%}')
print(f'Sep est {mean_est:.2f}')
print(f'Sep refined {x_mid:.2f}')
x_hires = np.linspace(start=data.min(), stop=data.max(), num=1000)
fit_pdf = (
scipy.stats.norm(loc_a, scale_a).pdf(x_hires)*bal +
scipy.stats.norm(loc_b, scale_b).pdf(x_hires)*(1-bal)
)
ax: plt.Axes
fig, ax = plt.subplots()
ax.hist(data, bins=40, density=True, label='empirical')
ax.plot(x_hires, fit_pdf, label='fit')
ax.axvline(x=x_mid, linestyle='--', color='black', label='midpoint')
ax.legend()
plt.show()
if __name__ == '__main__':
main()
lhs est x ~ 4.98 ±2.01, 75.2%
lhs refined x ~ 4.95 ±2.03, 74.9%
rhs est x ~ 44.67 ±9.16, 24.8%
rhs refined x ~ 44.46 ±9.47, 25.1%
Sep est 14.84
Sep refined 15.54
Answered By - Reinderien
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.