Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ matplotlib
scipy>=0.18
pandas
click
joblib
154 changes: 80 additions & 74 deletions src/fitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@
.. sectionauthor:: Thomas Cokelaer, Aug 2014-2020

"""
import logging
import sys
import threading
from datetime import datetime
import logging

import scipy.stats
import numpy as np
import pylab
import pandas as pd
import pylab
import scipy.stats
from easydev import Progress
from joblib import Parallel, delayed
from scipy.stats import entropy as kl_div


logging.getLogger(__name__)
logger = logging.getLogger(__name__)

__all__ = ['get_common_distributions', 'get_distributions', 'Fitter']

Expand All @@ -50,7 +51,7 @@ def get_common_distributions():
distributions = get_distributions()
# to avoid error due to changes in scipy
common = ['cauchy', 'chi2', 'expon', 'exponpow', 'gamma',
'lognorm', 'norm', 'powerlaw', 'rayleigh', 'uniform']
'lognorm', 'norm', 'powerlaw', 'rayleigh', 'uniform']
common = [x for x in common if x in distributions]
return common

Expand Down Expand Up @@ -112,7 +113,7 @@ class Fitter(object):
"""

def __init__(self, data, xmin=None, xmax=None, bins=100,
distributions=None, timeout=30,
distributions=None, timeout=30,
density=True):
""".. rubric:: Constructor

Expand Down Expand Up @@ -181,13 +182,15 @@ def _init(self):
self._aic = {}
self._bic = {}
self._kldiv = {}
self._fit_i = 0 # fit progress
self.pb = Progress(len(self.distributions))

def _update_data_pdf(self):
# histogram retuns X with N+1 values. So, we rearrange the X output into only N
self.y, self.x = np.histogram(
self._data, bins=self.bins, density=self._density)
self.x = [(this + self.x[i + 1]) / 2. for i,
this in enumerate(self.x[0:-1])]
this in enumerate(self.x[0:-1])]

def _trim_data(self):
self._data = self._alldata[np.logical_and(
Expand All @@ -204,6 +207,7 @@ def _set_xmin(self, value):
self._xmin = value
self._trim_data()
self._update_data_pdf()

xmin = property(_get_xmin, _set_xmin,
doc="consider only data above xmin. reset if None")

Expand All @@ -218,6 +222,7 @@ def _set_xmax(self, value):
self._xmax = value
self._trim_data()
self._update_data_pdf()

xmax = property(_get_xmax, _set_xmax,
doc="consider only data below xmax. reset if None ")

Expand All @@ -241,7 +246,62 @@ def hist(self):
_ = pylab.hist(self._data, bins=self.bins, density=self._density)
pylab.grid(True)

def fit(self, amp=1, progress=False):
def _fit_single_distribution(self, distribution, progress: bool):
try:
# need a subprocess to check time it takes. If too long, skip it
dist = eval("scipy.stats." + distribution)

# TODO here, dist.fit may take a while or just hang forever
# with some distributions. So, I thought to use signal module
# to catch the error when signal takes too long. It did not work
# presumably because another try/exception is inside the
# fit function, so I used threading with a recipe from stackoverflow
# See timed_run function above
param = self._timed_run(dist.fit, distribution, args=self._data)

# with signal, does not work. maybe because another expection is caught
# hoping the order returned by fit is the same as in pdf
pdf_fitted = dist.pdf(self.x, *param)

self.fitted_param[distribution] = param[:]
self.fitted_pdf[distribution] = pdf_fitted

# calculate error
sq_error = pylab.sum(
(self.fitted_pdf[distribution] - self.y) ** 2)

# calcualte information criteria
logLik = np.sum(dist.logpdf(self.x, *param))
k = len(param[:])
n = len(self._data)
aic = 2 * k - 2 * logLik
bic = n * np.log(sq_error / n) + k * np.log(n)

# calcualte kullback leibler divergence
kullback_leibler = kl_div(
self.fitted_pdf[distribution], self.y)

logging.info("Fitted {} distribution with error={})".format(
distribution, sq_error))

# compute some errors now
self._fitted_errors[distribution] = sq_error
self._aic[distribution] = aic
self._bic[distribution] = bic
self._kldiv[distribution] = kullback_leibler
except Exception: # pragma: no cover
logging.warning("SKIPPED {} distribution (taking more than {} seconds)".format(distribution,
self.timeout))
# if we cannot compute the error, set it to large values
self._fitted_errors[distribution] = np.inf
self._aic[distribution] = np.inf
self._bic[distribution] = np.inf
self._kldiv[distribution] = np.inf
if progress:
self._fit_i += 1
self.pb.animate(self._fit_i)

def fit(self, amp=1, progress=False, n_jobs=-1):
r"""Loop over distributions and find best parameter to fit the data for each

When a distribution is fitted onto the data, we populate a set of
Expand All @@ -258,64 +318,9 @@ def fit(self, amp=1, progress=False):
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

from easydev import Progress
N = len(self.distributions)
pb = Progress(N)
for i, distribution in enumerate(self.distributions):
try:
# need a subprocess to check time it takes. If too long, skip it
dist = eval("scipy.stats." + distribution)

# TODO here, dist.fit may take a while or just hang forever
# with some distributions. So, I thought to use signal module
# to catch the error when signal takes too long. It did not work
# presumably because another try/exception is inside the
# fit function, so I used threading with a recipe from stackoverflow
# See timed_run function above
param = self._timed_run(
dist.fit, distribution, args=self._data)

# with signal, does not work. maybe because another expection is caught
# hoping the order returned by fit is the same as in pdf
pdf_fitted = dist.pdf(self.x, *param)

