8. SciPy范德蒙矩阵多项式逼近

本章基于范德蒙(德)Vandermonde矩阵求若干点的多项式表达式。

8.1 范德蒙矩阵

范德蒙矩阵行数为m,列数为n,矩阵具有最大的秩min(m, n),其形式如下所示:

在numpyPy模块下可以利用polynomial子模块的polynomial.polyvander构造出范德蒙矩阵来或者用scipy的vander也可构造出范德蒙矩阵来。

8.2 多项式逼近与范德蒙

范德蒙矩阵和多项式有啥关系呢?假设现在有4个点(1,1),(2,3),(3,5),(4,4)有没有一个x的3次多项式刚好经过这四个点呢?即$f(x) = c_0x^0 + c_1x^1 + c_2x^2 + c_3x^3$,如果能得到$(c_0,c_1,c_2,c_3)$这个四个系数就能确定$f(x)$了。由于假设这四个点经过了$f(x)$,那么便有:

写成矩阵表示形式如下:

上边$4\times4$的矩阵不就是范德蒙矩阵么?形如$Ax=y$的且知道A和y就可以用solve函数来求得x的值,这里的x就是$f(x) = c_0x^0 + c_1x^1 + c_2x^2 + c_3x^3$方程里的系数$c_i$构成的向量$(c_0,c_1,c_2,c_3)$。而scipy.vander刚好可以基于四个点的x坐标$(1,2,3,4)$构造出这个$4\times4的矩阵$

import scipy
import numpy as np
import numpy.polynomial.polynomial as npp
x = np.array([1,2,3,4])
A = scipy.vander(x,increasing=True)
print A,"#A"
A = npp.polyvander(x, 3)
print A,"#A"

程序执行结果:

[[ 1  1  1  1]
 [ 1  2  4  8]
 [ 1  3  9 27]
 [ 1  4 16 64]] #A
[[  1.   1.   1.   1.]
 [  1.   2.   4.   8.]
 [  1.   3.   9.  27.]
 [  1.   4.  16.  64.]] #A

8.3 范德蒙矩阵多项式逼近

好回到原问题求$f(x)$的各个系数就可最终知道$f(x)$了:

1). 首先给出四个点的x和y坐标

import numpy as np
x = np.array([1, 2, 3, 4])
y = np.array([1, 3, 5, 4])

2).用$x$构造上边那个$4\times4$矩阵,可以用scipy.vander函数构造。

import scipy
A = scipy.vander(x,increasing=True)
print A,"#scipy"

或者用numpy的模块来构造也行,

import numpy.polynomial.polynomial as npp
A = npp.polyvander(x, 3)
print A,"#A"

3). 方程$Ax=y$有了$A$和$y$求x(这里是求c系数)可以用solve函数求解。

c = np.linalg.solve(A, y)
print c

c即是$f(x) = c_0x^0 + c_1x^1 + c_2x^2 + c_3x^3$方程里的各个系数$c_i$,这样方程就找到了。

8.4 SciPy多项式逼近程序

基于范德蒙矩阵多项式逼近完整的程序如下所示:

import numpy.polynomial.polynomial as npp
import numpy as np
import scipy

x = np.array([1, 2, 3, 4])
y = np.array([1, 3, 5, 4])
A = npp.polyvander(x, 3)
print A,"#numpy"
A = scipy.vander(x,increasing=True)
print A,"#scipy"
c = np.linalg.solve(A, y)
print c,"#c"
x = np.linspace(1,50, 50)
y = c[0] + c[1] * x + c[2]*(x**2) + c[3]*(x ** 3)
import matplotlib.pyplot as plt
plt.plot(x[:4], y[:4],'go')
plt.xlim(-1,6)
plt.ylim(-1,6)
plt.show()
plt.plot(x, y)
plt.plot(x, y,'go')
plt.show()

程序执行结果:

[[  1.   1.   1.   1.]
 [  1.   2.   4.   8.]
 [  1.   3.   9.  27.]
 [  1.   4.  16.  64.]] #numpy
[[ 1  1  1  1]
 [ 1  2  4  8]
 [ 1  3  9 27]
 [ 1  4 16 64]] #scipy
[ 2.  -3.5  3.  -0.5]#c

从c的输出[ 2. -3.5 3. -0.5]得知方程为:

$f(x) = c_0x^0 + c_1x^1 + c_2x^2 + c_3x^3 =2\times x^0 -3.5\times x^1 + 3\times x^2 -0.5\times x^3 =2-3.5 x + 3x^2 -0.5 x^3$

可视化输出(1,1)、(2,3)、(3,5)、(4, 4)四个点如下所示: