1

I would like to write a curve-fitting script that allows me to fix parameters of a function of the form:

def func(x, *p):
    assert len(p) % 2 == 0
    fval = 0
    for j in xrange(0, len(p), 2):
        fval += p[j]*np.exp(-p[j+1]*t)
    return fval

For example, let's say I want p = [p1, p2, p3, p4], and I want p2 and p3 to be constant A and B (going from a 4-parameter fit to a 2-parameter fit). I understand that functools.partial doesn't let me do this which is why I want to write my own wrapper. But I am having a bit of trouble doing so. This is what I have so far:

def fix_params(f, t, pars, fix_pars):
    # fix_pars = ((ind1, A), (ind2, B))
    new_pars = [None]*(len(pars) + len(fix_pars))
    for ind, fix in fix_pars:
        new_pars[ind] = fix
    for par in pars:
        for j, npar in enumerate(new_pars):
            if npar == None:
                new_pars[j] = par
                break
    assert None not in new_pars
    return f(t, *new_pars)

The problem with this I think is that, scipy.optimize.curve_fit won't work well with a function passed through this kind of wrapper. How should I get around this?

firest
  • 57
  • 7
  • 1
    Unrelated to your question but it's not good practice to use assert for parameter checking. It's better to use if statement and throw proper errors. – Tim Jan 26 '19 at 21:06
  • Sorry, what makes using an assert for parameter checking bad practice. I just wanted a neat one-liner that would guarantee that all parameters have been taken into account. Are there situations where everything would go awry with asserts? – firest Jan 26 '19 at 21:26
  • Assert is mostly for unit testing. Basically it's used to test if your code is running as intended. For example, if you wrote a function sum(x, y) that returns sum of x and y, you can use assert to see if it's adding as expected. E.g. asset sum(x, y) == x + y. Moreover, if you run python with the -O flag, all asserts will be disabled. See this post for more discussions. https://softwareengineering.stackexchange.com/questions/225956/python-assert-vs-if-return – Tim Jan 26 '19 at 21:31
  • Here's an example of correct assert usage (in the question, the code asserts if 2 different implementations give the same result): https://stackoverflow.com/questions/54382727/are-python-generators-faster-than-nested-for-loops – Tim Jan 26 '19 at 21:38

2 Answers2

0

Sounds like what you want to do is currying? In Python, you can do this with inner functions.

Example:

def foo(x):
    def bar(y):
        return x + y
    return bar

bar = foo(3)
print(type(bar))    # a function (of one variable with the other fixed to 3)
print(bar(8))       # 11
bar = foo(9)
print(bar(8))       # 17

In this way we can fix x in the function x + y. You can also put this into a decorator.

Here's a blog post someone wrote on doing this: https://mtomassoli.wordpress.com/2012/03/18/currying-in-python/

Regarding what will play nice with external libraries, the function foo here will return a function. In Python functions are first-class objects. So anything you give the returned function to will just see it as a function.

Neil
  • 3,020
  • 4
  • 25
  • 48
  • I have never heard of currying. Thank you so much. I'll see what I can do with this and let you know if it works. – firest Jan 26 '19 at 21:24
0

So I think I have something workable. Maybe there is a way to improve on this.

Here is my code (without all the exception handling):

def func(x, *p):
    fval = 0
    for j in xrange(0, len(p), 2):
        fval += p[j]*np.exp(-p[j+1]*x)
    return fval

def fix_params(f, fix_pars):
    # fix_pars = ((1, A), (2, B))
    def new_func(x, *pars):
        new_pars = [None]*(len(pars) + len(fix_pars))
        for j, fp in fix_pars:
            new_pars[j] = fp
        for par in pars:
            for j, npar in enumerate(new_pars):
                if npar is None:
                    new_pars[j] = par
                    break
        return f(x, *new_pars)
    return new_func

p1 = [1, 0.5, 0.1, 1.2]
pfix = ((1, 0.5), (2, 0.1))
p2 = [1, 1.2]

new_func = fix_params(func, pfix)

x = np.arange(10)
dat1 = func(x, *p1)
dat2 = new_func(x, *p2)

if (dat1==dat2).all()
    print "ALL GOOD"
firest
  • 57
  • 7