# -*- coding: utf-8 -*-
"""
Created on Wed Sep 23 19:14:52 2020

@author: jbobowsk
"""

# Created using Spyder(Python 3.7)
# Solving nonlinear systems of equations

# To solve a system of nonlinear equations, we will use 'fsolve()' which requires
# the `scipy.optimize' module.
import numpy as np
from scipy.optimize import fsolve

# First, we will solve a single nonlinear equations.  We start by defining
# our equation as a function.
def f(x):
    return 3*x**3 - 2*x**2 + x - 7

# Once the function is defined, we can evaluate it for different values of x
# by:
print(f(1))

# We can now use fsolve().  The first arguement is the function (i.e. the equation
# that we want to solve, assumed to be equal to zero) and the second argument
# is an initial guess at the solution.
x = fsolve(f, 1)

# The output of fsolve() is a list.
print(x)

# We can index the list to get the final solution.
print(x[0])

# Let's check that the solution works.
print(f(x[0]))

# Here's another example: Find y such that tan(e^(-2y)) = 1/y.
def f(y):
    return np.tan(np.exp(-2*y)) - 1/y

y = fsolve(f, -0.1)
print(y)

# Test the solution...
print(f(y[0]))

# Note that there is another way to use fsolve such that it outputs just the 
# numerical solution (rather than inserting it into a list).
y, = fsolve(f, -0.1)
print(y)
print(f(y))


# We can also solve a system of nonlinear equations.  Below we define a system
# of three equations with three unknown variables.
def equations(vars):
    x, y, z = vars
    eq1 = 3*x*y + y - z -10
    eq2 = x + x**2*y + z -12
    eq3 = x -y -z + 2
    return [eq1, eq2, eq3]

# To evaluate the three equations are particular values of x, y, and z, we can
# enter:
print(equations([1, 0, -3]))

# Here is the fsolve() statement with an initial guess at the solution
x, y, z =  fsolve(equations, (1, 1, 1))

# Using this 'comma' notation, the get the three solutions directly.
print('x =', x, 'y =', y, 'z =', z)

# Here, we check that all three of the equations actually evaluate to something
# close to zero. Note that, the 'SymPy' module is needed to use the 'N()' function. 
from sympy import *
print('eq1 =', N(equations([x, y, z])[0], 3), 'eq2 =', N(equations([x, y, z])[1], 3),\
      'eq3 =', N(equations([x, y, z])[2], 3))