pol.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. #Copyright (c) 2008, Riccardo De Maria
  2. #All rights reserved.
  3. import math
  4. from functools import reduce
  5. pi=math.pi
  6. _mabs=abs
  7. def abs(a):
  8. if type(a) is pol:
  9. a0,c=a.separate()
  10. if c:
  11. return 1e20
  12. else:
  13. return _mabs(a0)
  14. return _mabs(a)
  15. def normint(n,m):
  16. """coefficient of int^m x^n dx"""
  17. c=1
  18. for i in range(m):
  19. c*=n+i+1
  20. return c
  21. def normder(n,m):
  22. """coefficient of d^m / dx^m x^n"""
  23. c=1
  24. for i in range(m):
  25. c*=n-i
  26. return c
  27. def mkpol(r):
  28. return map(pol,r.split(','))
  29. def pinv(c):
  30. a0,p=c.separate(); p/=a0
  31. lst=[1/a0]
  32. for n in range(1,c.order+1):
  33. lst.append(-lst[-1])
  34. return phorner(lst,p)
  35. def phorner(lst,p):
  36. out=lst.pop()
  37. for i in reversed(lst):
  38. out=out*p+i
  39. return out
  40. class mydict(dict):
  41. """ Dictionary for evaluation"""
  42. def __getitem__(self,k):
  43. if not k in self:
  44. return pol(k)
  45. else:
  46. return dict.__getitem__(self,k)
  47. class pol(dict):
  48. """ Class for multivariate polinomials
  49. >>> from pol import *
  50. >>> print pol('x**3+z-5-x**2')
  51. - 5.0 + z - 1.0*x**2 + x**3
  52. >>> x=pol('x')
  53. >>> y=pol('y')
  54. >>> c=pol(pi)
  55. >>> p=c+pi*x+y**2
  56. >>> print p
  57. 3.14159265359 + 3.14159265359*x + y**2
  58. >>> p2=pol(p)
  59. >>> print p(x=p2,y=p)
  60. 22.8808014558 + 29.6088132033*x + 9.86960440109*x**2 + 9.42477796077*y**2 + 6.28318530718*y**2*x + y**4
  61. >>> pol.out='table'
  62. >>> print p(x=p2,y=p)
  63. o y x
  64. 22.8808014558 0 0 0
  65. 29.6088132033 1 0 1
  66. 9.86960440109 2 0 2
  67. 9.42477796077 2 2 0
  68. 6.28318530718 3 2 1
  69. 1.0 4 4 0
  70. >>> pol.out='pretty'
  71. >>> print pol('1j*x')
  72. 1j*x
  73. """
  74. out='pretty'
  75. def __init__(self,val=None,order=None,eps=1E-18,loc={},m='eval'):
  76. self.vars=[]
  77. self.order=10
  78. self.eps=eps
  79. if val!=None:
  80. if isinstance(val,self.__class__):
  81. self.update(val)
  82. self.vars=val.vars[:]
  83. self.order=val.order
  84. self.eps=val.eps
  85. elif isinstance(val,str):
  86. if m=='eval':
  87. c=compile(val,'eval','eval')
  88. l=dict( (i,pol(i,m='name')) for i in c.co_names if i not in globals())
  89. l.update(loc)
  90. pol.__init__(self,eval(c,globals(),l))
  91. elif m=='name':
  92. self.vars=[val]
  93. self[(1,)]=1.
  94. else:
  95. self.vars=[]
  96. self[()]=val
  97. if order is not None:
  98. self.order=order
  99. def zero(self):
  100. """Extract zero order
  101. """
  102. return self.get((0,)*len(self.vars),0.)
  103. def linear(self):
  104. a=[1]+[0]*(len(self.vars)-1)
  105. out=[]
  106. for i in self.vars:
  107. out.append( self.get(tuple(a),0) )
  108. a.insert(0,a.pop())
  109. return out
  110. def getlind(self,v):
  111. ind=[0]*len(self.vars)
  112. ind[ self.vars.index(v)]=1
  113. return tuple(ind)
  114. def getlcoef(self,v):
  115. return self.get(self.getlind(v),0.)
  116. def setlcoef(self,v,val):
  117. self[self.getlind(v)]=val
  118. def separate(self):
  119. """Return a copy in couple of zero and high order
  120. """
  121. c=self.__class__(self)
  122. a0=c.pop((0,)*len(c.vars),0.)
  123. return a0,c
  124. def truncate(self,order=None,eps=None):
  125. """Truncate to the order indicate in pol.order and
  126. damp the elements smaller than self.eps"""
  127. if not order: order=self.order
  128. if not eps: eps=self.eps
  129. for k in list(self.keys()):
  130. if sum(k)>order or abs(self[k])<eps:
  131. del self[k]
  132. return self
  133. def const(self,vars=None):
  134. """Extract terms that do not depends on vars
  135. >>> from pol import *
  136. >>> print pol('x**3+y+l').const()
  137. 0.0
  138. >>> print pol('x**3+y+l').const('yl')
  139. x**3
  140. """
  141. if not vars:
  142. return self.get((0,)*len(self.vars),0.)
  143. else:
  144. c=self.__class__(self)
  145. for exp in self:
  146. expd=dict(zip(self.vars,exp))
  147. if sum(expd.get(j,0) for j in vars) >0 :
  148. del c[exp]
  149. return c
  150. def dropneg(self):
  151. """ Delete terms with negative exponent"""
  152. for k in self.keys():
  153. for j in k:
  154. if j<0:
  155. del self[k]
  156. break
  157. return self
  158. def addcoef(self,other):
  159. """Add a number to pol"""
  160. new=self.__class__(self)
  161. i=(0,)*len(new.vars)
  162. new[i]=new.get(i,0.)+other
  163. return new.truncate()
  164. def mulcoef(self,other):
  165. """Mul coef to pol"""
  166. new=self.__class__(self)
  167. for i in new:
  168. new[i]*=other
  169. return new.truncate()
  170. def reorder(self,vars):
  171. new=self.__class__(order=self.order)
  172. new.vars=vars
  173. for exp in self:
  174. expd=dict(zip(self.vars,exp))
  175. newexp=tuple([expd.get(j,0) for j in new.vars])
  176. new[newexp]=self[exp]
  177. return new
  178. def addpol(self,other):
  179. """Add pol to pol """
  180. new=self.__class__(order=min(self.order,other.order))
  181. new.vars=list(set(other.vars+self.vars))
  182. for exp in self:
  183. expd=dict(zip(self.vars,exp))
  184. newexp=tuple([expd.get(j,0) for j in new.vars])
  185. new[newexp]=new.get(newexp,0.)+self[exp]
  186. for exp in other:
  187. expd=dict(zip(other.vars,exp))
  188. newexp=tuple([expd.get(j,0) for j in new.vars])
  189. new[newexp]=new.get(newexp,0.)+other[exp]
  190. return new.truncate()
  191. def mulpol(self,other):
  192. """Mul pol to pol """
  193. if other.vars==self.vars:
  194. return self.fmulpol(other)
  195. else:
  196. new=pol(order=min(self.order,other.order))
  197. new.vars=list(set(other.vars+self.vars))
  198. for i in self:
  199. for j in other:
  200. c=self[i]*other[j]
  201. expi=dict(zip(self.vars,i))
  202. expj=dict(zip(other.vars,j))
  203. newexp=tuple([expi.get(k,0)+expj.get(k,0) for k in new.vars])
  204. new[newexp]=new.get(newexp,0.)+c
  205. return new.truncate()
  206. def fmulpol(self,other):
  207. """fast mul pol to pol """
  208. new=self.__class__(order=min(self.order,other.order))
  209. new.vars=self.vars[:]
  210. for i in self:
  211. for j in other:
  212. c=self[i]*other[j]
  213. newexp=tuple([l+m for l,m in zip(i,j)])
  214. new[newexp]=new.get(newexp,0.)+c
  215. return new.truncate()
  216. def divterm(self,other):
  217. """Extract a term from a pol EXPERIMENTAL:
  218. >>> from pol import *
  219. >>> r=pol('x**3+z-5').divterm('x**3')
  220. >>> print r
  221. 1.0
  222. """
  223. other=pol(other)
  224. new=pol(order=min(self.order,other.order))
  225. other=pol(other)
  226. new.vars=list(set(other.vars+self.vars))
  227. for i in self:
  228. for j in other:
  229. c=self[i]*other[j]
  230. expi=dict(zip(self.vars,i))
  231. expj=dict(zip(other.vars,j))
  232. newexp=tuple([expi.get(k,0)-expj.get(k,0) for k in new.vars])
  233. new[newexp]=new.get(newexp,0.)+c
  234. return new.truncate().dropneg()
  235. def __pow__(self,n):
  236. new=self.__class__(self)
  237. if n==0:
  238. return 1
  239. elif n <0:
  240. return pinv(self)**n
  241. else:
  242. for i in range(n-1):
  243. new*=self
  244. return new
  245. def __add__(self,other):
  246. """Addition
  247. >>> from pol import *
  248. >>> x=pol('x')
  249. >>> 1-x+x
  250. 1.0
  251. """
  252. if isinstance(other,self.__class__):
  253. return self.addpol(other)
  254. else:
  255. return self.addcoef(other)
  256. def __radd__(self,other):
  257. return self.addcoef(other)
  258. def __mul__(self,other):
  259. """Addition
  260. >>> from pol import *
  261. >>> x=pol('x')
  262. >>> (1+1.0*x)/-(-x-1)
  263. 1.0
  264. """
  265. if isinstance(other,self.__class__):
  266. return self.mulpol(other)
  267. else:
  268. return self.mulcoef(other)
  269. def __rmul__(self,other):
  270. return self.mulcoef(other)
  271. def __sub__(self,other):
  272. if isinstance(other,pol):
  273. return self.addpol(-other)
  274. else:
  275. return self.addcoef(-other)
  276. def __rsub__(self,other):
  277. return (-self).addcoef(other)
  278. def __truediv__(self,other):
  279. if isinstance(other,pol):
  280. return self.mulpol(pinv(other))
  281. else:
  282. return self.mulcoef(1/other)
  283. def __rtruediv__(self,other):
  284. return pinv(self).mulcoef(other)
  285. def __neg__(self):
  286. return self.mulcoef(-1)
  287. def __pos__(self):
  288. return self
  289. def eval(self,*args,**loc):
  290. if len(args)>0:
  291. loc.update(args[0])
  292. for i in self.vars:
  293. loc.setdefault(i,pol(i))
  294. loc=mydict(loc)
  295. return eval(self.pretty(),{},loc)
  296. __call__=eval
  297. def _pexp(self,i):
  298. out=[]
  299. for j in range(len(i)):
  300. if i[j]==1:
  301. out.append( '%s' % self.vars[j] )
  302. elif i[j]!=0.:
  303. out.append( '%s**%d' %(self.vars[j],i[j]) )
  304. return '*'.join(out)
  305. def _pcoeff(self,c,i):
  306. if isinstance(c,complex):
  307. if abs(c.imag)<self.eps:
  308. return self._pcoeff(c.real,i)
  309. else:
  310. c='+ '+str(c)
  311. elif isinstance(c,float):
  312. sign=c<0 and '-' or '+'
  313. if abs(c-1.0)<self.eps and i:
  314. c='%s %s' % (sign,i)
  315. i=''
  316. else:
  317. c='%s %s' % (sign,abs(c))
  318. else:
  319. c='+ (%s)' % c
  320. if i:
  321. return '%s*%s' % (c,i)
  322. else:
  323. return c
  324. def pretty(self):
  325. lst=sorted([ (sum(i),i,c) for i,c in self.items()])
  326. m=[]
  327. for o,i,c in lst:
  328. i=self._pexp(i)
  329. c=self._pcoeff(c,i)
  330. m.append( c )
  331. if m:
  332. m=' '.join(m)
  333. if m.startswith('+ '):
  334. return ' '+m[2:]
  335. else:
  336. return m
  337. else:
  338. return '0'
  339. def table(self):
  340. fvar=lambda x,y:'%3s%3s' %(x,y)
  341. out=[['',0,0,reduce(fvar,self.vars),' o']]
  342. lst=sorted([ (sum(i),i) for i in self])
  343. rmax,cmax=0,0
  344. for order,exp in lst:
  345. coef=str(self[exp])
  346. r=[coef,len(coef),coef.find('.'),reduce(fvar,exp),str(order)]
  347. rmax=r[1]>rmax and r[1] or rmax
  348. cmax=r[2]>cmax and r[2] or cmax
  349. out.append(r)
  350. nout=[]
  351. for c,l,p,e,o in out:
  352. c=('%%-%ds'%(rmax+cmax)) % (' '*(cmax-p)+ c)
  353. nout.append('%(c)s %(o)2s %(e)s' % locals() )
  354. return '\n'.join(nout)
  355. def __repr__(self):
  356. return getattr(pol,pol.out)(self)
  357. #if __name__=='__main__':
  358. # import doctest
  359. # doctest.testmod()
  360. # import profile
  361. # pol.order=9
  362. # profile.run('pol("sqrt(1+x+y+z+px+py+pz)")',sort='time')