1

I've written a function that solves a system of equations, it does however not work when I have a square root in my solutions. The code does work for other equations as long as there are no square roots.

I am getting the following error

TypeError: No loop matching the specified signature and casting
was found for ufunc solve1

I could calculate the sqrt and get a decimal number but I don't want that. I need to do my calculations with complete numbers, I would rather have it return sqrt(5) than 2.236067977

I'm currently trying to solve the following recurrence relation

eqs :=
[
s(n) = s(n-1)+s(n-2),
s(0) = 1,
s(1) = 1
];

I've written down my outputs and steps theoratically down here. It does work for equations without square roots. How can I get linalg to work with sqrt or should I use a different approach?

def solve_homogeneous_equation(init_conditions, associated):
    # Write down characteristic equation for r
    Returns eq_string = r^2-1*r^(2-1)-+1*r^(2-2)

    # Find the roots for r
    Returns r_solutions = [1/2 + sqrt(5)/2, -sqrt(5)/2 + 1/2]

    # Write down general solution (for solver)
    Returns two lists, one with variables and one with outcomes
    general_solution_variables = [[1, 1], [1/2 + sqrt(5)/2, -sqrt(5)/2 + 1/2]]
    general_solution_outcomes = [1, 1]

    # Solve the system of equations
    THIS IS WHERE THE ERROR OCCURS
    solution = np.linalg.solve(general_solution_variables, general_solution_outcomes)

    # Write the solution
    This is where I rewrite the general solution with found solutions

The raw function is defined here, in case you want to look deeper in the code

def solve_homogeneous_equation(init_conditions, associated):
    print("Starting solver")
    # Write down characteristic equation for r
    eq_length = len(init_conditions)
    associated[0] = str('r^' + str(eq_length))
    for i in range(eq_length, 0, -1):
        if i in associated.keys() :
            associated[i] = associated[i] + str('*r^(') + str(eq_length) + str('-') + str(i) + str(')')
    print("Associated equation: " + str(associated))
    eq_string = ''
    for i in range(0, eq_length+1, 1):
        if i in associated.keys():
            if i < eq_length:
                eq_string = eq_string + associated[i] + '-'
            else:
                eq_string = eq_string + associated[i]
    print("Equation: " + eq_string)

    # Find the roots for r
    r_symbol = sy.Symbol('r')
    r_solutions = sy.solve(eq_string, r_symbol)
    r_length = len(r_solutions)
    print("Solutions: " + str(r_solutions))
    print("Eq length: " + str(eq_length) + " ; Amount of solutions: " + str(r_length))

    # If equation length is equal to solutions
    if eq_length == r_length:

        # Write down general solution (for solver)
        general_solution_variables = []
        general_solution_outcomes = []
        for i in range(0, eq_length):
            general_solution_variables_single = []
            for j in range(0, eq_length + 1):
                if j != eq_length:
                    k = r_solutions[j]**i
                    general_solution_variables_single.append(k)
                if j == eq_length:
                    k = init_conditions[i]
                    general_solution_outcomes.append(int(k))
            general_solution_variables.append(general_solution_variables_single)
        print("General solution variables: " + str(general_solution_variables))
        print("General solution outcomes: " + str(general_solution_outcomes))

        # Solve the system of equations
        solution = np.linalg.solve(general_solution_variables, general_solution_outcomes)
        print("Solutions: " + str(solution))

        # Write the solution
        solution_full = ""
        for i in range(0, eq_length):
            if i > 0:
                solution_full = solution_full + " + "
            solution_full = solution_full + str(int(solution[i])) + "*" + str(int(r_solutions[i])) + "^n"
        print("Solved equation: " + solution_full)
        return(solution_full)

    # If equation length is not equal to solutions
    elif eq_length > r_length:
        print("NonEqual")
        return 0
Rodi
  • 122
  • 8

1 Answers1

1

I haven't tried very hard to read your code. I note that you can solve that system symbolically using sympy.

  • As usual with sympy terms in function definitions and equations must all be moved to one side of equals sign.
  • The initial conditions are passed in a dict.

>>> from sympy import *
>>> from sympy.solvers.recurr import rsolve
>>> var('n')
n
>>> s = Function('s')
>>> f = s(n) - s(n-1) -s(n-2)
>>> rsolve(f, s(n), {s(0):1, s(1):1})
(1/2 + sqrt(5)/2)**n*(sqrt(5)/10 + 1/2) + (-sqrt(5)/2 + 1/2)**n*(-sqrt(5)/10 + 1/2)
Bill Bell
  • 21,021
  • 5
  • 43
  • 58