self.fitted_param[distribution] = param[:]
self.fitted_pdf[distribution] = pdf_fitted

# calculate error
sq_error = pylab.sum(
(self.fitted_pdf[distribution] - self.y)**2)

# calcualte information criteria
logLik = np.sum(dist.logpdf(self.x, *param))
k = len(param[:])
n = len(self._data)
aic = 2 * k - 2 * logLik
bic = n * np.log(sq_error / n) + k * np.log(n)

# calcualte kullback leibler divergence
kullback_leibler = kl_div(
self.fitted_pdf[distribution], self.y)

logging.info("Fitted {} distribution with error={})".format(
distribution, sq_error))

# compute some errors now
self._fitted_errors[distribution] = sq_error
self._aic[distribution] = aic
self._bic[distribution] = bic
self._kldiv[distribution] = kullback_leibler
except Exception as err: #pragma: no cover
logging.warning("SKIPPED {} distribution (taking more than {} seconds)".format(distribution,
self.timeout))
# if we cannot compute the error, set it to large values
self._fitted_errors[distribution] = np.inf
self._aic[distribution] = np.inf
self._bic[distribution] = np.inf
self._kldiv[distribution] = np.inf
if progress:
pb.animate(i+1)

jobs = (delayed(self._fit_single_distribution)(dist, progress) for dist in self.distributions)
pool = Parallel(n_jobs=n_jobs, backend='threading')
_ = pool(jobs)
self.df_errors = pd.DataFrame({'sumsquare_error': self._fitted_errors,
'aic': self._aic,
'bic': self._bic,
Expand Down Expand Up @@ -350,7 +355,7 @@ def plot_pdf(self, names=None, Nbest=5, lw=2, method="sumsquare_error"):
if name in self.fitted_pdf.keys():
pylab.plot(
self.x, self.fitted_pdf[name], lw=lw, label=name)
else: #pragma: no cover
else: # pragma: no cover
logger.warning("%s was not fitted. no parameters available" % name)
pylab.grid(True)
pylab.legend()
Expand Down Expand Up @@ -380,17 +385,18 @@ def summary(self, Nbest=5, lw=2, plot=True, method="sumsquare_error"):
try:
names = self.df_errors.sort_values(
by=method).index[0:Nbest]
except: #pragma: no cover
except: # pragma: no cover
names = self.df_errors.sort(method).index[0:Nbest]
return self.df_errors.loc[names]

def _timed_run(self, func, distribution, args=(), kwargs={}, default=None):
def _timed_run(self, func, distribution, args=(), kwargs={}, default=None):
"""This function will spawn a thread and run the given function
using the args, kwargs and return the given default value if the
timeout is exceeded.

http://stackoverflow.com/questions/492519/timeout-on-a-python-function-call
"""

class InterruptableThread(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
Expand All @@ -400,10 +406,10 @@ def __init__(self):
def run(self):
try:
self.result = func(args, **kwargs)
except Exception as err: #pragma: no cover
except Exception as err: # pragma: no cover
self.exc_info = sys.exc_info()

def suicide(self): # pragma: no cover
def suicide(self): # pragma: no cover
raise RuntimeError('Stop has been called')

it = InterruptableThread()
Expand All @@ -413,11 +419,11 @@ def suicide(self): # pragma: no cover
ended_at = datetime.now()
diff = ended_at - started_at

if it.exc_info[0] is not None: #pragma: no cover ; if there were any exceptions
if it.exc_info[0] is not None: # pragma: no cover ; if there were any exceptions
a, b, c = it.exc_info
raise Exception(a, b, c) # communicate that to caller

if it.isAlive(): #pragma: no cover
if it.isAlive(): # pragma: no cover
it.suicide()
raise RuntimeError
else:
Expand Down
35 changes: 15 additions & 20 deletions test/test_fitter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from fitter import Fitter, get_distributions, get_common_distributions
from fitter import Fitter, get_common_distributions, get_distributions


def test_dist():
assert 'gamma' in get_common_distributions()
assert len(get_distributions())> 40

assert len(get_distributions()) > 40


def test_fitter():
f = Fitter([1,1,1,2,2,2,2,2,3,3,3,3], distributions=['gamma'], xmin=0, xmax=4)
f = Fitter([1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3], distributions=['gamma'], xmin=0, xmax=4)
try:
f.plot_pdf()
except:
except Exception:
pass
f.fit()
f.summary()
Expand All @@ -24,8 +23,7 @@ def test_fitter():
assert f.xmin == 1
assert f.xmax == 3


f = Fitter([1,1,1,2,2,2,2,2,3,3,3,3], distributions=['gamma'])
f = Fitter([1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3], distributions=['gamma'])
f.fit(progress=True)
f.summary()
assert f.xmin == 1
Expand All @@ -36,12 +34,11 @@ def test_gamma():
from scipy import stats
data = stats.gamma.rvs(2, loc=1.5, scale=2, size=10000)


f = Fitter(data, bins=100)
f.xmin = -10 #should have no effect
f.xmax = 1000000 # no effet
f.xmin=0.1
f.xmax=10
f.xmin = -10 # should have no effect
f.xmax = 1000000 # no effet
f.xmin = 0.1
f.xmax = 10
f.distributions = ['gamma', "alpha"]
f.fit()
df = f.summary()
Expand All @@ -66,11 +63,9 @@ def test_others():
assert f.df_errors.loc["gamma"].loc['aic'] > 100










def test_n_jobs_api():
from scipy import stats
data = stats.gamma.rvs(2, loc=1.5, scale=2, size=1000)
f = Fitter(data, distributions="common")
f.fit(n_jobs=-1)
f.fit(n_jobs=1)