4

I have a least squares error function (based on another function) that I'd like to minimize (to obtain a globally minimized curve fit, basically), which looks something like this:

def err(a, b, X, Y):
    return np.sum((f(a, b, X)-Y)**2)

with X being an array of points at which f is evaluated, depending on the parameters a and b, and Y being the "ground truth" for the points defined in X.

Now according to what I found in questions 25206482 and 31388319 the syntax should be as follows:

Xc = np.array([1.0, 2.0, 3.0, 4.0, 5.0])   # points at which to evaluate error function
Yc = np.array([0.2, 0.4, 0.8, 0.12, 0.15]) # ground truth
g0 = np.array([1.0, 3.0])                  # initial guess for a and b    
res = scipy.optimize.minimize(err, g0, args=(Xc, Yc), method="Powell")

Unfortunately, I get the following error message:

TypeError: err() takes exactly 4 arguments (3 given)

If I delete Xc or Yc from the tuple, the number of arguments given decreases, so I suspect it is somewhere in the definition of g0, because this seems to be passed to err as a single argument.

How do I call minimize properly if I have more than one parameter for optimization and additional "constant" arguments I want to pass to my function during optimization?

Raketenolli
  • 743
  • 8
  • 24
  • Test your setup with `err(g0, Xc, Yc)` - that evaluates your function as the initial guess. If that fails then you need to change `err` or the `args` tuple. – hpaulj Jul 26 '17 at 16:20

2 Answers2

3

minimize manipulates only a single parameter, but this can be an array. What you have to do is extract a and b from the first parameter in err:

def err(p, X, Y):
    a, b = p
    return np.sum((f(a, b, X)-Y)**2)
MB-F
  • 22,770
  • 4
  • 61
  • 116
2

I found that minimizing works if I define the error function such that it uses a list of parameters instead of several parameters:

def err(p, X, Y):
    a = p[0]
    b = p[1]
    return np.sum((f(a, b, X)-Y)**2)

and then call minimize:

g0 = [1.0, 3.0]
res = scipy.optimize.minimize(err, g0, args=(Xc, Yc), method="Powell")
Raketenolli
  • 743
  • 8
  • 24