Convert entire OpenRAM to use python3. Works with Python 3.6.

Major changes:
Remove mpmath library and use numpy instead.
Convert bytes to new bytearrays.
Fix class name check for duplicate gds instances.
Add explicit integer conversion from floats.
Fix importlib reload from importlib library
Fix new key/index syntax issues.
Fix filter and map conversion to lists.
Fix deprecation warnings.
Fix Circuits vs Netlist in Magic LVS results.
Fix file closing warnings.
This commit is contained in:
Matt Guthaus 2018-05-11 16:32:00 -07:00
parent 58628d7867
commit f34c4eb7dc
179 changed files with 803 additions and 42105 deletions

View File

@ -29,18 +29,18 @@ class design(hierarchy_spice.spice, hierarchy_layout.layout):
# because each reference must be a unique name.
# These modules ensure unique names or have no changes if they
# aren't unique
ok_list = ['ms_flop.ms_flop',
'dff.dff',
'dff_buf.dff_buf',
'bitcell.bitcell',
'contact.contact',
'ptx.ptx',
'sram.sram',
'hierarchical_predecode2x4.hierarchical_predecode2x4',
'hierarchical_predecode3x8.hierarchical_predecode3x8']
ok_list = ['ms_flop',
'dff',
'dff_buf',
'bitcell',
'contact',
'ptx',
'sram',
'hierarchical_predecode2x4',
'hierarchical_predecode3x8']
if name not in design.name_map:
design.name_map.append(name)
elif str(self.__class__) in ok_list:
elif self.__class__.__name__ in ok_list:
pass
else:
debug.error("Duplicate layout reference name {0} of class {1}. GDS2 requires names be unique.".format(name,self.__class__),-1)

View File

@ -116,10 +116,11 @@ class spice(verilog.verilog):
self.spice = f.readlines()
for i in range(len(self.spice)):
self.spice[i] = self.spice[i].rstrip(" \n")
f.close()
# find the correct subckt line in the file
subckt = re.compile("^.subckt {}".format(self.name), re.IGNORECASE)
subckt_line = filter(subckt.search, self.spice)[0]
subckt_line = list(filter(subckt.search, self.spice))[0]
# parses line into ports and remove subckt
self.pins = subckt_line.split(" ")[2:]
else:

View File

@ -20,7 +20,7 @@ class pin_layout:
self.rect = [x.snap_to_grid() for x in self.rect]
# if it's a layer number look up the layer name. this assumes a unique layer number.
if type(layer_name_num)==int:
self.layer = layer.keys()[layer.values().index(layer_name_num)]
self.layer = list(layer.keys())[list(layer.values()).index(layer_name_num)]
else:
self.layer=layer_name_num
self.layer_num = layer[self.layer]

View File

@ -1,9 +1,9 @@
import os
import debug
from globals import OPTS,find_exe,get_tool
import lib
import delay
import setup_hold
from .lib import *
from .delay import *
from .setup_hold import *
debug.info(2,"Initializing characterizer...")

View File

@ -2,9 +2,9 @@ import sys,re,shutil
import debug
import tech
import math
import stimuli
from trim_spice import trim_spice
import charutils as ch
from .stimuli import *
from .trim_spice import *
from .charutils import *
import utils
from globals import OPTS
@ -101,7 +101,7 @@ class delay():
self.sf.write("* Delay stimulus for period of {0}n load={1}fF slew={2}ns\n\n".format(self.period,
self.load,
self.slew))
self.stim = stimuli.stimuli(self.sf, self.corner)
self.stim = stimuli(self.sf, self.corner)
# include files in stimulus file
self.stim.write_include(self.trim_sp_file)
@ -339,16 +339,16 @@ class delay():
# Checking from not data_value to data_value
self.write_delay_stimulus()
self.stim.run_sim()
delay_hl = ch.parse_output("timing", "delay_hl")
delay_lh = ch.parse_output("timing", "delay_lh")
slew_hl = ch.parse_output("timing", "slew_hl")
slew_lh = ch.parse_output("timing", "slew_lh")
delay_hl = parse_output("timing", "delay_hl")
delay_lh = parse_output("timing", "delay_lh")
slew_hl = parse_output("timing", "slew_hl")
slew_lh = parse_output("timing", "slew_lh")
delays = (delay_hl, delay_lh, slew_hl, slew_lh)
read0_power=ch.parse_output("timing", "read0_power")
write0_power=ch.parse_output("timing", "write0_power")
read1_power=ch.parse_output("timing", "read1_power")
write1_power=ch.parse_output("timing", "write1_power")
read0_power=parse_output("timing", "read0_power")
write0_power=parse_output("timing", "write0_power")
read1_power=parse_output("timing", "read1_power")
write1_power=parse_output("timing", "write1_power")
if not self.check_valid_delays(delays):
return (False,{})
@ -378,22 +378,24 @@ class delay():
self.write_power_stimulus(trim=False)
self.stim.run_sim()
leakage_power=ch.parse_output("timing", "leakage_power")
leakage_power=parse_output("timing", "leakage_power")
debug.check(leakage_power!="Failed","Could not measure leakage power.")
self.write_power_stimulus(trim=True)
self.stim.run_sim()
trim_leakage_power=ch.parse_output("timing", "leakage_power")
trim_leakage_power=parse_output("timing", "leakage_power")
debug.check(trim_leakage_power!="Failed","Could not measure leakage power.")
# For debug, you sometimes want to inspect each simulation.
#key=raw_input("press return to continue")
return (leakage_power*1e3, trim_leakage_power*1e3)
def check_valid_delays(self, (delay_hl, delay_lh, slew_hl, slew_lh)):
def check_valid_delays(self, delay_tuple):
""" Check if the measurements are defined and if they are valid. """
(delay_hl, delay_lh, slew_hl, slew_lh) = delay_tuple
# if it failed or the read was longer than a period
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
debug.info(2,"Failed simulation: period {0} load {1} slew {2}, delay_hl={3}n delay_lh={4}ns slew_hl={5}n slew_lh={6}n".format(self.period,
@ -457,7 +459,7 @@ class delay():
else:
lb_period = target_period
if ch.relative_compare(ub_period, lb_period, error_tolerance=0.05):
if relative_compare(ub_period, lb_period, error_tolerance=0.05):
# ub_period is always feasible
return ub_period
@ -471,10 +473,10 @@ class delay():
# Checking from not data_value to data_value
self.write_delay_stimulus()
self.stim.run_sim()
delay_hl = ch.parse_output("timing", "delay_hl")
delay_lh = ch.parse_output("timing", "delay_lh")
slew_hl = ch.parse_output("timing", "slew_hl")
slew_lh = ch.parse_output("timing", "slew_lh")
delay_hl = parse_output("timing", "delay_hl")
delay_lh = parse_output("timing", "delay_lh")
slew_hl = parse_output("timing", "slew_hl")
slew_lh = parse_output("timing", "slew_lh")
# if it failed or the read was longer than a period
if type(delay_hl)!=float or type(delay_lh)!=float or type(slew_lh)!=float or type(slew_hl)!=float:
debug.info(2,"Invalid measures: Period {0}, delay_hl={1}ns, delay_lh={2}ns slew_hl={3}ns slew_lh={4}ns".format(self.period,
@ -495,10 +497,10 @@ class delay():
slew_lh))
return False
else:
if not ch.relative_compare(delay_lh,feasible_delay_lh,error_tolerance=0.05):
if not relative_compare(delay_lh,feasible_delay_lh,error_tolerance=0.05):
debug.info(2,"Delay too big {0} vs {1}".format(delay_lh,feasible_delay_lh))
return False
elif not ch.relative_compare(delay_hl,feasible_delay_hl,error_tolerance=0.05):
elif not relative_compare(delay_hl,feasible_delay_hl,error_tolerance=0.05):
debug.info(2,"Delay too big {0} vs {1}".format(delay_hl,feasible_delay_hl))
return False
@ -602,7 +604,7 @@ class delay():
debug.info(1, "Min Period: {0}n with a delay of {1} / {2}".format(min_period, feasible_delay_lh, feasible_delay_hl))
# 4) Pack up the final measurements
char_data["min_period"] = ch.round_time(min_period)
char_data["min_period"] = round_time(min_period)
return char_data

View File

@ -1,9 +1,9 @@
import os,sys,re
import debug
import math
import setup_hold
import delay
import charutils as ch
from .setup_hold import *
from .delay import *
from .charutils import *
import tech
import numpy as np
from globals import OPTS
@ -186,9 +186,9 @@ class lib:
""" Helper function to create quoted, line wrapped array with each row of given length """
# check that the length is a multiple or give an error!
debug.check(len(values)%length == 0,"Values are not a multiple of the length. Cannot make a full array.")
rounded_values = map(ch.round_time,values)
rounded_values = list(map(round_time,values))
split_values = [rounded_values[i:i+length] for i in range(0, len(rounded_values), length)]
formatted_rows = map(self.create_list,split_values)
formatted_rows = list(map(self.create_list,split_values))
formatted_array = ",\\\n".join(formatted_rows)
return formatted_array
@ -274,11 +274,11 @@ class lib:
self.lib.write(" timing_type : setup_rising; \n")
self.lib.write(" related_pin : \"clk\"; \n")
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
rounded_values = map(ch.round_time,self.times["setup_times_LH"])
rounded_values = list(map(round_time,self.times["setup_times_LH"]))
self.write_values(rounded_values,len(self.slews)," ")
self.lib.write(" }\n")
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
rounded_values = map(ch.round_time,self.times["setup_times_HL"])
rounded_values = list(map(round_time,self.times["setup_times_HL"]))
self.write_values(rounded_values,len(self.slews)," ")
self.lib.write(" }\n")
self.lib.write(" }\n")
@ -286,11 +286,11 @@ class lib:
self.lib.write(" timing_type : hold_rising; \n")
self.lib.write(" related_pin : \"clk\"; \n")
self.lib.write(" rise_constraint(CONSTRAINT_TABLE) {\n")
rounded_values = map(ch.round_time,self.times["hold_times_LH"])
rounded_values = list(map(round_time,self.times["hold_times_LH"]))
self.write_values(rounded_values,len(self.slews)," ")
self.lib.write(" }\n")
self.lib.write(" fall_constraint(CONSTRAINT_TABLE) {\n")
rounded_values = map(ch.round_time,self.times["hold_times_HL"])
rounded_values = list(map(round_time,self.times["hold_times_HL"]))
self.write_values(rounded_values,len(self.slews)," ")
self.lib.write(" }\n")
self.lib.write(" }\n")
@ -413,8 +413,8 @@ class lib:
self.lib.write(" }\n")
self.lib.write(" }\n")
min_pulse_width = ch.round_time(self.char_results["min_period"])/2.0
min_period = ch.round_time(self.char_results["min_period"])
min_pulse_width = round_time(self.char_results["min_period"])/2.0
min_period = round_time(self.char_results["min_period"])
self.lib.write(" timing(){ \n")
self.lib.write(" timing_type :\"min_pulse_width\"; \n")
self.lib.write(" related_pin : clk; \n")
@ -443,7 +443,7 @@ class lib:
try:
self.d
except AttributeError:
self.d = delay.delay(self.sram, self.sp_file, self.corner)
self.d = delay(self.sram, self.sp_file, self.corner)
if self.use_model:
self.char_results = self.d.analytical_delay(self.sram,self.slews,self.loads)
else:
@ -458,7 +458,7 @@ class lib:
try:
self.sh
except AttributeError:
self.sh = setup_hold.setup_hold(self.corner)
self.sh = setup_hold(self.corner)
if self.use_model:
self.times = self.sh.analytical_setuphold(self.slews,self.loads)
else:

View File

@ -1,8 +1,8 @@
import sys
import tech
import stimuli
from .stimuli import *
import debug
import charutils as ch
from .charutils import *
import ms_flop
from globals import OPTS
@ -38,7 +38,7 @@ class setup_hold():
# creates and opens the stimulus file for writing
temp_stim = OPTS.openram_temp + "stim.sp"
self.sf = open(temp_stim, "w")
self.stim = stimuli.stimuli(self.sf, self.corner)
self.stim = stimuli(self.sf, self.corner)
self.write_header(correct_value)
@ -186,8 +186,8 @@ class setup_hold():
target_time=feasible_bound,
correct_value=correct_value)
self.stim.run_sim()
ideal_clk_to_q = ch.convert_to_float(ch.parse_output("timing", "clk2q_delay"))
setuphold_time = ch.convert_to_float(ch.parse_output("timing", "setup_hold_time"))
ideal_clk_to_q = convert_to_float(parse_output("timing", "clk2q_delay"))
setuphold_time = convert_to_float(parse_output("timing", "setup_hold_time"))
debug.info(2,"*** {0} CHECK: {1} Ideal Clk-to-Q: {2} Setup/Hold: {3}".format(mode, correct_value,ideal_clk_to_q,setuphold_time))
if type(ideal_clk_to_q)!=float or type(setuphold_time)!=float:
@ -219,8 +219,8 @@ class setup_hold():
self.stim.run_sim()
clk_to_q = ch.convert_to_float(ch.parse_output("timing", "clk2q_delay"))
setuphold_time = ch.convert_to_float(ch.parse_output("timing", "setup_hold_time"))
clk_to_q = convert_to_float(parse_output("timing", "clk2q_delay"))
setuphold_time = convert_to_float(parse_output("timing", "setup_hold_time"))
if type(clk_to_q)==float and (clk_to_q<1.1*ideal_clk_to_q) and type(setuphold_time)==float:
if mode == "SETUP": # SETUP is clk-din, not din-clk
setuphold_time *= -1e9
@ -235,7 +235,7 @@ class setup_hold():
infeasible_bound = target_time
#raw_input("Press Enter to continue...")
if ch.relative_compare(feasible_bound, infeasible_bound, error_tolerance=0.001):
if relative_compare(feasible_bound, infeasible_bound, error_tolerance=0.001):
debug.info(3,"CONVERGE {0} vs {1}".format(feasible_bound,infeasible_bound))
break

View File

@ -4,10 +4,10 @@ Python GDS Mill Package
GDS Mill is a Python package for the creation and manipulation of binary GDS2 layout files.
"""
from gds2reader import *
from gds2writer import *
from pdfLayout import *
from vlsiLayout import *
from gdsStreamer import *
from gdsPrimitives import *
from .gds2reader import *
from .gds2writer import *
#from .pdfLayout import *
from .vlsiLayout import *
from .gdsStreamer import *
from .gdsPrimitives import *

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
import struct
from gdsPrimitives import *
from .gdsPrimitives import *
class Gds2writer:
"""Class to take a populated layout class and write it to a file in GDSII format"""
@ -14,8 +14,8 @@ class Gds2writer:
def print64AsBinary(self,number):
#debugging method for binary inspection
for index in range(0,64):
print (number>>(63-index))&0x1,
print "\n"
print((number>>(63-index))&0x1,eol='')
print("\n")
def ieeeDoubleFromIbmData(self,ibmData):
#the GDS double is in IBM 370 format like this:
@ -40,9 +40,9 @@ class Gds2writer:
exponent-=1
#check for underflow error -- should handle these properly!
if(exponent<=0):
print "Underflow Error"
print("Underflow Error")
elif(exponent == 2047):
print "Overflow Error"
print("Overflow Error")
#re assemble
newFloat=(sign<<63)|(exponent<<52)|((mantissa>>12)&0xfffffffffffff)
asciiDouble = struct.pack('>q',newFloat)
@ -84,12 +84,12 @@ class Gds2writer:
data = struct.unpack('>q',asciiDouble)[0]
sign = data >> 63
exponent = ((data >> 52) & 0x7ff)-1023
print exponent+1023
print(exponent+1023)
mantissa = data << 12 #chop off sign and exponent
#self.print64AsBinary((sign<<63)|((exponent+1023)<<52)|(mantissa>>12))
asciiDouble = struct.pack('>q',(sign<<63)|(exponent+1023<<52)|(mantissa>>12))
newFloat = struct.unpack('>d',asciiDouble)[0]
print "Check:"+str(newFloat)
print("Check:"+str(newFloat))
def writeRecord(self,record):
recordLength = len(record)+2 #make sure to include this in the length
@ -99,12 +99,12 @@ class Gds2writer:
def writeHeader(self):
## Header
if("gdsVersion" in self.layoutObject.info):
idBits='\x00\x02'
idBits=b'\x00\x02'
gdsVersion = struct.pack(">h",self.layoutObject.info["gdsVersion"])
self.writeRecord(idBits+gdsVersion)
## Modified Date
if("dates" in self.layoutObject.info):
idBits='\x01\x02'
idBits=b'\x01\x02'
modYear = struct.pack(">h",self.layoutObject.info["dates"][0])
modMonth = struct.pack(">h",self.layoutObject.info["dates"][1])
modDay = struct.pack(">h",self.layoutObject.info["dates"][2])
@ -122,43 +122,43 @@ class Gds2writer:
lastAccessMinute+lastAccessSecond)
## LibraryName
if("libraryName" in self.layoutObject.info):
idBits='\x02\x06'
idBits=b'\x02\x06'
if (len(self.layoutObject.info["libraryName"]) % 2 != 0):
libraryName = self.layoutObject.info["libraryName"] + "\0"
libraryName = self.layoutObject.info["libraryName"].encode() + "\0"
else:
libraryName = self.layoutObject.info["libraryName"]
libraryName = self.layoutObject.info["libraryName"].encode()
self.writeRecord(idBits+libraryName)
## reference libraries
if("referenceLibraries" in self.layoutObject.info):
idBits='\x1F\x06'
idBits=b'\x1F\x06'
referenceLibraryA = self.layoutObject.info["referenceLibraries"][0]
referenceLibraryB = self.layoutObject.info["referenceLibraries"][1]
self.writeRecord(idBits+referenceLibraryA+referenceLibraryB)
if("fonts" in self.layoutObject.info):
idBits='\x20\x06'
idBits=b'\x20\x06'
fontA = self.layoutObject.info["fonts"][0]
fontB = self.layoutObject.info["fonts"][1]
fontC = self.layoutObject.info["fonts"][2]
fontD = self.layoutObject.info["fonts"][3]
self.writeRecord(idBits+fontA+fontB+fontC+fontD)
if("attributeTable" in self.layoutObject.info):
idBits='\x23\x06'
idBits=b'\x23\x06'
attributeTable = self.layoutObject.info["attributeTable"]
self.writeRecord(idBits+attributeTable)
if("generations" in self.layoutObject.info):
idBits='\x22\x02'
idBits=b'\x22\x02'
generations = struct.pack(">h",self.layoutObject.info["generations"])
self.writeRecord(idBits+generations)
if("fileFormat" in self.layoutObject.info):
idBits='\x36\x02'
idBits=b'\x36\x02'
fileFormat = struct.pack(">h",self.layoutObject.info["fileFormat"])
self.writeRecord(idBits+fileFormat)
if("mask" in self.layoutObject.info):
idBits='\x37\x06'
idBits=b'\x37\x06'
mask = self.layoutObject.info["mask"]
self.writeRecord(idBits+mask)
if("units" in self.layoutObject.info):
idBits='\x03\x05'
idBits=b'\x03\x05'
userUnits=self.ibmDataFromIeeeDouble(self.layoutObject.info["units"][0])
dbUnits=self.ibmDataFromIeeeDouble((self.layoutObject.info["units"][0]*1e-6/self.layoutObject.info["units"][1])*self.layoutObject.info["units"][1])
@ -176,171 +176,171 @@ class Gds2writer:
self.writeRecord(idBits+userUnits+dbUnits)
if(self.debugToTerminal==1):
print "writer: userUnits %s"%(userUnits.encode("hex"))
print "writer: dbUnits %s"%(dbUnits.encode("hex"))
print("writer: userUnits %s"%(userUnits.encode("hex")))
print("writer: dbUnits %s"%(dbUnits.encode("hex")))
#self.ieeeFloatCheck(1.3e-6)
print "End of GDSII Header Written"
print("End of GDSII Header Written")
return 1
def writeBoundary(self,thisBoundary):
idBits = '\x08\x00' #record Type
idBits=b'\x08\x00' #record Type
self.writeRecord(idBits)
if(thisBoundary.elementFlags!=""):
idBits='\x26\x01' #ELFLAGS
idBits=b'\x26\x01' #ELFLAGS
elementFlags = struct.pack(">h",thisBoundary.elementFlags)
self.writeRecord(idBits+elementFlags)
if(thisBoundary.plex!=""):
idBits='\x2F\x03' #PLEX
idBits=b'\x2F\x03' #PLEX
plex = struct.pack(">i",thisBoundary.plex)
self.writeRecord(idBits+plex)
if(thisBoundary.drawingLayer!=""):
idBits='\x0D\x02' #drawig layer
idBits=b'\x0D\x02' #drawig layer
drawingLayer = struct.pack(">h",thisBoundary.drawingLayer)
self.writeRecord(idBits+drawingLayer)
if(thisBoundary.purposeLayer):
idBits='\x16\x02' #purpose layer
idBits=b'\x16\x02' #purpose layer
purposeLayer = struct.pack(">h",thisBoundary.purposeLayer)
self.writeRecord(idBits+purposeLayer)
if(thisBoundary.dataType!=""):
idBits='\x0E\x02'#DataType
idBits=b'\x0E\x02'#DataType
dataType = struct.pack(">h",thisBoundary.dataType)
self.writeRecord(idBits+dataType)
if(thisBoundary.coordinates!=""):
idBits='\x10\x03' #XY Data Points
idBits=b'\x10\x03' #XY Data Points
coordinateRecord = idBits
for coordinate in thisBoundary.coordinates:
x=struct.pack(">i",coordinate[0])
y=struct.pack(">i",coordinate[1])
x=struct.pack(">i",int(coordinate[0]))
y=struct.pack(">i",int(coordinate[1]))
coordinateRecord+=x
coordinateRecord+=y
self.writeRecord(coordinateRecord)
idBits='\x11\x00' #End Of Element
idBits=b'\x11\x00' #End Of Element
coordinateRecord = idBits
self.writeRecord(coordinateRecord)
def writePath(self,thisPath): #writes out a path structure
idBits = '\x09\x00' #record Type
idBits=b'\x09\x00' #record Type
self.writeRecord(idBits)
if(thisPath.elementFlags != ""):
idBits='\x26\x01' #ELFLAGS
idBits=b'\x26\x01' #ELFLAGS
elementFlags = struct.pack(">h",thisPath.elementFlags)
self.writeRecord(idBits+elementFlags)
if(thisPath.plex!=""):
idBits='\x2F\x03' #PLEX
idBits=b'\x2F\x03' #PLEX
plex = struct.pack(">i",thisPath.plex)
self.writeRecord(idBits+plex)
if(thisPath.drawingLayer):
idBits='\x0D\x02' #drawig layer
idBits=b'\x0D\x02' #drawig layer
drawingLayer = struct.pack(">h",thisPath.drawingLayer)
self.writeRecord(idBits+drawingLayer)
if(thisPath.purposeLayer):
idBits='\x16\x02' #purpose layer
idBits=b'\x16\x02' #purpose layer
purposeLayer = struct.pack(">h",thisPath.purposeLayer)
self.writeRecord(idBits+purposeLayer)
if(thisPath.pathType):
idBits='\x21\x02' #Path type
idBits=b'\x21\x02' #Path type
pathType = struct.pack(">h",thisPath.pathType)
self.writeRecord(idBits+pathType)
if(thisPath.pathWidth):
idBits='\x0F\x03'
idBits=b'\x0F\x03'
pathWidth = struct.pack(">i",thisPath.pathWidth)
self.writeRecord(idBits+pathWidth)
if(thisPath.coordinates):
idBits='\x10\x03' #XY Data Points
idBits=b'\x10\x03' #XY Data Points
coordinateRecord = idBits
for coordinate in thisPath.coordinates:
x=struct.pack(">i",coordinate[0])
y=struct.pack(">i",coordinate[1])
x=struct.pack(">i",int(coordinate[0]))
y=struct.pack(">i",int(coordinate[1]))
coordinateRecord+=x
coordinateRecord+=y
self.writeRecord(coordinateRecord)
idBits='\x11\x00' #End Of Element
idBits=b'\x11\x00' #End Of Element
coordinateRecord = idBits
self.writeRecord(coordinateRecord)
def writeSref(self,thisSref): #reads in a reference to another structure
idBits = '\x0A\x00' #record Type
idBits=b'\x0A\x00' #record Type
self.writeRecord(idBits)
if(thisSref.elementFlags != ""):
idBits='\x26\x01' #ELFLAGS
idBits=b'\x26\x01' #ELFLAGS
elementFlags = struct.pack(">h",thisSref.elementFlags)
self.writeRecord(idBits+elementFlags)
if(thisSref.plex!=""):
idBits='\x2F\x03' #PLEX
idBits=b'\x2F\x03' #PLEX
plex = struct.pack(">i",thisSref.plex)
self.writeRecord(idBits+plex)
if(thisSref.sName!=""):
idBits='\x12\x06'
idBits=b'\x12\x06'
if (len(thisSref.sName) % 2 != 0):
sName = thisSref.sName+"\0"
else:
sName = thisSref.sName
self.writeRecord(idBits+sName)
self.writeRecord(idBits+sName.encode())
if(thisSref.transFlags!=""):
idBits='\x1A\x01'
idBits=b'\x1A\x01'
mirrorFlag = int(thisSref.transFlags[0])<<15
rotateFlag = int(thisSref.transFlags[1])<<1
magnifyFlag = int(thisSref.transFlags[2])<<3
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
self.writeRecord(idBits+transFlags)
if(thisSref.magFactor!=""):
idBits='\x1B\x05'
idBits=b'\x1B\x05'
magFactor=self.ibmDataFromIeeeDouble(thisSref.magFactor)
self.writeRecord(idBits+magFactor)
if(thisSref.rotateAngle!=""):
idBits='\x1C\x05'
idBits=b'\x1C\x05'
rotateAngle=self.ibmDataFromIeeeDouble(thisSref.rotateAngle)
self.writeRecord(idBits+rotateAngle)
if(thisSref.coordinates!=""):
idBits='\x10\x03' #XY Data Points
idBits=b'\x10\x03' #XY Data Points
coordinateRecord = idBits
coordinate = thisSref.coordinates
x=struct.pack(">i",coordinate[0])
y=struct.pack(">i",coordinate[1])
x=struct.pack(">i",int(coordinate[0]))
y=struct.pack(">i",int(coordinate[1]))
coordinateRecord+=x
coordinateRecord+=y
#print thisSref.coordinates
#print(thisSref.coordinates)
self.writeRecord(coordinateRecord)
idBits='\x11\x00' #End Of Element
idBits=b'\x11\x00' #End Of Element
coordinateRecord = idBits
self.writeRecord(coordinateRecord)
def writeAref(self,thisAref): #an array of references
idBits = '\x0B\x00' #record Type
idBits=b'\x0B\x00' #record Type
self.writeRecord(idBits)
if(thisAref.elementFlags!=""):
idBits='\x26\x01' #ELFLAGS
idBits=b'\x26\x01' #ELFLAGS
elementFlags = struct.pack(">h",thisAref.elementFlags)
self.writeRecord(idBits+elementFlags)
if(thisAref.plex):
idBits='\x2F\x03' #PLEX
idBits=b'\x2F\x03' #PLEX
plex = struct.pack(">i",thisAref.plex)
self.writeRecord(idBits+plex)
if(thisAref.aName):
idBits='\x12\x06'
idBits=b'\x12\x06'
if (len(thisAref.aName) % 2 != 0):
aName = thisAref.aName+"\0"
else:
aName = thisAref.aName
self.writeRecord(idBits+aName)
if(thisAref.transFlags):
idBits='\x1A\x01'
idBits=b'\x1A\x01'
mirrorFlag = int(thisAref.transFlags[0])<<15
rotateFlag = int(thisAref.transFlags[1])<<1
magnifyFlag = int(thisAref.transFlags[0])<<3
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
self.writeRecord(idBits+transFlags)
if(thisAref.magFactor):
idBits='\x1B\x05'
idBits=b'\x1B\x05'
magFactor=self.ibmDataFromIeeeDouble(thisAref.magFactor)
self.writeRecord(idBits+magFactor)
if(thisAref.rotateAngle):
idBits='\x1C\x05'
idBits=b'\x1C\x05'
rotateAngle=self.ibmDataFromIeeeDouble(thisAref.rotateAngle)
self.writeRecord(idBits+rotateAngle)
if(thisAref.coordinates):
idBits='\x10\x03' #XY Data Points
idBits=b'\x10\x03' #XY Data Points
coordinateRecord = idBits
for coordinate in thisAref.coordinates:
x=struct.pack(">i",coordinate[0])
@ -348,151 +348,151 @@ class Gds2writer:
coordinateRecord+=x
coordinateRecord+=y
self.writeRecord(coordinateRecord)
idBits='\x11\x00' #End Of Element
idBits=b'\x11\x00' #End Of Element
coordinateRecord = idBits
self.writeRecord(coordinateRecord)
def writeText(self,thisText):
idBits = '\x0C\x00' #record Type
idBits=b'\x0C\x00' #record Type
self.writeRecord(idBits)
if(thisText.elementFlags!=""):
idBits='\x26\x01' #ELFLAGS
idBits=b'\x26\x01' #ELFLAGS
elementFlags = struct.pack(">h",thisText.elementFlags)
self.writeRecord(idBits+elementFlags)
if(thisText.plex !=""):
idBits='\x2F\x03' #PLEX
idBits=b'\x2F\x03' #PLEX
plex = struct.pack(">i",thisText.plex)
self.writeRecord(idBits+plex)
if(thisText.drawingLayer != ""):
idBits='\x0D\x02' #drawing layer
idBits=b'\x0D\x02' #drawing layer
drawingLayer = struct.pack(">h",thisText.drawingLayer)
self.writeRecord(idBits+drawingLayer)
#if(thisText.purposeLayer):
idBits='\x16\x02' #purpose layer
idBits=b'\x16\x02' #purpose layer
purposeLayer = struct.pack(">h",thisText.purposeLayer)
self.writeRecord(idBits+purposeLayer)
if(thisText.transFlags != ""):
idBits='\x1A\x01'
idBits=b'\x1A\x01'
mirrorFlag = int(thisText.transFlags[0])<<15
rotateFlag = int(thisText.transFlags[1])<<1
magnifyFlag = int(thisText.transFlags[0])<<3
transFlags = struct.pack(">H",mirrorFlag|rotateFlag|magnifyFlag)
self.writeRecord(idBits+transFlags)
if(thisText.magFactor != ""):
idBits='\x1B\x05'
idBits=b'\x1B\x05'
magFactor=self.ibmDataFromIeeeDouble(thisText.magFactor)
self.writeRecord(idBits+magFactor)
if(thisText.rotateAngle != ""):
idBits='\x1C\x05'
idBits=b'\x1C\x05'
rotateAngle=self.ibmDataFromIeeeDouble(thisText.rotateAngle)
self.writeRecord(idBits+rotateAngle)
if(thisText.pathType !=""):
idBits='\x21\x02' #Path type
idBits=b'\x21\x02' #Path type
pathType = struct.pack(">h",thisText.pathType)
self.writeRecord(idBits+pathType)
if(thisText.pathWidth != ""):
idBits='\x0F\x03'
idBits=b'\x0F\x03'
pathWidth = struct.pack(">i",thisText.pathWidth)
self.writeRecord(idBits+pathWidth)
if(thisText.presentationFlags!=""):
idBits='\x1A\x01'
idBits=b'\x1A\x01'
font = thisText.presentationFlags[0]<<4
verticalFlags = int(thisText.presentationFlags[1])<<2
horizontalFlags = int(thisText.presentationFlags[2])
presentationFlags = struct.pack(">H",font|verticalFlags|horizontalFlags)
self.writeRecord(idBits+transFlags)
if(thisText.coordinates!=""):
idBits='\x10\x03' #XY Data Points
idBits=b'\x10\x03' #XY Data Points
coordinateRecord = idBits
for coordinate in thisText.coordinates:
x=struct.pack(">i",coordinate[0])
y=struct.pack(">i",coordinate[1])
x=struct.pack(">i",int(coordinate[0]))
y=struct.pack(">i",int(coordinate[1]))
coordinateRecord+=x
coordinateRecord+=y
self.writeRecord(coordinateRecord)
if(thisText.textString):
idBits='\x19\x06'
idBits=b'\x19\x06'
textString = thisText.textString
self.writeRecord(idBits+textString)
self.writeRecord(idBits+textString.encode())
idBits='\x11\x00' #End Of Element
idBits=b'\x11\x00' #End Of Element
coordinateRecord = idBits
self.writeRecord(coordinateRecord)
def writeNode(self,thisNode):
idBits = '\x15\x00' #record Type
idBits=b'\x15\x00' #record Type
self.writeRecord(idBits)
if(thisNode.elementFlags!=""):
idBits='\x26\x01' #ELFLAGS
idBits=b'\x26\x01' #ELFLAGS
elementFlags = struct.pack(">h",thisNode.elementFlags)
self.writeRecord(idBits+elementFlags)
if(thisNode.plex!=""):
idBits='\x2F\x03' #PLEX
idBits=b'\x2F\x03' #PLEX
plex = struct.pack(">i",thisNode.plex)
self.writeRecord(idBits+plex)
if(thisNode.drawingLayer!=""):
idBits='\x0D\x02' #drawig layer
idBits=b'\x0D\x02' #drawig layer
drawingLayer = struct.pack(">h",thisNode.drawingLayer)
self.writeRecord(idBits+drawingLayer)
if(thisNode.nodeType!=""):
idBits='\x2A\x02'
idBits=b'\x2A\x02'
nodeType = struct.pack(">h",thisNode.nodeType)
self.writeRecord(idBits+nodeType)
if(thisText.coordinates!=""):
idBits='\x10\x03' #XY Data Points
idBits=b'\x10\x03' #XY Data Points
coordinateRecord = idBits
for coordinate in thisText.coordinates:
x=struct.pack(">i",coordinate[0])
y=struct.pack(">i",coordinate[1])
x=struct.pack(">i",int(coordinate[0]))
y=struct.pack(">i",int(coordinate[1]))
coordinateRecord+=x
coordinateRecord+=y
self.writeRecord(coordinateRecord)
idBits='\x11\x00' #End Of Element
idBits=b'\x11\x00' #End Of Element
coordinateRecord = idBits
self.writeRecord(coordinateRecord)
def writeBox(self,thisBox):
idBits = '\x2E\x02' #record Type
idBits=b'\x2E\x02' #record Type
self.writeRecord(idBits)
if(thisBox.elementFlags!=""):
idBits='\x26\x01' #ELFLAGS
idBits=b'\x26\x01' #ELFLAGS
elementFlags = struct.pack(">h",thisBox.elementFlags)
self.writeRecord(idBits+elementFlags)
if(thisBox.plex!=""):
idBits='\x2F\x03' #PLEX
idBits=b'\x2F\x03' #PLEX
plex = struct.pack(">i",thisBox.plex)
self.writeRecord(idBits+plex)
if(thisBox.drawingLayer!=""):
idBits='\x0D\x02' #drawig layer
idBits=b'\x0D\x02' #drawig layer
drawingLayer = struct.pack(">h",thisBox.drawingLayer)
self.writeRecord(idBits+drawingLayer)
if(thisBox.purposeLayer):
idBits='\x16\x02' #purpose layer
idBits=b'\x16\x02' #purpose layer
purposeLayer = struct.pack(">h",thisBox.purposeLayer)
self.writeRecord(idBits+purposeLayer)
if(thisBox.boxValue!=""):
idBits='\x2D\x00'
idBits=b'\x2D\x00'
boxValue = struct.pack(">h",thisBox.boxValue)
self.writeRecord(idBits+boxValue)
if(thisBox.coordinates!=""):
idBits='\x10\x03' #XY Data Points
idBits=b'\x10\x03' #XY Data Points
coordinateRecord = idBits
for coordinate in thisBox.coordinates:
x=struct.pack(">i",coordinate[0])
y=struct.pack(">i",coordinate[1])
x=struct.pack(">i",int(coordinate[0]))
y=struct.pack(">i",int(coordinate[1]))
coordinateRecord+=x
coordinateRecord+=y
self.writeRecord(coordinateRecord)
idBits='\x11\x00' #End Of Element
idBits=b'\x11\x00' #End Of Element
coordinateRecord = idBits
self.writeRecord(coordinateRecord)
def writeNextStructure(self,structureName):
#first put in the structure head
thisStructure = self.layoutObject.structures[structureName]
idBits='\x05\x02'
idBits=b'\x05\x02'
createYear = struct.pack(">h",thisStructure.createDate[0])
createMonth = struct.pack(">h",thisStructure.createDate[1])
createDay = struct.pack(">h",thisStructure.createDate[2])
@ -508,12 +508,12 @@ class Gds2writer:
self.writeRecord(idBits+createYear+createMonth+createDay+createHour+createMinute+createSecond\
+modYear+modMonth+modDay+modHour+modMinute+modSecond)
#now the structure name
idBits='\x06\x06'
idBits=b'\x06\x06'
##caveat: the name needs to be an EVEN number of characters
if(len(structureName)%2 == 1):
#pad with a zero
structureName = structureName + '\x00'
self.writeRecord(idBits+structureName)
self.writeRecord(idBits+structureName.encode())
#now go through all the structure elements and write them in
for boundary in thisStructure.boundaries:
@ -531,7 +531,7 @@ class Gds2writer:
for box in thisStructure.boxes:
self.writeBox(box)
#put in the structure tail
idBits='\x07\x00'
idBits=b'\x07\x00'
self.writeRecord(idBits)
def writeGds2(self):
@ -540,7 +540,7 @@ class Gds2writer:
for structureName in self.layoutObject.structures:
self.writeNextStructure(structureName)
#at the end, put in the END LIB record
idBits='\x04\x00'
idBits=b'\x04\x00'
self.writeRecord(idBits)
def writeToFile(self,fileName):

View File

@ -122,11 +122,11 @@ class GdsStreamer:
#stream the gds out from cadence
worker = os.popen("pipo strmout "+self.workingDirectory+"/partStreamOut.tmpl")
#dump the outputs to the screen line by line
print "Streaming Out From Cadence......"
print("Streaming Out From Cadence......")
while 1:
line = worker.readline()
if not line: break #this means sim is finished so jump out
#else: print line #for debug only
#else: print(line) #for debug only
worker.close()
#now remove the template file
os.remove(self.workingDirectory+"/partStreamOut.tmpl")
@ -142,13 +142,13 @@ class GdsStreamer:
#stream the gds out from cadence
worker = os.popen("pipo strmin "+self.workingDirectory+"/partStreamIn.tmpl")
#dump the outputs to the screen line by line
print "Streaming In To Cadence......"
print("Streaming In To Cadence......")
while 1:
line = worker.readline()
if not line: break #this means sim is finished so jump out
#else: print line #for debug only
#else: print(line) #for debug only
worker.close()
#now remove the template file
os.remove(self.workingDirectory+"/partStreamIn.tmpl")
#and go back to whever it was we started from
os.chdir(currentPath)
os.chdir(currentPath)

View File

@ -1,6 +1,6 @@
import pyx
import math
import mpmath
from numpy import matrix
from gdsPrimitives import *
import random
@ -39,12 +39,12 @@ class pdfLayout:
"""
xyCoordinates = []
#setup a translation matrix
tMatrix = mpmath.matrix([[1.0,0.0,origin[0]],[0.0,1.0,origin[1]],[0.0,0.0,1.0]])
tMatrix = matrix([[1.0,0.0,origin[0]],[0.0,1.0,origin[1]],[0.0,0.0,1.0]])
#and a rotation matrix
rMatrix = mpmath.matrix([[uVector[0],vVector[0],0.0],[uVector[1],vVector[1],0.0],[0.0,0.0,1.0]])
rMatrix = matrix([[uVector[0],vVector[0],0.0],[uVector[1],vVector[1],0.0],[0.0,0.0,1.0]])
for coordinate in uvCoordinates:
#grab the point in UV space
uvPoint = mpmath.matrix([coordinate[0],coordinate[1],1.0])
uvPoint = matrix([coordinate[0],coordinate[1],1.0])
#now rotate and translate it back to XY space
xyPoint = rMatrix * uvPoint
xyPoint = tMatrix * xyPoint

View File

@ -1,7 +1,8 @@
from gdsPrimitives import *
from .gdsPrimitives import *
from datetime import *
import mpmath
import gdsPrimitives
#from mpmath import matrix
from numpy import matrix
#import gdsPrimitives
import debug
class VlsiLayout:
@ -10,7 +11,7 @@ class VlsiLayout:
def __init__(self, name=None, units=(0.001,1e-9), libraryName = "DEFAULT.DB", gdsVersion=5):
#keep a list of all the structures in this layout
self.units = units
#print units
#print(units)
modDate = datetime.now()
self.structures=dict()
self.layerNumbersInUse = []
@ -89,7 +90,7 @@ class VlsiLayout:
def newLayout(self,newName):
#if (newName == "" | newName == 0):
# print("ERROR: vlsiLayout.py:newLayout newName is null")
# print("ERROR: vlsiLayout.py:newLayout newName is null")
#make sure the newName is a multiple of 2 characters
#if(len(newName)%2 == 1):
@ -134,13 +135,12 @@ class VlsiLayout:
self.populateCoordinateMap()
def deduceHierarchy(self):
#first, find the root of the tree.
#go through and get the name of every structure.
#then, go through and find which structure is not
#contained by any other structure. this is the root.
""" First, find the root of the tree.
Then go through and get the name of every structure.
Then, go through and find which structure is not
contained by any other structure. this is the root."""
structureNames=[]
for name in self.structures:
#print "deduceHierarchy: structure.name[%s]",name //FIXME: Added By Tom G.
structureNames+=[name]
for name in self.structures:
@ -148,7 +148,7 @@ class VlsiLayout:
for sref in self.structures[name].srefs: #go through each reference
if sref.sName in structureNames: #and compare to our list
structureNames.remove(sref.sName)
self.rootStructureName = structureNames[0]
def traverseTheHierarchy(self, startingStructureName=None, delegateFunction = None,
@ -163,19 +163,20 @@ class VlsiLayout:
rotateAngle = 0
else:
rotateAngle = math.radians(float(rotateAngle))
mRotate = mpmath.matrix([[math.cos(rotateAngle),-math.sin(rotateAngle),0.0],
[math.sin(rotateAngle),math.cos(rotateAngle),0.0],[0.0,0.0,1.0],])
mRotate = matrix([[math.cos(rotateAngle),-math.sin(rotateAngle),0.0],
[math.sin(rotateAngle),math.cos(rotateAngle),0.0],
[0.0,0.0,1.0]])
#set up the translation matrix
translateX = float(coordinates[0])
translateY = float(coordinates[1])
mTranslate = mpmath.matrix([[1.0,0.0,translateX],[0.0,1.0,translateY],[0.0,0.0,1.0]])
mTranslate = matrix([[1.0,0.0,translateX],[0.0,1.0,translateY],[0.0,0.0,1.0]])
#set up the scale matrix (handles mirror X)
scaleX = 1.0
if(transFlags[0]):
scaleY = -1.0
else:
scaleY = 1.0
mScale = mpmath.matrix([[scaleX,0.0,0.0],[0.0,scaleY,0.0],[0.0,0.0,1.0]])
mScale = matrix([[scaleX,0.0,0.0],[0.0,scaleY,0.0],[0.0,0.0,1.0]])
#we need to keep track of all transforms in the hierarchy
#when we add an element to the xy tree, we apply all transforms from the bottom up
@ -197,7 +198,7 @@ class VlsiLayout:
transFlags = sref.transFlags,
coordinates = sref.coordinates)
# else:
# print "WARNING: via encountered, ignoring:", sref.sName
# print("WARNING: via encountered, ignoring:", sref.sName)
#MUST HANDLE AREFs HERE AS WELL
#when we return, drop the last transform from the transformPath
del transformPath[-1]
@ -210,10 +211,10 @@ class VlsiLayout:
def populateCoordinateMap(self):
def addToXyTree(startingStructureName = None,transformPath = None):
#print"populateCoordinateMap"
uVector = mpmath.matrix([1.0,0.0,0.0]) #start with normal basis vectors
vVector = mpmath.matrix([0.0,1.0,0.0])
origin = mpmath.matrix([0.0,0.0,1.0]) #and an origin (Z component is 1.0 to indicate position instead of vector)
#print("populateCoordinateMap")
uVector = matrix([1.0,0.0,0.0]).transpose() #start with normal basis vectors
vVector = matrix([0.0,1.0,0.0]).transpose()
origin = matrix([0.0,0.0,1.0]).transpose() #and an origin (Z component is 1.0 to indicate position instead of vector)
#make a copy of all the transforms and reverse it
reverseTransformPath = transformPath[:]
if len(reverseTransformPath) > 1:
@ -245,7 +246,7 @@ class VlsiLayout:
#userUnitsPerMicron = userUnit / 1e-6
userUnitsPerMicron = userUnit / (userUnit)
layoutUnitsPerMicron = userUnitsPerMicron / self.units[0]
#print "userUnit:",userUnit,"userUnitsPerMicron",userUnitsPerMicron,"layoutUnitsPerMicron",layoutUnitsPerMicron,[microns,microns*layoutUnitsPerMicron]
#print("userUnit:",userUnit,"userUnitsPerMicron",userUnitsPerMicron,"layoutUnitsPerMicron",layoutUnitsPerMicron,[microns,microns*layoutUnitsPerMicron])
return round(microns*layoutUnitsPerMicron,0)
def changeRoot(self,newRoot, create=False):
@ -259,7 +260,7 @@ class VlsiLayout:
# Determine if newRoot exists
# layoutToAdd (default) or nameOfLayout
if (newRoot == 0 | ((newRoot not in self.structures) & ~create)):
print "ERROR: vlsiLayout.changeRoot: Name of new root [%s] not found and create flag is false"%newRoot
print("ERROR: vlsiLayout.changeRoot: Name of new root [%s] not found and create flag is false"%newRoot)
exit(1)
else:
if ((newRoot not in self.structures) & create):
@ -308,13 +309,13 @@ class VlsiLayout:
self.layerNumbersInUse += [layerNumber]
#Also, check if the user units / microns is the same as this Layout
#if (layoutToAdd.units != self.units):
#print "WARNING: VlsiLayout: Units from design to be added do not match target Layout"
#print("WARNING: VlsiLayout: Units from design to be added do not match target Layout")
# if debug: print "DEBUG: vlsilayout: Using %d layers"
# if debug: print("DEBUG: vlsilayout: Using %d layers")
# If we can't find the structure, error
#if StructureFound == False:
#print "ERROR: vlsiLayout.addInstance: [%s] Name not found in local structures, "%(nameOfLayout)
#print("ERROR: vlsiLayout.addInstance: [%s] Name not found in local structures, "%(nameOfLayout))
#return #FIXME: remove!
#exit(1)
@ -353,10 +354,10 @@ class VlsiLayout:
Method to add a box to a layout
"""
offsetInLayoutUnits = (self.userUnits(offsetInMicrons[0]),self.userUnits(offsetInMicrons[1]))
#print "addBox:offsetInLayoutUnits",offsetInLayoutUnits
#print("addBox:offsetInLayoutUnits",offsetInLayoutUnits)
widthInLayoutUnits = self.userUnits(width)
heightInLayoutUnits = self.userUnits(height)
#print "offsetInLayoutUnits",widthInLayoutUnits,"heightInLayoutUnits",heightInLayoutUnits
#print("offsetInLayoutUnits",widthInLayoutUnits,"heightInLayoutUnits",heightInLayoutUnits)
if not center:
coordinates=[offsetInLayoutUnits,
(offsetInLayoutUnits[0]+widthInLayoutUnits,offsetInLayoutUnits[1]),
@ -522,7 +523,7 @@ class VlsiLayout:
heightInBlocks = int(coverageHeight/effectiveBlock)
passFailRecord = []
print "Filling layer:",layerToFill
print("Filling layer:",layerToFill)
def isThisBlockOk(startingStructureName,coordinates,rotateAngle=None):
#go through every boundary and check
for boundary in self.structures[startingStructureName].boundaries:
@ -568,7 +569,7 @@ class VlsiLayout:
#if its bad, this global tempPassFail will be false
#if true, we can add the block
passFailRecord+=[self.tempPassFail]
print "Percent Complete:"+str(percentDone)
print("Percent Complete:"+str(percentDone))
passFailIndex=0
@ -579,7 +580,7 @@ class VlsiLayout:
if passFailRecord[passFailIndex]:
self.addBox(layerToFill, (blockX,blockY), width=blockSize, height=blockSize)
passFailIndex+=1
print "Done\n\n"
print("Done\n\n")
def getLayoutBorder(self,borderlayer):
for boundary in self.structures[self.rootStructureName].boundaries:
@ -591,7 +592,7 @@ class VlsiLayout:
cellSize=[right_top[0]-left_bottom[0],right_top[1]-left_bottom[1]]
cellSizeMicron=[cellSize[0]*self.units[0],cellSize[1]*self.units[0]]
if not(cellSizeMicron):
print "Error: "+str(self.rootStructureName)+".cell_size information not found yet"
print("Error: "+str(self.rootStructureName)+".cell_size information not found yet")
return cellSizeMicron
def measureSize(self,startStructure):
@ -700,7 +701,7 @@ class VlsiLayout:
debug.warning("Did not find pin on layer {0} at coordinate {1}".format(layer, coordinate))
# sort the boundaries, return the max area pin boundary
pin_boundaries.sort(cmpBoundaryAreas,reverse=True)
pin_boundaries.sort(key=boundaryArea,reverse=True)
pin_boundary=pin_boundaries[0]
# Convert to USER units
@ -743,7 +744,8 @@ class VlsiLayout:
shape_list=[]
for label in label_list:
(label_coordinate,label_layer)=label
shape_list.append(self.getPinShapeByDBLocLayer(label_coordinate, label_layer))
shape = self.getPinShapeByDBLocLayer(label_coordinate, label_layer)
shape_list.append(shape)
return shape_list
def getAllPinShapesByLabel(self,label_name):
@ -797,23 +799,23 @@ class VlsiLayout:
# Rectangle is [leftx, bottomy, rightx, topy].
boundaryRect=[left_bottom[0],left_bottom[1],right_top[0],right_top[1]]
boundaryRect=self.transformRectangle(boundaryRect,structureuVector,structurevVector)
boundaryRect=[boundaryRect[0]+structureOrigin[0],boundaryRect[1]+structureOrigin[1],
boundaryRect[2]+structureOrigin[0],boundaryRect[3]+structureOrigin[1]]
boundaryRect=[boundaryRect[0]+structureOrigin[0].item(),boundaryRect[1]+structureOrigin[1].item(),
boundaryRect[2]+structureOrigin[0].item(),boundaryRect[3]+structureOrigin[1].item()]
if self.labelInRectangle(coordinates,boundaryRect):
boundaries.append(boundaryRect)
return boundaries
def transformRectangle(self,orignalRectangle,uVector,vVector):
def transformRectangle(self,originalRectangle,uVector,vVector):
"""
Transforms the four coordinates of a rectangle in space
and recomputes the left, bottom, right, top values.
"""
leftBottom=mpmath.matrix([orignalRectangle[0],orignalRectangle[1]])
leftBottom=[originalRectangle[0],originalRectangle[1]]
leftBottom=self.transformCoordinate(leftBottom,uVector,vVector)
rightTop=mpmath.matrix([orignalRectangle[2],orignalRectangle[3]])
rightTop=[originalRectangle[2],originalRectangle[3]]
rightTop=self.transformCoordinate(rightTop,uVector,vVector)
left=min(leftBottom[0],rightTop[0])
@ -821,14 +823,15 @@ class VlsiLayout:
right=max(leftBottom[0],rightTop[0])
top=max(leftBottom[1],rightTop[1])
return [left,bottom,right,top]
newRectangle = [left,bottom,right,top]
return newRectangle
def transformCoordinate(self,coordinate,uVector,vVector):
"""
Rotate a coordinate in space.
"""
x=coordinate[0]*uVector[0]+coordinate[1]*uVector[1]
y=coordinate[1]*vVector[1]+coordinate[0]*vVector[0]
x=coordinate[0]*uVector[0].item()+coordinate[1]*uVector[1].item()
y=coordinate[1]*vVector[1].item()+coordinate[0]*vVector[0].item()
transformCoordinate=[x,y]
return transformCoordinate
@ -845,18 +848,12 @@ class VlsiLayout:
else:
return False
def cmpBoundaryAreas(A,B):
def boundaryArea(A):
"""
Compares two rectangles and return true if Area(A)>Area(B).
Returns boundary area for sorting.
"""
area_A=(A[2]-A[0])*(A[3]-A[1])
area_B=(B[2]-B[0])*(B[3]-B[1])
if area_A>area_B:
return 1
elif area_A==area_B:
return 0
else:
return -1
return area_A

View File

@ -1,386 +0,0 @@
__version__ = '0.14'
from usertools import monitor, timing
from ctx_fp import FPContext
from ctx_mp import MPContext
fp = FPContext()
mp = MPContext()
fp._mp = mp
mp._mp = mp
mp._fp = fp
fp._fp = fp
# XXX: extremely bad pickle hack
import ctx_mp as _ctx_mp
_ctx_mp._mpf_module.mpf = mp.mpf
_ctx_mp._mpf_module.mpc = mp.mpc
make_mpf = mp.make_mpf
make_mpc = mp.make_mpc
extraprec = mp.extraprec
extradps = mp.extradps
workprec = mp.workprec
workdps = mp.workdps
mag = mp.mag
bernfrac = mp.bernfrac
jdn = mp.jdn
jsn = mp.jsn
jcn = mp.jcn
jtheta = mp.jtheta
calculate_nome = mp.calculate_nome
nint_distance = mp.nint_distance
plot = mp.plot
cplot = mp.cplot
splot = mp.splot
odefun = mp.odefun
jacobian = mp.jacobian
findroot = mp.findroot
multiplicity = mp.multiplicity
isinf = mp.isinf
isnan = mp.isnan
isint = mp.isint
almosteq = mp.almosteq
nan = mp.nan
rand = mp.rand
absmin = mp.absmin
absmax = mp.absmax
fraction = mp.fraction
linspace = mp.linspace
arange = mp.arange
mpmathify = convert = mp.convert
mpc = mp.mpc
mpi = mp.mpi
nstr = mp.nstr
nprint = mp.nprint
chop = mp.chop
fneg = mp.fneg
fadd = mp.fadd
fsub = mp.fsub
fmul = mp.fmul
fdiv = mp.fdiv
fprod = mp.fprod
quad = mp.quad
quadgl = mp.quadgl
quadts = mp.quadts
quadosc = mp.quadosc
pslq = mp.pslq
identify = mp.identify
findpoly = mp.findpoly
richardson = mp.richardson
shanks = mp.shanks
nsum = mp.nsum
nprod = mp.nprod
diff = mp.diff
diffs = mp.diffs
diffun = mp.diffun
differint = mp.differint
taylor = mp.taylor
pade = mp.pade
polyval = mp.polyval
polyroots = mp.polyroots
fourier = mp.fourier
fourierval = mp.fourierval
sumem = mp.sumem
chebyfit = mp.chebyfit
limit = mp.limit
matrix = mp.matrix
eye = mp.eye
diag = mp.diag
zeros = mp.zeros
ones = mp.ones
hilbert = mp.hilbert
randmatrix = mp.randmatrix
swap_row = mp.swap_row
extend = mp.extend
norm = mp.norm
mnorm = mp.mnorm
lu_solve = mp.lu_solve
lu = mp.lu
unitvector = mp.unitvector
inverse = mp.inverse
residual = mp.residual
qr_solve = mp.qr_solve
cholesky = mp.cholesky
cholesky_solve = mp.cholesky_solve
det = mp.det
cond = mp.cond
expm = mp.expm
sqrtm = mp.sqrtm
powm = mp.powm
logm = mp.logm
sinm = mp.sinm
cosm = mp.cosm
mpf = mp.mpf
j = mp.j
exp = mp.exp
expj = mp.expj
expjpi = mp.expjpi
ln = mp.ln
im = mp.im
re = mp.re
inf = mp.inf
ninf = mp.ninf
sign = mp.sign
eps = mp.eps
pi = mp.pi
ln2 = mp.ln2
ln10 = mp.ln10
phi = mp.phi
e = mp.e
euler = mp.euler
catalan = mp.catalan
khinchin = mp.khinchin
glaisher = mp.glaisher
apery = mp.apery
degree = mp.degree
twinprime = mp.twinprime
mertens = mp.mertens
ldexp = mp.ldexp
frexp = mp.frexp
fsum = mp.fsum
fdot = mp.fdot
sqrt = mp.sqrt
cbrt = mp.cbrt
exp = mp.exp
ln = mp.ln
log = mp.log
log10 = mp.log10
power = mp.power
cos = mp.cos
sin = mp.sin
tan = mp.tan
cosh = mp.cosh
sinh = mp.sinh
tanh = mp.tanh
acos = mp.acos
asin = mp.asin
atan = mp.atan
asinh = mp.asinh
acosh = mp.acosh
atanh = mp.atanh
sec = mp.sec
csc = mp.csc
cot = mp.cot
sech = mp.sech
csch = mp.csch
coth = mp.coth
asec = mp.asec
acsc = mp.acsc
acot = mp.acot
asech = mp.asech
acsch = mp.acsch
acoth = mp.acoth
cospi = mp.cospi
sinpi = mp.sinpi
sinc = mp.sinc
sincpi = mp.sincpi
fabs = mp.fabs
re = mp.re
im = mp.im
conj = mp.conj
floor = mp.floor
ceil = mp.ceil
root = mp.root
nthroot = mp.nthroot
hypot = mp.hypot
modf = mp.modf
ldexp = mp.ldexp
frexp = mp.frexp
sign = mp.sign
arg = mp.arg
phase = mp.phase
polar = mp.polar
rect = mp.rect
degrees = mp.degrees
radians = mp.radians
atan2 = mp.atan2
fib = mp.fib
fibonacci = mp.fibonacci
lambertw = mp.lambertw
zeta = mp.zeta
altzeta = mp.altzeta
gamma = mp.gamma
factorial = mp.factorial
fac = mp.fac
fac2 = mp.fac2
beta = mp.beta
betainc = mp.betainc
psi = mp.psi
#psi0 = mp.psi0
#psi1 = mp.psi1
#psi2 = mp.psi2
#psi3 = mp.psi3
polygamma = mp.polygamma
digamma = mp.digamma
#trigamma = mp.trigamma
#tetragamma = mp.tetragamma
#pentagamma = mp.pentagamma
harmonic = mp.harmonic
bernoulli = mp.bernoulli
bernfrac = mp.bernfrac
stieltjes = mp.stieltjes
hurwitz = mp.hurwitz
dirichlet = mp.dirichlet
bernpoly = mp.bernpoly
eulerpoly = mp.eulerpoly
eulernum = mp.eulernum
polylog = mp.polylog
clsin = mp.clsin
clcos = mp.clcos
gammainc = mp.gammainc
gammaprod = mp.gammaprod
binomial = mp.binomial
rf = mp.rf
ff = mp.ff
hyper = mp.hyper
hyp0f1 = mp.hyp0f1
hyp1f1 = mp.hyp1f1
hyp1f2 = mp.hyp1f2
hyp2f1 = mp.hyp2f1
hyp2f2 = mp.hyp2f2
hyp2f0 = mp.hyp2f0
hyp2f3 = mp.hyp2f3
hyp3f2 = mp.hyp3f2
hyperu = mp.hyperu
hypercomb = mp.hypercomb
meijerg = mp.meijerg
appellf1 = mp.appellf1
erf = mp.erf
erfc = mp.erfc
erfi = mp.erfi
erfinv = mp.erfinv
npdf = mp.npdf
ncdf = mp.ncdf
expint = mp.expint
e1 = mp.e1
ei = mp.ei
li = mp.li
ci = mp.ci
si = mp.si
chi = mp.chi
shi = mp.shi
fresnels = mp.fresnels
fresnelc = mp.fresnelc
airyai = mp.airyai
airybi = mp.airybi
ellipe = mp.ellipe
ellipk = mp.ellipk
agm = mp.agm
jacobi = mp.jacobi
chebyt = mp.chebyt
chebyu = mp.chebyu
legendre = mp.legendre
legenp = mp.legenp
legenq = mp.legenq
hermite = mp.hermite
gegenbauer = mp.gegenbauer
laguerre = mp.laguerre
spherharm = mp.spherharm
besselj = mp.besselj
j0 = mp.j0
j1 = mp.j1
besseli = mp.besseli
bessely = mp.bessely
besselk = mp.besselk
hankel1 = mp.hankel1
hankel2 = mp.hankel2
struveh = mp.struveh
struvel = mp.struvel
whitm = mp.whitm
whitw = mp.whitw
ber = mp.ber
bei = mp.bei
ker = mp.ker
kei = mp.kei
coulombc = mp.coulombc
coulombf = mp.coulombf
coulombg = mp.coulombg
lambertw = mp.lambertw
barnesg = mp.barnesg
superfac = mp.superfac
hyperfac = mp.hyperfac
loggamma = mp.loggamma
siegeltheta = mp.siegeltheta
siegelz = mp.siegelz
grampoint = mp.grampoint
zetazero = mp.zetazero
riemannr = mp.riemannr
primepi = mp.primepi
primepi2 = mp.primepi2
primezeta = mp.primezeta
bell = mp.bell
polyexp = mp.polyexp
expm1 = mp.expm1
powm1 = mp.powm1
unitroots = mp.unitroots
cyclotomic = mp.cyclotomic
# be careful when changing this name, don't use test*!
def runtests():
"""
Run all mpmath tests and print output.
"""
import os.path
from inspect import getsourcefile
import tests.runtests as tests
testdir = os.path.dirname(os.path.abspath(getsourcefile(tests)))
importdir = os.path.abspath(testdir + '/../..')
tests.testit(importdir, testdir)
def doctests():
try:
import psyco; psyco.full()
except ImportError:
pass
import sys
from timeit import default_timer as clock
filter = []
for i, arg in enumerate(sys.argv):
if '__init__.py' in arg:
filter = [sn for sn in sys.argv[i+1:] if not sn.startswith("-")]
break
import doctest
globs = globals().copy()
for obj in globs: #sorted(globs.keys()):
if filter:
if not sum([pat in obj for pat in filter]):
continue
print obj,
t1 = clock()
doctest.run_docstring_examples(globs[obj], {}, verbose=("-v" in sys.argv))
t2 = clock()
print round(t2-t1, 3)
if __name__ == '__main__':
doctests()

View File

@ -1,6 +0,0 @@
import calculus
# XXX: hack to set methods
import approximation
import differentiation
import extrapolation
import polynomials

View File

@ -1,246 +0,0 @@
from calculus import defun
#----------------------------------------------------------------------------#
# Approximation methods #
#----------------------------------------------------------------------------#
# The Chebyshev approximation formula is given at:
# http://mathworld.wolfram.com/ChebyshevApproximationFormula.html
# The only major changes in the following code is that we return the
# expanded polynomial coefficients instead of Chebyshev coefficients,
# and that we automatically transform [a,b] -> [-1,1] and back
# for convenience.
# Coefficient in Chebyshev approximation
def chebcoeff(ctx,f,a,b,j,N):
s = ctx.mpf(0)
h = ctx.mpf(0.5)
for k in range(1, N+1):
t = ctx.cos(ctx.pi*(k-h)/N)
s += f(t*(b-a)*h + (b+a)*h) * ctx.cos(ctx.pi*j*(k-h)/N)
return 2*s/N
# Generate Chebyshev polynomials T_n(ax+b) in expanded form
def chebT(ctx, a=1, b=0):
Tb = [1]
yield Tb
Ta = [b, a]
while 1:
yield Ta
# Recurrence: T[n+1](ax+b) = 2*(ax+b)*T[n](ax+b) - T[n-1](ax+b)
Tmp = [0] + [2*a*t for t in Ta]
for i, c in enumerate(Ta): Tmp[i] += 2*b*c
for i, c in enumerate(Tb): Tmp[i] -= c
Ta, Tb = Tmp, Ta
@defun
def chebyfit(ctx, f, interval, N, error=False):
r"""
Computes a polynomial of degree `N-1` that approximates the
given function `f` on the interval `[a, b]`. With ``error=True``,
:func:`chebyfit` also returns an accurate estimate of the
maximum absolute error; that is, the maximum value of
`|f(x) - P(x)|` for `x \in [a, b]`.
:func:`chebyfit` uses the Chebyshev approximation formula,
which gives a nearly optimal solution: that is, the maximum
error of the approximating polynomial is very close to
the smallest possible for any polynomial of the same degree.
Chebyshev approximation is very useful if one needs repeated
evaluation of an expensive function, such as function defined
implicitly by an integral or a differential equation. (For
example, it could be used to turn a slow mpmath function
into a fast machine-precision version of the same.)
**Examples**
Here we use :func:`chebyfit` to generate a low-degree approximation
of `f(x) = \cos(x)`, valid on the interval `[1, 2]`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> poly, err = chebyfit(cos, [1, 2], 5, error=True)
>>> nprint(poly)
[0.00291682, 0.146166, -0.732491, 0.174141, 0.949553]
>>> nprint(err, 12)
1.61351758081e-5
The polynomial can be evaluated using ``polyval``::
>>> nprint(polyval(poly, 1.6), 12)
-0.0291858904138
>>> nprint(cos(1.6), 12)
-0.0291995223013
Sampling the true error at 1000 points shows that the error
estimate generated by ``chebyfit`` is remarkably good::
>>> error = lambda x: abs(cos(x) - polyval(poly, x))
>>> nprint(max([error(1+n/1000.) for n in range(1000)]), 12)
1.61349954245e-5
**Choice of degree**
The degree `N` can be set arbitrarily high, to obtain an
arbitrarily good approximation. As a rule of thumb, an
`N`-term Chebyshev approximation is good to `N/(b-a)` decimal
places on a unit interval (although this depends on how
well-behaved `f` is). The cost grows accordingly: ``chebyfit``
evaluates the function `(N^2)/2` times to compute the
coefficients and an additional `N` times to estimate the error.
**Possible issues**
One should be careful to use a sufficiently high working
precision both when calling ``chebyfit`` and when evaluating
the resulting polynomial, as the polynomial is sometimes
ill-conditioned. It is for example difficult to reach
15-digit accuracy when evaluating the polynomial using
machine precision floats, no matter the theoretical
accuracy of the polynomial. (The option to return the
coefficients in Chebyshev form should be made available
in the future.)
It is important to note the Chebyshev approximation works
poorly if `f` is not smooth. A function containing singularities,
rapid oscillation, etc can be approximated more effectively by
multiplying it by a weight function that cancels out the
nonsmooth features, or by dividing the interval into several
segments.
"""
a, b = ctx._as_points(interval)
orig = ctx.prec
try:
ctx.prec = orig + int(N**0.5) + 20
c = [chebcoeff(ctx,f,a,b,k,N) for k in range(N)]
d = [ctx.zero] * N
d[0] = -c[0]/2
h = ctx.mpf(0.5)
T = chebT(ctx, ctx.mpf(2)/(b-a), ctx.mpf(-1)*(b+a)/(b-a))
for k in range(N):
Tk = T.next()
for i in range(len(Tk)):
d[i] += c[k]*Tk[i]
d = d[::-1]
# Estimate maximum error
err = ctx.zero
for k in range(N):
x = ctx.cos(ctx.pi*k/N) * (b-a)*h + (b+a)*h
err = max(err, abs(f(x) - ctx.polyval(d, x)))
finally:
ctx.prec = orig
if error:
return d, +err
else:
return d
@defun
def fourier(ctx, f, interval, N):
r"""
Computes the Fourier series of degree `N` of the given function
on the interval `[a, b]`. More precisely, :func:`fourier` returns
two lists `(c, s)` of coefficients (the cosine series and sine
series, respectively), such that
.. math ::
f(x) \sim \sum_{k=0}^N
c_k \cos(k m) + s_k \sin(k m)
where `m = 2 \pi / (b-a)`.
Note that many texts define the first coefficient as `2 c_0` instead
of `c_0`. The easiest way to evaluate the computed series correctly
is to pass it to :func:`fourierval`.
**Examples**
The function `f(x) = x` has a simple Fourier series on the standard
interval `[-\pi, \pi]`. The cosine coefficients are all zero (because
the function has odd symmetry), and the sine coefficients are
rational numbers::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> c, s = fourier(lambda x: x, [-pi, pi], 5)
>>> nprint(c)
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>>> nprint(s)
[0.0, 2.0, -1.0, 0.666667, -0.5, 0.4]
This computes a Fourier series of a nonsymmetric function on
a nonstandard interval::
>>> I = [-1, 1.5]
>>> f = lambda x: x**2 - 4*x + 1
>>> cs = fourier(f, I, 4)
>>> nprint(cs[0])
[0.583333, 1.12479, -1.27552, 0.904708, -0.441296]
>>> nprint(cs[1])
[0.0, -2.6255, 0.580905, 0.219974, -0.540057]
It is instructive to plot a function along with its truncated
Fourier series::
>>> plot([f, lambda x: fourierval(cs, I, x)], I) #doctest: +SKIP
Fourier series generally converge slowly (and may not converge
pointwise). For example, if `f(x) = \cosh(x)`, a 10-term Fourier
series gives an `L^2` error corresponding to 2-digit accuracy::
>>> I = [-1, 1]
>>> cs = fourier(cosh, I, 9)
>>> g = lambda x: (cosh(x) - fourierval(cs, I, x))**2
>>> nprint(sqrt(quad(g, I)))
0.00467963
:func:`fourier` uses numerical quadrature. For nonsmooth functions,
the accuracy (and speed) can be improved by including all singular
points in the interval specification::
>>> nprint(fourier(abs, [-1, 1], 0), 10)
([0.5000441648], [0.0])
>>> nprint(fourier(abs, [-1, 0, 1], 0), 10)
([0.5], [0.0])
"""
interval = ctx._as_points(interval)
a = interval[0]
b = interval[-1]
L = b-a
cos_series = []
sin_series = []
cutoff = ctx.eps*10
for n in xrange(N+1):
m = 2*n*ctx.pi/L
an = 2*ctx.quadgl(lambda t: f(t)*ctx.cos(m*t), interval)/L
bn = 2*ctx.quadgl(lambda t: f(t)*ctx.sin(m*t), interval)/L
if n == 0:
an /= 2
if abs(an) < cutoff: an = ctx.zero
if abs(bn) < cutoff: bn = ctx.zero
cos_series.append(an)
sin_series.append(bn)
return cos_series, sin_series
@defun
def fourierval(ctx, series, interval, x):
"""
Evaluates a Fourier series (in the format computed by
by :func:`fourier` for the given interval) at the point `x`.
The series should be a pair `(c, s)` where `c` is the
cosine series and `s` is the sine series. The two lists
need not have the same length.
"""
cs, ss = series
ab = ctx._as_points(interval)
a = interval[0]
b = interval[-1]
m = 2*ctx.pi/(ab[-1]-ab[0])
s = ctx.zero
s += ctx.fsum(cs[n]*ctx.cos(m*n*x) for n in xrange(len(cs)) if cs[n])
s += ctx.fsum(ss[n]*ctx.sin(m*n*x) for n in xrange(len(ss)) if ss[n])
return s

View File

@ -1,5 +0,0 @@
class CalculusMethods(object):
pass
def defun(f):
setattr(CalculusMethods, f.__name__, f)

View File

@ -1,438 +0,0 @@
from calculus import defun
#----------------------------------------------------------------------------#
# Differentiation #
#----------------------------------------------------------------------------#
@defun
def difference_delta(ctx, s, n):
r"""
Given a sequence `(s_k)` containing at least `n+1` items, returns the
`n`-th forward difference,
.. math ::
\Delta^n = \sum_{k=0}^{\infty} (-1)^{k+n} {n \choose k} s_k.
"""
n = int(n)
d = ctx.zero
b = (-1) ** (n & 1)
for k in xrange(n+1):
d += b * s[k]
b = (b * (k-n)) // (k+1)
return d
@defun
def diff(ctx, f, x, n=1, method='step', scale=1, direction=0):
r"""
Numerically computes the derivative of `f`, `f'(x)`. Optionally,
computes the `n`-th derivative, `f^{(n)}(x)`, for any order `n`.
**Basic examples**
Derivatives of a simple function::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> diff(lambda x: x**2 + x, 1.0)
3.0
>>> diff(lambda x: x**2 + x, 1.0, 2)
2.0
>>> diff(lambda x: x**2 + x, 1.0, 3)
0.0
The exponential function is invariant under differentiation::
>>> nprint([diff(exp, 3, n) for n in range(5)])
[20.0855, 20.0855, 20.0855, 20.0855, 20.0855]
**Method**
One of two differentiation algorithms can be chosen with the
``method`` keyword argument. The two options are ``'step'``,
and ``'quad'``. The default method is ``'step'``.
``'step'``:
The derivative is computed using a finite difference
approximation, with a small step h. This requires n+1 function
evaluations and must be performed at (n+1) times the target
precision. Accordingly, f must support fast evaluation at high
precision.
``'quad'``:
The derivative is computed using complex
numerical integration. This requires a larger number of function
evaluations, but the advantage is that not much extra precision
is required. For high order derivatives, this method may thus
be faster if f is very expensive to evaluate at high precision.
With ``'quad'`` the result is likely to have a small imaginary
component even if the derivative is actually real::
>>> diff(sqrt, 1, method='quad') # doctest:+ELLIPSIS
(0.5 - 9.44...e-27j)
**Scale**
The scale option specifies the scale of variation of f. The step
size in the finite difference is taken to be approximately
eps*scale. Thus, for example if `f(x) = \cos(1000 x)`, the scale
should be set to 1/1000 and if `f(x) = \cos(x/1000)`, the scale
should be 1000. By default, scale = 1.
(In practice, the default scale will work even for `\cos(1000 x)` or
`\cos(x/1000)`. Changing this parameter is a good idea if the scale
is something *preposterous*.)
If numerical integration is used, the radius of integration is
taken to be equal to scale/2. Note that f must not have any
singularities within the circle of radius scale/2 centered around
x. If possible, a larger scale value is preferable because it
typically makes the integration faster and more accurate.
**Direction**
By default, :func:`diff` uses a central difference approximation.
This corresponds to direction=0. Alternatively, it can compute a
left difference (direction=-1) or right difference (direction=1).
This is useful for computing left- or right-sided derivatives
of nonsmooth functions:
>>> diff(abs, 0, direction=0)
0.0
>>> diff(abs, 0, direction=1)
1.0
>>> diff(abs, 0, direction=-1)
-1.0
More generally, if the direction is nonzero, a right difference
is computed where the step size is multiplied by sign(direction).
For example, with direction=+j, the derivative from the positive
imaginary direction will be computed.
This option only makes sense with method='step'. If integration
is used, it is assumed that f is analytic, implying that the
derivative is the same in all directions.
"""
if n == 0:
return f(ctx.convert(x))
orig = ctx.prec
try:
if method == 'step':
ctx.prec = (orig+20) * (n+1)
h = ctx.ldexp(scale, -orig-10)
# Applying the finite difference formula recursively n times,
# we get a step sum weighted by a row of binomial coefficients
# Directed: steps x, x+h, ... x+n*h
if direction:
h *= ctx.sign(direction)
steps = xrange(n+1)
norm = h**n
# Central: steps x-n*h, x-(n-2)*h ..., x, ..., x+(n-2)*h, x+n*h
else:
steps = xrange(-n, n+1, 2)
norm = (2*h)**n
v = ctx.difference_delta([f(x+k*h) for k in steps], n)
v = v / norm
elif method == 'quad':
ctx.prec += 10
radius = ctx.mpf(scale)/2
def g(t):
rei = radius*ctx.expj(t)
z = x + rei
return f(z) / rei**n
d = ctx.quadts(g, [0, 2*ctx.pi])
v = d * ctx.factorial(n) / (2*ctx.pi)
else:
raise ValueError("unknown method: %r" % method)
finally:
ctx.prec = orig
return +v
@defun
def diffs(ctx, f, x, n=None, method='step', scale=1, direction=0):
r"""
Returns a generator that yields the sequence of derivatives
.. math ::
f(x), f'(x), f''(x), \ldots, f^{(k)}(x), \ldots
With ``method='step'``, :func:`diffs` uses only `O(k)`
function evaluations to generate the first `k` derivatives,
rather than the roughly `O(k^2)` evaluations
required if one calls :func:`diff` `k` separate times.
With `n < \infty`, the generator stops as soon as the
`n`-th derivative has been generated. If the exact number of
needed derivatives is known in advance, this is further
slightly more efficient.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15
>>> nprint(list(diffs(cos, 1, 5)))
[0.540302, -0.841471, -0.540302, 0.841471, 0.540302, -0.841471]
>>> for i, d in zip(range(6), diffs(cos, 1)): print i, d
...
0 0.54030230586814
1 -0.841470984807897
2 -0.54030230586814
3 0.841470984807897
4 0.54030230586814
5 -0.841470984807897
"""
if n is None:
n = ctx.inf
else:
n = int(n)
if method != 'step':
k = 0
while k < n:
yield ctx.diff(f, x, k)
k += 1
return
targetprec = ctx.prec
def getvalues(m):
callprec = ctx.prec
try:
ctx.prec = workprec = (targetprec+20) * (m+1)
h = ctx.ldexp(scale, -targetprec-10)
if direction:
h *= ctx.sign(direction)
y = [f(x+h*k) for k in xrange(m+1)]
hnorm = h
else:
y = [f(x+h*k) for k in xrange(-m, m+1, 2)]
hnorm = 2*h
return y, hnorm, workprec
finally:
ctx.prec = callprec
yield f(ctx.convert(x))
if n < 1:
return
if n == ctx.inf:
A, B = 1, 2
else:
A, B = 1, n+1
while 1:
y, hnorm, workprec = getvalues(B)
for k in xrange(A, B):
try:
callprec = ctx.prec
ctx.prec = workprec
d = ctx.difference_delta(y, k) / hnorm**k
finally:
ctx.prec = callprec
yield +d
if k >= n:
return
A, B = B, int(A*1.4+1)
B = min(B, n)
@defun
def differint(ctx, f, x, n=1, x0=0):
r"""
Calculates the Riemann-Liouville differintegral, or fractional
derivative, defined by
.. math ::
\,_{x_0}{\mathbb{D}}^n_xf(x) \frac{1}{\Gamma(m-n)} \frac{d^m}{dx^m}
\int_{x_0}^{x}(x-t)^{m-n-1}f(t)dt
where `f` is a given (presumably well-behaved) function,
`x` is the evaluation point, `n` is the order, and `x_0` is
the reference point of integration (`m` is an arbitrary
parameter selected automatically).
With `n = 1`, this is just the standard derivative `f'(x)`; with `n = 2`,
the second derivative `f''(x)`, etc. With `n = -1`, it gives
`\int_{x_0}^x f(t) dt`, with `n = -2`
it gives `\int_{x_0}^x \left( \int_{x_0}^t f(u) du \right) dt`, etc.
As `n` is permitted to be any number, this operator generalizes
iterated differentiation and iterated integration to a single
operator with a continuous order parameter.
**Examples**
There is an exact formula for the fractional derivative of a
monomial `x^p`, which may be used as a reference. For example,
the following gives a half-derivative (order 0.5)::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> x = mpf(3); p = 2; n = 0.5
>>> differint(lambda t: t**p, x, n)
7.81764019044672
>>> gamma(p+1)/gamma(p-n+1) * x**(p-n)
7.81764019044672
Another useful test function is the exponential function, whose
integration / differentiation formula easy generalizes
to arbitrary order. Here we first compute a third derivative,
and then a triply nested integral. (The reference point `x_0`
is set to `-\infty` to avoid nonzero endpoint terms.)::
>>> differint(lambda x: exp(pi*x), -1.5, 3)
0.278538406900792
>>> exp(pi*-1.5) * pi**3
0.278538406900792
>>> differint(lambda x: exp(pi*x), 3.5, -3, -inf)
1922.50563031149
>>> exp(pi*3.5) / pi**3
1922.50563031149
However, for noninteger `n`, the differentiation formula for the
exponential function must be modified to give the same result as the
Riemann-Liouville differintegral::
>>> x = mpf(3.5)
>>> c = pi
>>> n = 1+2*j
>>> differint(lambda x: exp(c*x), x, n)
(-123295.005390743 + 140955.117867654j)
>>> x**(-n) * exp(c)**x * (x*c)**n * gammainc(-n, 0, x*c) / gamma(-n)
(-123295.005390743 + 140955.117867654j)
"""
m = max(int(ctx.ceil(ctx.re(n)))+1, 1)
r = m-n-1
g = lambda x: ctx.quad(lambda t: (x-t)**r * f(t), [x0, x])
return ctx.diff(g, x, m) / ctx.gamma(m-n)
@defun
def diffun(ctx, f, n=1, **options):
"""
Given a function f, returns a function g(x) that evaluates the nth
derivative f^(n)(x)::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> cos2 = diffun(sin)
>>> sin2 = diffun(sin, 4)
>>> cos(1.3), cos2(1.3)
(0.267498828624587, 0.267498828624587)
>>> sin(1.3), sin2(1.3)
(0.963558185417193, 0.963558185417193)
The function f must support arbitrary precision evaluation.
See :func:`diff` for additional details and supported
keyword options.
"""
if n == 0:
return f
def g(x):
return ctx.diff(f, x, n, **options)
return g
@defun
def taylor(ctx, f, x, n, **options):
r"""
Produces a degree-`n` Taylor polynomial around the point `x` of the
given function `f`. The coefficients are returned as a list.
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> nprint(chop(taylor(sin, 0, 5)))
[0.0, 1.0, 0.0, -0.166667, 0.0, 0.00833333]
The coefficients are computed using high-order numerical
differentiation. The function must be possible to evaluate
to arbitrary precision. See :func:`diff` for additional details
and supported keyword options.
Note that to evaluate the Taylor polynomial as an approximation
of `f`, e.g. with :func:`polyval`, the coefficients must be reversed,
and the point of the Taylor expansion must be subtracted from
the argument:
>>> p = taylor(exp, 2.0, 10)
>>> polyval(p[::-1], 2.5 - 2.0)
12.1824939606092
>>> exp(2.5)
12.1824939607035
"""
return [d/ctx.factorial(i) for i, d in enumerate(ctx.diffs(f, x, n, **options))]
@defun
def pade(ctx, a, L, M):
r"""
Computes a Pade approximation of degree `(L, M)` to a function.
Given at least `L+M+1` Taylor coefficients `a` approximating
a function `A(x)`, :func:`pade` returns coefficients of
polynomials `P, Q` satisfying
.. math ::
P = \sum_{k=0}^L p_k x^k
Q = \sum_{k=0}^M q_k x^k
Q_0 = 1
A(x) Q(x) = P(x) + O(x^{L+M+1})
`P(x)/Q(x)` can provide a good approximation to an analytic function
beyond the radius of convergence of its Taylor series (example
from G.A. Baker 'Essentials of Pade Approximants' Academic Press,
Ch.1A)::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> one = mpf(1)
>>> def f(x):
... return sqrt((one + 2*x)/(one + x))
...
>>> a = taylor(f, 0, 6)
>>> p, q = pade(a, 3, 3)
>>> x = 10
>>> polyval(p[::-1], x)/polyval(q[::-1], x)
1.38169105566806
>>> f(x)
1.38169855941551
"""
# To determine L+1 coefficients of P and M coefficients of Q
# L+M+1 coefficients of A must be provided
assert(len(a) >= L+M+1)
if M == 0:
if L == 0:
return [ctx.one], [ctx.one]
else:
return a[:L+1], [ctx.one]
# Solve first
# a[L]*q[1] + ... + a[L-M+1]*q[M] = -a[L+1]
# ...
# a[L+M-1]*q[1] + ... + a[L]*q[M] = -a[L+M]
A = ctx.matrix(M)
for j in range(M):
for i in range(min(M, L+j+1)):
A[j, i] = a[L+j-i]
v = -ctx.matrix(a[(L+1):(L+M+1)])
x = ctx.lu_solve(A, v)
q = [ctx.one] + list(x)
# compute p
p = [0]*(L+1)
for i in range(L+1):
s = a[i]
for j in range(1, min(M,i) + 1):
s += q[j]*a[i-j]
p[i] = s
return p, q

File diff suppressed because it is too large Load Diff

View File

@ -1,287 +0,0 @@
from bisect import bisect
class ODEMethods(object):
pass
def ode_taylor(ctx, derivs, x0, y0, tol_prec, n):
h = tol = ctx.ldexp(1, -tol_prec)
dim = len(y0)
xs = [x0]
ys = [y0]
x = x0
y = y0
orig = ctx.prec
try:
ctx.prec = orig*(1+n)
# Use n steps with Euler's method to get
# evaluation points for derivatives
for i in range(n):
fxy = derivs(x, y)
y = [y[i]+h*fxy[i] for i in xrange(len(y))]
x += h
xs.append(x)
ys.append(y)
# Compute derivatives
ser = [[] for d in range(dim)]
for j in range(n+1):
s = [0]*dim
b = (-1) ** (j & 1)
k = 1
for i in range(j+1):
for d in range(dim):
s[d] += b * ys[i][d]
b = (b * (j-k+1)) // (-k)
k += 1
scale = h**(-j) / ctx.fac(j)
for d in range(dim):
s[d] = s[d] * scale
ser[d].append(s[d])
finally:
ctx.prec = orig
# Estimate radius for which we can get full accuracy.
# XXX: do this right for zeros
radius = ctx.one
for ts in ser:
if ts[-1]:
radius = min(radius, ctx.nthroot(tol/abs(ts[-1]), n))
radius /= 2 # XXX
return ser, x0+radius
def odefun(ctx, F, x0, y0, tol=None, degree=None, method='taylor', verbose=False):
r"""
Returns a function `y(x) = [y_0(x), y_1(x), \ldots, y_n(x)]`
that is a numerical solution of the `n+1`-dimensional first-order
ordinary differential equation (ODE) system
.. math ::
y_0'(x) = F_0(x, [y_0(x), y_1(x), \ldots, y_n(x)])
y_1'(x) = F_1(x, [y_0(x), y_1(x), \ldots, y_n(x)])
\vdots
y_n'(x) = F_n(x, [y_0(x), y_1(x), \ldots, y_n(x)])
The derivatives are specified by the vector-valued function
*F* that evaluates
`[y_0', \ldots, y_n'] = F(x, [y_0, \ldots, y_n])`.
The initial point `x_0` is specified by the scalar argument *x0*,
and the initial value `y(x_0) = [y_0(x_0), \ldots, y_n(x_0)]` is
specified by the vector argument *y0*.
For convenience, if the system is one-dimensional, you may optionally
provide just a scalar value for *y0*. In this case, *F* should accept
a scalar *y* argument and return a scalar. The solution function
*y* will return scalar values instead of length-1 vectors.
Evaluation of the solution function `y(x)` is permitted
for any `x \ge x_0`.
A high-order ODE can be solved by transforming it into first-order
vector form. This transformation is described in standard texts
on ODEs. Examples will also be given below.
**Options, speed and accuracy**
By default, :func:`odefun` uses a high-order Taylor series
method. For reasonably well-behaved problems, the solution will
be fully accurate to within the working precision. Note that
*F* must be possible to evaluate to very high precision
for the generation of Taylor series to work.
To get a faster but less accurate solution, you can set a large
value for *tol* (which defaults roughly to *eps*). If you just
want to plot the solution or perform a basic simulation,
*tol = 0.01* is likely sufficient.
The *degree* argument controls the degree of the solver (with
*method='taylor'*, this is the degree of the Taylor series
expansion). A higher degree means that a longer step can be taken
before a new local solution must be generated from *F*,
meaning that fewer steps are required to get from `x_0` to a given
`x_1`. On the other hand, a higher degree also means that each
local solution becomes more expensive (i.e., more evaluations of
*F* are required per step, and at higher precision).
The optimal setting therefore involves a tradeoff. Generally,
decreasing the *degree* for Taylor series is likely to give faster
solution at low precision, while increasing is likely to be better
at higher precision.
The function
object returned by :func:`odefun` caches the solutions at all step
points and uses polynomial interpolation between step points.
Therefore, once `y(x_1)` has been evaluated for some `x_1`,
`y(x)` can be evaluated very quickly for any `x_0 \le x \le x_1`.
and continuing the evaluation up to `x_2 > x_1` is also fast.
**Examples of first-order ODEs**
We will solve the standard test problem `y'(x) = y(x), y(0) = 1`
which has explicit solution `y(x) = \exp(x)`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> f = odefun(lambda x, y: y, 0, 1)
>>> for x in [0, 1, 2.5]:
... print f(x), exp(x)
...
1.0 1.0
2.71828182845905 2.71828182845905
12.1824939607035 12.1824939607035
The solution with high precision::
>>> mp.dps = 50
>>> f = odefun(lambda x, y: y, 0, 1)
>>> f(1)
2.7182818284590452353602874713526624977572470937
>>> exp(1)
2.7182818284590452353602874713526624977572470937
Using the more general vectorized form, the test problem
can be input as (note that *f* returns a 1-element vector)::
>>> mp.dps = 15
>>> f = odefun(lambda x, y: [y[0]], 0, [1])
>>> f(1)
[2.71828182845905]
:func:`odefun` can solve nonlinear ODEs, which are generally
impossible (and at best difficult) to solve analytically. As
an example of a nonlinear ODE, we will solve `y'(x) = x \sin(y(x))`
for `y(0) = \pi/2`. An exact solution happens to be known
for this problem, and is given by
`y(x) = 2 \tan^{-1}\left(\exp\left(x^2/2\right)\right)`::
>>> f = odefun(lambda x, y: x*sin(y), 0, pi/2)
>>> for x in [2, 5, 10]:
... print f(x), 2*atan(exp(mpf(x)**2/2))
...
2.87255666284091 2.87255666284091
3.14158520028345 3.14158520028345
3.14159265358979 3.14159265358979
If `F` is independent of `y`, an ODE can be solved using direct
integration. We can therefore obtain a reference solution with
:func:`quad`::
>>> f = lambda x: (1+x**2)/(1+x**3)
>>> g = odefun(lambda x, y: f(x), pi, 0)
>>> g(2*pi)
0.72128263801696
>>> quad(f, [pi, 2*pi])
0.72128263801696
**Examples of second-order ODEs**
We will solve the harmonic oscillator equation `y''(x) + y(x) = 0`.
To do this, we introduce the helper functions `y_0 = y, y_1 = y_0'`
whereby the original equation can be written as `y_1' + y_0' = 0`. Put
together, we get the first-order, two-dimensional vector ODE
.. math ::
\begin{cases}
y_0' = y_1 \\
y_1' = -y_0
\end{cases}
To get a well-defined IVP, we need two initial values. With
`y(0) = y_0(0) = 1` and `-y'(0) = y_1(0) = 0`, the problem will of
course be solved by `y(x) = y_0(x) = \cos(x)` and
`-y'(x) = y_1(x) = \sin(x)`. We check this::
>>> f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
>>> for x in [0, 1, 2.5, 10]:
... nprint(f(x), 15)
... nprint([cos(x), sin(x)], 15)
... print "---"
...
[1.0, 0.0]
[1.0, 0.0]
---
[0.54030230586814, 0.841470984807897]
[0.54030230586814, 0.841470984807897]
---
[-0.801143615546934, 0.598472144103957]
[-0.801143615546934, 0.598472144103957]
---
[-0.839071529076452, -0.54402111088937]
[-0.839071529076452, -0.54402111088937]
---
Note that we get both the sine and the cosine solutions
simultaneously.
**TODO**
* Better automatic choice of degree and step size
* Make determination of Taylor series convergence radius
more robust
* Allow solution for `x < x_0`
* Allow solution for complex `x`
* Test for difficult (ill-conditioned) problems
* Implement Runge-Kutta and other algorithms
"""
if tol:
tol_prec = int(-ctx.log(tol, 2))+10
else:
tol_prec = ctx.prec+10
degree = degree or (3 + int(3*ctx.dps/2.))
workprec = ctx.prec + 40
try:
len(y0)
return_vector = True
except TypeError:
F_ = F
F = lambda x, y: [F_(x, y[0])]
y0 = [y0]
return_vector = False
ser, xb = ode_taylor(ctx, F, x0, y0, tol_prec, degree)
series_boundaries = [x0, xb]
series_data = [(ser, x0, xb)]
# We will be working with vectors of Taylor series
def mpolyval(ser, a):
return [ctx.polyval(s[::-1], a) for s in ser]
# Find nearest expansion point; compute if necessary
def get_series(x):
if x < x0:
raise ValueError
n = bisect(series_boundaries, x)
if n < len(series_boundaries):
return series_data[n-1]
while 1:
ser, xa, xb = series_data[-1]
if verbose:
print "Computing Taylor series for [%f, %f]" % (xa, xb)
y = mpolyval(ser, xb-xa)
xa = xb
ser, xb = ode_taylor(ctx, F, xb, y, tol_prec, degree)
series_boundaries.append(xb)
series_data.append((ser, xa, xb))
if x <= xb:
return series_data[-1]
# Evaluation function
def interpolant(x):
x = ctx.convert(x)
orig = ctx.prec
try:
ctx.prec = workprec
ser, xa, xb = get_series(x)
y = mpolyval(ser, x-xa)
finally:
ctx.prec = orig
if return_vector:
return [+yk for yk in y]
else:
return +y[0]
return interpolant
ODEMethods.odefun = odefun
if __name__ == "__main__":
import doctest
doctest.testmod()

File diff suppressed because it is too large Load Diff

View File

@ -1,189 +0,0 @@
from calculus import defun
#----------------------------------------------------------------------------#
# Polynomials #
#----------------------------------------------------------------------------#
# XXX: extra precision
@defun
def polyval(ctx, coeffs, x, derivative=False):
r"""
Given coefficients `[c_n, \ldots, c_2, c_1, c_0]` and a number `x`,
:func:`polyval` evaluates the polynomial
.. math ::
P(x) = c_n x^n + \ldots + c_2 x^2 + c_1 x + c_0.
If *derivative=True* is set, :func:`polyval` simultaneously
evaluates `P(x)` with the derivative, `P'(x)`, and returns the
tuple `(P(x), P'(x))`.
>>> from mpmath import *
>>> mp.pretty = True
>>> polyval([3, 0, 2], 0.5)
2.75
>>> polyval([3, 0, 2], 0.5, derivative=True)
(2.75, 3.0)
The coefficients and the evaluation point may be any combination
of real or complex numbers.
"""
if not coeffs:
return ctx.zero
p = ctx.convert(coeffs[0])
q = ctx.zero
for c in coeffs[1:]:
if derivative:
q = p + x*q
p = c + x*p
if derivative:
return p, q
else:
return p
@defun
def polyroots(ctx, coeffs, maxsteps=50, cleanup=True, extraprec=10, error=False):
"""
Computes all roots (real or complex) of a given polynomial. The roots are
returned as a sorted list, where real roots appear first followed by
complex conjugate roots as adjacent elements. The polynomial should be
given as a list of coefficients, in the format used by :func:`polyval`.
The leading coefficient must be nonzero.
With *error=True*, :func:`polyroots` returns a tuple *(roots, err)* where
*err* is an estimate of the maximum error among the computed roots.
**Examples**
Finding the three real roots of `x^3 - x^2 - 14x + 24`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> nprint(polyroots([1,-1,-14,24]), 4)
[-4.0, 2.0, 3.0]
Finding the two complex conjugate roots of `4x^2 + 3x + 2`, with an
error estimate::
>>> roots, err = polyroots([4,3,2], error=True)
>>> for r in roots:
... print r
...
(-0.375 + 0.59947894041409j)
(-0.375 - 0.59947894041409j)
>>>
>>> err
2.22044604925031e-16
>>>
>>> polyval([4,3,2], roots[0])
(2.22044604925031e-16 + 0.0j)
>>> polyval([4,3,2], roots[1])
(2.22044604925031e-16 + 0.0j)
The following example computes all the 5th roots of unity; that is,
the roots of `x^5 - 1`::
>>> mp.dps = 20
>>> for r in polyroots([1, 0, 0, 0, 0, -1]):
... print r
...
1.0
(-0.8090169943749474241 + 0.58778525229247312917j)
(-0.8090169943749474241 - 0.58778525229247312917j)
(0.3090169943749474241 + 0.95105651629515357212j)
(0.3090169943749474241 - 0.95105651629515357212j)
**Precision and conditioning**
Provided there are no repeated roots, :func:`polyroots` can typically
compute all roots of an arbitrary polynomial to high precision::
>>> mp.dps = 60
>>> for r in polyroots([1, 0, -10, 0, 1]):
... print r
...
-3.14626436994197234232913506571557044551247712918732870123249
-0.317837245195782244725757617296174288373133378433432554879127
0.317837245195782244725757617296174288373133378433432554879127
3.14626436994197234232913506571557044551247712918732870123249
>>>
>>> sqrt(3) + sqrt(2)
3.14626436994197234232913506571557044551247712918732870123249
>>> sqrt(3) - sqrt(2)
0.317837245195782244725757617296174288373133378433432554879127
**Algorithm**
:func:`polyroots` implements the Durand-Kerner method [1], which
uses complex arithmetic to locate all roots simultaneously.
The Durand-Kerner method can be viewed as approximately performing
simultaneous Newton iteration for all the roots. In particular,
the convergence to simple roots is quadratic, just like Newton's
method.
Although all roots are internally calculated using complex arithmetic,
any root found to have an imaginary part smaller than the estimated
numerical error is truncated to a real number. Real roots are placed
first in the returned list, sorted by value. The remaining complex
roots are sorted by real their parts so that conjugate roots end up
next to each other.
**References**
1. http://en.wikipedia.org/wiki/Durand-Kerner_method
"""
if len(coeffs) <= 1:
if not coeffs or not coeffs[0]:
raise ValueError("Input to polyroots must not be the zero polynomial")
# Constant polynomial with no roots
return []
orig = ctx.prec
weps = +ctx.eps
try:
ctx.prec += 10
tol = ctx.eps * 128
deg = len(coeffs) - 1
# Must be monic
lead = ctx.convert(coeffs[0])
if lead == 1:
coeffs = map(ctx.convert, coeffs)
else:
coeffs = [c/lead for c in coeffs]
f = lambda x: ctx.polyval(coeffs, x)
roots = [ctx.mpc((0.4+0.9j)**n) for n in xrange(deg)]
err = [ctx.one for n in xrange(deg)]
# Durand-Kerner iteration until convergence
for step in xrange(maxsteps):
if abs(max(err)) < tol:
break
for i in xrange(deg):
if not abs(err[i]) < tol:
p = roots[i]
x = f(p)
for j in range(deg):
if i != j:
try:
x /= (p-roots[j])
except ZeroDivisionError:
continue
roots[i] = p - x
err[i] = abs(x)
# Remove small imaginary parts
if cleanup:
for i in xrange(deg):
if abs(ctx._im(roots[i])) < weps:
roots[i] = roots[i].real
elif abs(ctx._re(roots[i])) < weps:
roots[i] = roots[i].imag * 1j
roots.sort(key=lambda x: (abs(ctx._im(x)), ctx._re(x)))
finally:
ctx.prec = orig
if error:
err = max(err)
err = max(err, ctx.ldexp(1, -orig+1))
return [+r for r in roots], +err
else:
return [+r for r in roots]

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +0,0 @@
# The py library is part of the "py.test" testing suite (python-codespeak-lib
# on Debian), see http://codespeak.net/py/
import py
#this makes py.test put mpath directory into the sys.path, so that we can
#"import mpmath" from tests nicely
rootdir = py.magic.autopath().dirpath()

View File

@ -1,324 +0,0 @@
from operator import gt, lt
from functions.functions import SpecialFunctions
from functions.rszeta import RSCache
from calculus.quadrature import QuadratureMethods
from calculus.calculus import CalculusMethods
from calculus.optimization import OptimizationMethods
from calculus.odes import ODEMethods
from matrices.matrices import MatrixMethods
from matrices.calculus import MatrixCalculusMethods
from matrices.linalg import LinearAlgebraMethods
from identification import IdentificationMethods
from visualization import VisualizationMethods
import libmp
class Context(object):
pass
class StandardBaseContext(Context,
SpecialFunctions,
RSCache,
QuadratureMethods,
CalculusMethods,
MatrixMethods,
MatrixCalculusMethods,
LinearAlgebraMethods,
IdentificationMethods,
OptimizationMethods,
ODEMethods,
VisualizationMethods):
NoConvergence = libmp.NoConvergence
ComplexResult = libmp.ComplexResult
def __init__(ctx):
ctx._aliases = {}
# Call those that need preinitialization (e.g. for wrappers)
SpecialFunctions.__init__(ctx)
RSCache.__init__(ctx)
QuadratureMethods.__init__(ctx)
CalculusMethods.__init__(ctx)
MatrixMethods.__init__(ctx)
def _init_aliases(ctx):
for alias, value in ctx._aliases.items():
try:
setattr(ctx, alias, getattr(ctx, value))
except AttributeError:
pass
_fixed_precision = False
# XXX
verbose = False
def warn(ctx, msg):
print "Warning:", msg
def bad_domain(ctx, msg):
raise ValueError(msg)
def _re(ctx, x):
if hasattr(x, "real"):
return x.real
return x
def _im(ctx, x):
if hasattr(x, "imag"):
return x.imag
return ctx.zero
def chop(ctx, x, tol=None):
"""
Chops off small real or imaginary parts, or converts
numbers close to zero to exact zeros. The input can be a
single number or an iterable::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> chop(5+1e-10j, tol=1e-9)
mpf('5.0')
>>> nprint(chop([1.0, 1e-20, 3+1e-18j, -4, 2]))
[1.0, 0.0, 3.0, -4.0, 2.0]
The tolerance defaults to ``100*eps``.
"""
if tol is None:
tol = 100*ctx.eps
try:
x = ctx.convert(x)
absx = abs(x)
if abs(x) < tol:
return ctx.zero
if ctx._is_complex_type(x):
if abs(x.imag) < min(tol, absx*tol):
return x.real
if abs(x.real) < min(tol, absx*tol):
return ctx.mpc(0, x.imag)
except TypeError:
if isinstance(x, ctx.matrix):
return x.apply(lambda a: ctx.chop(a, tol))
if hasattr(x, "__iter__"):
return [ctx.chop(a, tol) for a in x]
return x
def almosteq(ctx, s, t, rel_eps=None, abs_eps=None):
r"""
Determine whether the difference between `s` and `t` is smaller
than a given epsilon, either relatively or absolutely.
Both a maximum relative difference and a maximum difference
('epsilons') may be specified. The absolute difference is
defined as `|s-t|` and the relative difference is defined
as `|s-t|/\max(|s|, |t|)`.
If only one epsilon is given, both are set to the same value.
If none is given, both epsilons are set to `2^{-p+m}` where
`p` is the current working precision and `m` is a small
integer. The default setting typically allows :func:`almosteq`
to be used to check for mathematical equality
in the presence of small rounding errors.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15
>>> almosteq(3.141592653589793, 3.141592653589790)
True
>>> almosteq(3.141592653589793, 3.141592653589700)
False
>>> almosteq(3.141592653589793, 3.141592653589700, 1e-10)
True
>>> almosteq(1e-20, 2e-20)
True
>>> almosteq(1e-20, 2e-20, rel_eps=0, abs_eps=0)
False
"""
t = ctx.convert(t)
if abs_eps is None and rel_eps is None:
rel_eps = abs_eps = ctx.ldexp(1, -ctx.prec+4)
if abs_eps is None:
abs_eps = rel_eps
elif rel_eps is None:
rel_eps = abs_eps
diff = abs(s-t)
if diff <= abs_eps:
return True
abss = abs(s)
abst = abs(t)
if abss < abst:
err = diff/abst
else:
err = diff/abss
return err <= rel_eps
def arange(ctx, *args):
r"""
This is a generalized version of Python's :func:`range` function
that accepts fractional endpoints and step sizes and
returns a list of ``mpf`` instances. Like :func:`range`,
:func:`arange` can be called with 1, 2 or 3 arguments:
``arange(b)``
`[0, 1, 2, \ldots, x]`
``arange(a, b)``
`[a, a+1, a+2, \ldots, x]`
``arange(a, b, h)``
`[a, a+h, a+h, \ldots, x]`
where `b-1 \le x < b` (in the third case, `b-h \le x < b`).
Like Python's :func:`range`, the endpoint is not included. To
produce ranges where the endpoint is included, :func:`linspace`
is more convenient.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> arange(4)
[mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0')]
>>> arange(1, 2, 0.25)
[mpf('1.0'), mpf('1.25'), mpf('1.5'), mpf('1.75')]
>>> arange(1, -1, -0.75)
[mpf('1.0'), mpf('0.25'), mpf('-0.5')]
"""
if not len(args) <= 3:
raise TypeError('arange expected at most 3 arguments, got %i'
% len(args))
if not len(args) >= 1:
raise TypeError('arange expected at least 1 argument, got %i'
% len(args))
# set default
a = 0
dt = 1
# interpret arguments
if len(args) == 1:
b = args[0]
elif len(args) >= 2:
a = args[0]
b = args[1]
if len(args) == 3:
dt = args[2]
a, b, dt = ctx.mpf(a), ctx.mpf(b), ctx.mpf(dt)
assert a + dt != a, 'dt is too small and would cause an infinite loop'
# adapt code for sign of dt
if a > b:
if dt > 0:
return []
op = gt
else:
if dt < 0:
return []
op = lt
# create list
result = []
i = 0
t = a
while 1:
t = a + dt*i
i += 1
if op(t, b):
result.append(t)
else:
break
return result
def linspace(ctx, *args, **kwargs):
"""
``linspace(a, b, n)`` returns a list of `n` evenly spaced
samples from `a` to `b`. The syntax ``linspace(mpi(a,b), n)``
is also valid.
This function is often more convenient than :func:`arange`
for partitioning an interval into subintervals, since
the endpoint is included::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> linspace(1, 4, 4)
[mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
>>> linspace(mpi(1,4), 4)
[mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
You may also provide the keyword argument ``endpoint=False``::
>>> linspace(1, 4, 4, endpoint=False)
[mpf('1.0'), mpf('1.75'), mpf('2.5'), mpf('3.25')]
"""
if len(args) == 3:
a = ctx.mpf(args[0])
b = ctx.mpf(args[1])
n = int(args[2])
elif len(args) == 2:
assert hasattr(args[0], '_mpi_')
a = args[0].a
b = args[0].b
n = int(args[1])
else:
raise TypeError('linspace expected 2 or 3 arguments, got %i' \
% len(args))
if n < 1:
raise ValueError('n must be greater than 0')
if not 'endpoint' in kwargs or kwargs['endpoint']:
if n == 1:
return [ctx.mpf(a)]
step = (b - a) / ctx.mpf(n - 1)
y = [i*step + a for i in xrange(n)]
y[-1] = b
else:
step = (b - a) / ctx.mpf(n)
y = [i*step + a for i in xrange(n)]
return y
def cos_sin(ctx, z, **kwargs):
return ctx.cos(z, **kwargs), ctx.sin(z, **kwargs)
def _default_hyper_maxprec(ctx, p):
return int(1000 * p**0.25 + 4*p)
_gcd = staticmethod(libmp.gcd)
list_primes = staticmethod(libmp.list_primes)
bernfrac = staticmethod(libmp.bernfrac)
moebius = staticmethod(libmp.moebius)
_ifac = staticmethod(libmp.ifac)
_eulernum = staticmethod(libmp.eulernum)
def sum_accurately(ctx, terms, check_step=1):
prec = ctx.prec
try:
extraprec = 10
while 1:
ctx.prec = prec + extraprec + 5
max_mag = ctx.ninf
s = ctx.zero
k = 0
for term in terms():
s += term
if (not k % check_step) and term:
term_mag = ctx.mag(term)
max_mag = max(max_mag, term_mag)
sum_mag = ctx.mag(s)
if sum_mag - term_mag > ctx.prec:
break
k += 1
cancellation = max_mag - sum_mag
if cancellation != cancellation:
break
if cancellation < extraprec or ctx._fixed_precision:
break
extraprec += min(ctx.prec, cancellation)
return s
finally:
ctx.prec = prec
def power(ctx, x, y):
return ctx.convert(x) ** ctx.convert(y)
def _zeta_int(ctx, n):
return ctx.zeta(n)

View File

@ -1,278 +0,0 @@
from ctx_base import StandardBaseContext
import math
import cmath
import math2
import function_docs
from libmp import mpf_bernoulli, to_float, int_types
import libmp
class FPContext(StandardBaseContext):
"""
Context for fast low-precision arithmetic (53-bit precision, giving at most
about 15-digit accuracy), using Python's builtin float and complex.
"""
def __init__(ctx):
StandardBaseContext.__init__(ctx)
# Override SpecialFunctions implementation
ctx.loggamma = math2.loggamma
ctx._bernoulli_cache = {}
ctx.pretty = False
ctx._init_aliases()
_mpq = lambda cls, x: float(x[0])/x[1]
NoConvergence = libmp.NoConvergence
def _get_prec(ctx): return 53
def _set_prec(ctx, p): return
def _get_dps(ctx): return 15
def _set_dps(ctx, p): return
_fixed_precision = True
prec = property(_get_prec, _set_prec)
dps = property(_get_dps, _set_dps)
zero = 0.0
one = 1.0
eps = math2.EPS
inf = math2.INF
ninf = math2.NINF
nan = math2.NAN
j = 1j
# Called by SpecialFunctions.__init__()
@classmethod
def _wrap_specfun(cls, name, f, wrap):
if wrap:
def f_wrapped(ctx, *args, **kwargs):
convert = ctx.convert
args = [convert(a) for a in args]
return f(ctx, *args, **kwargs)
else:
f_wrapped = f
f_wrapped.__doc__ = function_docs.__dict__.get(name, "<no doc>")
setattr(cls, name, f_wrapped)
def bernoulli(ctx, n):
cache = ctx._bernoulli_cache
if n in cache:
return cache[n]
cache[n] = to_float(mpf_bernoulli(n, 53, 'n'), strict=True)
return cache[n]
pi = math2.pi
e = math2.e
euler = math2.euler
sqrt2 = 1.4142135623730950488
sqrt5 = 2.2360679774997896964
phi = 1.6180339887498948482
ln2 = 0.69314718055994530942
ln10 = 2.302585092994045684
euler = 0.57721566490153286061
catalan = 0.91596559417721901505
khinchin = 2.6854520010653064453
apery = 1.2020569031595942854
absmin = absmax = abs
def _as_points(ctx, x):
return x
def fneg(ctx, x, **kwargs):
return -ctx.convert(x)
def fadd(ctx, x, y, **kwargs):
return ctx.convert(x)+ctx.convert(y)
def fsub(ctx, x, y, **kwargs):
return ctx.convert(x)-ctx.convert(y)
def fmul(ctx, x, y, **kwargs):
return ctx.convert(x)*ctx.convert(y)
def fdiv(ctx, x, y, **kwargs):
return ctx.convert(x)/ctx.convert(y)
def fsum(ctx, args, absolute=False, squared=False):
if absolute:
if squared:
return sum((abs(x)**2 for x in args), ctx.zero)
return sum((abs(x) for x in args), ctx.zero)
if squared:
return sum((x**2 for x in args), ctx.zero)
return sum(args, ctx.zero)
def fdot(ctx, xs, ys=None):
if ys is not None:
xs = zip(xs, ys)
return sum((x*y for (x,y) in xs), ctx.zero)
def is_special(ctx, x):
return x - x != 0.0
def isnan(ctx, x):
return x != x
def isinf(ctx, x):
return abs(x) == math2.INF
def isnpint(ctx, x):
if type(x) is complex:
if x.imag:
return False
x = x.real
return x <= 0.0 and round(x) == x
mpf = float
mpc = complex
def convert(ctx, x):
try:
return float(x)
except:
return complex(x)
power = staticmethod(math2.pow)
sqrt = staticmethod(math2.sqrt)
exp = staticmethod(math2.exp)
ln = log = staticmethod(math2.log)
cos = staticmethod(math2.cos)
sin = staticmethod(math2.sin)
tan = staticmethod(math2.tan)
cos_sin = staticmethod(math2.cos_sin)
acos = staticmethod(math2.acos)
asin = staticmethod(math2.asin)
atan = staticmethod(math2.atan)
cosh = staticmethod(math2.cosh)
sinh = staticmethod(math2.sinh)
tanh = staticmethod(math2.tanh)
gamma = staticmethod(math2.gamma)
fac = factorial = staticmethod(math2.factorial)
floor = staticmethod(math2.floor)
ceil = staticmethod(math2.ceil)
cospi = staticmethod(math2.cospi)
sinpi = staticmethod(math2.sinpi)
cbrt = staticmethod(math2.cbrt)
_nthroot = staticmethod(math2.nthroot)
_ei = staticmethod(math2.ei)
_e1 = staticmethod(math2.e1)
_zeta = _zeta_int = staticmethod(math2.zeta)
# XXX: math2
def arg(ctx, z):
z = complex(z)
return math.atan2(z.imag, z.real)
def expj(ctx, x):
return ctx.exp(ctx.j*x)
def expjpi(ctx, x):
return ctx.exp(ctx.j*ctx.pi*x)
ldexp = math.ldexp
frexp = math.frexp
def mag(ctx, z):
if z:
return ctx.frexp(abs(z))[1]
return ctx.ninf
def isint(ctx, z):
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
if z.imag:
return False
z = z.real
try:
return z == int(z)
except:
return False
def nint_distance(ctx, z):
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
n = round(z.real)
else:
n = round(z)
if n == z:
return n, ctx.ninf
return n, ctx.mag(abs(z-n))
def _convert_param(ctx, z):
if type(z) is tuple:
p, q = z
return ctx.mpf(p) / q, 'R'
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
intz = int(z.real)
else:
intz = int(z)
if z == intz:
return intz, 'Z'
return z, 'R'
def _is_real_type(ctx, z):
return isinstance(z, float) or isinstance(z, int_types)
def _is_complex_type(ctx, z):
return isinstance(z, complex)
def hypsum(ctx, p, q, types, coeffs, z, maxterms=6000, **kwargs):
coeffs = list(coeffs)
num = range(p)
den = range(p,p+q)
tol = ctx.eps
s = t = 1.0
k = 0
while 1:
for i in num: t *= (coeffs[i]+k)
for i in den: t /= (coeffs[i]+k)
k += 1; t /= k; t *= z; s += t
if abs(t) < tol:
return s
if k > maxterms:
raise ctx.NoConvergence
def atan2(ctx, x, y):
return math.atan2(x, y)
def psi(ctx, m, z):
m = int(m)
if m == 0:
return ctx.digamma(z)
return (-1)**(m+1) * ctx.fac(m) * ctx.zeta(m+1, z)
digamma = staticmethod(math2.digamma)
def harmonic(ctx, x):
x = ctx.convert(x)
if x == 0 or x == 1:
return x
return ctx.digamma(x+1) + ctx.euler
nstr = str
def to_fixed(ctx, x, prec):
return int(math.ldexp(x, prec))
def rand(ctx):
import random
return random.random()
_erf = staticmethod(math2.erf)
_erfc = staticmethod(math2.erfc)
def sum_accurately(ctx, terms, check_step=1):
s = ctx.zero
k = 0
for term in terms():
s += term
if (not k % check_step) and term:
if abs(term) <= 1e-18*abs(s):
break
k += 1
return s

File diff suppressed because it is too large Load Diff

View File

@ -1,986 +0,0 @@
#from ctx_base import StandardBaseContext
from libmp import (MPZ, MPZ_ZERO, MPZ_ONE, int_types, repr_dps,
round_floor, round_ceiling, dps_to_prec, round_nearest, prec_to_dps,
ComplexResult, to_pickable, from_pickable, normalize,
from_int, from_float, from_str, to_int, to_float, to_str,
from_rational, from_man_exp,
fone, fzero, finf, fninf, fnan,
mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul, mpf_mul_int,
mpf_div, mpf_rdiv_int, mpf_pow_int, mpf_mod,
mpf_eq, mpf_cmp, mpf_lt, mpf_gt, mpf_le, mpf_ge,
mpf_hash, mpf_rand,
mpf_sum,
bitcount, to_fixed,
mpc_to_str,
mpc_to_complex, mpc_hash, mpc_pos, mpc_is_nonzero, mpc_neg, mpc_conjugate,
mpc_abs, mpc_add, mpc_add_mpf, mpc_sub, mpc_sub_mpf, mpc_mul, mpc_mul_mpf,
mpc_mul_int, mpc_div, mpc_div_mpf, mpc_pow, mpc_pow_mpf, mpc_pow_int,
mpc_mpf_div,
mpf_pow,
mpi_mid, mpi_delta, mpi_str,
mpi_abs, mpi_pos, mpi_neg, mpi_add, mpi_sub,
mpi_mul, mpi_div, mpi_pow_int, mpi_pow,
mpf_pi, mpf_degree, mpf_e, mpf_phi, mpf_ln2, mpf_ln10,
mpf_euler, mpf_catalan, mpf_apery, mpf_khinchin,
mpf_glaisher, mpf_twinprime, mpf_mertens,
int_types)
import rational
import function_docs
new = object.__new__
class mpnumeric(object):
"""Base class for mpf and mpc."""
__slots__ = []
def __new__(cls, val):
raise NotImplementedError
class _mpf(mpnumeric):
"""
An mpf instance holds a real-valued floating-point number. mpf:s
work analogously to Python floats, but support arbitrary-precision
arithmetic.
"""
__slots__ = ['_mpf_']
def __new__(cls, val=fzero, **kwargs):
"""A new mpf can be created from a Python float, an int, a
or a decimal string representing a number in floating-point
format."""
prec, rounding = cls.context._prec_rounding
if kwargs:
prec = kwargs.get('prec', prec)
if 'dps' in kwargs:
prec = dps_to_prec(kwargs['dps'])
rounding = kwargs.get('rounding', rounding)
if type(val) is cls:
sign, man, exp, bc = val._mpf_
if (not man) and exp:
return val
v = new(cls)
v._mpf_ = normalize(sign, man, exp, bc, prec, rounding)
return v
elif type(val) is tuple:
if len(val) == 2:
v = new(cls)
v._mpf_ = from_man_exp(val[0], val[1], prec, rounding)
return v
if len(val) == 4:
sign, man, exp, bc = val
v = new(cls)
v._mpf_ = normalize(sign, MPZ(man), exp, bc, prec, rounding)
return v
raise ValueError
else:
v = new(cls)
v._mpf_ = mpf_pos(cls.mpf_convert_arg(val, prec, rounding), prec, rounding)
return v
@classmethod
def mpf_convert_arg(cls, x, prec, rounding):
if isinstance(x, int_types): return from_int(x)
if isinstance(x, float): return from_float(x)
if isinstance(x, basestring): return from_str(x, prec, rounding)
if isinstance(x, cls.context.constant): return x.func(prec, rounding)
if hasattr(x, '_mpf_'): return x._mpf_
if hasattr(x, '_mpmath_'):
t = cls.context.convert(x._mpmath_(prec, rounding))
if hasattr(t, '_mpf_'):
return t._mpf_
raise TypeError("cannot create mpf from " + repr(x))
@classmethod
def mpf_convert_rhs(cls, x):
if isinstance(x, int_types): return from_int(x)
if isinstance(x, float): return from_float(x)
if isinstance(x, complex_types): return cls.context.mpc(x)
if isinstance(x, rational.mpq):
p, q = x
return from_rational(p, q, cls.context.prec)
if hasattr(x, '_mpf_'): return x._mpf_
if hasattr(x, '_mpmath_'):
t = cls.context.convert(x._mpmath_(*cls.context._prec_rounding))
if hasattr(t, '_mpf_'):
return t._mpf_
return t
return NotImplemented
@classmethod
def mpf_convert_lhs(cls, x):
x = cls.mpf_convert_rhs(x)
if type(x) is tuple:
return cls.context.make_mpf(x)
return x
man_exp = property(lambda self: self._mpf_[1:3])
man = property(lambda self: self._mpf_[1])
exp = property(lambda self: self._mpf_[2])
bc = property(lambda self: self._mpf_[3])
real = property(lambda self: self)
imag = property(lambda self: self.context.zero)
conjugate = lambda self: self
def __getstate__(self): return to_pickable(self._mpf_)
def __setstate__(self, val): self._mpf_ = from_pickable(val)
def __repr__(s):
if s.context.pretty:
return str(s)
return "mpf('%s')" % to_str(s._mpf_, s.context._repr_digits)
def __str__(s): return to_str(s._mpf_, s.context._str_digits)
def __hash__(s): return mpf_hash(s._mpf_)
def __int__(s): return int(to_int(s._mpf_))
def __long__(s): return long(to_int(s._mpf_))
def __float__(s): return to_float(s._mpf_)
def __complex__(s): return complex(float(s))
def __nonzero__(s): return s._mpf_ != fzero
def __abs__(s):
cls, new, (prec, rounding) = s._ctxdata
v = new(cls)
v._mpf_ = mpf_abs(s._mpf_, prec, rounding)
return v
def __pos__(s):
cls, new, (prec, rounding) = s._ctxdata
v = new(cls)
v._mpf_ = mpf_pos(s._mpf_, prec, rounding)
return v
def __neg__(s):
cls, new, (prec, rounding) = s._ctxdata
v = new(cls)
v._mpf_ = mpf_neg(s._mpf_, prec, rounding)
return v
def _cmp(s, t, func):
if hasattr(t, '_mpf_'):
t = t._mpf_
else:
t = s.mpf_convert_rhs(t)
if t is NotImplemented:
return t
return func(s._mpf_, t)
def __cmp__(s, t): return s._cmp(t, mpf_cmp)
def __lt__(s, t): return s._cmp(t, mpf_lt)
def __gt__(s, t): return s._cmp(t, mpf_gt)
def __le__(s, t): return s._cmp(t, mpf_le)
def __ge__(s, t): return s._cmp(t, mpf_ge)
def __ne__(s, t):
v = s.__eq__(t)
if v is NotImplemented:
return v
return not v
def __rsub__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if type(t) in int_types:
v = new(cls)
v._mpf_ = mpf_sub(from_int(t), s._mpf_, prec, rounding)
return v
t = s.mpf_convert_lhs(t)
if t is NotImplemented:
return t
return t - s
def __rdiv__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if isinstance(t, int_types):
v = new(cls)
v._mpf_ = mpf_rdiv_int(t, s._mpf_, prec, rounding)
return v
t = s.mpf_convert_lhs(t)
if t is NotImplemented:
return t
return t / s
def __rpow__(s, t):
t = s.mpf_convert_lhs(t)
if t is NotImplemented:
return t
return t ** s
def __rmod__(s, t):
t = s.mpf_convert_lhs(t)
if t is NotImplemented:
return t
return t % s
def sqrt(s):
return s.context.sqrt(s)
def ae(s, t, rel_eps=None, abs_eps=None):
return s.context.almosteq(s, t, rel_eps, abs_eps)
def to_fixed(self, prec):
return to_fixed(self._mpf_, prec)
mpf_binary_op = """
def %NAME%(self, other):
mpf, new, (prec, rounding) = self._ctxdata
sval = self._mpf_
if hasattr(other, '_mpf_'):
tval = other._mpf_
%WITH_MPF%
ttype = type(other)
if ttype in int_types:
%WITH_INT%
elif ttype is float:
tval = from_float(other)
%WITH_MPF%
elif hasattr(other, '_mpc_'):
tval = other._mpc_
mpc = type(other)
%WITH_MPC%
elif ttype is complex:
tval = from_float(other.real), from_float(other.imag)
mpc = self.context.mpc
%WITH_MPC%
if isinstance(other, mpnumeric):
return NotImplemented
try:
other = mpf.context.convert(other, strings=False)
except TypeError:
return NotImplemented
return self.%NAME%(other)
"""
return_mpf = "; obj = new(mpf); obj._mpf_ = val; return obj"
return_mpc = "; obj = new(mpc); obj._mpc_ = val; return obj"
mpf_pow_same = """
try:
val = mpf_pow(sval, tval, prec, rounding) %s
except ComplexResult:
if mpf.context.trap_complex:
raise
mpc = mpf.context.mpc
val = mpc_pow((sval, fzero), (tval, fzero), prec, rounding) %s
""" % (return_mpf, return_mpc)
def binary_op(name, with_mpf='', with_int='', with_mpc=''):
code = mpf_binary_op
code = code.replace("%WITH_INT%", with_int)
code = code.replace("%WITH_MPC%", with_mpc)
code = code.replace("%WITH_MPF%", with_mpf)
code = code.replace("%NAME%", name)
np = {}
exec code in globals(), np
return np[name]
_mpf.__eq__ = binary_op('__eq__',
'return mpf_eq(sval, tval)',
'return mpf_eq(sval, from_int(other))',
'return (tval[1] == fzero) and mpf_eq(tval[0], sval)')
_mpf.__add__ = binary_op('__add__',
'val = mpf_add(sval, tval, prec, rounding)' + return_mpf,
'val = mpf_add(sval, from_int(other), prec, rounding)' + return_mpf,
'val = mpc_add_mpf(tval, sval, prec, rounding)' + return_mpc)
_mpf.__sub__ = binary_op('__sub__',
'val = mpf_sub(sval, tval, prec, rounding)' + return_mpf,
'val = mpf_sub(sval, from_int(other), prec, rounding)' + return_mpf,
'val = mpc_sub((sval, fzero), tval, prec, rounding)' + return_mpc)
_mpf.__mul__ = binary_op('__mul__',
'val = mpf_mul(sval, tval, prec, rounding)' + return_mpf,
'val = mpf_mul_int(sval, other, prec, rounding)' + return_mpf,
'val = mpc_mul_mpf(tval, sval, prec, rounding)' + return_mpc)
_mpf.__div__ = binary_op('__div__',
'val = mpf_div(sval, tval, prec, rounding)' + return_mpf,
'val = mpf_div(sval, from_int(other), prec, rounding)' + return_mpf,
'val = mpc_mpf_div(sval, tval, prec, rounding)' + return_mpc)
_mpf.__mod__ = binary_op('__mod__',
'val = mpf_mod(sval, tval, prec, rounding)' + return_mpf,
'val = mpf_mod(sval, from_int(other), prec, rounding)' + return_mpf,
'raise NotImplementedError("complex modulo")')
_mpf.__pow__ = binary_op('__pow__',
mpf_pow_same,
'val = mpf_pow_int(sval, other, prec, rounding)' + return_mpf,
'val = mpc_pow((sval, fzero), tval, prec, rounding)' + return_mpc)
_mpf.__radd__ = _mpf.__add__
_mpf.__rmul__ = _mpf.__mul__
_mpf.__truediv__ = _mpf.__div__
_mpf.__rtruediv__ = _mpf.__rdiv__
class _constant(_mpf):
"""Represents a mathematical constant with dynamic precision.
When printed or used in an arithmetic operation, a constant
is converted to a regular mpf at the working precision. A
regular mpf can also be obtained using the operation +x."""
def __new__(cls, func, name, docname=''):
a = object.__new__(cls)
a.name = name
a.func = func
a.__doc__ = getattr(function_docs, docname, '')
return a
def __call__(self, prec=None, dps=None, rounding=None):
prec2, rounding2 = self.context._prec_rounding
if not prec: prec = prec2
if not rounding: rounding = rounding2
if dps: prec = dps_to_prec(dps)
return self.context.make_mpf(self.func(prec, rounding))
@property
def _mpf_(self):
prec, rounding = self.context._prec_rounding
return self.func(prec, rounding)
def __repr__(self):
return "<%s: %s~>" % (self.name, self.context.nstr(self))
class _mpc(mpnumeric):
"""
An mpc represents a complex number using a pair of mpf:s (one
for the real part and another for the imaginary part.) The mpc
class behaves fairly similarly to Python's complex type.
"""
__slots__ = ['_mpc_']
def __new__(cls, real=0, imag=0):
s = object.__new__(cls)
if isinstance(real, complex_types):
real, imag = real.real, real.imag
elif hasattr(real, '_mpc_'):
s._mpc_ = real._mpc_
return s
real = cls.context.mpf(real)
imag = cls.context.mpf(imag)
s._mpc_ = (real._mpf_, imag._mpf_)
return s
real = property(lambda self: self.context.make_mpf(self._mpc_[0]))
imag = property(lambda self: self.context.make_mpf(self._mpc_[1]))
def __getstate__(self):
return to_pickable(self._mpc_[0]), to_pickable(self._mpc_[1])
def __setstate__(self, val):
self._mpc_ = from_pickable(val[0]), from_pickable(val[1])
def __repr__(s):
if s.context.pretty:
return str(s)
r = repr(s.real)[4:-1]
i = repr(s.imag)[4:-1]
return "%s(real=%s, imag=%s)" % (type(s).__name__, r, i)
def __str__(s):
return "(%s)" % mpc_to_str(s._mpc_, s.context._str_digits)
def __complex__(s):
return mpc_to_complex(s._mpc_)
def __pos__(s):
cls, new, (prec, rounding) = s._ctxdata
v = new(cls)
v._mpc_ = mpc_pos(s._mpc_, prec, rounding)
return v
def __abs__(s):
prec, rounding = s.context._prec_rounding
v = new(s.context.mpf)
v._mpf_ = mpc_abs(s._mpc_, prec, rounding)
return v
def __neg__(s):
cls, new, (prec, rounding) = s._ctxdata
v = new(cls)
v._mpc_ = mpc_neg(s._mpc_, prec, rounding)
return v
def conjugate(s):
cls, new, (prec, rounding) = s._ctxdata
v = new(cls)
v._mpc_ = mpc_conjugate(s._mpc_, prec, rounding)
return v
def __nonzero__(s):
return mpc_is_nonzero(s._mpc_)
def __hash__(s):
return mpc_hash(s._mpc_)
@classmethod
def mpc_convert_lhs(cls, x):
try:
y = cls.context.convert(x)
return y
except TypeError:
return NotImplemented
def __eq__(s, t):
if not hasattr(t, '_mpc_'):
if isinstance(t, str):
return False
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
return s.real == t.real and s.imag == t.imag
def __ne__(s, t):
b = s.__eq__(t)
if b is NotImplemented:
return b
return not b
def _compare(*args):
raise TypeError("no ordering relation is defined for complex numbers")
__gt__ = _compare
__le__ = _compare
__gt__ = _compare
__ge__ = _compare
def __add__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if not hasattr(t, '_mpc_'):
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
if hasattr(t, '_mpf_'):
v = new(cls)
v._mpc_ = mpc_add_mpf(s._mpc_, t._mpf_, prec, rounding)
return v
v = new(cls)
v._mpc_ = mpc_add(s._mpc_, t._mpc_, prec, rounding)
return v
def __sub__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if not hasattr(t, '_mpc_'):
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
if hasattr(t, '_mpf_'):
v = new(cls)
v._mpc_ = mpc_sub_mpf(s._mpc_, t._mpf_, prec, rounding)
return v
v = new(cls)
v._mpc_ = mpc_sub(s._mpc_, t._mpc_, prec, rounding)
return v
def __mul__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if not hasattr(t, '_mpc_'):
if isinstance(t, int_types):
v = new(cls)
v._mpc_ = mpc_mul_int(s._mpc_, t, prec, rounding)
return v
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
if hasattr(t, '_mpf_'):
v = new(cls)
v._mpc_ = mpc_mul_mpf(s._mpc_, t._mpf_, prec, rounding)
return v
t = s.mpc_convert_lhs(t)
v = new(cls)
v._mpc_ = mpc_mul(s._mpc_, t._mpc_, prec, rounding)
return v
def __div__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if not hasattr(t, '_mpc_'):
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
if hasattr(t, '_mpf_'):
v = new(cls)
v._mpc_ = mpc_div_mpf(s._mpc_, t._mpf_, prec, rounding)
return v
v = new(cls)
v._mpc_ = mpc_div(s._mpc_, t._mpc_, prec, rounding)
return v
def __pow__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if isinstance(t, int_types):
v = new(cls)
v._mpc_ = mpc_pow_int(s._mpc_, t, prec, rounding)
return v
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
v = new(cls)
if hasattr(t, '_mpf_'):
v._mpc_ = mpc_pow_mpf(s._mpc_, t._mpf_, prec, rounding)
else:
v._mpc_ = mpc_pow(s._mpc_, t._mpc_, prec, rounding)
return v
__radd__ = __add__
def __rsub__(s, t):
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
return t - s
def __rmul__(s, t):
cls, new, (prec, rounding) = s._ctxdata
if isinstance(t, int_types):
v = new(cls)
v._mpc_ = mpc_mul_int(s._mpc_, t, prec, rounding)
return v
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
return t * s
def __rdiv__(s, t):
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
return t / s
def __rpow__(s, t):
t = s.mpc_convert_lhs(t)
if t is NotImplemented:
return t
return t ** s
__truediv__ = __div__
__rtruediv__ = __rdiv__
def ae(s, t, rel_eps=None, abs_eps=None):
return s.context.almosteq(s, t, rel_eps, abs_eps)
complex_types = (complex, _mpc)
class PythonMPContext:
def __init__(ctx):
ctx._prec_rounding = [53, round_nearest]
ctx.mpf = type('mpf', (_mpf,), {})
ctx.mpc = type('mpc', (_mpc,), {})
ctx.mpf._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
ctx.mpc._ctxdata = [ctx.mpc, new, ctx._prec_rounding]
ctx.mpf.context = ctx
ctx.mpc.context = ctx
ctx.constant = type('constant', (_constant,), {})
ctx.constant._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
ctx.constant.context = ctx
def make_mpf(ctx, v):
a = new(ctx.mpf)
a._mpf_ = v
return a
def make_mpc(ctx, v):
a = new(ctx.mpc)
a._mpc_ = v
return a
def default(ctx):
ctx._prec = ctx._prec_rounding[0] = 53
ctx._dps = 15
ctx.trap_complex = False
def _set_prec(ctx, n):
ctx._prec = ctx._prec_rounding[0] = max(1, int(n))
ctx._dps = prec_to_dps(n)
def _set_dps(ctx, n):
ctx._prec = ctx._prec_rounding[0] = dps_to_prec(n)
ctx._dps = max(1, int(n))
prec = property(lambda ctx: ctx._prec, _set_prec)
dps = property(lambda ctx: ctx._dps, _set_dps)
def convert(ctx, x, strings=True):
"""
Converts *x* to an ``mpf``, ``mpc`` or ``mpi``. If *x* is of type ``mpf``,
``mpc``, ``int``, ``float``, ``complex``, the conversion
will be performed losslessly.
If *x* is a string, the result will be rounded to the present
working precision. Strings representing fractions or complex
numbers are permitted.
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> mpmathify(3.5)
mpf('3.5')
>>> mpmathify('2.1')
mpf('2.1000000000000001')
>>> mpmathify('3/4')
mpf('0.75')
>>> mpmathify('2+3j')
mpc(real='2.0', imag='3.0')
"""
if type(x) in ctx.types: return x
if isinstance(x, int_types): return ctx.make_mpf(from_int(x))
if isinstance(x, float): return ctx.make_mpf(from_float(x))
if isinstance(x, complex):
return ctx.make_mpc((from_float(x.real), from_float(x.imag)))
prec, rounding = ctx._prec_rounding
if isinstance(x, rational.mpq):
p, q = x
return ctx.make_mpf(from_rational(p, q, prec))
if strings and isinstance(x, basestring):
try:
_mpf_ = from_str(x, prec, rounding)
return ctx.make_mpf(_mpf_)
except ValueError:
pass
if hasattr(x, '_mpf_'): return ctx.make_mpf(x._mpf_)
if hasattr(x, '_mpc_'): return ctx.make_mpc(x._mpc_)
if hasattr(x, '_mpmath_'):
return ctx.convert(x._mpmath_(prec, rounding))
return ctx._convert_fallback(x, strings)
def isnan(ctx, x):
"""
For an ``mpf`` *x*, determines whether *x* is not-a-number (nan)::
>>> from mpmath import *
>>> isnan(nan), isnan(3)
(True, False)
"""
if not hasattr(x, '_mpf_'):
return False
return x._mpf_ == fnan
def isinf(ctx, x):
"""
For an ``mpf`` *x*, determines whether *x* is infinite::
>>> from mpmath import *
>>> isinf(inf), isinf(-inf), isinf(3)
(True, True, False)
"""
if not hasattr(x, '_mpf_'):
return False
return x._mpf_ in (finf, fninf)
def isint(ctx, x):
"""
For an ``mpf`` *x*, or any type that can be converted
to ``mpf``, determines whether *x* is exactly
integer-valued::
>>> from mpmath import *
>>> isint(3), isint(mpf(3)), isint(3.2)
(True, True, False)
"""
if isinstance(x, int_types):
return True
try:
x = ctx.convert(x)
except:
return False
if hasattr(x, '_mpf_'):
if ctx.isnan(x) or ctx.isinf(x):
return False
return x == int(x)
if isinstance(x, ctx.mpq):
p, q = x
return not (p % q)
return False
def fsum(ctx, terms, absolute=False, squared=False):
"""
Calculates a sum containing a finite number of terms (for infinite
series, see :func:`nsum`). The terms will be converted to
mpmath numbers. For len(terms) > 2, this function is generally
faster and produces more accurate results than the builtin
Python function :func:`sum`.
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> fsum([1, 2, 0.5, 7])
mpf('10.5')
With squared=True each term is squared, and with absolute=True
the absolute value of each term is used.
"""
prec, rnd = ctx._prec_rounding
real = []
imag = []
other = 0
for term in terms:
reval = imval = 0
if hasattr(term, "_mpf_"):
reval = term._mpf_
elif hasattr(term, "_mpc_"):
reval, imval = term._mpc_
else:
term = ctx.convert(term)
if hasattr(term, "_mpf_"):
reval = term._mpf_
elif hasattr(term, "_mpc_"):
reval, imval = term._mpc_
else:
if absolute: term = ctx.absmax(term)
if squared: term = term**2
other += term
continue
if imval:
if squared:
if absolute:
real.append(mpf_mul(reval,reval))
real.append(mpf_mul(imval,imval))
else:
reval, imval = mpc_pow_int((reval,imval),2,prec+10)
real.append(reval)
imag.append(imval)
elif absolute:
real.append(mpc_abs((reval,imval), prec))
else:
real.append(reval)
imag.append(imval)
else:
if squared:
reval = mpf_mul(reval, reval)
elif absolute:
reval = mpf_abs(reval)
real.append(reval)
s = mpf_sum(real, prec, rnd, absolute)
if imag:
s = ctx.make_mpc((s, mpf_sum(imag, prec, rnd)))
else:
s = ctx.make_mpf(s)
if other is 0:
return s
else:
return s + other
def fdot(ctx, A, B=None):
r"""
Computes the dot product of the iterables `A` and `B`,
.. math ::
\sum_{k=0} A_k B_k.
Alternatively, :func:`fdot` accepts a single iterable of pairs.
In other words, ``fdot(A,B)`` and ``fdot(zip(A,B))`` are equivalent.
The elements are automatically converted to mpmath numbers.
Examples::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> A = [2, 1.5, 3]
>>> B = [1, -1, 2]
>>> fdot(A, B)
mpf('6.5')
>>> zip(A, B)
[(2, 1), (1.5, -1), (3, 2)]
>>> fdot(_)
mpf('6.5')
"""
if B:
A = zip(A, B)
prec, rnd = ctx._prec_rounding
real = []
imag = []
other = 0
hasattr_ = hasattr
types = (ctx.mpf, ctx.mpc)
for a, b in A:
if type(a) not in types: a = ctx.convert(a)
if type(b) not in types: b = ctx.convert(b)
a_real = hasattr_(a, "_mpf_")
b_real = hasattr_(b, "_mpf_")
if a_real and b_real:
real.append(mpf_mul(a._mpf_, b._mpf_))
continue
a_complex = hasattr_(a, "_mpc_")
b_complex = hasattr_(b, "_mpc_")
if a_real and b_complex:
aval = a._mpf_
bre, bim = b._mpc_
real.append(mpf_mul(aval, bre))
imag.append(mpf_mul(aval, bim))
elif b_real and a_complex:
are, aim = a._mpc_
bval = b._mpf_
real.append(mpf_mul(are, bval))
imag.append(mpf_mul(aim, bval))
elif a_complex and b_complex:
re, im = mpc_mul(a._mpc_, b._mpc_, prec+20)
real.append(re)
imag.append(im)
else:
other += a*b
s = mpf_sum(real, prec, rnd)
if imag:
s = ctx.make_mpc((s, mpf_sum(imag, prec, rnd)))
else:
s = ctx.make_mpf(s)
if other is 0:
return s
else:
return s + other
def _wrap_libmp_function(ctx, mpf_f, mpc_f=None, mpi_f=None, doc="<no doc>"):
"""
Given a low-level mpf_ function, and optionally similar functions
for mpc_ and mpi_, defines the function as a context method.
It is assumed that the return type is the same as that of
the input; the exception is that propagation from mpf to mpc is possible
by raising ComplexResult.
"""
def f(x, **kwargs):
if type(x) not in ctx.types:
x = ctx.convert(x)
prec, rounding = ctx._prec_rounding
if kwargs:
prec = kwargs.get('prec', prec)
if 'dps' in kwargs:
prec = dps_to_prec(kwargs['dps'])
rounding = kwargs.get('rounding', rounding)
if hasattr(x, '_mpf_'):
try:
return ctx.make_mpf(mpf_f(x._mpf_, prec, rounding))
except ComplexResult:
# Handle propagation to complex
if ctx.trap_complex:
raise
return ctx.make_mpc(mpc_f((x._mpf_, fzero), prec, rounding))
elif hasattr(x, '_mpc_'):
return ctx.make_mpc(mpc_f(x._mpc_, prec, rounding))
elif hasattr(x, '_mpi_'):
if mpi_f:
return ctx.make_mpi(mpi_f(x._mpi_, prec))
raise NotImplementedError("%s of a %s" % (name, type(x)))
name = mpf_f.__name__[4:]
f.__doc__ = function_docs.__dict__.get(name, "Computes the %s of x" % doc)
return f
# Called by SpecialFunctions.__init__()
@classmethod
def _wrap_specfun(cls, name, f, wrap):
if wrap:
def f_wrapped(ctx, *args, **kwargs):
convert = ctx.convert
args = [convert(a) for a in args]
prec = ctx.prec
try:
ctx.prec += 10
retval = f(ctx, *args, **kwargs)
finally:
ctx.prec = prec
return +retval
else:
f_wrapped = f
f_wrapped.__doc__ = function_docs.__dict__.get(name, "<no doc>")
setattr(cls, name, f_wrapped)
def _convert_param(ctx, x):
if hasattr(x, "_mpc_"):
v, im = x._mpc_
if im != fzero:
return x, 'C'
elif hasattr(x, "_mpf_"):
v = x._mpf_
else:
if type(x) in int_types:
return int(x), 'Z'
p = None
if isinstance(x, tuple):
p, q = x
elif isinstance(x, basestring) and '/' in x:
p, q = x.split('/')
p = int(p)
q = int(q)
if p is not None:
if not p % q:
return p // q, 'Z'
return ctx.mpq((p,q)), 'Q'
x = ctx.convert(x)
if hasattr(x, "_mpc_"):
v, im = x._mpc_
if im != fzero:
return x, 'C'
elif hasattr(x, "_mpf_"):
v = x._mpf_
else:
return x, 'U'
sign, man, exp, bc = v
if man:
if exp >= -4:
if sign:
man = -man
if exp >= 0:
return int(man) << exp, 'Z'
if exp >= -4:
p, q = int(man), (1<<(-exp))
return ctx.mpq((p,q)), 'Q'
x = ctx.make_mpf(v)
return x, 'R'
elif not exp:
return 0, 'Z'
else:
return x, 'U'
def _mpf_mag(ctx, x):
sign, man, exp, bc = x
if man:
return exp+bc
if x == fzero:
return ctx.ninf
if x == finf or x == fninf:
return ctx.inf
return ctx.nan
def mag(ctx, x):
"""
Quick logarithmic magnitude estimate of a number.
Returns an integer or infinity `m` such that `|x| <= 2^m`.
It is not guaranteed that `m` is an optimal bound,
but it will never be off by more than 2 (and probably not
more than 1).
"""
if hasattr(x, "_mpf_"):
return ctx._mpf_mag(x._mpf_)
elif hasattr(x, "_mpc_"):
r, i = x._mpc_
if r == fzero:
return ctx._mpf_mag(i)
if i == fzero:
return ctx._mpf_mag(r)
return 1+max(ctx._mpf_mag(r), ctx._mpf_mag(i))
elif isinstance(x, int_types):
if x:
return bitcount(abs(x))
return ctx.ninf
elif isinstance(x, rational.mpq):
p, q = x
if p:
return 1 + bitcount(abs(p)) - bitcount(abs(q))
return ctx.ninf
else:
x = ctx.convert(x)
if hasattr(x, "_mpf_") or hasattr(x, "_mpc_"):
return ctx.mag(x)
else:
raise TypeError("requires an mpf/mpc")

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +0,0 @@
import functions
# Hack to update methods
import factorials
import hypergeometric
import elliptic
import zeta
import rszeta

File diff suppressed because it is too large Load Diff

View File

@ -1,196 +0,0 @@
from functions import defun, defun_wrapped
@defun
def gammaprod(ctx, a, b, _infsign=False):
a = [ctx.convert(x) for x in a]
b = [ctx.convert(x) for x in b]
poles_num = []
poles_den = []
regular_num = []
regular_den = []
for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
# One more pole in numerator or denominator gives 0 or inf
if len(poles_num) < len(poles_den): return ctx.zero
if len(poles_num) > len(poles_den):
# Get correct sign of infinity for x+h, h -> 0 from above
# XXX: hack, this should be done properly
if _infsign:
a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
else:
return ctx.inf
# All poles cancel
# lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
p = ctx.one
orig = ctx.prec
try:
ctx.prec = orig + 15
while poles_num:
i = poles_num.pop()
j = poles_den.pop()
p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
for x in regular_num: p *= ctx.gamma(x)
for x in regular_den: p /= ctx.gamma(x)
finally:
ctx.prec = orig
return +p
@defun
def beta(ctx, x, y):
x = ctx.convert(x)
y = ctx.convert(y)
if ctx.isinf(y):
x, y = y, x
if ctx.isinf(x):
if x == ctx.inf and not ctx._im(y):
if y == ctx.ninf:
return ctx.nan
if y > 0:
return ctx.zero
if ctx.isint(y):
return ctx.nan
if y < 0:
return ctx.sign(ctx.gamma(y)) * ctx.inf
return ctx.nan
return ctx.gammaprod([x, y], [x+y])
@defun
def binomial(ctx, n, k):
return ctx.gammaprod([n+1], [k+1, n-k+1])
@defun
def rf(ctx, x, n):
return ctx.gammaprod([x+n], [x])
@defun
def ff(ctx, x, n):
return ctx.gammaprod([x+1], [x-n+1])
@defun_wrapped
def fac2(ctx, x):
if ctx.isinf(x):
if x == ctx.inf:
return x
return ctx.nan
return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
@defun_wrapped
def barnesg(ctx, z):
if ctx.isinf(z):
if z == ctx.inf:
return z
return ctx.nan
if ctx.isnan(z):
return z
if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
return z*0
# Account for size (would not be needed if computing log(G))
if abs(z) > 5:
ctx.dps += 2*ctx.log(abs(z),2)
# Reflection formula
if ctx.re(z) < -ctx.dps:
w = 1-z
pi2 = 2*ctx.pi
u = ctx.expjpi(2*w)
v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
ctx.j*ctx.polylog(2, u)/pi2
v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
if ctx._is_real_type(z):
v = ctx._re(v)
return v
# Estimate terms for asymptotic expansion
# TODO: fixme, obviously
N = ctx.dps // 2 + 5
G = 1
while abs(z) < N or ctx.re(z) < 1:
G /= ctx.gamma(z)
z += 1
z -= 1
s = ctx.mpf(1)/12
s -= ctx.log(ctx.glaisher)
s += z*ctx.log(2*ctx.pi)/2
s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
s -= 3*z**2/4
z2k = z2 = z**2
for k in xrange(1, N+1):
t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
if abs(t) < ctx.eps:
#print k, N # check how many terms were needed
break
z2k *= z2
s += t
#if k == N:
# print "warning: series for barnesg failed to converge", ctx.dps
return G*ctx.exp(s)
@defun
def superfac(ctx, z):
return ctx.barnesg(z+2)
@defun_wrapped
def hyperfac(ctx, z):
# XXX: estimate needed extra bits accurately
if z == ctx.inf:
return z
if abs(z) > 5:
extra = 4*int(ctx.log(abs(z),2))
else:
extra = 0
ctx.prec += extra
if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
n = int(ctx.re(z))
h = ctx.hyperfac(-n-1)
if ((n+1)//2) & 1:
h = -h
if ctx._is_complex_type(z):
return h + 0j
return h
zp1 = z+1
# Wrong branch cut
#v = ctx.gamma(zp1)**z
#ctx.prec -= extra
#return v / ctx.barnesg(zp1)
v = ctx.exp(z*ctx.loggamma(zp1))
ctx.prec -= extra
return v / ctx.barnesg(zp1)
@defun_wrapped
def loggamma(ctx, z):
a = ctx._re(z)
b = ctx._im(z)
if not b and a > 0:
return ctx.ln(ctx.gamma(z))
u = ctx.arg(z)
w = ctx.ln(ctx.gamma(z))
if b:
gi = -b - u/2 + a*u + b*ctx.ln(abs(z))
n = ctx.floor((gi-ctx._im(w))/(2*ctx.pi)+0.5) * (2*ctx.pi)
return w + n*ctx.j
elif a < 0:
n = int(ctx.floor(a))
w += (n-(n%2))*ctx.pi*ctx.j
return w
'''
@defun
def psi0(ctx, z):
"""Shortcut for psi(0,z) (the digamma function)"""
return ctx.psi(0, z)
@defun
def psi1(ctx, z):
"""Shortcut for psi(1,z) (the trigamma function)"""
return ctx.psi(1, z)
@defun
def psi2(ctx, z):
"""Shortcut for psi(2,z) (the tetragamma function)"""
return ctx.psi(2, z)
@defun
def psi3(ctx, z):
"""Shortcut for psi(3,z) (the pentagamma function)"""
return ctx.psi(3, z)
'''

View File

@ -1,435 +0,0 @@
class SpecialFunctions(object):
"""
This class implements special functions using high-level code.
Elementary and some other functions (e.g. gamma function, basecase
hypergeometric series) are assumed to be predefined by the context as
"builtins" or "low-level" functions.
"""
defined_functions = {}
# The series for the Jacobi theta functions converge for |q| < 1;
# in the current implementation they throw a ValueError for
# abs(q) > THETA_Q_LIM
THETA_Q_LIM = 1 - 10**-7
def __init__(self):
cls = self.__class__
for name in cls.defined_functions:
f, wrap = cls.defined_functions[name]
cls._wrap_specfun(name, f, wrap)
self.mpq_1 = self._mpq((1,1))
self.mpq_0 = self._mpq((0,1))
self.mpq_1_2 = self._mpq((1,2))
self.mpq_3_2 = self._mpq((3,2))
self.mpq_1_4 = self._mpq((1,4))
self.mpq_1_16 = self._mpq((1,16))
self.mpq_3_16 = self._mpq((3,16))
self.mpq_5_2 = self._mpq((5,2))
self.mpq_3_4 = self._mpq((3,4))
self.mpq_7_4 = self._mpq((7,4))
self.mpq_5_4 = self._mpq((5,4))
self._aliases.update({
'phase' : 'arg',
'conjugate' : 'conj',
'nthroot' : 'root',
'polygamma' : 'psi',
'hurwitz' : 'zeta',
#'digamma' : 'psi0',
#'trigamma' : 'psi1',
#'tetragamma' : 'psi2',
#'pentagamma' : 'psi3',
'fibonacci' : 'fib',
'factorial' : 'fac',
})
# Default -- do nothing
@classmethod
def _wrap_specfun(cls, name, f, wrap):
setattr(cls, name, f)
# Optional fast versions of common functions in common cases.
# If not overridden, default (generic hypergeometric series)
# implementations will be used
def _besselj(ctx, n, z): raise NotImplementedError
def _erf(ctx, z): raise NotImplementedError
def _erfc(ctx, z): raise NotImplementedError
def _gamma_upper_int(ctx, z, a): raise NotImplementedError
def _expint_int(ctx, n, z): raise NotImplementedError
def _zeta(ctx, s): raise NotImplementedError
def _zetasum_fast(ctx, s, a, n, derivatives, reflect): raise NotImplementedError
def _ei(ctx, z): raise NotImplementedError
def _e1(ctx, z): raise NotImplementedError
def _ci(ctx, z): raise NotImplementedError
def _si(ctx, z): raise NotImplementedError
def _altzeta(ctx, s): raise NotImplementedError
def defun_wrapped(f):
SpecialFunctions.defined_functions[f.__name__] = f, True
def defun(f):
SpecialFunctions.defined_functions[f.__name__] = f, False
def defun_static(f):
setattr(SpecialFunctions, f.__name__, f)
@defun_wrapped
def cot(ctx, z): return ctx.one / ctx.tan(z)
@defun_wrapped
def sec(ctx, z): return ctx.one / ctx.cos(z)
@defun_wrapped
def csc(ctx, z): return ctx.one / ctx.sin(z)
@defun_wrapped
def coth(ctx, z): return ctx.one / ctx.tanh(z)
@defun_wrapped
def sech(ctx, z): return ctx.one / ctx.cosh(z)
@defun_wrapped
def csch(ctx, z): return ctx.one / ctx.sinh(z)
@defun_wrapped
def acot(ctx, z): return ctx.atan(ctx.one / z)
@defun_wrapped
def asec(ctx, z): return ctx.acos(ctx.one / z)
@defun_wrapped
def acsc(ctx, z): return ctx.asin(ctx.one / z)
@defun_wrapped
def acoth(ctx, z): return ctx.atanh(ctx.one / z)
@defun_wrapped
def asech(ctx, z): return ctx.acosh(ctx.one / z)
@defun_wrapped
def acsch(ctx, z): return ctx.asinh(ctx.one / z)
@defun
def sign(ctx, x):
x = ctx.convert(x)
if not x or ctx.isnan(x):
return x
if ctx._is_real_type(x):
return ctx.mpf(cmp(x, 0))
return x / abs(x)
@defun
def agm(ctx, a, b=1):
if b == 1:
return ctx.agm1(a)
a = ctx.convert(a)
b = ctx.convert(b)
return ctx._agm(a, b)
@defun_wrapped
def sinc(ctx, x):
if ctx.isinf(x):
return 1/x
if not x:
return x+1
return ctx.sin(x)/x
@defun_wrapped
def sincpi(ctx, x):
if ctx.isinf(x):
return 1/x
if not x:
return x+1
return ctx.sinpi(x)/(ctx.pi*x)
# TODO: tests; improve implementation
@defun_wrapped
def expm1(ctx, x):
if not x:
return ctx.zero
# exp(x) - 1 ~ x
if ctx.mag(x) < -ctx.prec:
return x + 0.5*x**2
# TODO: accurately eval the smaller of the real/imag parts
return ctx.sum_accurately(lambda: iter([ctx.exp(x),-1]),1)
@defun_wrapped
def powm1(ctx, x, y):
mag = ctx.mag
one = ctx.one
w = x**y - one
M = mag(w)
# Only moderate cancellation
if M > -8:
return w
# Check for the only possible exact cases
if not w:
if (not y) or (x in (1, -1, 1j, -1j) and ctx.isint(y)):
return w
x1 = x - one
magy = mag(y)
lnx = ctx.ln(x)
# Small y: x^y - 1 ~ log(x)*y + O(log(x)^2 * y^2)
if magy + mag(lnx) < -ctx.prec:
return lnx*y + (lnx*y)**2/2
# TODO: accurately eval the smaller of the real/imag part
return ctx.sum_accurately(lambda: iter([x**y, -1]), 1)
@defun
def _rootof1(ctx, k, n):
k = int(k)
n = int(n)
k %= n
if not k:
return ctx.one
elif 2*k == n:
return -ctx.one
elif 4*k == n:
return ctx.j
elif 4*k == 3*n:
return -ctx.j
return ctx.expjpi(2*ctx.mpf(k)/n)
@defun
def root(ctx, x, n, k=0):
n = int(n)
x = ctx.convert(x)
if k:
# Special case: there is an exact real root
if (n & 1 and 2*k == n-1) and (not ctx.im(x)) and (ctx.re(x) < 0):
return -ctx.root(-x, n)
# Multiply by root of unity
prec = ctx.prec
try:
ctx.prec += 10
v = ctx.root(x, n, 0) * ctx._rootof1(k, n)
finally:
ctx.prec = prec
return +v
return ctx._nthroot(x, n)
@defun
def unitroots(ctx, n, primitive=False):
gcd = ctx._gcd
prec = ctx.prec
try:
ctx.prec += 10
if primitive:
v = [ctx._rootof1(k,n) for k in range(n) if gcd(k,n) == 1]
else:
# TODO: this can be done *much* faster
v = [ctx._rootof1(k,n) for k in range(n)]
finally:
ctx.prec = prec
return [+x for x in v]
@defun
def arg(ctx, x):
x = ctx.convert(x)
re = ctx._re(x)
im = ctx._im(x)
return ctx.atan2(im, re)
@defun
def fabs(ctx, x):
return abs(ctx.convert(x))
@defun
def re(ctx, x):
x = ctx.convert(x)
if hasattr(x, "real"): # py2.5 doesn't have .real/.imag for all numbers
return x.real
return x
@defun
def im(ctx, x):
x = ctx.convert(x)
if hasattr(x, "imag"): # py2.5 doesn't have .real/.imag for all numbers
return x.imag
return ctx.zero
@defun
def conj(ctx, x):
return ctx.convert(x).conjugate()
@defun
def polar(ctx, z):
return (ctx.fabs(z), ctx.arg(z))
@defun_wrapped
def rect(ctx, r, phi):
return r * ctx.mpc(*ctx.cos_sin(phi))
@defun
def log(ctx, x, b=None):
if b is None:
return ctx.ln(x)
wp = ctx.prec + 20
return ctx.ln(x, prec=wp) / ctx.ln(b, prec=wp)
@defun
def log10(ctx, x):
return ctx.log(x, 10)
@defun
def modf(ctx, x, y):
return ctx.convert(x) % ctx.convert(y)
@defun
def degrees(ctx, x):
return x / ctx.degree
@defun
def radians(ctx, x):
return x * ctx.degree
@defun_wrapped
def lambertw(ctx, z, k=0):
k = int(k)
if ctx.isnan(z):
return z
ctx.prec += 20
mag = ctx.mag(z)
# Start from fp approximation
if ctx is ctx._mp and abs(mag) < 900 and abs(k) < 10000 and \
abs(z+0.36787944117144) > 0.01:
w = ctx._fp.lambertw(z, k)
else:
absz = abs(z)
# We must be extremely careful near the singularities at -1/e and 0
u = ctx.exp(-1)
if absz <= u:
if not z:
# w(0,0) = 0; for all other branches we hit the pole
if not k:
return z
return ctx.ninf
if not k:
w = z
# For small real z < 0, the -1 branch aves roughly like log(-z)
elif k == -1 and not ctx.im(z) and ctx.re(z) < 0:
w = ctx.ln(-z)
# Use a simple asymptotic approximation.
else:
w = ctx.ln(z)
# The branches are roughly logarithmic. This approximation
# gets better for large |k|; need to check that this always
# works for k ~= -1, 0, 1.
if k: w += k * 2*ctx.pi*ctx.j
elif k == 0 and ctx.im(z) and absz <= 0.7:
# Both the W(z) ~= z and W(z) ~= ln(z) approximations break
# down around z ~= -0.5 (converging to the wrong branch), so patch
# with a constant approximation (adjusted for sign)
if abs(z+0.5) < 0.1:
if ctx.im(z) > 0:
w = ctx.mpc(0.7+0.7j)
else:
w = ctx.mpc(0.7-0.7j)
else:
w = z
else:
if z == ctx.inf:
if k == 0:
return z
else:
return z + 2*k*ctx.pi*ctx.j
if z == ctx.ninf:
return (-z) + (2*k+1)*ctx.pi*ctx.j
# Simple asymptotic approximation as above
w = ctx.ln(z)
if k:
w += k * 2*ctx.pi*ctx.j
# Use Halley iteration to solve w*exp(w) = z
two = ctx.mpf(2)
weps = ctx.ldexp(ctx.eps, 15)
for i in xrange(100):
ew = ctx.exp(w)
wew = w*ew
wewz = wew-z
wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two))
if abs(wn-w) < weps*abs(wn):
return wn
else:
w = wn
ctx.warn("Lambert W iteration failed to converge for %s" % z)
return wn
@defun_wrapped
def bell(ctx, n, x=1):
x = ctx.convert(x)
if not n:
if ctx.isnan(x):
return x
return type(x)(1)
if ctx.isinf(x) or ctx.isinf(n) or ctx.isnan(x) or ctx.isnan(n):
return x**n
if n == 1: return x
if n == 2: return x*(x+1)
if x == 0: return ctx.sincpi(n)
return _polyexp(ctx, n, x, True) / ctx.exp(x)
def _polyexp(ctx, n, x, extra=False):
def _terms():
if extra:
yield ctx.sincpi(n)
t = x
k = 1
while 1:
yield k**n * t
k += 1
t = t*x/k
return ctx.sum_accurately(_terms, check_step=4)
@defun_wrapped
def polyexp(ctx, s, z):
if ctx.isinf(z) or ctx.isinf(s) or ctx.isnan(z) or ctx.isnan(s):
return z**s
if z == 0: return z*s
if s == 0: return ctx.expm1(z)
if s == 1: return ctx.exp(z)*z
if s == 2: return ctx.exp(z)*z*(z+1)
return _polyexp(ctx, s, z)
@defun_wrapped
def cyclotomic(ctx, n, z):
n = int(n)
assert n >= 0
p = ctx.one
if n == 0:
return p
if n == 1:
return z - p
if n == 2:
return z + p
# Use divisor product representation. Unfortunately, this sometimes
# includes singularities for roots of unity, which we have to cancel out.
# Matching zeros/poles pairwise, we have (1-z^a)/(1-z^b) ~ a/b + O(z-1).
a_prod = 1
b_prod = 1
num_zeros = 0
num_poles = 0
for d in range(1,n+1):
if not n % d:
w = ctx.moebius(n//d)
# Use powm1 because it is important that we get 0 only
# if it really is exactly 0
b = -ctx.powm1(z, d)
if b:
p *= b**w
else:
if w == 1:
a_prod *= d
num_zeros += 1
elif w == -1:
b_prod *= d
num_poles += 1
#print n, num_zeros, num_poles
if num_zeros:
if num_zeros > num_poles:
p *= 0
else:
p *= a_prod
p /= b_prod
return p

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,693 +0,0 @@
from functions import defun, defun_wrapped, defun_static
@defun
def stieltjes(ctx, n, a=1):
n = ctx.convert(n)
a = ctx.convert(a)
if n < 0:
return ctx.bad_domain("Stieltjes constants defined for n >= 0")
if hasattr(ctx, "stieltjes_cache"):
stieltjes_cache = ctx.stieltjes_cache
else:
stieltjes_cache = ctx.stieltjes_cache = {}
if a == 1:
if n == 0:
return +ctx.euler
if n in stieltjes_cache:
prec, s = stieltjes_cache[n]
if prec >= ctx.prec:
return +s
mag = 1
def f(x):
xa = x/a
v = (xa-ctx.j)*ctx.ln(a-ctx.j*x)**n/(1+xa**2)/(ctx.exp(2*ctx.pi*x)-1)
return ctx._re(v) / mag
orig = ctx.prec
try:
# Normalize integrand by approx. magnitude to
# speed up quadrature (which uses absolute error)
if n > 50:
ctx.prec = 20
mag = ctx.quad(f, [0,ctx.inf], maxdegree=3)
ctx.prec = orig + 10 + int(n**0.5)
s = ctx.quad(f, [0,ctx.inf], maxdegree=20)
v = ctx.ln(a)**n/(2*a) - ctx.ln(a)**(n+1)/(n+1) + 2*s/a*mag
finally:
ctx.prec = orig
if a == 1 and ctx.isint(n):
stieltjes_cache[n] = (ctx.prec, v)
return +v
@defun_wrapped
def siegeltheta(ctx, t):
if ctx._im(t):
# XXX: cancellation occurs
a = ctx.loggamma(0.25+0.5j*t)
b = ctx.loggamma(0.25-0.5j*t)
return -ctx.ln(ctx.pi)/2*t - 0.5j*(a-b)
else:
if ctx.isinf(t):
return t
return ctx._im(ctx.loggamma(0.25+0.5j*t)) - ctx.ln(ctx.pi)/2*t
@defun_wrapped
def grampoint(ctx, n):
# asymptotic expansion, from
# http://mathworld.wolfram.com/GramPoint.html
g = 2*ctx.pi*ctx.exp(1+ctx.lambertw((8*n+1)/(8*ctx.e)))
return ctx.findroot(lambda t: ctx.siegeltheta(t)-ctx.pi*n, g)
@defun_wrapped
def siegelz(ctx, t):
v = ctx.expj(ctx.siegeltheta(t))*ctx.zeta(0.5+ctx.j*t)
if ctx._is_real_type(t):
return ctx._re(v)
return v
_zeta_zeros = [
14.134725142,21.022039639,25.010857580,30.424876126,32.935061588,
37.586178159,40.918719012,43.327073281,48.005150881,49.773832478,
52.970321478,56.446247697,59.347044003,60.831778525,65.112544048,
67.079810529,69.546401711,72.067157674,75.704690699,77.144840069,
79.337375020,82.910380854,84.735492981,87.425274613,88.809111208,
92.491899271,94.651344041,95.870634228,98.831194218,101.317851006,
103.725538040,105.446623052,107.168611184,111.029535543,111.874659177,
114.320220915,116.226680321,118.790782866,121.370125002,122.946829294,
124.256818554,127.516683880,129.578704200,131.087688531,133.497737203,
134.756509753,138.116042055,139.736208952,141.123707404,143.111845808,
146.000982487,147.422765343,150.053520421,150.925257612,153.024693811,
156.112909294,157.597591818,158.849988171,161.188964138,163.030709687,
165.537069188,167.184439978,169.094515416,169.911976479,173.411536520,
174.754191523,176.441434298,178.377407776,179.916484020,182.207078484,
184.874467848,185.598783678,187.228922584,189.416158656,192.026656361,
193.079726604,195.265396680,196.876481841,198.015309676,201.264751944,
202.493594514,204.189671803,205.394697202,207.906258888,209.576509717,
211.690862595,213.347919360,214.547044783,216.169538508,219.067596349,
220.714918839,221.430705555,224.007000255,224.983324670,227.421444280,
229.337413306,231.250188700,231.987235253,233.693404179,236.524229666,
]
def _load_zeta_zeros(url):
import urllib
d = urllib.urlopen(url)
L = [float(x) for x in d.readlines()]
# Sanity check
assert round(L[0]) == 14
_zeta_zeros[:] = L
@defun
def zetazero(ctx, n, url='http://www.dtc.umn.edu/~odlyzko/zeta_tables/zeros1'):
n = int(n)
if n < 0:
return ctx.zetazero(-n).conjugate()
if n == 0:
raise ValueError("n must be nonzero")
if n > len(_zeta_zeros) and n <= 100000:
_load_zeta_zeros(url)
if n > len(_zeta_zeros):
raise NotImplementedError("n too large for zetazeros")
return ctx.mpc(0.5, ctx.findroot(ctx.siegelz, _zeta_zeros[n-1]))
@defun_wrapped
def riemannr(ctx, x):
if x == 0:
return ctx.zero
# Check if a simple asymptotic estimate is accurate enough
if abs(x) > 1000:
a = ctx.li(x)
b = 0.5*ctx.li(ctx.sqrt(x))
if abs(b) < abs(a)*ctx.eps:
return a
if abs(x) < 0.01:
# XXX
ctx.prec += int(-ctx.log(abs(x),2))
# Sum Gram's series
s = t = ctx.one
u = ctx.ln(x)
k = 1
while abs(t) > abs(s)*ctx.eps:
t = t * u / k
s += t / (k * ctx._zeta_int(k+1))
k += 1
return s
@defun_static
def primepi(ctx, x):
x = int(x)
if x < 2:
return 0
return len(ctx.list_primes(x))
@defun_wrapped
def primepi2(ctx, x):
x = int(x)
if x < 2:
return ctx.mpi(0,0)
if x < 2657:
return ctx.mpi(ctx.primepi(x))
mid = ctx.li(x)
# Schoenfeld's estimate for x >= 2657, assuming RH
err = ctx.sqrt(x,rounding='u')*ctx.ln(x,rounding='u')/8/ctx.pi(rounding='d')
a = ctx.floor((ctx.mpi(mid)-err).a, rounding='d')
b = ctx.ceil((ctx.mpi(mid)+err).b, rounding='u')
return ctx.mpi(a, b)
@defun_wrapped
def primezeta(ctx, s):
if ctx.isnan(s):
return s
if ctx.re(s) <= 0:
raise ValueError("prime zeta function defined only for re(s) > 0")
if s == 1:
return ctx.inf
if s == 0.5:
return ctx.mpc(ctx.ninf, ctx.pi)
r = ctx.re(s)
if r > ctx.prec:
return 0.5**s
else:
wp = ctx.prec + int(r)
def terms():
orig = ctx.prec
# zeta ~ 1+eps; need to set precision
# to get logarithm accurately
k = 0
while 1:
k += 1
u = ctx.moebius(k)
if not u:
continue
ctx.prec = wp
t = u*ctx.ln(ctx.zeta(k*s))/k
if not t:
return
#print ctx.prec, ctx.nstr(t)
ctx.prec = orig
yield t
return ctx.sum_accurately(terms)
# TODO: for bernpoly and eulerpoly, ensure that all exact zeros are covered
@defun_wrapped
def bernpoly(ctx, n, z):
# Slow implementation:
#return sum(ctx.binomial(n,k)*ctx.bernoulli(k)*z**(n-k) for k in xrange(0,n+1))
n = int(n)
if n < 0:
raise ValueError("Bernoulli polynomials only defined for n >= 0")
if z == 0 or (z == 1 and n > 1):
return ctx.bernoulli(n)
if z == 0.5:
return (ctx.ldexp(1,1-n)-1)*ctx.bernoulli(n)
if n <= 3:
if n == 0: return z ** 0
if n == 1: return z - 0.5
if n == 2: return (6*z*(z-1)+1)/6
if n == 3: return z*(z*(z-1.5)+0.5)
if abs(z) == ctx.inf:
return z ** n
if z != z:
return z
if abs(z) > 2:
def terms():
t = ctx.one
yield t
r = ctx.one/z
k = 1
while k <= n:
t = t*(n+1-k)/k*r
if not (k > 2 and k & 1):
yield t*ctx.bernoulli(k)
k += 1
return ctx.sum_accurately(terms) * z**n
else:
def terms():
yield ctx.bernoulli(n)
t = ctx.one
k = 1
while k <= n:
t = t*(n+1-k)/k * z
m = n-k
if not (m > 2 and m & 1):
yield t*ctx.bernoulli(m)
k += 1
return ctx.sum_accurately(terms)
@defun_wrapped
def eulerpoly(ctx, n, z):
n = int(n)
if n < 0:
raise ValueError("Euler polynomials only defined for n >= 0")
if n <= 2:
if n == 0: return z ** 0
if n == 1: return z - 0.5
if n == 2: return z*(z-1)
if abs(z) == ctx.inf:
return z**n
if z != z:
return z
m = n+1
if z == 0:
return -2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
if z == 1:
return 2*(ctx.ldexp(1,m)-1)*ctx.bernoulli(m)/m * z**0
if z == 0.5:
if n % 2:
return ctx.zero
# Use exact code for Euler numbers
if n < 100 or n*ctx.mag(0.46839865*n) < ctx.prec*0.25:
return ctx.ldexp(ctx._eulernum(n), -n)
# http://functions.wolfram.com/Polynomials/EulerE2/06/01/02/01/0002/
def terms():
t = ctx.one
k = 0
w = ctx.ldexp(1,n+2)
while 1:
v = n-k+1
if not (v > 2 and v & 1):
yield (2-w)*ctx.bernoulli(v)*t
k += 1
if k > n:
break
t = t*z*(n-k+2)/k
w *= 0.5
return ctx.sum_accurately(terms) / m
@defun
def eulernum(ctx, n, exact=False):
n = int(n)
if exact:
return int(ctx._eulernum(n))
if n < 100:
return ctx.mpf(ctx._eulernum(n))
if n % 2:
return ctx.zero
return ctx.ldexp(ctx.eulerpoly(n,0.5), n)
# TODO: this should be implemented low-level
def polylog_series(ctx, s, z):
tol = +ctx.eps
l = ctx.zero
k = 1
zk = z
while 1:
term = zk / k**s
l += term
if abs(term) < tol:
break
zk *= z
k += 1
return l
def polylog_continuation(ctx, n, z):
if n < 0:
return z*0
twopij = 2j * ctx.pi
a = -twopij**n/ctx.fac(n) * ctx.bernpoly(n, ctx.ln(z)/twopij)
if ctx._is_real_type(z) and z < 0:
a = ctx._re(a)
if ctx._im(z) < 0 or (ctx._im(z) == 0 and ctx._re(z) >= 1):
a -= twopij*ctx.ln(z)**(n-1)/ctx.fac(n-1)
return a
def polylog_unitcircle(ctx, n, z):
tol = +ctx.eps
if n > 1:
l = ctx.zero
logz = ctx.ln(z)
logmz = ctx.one
m = 0
while 1:
if (n-m) != 1:
term = ctx.zeta(n-m) * logmz / ctx.fac(m)
if term and abs(term) < tol:
break
l += term
logmz *= logz
m += 1
l += ctx.ln(z)**(n-1)/ctx.fac(n-1)*(ctx.harmonic(n-1)-ctx.ln(-ctx.ln(z)))
elif n < 1: # else
l = ctx.fac(-n)*(-ctx.ln(z))**(n-1)
logz = ctx.ln(z)
logkz = ctx.one
k = 0
while 1:
b = ctx.bernoulli(k-n+1)
if b:
term = b*logkz/(ctx.fac(k)*(k-n+1))
if abs(term) < tol:
break
l -= term
logkz *= logz
k += 1
else:
raise ValueError
if ctx._is_real_type(z) and z < 0:
l = ctx._re(l)
return l
def polylog_general(ctx, s, z):
v = ctx.zero
u = ctx.ln(z)
if not abs(u) < 5: # theoretically |u| < 2*pi
raise NotImplementedError("polylog for arbitrary s and z")
t = 1
k = 0
while 1:
term = ctx.zeta(s-k) * t
if abs(term) < ctx.eps:
break
v += term
k += 1
t *= u
t /= k
return ctx.gamma(1-s)*(-u)**(s-1) + v
@defun_wrapped
def polylog(ctx, s, z):
s = ctx.convert(s)
z = ctx.convert(z)
if z == 1:
return ctx.zeta(s)
if z == -1:
return -ctx.altzeta(s)
if s == 0:
return z/(1-z)
if s == 1:
return -ctx.ln(1-z)
if s == -1:
return z/(1-z)**2
if abs(z) <= 0.75 or (not ctx.isint(s) and abs(z) < 0.9):
return polylog_series(ctx, s, z)
if abs(z) >= 1.4 and ctx.isint(s):
return (-1)**(s+1)*polylog_series(ctx, s, 1/z) + polylog_continuation(ctx, s, z)
if ctx.isint(s):
return polylog_unitcircle(ctx, int(s), z)
return polylog_general(ctx, s, z)
#raise NotImplementedError("polylog for arbitrary s and z")
# This could perhaps be used in some cases
#from quadrature import quad
#return quad(lambda t: t**(s-1)/(exp(t)/z-1),[0,inf])/gamma(s)
@defun_wrapped
def clsin(ctx, s, z, pi=False):
if ctx.isint(s) and s < 0 and int(s) % 2 == 1:
return z*0
if pi:
a = ctx.expjpi(z)
else:
a = ctx.expj(z)
if ctx._is_real_type(z) and ctx._is_real_type(s):
return ctx.im(ctx.polylog(s,a))
b = 1/a
return (-0.5j)*(ctx.polylog(s,a) - ctx.polylog(s,b))
@defun_wrapped
def clcos(ctx, s, z, pi=False):
if ctx.isint(s) and s < 0 and int(s) % 2 == 0:
return z*0
if pi:
a = ctx.expjpi(z)
else:
a = ctx.expj(z)
if ctx._is_real_type(z) and ctx._is_real_type(s):
return ctx.re(ctx.polylog(s,a))
b = 1/a
return 0.5*(ctx.polylog(s,a) + ctx.polylog(s,b))
@defun
def altzeta(ctx, s, **kwargs):
try:
return ctx._altzeta(s, **kwargs)
except NotImplementedError:
return ctx._altzeta_generic(s)
@defun_wrapped
def _altzeta_generic(ctx, s):
if s == 1:
return ctx.ln2 + 0*s
return -ctx.powm1(2, 1-s) * ctx.zeta(s)
@defun
def zeta(ctx, s, a=1, derivative=0, method=None, **kwargs):
d = int(derivative)
if a == 1 and not (d or method):
try:
return ctx._zeta(s, **kwargs)
except NotImplementedError:
pass
s = ctx.convert(s)
prec = ctx.prec
method = kwargs.get('method')
verbose = kwargs.get('verbose')
if a == 1 and method != 'euler-maclaurin':
im = abs(ctx._im(s))
re = abs(ctx._re(s))
#if (im < prec or method == 'borwein') and not derivative:
# try:
# if verbose:
# print "zeta: Attempting to use the Borwein algorithm"
# return ctx._zeta(s, **kwargs)
# except NotImplementedError:
# if verbose:
# print "zeta: Could not use the Borwein algorithm"
# pass
if abs(im) > 60*prec and 10*re < prec and derivative <= 4 or \
method == 'riemann-siegel':
try: # py2.4 compatible try block
try:
if verbose:
print "zeta: Attempting to use the Riemann-Siegel algorithm"
return ctx.rs_zeta(s, derivative, **kwargs)
except NotImplementedError:
if verbose:
print "zeta: Could not use the Riemann-Siegel algorithm"
pass
finally:
ctx.prec = prec
if s == 1:
return ctx.inf
abss = abs(s)
if abss == ctx.inf:
if ctx.re(s) == ctx.inf:
if d == 0:
return ctx.one
return ctx.zero
return s*0
elif ctx.isnan(abss):
return 1/s
if ctx.re(s) > 2*ctx.prec and a == 1 and not derivative:
return ctx.one + ctx.power(2, -s)
if verbose:
print "zeta: Using the Euler-Maclaurin algorithm"
prec = ctx.prec
try:
ctx.prec += 10
v = ctx._hurwitz(s, a, d)
finally:
ctx.prec = prec
return +v
@defun
def _hurwitz(ctx, s, a=1, d=0):
# We strongly want to special-case rational a
a, atype = ctx._convert_param(a)
prec = ctx.prec
# TODO: implement reflection for derivatives
res = ctx.re(s)
negs = -s
try:
if res < 0 and not d:
# Integer reflection formula
if ctx.isnpint(s):
n = int(res)
if n <= 0:
return ctx.bernpoly(1-n, a) / (n-1)
t = 1-s
# We now require a to be standardized
v = 0
shift = 0
b = a
while ctx.re(b) > 1:
b -= 1
v -= b**negs
shift -= 1
while ctx.re(b) <= 0:
v += b**negs
b += 1
shift += 1
# Rational reflection formula
if atype == 'Q' or atype == 'Z':
try:
p, q = a
except:
assert a == int(a)
p = int(a)
q = 1
p += shift*q
assert 1 <= p <= q
g = ctx.fsum(ctx.cospi(t/2-2*k*b)*ctx._hurwitz(t,(k,q)) \
for k in range(1,q+1))
g *= 2*ctx.gamma(t)/(2*ctx.pi*q)**t
v += g
return v
# General reflection formula
else:
C1 = ctx.cospi(t/2)
C2 = ctx.sinpi(t/2)
# Clausen functions; could maybe use polylog directly
if C1: C1 *= ctx.clcos(t, 2*a, pi=True)
if C2: C2 *= ctx.clsin(t, 2*a, pi=True)
v += 2*ctx.gamma(t)/(2*ctx.pi)**t*(C1+C2)
return v
except NotImplementedError:
pass
a = ctx.convert(a)
tol = -prec
# Estimate number of terms for Euler-Maclaurin summation; could be improved
M1 = 0
M2 = prec // 3
N = M2
lsum = 0
# This speeds up the recurrence for derivatives
if ctx.isint(s):
s = int(ctx._re(s))
s1 = s-1
while 1:
# Truncated L-series
l = ctx._zetasum(s, M1+a, M2-M1-1, [d])[0][0]
#if d:
# l = ctx.fsum((-ctx.ln(n+a))**d * (n+a)**negs for n in range(M1,M2))
#else:
# l = ctx.fsum((n+a)**negs for n in range(M1,M2))
lsum += l
M2a = M2+a
logM2a = ctx.ln(M2a)
logM2ad = logM2a**d
logs = [logM2ad]
logr = 1/logM2a
rM2a = 1/M2a
M2as = rM2a**s
if d:
tailsum = ctx.gammainc(d+1, s1*logM2a) / s1**(d+1)
else:
tailsum = 1/((s1)*(M2a)**s1)
tailsum += 0.5 * logM2ad * M2as
U = [1]
r = M2as
fact = 2
for j in range(1, N+1):
# TODO: the following could perhaps be tidied a bit
j2 = 2*j
if j == 1:
upds = [1]
else:
upds = [j2-2, j2-1]
for m in upds:
D = min(m,d+1)
if m <= d:
logs.append(logs[-1] * logr)
Un = [0]*(D+1)
for i in xrange(D): Un[i] = (1-m-s)*U[i]
for i in xrange(1,D+1): Un[i] += (d-(i-1))*U[i-1]
U = Un
r *= rM2a
t = ctx.fdot(U, logs) * r * ctx.bernoulli(j2)/(-fact)
tailsum += t
if ctx.mag(t) < tol:
return lsum + (-1)**d * tailsum
fact *= (j2+1)*(j2+2)
M1, M2 = M2, M2*2
@defun
def _zetasum(ctx, s, a, n, derivatives=[0], reflect=False):
"""
Returns [xd0,xd1,...,xdr], [yd0,yd1,...ydr] where
xdk = D^k ( 1/a^s + 1/(a+1)^s + ... + 1/(a+n)^s )
ydk = D^k conj( 1/a^(1-s) + 1/(a+1)^(1-s) + ... + 1/(a+n)^(1-s) )
D^k = kth derivative with respect to s, k ranges over the given list of
derivatives (which should consist of either a single element
or a range 0,1,...r). If reflect=False, the ydks are not computed.
"""
try:
return ctx._zetasum_fast(s, a, n, derivatives, reflect)
except NotImplementedError:
pass
negs = ctx.fneg(s, exact=True)
have_derivatives = derivatives != [0]
have_one_derivative = len(derivatives) == 1
if not reflect:
if not have_derivatives:
return [ctx.fsum((a+k)**negs for k in xrange(n+1))], []
if have_one_derivative:
d = derivatives[0]
x = ctx.fsum(ctx.ln(a+k)**d * (a+k)**negs for k in xrange(n+1))
return [(-1)**d * x], []
maxd = max(derivatives)
if not have_one_derivative:
derivatives = range(maxd+1)
xs = [ctx.zero for d in derivatives]
if reflect:
ys = [ctx.zero for d in derivatives]
else:
ys = []
for k in xrange(n+1):
w = a + k
xterm = w ** negs
if reflect:
yterm = ctx.conj(ctx.one / (w * xterm))
if have_derivatives:
logw = -ctx.ln(w)
if have_one_derivative:
logw = logw ** maxd
xs[0] += xterm * logw
if reflect:
ys[0] += yterm * logw
else:
t = ctx.one
for d in derivatives:
xs[d] += xterm * t
if reflect:
ys[d] += yterm * t
t *= logw
else:
xs[0] += xterm
if reflect:
ys[0] += yterm
return xs, ys
@defun
def dirichlet(ctx, s, chi=[1], derivative=0):
s = ctx.convert(s)
q = len(chi)
d = int(derivative)
if d > 2:
raise NotImplementedError("arbitrary order derivatives")
prec = ctx.prec
try:
ctx.prec += 10
if s == 1:
have_pole = True
for x in chi:
if x and x != 1:
have_pole = False
h = +ctx.eps
ctx.prec *= 2*(d+1)
s += h
if have_pole:
return +ctx.inf
z = ctx.zero
for p in range(1,q+1):
if chi[p%q]:
if d == 1:
z += chi[p%q] * (ctx.zeta(s, (p,q), 1) - \
ctx.zeta(s, (p,q))*ctx.log(q))
else:
z += chi[p%q] * ctx.zeta(s, (p,q))
z /= q**s
finally:
ctx.prec = prec
return +z

View File

@ -1,840 +0,0 @@
"""
Implements the PSLQ algorithm for integer relation detection,
and derivative algorithms for constant recognition.
"""
from libmp import int_types, sqrt_fixed
# round to nearest integer (can be done more elegantly...)
def round_fixed(x, prec):
return ((x + (1<<(prec-1))) >> prec) << prec
class IdentificationMethods(object):
pass
def pslq(ctx, x, tol=None, maxcoeff=1000, maxsteps=100, verbose=False):
r"""
Given a vector of real numbers `x = [x_0, x_1, ..., x_n]`, ``pslq(x)``
uses the PSLQ algorithm to find a list of integers
`[c_0, c_1, ..., c_n]` such that
.. math ::
|c_1 x_1 + c_2 x_2 + ... + c_n x_n| < \mathrm{tol}
and such that `\max |c_k| < \mathrm{maxcoeff}`. If no such vector
exists, :func:`pslq` returns ``None``. The tolerance defaults to
3/4 of the working precision.
**Examples**
Find rational approximations for `\pi`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> pslq([-1, pi], tol=0.01)
[22, 7]
>>> pslq([-1, pi], tol=0.001)
[355, 113]
>>> mpf(22)/7; mpf(355)/113; +pi
3.14285714285714
3.14159292035398
3.14159265358979
Pi is not a rational number with denominator less than 1000::
>>> pslq([-1, pi])
>>>
To within the standard precision, it can however be approximated
by at least one rational number with denominator less than `10^{12}`::
>>> p, q = pslq([-1, pi], maxcoeff=10**12)
>>> print p, q
238410049439 75888275702
>>> mpf(p)/q
3.14159265358979
The PSLQ algorithm can be applied to long vectors. For example,
we can investigate the rational (in)dependence of integer square
roots::
>>> mp.dps = 30
>>> pslq([sqrt(n) for n in range(2, 5+1)])
>>>
>>> pslq([sqrt(n) for n in range(2, 6+1)])
>>>
>>> pslq([sqrt(n) for n in range(2, 8+1)])
[2, 0, 0, 0, 0, 0, -1]
**Machin formulas**
A famous formula for `\pi` is Machin's,
.. math ::
\frac{\pi}{4} = 4 \operatorname{acot} 5 - \operatorname{acot} 239
There are actually infinitely many formulas of this type. Two
others are
.. math ::
\frac{\pi}{4} = \operatorname{acot} 1
\frac{\pi}{4} = 12 \operatorname{acot} 49 + 32 \operatorname{acot} 57
+ 5 \operatorname{acot} 239 + 12 \operatorname{acot} 110443
We can easily verify the formulas using the PSLQ algorithm::
>>> mp.dps = 30
>>> pslq([pi/4, acot(1)])
[1, -1]
>>> pslq([pi/4, acot(5), acot(239)])
[1, -4, 1]
>>> pslq([pi/4, acot(49), acot(57), acot(239), acot(110443)])
[1, -12, -32, 5, -12]
We could try to generate a custom Machin-like formula by running
the PSLQ algorithm with a few inverse cotangent values, for example
acot(2), acot(3) ... acot(10). Unfortunately, there is a linear
dependence among these values, resulting in only that dependence
being detected, with a zero coefficient for `\pi`::
>>> pslq([pi] + [acot(n) for n in range(2,11)])
[0, 1, -1, 0, 0, 0, -1, 0, 0, 0]
We get better luck by removing linearly dependent terms::
>>> pslq([pi] + [acot(n) for n in range(2,11) if n not in (3, 5)])
[1, -8, 0, 0, 4, 0, 0, 0]
In other words, we found the following formula::
>>> 8*acot(2) - 4*acot(7)
3.14159265358979323846264338328
>>> +pi
3.14159265358979323846264338328
**Algorithm**
This is a fairly direct translation to Python of the pseudocode given by
David Bailey, "The PSLQ Integer Relation Algorithm":
http://www.cecm.sfu.ca/organics/papers/bailey/paper/html/node3.html
The present implementation uses fixed-point instead of floating-point
arithmetic, since this is significantly (about 7x) faster.
"""
n = len(x)
assert n >= 2
# At too low precision, the algorithm becomes meaningless
prec = ctx.prec
assert prec >= 53
if verbose and prec // max(2,n) < 5:
print "Warning: precision for PSLQ may be too low"
target = int(prec * 0.75)
if tol is None:
tol = ctx.mpf(2)**(-target)
else:
tol = ctx.convert(tol)
extra = 60
prec += extra
if verbose:
print "PSLQ using prec %i and tol %s" % (prec, ctx.nstr(tol))
tol = ctx.to_fixed(tol, prec)
assert tol
# Convert to fixed-point numbers. The dummy None is added so we can
# use 1-based indexing. (This just allows us to be consistent with
# Bailey's indexing. The algorithm is 100 lines long, so debugging
# a single wrong index can be painful.)
x = [None] + [ctx.to_fixed(ctx.mpf(xk), prec) for xk in x]
# Sanity check on magnitudes
minx = min(abs(xx) for xx in x[1:])
if not minx:
raise ValueError("PSLQ requires a vector of nonzero numbers")
if minx < tol//100:
if verbose:
print "STOPPING: (one number is too small)"
return None
g = sqrt_fixed((4<<prec)//3, prec)
A = {}
B = {}
H = {}
# Initialization
# step 1
for i in xrange(1, n+1):
for j in xrange(1, n+1):
A[i,j] = B[i,j] = (i==j) << prec
H[i,j] = 0
# step 2
s = [None] + [0] * n
for k in xrange(1, n+1):
t = 0
for j in xrange(k, n+1):
t += (x[j]**2 >> prec)
s[k] = sqrt_fixed(t, prec)
t = s[1]
y = x[:]
for k in xrange(1, n+1):
y[k] = (x[k] << prec) // t
s[k] = (s[k] << prec) // t
# step 3
for i in xrange(1, n+1):
for j in xrange(i+1, n):
H[i,j] = 0
if i <= n-1:
if s[i]:
H[i,i] = (s[i+1] << prec) // s[i]
else:
H[i,i] = 0
for j in range(1, i):
sjj1 = s[j]*s[j+1]
if sjj1:
H[i,j] = ((-y[i]*y[j])<<prec)//sjj1
else:
H[i,j] = 0
# step 4
for i in xrange(2, n+1):
for j in xrange(i-1, 0, -1):
#t = floor(H[i,j]/H[j,j] + 0.5)
if H[j,j]:
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
else:
#t = 0
continue
y[j] = y[j] + (t*y[i] >> prec)
for k in xrange(1, j+1):
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
for k in xrange(1, n+1):
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
# Main algorithm
for REP in range(maxsteps):
# Step 1
m = -1
szmax = -1
for i in range(1, n):
h = H[i,i]
sz = (g**i * abs(h)) >> (prec*(i-1))
if sz > szmax:
m = i
szmax = sz
# Step 2
y[m], y[m+1] = y[m+1], y[m]
tmp = {}
for i in xrange(1,n+1): H[m,i], H[m+1,i] = H[m+1,i], H[m,i]
for i in xrange(1,n+1): A[m,i], A[m+1,i] = A[m+1,i], A[m,i]
for i in xrange(1,n+1): B[i,m], B[i,m+1] = B[i,m+1], B[i,m]
# Step 3
if m <= n - 2:
t0 = sqrt_fixed((H[m,m]**2 + H[m,m+1]**2)>>prec, prec)
# A zero element probably indicates that the precision has
# been exhausted. XXX: this could be spurious, due to
# using fixed-point arithmetic
if not t0:
break
t1 = (H[m,m] << prec) // t0
t2 = (H[m,m+1] << prec) // t0
for i in xrange(m, n+1):
t3 = H[i,m]
t4 = H[i,m+1]
H[i,m] = (t1*t3+t2*t4) >> prec
H[i,m+1] = (-t2*t3+t1*t4) >> prec
# Step 4
for i in xrange(m+1, n+1):
for j in xrange(min(i-1, m+1), 0, -1):
try:
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
# Precision probably exhausted
except ZeroDivisionError:
break
y[j] = y[j] + ((t*y[i]) >> prec)
for k in xrange(1, j+1):
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
for k in xrange(1, n+1):
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
# Until a relation is found, the error typically decreases
# slowly (e.g. a factor 1-10) with each step TODO: we could
# compare err from two successive iterations. If there is a
# large drop (several orders of magnitude), that indicates a
# "high quality" relation was detected. Reporting this to
# the user somehow might be useful.
best_err = maxcoeff<<prec
for i in xrange(1, n+1):
err = abs(y[i])
# Maybe we are done?
if err < tol:
# We are done if the coefficients are acceptable
vec = [int(round_fixed(B[j,i], prec) >> prec) for j in \
range(1,n+1)]
if max(abs(v) for v in vec) < maxcoeff:
if verbose:
print "FOUND relation at iter %i/%i, error: %s" % \
(REP, maxsteps, ctx.nstr(err / ctx.mpf(2)**prec, 1))
return vec
best_err = min(err, best_err)
# Calculate a lower bound for the norm. We could do this
# more exactly (using the Euclidean norm) but there is probably
# no practical benefit.
recnorm = max(abs(h) for h in H.values())
if recnorm:
norm = ((1 << (2*prec)) // recnorm) >> prec
norm //= 100
else:
norm = ctx.inf
if verbose:
print "%i/%i: Error: %8s Norm: %s" % \
(REP, maxsteps, ctx.nstr(best_err / ctx.mpf(2)**prec, 1), norm)
if norm >= maxcoeff:
break
if verbose:
print "CANCELLING after step %i/%i." % (REP, maxsteps)
print "Could not find an integer relation. Norm bound: %s" % norm
return None
def findpoly(ctx, x, n=1, **kwargs):
r"""
``findpoly(x, n)`` returns the coefficients of an integer
polynomial `P` of degree at most `n` such that `P(x) \approx 0`.
If no polynomial having `x` as a root can be found,
:func:`findpoly` returns ``None``.
:func:`findpoly` works by successively calling :func:`pslq` with
the vectors `[1, x]`, `[1, x, x^2]`, `[1, x, x^2, x^3]`, ...,
`[1, x, x^2, .., x^n]` as input. Keyword arguments given to
:func:`findpoly` are forwarded verbatim to :func:`pslq`. In
particular, you can specify a tolerance for `P(x)` with ``tol``
and a maximum permitted coefficient size with ``maxcoeff``.
For large values of `n`, it is recommended to run :func:`findpoly`
at high precision; preferably 50 digits or more.
**Examples**
By default (degree `n = 1`), :func:`findpoly` simply finds a linear
polynomial with a rational root::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> findpoly(0.7)
[-10, 7]
The generated coefficient list is valid input to ``polyval`` and
``polyroots``::
>>> nprint(polyval(findpoly(phi, 2), phi), 1)
-2.0e-16
>>> for r in polyroots(findpoly(phi, 2)):
... print r
...
-0.618033988749895
1.61803398874989
Numbers of the form `m + n \sqrt p` for integers `(m, n, p)` are
solutions to quadratic equations. As we find here, `1+\sqrt 2`
is a root of the polynomial `x^2 - 2x - 1`::
>>> findpoly(1+sqrt(2), 2)
[1, -2, -1]
>>> findroot(lambda x: x**2 - 2*x - 1, 1)
2.4142135623731
Despite only containing square roots, the following number results
in a polynomial of degree 4::
>>> findpoly(sqrt(2)+sqrt(3), 4)
[1, 0, -10, 0, 1]
In fact, `x^4 - 10x^2 + 1` is the *minimal polynomial* of
`r = \sqrt 2 + \sqrt 3`, meaning that a rational polynomial of
lower degree having `r` as a root does not exist. Given sufficient
precision, :func:`findpoly` will usually find the correct
minimal polynomial of a given algebraic number.
**Non-algebraic numbers**
If :func:`findpoly` fails to find a polynomial with given
coefficient size and tolerance constraints, that means no such
polynomial exists.
We can verify that `\pi` is not an algebraic number of degree 3 with
coefficients less than 1000::
>>> mp.dps = 15
>>> findpoly(pi, 3)
>>>
It is always possible to find an algebraic approximation of a number
using one (or several) of the following methods:
1. Increasing the permitted degree
2. Allowing larger coefficients
3. Reducing the tolerance
One example of each method is shown below::
>>> mp.dps = 15
>>> findpoly(pi, 4)
[95, -545, 863, -183, -298]
>>> findpoly(pi, 3, maxcoeff=10000)
[836, -1734, -2658, -457]
>>> findpoly(pi, 3, tol=1e-7)
[-4, 22, -29, -2]
It is unknown whether Euler's constant is transcendental (or even
irrational). We can use :func:`findpoly` to check that if is
an algebraic number, its minimal polynomial must have degree
at least 7 and a coefficient of magnitude at least 1000000::
>>> mp.dps = 200
>>> findpoly(euler, 6, maxcoeff=10**6, tol=1e-100, maxsteps=1000)
>>>
Note that the high precision and strict tolerance is necessary
for such high-degree runs, since otherwise unwanted low-accuracy
approximations will be detected. It may also be necessary to set
maxsteps high to prevent a premature exit (before the coefficient
bound has been reached). Running with ``verbose=True`` to get an
idea what is happening can be useful.
"""
x = ctx.mpf(x)
assert n >= 1
if x == 0:
return [1, 0]
xs = [ctx.mpf(1)]
for i in range(1,n+1):
xs.append(x**i)
a = ctx.pslq(xs, **kwargs)
if a is not None:
return a[::-1]
def fracgcd(p, q):
x, y = p, q
while y:
x, y = y, x % y
if x != 1:
p //= x
q //= x
if q == 1:
return p
return p, q
def pslqstring(r, constants):
q = r[0]
r = r[1:]
s = []
for i in range(len(r)):
p = r[i]
if p:
z = fracgcd(-p,q)
cs = constants[i][1]
if cs == '1':
cs = ''
else:
cs = '*' + cs
if isinstance(z, int_types):
if z > 0: term = str(z) + cs
else: term = ("(%s)" % z) + cs
else:
term = ("(%s/%s)" % z) + cs
s.append(term)
s = ' + '.join(s)
if '+' in s or '*' in s:
s = '(' + s + ')'
return s or '0'
def prodstring(r, constants):
q = r[0]
r = r[1:]
num = []
den = []
for i in range(len(r)):
p = r[i]
if p:
z = fracgcd(-p,q)
cs = constants[i][1]
if isinstance(z, int_types):
if abs(z) == 1: t = cs
else: t = '%s**%s' % (cs, abs(z))
([num,den][z<0]).append(t)
else:
t = '%s**(%s/%s)' % (cs, abs(z[0]), z[1])
([num,den][z[0]<0]).append(t)
num = '*'.join(num)
den = '*'.join(den)
if num and den: return "(%s)/(%s)" % (num, den)
if num: return num
if den: return "1/(%s)" % den
def quadraticstring(ctx,t,a,b,c):
if c < 0:
a,b,c = -a,-b,-c
u1 = (-b+ctx.sqrt(b**2-4*a*c))/(2*c)
u2 = (-b-ctx.sqrt(b**2-4*a*c))/(2*c)
if abs(u1-t) < abs(u2-t):
if b: s = '((%s+sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
else: s = '(sqrt(%s)/%s)' % (-4*a*c,2*c)
else:
if b: s = '((%s-sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
else: s = '(-sqrt(%s)/%s)' % (-4*a*c,2*c)
return s
# Transformation y = f(x,c), with inverse function x = f(y,c)
# The third entry indicates whether the transformation is
# redundant when c = 1
transforms = [
(lambda ctx,x,c: x*c, '$y/$c', 0),
(lambda ctx,x,c: x/c, '$c*$y', 1),
(lambda ctx,x,c: c/x, '$c/$y', 0),
(lambda ctx,x,c: (x*c)**2, 'sqrt($y)/$c', 0),
(lambda ctx,x,c: (x/c)**2, '$c*sqrt($y)', 1),
(lambda ctx,x,c: (c/x)**2, '$c/sqrt($y)', 0),
(lambda ctx,x,c: c*x**2, 'sqrt($y)/sqrt($c)', 1),
(lambda ctx,x,c: x**2/c, 'sqrt($c)*sqrt($y)', 1),
(lambda ctx,x,c: c/x**2, 'sqrt($c)/sqrt($y)', 1),
(lambda ctx,x,c: ctx.sqrt(x*c), '$y**2/$c', 0),
(lambda ctx,x,c: ctx.sqrt(x/c), '$c*$y**2', 1),
(lambda ctx,x,c: ctx.sqrt(c/x), '$c/$y**2', 0),
(lambda ctx,x,c: c*ctx.sqrt(x), '$y**2/$c**2', 1),
(lambda ctx,x,c: ctx.sqrt(x)/c, '$c**2*$y**2', 1),
(lambda ctx,x,c: c/ctx.sqrt(x), '$c**2/$y**2', 1),
(lambda ctx,x,c: ctx.exp(x*c), 'log($y)/$c', 0),
(lambda ctx,x,c: ctx.exp(x/c), '$c*log($y)', 1),
(lambda ctx,x,c: ctx.exp(c/x), '$c/log($y)', 0),
(lambda ctx,x,c: c*ctx.exp(x), 'log($y/$c)', 1),
(lambda ctx,x,c: ctx.exp(x)/c, 'log($c*$y)', 1),
(lambda ctx,x,c: c/ctx.exp(x), 'log($c/$y)', 0),
(lambda ctx,x,c: ctx.ln(x*c), 'exp($y)/$c', 0),
(lambda ctx,x,c: ctx.ln(x/c), '$c*exp($y)', 1),
(lambda ctx,x,c: ctx.ln(c/x), '$c/exp($y)', 0),
(lambda ctx,x,c: c*ctx.ln(x), 'exp($y/$c)', 1),
(lambda ctx,x,c: ctx.ln(x)/c, 'exp($c*$y)', 1),
(lambda ctx,x,c: c/ctx.ln(x), 'exp($c/$y)', 0),
]
def identify(ctx, x, constants=[], tol=None, maxcoeff=1000, full=False,
verbose=False):
"""
Given a real number `x`, ``identify(x)`` attempts to find an exact
formula for `x`. This formula is returned as a string. If no match
is found, ``None`` is returned. With ``full=True``, a list of
matching formulas is returned.
As a simple example, :func:`identify` will find an algebraic
formula for the golden ratio::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> identify(phi)
'((1+sqrt(5))/2)'
:func:`identify` can identify simple algebraic numbers and simple
combinations of given base constants, as well as certain basic
transformations thereof. More specifically, :func:`identify`
looks for the following:
1. Fractions
2. Quadratic algebraic numbers
3. Rational linear combinations of the base constants
4. Any of the above after first transforming `x` into `f(x)` where
`f(x)` is `1/x`, `\sqrt x`, `x^2`, `\log x` or `\exp x`, either
directly or with `x` or `f(x)` multiplied or divided by one of
the base constants
5. Products of fractional powers of the base constants and
small integers
Base constants can be given as a list of strings representing mpmath
expressions (:func:`identify` will ``eval`` the strings to numerical
values and use the original strings for the output), or as a dict of
formula:value pairs.
In order not to produce spurious results, :func:`identify` should
be used with high precision; preferrably 50 digits or more.
**Examples**
Simple identifications can be performed safely at standard
precision. Here the default recognition of rational, algebraic,
and exp/log of algebraic numbers is demonstrated::
>>> mp.dps = 15
>>> identify(0.22222222222222222)
'(2/9)'
>>> identify(1.9662210973805663)
'sqrt(((24+sqrt(48))/8))'
>>> identify(4.1132503787829275)
'exp((sqrt(8)/2))'
>>> identify(0.881373587019543)
'log(((2+sqrt(8))/2))'
By default, :func:`identify` does not recognize `\pi`. At standard
precision it finds a not too useful approximation. At slightly
increased precision, this approximation is no longer accurate
enough and :func:`identify` more correctly returns ``None``::
>>> identify(pi)
'(2**(176/117)*3**(20/117)*5**(35/39))/(7**(92/117))'
>>> mp.dps = 30
>>> identify(pi)
>>>
Numbers such as `\pi`, and simple combinations of user-defined
constants, can be identified if they are provided explicitly::
>>> identify(3*pi-2*e, ['pi', 'e'])
'(3*pi + (-2)*e)'
Here is an example using a dict of constants. Note that the
constants need not be "atomic"; :func:`identify` can just
as well express the given number in terms of expressions
given by formulas::
>>> identify(pi+e, {'a':pi+2, 'b':2*e})
'((-2) + 1*a + (1/2)*b)'
Next, we attempt some identifications with a set of base constants.
It is necessary to increase the precision a bit.
>>> mp.dps = 50
>>> base = ['sqrt(2)','pi','log(2)']
>>> identify(0.25, base)
'(1/4)'
>>> identify(3*pi + 2*sqrt(2) + 5*log(2)/7, base)
'(2*sqrt(2) + 3*pi + (5/7)*log(2))'
>>> identify(exp(pi+2), base)
'exp((2 + 1*pi))'
>>> identify(1/(3+sqrt(2)), base)
'((3/7) + (-1/7)*sqrt(2))'
>>> identify(sqrt(2)/(3*pi+4), base)
'sqrt(2)/(4 + 3*pi)'
>>> identify(5**(mpf(1)/3)*pi*log(2)**2, base)
'5**(1/3)*pi*log(2)**2'
An example of an erroneous solution being found when too low
precision is used::
>>> mp.dps = 15
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
'((11/25) + (-158/75)*pi + (76/75)*e + (44/15)*sqrt(2))'
>>> mp.dps = 50
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
'1/(3*pi + (-4)*e + 2*sqrt(2))'
**Finding approximate solutions**
The tolerance ``tol`` defaults to 3/4 of the working precision.
Lowering the tolerance is useful for finding approximate matches.
We can for example try to generate approximations for pi::
>>> mp.dps = 15
>>> identify(pi, tol=1e-2)
'(22/7)'
>>> identify(pi, tol=1e-3)
'(355/113)'
>>> identify(pi, tol=1e-10)
'(5**(339/269))/(2**(64/269)*3**(13/269)*7**(92/269))'
With ``full=True``, and by supplying a few base constants,
``identify`` can generate almost endless lists of approximations
for any number (the output below has been truncated to show only
the first few)::
>>> for p in identify(pi, ['e', 'catalan'], tol=1e-5, full=True):
... print p
... # doctest: +ELLIPSIS
e/log((6 + (-4/3)*e))
(3**3*5*e*catalan**2)/(2*7**2)
sqrt(((-13) + 1*e + 22*catalan))
log(((-6) + 24*e + 4*catalan)/e)
exp(catalan*((-1/5) + (8/15)*e))
catalan*(6 + (-6)*e + 15*catalan)
sqrt((5 + 26*e + (-3)*catalan))/e
e*sqrt(((-27) + 2*e + 25*catalan))
log(((-1) + (-11)*e + 59*catalan))
((3/20) + (21/20)*e + (3/20)*catalan)
...
The numerical values are roughly as close to pi as permitted by the
specified tolerance:
>>> e/log(6-4*e/3)
3.14157719846001
>>> 135*e*catalan**2/98
3.14166950419369
>>> sqrt(e-13+22*catalan)
3.14158000062992
>>> log(24*e-6+4*catalan)-1
3.14158791577159
**Symbolic processing**
The output formula can be evaluated as a Python expression.
Note however that if fractions (like '2/3') are present in
the formula, Python's :func:`eval()` may erroneously perform
integer division. Note also that the output is not necessarily
in the algebraically simplest form::
>>> identify(sqrt(2))
'(sqrt(8)/2)'
As a solution to both problems, consider using SymPy's
:func:`sympify` to convert the formula into a symbolic expression.
SymPy can be used to pretty-print or further simplify the formula
symbolically::
>>> from sympy import sympify
>>> sympify(identify(sqrt(2)))
2**(1/2)
Sometimes :func:`identify` can simplify an expression further than
a symbolic algorithm::
>>> from sympy import simplify
>>> x = sympify('-1/(-3/2+(1/2)*5**(1/2))*(3/2-1/2*5**(1/2))**(1/2)')
>>> x
(3/2 - 5**(1/2)/2)**(-1/2)
>>> x = simplify(x)
>>> x
2/(6 - 2*5**(1/2))**(1/2)
>>> mp.dps = 30
>>> x = sympify(identify(x.evalf(30)))
>>> x
1/2 + 5**(1/2)/2
(In fact, this functionality is available directly in SymPy as the
function :func:`nsimplify`, which is essentially a wrapper for
:func:`identify`.)
**Miscellaneous issues and limitations**
The input `x` must be a real number. All base constants must be
positive real numbers and must not be rationals or rational linear
combinations of each other.
The worst-case computation time grows quickly with the number of
base constants. Already with 3 or 4 base constants,
:func:`identify` may require several seconds to finish. To search
for relations among a large number of constants, you should
consider using :func:`pslq` directly.
The extended transformations are applied to x, not the constants
separately. As a result, ``identify`` will for example be able to
recognize ``exp(2*pi+3)`` with ``pi`` given as a base constant, but
not ``2*exp(pi)+3``. It will be able to recognize the latter if
``exp(pi)`` is given explicitly as a base constant.
"""
solutions = []
def addsolution(s):
if verbose: print "Found: ", s
solutions.append(s)
x = ctx.mpf(x)
# Further along, x will be assumed positive
if x == 0:
if full: return ['0']
else: return '0'
if x < 0:
sol = ctx.identify(-x, constants, tol, maxcoeff, full, verbose)
if sol is None:
return sol
if full:
return ["-(%s)"%s for s in sol]
else:
return "-(%s)" % sol
if tol:
tol = ctx.mpf(tol)
else:
tol = ctx.eps**0.7
M = maxcoeff
if constants:
if isinstance(constants, dict):
constants = [(ctx.mpf(v), name) for (name, v) in constants.items()]
else:
namespace = dict((name, getattr(ctx,name)) for name in dir(ctx))
constants = [(eval(p, namespace), p) for p in constants]
else:
constants = []
# We always want to find at least rational terms
if 1 not in [value for (name, value) in constants]:
constants = [(ctx.mpf(1), '1')] + constants
# PSLQ with simple algebraic and functional transformations
for ft, ftn, red in transforms:
for c, cn in constants:
if red and cn == '1':
continue
t = ft(ctx,x,c)
# Prevent exponential transforms from wreaking havoc
if abs(t) > M**2 or abs(t) < tol:
continue
# Linear combination of base constants
r = ctx.pslq([t] + [a[0] for a in constants], tol, M)
s = None
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
s = pslqstring(r, constants)
# Quadratic algebraic numbers
else:
q = ctx.pslq([ctx.one, t, t**2], tol, M)
if q is not None and len(q) == 3 and q[2]:
aa, bb, cc = q
if max(abs(aa),abs(bb),abs(cc)) <= M:
s = quadraticstring(ctx,t,aa,bb,cc)
if s:
if cn == '1' and ('/$c' in ftn):
s = ftn.replace('$y', s).replace('/$c', '')
else:
s = ftn.replace('$y', s).replace('$c', cn)
addsolution(s)
if not full: return solutions[0]
if verbose:
print "."
# Check for a direct multiplicative formula
if x != 1:
# Allow fractional powers of fractions
ilogs = [2,3,5,7]
# Watch out for existing fractional powers of fractions
logs = []
for a, s in constants:
if not sum(bool(ctx.findpoly(ctx.ln(a)/ctx.ln(i),1)) for i in ilogs):
logs.append((ctx.ln(a), s))
logs = [(ctx.ln(i),str(i)) for i in ilogs] + logs
r = ctx.pslq([ctx.ln(x)] + [a[0] for a in logs], tol, M)
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
addsolution(prodstring(r, logs))
if not full: return solutions[0]
if full:
return sorted(solutions, key=len)
else:
return None
IdentificationMethods.pslq = pslq
IdentificationMethods.findpoly = findpoly
IdentificationMethods.identify = identify
if __name__ == '__main__':
import doctest
doctest.testmod()

View File

@ -1,64 +0,0 @@
from libmpf import (prec_to_dps, dps_to_prec, repr_dps,
round_down, round_up, round_floor, round_ceiling, round_nearest,
to_pickable, from_pickable, ComplexResult,
fzero, fnzero, fone, fnone, ftwo, ften, fhalf, fnan, finf, fninf,
math_float_inf, round_int, normalize, normalize1,
from_man_exp, from_int, to_man_exp, to_int, mpf_ceil, mpf_floor,
from_float, to_float, from_rational, to_rational, to_fixed,
mpf_rand, mpf_eq, mpf_hash, mpf_cmp, mpf_lt, mpf_le, mpf_gt, mpf_ge,
mpf_pos, mpf_neg, mpf_abs, mpf_sign, mpf_add, mpf_sub, mpf_sum,
mpf_mul, mpf_mul_int, mpf_shift, mpf_frexp,
mpf_div, mpf_rdiv_int, mpf_mod, mpf_pow_int,
mpf_perturb,
to_digits_exp, to_str, str_to_man_exp, from_str, from_bstr, to_bstr,
mpf_sqrt, mpf_hypot)
from libmpc import (mpc_one, mpc_zero, mpc_two, mpc_half,
mpc_is_inf, mpc_is_infnan, mpc_to_str, mpc_to_complex, mpc_hash,
mpc_conjugate, mpc_is_nonzero, mpc_add, mpc_add_mpf,
mpc_sub, mpc_sub_mpf, mpc_pos, mpc_neg, mpc_shift, mpc_abs,
mpc_arg, mpc_floor, mpc_ceil, mpc_mul, mpc_square,
mpc_mul_mpf, mpc_mul_imag_mpf, mpc_mul_int,
mpc_div, mpc_div_mpf, mpc_reciprocal, mpc_mpf_div,
complex_int_pow, mpc_pow, mpc_pow_mpf, mpc_pow_int,
mpc_sqrt, mpc_nthroot, mpc_cbrt, mpc_exp, mpc_log, mpc_cos, mpc_sin,
mpc_tan, mpc_cos_pi, mpc_sin_pi, mpc_cosh, mpc_sinh, mpc_tanh,
mpc_atan, mpc_acos, mpc_asin, mpc_asinh, mpc_acosh, mpc_atanh,
mpc_fibonacci, mpf_expj, mpf_expjpi, mpc_expj, mpc_expjpi)
from libelefun import (ln2_fixed, mpf_ln2, ln10_fixed, mpf_ln10,
pi_fixed, mpf_pi, e_fixed, mpf_e, phi_fixed, mpf_phi,
degree_fixed, mpf_degree,
mpf_pow, mpf_nthroot, mpf_cbrt, log_int_fixed, agm_fixed,
mpf_log, mpf_log_hypot, mpf_exp, mpf_cos_sin, mpf_cos, mpf_sin, mpf_tan,
mpf_cos_sin_pi, mpf_cos_pi, mpf_sin_pi, mpf_cosh_sinh,
mpf_cosh, mpf_sinh, mpf_tanh, mpf_atan, mpf_atan2, mpf_asin,
mpf_acos, mpf_asinh, mpf_acosh, mpf_atanh, mpf_fibonacci)
from libhyper import (NoConvergence, make_hyp_summator,
mpf_erf, mpf_erfc, mpf_ei, mpc_ei, mpf_e1, mpc_e1, mpf_expint,
mpf_ci_si, mpf_ci, mpf_si, mpc_ci, mpc_si, mpf_besseljn,
mpc_besseljn, mpf_agm, mpf_agm1, mpc_agm, mpc_agm1,
mpf_ellipk, mpc_ellipk, mpf_ellipe, mpc_ellipe)
from gammazeta import (catalan_fixed, mpf_catalan,
khinchin_fixed, mpf_khinchin, glaisher_fixed, mpf_glaisher,
apery_fixed, mpf_apery, euler_fixed, mpf_euler, mertens_fixed,
mpf_mertens, twinprime_fixed, mpf_twinprime,
mpf_bernoulli, bernfrac, mpf_gamma_int,
mpf_factorial, mpc_factorial, mpf_gamma, mpc_gamma,
mpf_harmonic, mpc_harmonic, mpf_psi0, mpc_psi0,
mpf_psi, mpc_psi, mpf_zeta_int, mpf_zeta, mpc_zeta,
mpf_altzeta, mpc_altzeta, mpf_zetasum, mpc_zetasum)
from libmpi import (mpi_str, mpi_add, mpi_sub, mpi_delta, mpi_mid,
mpi_pos, mpi_neg, mpi_abs, mpi_mul, mpi_div, mpi_exp,
mpi_log, mpi_sqrt, mpi_pow_int, mpi_pow, mpi_cos_sin,
mpi_cos, mpi_sin, mpi_tan, mpi_cot)
from libintmath import (trailing, bitcount, numeral, bin_to_radix,
isqrt, isqrt_small, isqrt_fast, sqrt_fixed, sqrtrem, ifib, ifac,
list_primes, moebius, gcd, eulernum)
from backend import (gmpy, sage, BACKEND, STRICT, MPZ, MPZ_TYPE,
MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_THREE, MPZ_FIVE, int_types)

View File

@ -1,64 +0,0 @@
import os
import sys
#----------------------------------------------------------------------------#
# Support GMPY for high-speed large integer arithmetic. #
# #
# To allow an external module to handle arithmetic, we need to make sure #
# that all high-precision variables are declared of the correct type. MPZ #
# is the constructor for the high-precision type. It defaults to Python's #
# long type but can be assinged another type, typically gmpy.mpz. #
# #
# MPZ must be used for the mantissa component of an mpf and must be used #
# for internal fixed-point operations. #
# #
# Side-effects #
# 1) "is" cannot be used to test for special values. Must use "==". #
# 2) There are bugs in GMPY prior to v1.02 so we must use v1.03 or later. #
#----------------------------------------------------------------------------#
# So we can import it from this module
gmpy = None
sage = None
sage_utils = None
BACKEND = 'python'
MPZ = long
if 'MPMATH_NOGMPY' not in os.environ:
try:
import gmpy
if gmpy.version() >= '1.03':
BACKEND = 'gmpy'
MPZ = gmpy.mpz
except:
pass
if 'MPMATH_NOSAGE' not in os.environ:
try:
import sage.all
import sage.libs.mpmath.utils as _sage_utils
sage = sage.all
sage_utils = _sage_utils
BACKEND = 'sage'
MPZ = sage.Integer
except:
pass
if 'MPMATH_STRICT' in os.environ:
STRICT = True
else:
STRICT = False
MPZ_TYPE = type(MPZ(0))
MPZ_ZERO = MPZ(0)
MPZ_ONE = MPZ(1)
MPZ_TWO = MPZ(2)
MPZ_THREE = MPZ(3)
MPZ_FIVE = MPZ(5)
if BACKEND == 'python':
int_types = (int, long)
else:
int_types = (int, long, MPZ_TYPE)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,461 +0,0 @@
"""
Utility functions for integer math.
TODO: rename, cleanup, perhaps move the gmpy wrapper code
here from settings.py
"""
import math
from bisect import bisect
from backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO
def giant_steps(start, target, n=2):
"""
Return a list of integers ~=
[start, n*start, ..., target/n^2, target/n, target]
but conservatively rounded so that the quotient between two
successive elements is actually slightly less than n.
With n = 2, this describes suitable precision steps for a
quadratically convergent algorithm such as Newton's method;
with n = 3 steps for cubic convergence (Halley's method), etc.
>>> giant_steps(50,1000)
[66, 128, 253, 502, 1000]
>>> giant_steps(50,1000,4)
[65, 252, 1000]
"""
L = [target]
while L[-1] > start*n:
L = L + [L[-1]//n + 2]
return L[::-1]
def rshift(x, n):
"""For an integer x, calculate x >> n with the fastest (floor)
rounding. Unlike the plain Python expression (x >> n), n is
allowed to be negative, in which case a left shift is performed."""
if n >= 0: return x >> n
else: return x << (-n)
def lshift(x, n):
"""For an integer x, calculate x << n. Unlike the plain Python
expression (x << n), n is allowed to be negative, in which case a
right shift with default (floor) rounding is performed."""
if n >= 0: return x << n
else: return x >> (-n)
if BACKEND == 'sage':
import operator
rshift = operator.rshift
lshift = operator.lshift
def python_trailing(n):
"""Count the number of trailing zero bits in abs(n)."""
if not n:
return 0
t = 0
while not n & 1:
n >>= 1
t += 1
return t
def gmpy_trailing(n):
"""Count the number of trailing zero bits in abs(n) using gmpy."""
if n: return MPZ(n).scan1()
else: return 0
# Small powers of 2
powers = [1<<_ for _ in range(300)]
def python_bitcount(n):
"""Calculate bit size of the nonnegative integer n."""
bc = bisect(powers, n)
if bc != 300:
return bc
bc = int(math.log(n, 2)) - 4
return bc + bctable[n>>bc]
def gmpy_bitcount(n):
"""Calculate bit size of the nonnegative integer n."""
if n: return MPZ(n).numdigits(2)
else: return 0
#def sage_bitcount(n):
# if n: return MPZ(n).nbits()
# else: return 0
def sage_trailing(n):
return MPZ(n).trailing_zero_bits()
if BACKEND == 'gmpy':
bitcount = gmpy_bitcount
trailing = gmpy_trailing
elif BACKEND == 'sage':
sage_bitcount = sage_utils.bitcount
bitcount = sage_bitcount
trailing = sage_trailing
else:
bitcount = python_bitcount
trailing = python_trailing
if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy):
bitcount = gmpy.bit_length
# Used to avoid slow function calls as far as possible
trailtable = map(trailing, range(256))
bctable = map(bitcount, range(1024))
# TODO: speed up for bases 2, 4, 8, 16, ...
def bin_to_radix(x, xbits, base, bdigits):
"""Changes radix of a fixed-point number; i.e., converts
x * 2**xbits to floor(x * 10**bdigits)."""
return x * (MPZ(base)**bdigits) >> xbits
stddigits = '0123456789abcdefghijklmnopqrstuvwxyz'
def small_numeral(n, base=10, digits=stddigits):
"""Return the string numeral of a positive integer in an arbitrary
base. Most efficient for small input."""
if base == 10:
return str(n)
digs = []
while n:
n, digit = divmod(n, base)
digs.append(digits[digit])
return "".join(digs[::-1])
def numeral_python(n, base=10, size=0, digits=stddigits):
"""Represent the integer n as a string of digits in the given base.
Recursive division is used to make this function about 3x faster
than Python's str() for converting integers to decimal strings.
The 'size' parameters specifies the number of digits in n; this
number is only used to determine splitting points and need not be
exact."""
if n <= 0:
if not n:
return "0"
return "-" + numeral(-n, base, size, digits)
# Fast enough to do directly
if size < 250:
return small_numeral(n, base, digits)
# Divide in half
half = (size // 2) + (size & 1)
A, B = divmod(n, base**half)
ad = numeral(A, base, half, digits)
bd = numeral(B, base, half, digits).rjust(half, "0")
return ad + bd
def numeral_gmpy(n, base=10, size=0, digits=stddigits):
"""Represent the integer n as a string of digits in the given base.
Recursive division is used to make this function about 3x faster
than Python's str() for converting integers to decimal strings.
The 'size' parameters specifies the number of digits in n; this
number is only used to determine splitting points and need not be
exact."""
if n < 0:
return "-" + numeral(-n, base, size, digits)
# gmpy.digits() may cause a segmentation fault when trying to convert
# extremely large values to a string. The size limit may need to be
# adjusted on some platforms, but 1500000 works on Windows and Linux.
if size < 1500000:
return gmpy.digits(n, base)
# Divide in half
half = (size // 2) + (size & 1)
A, B = divmod(n, MPZ(base)**half)
ad = numeral(A, base, half, digits)
bd = numeral(B, base, half, digits).rjust(half, "0")
return ad + bd
if BACKEND == "gmpy":
numeral = numeral_gmpy
else:
numeral = numeral_python
_1_800 = 1<<800
_1_600 = 1<<600
_1_400 = 1<<400
_1_200 = 1<<200
_1_100 = 1<<100
_1_50 = 1<<50
def isqrt_small_python(x):
"""
Correctly (floor) rounded integer square root, using
division. Fast up to ~200 digits.
"""
if not x:
return x
if x < _1_800:
# Exact with IEEE double precision arithmetic
if x < _1_50:
return int(x**0.5)
# Initial estimate can be any integer >= the true root; round up
r = int(x**0.5 * 1.00000000000001) + 1
else:
bc = bitcount(x)
n = bc//2
r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up
# The following iteration now precisely computes floor(sqrt(x))
# See e.g. Crandall & Pomerance, "Prime Numbers: A Computational
# Perspective"
while 1:
y = (r+x//r)>>1
if y >= r:
return r
r = y
def isqrt_fast_python(x):
"""
Fast approximate integer square root, computed using division-free
Newton iteration for large x. For random integers the result is almost
always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly
0.1% probability. If x is very close to an exact square, the answer is
1 ulp wrong with high probability.
With 0 guard bits, the largest error over a set of 10^5 random
inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits
almost certainly guarantees a max 1 ulp error.
"""
# Use direct division-based iteration if sqrt(x) < 2^400
# Assume floating-point square root accurate to within 1 ulp, then:
# 0 Newton iterations good to 52 bits
# 1 Newton iterations good to 104 bits
# 2 Newton iterations good to 208 bits
# 3 Newton iterations good to 416 bits
if x < _1_800:
y = int(x**0.5)
if x >= _1_100:
y = (y + x//y) >> 1
if x >= _1_200:
y = (y + x//y) >> 1
if x >= _1_400:
y = (y + x//y) >> 1
return y
bc = bitcount(x)
guard_bits = 10
x <<= 2*guard_bits
bc += 2*guard_bits
bc += (bc&1)
hbc = bc//2
startprec = min(50, hbc)
# Newton iteration for 1/sqrt(x), with floating-point starting value
r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5)
pp = startprec
for p in giant_steps(startprec, hbc):
# r**2, scaled from real size 2**(-bc) to 2**p
r2 = (r*r) >> (2*pp - p)
# x*r**2, scaled from real size ~1.0 to 2**p
xr2 = ((x >> (bc-p)) * r2) >> p
# New value of r, scaled from real size 2**(-bc/2) to 2**p
r = (r * ((3<<p) - xr2)) >> (pp+1)
pp = p
# (1/sqrt(x))*x = sqrt(x)
return (r*(x>>hbc)) >> (p+guard_bits)
def sqrtrem_python(x):
"""Correctly rounded integer (floor) square root with remainder."""
# to check cutoff:
# plot(lambda x: timing(isqrt, 2**int(x)), [0,2000])
if x < _1_600:
y = isqrt_small_python(x)
return y, x - y*y
y = isqrt_fast_python(x) + 1
rem = x - y*y
# Correct remainder
while rem < 0:
y -= 1
rem += (1+2*y)
else:
if rem:
while rem > 2*(1+y):
y += 1
rem -= (1+2*y)
return y, rem
def isqrt_python(x):
"""Integer square root with correct (floor) rounding."""
return sqrtrem_python(x)[0]
def sqrt_fixed(x, prec):
return isqrt_fast(x<<prec)
sqrt_fixed2 = sqrt_fixed
if BACKEND == 'gmpy':
isqrt_small = isqrt_fast = isqrt = gmpy.sqrt
sqrtrem = gmpy.sqrtrem
elif BACKEND == 'sage':
isqrt_small = isqrt_fast = isqrt = lambda n: MPZ(n).isqrt()
sqrtrem = lambda n: MPZ(n).sqrtrem()
else:
isqrt_small = isqrt_small_python
isqrt_fast = isqrt_fast_python
isqrt = isqrt_python
sqrtrem = sqrtrem_python
def ifib(n, _cache={}):
"""Computes the nth Fibonacci number as an integer, for
integer n."""
if n < 0:
return (-1)**(-n+1) * ifib(-n)
if n in _cache:
return _cache[n]
m = n
# Use Dijkstra's logarithmic algorithm
# The following implementation is basically equivalent to
# http://en.literateprograms.org/Fibonacci_numbers_(Scheme)
a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE
while n:
if n & 1:
aq = a*q
a, b = b*q+aq+a*p, b*p+aq
n -= 1
else:
qq = q*q
p, q = p*p+qq, qq+2*p*q
n >>= 1
if m < 250:
_cache[m] = b
return b
MAX_FACTORIAL_CACHE = 1000
def ifac(n, memo={0:1, 1:1}):
"""Return n factorial (for integers n >= 0 only)."""
f = memo.get(n)
if f:
return f
k = len(memo)
p = memo[k-1]
MAX = MAX_FACTORIAL_CACHE
while k <= n:
p *= k
if k <= MAX:
memo[k] = p
k += 1
return p
if BACKEND == 'gmpy':
ifac = gmpy.fac
elif BACKEND == 'sage':
ifac = lambda n: int(sage.factorial(n))
ifib = sage.fibonacci
def list_primes(n):
n = n + 1
sieve = range(n)
sieve[:2] = [0, 0]
for i in xrange(2, int(n**0.5)+1):
if sieve[i]:
for j in xrange(i**2, n, i):
sieve[j] = 0
return [p for p in sieve if p]
if BACKEND == 'sage':
def list_primes(n):
return list(sage.primes(n+1))
def moebius(n):
"""
Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n`
is a product of `k` distinct primes and `mu(n) = 0` otherwise.
TODO: speed up using factorization
"""
n = abs(int(n))
if n < 2:
return n
factors = []
for p in xrange(2, n+1):
if not (n % p):
if not (n % p**2):
return 0
if not sum(p % f for f in factors):
factors.append(p)
return (-1)**len(factors)
def gcd(*args):
a = 0
for b in args:
if a:
while b:
a, b = b, a % b
else:
a = b
return a
# Comment by Juan Arias de Reyna:
#
# I learn this method to compute EulerE[2n] from van de Lune.
#
# We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1)
#
# where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies
#
# a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0
#
# a(n,j) = a(n-1,j) when n+j is even
# a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd
#
#
# But we can use only one array unidimensional a(j) since to compute
# a(n,j) we only need to know a(n-1,k) where k and j are of different parity
# and we have not to conserve the used values.
#
# We cached up the values of Euler numbers to sufficiently high order.
#
# Important Observation: If we pretend to use the numbers
# EulerE[1], EulerE[2], ... , EulerE[n]
# it is convenient to compute first EulerE[n], since the algorithm
# computes first all
# the previous ones, and keeps them in the CACHE
MAX_EULER_CACHE = 500
def eulernum(m, _cache={0:MPZ_ONE}):
r"""
Computes the Euler numbers `E(n)`, which can be defined as
coefficients of the Taylor expansion of `1/cosh x`:
.. math ::
\frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n
Example::
>>> [int(eulernum(n)) for n in range(11)]
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
>>> [int(eulernum(n)) for n in range(11)] # test cache
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
"""
# for odd m > 1, the Euler numbers are zero
if m & 1:
return MPZ_ZERO
f = _cache.get(m)
if f:
return f
MAX = MAX_EULER_CACHE
n = m
a = map(MPZ, [0,0,1,0,0,0])
for n in range(1, m+1):
for j in range(n+1, -1, -2):
a[j+1] = (j-1)*a[j] + (j+1)*a[j+2]
a.append(0)
suma = 0
for k in range(n+1, -1, -2):
suma += a[k+1]
if n <= MAX:
_cache[n] = ((-1)**(n//2))*(suma // 2**n)
if n == m:
return ((-1)**(n//2))*suma // 2**n

View File

@ -1,754 +0,0 @@
"""
Low-level functions for complex arithmetic.
"""
from backend import MPZ, MPZ_ZERO, MPZ_ONE, MPZ_TWO
from libmpf import (\
round_floor, round_ceiling, round_down, round_up,
round_nearest, round_fast, bitcount,
bctable, normalize, normalize1, reciprocal_rnd, rshift, lshift, giant_steps,
negative_rnd,
to_str, to_fixed, from_man_exp, from_float, to_float, from_int, to_int,
fzero, fone, ftwo, fhalf, finf, fninf, fnan, fnone,
mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul,
mpf_div, mpf_mul_int, mpf_shift, mpf_sqrt, mpf_hypot,
mpf_rdiv_int, mpf_floor, mpf_ceil,
mpf_sign,
ComplexResult
)
from libelefun import (\
mpf_pi, mpf_exp, mpf_log, mpf_cos_sin, mpf_cosh_sinh, mpf_tan, mpf_pow_int,
mpf_log_hypot,
mpf_cos_sin_pi, mpf_phi,
mpf_atan, mpf_atan2, mpf_cosh, mpf_sinh, mpf_tanh,
mpf_asin, mpf_acos, mpf_acosh, mpf_nthroot, mpf_fibonacci
)
# An mpc value is a (real, imag) tuple
mpc_one = fone, fzero
mpc_zero = fzero, fzero
mpc_two = ftwo, fzero
mpc_half = (fhalf, fzero)
_infs = (finf, fninf)
_infs_nan = (finf, fninf, fnan)
def mpc_is_inf(z):
"""Check if either real or imaginary part is infinite"""
re, im = z
if re in _infs: return True
if im in _infs: return True
return False
def mpc_is_infnan(z):
"""Check if either real or imaginary part is infinite or nan"""
re, im = z
if re in _infs_nan: return True
if im in _infs_nan: return True
return False
def mpc_to_str(z, dps, **kwargs):
re, im = z
rs = to_str(re, dps)
if im[0]:
return rs + " - " + to_str(mpf_neg(im), dps, **kwargs) + "j"
else:
return rs + " + " + to_str(im, dps, **kwargs) + "j"
def mpc_to_complex(z, strict=False):
re, im = z
return complex(to_float(re, strict), to_float(im, strict))
def mpc_hash(z):
try:
return hash(mpc_to_complex(z, strict=True))
except OverflowError:
return hash(z)
def mpc_conjugate(z, prec, rnd=round_fast):
re, im = z
return re, mpf_neg(im, prec, rnd)
def mpc_is_nonzero(z):
return z != mpc_zero
def mpc_add(z, w, prec, rnd=round_fast):
a, b = z
c, d = w
return mpf_add(a, c, prec, rnd), mpf_add(b, d, prec, rnd)
def mpc_add_mpf(z, x, prec, rnd=round_fast):
a, b = z
return mpf_add(a, x, prec, rnd), b
def mpc_sub(z, w, prec, rnd=round_fast):
a, b = z
c, d = w
return mpf_sub(a, c, prec, rnd), mpf_sub(b, d, prec, rnd)
def mpc_sub_mpf(z, p, prec, rnd=round_fast):
a, b = z
return mpf_sub(a, p, prec, rnd), b
def mpc_pos(z, prec, rnd=round_fast):
a, b = z
return mpf_pos(a, prec, rnd), mpf_pos(b, prec, rnd)
def mpc_neg(z, prec=None, rnd=round_fast):
a, b = z
return mpf_neg(a, prec, rnd), mpf_neg(b, prec, rnd)
def mpc_shift(z, n):
a, b = z
return mpf_shift(a, n), mpf_shift(b, n)
def mpc_abs(z, prec, rnd=round_fast):
"""Absolute value of a complex number, |a+bi|.
Returns an mpf value."""
a, b = z
return mpf_hypot(a, b, prec, rnd)
def mpc_arg(z, prec, rnd=round_fast):
"""Argument of a complex number. Returns an mpf value."""
a, b = z
return mpf_atan2(b, a, prec, rnd)
def mpc_floor(z, prec, rnd=round_fast):
a, b = z
return mpf_floor(a, prec, rnd), mpf_floor(b, prec, rnd)
def mpc_ceil(z, prec, rnd=round_fast):
a, b = z
return mpf_ceil(a, prec, rnd), mpf_ceil(b, prec, rnd)
def mpc_mul(z, w, prec, rnd=round_fast):
"""
Complex multiplication.
Returns the real and imaginary part of (a+bi)*(c+di), rounded to
the specified precision. The rounding mode applies to the real and
imaginary parts separately.
"""
a, b = z
c, d = w
p = mpf_mul(a, c)
q = mpf_mul(b, d)
r = mpf_mul(a, d)
s = mpf_mul(b, c)
re = mpf_sub(p, q, prec, rnd)
im = mpf_add(r, s, prec, rnd)
return re, im
def mpc_square(z, prec, rnd=round_fast):
# (a+b*I)**2 == a**2 - b**2 + 2*I*a*b
a, b = z
p = mpf_mul(a,a)
q = mpf_mul(b,b)
r = mpf_mul(a,b, prec, rnd)
re = mpf_sub(p, q, prec, rnd)
im = mpf_shift(r, 1)
return re, im
def mpc_mul_mpf(z, p, prec, rnd=round_fast):
a, b = z
re = mpf_mul(a, p, prec, rnd)
im = mpf_mul(b, p, prec, rnd)
return re, im
def mpc_mul_imag_mpf(z, x, prec, rnd=round_fast):
"""
Multiply the mpc value z by I*x where x is an mpf value.
"""
a, b = z
re = mpf_neg(mpf_mul(b, x, prec, rnd))
im = mpf_mul(a, x, prec, rnd)
return re, im
def mpc_mul_int(z, n, prec, rnd=round_fast):
a, b = z
re = mpf_mul_int(a, n, prec, rnd)
im = mpf_mul_int(b, n, prec, rnd)
return re, im
def mpc_div(z, w, prec, rnd=round_fast):
a, b = z
c, d = w
wp = prec + 10
# mag = c*c + d*d
mag = mpf_add(mpf_mul(c, c), mpf_mul(d, d), wp)
# (a*c+b*d)/mag, (b*c-a*d)/mag
t = mpf_add(mpf_mul(a,c), mpf_mul(b,d), wp)
u = mpf_sub(mpf_mul(b,c), mpf_mul(a,d), wp)
return mpf_div(t,mag,prec,rnd), mpf_div(u,mag,prec,rnd)
def mpc_div_mpf(z, p, prec, rnd=round_fast):
"""Calculate z/p where p is real"""
a, b = z
re = mpf_div(a, p, prec, rnd)
im = mpf_div(b, p, prec, rnd)
return re, im
def mpc_reciprocal(z, prec, rnd=round_fast):
"""Calculate 1/z efficiently"""
a, b = z
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b),prec+10)
re = mpf_div(a, m, prec, rnd)
im = mpf_neg(mpf_div(b, m, prec, rnd))
return re, im
def mpc_mpf_div(p, z, prec, rnd=round_fast):
"""Calculate p/z where p is real efficiently"""
a, b = z
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b), prec+10)
re = mpf_div(mpf_mul(a,p), m, prec, rnd)
im = mpf_div(mpf_neg(mpf_mul(b,p)), m, prec, rnd)
return re, im
def complex_int_pow(a, b, n):
"""Complex integer power: computes (a+b*I)**n exactly for
nonnegative n (a and b must be Python ints)."""
wre = 1
wim = 0
while n:
if n & 1:
wre, wim = wre*a - wim*b, wim*a + wre*b
n -= 1
a, b = a*a - b*b, 2*a*b
n //= 2
return wre, wim
def mpc_pow(z, w, prec, rnd=round_fast):
if w[1] == fzero:
return mpc_pow_mpf(z, w[0], prec, rnd)
return mpc_exp(mpc_mul(mpc_log(z, prec+10), w, prec+10), prec, rnd)
def mpc_pow_mpf(z, p, prec, rnd=round_fast):
psign, pman, pexp, pbc = p
if pexp >= 0:
return mpc_pow_int(z, (-1)**psign * (pman<<pexp), prec, rnd)
if pexp == -1:
sqrtz = mpc_sqrt(z, prec+10)
return mpc_pow_int(sqrtz, (-1)**psign * pman, prec, rnd)
return mpc_exp(mpc_mul_mpf(mpc_log(z, prec+10), p, prec+10), prec, rnd)
def mpc_pow_int(z, n, prec, rnd=round_fast):
a, b = z
if b == fzero:
return mpf_pow_int(a, n, prec, rnd), fzero
if a == fzero:
v = mpf_pow_int(b, n, prec, rnd)
n %= 4
if n == 0:
return v, fzero
elif n == 1:
return fzero, v
elif n == 2:
return mpf_neg(v), fzero
elif n == 3:
return fzero, mpf_neg(v)
if n == 0: return mpc_one
if n == 1: return mpc_pos(z, prec, rnd)
if n == 2: return mpc_square(z, prec, rnd)
if n == -1: return mpc_reciprocal(z, prec, rnd)
if n < 0: return mpc_reciprocal(mpc_pow_int(z, -n, prec+4), prec, rnd)
asign, aman, aexp, abc = a
bsign, bman, bexp, bbc = b
if asign: aman = -aman
if bsign: bman = -bman
de = aexp - bexp
abs_de = abs(de)
exact_size = n*(abs_de + max(abc, bbc))
if exact_size < 10000:
if de > 0:
aman <<= de
aexp = bexp
else:
bman <<= (-de)
bexp = aexp
re, im = complex_int_pow(aman, bman, n)
re = from_man_exp(re, int(n*aexp), prec, rnd)
im = from_man_exp(im, int(n*bexp), prec, rnd)
return re, im
return mpc_exp(mpc_mul_int(mpc_log(z, prec+10), n, prec+10), prec, rnd)
def mpc_sqrt(z, prec, rnd=round_fast):
"""Complex square root (principal branch).
We have sqrt(a+bi) = sqrt((r+a)/2) + b/sqrt(2*(r+a))*i where
r = abs(a+bi), when a+bi is not a negative real number."""
a, b = z
if b == fzero:
if a == fzero:
return (a, b)
# When a+bi is a negative real number, we get a real sqrt times i
if a[0]:
im = mpf_sqrt(mpf_neg(a), prec, rnd)
return (fzero, im)
else:
re = mpf_sqrt(a, prec, rnd)
return (re, fzero)
wp = prec+20
if not a[0]: # case a positive
t = mpf_add(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) + a
u = mpf_shift(t, -1) # u = t/2
re = mpf_sqrt(u, prec, rnd) # re = sqrt(u)
v = mpf_shift(t, 1) # v = 2*t
w = mpf_sqrt(v, wp) # w = sqrt(v)
im = mpf_div(b, w, prec, rnd) # im = b / w
else: # case a negative
t = mpf_sub(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) - a
u = mpf_shift(t, -1) # u = t/2
im = mpf_sqrt(u, prec, rnd) # im = sqrt(u)
v = mpf_shift(t, 1) # v = 2*t
w = mpf_sqrt(v, wp) # w = sqrt(v)
re = mpf_div(b, w, prec, rnd) # re = b/w
if b[0]:
re = mpf_neg(re)
im = mpf_neg(im)
return re, im
def mpc_nthroot_fixed(a, b, n, prec):
# a, b signed integers at fixed precision prec
start = 50
a1 = int(rshift(a, prec - n*start))
b1 = int(rshift(b, prec - n*start))
try:
r = (a1 + 1j * b1)**(1.0/n)
re = r.real
im = r.imag
re = MPZ(int(re))
im = MPZ(int(im))
except OverflowError:
a1 = from_int(a1, start)
b1 = from_int(b1, start)
fn = from_int(n)
nth = mpf_rdiv_int(1, fn, start)
re, im = mpc_pow((a1, b1), (nth, fzero), start)
re = to_int(re)
im = to_int(im)
extra = 10
prevp = start
extra1 = n
for p in giant_steps(start, prec+extra):
# this is slow for large n, unlike int_pow_fixed
re2, im2 = complex_int_pow(re, im, n-1)
re2 = rshift(re2, (n-1)*prevp - p - extra1)
im2 = rshift(im2, (n-1)*prevp - p - extra1)
r4 = (re2*re2 + im2*im2) >> (p + extra1)
ap = rshift(a, prec - p)
bp = rshift(b, prec - p)
rec = (ap * re2 + bp * im2) >> p
imc = (-ap * im2 + bp * re2) >> p
reb = (rec << p) // r4
imb = (imc << p) // r4
re = (reb + (n-1)*lshift(re, p-prevp))//n
im = (imb + (n-1)*lshift(im, p-prevp))//n
prevp = p
return re, im
def mpc_nthroot(z, n, prec, rnd=round_fast):
"""
Complex n-th root.
Use Newton method as in the real case when it is faster,
otherwise use z**(1/n)
"""
a, b = z
if a[0] == 0 and b == fzero:
re = mpf_nthroot(a, n, prec, rnd)
return (re, fzero)
if n < 2:
if n == 0:
return mpc_one
if n == 1:
return mpc_pos((a, b), prec, rnd)
if n == -1:
return mpc_div(mpc_one, (a, b), prec, rnd)
inverse = mpc_nthroot((a, b), -n, prec+5, reciprocal_rnd[rnd])
return mpc_div(mpc_one, inverse, prec, rnd)
if n <= 20:
prec2 = int(1.2 * (prec + 10))
asign, aman, aexp, abc = a
bsign, bman, bexp, bbc = b
pf = mpc_abs((a,b), prec)
if pf[-2] + pf[-1] > -10 and pf[-2] + pf[-1] < prec:
af = to_fixed(a, prec2)
bf = to_fixed(b, prec2)
re, im = mpc_nthroot_fixed(af, bf, n, prec2)
extra = 10
re = from_man_exp(re, -prec2-extra, prec2, rnd)
im = from_man_exp(im, -prec2-extra, prec2, rnd)
return re, im
fn = from_int(n)
prec2 = prec+10 + 10
nth = mpf_rdiv_int(1, fn, prec2)
re, im = mpc_pow((a, b), (nth, fzero), prec2, rnd)
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
return re, im
def mpc_cbrt((a, b), prec, rnd=round_fast):
"""
Complex cubic root.
"""
return mpc_nthroot((a, b), 3, prec, rnd)
def mpc_exp((a, b), prec, rnd=round_fast):
"""
Complex exponential function.
We use the direct formula exp(a+bi) = exp(a) * (cos(b) + sin(b)*i)
for the computation. This formula is very nice because it is
pefectly stable; since we just do real multiplications, the only
numerical errors that can creep in are single-ulp rounding errors.
The formula is efficient since mpmath's real exp is quite fast and
since we can compute cos and sin simultaneously.
It is no problem if a and b are large; if the implementations of
exp/cos/sin are accurate and efficient for all real numbers, then
so is this function for all complex numbers.
"""
if a == fzero:
return mpf_cos_sin(b, prec, rnd)
mag = mpf_exp(a, prec+4, rnd)
c, s = mpf_cos_sin(b, prec+4, rnd)
re = mpf_mul(mag, c, prec, rnd)
im = mpf_mul(mag, s, prec, rnd)
return re, im
def mpc_log(z, prec, rnd=round_fast):
re = mpf_log_hypot(z[0], z[1], prec, rnd)
im = mpc_arg(z, prec, rnd)
return re, im
def mpc_cos((a, b), prec, rnd=round_fast):
"""Complex cosine. The formula used is cos(a+bi) = cos(a)*cosh(b) -
sin(a)*sinh(b)*i.
The same comments apply as for the complex exp: only real
multiplications are pewrormed, so no cancellation errors are
possible. The formula is also efficient since we can compute both
pairs (cos, sin) and (cosh, sinh) in single stwps."""
if a == fzero:
return mpf_cosh(b, prec, rnd), fzero
wp = prec + 6
c, s = mpf_cos_sin(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(c, ch, prec, rnd)
im = mpf_mul(s, sh, prec, rnd)
return re, mpf_neg(im)
def mpc_sin((a, b), prec, rnd=round_fast):
"""Complex sine. We have sin(a+bi) = sin(a)*cosh(b) +
cos(a)*sinh(b)*i. See the docstring for mpc_cos for additional
comments."""
if a == fzero:
return fzero, mpf_sinh(b, prec, rnd)
wp = prec + 6
c, s = mpf_cos_sin(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(s, ch, prec, rnd)
im = mpf_mul(c, sh, prec, rnd)
return re, im
def mpc_tan(z, prec, rnd=round_fast):
"""Complex tangent. Computed as tan(a+bi) = sin(2a)/M + sinh(2b)/M*i
where M = cos(2a) + cosh(2b)."""
a, b = z
asign, aman, aexp, abc = a
bsign, bman, bexp, bbc = b
if b == fzero: return mpf_tan(a, prec, rnd), fzero
if a == fzero: return fzero, mpf_tanh(b, prec, rnd)
wp = prec + 15
a = mpf_shift(a, 1)
b = mpf_shift(b, 1)
c, s = mpf_cos_sin(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
# TODO: handle cancellation when c ~= -1 and ch ~= 1
mag = mpf_add(c, ch, wp)
re = mpf_div(s, mag, prec, rnd)
im = mpf_div(sh, mag, prec, rnd)
return re, im
def mpc_cos_pi((a, b), prec, rnd=round_fast):
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
if a == fzero:
return mpf_cosh(b, prec, rnd), fzero
wp = prec + 6
c, s = mpf_cos_sin_pi(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(c, ch, prec, rnd)
im = mpf_mul(s, sh, prec, rnd)
return re, mpf_neg(im)
def mpc_sin_pi((a, b), prec, rnd=round_fast):
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
if a == fzero:
return fzero, mpf_sinh(b, prec, rnd)
wp = prec + 6
c, s = mpf_cos_sin_pi(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(s, ch, prec, rnd)
im = mpf_mul(c, sh, prec, rnd)
return re, im
def mpc_cosh((a, b), prec, rnd=round_fast):
"""Complex hyperbolic cosine. Computed as cosh(z) = cos(z*i)."""
return mpc_cos((b, mpf_neg(a)), prec, rnd)
def mpc_sinh((a, b), prec, rnd=round_fast):
"""Complex hyperbolic sine. Computed as sinh(z) = -i*sin(z*i)."""
b, a = mpc_sin((b, a), prec, rnd)
return a, b
def mpc_tanh((a, b), prec, rnd=round_fast):
"""Complex hyperbolic tangent. Computed as tanh(z) = -i*tan(z*i)."""
b, a = mpc_tan((b, a), prec, rnd)
return a, b
# TODO: avoid loss of accuracy
def mpc_atan(z, prec, rnd=round_fast):
a, b = z
# atan(z) = (I/2)*(log(1-I*z) - log(1+I*z))
# x = 1-I*z = 1 + b - I*a
# y = 1+I*z = 1 - b + I*a
wp = prec + 15
x = mpf_add(fone, b, wp), mpf_neg(a)
y = mpf_sub(fone, b, wp), a
l1 = mpc_log(x, wp)
l2 = mpc_log(y, wp)
a, b = mpc_sub(l1, l2, prec, rnd)
# (I/2) * (a+b*I) = (-b/2 + a/2*I)
v = mpf_neg(mpf_shift(b,-1)), mpf_shift(a,-1)
# Subtraction at infinity gives correct real part but
# wrong imaginary part (should be zero)
if v[1] == fnan and mpc_is_inf(z):
v = (v[0], fzero)
return v
beta_crossover = from_float(0.6417)
alpha_crossover = from_float(1.5)
def acos_asin(z, prec, rnd, n):
""" complex acos for n = 0, asin for n = 1
The algorithm is described in
T.E. Hull, T.F. Fairgrieve and P.T.P. Tang
'Implementing the Complex Arcsine and Arcosine Functions
using Exception Handling',
ACM Trans. on Math. Software Vol. 23 (1997), p299
The complex acos and asin can be defined as
acos(z) = acos(beta) - I*sign(a)* log(alpha + sqrt(alpha**2 -1))
asin(z) = asin(beta) + I*sign(a)* log(alpha + sqrt(alpha**2 -1))
where z = a + I*b
alpha = (1/2)*(r + s); beta = (1/2)*(r - s) = a/alpha
r = sqrt((a+1)**2 + y**2); s = sqrt((a-1)**2 + y**2)
These expressions are rewritten in different ways in different
regions, delimited by two crossovers alpha_crossover and beta_crossover,
and by abs(a) <= 1, in order to improve the numerical accuracy.
"""
a, b = z
wp = prec + 10
# special cases with real argument
if b == fzero:
am = mpf_sub(fone, mpf_abs(a), wp)
# case abs(a) <= 1
if not am[0]:
if n == 0:
return mpf_acos(a, prec, rnd), fzero
else:
return mpf_asin(a, prec, rnd), fzero
# cases abs(a) > 1
else:
# case a < -1
if a[0]:
pi = mpf_pi(prec, rnd)
c = mpf_acosh(mpf_neg(a), prec, rnd)
if n == 0:
return pi, mpf_neg(c)
else:
return mpf_neg(mpf_shift(pi, -1)), c
# case a > 1
else:
c = mpf_acosh(a, prec, rnd)
if n == 0:
return fzero, c
else:
pi = mpf_pi(prec, rnd)
return mpf_shift(pi, -1), mpf_neg(c)
asign = bsign = 0
if a[0]:
a = mpf_neg(a)
asign = 1
if b[0]:
b = mpf_neg(b)
bsign = 1
am = mpf_sub(fone, a, wp)
ap = mpf_add(fone, a, wp)
r = mpf_hypot(ap, b, wp)
s = mpf_hypot(am, b, wp)
alpha = mpf_shift(mpf_add(r, s, wp), -1)
beta = mpf_div(a, alpha, wp)
b2 = mpf_mul(b,b, wp)
# case beta <= beta_crossover
if not mpf_sub(beta_crossover, beta, wp)[0]:
if n == 0:
re = mpf_acos(beta, wp)
else:
re = mpf_asin(beta, wp)
else:
# to compute the real part in this region use the identity
# asin(beta) = atan(beta/sqrt(1-beta**2))
# beta/sqrt(1-beta**2) = (alpha + a) * (alpha - a)
# alpha + a is numerically accurate; alpha - a can have
# cancellations leading to numerical inaccuracies, so rewrite
# it in differente ways according to the region
Ax = mpf_add(alpha, a, wp)
# case a <= 1
if not am[0]:
# c = b*b/(r + (a+1)); d = (s + (1-a))
# alpha - a = (1/2)*(c + d)
# case n=0: re = atan(sqrt((1/2) * Ax * (c + d))/a)
# case n=1: re = atan(a/sqrt((1/2) * Ax * (c + d)))
c = mpf_div(b2, mpf_add(r, ap, wp), wp)
d = mpf_add(s, am, wp)
re = mpf_shift(mpf_mul(Ax, mpf_add(c, d, wp), wp), -1)
if n == 0:
re = mpf_atan(mpf_div(mpf_sqrt(re, wp), a, wp), wp)
else:
re = mpf_atan(mpf_div(a, mpf_sqrt(re, wp), wp), wp)
else:
# c = Ax/(r + (a+1)); d = Ax/(s - (1-a))
# alpha - a = (1/2)*(c + d)
# case n = 0: re = atan(b*sqrt(c + d)/2/a)
# case n = 1: re = atan(a/(b*sqrt(c + d)/2)
c = mpf_div(Ax, mpf_add(r, ap, wp), wp)
d = mpf_div(Ax, mpf_sub(s, am, wp), wp)
re = mpf_shift(mpf_add(c, d, wp), -1)
re = mpf_mul(b, mpf_sqrt(re, wp), wp)
if n == 0:
re = mpf_atan(mpf_div(re, a, wp), wp)
else:
re = mpf_atan(mpf_div(a, re, wp), wp)
# to compute alpha + sqrt(alpha**2 - 1), if alpha <= alpha_crossover
# replace it with 1 + Am1 + sqrt(Am1*(alpha+1)))
# where Am1 = alpha -1
# if alpha <= alpha_crossover:
if not mpf_sub(alpha_crossover, alpha, wp)[0]:
c1 = mpf_div(b2, mpf_add(r, ap, wp), wp)
# case a < 1
if mpf_neg(am)[0]:
# Am1 = (1/2) * (b*b/(r + (a+1)) + b*b/(s + (1-a))
c2 = mpf_add(s, am, wp)
c2 = mpf_div(b2, c2, wp)
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
else:
# Am1 = (1/2) * (b*b/(r + (a+1)) + (s - (1-a)))
c2 = mpf_sub(s, am, wp)
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
# im = log(1 + Am1 + sqrt(Am1*(alpha+1)))
im = mpf_mul(Am1, mpf_add(alpha, fone, wp), wp)
im = mpf_log(mpf_add(fone, mpf_add(Am1, mpf_sqrt(im, wp), wp), wp), wp)
else:
# im = log(alpha + sqrt(alpha*alpha - 1))
im = mpf_sqrt(mpf_sub(mpf_mul(alpha, alpha, wp), fone, wp), wp)
im = mpf_log(mpf_add(alpha, im, wp), wp)
if asign:
if n == 0:
re = mpf_sub(mpf_pi(wp), re, wp)
else:
re = mpf_neg(re)
if not bsign and n == 0:
im = mpf_neg(im)
if bsign and n == 1:
im = mpf_neg(im)
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
return re, im
def mpc_acos(z, prec, rnd=round_fast):
return acos_asin(z, prec, rnd, 0)
def mpc_asin(z, prec, rnd=round_fast):
return acos_asin(z, prec, rnd, 1)
def mpc_asinh(z, prec, rnd=round_fast):
# asinh(z) = I * asin(-I z)
a, b = z
a, b = mpc_asin((b, mpf_neg(a)), prec, rnd)
return mpf_neg(b), a
def mpc_acosh(z, prec, rnd=round_fast):
# acosh(z) = -I * acos(z) for Im(acos(z)) <= 0
# +I * acos(z) otherwise
a, b = mpc_acos(z, prec, rnd)
if b[0] or b == fzero:
return mpf_neg(b), a
else:
return b, mpf_neg(a)
def mpc_atanh(z, prec, rnd=round_fast):
# atanh(z) = (log(1+z)-log(1-z))/2
wp = prec + 15
a = mpc_add(z, mpc_one, wp)
b = mpc_sub(mpc_one, z, wp)
a = mpc_log(a, wp)
b = mpc_log(b, wp)
v = mpc_shift(mpc_sub(a, b, wp), -1)
# Subtraction at infinity gives correct imaginary part but
# wrong real part (should be zero)
if v[0] == fnan and mpc_is_inf(z):
v = (fzero, v[1])
return v
def mpc_fibonacci(z, prec, rnd=round_fast):
re, im = z
if im == fzero:
return (mpf_fibonacci(re, prec, rnd), fzero)
size = max(abs(re[2]+re[3]), abs(re[2]+re[3]))
wp = prec + size + 20
a = mpf_phi(wp)
b = mpf_add(mpf_shift(a, 1), fnone, wp)
u = mpc_pow((a, fzero), z, wp)
v = mpc_cos_pi(z, wp)
v = mpc_div(v, u, wp)
u = mpc_sub(u, v, wp)
u = mpc_div_mpf(u, b, prec, rnd)
return u
def mpf_expj(x, prec, rnd='f'):
raise ComplexResult
def mpc_expj(z, prec, rnd='f'):
re, im = z
if im == fzero:
return mpf_cos_sin(re, prec, rnd)
if re == fzero:
return mpf_exp(mpf_neg(im), prec, rnd), fzero
ey = mpf_exp(mpf_neg(im), prec+10)
c, s = mpf_cos_sin(re, prec+10)
re = mpf_mul(ey, c, prec, rnd)
im = mpf_mul(ey, s, prec, rnd)
return re, im
def mpf_expjpi(x, prec, rnd='f'):
raise ComplexResult
def mpc_expjpi(z, prec, rnd='f'):
re, im = z
if im == fzero:
return mpf_cos_sin_pi(re, prec, rnd)
sign, man, exp, bc = im
wp = prec+10
if man:
wp += max(0, exp+bc)
im = mpf_neg(mpf_mul(mpf_pi(wp), im, wp))
if re == fzero:
return mpf_exp(im, prec, rnd), fzero
ey = mpf_exp(im, prec+10)
c, s = mpf_cos_sin_pi(re, prec+10)
re = mpf_mul(ey, c, prec, rnd)
im = mpf_mul(ey, s, prec, rnd)
return re, im

File diff suppressed because it is too large Load Diff

View File

@ -1,348 +0,0 @@
"""
Computational functions for interval arithmetic.
"""
from libmpf import (
ComplexResult,
round_down, round_up, round_floor, round_ceiling, round_nearest,
prec_to_dps, repr_dps,
fnan, finf, fninf, fzero, fhalf, fone, fnone,
mpf_sign, mpf_lt, mpf_le, mpf_gt, mpf_ge, mpf_eq, mpf_cmp,
mpf_floor, from_int, to_int, to_str,
mpf_abs, mpf_neg, mpf_pos, mpf_add, mpf_sub, mpf_mul,
mpf_div, mpf_shift, mpf_pow_int)
from libelefun import (
mpf_log, mpf_exp, mpf_sqrt, reduce_angle, calc_cos_sin
)
def mpi_str(s, prec):
sa, sb = s
dps = prec_to_dps(prec) + 5
return "[%s, %s]" % (to_str(sa, dps), to_str(sb, dps))
#dps = prec_to_dps(prec)
#m = mpi_mid(s, prec)
#d = mpf_shift(mpi_delta(s, 20), -1)
#return "%s +/- %s" % (to_str(m, dps), to_str(d, 3))
def mpi_add(s, t, prec):
sa, sb = s
ta, tb = t
a = mpf_add(sa, ta, prec, round_floor)
b = mpf_add(sb, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
return a, b
def mpi_sub(s, t, prec):
sa, sb = s
ta, tb = t
a = mpf_sub(sa, tb, prec, round_floor)
b = mpf_sub(sb, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
return a, b
def mpi_delta(s, prec):
sa, sb = s
return mpf_sub(sb, sa, prec, round_up)
def mpi_mid(s, prec):
sa, sb = s
return mpf_shift(mpf_add(sa, sb, prec, round_nearest), -1)
def mpi_pos(s, prec):
sa, sb = s
a = mpf_pos(sa, prec, round_floor)
b = mpf_pos(sb, prec, round_ceiling)
return a, b
def mpi_neg(s, prec=None):
sa, sb = s
a = mpf_neg(sb, prec, round_floor)
b = mpf_neg(sa, prec, round_ceiling)
return a, b
def mpi_abs(s, prec):
sa, sb = s
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
# Both points nonnegative?
if sas >= 0:
a = mpf_pos(sa, prec, round_floor)
b = mpf_pos(sb, prec, round_ceiling)
# Upper point nonnegative?
elif sbs >= 0:
a = fzero
negsa = mpf_neg(sa)
if mpf_lt(negsa, sb):
b = mpf_pos(sb, prec, round_ceiling)
else:
b = mpf_pos(negsa, prec, round_ceiling)
# Both negative?
else:
a = mpf_neg(sb, prec, round_floor)
b = mpf_neg(sa, prec, round_ceiling)
return a, b
def mpi_mul(s, t, prec):
sa, sb = s
ta, tb = t
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
tas = mpf_sign(ta)
tbs = mpf_sign(tb)
if sas == sbs == 0:
# Should maybe be undefined
if ta == fninf or tb == finf:
return fninf, finf
return fzero, fzero
if tas == tbs == 0:
# Should maybe be undefined
if sa == fninf or sb == finf:
return fninf, finf
return fzero, fzero
if sas >= 0:
# positive * positive
if tas >= 0:
a = mpf_mul(sa, ta, prec, round_floor)
b = mpf_mul(sb, tb, prec, round_ceiling)
if a == fnan: a = fzero
if b == fnan: b = finf
# positive * negative
elif tbs <= 0:
a = mpf_mul(sb, ta, prec, round_floor)
b = mpf_mul(sa, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = fzero
# positive * both signs
else:
a = mpf_mul(sb, ta, prec, round_floor)
b = mpf_mul(sb, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
elif sbs <= 0:
# negative * positive
if tas >= 0:
a = mpf_mul(sa, tb, prec, round_floor)
b = mpf_mul(sb, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = fzero
# negative * negative
elif tbs <= 0:
a = mpf_mul(sb, tb, prec, round_floor)
b = mpf_mul(sa, ta, prec, round_ceiling)
if a == fnan: a = fzero
if b == fnan: b = finf
# negative * both signs
else:
a = mpf_mul(sa, tb, prec, round_floor)
b = mpf_mul(sa, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
else:
# General case: perform all cross-multiplications and compare
# Since the multiplications can be done exactly, we need only
# do 4 (instead of 8: two for each rounding mode)
cases = [mpf_mul(sa, ta), mpf_mul(sa, tb), mpf_mul(sb, ta), mpf_mul(sb, tb)]
if fnan in cases:
a, b = (fninf, finf)
else:
cases = sorted(cases, cmp=mpf_cmp)
a = mpf_pos(cases[0], prec, round_floor)
b = mpf_pos(cases[-1], prec, round_ceiling)
return a, b
def mpi_div(s, t, prec):
sa, sb = s
ta, tb = t
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
tas = mpf_sign(ta)
tbs = mpf_sign(tb)
# 0 / X
if sas == sbs == 0:
# 0 / <interval containing 0>
if (tas < 0 and tbs > 0) or (tas == 0 or tbs == 0):
return fninf, finf
return fzero, fzero
# Denominator contains both negative and positive numbers;
# this should properly be a multi-interval, but the closest
# match is the entire (extended) real line
if tas < 0 and tbs > 0:
return fninf, finf
# Assume denominator to be nonnegative
if tas < 0:
return mpi_div(mpi_neg(s), mpi_neg(t), prec)
# Division by zero
# XXX: make sure all results make sense
if tas == 0:
# Numerator contains both signs?
if sas < 0 and sbs > 0:
return fninf, finf
if tas == tbs:
return fninf, finf
# Numerator positive?
if sas >= 0:
a = mpf_div(sa, tb, prec, round_floor)
b = finf
if sbs <= 0:
a = fninf
b = mpf_div(sb, tb, prec, round_ceiling)
# Division with positive denominator
# We still have to handle nans resulting from inf/0 or inf/inf
else:
# Nonnegative numerator
if sas >= 0:
a = mpf_div(sa, tb, prec, round_floor)
b = mpf_div(sb, ta, prec, round_ceiling)
if a == fnan: a = fzero
if b == fnan: b = finf
# Nonpositive numerator
elif sbs <= 0:
a = mpf_div(sa, ta, prec, round_floor)
b = mpf_div(sb, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = fzero
# Numerator contains both signs?
else:
a = mpf_div(sa, ta, prec, round_floor)
b = mpf_div(sb, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
return a, b
def mpi_exp(s, prec):
sa, sb = s
# exp is monotonous
a = mpf_exp(sa, prec, round_floor)
b = mpf_exp(sb, prec, round_ceiling)
return a, b
def mpi_log(s, prec):
sa, sb = s
# log is monotonous
a = mpf_log(sa, prec, round_floor)
b = mpf_log(sb, prec, round_ceiling)
return a, b
def mpi_sqrt(s, prec):
sa, sb = s
# sqrt is monotonous
a = mpf_sqrt(sa, prec, round_floor)
b = mpf_sqrt(sb, prec, round_ceiling)
return a, b
def mpi_pow_int(s, n, prec):
sa, sb = s
if n < 0:
return mpi_div((fone, fone), mpi_pow_int(s, -n, prec+20), prec)
if n == 0:
return (fone, fone)
if n == 1:
return s
# Odd -- signs are preserved
if n & 1:
a = mpf_pow_int(sa, n, prec, round_floor)
b = mpf_pow_int(sb, n, prec, round_ceiling)
# Even -- important to ensure positivity
else:
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
# Nonnegative?
if sas >= 0:
a = mpf_pow_int(sa, n, prec, round_floor)
b = mpf_pow_int(sb, n, prec, round_ceiling)
# Nonpositive?
elif sbs <= 0:
a = mpf_pow_int(sb, n, prec, round_floor)
b = mpf_pow_int(sa, n, prec, round_ceiling)
# Mixed signs?
else:
a = fzero
# max(-a,b)**n
sa = mpf_neg(sa)
if mpf_ge(sa, sb):
b = mpf_pow_int(sa, n, prec, round_ceiling)
else:
b = mpf_pow_int(sb, n, prec, round_ceiling)
return a, b
def mpi_pow(s, t, prec):
ta, tb = t
if ta == tb and ta not in (finf, fninf):
if ta == from_int(to_int(ta)):
return mpi_pow_int(s, to_int(ta), prec)
if ta == fhalf:
return mpi_sqrt(s, prec)
u = mpi_log(s, prec + 20)
v = mpi_mul(u, t, prec + 20)
return mpi_exp(v, prec)
def MIN(x, y):
if mpf_le(x, y):
return x
return y
def MAX(x, y):
if mpf_ge(x, y):
return x
return y
def mpi_cos_sin(x, prec):
a, b = x
# Guaranteed to contain both -1 and 1
if finf in (a, b) or fninf in (a, b):
return (fnone, fone), (fnone, fone)
y, yswaps, yn = reduce_angle(a, prec+20)
z, zswaps, zn = reduce_angle(b, prec+20)
# Guaranteed to contain both -1 and 1
if zn - yn >= 4:
return (fnone, fone), (fnone, fone)
# Both points in the same quadrant -- cos and sin both strictly monotonous
if yn == zn:
m = yn % 4
if m == 0:
cb, sa = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_floor)
ca, sb = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
if m == 1:
cb, sb = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_ceiling)
ca, sa = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
if m == 2:
ca, sb = calc_cos_sin(0, y, yswaps, prec, round_floor, round_ceiling)
cb, sa = calc_cos_sin(0, z, zswaps, prec, round_ceiling, round_floor)
if m == 3:
ca, sa = calc_cos_sin(0, y, yswaps, prec, round_floor, round_floor)
cb, sb = calc_cos_sin(0, z, zswaps, prec, round_ceiling, round_ceiling)
return (ca, cb), (sa, sb)
# Intervals spanning multiple quadrants
yn %= 4
zn %= 4
case = (yn, zn)
if case == (0, 1):
cb, sy = calc_cos_sin(0, y, yswaps, prec, round_ceiling, round_floor)
ca, sz = calc_cos_sin(0, z, zswaps, prec, round_floor, round_floor)
return (ca, cb), (MIN(sy, sz), fone)
if case == (3, 0):
cy, sa = calc_cos_sin(0, y, yswaps, prec, round_floor, round_floor)
cz, sb = calc_cos_sin(0, z, zswaps, prec, round_floor, round_ceiling)
return (MIN(cy, cz), fone), (sa, sb)
raise NotImplementedError("cos/sin spanning multiple quadrants")
def mpi_cos(x, prec):
return mpi_cos_sin(x, prec)[0]
def mpi_sin(x, prec):
return mpi_cos_sin(x, prec)[1]
def mpi_tan(x, prec):
cos, sin = mpi_cos_sin(x, prec+20)
return mpi_div(sin, cos, prec)
def mpi_cot(x, prec):
cos, sin = mpi_cos_sin(x, prec+20)
return mpi_div(cos, sin, prec)

View File

@ -1,645 +0,0 @@
"""
This module complements the math and cmath builtin modules by providing
fast machine precision versions of some additional functions (gamma, ...)
and wrapping math/cmath functions so that they can be called with either
real or complex arguments.
"""
import operator
import math
import cmath
# Irrational (?) constants
pi = 3.1415926535897932385
e = 2.7182818284590452354
sqrt2 = 1.4142135623730950488
sqrt5 = 2.2360679774997896964
phi = 1.6180339887498948482
ln2 = 0.69314718055994530942
ln10 = 2.302585092994045684
euler = 0.57721566490153286061
catalan = 0.91596559417721901505
khinchin = 2.6854520010653064453
apery = 1.2020569031595942854
logpi = 1.1447298858494001741
def _mathfun_real(f_real, f_complex):
def f(x, **kwargs):
if type(x) is float:
return f_real(x)
if type(x) is complex:
return f_complex(x)
try:
x = float(x)
return f_real(x)
except (TypeError, ValueError):
x = complex(x)
return f_complex(x)
f.__name__ = f_real.__name__
return f
def _mathfun(f_real, f_complex):
def f(x, **kwargs):
if type(x) is complex:
return f_complex(x)
try:
return f_real(float(x))
except (TypeError, ValueError):
return f_complex(complex(x))
f.__name__ = f_real.__name__
return f
def _mathfun_n(f_real, f_complex):
def f(*args, **kwargs):
try:
return f_real(*(float(x) for x in args))
except (TypeError, ValueError):
return f_complex(*(complex(x) for x in args))
f.__name__ = f_real.__name__
return f
pow = _mathfun_n(operator.pow, lambda x, y: complex(x)**y)
log = _mathfun_n(math.log, cmath.log)
sqrt = _mathfun(math.sqrt, cmath.sqrt)
exp = _mathfun_real(math.exp, cmath.exp)
cos = _mathfun_real(math.cos, cmath.cos)
sin = _mathfun_real(math.sin, cmath.sin)
tan = _mathfun_real(math.tan, cmath.tan)
acos = _mathfun(math.acos, cmath.acos)
asin = _mathfun(math.asin, cmath.asin)
atan = _mathfun_real(math.atan, cmath.atan)
cosh = _mathfun_real(math.cosh, cmath.cosh)
sinh = _mathfun_real(math.sinh, cmath.sinh)
tanh = _mathfun_real(math.tanh, cmath.tanh)
floor = _mathfun_real(math.floor,
lambda z: complex(math.floor(z.real), math.floor(z.imag)))
ceil = _mathfun_real(math.ceil,
lambda z: complex(math.ceil(z.real), math.ceil(z.imag)))
cos_sin = _mathfun_real(lambda x: (math.cos(x), math.sin(x)),
lambda z: (cmath.cos(z), cmath.sin(z)))
cbrt = _mathfun(lambda x: x**(1./3), lambda z: z**(1./3))
def nthroot(x, n):
r = 1./n
try:
return float(x) ** r
except (ValueError, TypeError):
return complex(x) ** r
def _sinpi_real(x):
if x < 0:
return -_sinpi_real(-x)
n, r = divmod(x, 0.5)
r *= pi
n %= 4
if n == 0: return math.sin(r)
if n == 1: return math.cos(r)
if n == 2: return -math.sin(r)
if n == 3: return -math.cos(r)
def _cospi_real(x):
if x < 0:
x = -x
n, r = divmod(x, 0.5)
r *= pi
n %= 4
if n == 0: return math.cos(r)
if n == 1: return -math.sin(r)
if n == 2: return -math.cos(r)
if n == 3: return math.sin(r)
def _sinpi_complex(z):
if z.real < 0:
return -_sinpi_complex(-z)
n, r = divmod(z.real, 0.5)
z = pi*complex(r, z.imag)
n %= 4
if n == 0: return cmath.sin(z)
if n == 1: return cmath.cos(z)
if n == 2: return -cmath.sin(z)
if n == 3: return -cmath.cos(z)
def _cospi_complex(z):
if z.real < 0:
z = -z
n, r = divmod(z.real, 0.5)
z = pi*complex(r, z.imag)
n %= 4
if n == 0: return cmath.cos(z)
if n == 1: return -cmath.sin(z)
if n == 2: return -cmath.cos(z)
if n == 3: return cmath.sin(z)
cospi = _mathfun_real(_cospi_real, _cospi_complex)
sinpi = _mathfun_real(_sinpi_real, _sinpi_complex)
def tanpi(x):
try:
return sinpi(x) / cospi(x)
except OverflowError:
if complex(x).imag > 10:
return 1j
if complex(x).imag < 10:
return -1j
raise
def cotpi(x):
try:
return cospi(x) / sinpi(x)
except OverflowError:
if complex(x).imag > 10:
return -1j
if complex(x).imag < 10:
return 1j
raise
INF = 1e300*1e300
NINF = -INF
NAN = INF-INF
EPS = 2.2204460492503131e-16
_exact_gamma = (INF, 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0,
362880.0, 3628800.0, 39916800.0, 479001600.0, 6227020800.0, 87178291200.0,
1307674368000.0, 20922789888000.0, 355687428096000.0, 6402373705728000.0,
121645100408832000.0, 2432902008176640000.0)
_max_exact_gamma = len(_exact_gamma)-1
# Lanczos coefficients used by the GNU Scientific Library
_lanczos_g = 7
_lanczos_p = (0.99999999999980993, 676.5203681218851, -1259.1392167224028,
771.32342877765313, -176.61502916214059, 12.507343278686905,
-0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7)
def _gamma_real(x):
_intx = int(x)
if _intx == x:
if _intx <= 0:
#return (-1)**_intx * INF
raise ZeroDivisionError("gamma function pole")
if _intx <= _max_exact_gamma:
return _exact_gamma[_intx]
if x < 0.5:
# TODO: sinpi
return pi / (_sinpi_real(x)*_gamma_real(1-x))
else:
x -= 1.0
r = _lanczos_p[0]
for i in range(1, _lanczos_g+2):
r += _lanczos_p[i]/(x+i)
t = x + _lanczos_g + 0.5
return 2.506628274631000502417 * t**(x+0.5) * math.exp(-t) * r
def _gamma_complex(x):
if not x.imag:
return complex(_gamma_real(x.real))
if x.real < 0.5:
# TODO: sinpi
return pi / (_sinpi_complex(x)*_gamma_complex(1-x))
else:
x -= 1.0
r = _lanczos_p[0]
for i in range(1, _lanczos_g+2):
r += _lanczos_p[i]/(x+i)
t = x + _lanczos_g + 0.5
return 2.506628274631000502417 * t**(x+0.5) * cmath.exp(-t) * r
gamma = _mathfun_real(_gamma_real, _gamma_complex)
def factorial(x):
return gamma(x+1.0)
def arg(x):
if type(x) is float:
return math.atan2(0.0,x)
return math.atan2(x.imag,x.real)
# XXX: broken for negatives
def loggamma(x):
if type(x) not in (float, complex):
try:
x = float(x)
except (ValueError, TypeError):
x = complex(x)
try:
xreal = x.real
ximag = x.imag
except AttributeError: # py2.5
xreal = x
ximag = 0.0
# Reflection formula
# http://functions.wolfram.com/GammaBetaErf/LogGamma/16/01/01/0003/
if xreal < 0.0:
if abs(x) < 0.5:
v = log(gamma(x))
if ximag == 0:
v = v.conjugate()
return v
z = 1-x
try:
re = z.real
im = z.imag
except AttributeError: # py2.5
re = z
im = 0.0
refloor = floor(re)
imsign = cmp(im, 0)
return (-pi*1j)*abs(refloor)*(1-abs(imsign)) + logpi - \
log(sinpi(z-refloor)) - loggamma(z) + 1j*pi*refloor*imsign
if x == 1.0 or x == 2.0:
return x*0
p = 0.
while abs(x) < 11:
p -= log(x)
x += 1.0
s = 0.918938533204672742 + (x-0.5)*log(x) - x
r = 1./x
r2 = r*r
s += 0.083333333333333333333*r; r *= r2
s += -0.0027777777777777777778*r; r *= r2
s += 0.00079365079365079365079*r; r *= r2
s += -0.0005952380952380952381*r; r *= r2
s += 0.00084175084175084175084*r; r *= r2
s += -0.0019175269175269175269*r; r *= r2
s += 0.0064102564102564102564*r; r *= r2
s += -0.02955065359477124183*r
return s + p
_psi_coeff = [
0.083333333333333333333,
-0.0083333333333333333333,
0.003968253968253968254,
-0.0041666666666666666667,
0.0075757575757575757576,
-0.021092796092796092796,
0.083333333333333333333,
-0.44325980392156862745,
3.0539543302701197438,
-26.456212121212121212]
def _digamma_real(x):
_intx = int(x)
if _intx == x:
if _intx <= 0:
raise ZeroDivisionError("polygamma pole")
if x < 0.5:
x = 1.0-x
s = pi*cotpi(x)
else:
s = 0.0
while x < 10.0:
s -= 1.0/x
x += 1.0
x2 = x**-2
t = x2
for c in _psi_coeff:
s -= c*t
if t < 1e-20:
break
t *= x2
return s + math.log(x) - 0.5/x
def _digamma_complex(x):
if not x.imag:
return complex(_digamma_real(x.real))
if x.real < 0.5:
x = 1.0-x
s = pi*cotpi(x)
else:
s = 0.0
while abs(x) < 10.0:
s -= 1.0/x
x += 1.0
x2 = x**-2
t = x2
for c in _psi_coeff:
s -= c*t
if abs(t) < 1e-20:
break
t *= x2
return s + cmath.log(x) - 0.5/x
digamma = _mathfun_real(_digamma_real, _digamma_complex)
# TODO: could implement complex erf and erfc here. Need
# to find an accurate method (avoiding cancellation)
# for approx. 1 < abs(x) < 9.
_erfc_coeff_P = [
1.0000000161203922312,
2.1275306946297962644,
2.2280433377390253297,
1.4695509105618423961,
0.66275911699770787537,
0.20924776504163751585,
0.045459713768411264339,
0.0063065951710717791934,
0.00044560259661560421715][::-1]
_erfc_coeff_Q = [
1.0000000000000000000,
3.2559100272784894318,
4.9019435608903239131,
4.4971472894498014205,
2.7845640601891186528,
1.2146026030046904138,
0.37647108453729465912,
0.080970149639040548613,
0.011178148899483545902,
0.00078981003831980423513][::-1]
def _polyval(coeffs, x):
p = coeffs[0]
for c in coeffs[1:]:
p = c + x*p
return p
def _erf_taylor(x):
# Taylor series assuming 0 <= x <= 1
x2 = x*x
s = t = x
n = 1
while abs(t) > 1e-17:
t *= x2/n
s -= t/(n+n+1)
n += 1
t *= x2/n
s += t/(n+n+1)
n += 1
return 1.1283791670955125739*s
def _erfc_mid(x):
# Rational approximation assuming 0 <= x <= 9
return exp(-x*x)*_polyval(_erfc_coeff_P,x)/_polyval(_erfc_coeff_Q,x)
def _erfc_asymp(x):
# Asymptotic expansion assuming x >= 9
x2 = x*x
v = exp(-x2)/x*0.56418958354775628695
r = t = 0.5 / x2
s = 1.0
for n in range(1,22,4):
s -= t
t *= r * (n+2)
s += t
t *= r * (n+4)
if abs(t) < 1e-17:
break
return s * v
def erf(x):
"""
erf of a real number.
"""
x = float(x)
if x != x:
return x
if x < 0.0:
return -erf(-x)
if x >= 1.0:
if x >= 6.0:
return 1.0
return 1.0 - _erfc_mid(x)
return _erf_taylor(x)
def erfc(x):
"""
erfc of a real number.
"""
x = float(x)
if x != x:
return x
if x < 0.0:
if x < -6.0:
return 2.0
return 2.0-erfc(-x)
if x > 9.0:
return _erfc_asymp(x)
if x >= 1.0:
return _erfc_mid(x)
return 1.0 - _erf_taylor(x)
gauss42 = [\
(0.99839961899006235, 0.0041059986046490839),
(-0.99839961899006235, 0.0041059986046490839),
(0.9915772883408609, 0.009536220301748501),
(-0.9915772883408609,0.009536220301748501),
(0.97934250806374812, 0.014922443697357493),
(-0.97934250806374812, 0.014922443697357493),
(0.96175936533820439,0.020227869569052644),
(-0.96175936533820439, 0.020227869569052644),
(0.93892355735498811, 0.025422959526113047),
(-0.93892355735498811,0.025422959526113047),
(0.91095972490412735, 0.030479240699603467),
(-0.91095972490412735, 0.030479240699603467),
(0.87802056981217269,0.03536907109759211),
(-0.87802056981217269, 0.03536907109759211),
(0.8402859832618168, 0.040065735180692258),
(-0.8402859832618168,0.040065735180692258),
(0.7979620532554873, 0.044543577771965874),
(-0.7979620532554873, 0.044543577771965874),
(0.75127993568948048,0.048778140792803244),
(-0.75127993568948048, 0.048778140792803244),
(0.70049459055617114, 0.052746295699174064),
(-0.70049459055617114,0.052746295699174064),
(0.64588338886924779, 0.056426369358018376),
(-0.64588338886924779, 0.056426369358018376),
(0.58774459748510932, 0.059798262227586649),
(-0.58774459748510932, 0.059798262227586649),
(0.5263957499311922, 0.062843558045002565),
(-0.5263957499311922, 0.062843558045002565),
(0.46217191207042191, 0.065545624364908975),
(-0.46217191207042191, 0.065545624364908975),
(0.39542385204297503, 0.067889703376521934),
(-0.39542385204297503, 0.067889703376521934),
(0.32651612446541151, 0.069862992492594159),
(-0.32651612446541151, 0.069862992492594159),
(0.25582507934287907, 0.071454714265170971),
(-0.25582507934287907, 0.071454714265170971),
(0.18373680656485453, 0.072656175243804091),
(-0.18373680656485453, 0.072656175243804091),
(0.11064502720851986, 0.073460813453467527),
(-0.11064502720851986, 0.073460813453467527),
(0.036948943165351772, 0.073864234232172879),
(-0.036948943165351772, 0.073864234232172879)]
EI_ASYMP_CONVERGENCE_RADIUS = 40.0
def ei_asymp(z, _e1=False):
r = 1./z
s = t = 1.0
k = 1
while 1:
t *= k*r
s += t
if abs(t) < 1e-16:
break
k += 1
v = s*exp(z)/z
if _e1:
if type(z) is complex:
zreal = z.real
zimag = z.imag
else:
zreal = z
zimag = 0.0
if zimag == 0.0 and zreal > 0.0:
v += pi*1j
else:
if type(z) is complex:
if z.imag > 0:
v += pi*1j
if z.imag < 0:
v -= pi*1j
return v
def ei_taylor(z, _e1=False):
s = t = z
k = 2
while 1:
t = t*z/k
term = t/k
if abs(term) < 1e-17:
break
s += term
k += 1
s += euler
if _e1:
s += log(-z)
else:
if type(z) is float or z.imag == 0.0:
s += math.log(abs(z))
else:
s += cmath.log(z)
return s
def ei(z, _e1=False):
typez = type(z)
if typez not in (float, complex):
try:
z = float(z)
typez = float
except (TypeError, ValueError):
z = complex(z)
typez = complex
if not z:
return -INF
absz = abs(z)
if absz > EI_ASYMP_CONVERGENCE_RADIUS:
return ei_asymp(z, _e1)
elif absz <= 2.0 or (typez is float and z > 0.0):
return ei_taylor(z, _e1)
# Integrate, starting from whichever is smaller of a Taylor
# series value or an asymptotic series value
if typez is complex and z.real > 0.0:
zref = z / absz
ref = ei_taylor(zref, _e1)
else:
zref = EI_ASYMP_CONVERGENCE_RADIUS * z / absz
ref = ei_asymp(zref, _e1)
C = (zref-z)*0.5
D = (zref+z)*0.5
s = 0.0
if type(z) is complex:
_exp = cmath.exp
else:
_exp = math.exp
for x,w in gauss42:
t = C*x+D
s += w*_exp(t)/t
ref -= C*s
return ref
def e1(z):
# hack to get consistent signs if the imaginary part if 0
# and signed
typez = type(z)
if type(z) not in (float, complex):
try:
z = float(z)
typez = float
except (TypeError, ValueError):
z = complex(z)
typez = complex
if typez is complex and not z.imag:
z = complex(z.real, 0.0)
# end hack
return -ei(-z, _e1=True)
_zeta_int = [\
-0.5,
0.0,
1.6449340668482264365,1.2020569031595942854,1.0823232337111381915,
1.0369277551433699263,1.0173430619844491397,1.0083492773819228268,
1.0040773561979443394,1.0020083928260822144,1.0009945751278180853,
1.0004941886041194646,1.0002460865533080483,1.0001227133475784891,
1.0000612481350587048,1.0000305882363070205,1.0000152822594086519,
1.0000076371976378998,1.0000038172932649998,1.0000019082127165539,
1.0000009539620338728,1.0000004769329867878,1.0000002384505027277,
1.0000001192199259653,1.0000000596081890513,1.0000000298035035147,
1.0000000149015548284]
_zeta_P = [-3.50000000087575873, -0.701274355654678147,
-0.0672313458590012612, -0.00398731457954257841,
-0.000160948723019303141, -4.67633010038383371e-6,
-1.02078104417700585e-7, -1.68030037095896287e-9,
-1.85231868742346722e-11][::-1]
_zeta_Q = [1.00000000000000000, -0.936552848762465319,
-0.0588835413263763741, -0.00441498861482948666,
-0.000143416758067432622, -5.10691659585090782e-6,
-9.58813053268913799e-8, -1.72963791443181972e-9,
-1.83527919681474132e-11][::-1]
_zeta_1 = [3.03768838606128127e-10, -1.21924525236601262e-8,
2.01201845887608893e-7, -1.53917240683468381e-6,
-5.09890411005967954e-7, 0.000122464707271619326,
-0.000905721539353130232, -0.00239315326074843037,
0.084239750013159168, 0.418938517907442414, 0.500000001921884009]
_zeta_0 = [-3.46092485016748794e-10, -6.42610089468292485e-9,
1.76409071536679773e-7, -1.47141263991560698e-6, -6.38880222546167613e-7,
0.000122641099800668209, -0.000905894913516772796, -0.00239303348507992713,
0.0842396947501199816, 0.418938533204660256, 0.500000000000000052]
def zeta(s):
"""
Riemann zeta function, real argument
"""
if not isinstance(s, (float, int)):
try:
s = float(s)
except (ValueError, TypeError):
try:
s = complex(s)
if not s.imag:
return complex(zeta(s.real))
except (ValueError, TypeError):
pass
raise NotImplementedError
if s == 1:
raise ValueError("zeta(1) pole")
if s >= 27:
return 1.0 + 2.0**(-s) + 3.0**(-s)
n = int(s)
if n == s:
if n >= 0:
return _zeta_int[n]
if not (n % 2):
return 0.0
if s <= 0.0:
return 2.**s*pi**(s-1)*_sinpi_real(0.5*s)*_gamma_real(1-s)*zeta(1-s)
if s <= 2.0:
if s <= 1.0:
return _polyval(_zeta_0,s)/(s-1)
return _polyval(_zeta_1,s)/(s-1)
z = _polyval(_zeta_P,s) / _polyval(_zeta_Q,s)
return 1.0 + 2.0**(-s) + 3.0**(-s) + 4.0**(-s)*z

View File

@ -1,522 +0,0 @@
# TODO: should use diagonalization-based algorithms
class MatrixCalculusMethods:
def _exp_pade(ctx, a):
"""
Exponential of a matrix using Pade approximants.
See G. H. Golub, C. F. van Loan 'Matrix Computations',
third Ed., page 572
TODO:
- find a good estimate for q
- reduce the number of matrix multiplications to improve
performance
"""
def eps_pade(p):
return ctx.mpf(2)**(3-2*p) * \
ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
q = 4
extraq = 8
while 1:
if eps_pade(q) < ctx.eps:
break
q += 1
q += extraq
j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
extra = q
prec = ctx.prec
ctx.dps += extra + 3
try:
a = a/2**j
na = a.rows
den = ctx.eye(na)
num = ctx.eye(na)
x = ctx.eye(na)
c = ctx.mpf(1)
for k in range(1, q+1):
c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
x = a*x
cx = c*x
num += cx
den += (-1)**k * cx
f = ctx.lu_solve_mat(den, num)
for k in range(j):
f = f*f
finally:
ctx.prec = prec
return f*1
def expm(ctx, A, method='taylor'):
r"""
Computes the matrix exponential of a square matrix `A`, which is defined
by the power series
.. math ::
\exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
With method='taylor', the matrix exponential is computed
using the Taylor series. With method='pade', Pade approximants
are used instead.
**Examples**
Basic examples::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> expm(zeros(3))
[1.0 0.0 0.0]
[0.0 1.0 0.0]
[0.0 0.0 1.0]
>>> expm(eye(3))
[2.71828182845905 0.0 0.0]
[ 0.0 2.71828182845905 0.0]
[ 0.0 0.0 2.71828182845905]
>>> expm([[1,1,0],[1,0,1],[0,1,0]])
[ 3.86814500615414 2.26812870852145 0.841130841230196]
[ 2.26812870852145 2.44114713886289 1.42699786729125]
[0.841130841230196 1.42699786729125 1.6000162976327]
>>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
[ 3.86814500615414 2.26812870852145 0.841130841230196]
[ 2.26812870852145 2.44114713886289 1.42699786729125]
[0.841130841230196 1.42699786729125 1.6000162976327]
>>> expm([[1+j, 0], [1+j,1]])
[(1.46869393991589 + 2.28735528717884j) 0.0]
[ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
Matrices with large entries are allowed::
>>> expm(matrix([[1,2],[2,3]])**25)
[5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
[9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
noncommuting matrices::
>>> A = hilbert(3)
>>> B = A + eye(3)
>>> chop(mnorm(A*B - B*A))
0.0
>>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
0.0
>>> B = A + ones(3)
>>> mnorm(A*B - B*A)
1.8
>>> mnorm(expm(A+B) - expm(A)*expm(B))
42.0927851137247
"""
A = ctx.matrix(A)
if method == 'pade':
return ctx._exp_pade(A)
prec = ctx.prec
j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
j += int(0.5*prec**0.5)
try:
ctx.prec += 10 + 2*j
tol = +ctx.eps
A = A/2**j
T = A
Y = A**0 + A
k = 2
while 1:
T *= A * (1/ctx.mpf(k))
if ctx.mnorm(T, 'inf') < tol:
break
Y += T
k += 1
for k in xrange(j):
Y = Y*Y
finally:
ctx.prec = prec
Y *= 1
return Y
def cosm(ctx, A):
r"""
Gives the cosine of a square matrix `A`, defined in analogy
with the matrix exponential.
Examples::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> X = eye(3)
>>> cosm(X)
[0.54030230586814 0.0 0.0]
[ 0.0 0.54030230586814 0.0]
[ 0.0 0.0 0.54030230586814]
>>> X = hilbert(3)
>>> cosm(X)
[ 0.424403834569555 -0.316643413047167 -0.221474945949293]
[-0.316643413047167 0.820646708837824 -0.127183694770039]
[-0.221474945949293 -0.127183694770039 0.909236687217541]
>>> X = matrix([[1+j,-2],[0,-j]])
>>> cosm(X)
[(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
[ 0.0 (1.54308063481524 + 0.0j)]
"""
B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
if not sum(A.apply(ctx.im).apply(abs)):
B = B.apply(ctx.re)
return B
def sinm(ctx, A):
r"""
Gives the sine of a square matrix `A`, defined in analogy
with the matrix exponential.
Examples::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> X = eye(3)
>>> sinm(X)
[0.841470984807897 0.0 0.0]
[ 0.0 0.841470984807897 0.0]
[ 0.0 0.0 0.841470984807897]
>>> X = hilbert(3)
>>> sinm(X)
[0.711608512150994 0.339783913247439 0.220742837314741]
[0.339783913247439 0.244113865695532 0.187231271174372]
[0.220742837314741 0.187231271174372 0.155816730769635]
>>> X = matrix([[1+j,-2],[0,-j]])
>>> sinm(X)
[(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
[ 0.0 (0.0 - 1.1752011936438j)]
"""
B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
if not sum(A.apply(ctx.im).apply(abs)):
B = B.apply(ctx.re)
return B
def _sqrtm_rot(ctx, A, _may_rotate):
# If the iteration fails to converge, cheat by performing
# a rotation by a complex number
u = ctx.j**0.3
return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
def sqrtm(ctx, A, _may_rotate=2):
r"""
Computes a square root of the square matrix `A`, i.e. returns
a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
of a matrix, if it exists, is not unique.
**Examples**
Square roots of some simple matrices::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> sqrtm([[1,0], [0,1]])
[1.0 0.0]
[0.0 1.0]
>>> sqrtm([[0,0], [0,0]])
[0.0 0.0]
[0.0 0.0]
>>> sqrtm([[2,0],[0,1]])
[1.4142135623731 0.0]
[ 0.0 1.0]
>>> sqrtm([[1,1],[1,0]])
[ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
[(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
>>> sqrtm([[1,0],[0,1]])
[1.0 0.0]
[0.0 1.0]
>>> sqrtm([[-1,0],[0,1]])
[(0.0 - 1.0j) 0.0]
[ 0.0 (1.0 + 0.0j)]
>>> sqrtm([[j,0],[0,j]])
[(0.707106781186547 + 0.707106781186547j) 0.0]
[ 0.0 (0.707106781186547 + 0.707106781186547j)]
A square root of a rotation matrix, giving the corresponding
half-angle rotation matrix::
>>> t1 = 0.75
>>> t2 = t1 * 0.5
>>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
>>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
>>> sqrtm(A1)
[0.930507621912314 -0.366272529086048]
[0.366272529086048 0.930507621912314]
>>> A2
[0.930507621912314 -0.366272529086048]
[0.366272529086048 0.930507621912314]
The identity `(A^2)^{1/2} = A` does not necessarily hold::
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
>>> sqrtm(A**2)
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> sqrtm(A)**2
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
>>> sqrtm(A**2)
[ 7.43715112194995 -0.324127569985474 1.8481718827526]
[-0.251549715716942 9.32699765900402 2.48221180985147]
[ 4.11609388833616 0.775751877098258 13.017955697342]
>>> chop(sqrtm(A)**2)
[-4.0 1.0 4.0]
[ 7.0 -8.0 9.0]
[10.0 2.0 11.0]
For some matrices, a square root does not exist::
>>> sqrtm([[0,1], [0,0]])
Traceback (most recent call last):
...
ZeroDivisionError: matrix is numerically singular
Two examples from the documentation for Matlab's ``sqrtm``::
>>> mp.dps = 15; mp.pretty = True
>>> sqrtm([[7,10],[15,22]])
[1.56669890360128 1.74077655955698]
[2.61116483933547 4.17786374293675]
>>>
>>> X = matrix(\
... [[5,-4,1,0,0],
... [-4,6,-4,1,0],
... [1,-4,6,-4,1],
... [0,1,-4,6,-4],
... [0,0,1,-4,5]])
>>> Y = matrix(\
... [[2,-1,-0,-0,-0],
... [-1,2,-1,0,-0],
... [0,-1,2,-1,0],
... [-0,0,-1,2,-1],
... [-0,-0,-0,-1,2]])
>>> mnorm(sqrtm(X) - Y)
4.53155328326114e-19
"""
A = ctx.matrix(A)
# Trivial
if A*0 == A:
return A
prec = ctx.prec
if _may_rotate:
d = ctx.det(A)
if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
return ctx._sqrtm_rot(A, _may_rotate-1)
try:
ctx.prec += 10
tol = ctx.eps * 128
Y = A
Z = I = A**0
k = 0
# Denman-Beavers iteration
while 1:
Yprev = Y
try:
Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
except ZeroDivisionError:
if _may_rotate:
Y = ctx._sqrtm_rot(A, _may_rotate-1)
break
else:
raise
mag1 = ctx.mnorm(Y-Yprev, 'inf')
mag2 = ctx.mnorm(Y, 'inf')
if mag1 <= mag2*tol:
break
if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
return ctx._sqrtm_rot(A, _may_rotate-1)
k += 1
if k > ctx.prec:
raise ctx.NoConvergence
finally:
ctx.prec = prec
Y *= 1
return Y
def logm(ctx, A):
r"""
Computes a logarithm of the square matrix `A`, i.e. returns
a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
of a matrix, if it exists, is not unique.
**Examples**
Logarithms of some simple matrices::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> X = eye(3)
>>> logm(X)
[0.0 0.0 0.0]
[0.0 0.0 0.0]
[0.0 0.0 0.0]
>>> logm(2*X)
[0.693147180559945 0.0 0.0]
[ 0.0 0.693147180559945 0.0]
[ 0.0 0.0 0.693147180559945]
>>> logm(expm(X))
[1.0 0.0 0.0]
[0.0 1.0 0.0]
[0.0 0.0 1.0]
A logarithm of a complex matrix::
>>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
>>> B = logm(X)
>>> nprint(B)
[ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
[ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
[(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
>>> chop(expm(B))
[(2.0 + 1.0j) 1.0 3.0]
[(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
[ -4.0 -5.0 (0.0 + 1.0j)]
A matrix `X` close to the identity matrix, for which
`\log(\exp(X)) = \exp(\log(X)) = X` holds::
>>> X = eye(3) + hilbert(3)/4
>>> X
[ 1.25 0.125 0.0833333333333333]
[ 0.125 1.08333333333333 0.0625]
[0.0833333333333333 0.0625 1.05]
>>> logm(expm(X))
[ 1.25 0.125 0.0833333333333333]
[ 0.125 1.08333333333333 0.0625]
[0.0833333333333333 0.0625 1.05]
>>> expm(logm(X))
[ 1.25 0.125 0.0833333333333333]
[ 0.125 1.08333333333333 0.0625]
[0.0833333333333333 0.0625 1.05]
A logarithm of a rotation matrix, giving back the angle of
the rotation::
>>> t = 3.7
>>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
>>> chop(logm(A))
[ 0.0 -2.58318530717959]
[2.58318530717959 0.0]
>>> (2*pi-t)
2.58318530717959
For some matrices, a logarithm does not exist::
>>> logm([[1,0], [0,0]])
Traceback (most recent call last):
...
ZeroDivisionError: matrix is numerically singular
Logarithm of a matrix with large entries::
>>> logm(hilbert(3) * 10**20).apply(re)
[ 45.5597513593433 1.27721006042799 0.317662687717978]
[ 1.27721006042799 42.5222778973542 2.24003708791604]
[0.317662687717978 2.24003708791604 42.395212822267]
"""
A = ctx.matrix(A)
prec = ctx.prec
try:
ctx.prec += 10
tol = ctx.eps * 128
I = A**0
B = A
n = 0
while 1:
B = ctx.sqrtm(B)
n += 1
if ctx.mnorm(B-I, 'inf') < 0.125:
break
T = X = B-I
L = X*0
k = 1
while 1:
if k & 1:
L += T / k
else:
L -= T / k
T *= X
if ctx.mnorm(T, 'inf') < tol:
break
k += 1
if k > ctx.prec:
raise ctx.NoConvergence
finally:
ctx.prec = prec
L *= 2**n
return L
def powm(ctx, A, r):
r"""
Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
number `r`.
**Examples**
Powers and inverse powers of a matrix::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
>>> powm(A, 2)
[ 63.0 20.0 69.0]
[174.0 89.0 199.0]
[164.0 48.0 179.0]
>>> chop(powm(powm(A, 4), 1/4.))
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> powm(extraprec(20)(powm)(A, -4), -1/4.)
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
A Fibonacci-generating matrix::
>>> powm([[1,1],[1,0]], 10)
[89.0 55.0]
[55.0 34.0]
>>> fib(10)
55.0
>>> powm([[1,1],[1,0]], 6.5)
[(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
[(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
>>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
(10.2078589271083 - 0.0195927472575932j)
>>> powm([[1,1],[1,0]], 6.2)
[ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
[(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
>>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
(8.81733464837593 - 0.0133048601383712j)
"""
A = ctx.matrix(A)
r = ctx.convert(r)
prec = ctx.prec
try:
ctx.prec += 10
if ctx.isint(r):
v = A ** int(r)
elif ctx.isint(r*2):
y = int(r*2)
v = ctx.sqrtm(A) ** y
else:
v = ctx.expm(r*ctx.logm(A))
finally:
ctx.prec = prec
v *= 1
return v

View File

@ -1,516 +0,0 @@
"""
Linear algebra
--------------
Linear equations
................
Basic linear algebra is implemented; you can for example solve the linear
equation system::
x + 2*y = -10
3*x + 4*y = 10
using ``lu_solve``::
>>> A = matrix([[1, 2], [3, 4]])
>>> b = matrix([-10, 10])
>>> x = lu_solve(A, b)
>>> x
matrix(
[['30.0'],
['-20.0']])
If you don't trust the result, use ``residual`` to calculate the residual ||A*x-b||::
>>> residual(A, x, b)
matrix(
[['3.46944695195361e-18'],
['3.46944695195361e-18']])
>>> str(eps)
'2.22044604925031e-16'
As you can see, the solution is quite accurate. The error is caused by the
inaccuracy of the internal floating point arithmetic. Though, it's even smaller
than the current machine epsilon, which basically means you can trust the
result.
If you need more speed, use NumPy. Or choose a faster data type using the
keyword ``force_type``::
>>> lu_solve(A, b, force_type=float)
matrix(
[[29.999999999999996],
[-19.999999999999996]])
``lu_solve`` accepts overdetermined systems. It is usually not possible to solve
such systems, so the residual is minimized instead. Internally this is done
using Cholesky decomposition to compute a least squares approximation. This means
that that ``lu_solve`` will square the errors. If you can't afford this, use
``qr_solve`` instead. It is twice as slow but more accurate, and it calculates
the residual automatically.
Matrix factorization
....................
The function ``lu`` computes an explicit LU factorization of a matrix::
>>> P, L, U = lu(matrix([[0,2,3],[4,5,6],[7,8,9]]))
>>> print P
[0.0 0.0 1.0]
[1.0 0.0 0.0]
[0.0 1.0 0.0]
>>> print L
[ 1.0 0.0 0.0]
[ 0.0 1.0 0.0]
[0.571428571428571 0.214285714285714 1.0]
>>> print U
[7.0 8.0 9.0]
[0.0 2.0 3.0]
[0.0 0.0 0.214285714285714]
>>> print P.T*L*U
[0.0 2.0 3.0]
[4.0 5.0 6.0]
[7.0 8.0 9.0]
Interval matrices
-----------------
Matrices may contain interval elements. This allows one to perform
basic linear algebra operations such as matrix multiplication
and equation solving with rigorous error bounds::
>>> a = matrix([['0.1','0.3','1.0'],
... ['7.1','5.5','4.8'],
... ['3.2','4.4','5.6']], force_type=mpi)
>>>
>>> b = matrix(['4','0.6','0.5'], force_type=mpi)
>>> c = lu_solve(a, b)
>>> c
matrix(
[[[5.2582327113062393041, 5.2582327113062749951]],
[[-13.155049396267856583, -13.155049396267821167]],
[[7.4206915477497212555, 7.4206915477497310922]]])
>>> print a*c
[ [3.9999999999999866773, 4.0000000000000133227]]
[[0.59999999999972430942, 0.60000000000027142733]]
[[0.49999999999982236432, 0.50000000000018474111]]
"""
# TODO:
# *implement high-level qr()
# *test unitvector
# *iterative solving
from copy import copy
class LinearAlgebraMethods(object):
def LU_decomp(ctx, A, overwrite=False, use_cache=True):
"""
LU-factorization of a n*n matrix using the Gauss algorithm.
Returns L and U in one matrix and the pivot indices.
Use overwrite to specify whether A will be overwritten with L and U.
"""
if not A.rows == A.cols:
raise ValueError('need n*n matrix')
# get from cache if possible
if use_cache and isinstance(A, ctx.matrix) and A._LU:
return A._LU
if not overwrite:
orig = A
A = A.copy()
tol = ctx.absmin(ctx.mnorm(A,1) * ctx.eps) # each pivot element has to be bigger
n = A.rows
p = [None]*(n - 1)
for j in xrange(n - 1):
# pivoting, choose max(abs(reciprocal row sum)*abs(pivot element))
biggest = 0
for k in xrange(j, n):
s = ctx.fsum([ctx.absmin(A[k,l]) for l in xrange(j, n)])
if ctx.absmin(s) <= tol:
raise ZeroDivisionError('matrix is numerically singular')
current = 1/s * ctx.absmin(A[k,j])
if current > biggest: # TODO: what if equal?
biggest = current
p[j] = k
# swap rows according to p
ctx.swap_row(A, j, p[j])
if ctx.absmin(A[j,j]) <= tol:
raise ZeroDivisionError('matrix is numerically singular')
# calculate elimination factors and add rows
for i in xrange(j + 1, n):
A[i,j] /= A[j,j]
for k in xrange(j + 1, n):
A[i,k] -= A[i,j]*A[j,k]
if ctx.absmin(A[n - 1,n - 1]) <= tol:
raise ZeroDivisionError('matrix is numerically singular')
# cache decomposition
if not overwrite and isinstance(orig, ctx.matrix):
orig._LU = (A, p)
return A, p
def L_solve(ctx, L, b, p=None):
"""
Solve the lower part of a LU factorized matrix for y.
"""
assert L.rows == L.cols, 'need n*n matrix'
n = L.rows
assert len(b) == n
b = copy(b)
if p: # swap b according to p
for k in xrange(0, len(p)):
ctx.swap_row(b, k, p[k])
# solve
for i in xrange(1, n):
for j in xrange(i):
b[i] -= L[i,j] * b[j]
return b
def U_solve(ctx, U, y):
"""
Solve the upper part of a LU factorized matrix for x.
"""
assert U.rows == U.cols, 'need n*n matrix'
n = U.rows
assert len(y) == n
x = copy(y)
for i in xrange(n - 1, -1, -1):
for j in xrange(i + 1, n):
x[i] -= U[i,j] * x[j]
x[i] /= U[i,i]
return x
def lu_solve(ctx, A, b, **kwargs):
"""
Ax = b => x
Solve a determined or overdetermined linear equations system.
Fast LU decomposition is used, which is less accurate than QR decomposition
(especially for overdetermined systems), but it's twice as efficient.
Use qr_solve if you want more precision or have to solve a very ill-
conditioned system.
If you specify real=True, it does not check for overdeterminded complex
systems.
"""
prec = ctx.prec
try:
ctx.prec += 10
# do not overwrite A nor b
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
if A.rows < A.cols:
raise ValueError('cannot solve underdetermined system')
if A.rows > A.cols:
# use least-squares method if overdetermined
# (this increases errors)
AH = A.H
A = AH * A
b = AH * b
if (kwargs.get('real', False) or
not sum(type(i) is ctx.mpc for i in A)):
# TODO: necessary to check also b?
x = ctx.cholesky_solve(A, b)
else:
x = ctx.lu_solve(A, b)
else:
# LU factorization
A, p = ctx.LU_decomp(A)
b = ctx.L_solve(A, b, p)
x = ctx.U_solve(A, b)
finally:
ctx.prec = prec
return x
def improve_solution(ctx, A, x, b, maxsteps=1):
"""
Improve a solution to a linear equation system iteratively.
This re-uses the LU decomposition and is thus cheap.
Usually 3 up to 4 iterations are giving the maximal improvement.
"""
assert A.rows == A.cols, 'need n*n matrix' # TODO: really?
for _ in xrange(maxsteps):
r = ctx.residual(A, x, b)
if ctx.norm(r, 2) < 10*ctx.eps:
break
# this uses cached LU decomposition and is thus cheap
dx = ctx.lu_solve(A, -r)
x += dx
return x
def lu(ctx, A):
"""
A -> P, L, U
LU factorisation of a square matrix A. L is the lower, U the upper part.
P is the permutation matrix indicating the row swaps.
P*A = L*U
If you need efficiency, use the low-level method LU_decomp instead, it's
much more memory efficient.
"""
# get factorization
A, p = ctx.LU_decomp(A)
n = A.rows
L = ctx.matrix(n)
U = ctx.matrix(n)
for i in xrange(n):
for j in xrange(n):
if i > j:
L[i,j] = A[i,j]
elif i == j:
L[i,j] = 1
U[i,j] = A[i,j]
else:
U[i,j] = A[i,j]
# calculate permutation matrix
P = ctx.eye(n)
for k in xrange(len(p)):
ctx.swap_row(P, k, p[k])
return P, L, U
def unitvector(ctx, n, i):
"""
Return the i-th n-dimensional unit vector.
"""
assert 0 < i <= n, 'this unit vector does not exist'
return [ctx.zero]*(i-1) + [ctx.one] + [ctx.zero]*(n-i)
def inverse(ctx, A, **kwargs):
"""
Calculate the inverse of a matrix.
If you want to solve an equation system Ax = b, it's recommended to use
solve(A, b) instead, it's about 3 times more efficient.
"""
prec = ctx.prec
try:
ctx.prec += 10
# do not overwrite A
A = ctx.matrix(A, **kwargs).copy()
n = A.rows
# get LU factorisation
A, p = ctx.LU_decomp(A)
cols = []
# calculate unit vectors and solve corresponding system to get columns
for i in xrange(1, n + 1):
e = ctx.unitvector(n, i)
y = ctx.L_solve(A, e, p)
cols.append(ctx.U_solve(A, y))
# convert columns to matrix
inv = []
for i in xrange(n):
row = []
for j in xrange(n):
row.append(cols[j][i])
inv.append(row)
result = ctx.matrix(inv, **kwargs)
finally:
ctx.prec = prec
return result
def householder(ctx, A):
"""
(A|b) -> H, p, x, res
(A|b) is the coefficient matrix with left hand side of an optionally
overdetermined linear equation system.
H and p contain all information about the transformation matrices.
x is the solution, res the residual.
"""
assert isinstance(A, ctx.matrix)
m = A.rows
n = A.cols
assert m >= n - 1
# calculate Householder matrix
p = []
for j in xrange(0, n - 1):
s = ctx.fsum((A[i,j])**2 for i in xrange(j, m))
if not abs(s) > ctx.eps:
raise ValueError('matrix is numerically singular')
p.append(-ctx.sign(A[j,j]) * ctx.sqrt(s))
kappa = ctx.one / (s - p[j] * A[j,j])
A[j,j] -= p[j]
for k in xrange(j+1, n):
y = ctx.fsum(A[i,j] * A[i,k] for i in xrange(j, m)) * kappa
for i in xrange(j, m):
A[i,k] -= A[i,j] * y
# solve Rx = c1
x = [A[i,n - 1] for i in xrange(n - 1)]
for i in xrange(n - 2, -1, -1):
x[i] -= ctx.fsum(A[i,j] * x[j] for j in xrange(i + 1, n - 1))
x[i] /= p[i]
# calculate residual
if not m == n - 1:
r = [A[m-1-i, n-1] for i in xrange(m - n + 1)]
else:
# determined system, residual should be 0
r = [0]*m # maybe a bad idea, changing r[i] will change all elements
return A, p, x, r
#def qr(ctx, A):
# """
# A -> Q, R
#
# QR factorisation of a square matrix A using Householder decomposition.
# Q is orthogonal, this leads to very few numerical errors.
#
# A = Q*R
# """
# H, p, x, res = householder(A)
# TODO: implement this
def residual(ctx, A, x, b, **kwargs):
"""
Calculate the residual of a solution to a linear equation system.
r = A*x - b for A*x = b
"""
oldprec = ctx.prec
try:
ctx.prec *= 2
A, x, b = ctx.matrix(A, **kwargs), ctx.matrix(x, **kwargs), ctx.matrix(b, **kwargs)
return A*x - b
finally:
ctx.prec = oldprec
def qr_solve(ctx, A, b, norm=None, **kwargs):
"""
Ax = b => x, ||Ax - b||
Solve a determined or overdetermined linear equations system and
calculate the norm of the residual (error).
QR decomposition using Householder factorization is applied, which gives very
accurate results even for ill-conditioned matrices. qr_solve is twice as
efficient.
"""
if norm is None:
norm = ctx.norm
prec = ctx.prec
try:
prec += 10
# do not overwrite A nor b
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
if A.rows < A.cols:
raise ValueError('cannot solve underdetermined system')
H, p, x, r = ctx.householder(ctx.extend(A, b))
res = ctx.norm(r)
# calculate residual "manually" for determined systems
if res == 0:
res = ctx.norm(ctx.residual(A, x, b))
return ctx.matrix(x, **kwargs), res
finally:
ctx.prec = prec
# TODO: possible for complex matrices? -> have a look at GSL
def cholesky(ctx, A):
"""
Cholesky decomposition of a symmetric positive-definite matrix.
Can be used to solve linear equation systems twice as efficient compared
to LU decomposition or to test whether A is positive-definite.
A = L * L.T
Only L (the lower part) is returned.
"""
assert isinstance(A, ctx.matrix)
if not A.rows == A.cols:
raise ValueError('need n*n matrix')
n = A.rows
L = ctx.matrix(n)
for j in xrange(n):
s = A[j,j] - ctx.fsum(L[j,k]**2 for k in xrange(j))
if s < ctx.eps:
raise ValueError('matrix not positive-definite')
L[j,j] = ctx.sqrt(s)
for i in xrange(j, n):
L[i,j] = (A[i,j] - ctx.fsum(L[i,k] * L[j,k] for k in xrange(j))) \
/ L[j,j]
return L
def cholesky_solve(ctx, A, b, **kwargs):
"""
Ax = b => x
Solve a symmetric positive-definite linear equation system.
This is twice as efficient as lu_solve.
Typical use cases:
* A.T*A
* Hessian matrix
* differential equations
"""
prec = ctx.prec
try:
prec += 10
# do not overwrite A nor b
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
if A.rows != A.cols:
raise ValueError('can only solve determined system')
# Cholesky factorization
L = ctx.cholesky(A)
# solve
n = L.rows
assert len(b) == n
for i in xrange(n):
b[i] -= ctx.fsum(L[i,j] * b[j] for j in xrange(i))
b[i] /= L[i,i]
x = ctx.U_solve(L.T, b)
return x
finally:
ctx.prec = prec
def det(ctx, A):
"""
Calculate the determinant of a matrix.
"""
prec = ctx.prec
try:
# do not overwrite A
A = ctx.matrix(A).copy()
# use LU factorization to calculate determinant
try:
R, p = ctx.LU_decomp(A)
except ZeroDivisionError:
return 0
z = 1
for i, e in enumerate(p):
if i != e:
z *= -1
for i in xrange(A.rows):
z *= R[i,i]
return z
finally:
ctx.prec = prec
def cond(ctx, A, norm=None):
"""
Calculate the condition number of a matrix using a specified matrix norm.
The condition number estimates the sensitivity of a matrix to errors.
Example: small input errors for ill-conditioned coefficient matrices
alter the solution of the system dramatically.
For ill-conditioned matrices it's recommended to use qr_solve() instead
of lu_solve(). This does not help with input errors however, it just avoids
to add additional errors.
Definition: cond(A) = ||A|| * ||A**-1||
"""
if norm is None:
norm = lambda x: ctx.mnorm(x,1)
return norm(A) * norm(ctx.inverse(A))
def lu_solve_mat(ctx, a, b):
"""Solve a * x = b where a and b are matrices."""
r = ctx.matrix(a.rows, b.cols)
for i in range(b.cols):
c = ctx.lu_solve(a, b.column(i))
for j in range(len(c)):
r[j, i] = c[j]
return r

View File

@ -1,858 +0,0 @@
# TODO: interpret list as vectors (for multiplication)
rowsep = '\n'
colsep = ' '
class _matrix(object):
"""
Numerical matrix.
Specify the dimensions or the data as a nested list.
Elements default to zero.
Use a flat list to create a column vector easily.
By default, only mpf is used to store the data. You can specify another type
using force_type=type. It's possible to specify None.
Make sure force_type(force_type()) is fast.
Creating matrices
-----------------
Matrices in mpmath are implemented using dictionaries. Only non-zero values
are stored, so it is cheap to represent sparse matrices.
The most basic way to create one is to use the ``matrix`` class directly.
You can create an empty matrix specifying the dimensions:
>>> from mpmath import *
>>> mp.dps = 15
>>> matrix(2)
matrix(
[['0.0', '0.0'],
['0.0', '0.0']])
>>> matrix(2, 3)
matrix(
[['0.0', '0.0', '0.0'],
['0.0', '0.0', '0.0']])
Calling ``matrix`` with one dimension will create a square matrix.
To access the dimensions of a matrix, use the ``rows`` or ``cols`` keyword:
>>> A = matrix(3, 2)
>>> A
matrix(
[['0.0', '0.0'],
['0.0', '0.0'],
['0.0', '0.0']])
>>> A.rows
3
>>> A.cols
2
You can also change the dimension of an existing matrix. This will set the
new elements to 0. If the new dimension is smaller than before, the
concerning elements are discarded:
>>> A.rows = 2
>>> A
matrix(
[['0.0', '0.0'],
['0.0', '0.0']])
Internally ``mpmathify`` is used every time an element is set. This
is done using the syntax A[row,column], counting from 0:
>>> A = matrix(2)
>>> A[1,1] = 1 + 1j
>>> A
matrix(
[['0.0', '0.0'],
['0.0', '(1.0 + 1.0j)']])
You can use the keyword ``force_type`` to change the function which is
called on every new element:
>>> matrix(2, 5, force_type=int)
matrix(
[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
A more comfortable way to create a matrix lets you use nested lists:
>>> matrix([[1, 2], [3, 4]])
matrix(
[['1.0', '2.0'],
['3.0', '4.0']])
If you want to preserve the type of the elements you can use
``force_type=None``:
>>> matrix([[1, 2.5], [1j, mpf(2)]], force_type=None)
matrix(
[[1, 2.5],
[1j, '2.0']])
Convenient advanced functions are available for creating various standard
matrices, see ``zeros``, ``ones``, ``diag``, ``eye``, ``randmatrix`` and
``hilbert``.
Vectors
.......
Vectors may also be represented by the ``matrix`` class (with rows = 1 or cols = 1).
For vectors there are some things which make life easier. A column vector can
be created using a flat list, a row vectors using an almost flat nested list::
>>> matrix([1, 2, 3])
matrix(
[['1.0'],
['2.0'],
['3.0']])
>>> matrix([[1, 2, 3]])
matrix(
[['1.0', '2.0', '3.0']])
Optionally vectors can be accessed like lists, using only a single index::
>>> x = matrix([1, 2, 3])
>>> x[1]
mpf('2.0')
>>> x[1,0]
mpf('2.0')
Other
.....
Like you probably expected, matrices can be printed::
>>> print randmatrix(3) # doctest:+SKIP
[ 0.782963853573023 0.802057689719883 0.427895717335467]
[0.0541876859348597 0.708243266653103 0.615134039977379]
[ 0.856151514955773 0.544759264818486 0.686210904770947]
Use ``nstr`` or ``nprint`` to specify the number of digits to print::
>>> nprint(randmatrix(5), 3) # doctest:+SKIP
[2.07e-1 1.66e-1 5.06e-1 1.89e-1 8.29e-1]
[6.62e-1 6.55e-1 4.47e-1 4.82e-1 2.06e-2]
[4.33e-1 7.75e-1 6.93e-2 2.86e-1 5.71e-1]
[1.01e-1 2.53e-1 6.13e-1 3.32e-1 2.59e-1]
[1.56e-1 7.27e-2 6.05e-1 6.67e-2 2.79e-1]
As matrices are mutable, you will need to copy them sometimes::
>>> A = matrix(2)
>>> A
matrix(
[['0.0', '0.0'],
['0.0', '0.0']])
>>> B = A.copy()
>>> B[0,0] = 1
>>> B
matrix(
[['1.0', '0.0'],
['0.0', '0.0']])
>>> A
matrix(
[['0.0', '0.0'],
['0.0', '0.0']])
Finally, it is possible to convert a matrix to a nested list. This is very useful,
as most Python libraries involving matrices or arrays (namely NumPy or SymPy)
support this format::
>>> B.tolist()
[[mpf('1.0'), mpf('0.0')], [mpf('0.0'), mpf('0.0')]]
Matrix operations
-----------------
You can add and subtract matrices of compatible dimensions::
>>> A = matrix([[1, 2], [3, 4]])
>>> B = matrix([[-2, 4], [5, 9]])
>>> A + B
matrix(
[['-1.0', '6.0'],
['8.0', '13.0']])
>>> A - B
matrix(
[['3.0', '-2.0'],
['-2.0', '-5.0']])
>>> A + ones(3) # doctest:+ELLIPSIS
Traceback (most recent call last):
...
ValueError: incompatible dimensions for addition
It is possible to multiply or add matrices and scalars. In the latter case the
operation will be done element-wise::
>>> A * 2
matrix(
[['2.0', '4.0'],
['6.0', '8.0']])
>>> A / 4
matrix(
[['0.25', '0.5'],
['0.75', '1.0']])
>>> A - 1
matrix(
[['0.0', '1.0'],
['2.0', '3.0']])
Of course you can perform matrix multiplication, if the dimensions are
compatible::
>>> A * B
matrix(
[['8.0', '22.0'],
['14.0', '48.0']])
>>> matrix([[1, 2, 3]]) * matrix([[-6], [7], [-2]])
matrix(
[['2.0']])
You can raise powers of square matrices::
>>> A**2
matrix(
[['7.0', '10.0'],
['15.0', '22.0']])
Negative powers will calculate the inverse::
>>> A**-1
matrix(
[['-2.0', '1.0'],
['1.5', '-0.5']])
>>> A * A**-1
matrix(
[['1.0', '1.0842021724855e-19'],
['-2.16840434497101e-19', '1.0']])
Matrix transposition is straightforward::
>>> A = ones(2, 3)
>>> A
matrix(
[['1.0', '1.0', '1.0'],
['1.0', '1.0', '1.0']])
>>> A.T
matrix(
[['1.0', '1.0'],
['1.0', '1.0'],
['1.0', '1.0']])
Norms
.....
Sometimes you need to know how "large" a matrix or vector is. Due to their
multidimensional nature it's not possible to compare them, but there are
several functions to map a matrix or a vector to a positive real number, the
so called norms.
For vectors the p-norm is intended, usually the 1-, the 2- and the oo-norm are
used.
>>> x = matrix([-10, 2, 100])
>>> norm(x, 1)
mpf('112.0')
>>> norm(x, 2)
mpf('100.5186549850325')
>>> norm(x, inf)
mpf('100.0')
Please note that the 2-norm is the most used one, though it is more expensive
to calculate than the 1- or oo-norm.
It is possible to generalize some vector norms to matrix norm::
>>> A = matrix([[1, -1000], [100, 50]])
>>> mnorm(A, 1)
mpf('1050.0')
>>> mnorm(A, inf)
mpf('1001.0')
>>> mnorm(A, 'F')
mpf('1006.2310867787777')
The last norm (the "Frobenius-norm") is an approximation for the 2-norm, which
is hard to calculate and not available. The Frobenius-norm lacks some
mathematical properties you might expect from a norm.
"""
def __init__(self, *args, **kwargs):
self.__data = {}
# LU decompostion cache, this is useful when solving the same system
# multiple times, when calculating the inverse and when calculating the
# determinant
self._LU = None
convert = kwargs.get('force_type', self.ctx.convert)
if isinstance(args[0], (list, tuple)):
if isinstance(args[0][0], (list, tuple)):
# interpret nested list as matrix
A = args[0]
self.__rows = len(A)
self.__cols = len(A[0])
for i, row in enumerate(A):
for j, a in enumerate(row):
self[i, j] = convert(a)
else:
# interpret list as row vector
v = args[0]
self.__rows = len(v)
self.__cols = 1
for i, e in enumerate(v):
self[i, 0] = e
elif isinstance(args[0], int):
# create empty matrix of given dimensions
if len(args) == 1:
self.__rows = self.__cols = args[0]
else:
assert isinstance(args[1], int), 'expected int'
self.__rows = args[0]
self.__cols = args[1]
elif isinstance(args[0], _matrix):
A = args[0].copy()
self.__data = A._matrix__data
self.__rows = A._matrix__rows
self.__cols = A._matrix__cols
convert = kwargs.get('force_type', self.ctx.convert)
for i in xrange(A.__rows):
for j in xrange(A.__cols):
A[i,j] = convert(A[i,j])
elif hasattr(args[0], 'tolist'):
A = self.ctx.matrix(args[0].tolist())
self.__data = A._matrix__data
self.__rows = A._matrix__rows
self.__cols = A._matrix__cols
else:
raise TypeError('could not interpret given arguments')
def apply(self, f):
"""
Return a copy of self with the function `f` applied elementwise.
"""
new = self.ctx.matrix(self.__rows, self.__cols)
for i in xrange(self.__rows):
for j in xrange(self.__cols):
new[i,j] = f(self[i,j])
return new
def __nstr__(self, n=None, **kwargs):
# Build table of string representations of the elements
res = []
# Track per-column max lengths for pretty alignment
maxlen = [0] * self.cols
for i in range(self.rows):
res.append([])
for j in range(self.cols):
if n:
string = self.ctx.nstr(self[i,j], n, **kwargs)
else:
string = str(self[i,j])
res[-1].append(string)
maxlen[j] = max(len(string), maxlen[j])
# Patch strings together
for i, row in enumerate(res):
for j, elem in enumerate(row):
# Pad each element up to maxlen so the columns line up
row[j] = elem.rjust(maxlen[j])
res[i] = "[" + colsep.join(row) + "]"
return rowsep.join(res)
def __str__(self):
return self.__nstr__()
def _toliststr(self, avoid_type=False):
"""
Create a list string from a matrix.
If avoid_type: avoid multiple 'mpf's.
"""
# XXX: should be something like self.ctx._types
typ = self.ctx.mpf
s = '['
for i in xrange(self.__rows):
s += '['
for j in xrange(self.__cols):
if not avoid_type or not isinstance(self[i,j], typ):
a = repr(self[i,j])
else:
a = "'" + str(self[i,j]) + "'"
s += a + ', '
s = s[:-2]
s += '],\n '
s = s[:-3]
s += ']'
return s
def tolist(self):
"""
Convert the matrix to a nested list.
"""
return [[self[i,j] for j in range(self.__cols)] for i in range(self.__rows)]
def __repr__(self):
if self.ctx.pretty:
return self.__str__()
s = 'matrix(\n'
s += self._toliststr(avoid_type=True) + ')'
return s
def __getitem__(self, key):
if type(key) is int:
# only sufficent for vectors
if self.__rows == 1:
key = (0, key)
elif self.__cols == 1:
key = (key, 0)
else:
raise IndexError('insufficient indices for matrix')
if key in self.__data:
return self.__data[key]
else:
if key[0] >= self.__rows or key[1] >= self.__cols:
raise IndexError('matrix index out of range')
return self.ctx.zero
def __setitem__(self, key, value):
if type(key) is int:
# only sufficent for vectors
if self.__rows == 1:
key = (0, key)
elif self.__cols == 1:
key = (key, 0)
else:
raise IndexError('insufficient indices for matrix')
if key[0] >= self.__rows or key[1] >= self.__cols:
raise IndexError('matrix index out of range')
value = self.ctx.convert(value)
if value: # only store non-zeros
self.__data[key] = value
elif key in self.__data:
del self.__data[key]
# TODO: maybe do this better, if the performance impact is significant
if self._LU:
self._LU = None
def __iter__(self):
for i in xrange(self.__rows):
for j in xrange(self.__cols):
yield self[i,j]
def __mul__(self, other):
if isinstance(other, self.ctx.matrix):
# dot multiplication TODO: use Strassen's method?
if self.__cols != other.__rows:
raise ValueError('dimensions not compatible for multiplication')
new = self.ctx.matrix(self.__rows, other.__cols)
for i in xrange(self.__rows):
for j in xrange(other.__cols):
new[i, j] = self.ctx.fdot((self[i,k], other[k,j])
for k in xrange(other.__rows))
return new
else:
# try scalar multiplication
new = self.ctx.matrix(self.__rows, self.__cols)
for i in xrange(self.__rows):
for j in xrange(self.__cols):
new[i, j] = other * self[i, j]
return new
def __rmul__(self, other):
# assume other is scalar and thus commutative
assert not isinstance(other, self.ctx.matrix)
return self.__mul__(other)
def __pow__(self, other):
# avoid cyclic import problems
#from linalg import inverse
if not isinstance(other, int):
raise ValueError('only integer exponents are supported')
if not self.__rows == self.__cols:
raise ValueError('only powers of square matrices are defined')
n = other
if n == 0:
return self.ctx.eye(self.__rows)
if n < 0:
n = -n
neg = True
else:
neg = False
i = n
y = 1
z = self.copy()
while i != 0:
if i % 2 == 1:
y = y * z
z = z*z
i = i // 2
if neg:
y = self.ctx.inverse(y)
return y
def __div__(self, other):
# assume other is scalar and do element-wise divison
assert not isinstance(other, self.ctx.matrix)
new = self.ctx.matrix(self.__rows, self.__cols)
for i in xrange(self.__rows):
for j in xrange(self.__cols):
new[i,j] = self[i,j] / other
return new
__truediv__ = __div__
def __add__(self, other):
if isinstance(other, self.ctx.matrix):
if not (self.__rows == other.__rows and self.__cols == other.__cols):
raise ValueError('incompatible dimensions for addition')
new = self.ctx.matrix(self.__rows, self.__cols)
for i in xrange(self.__rows):
for j in xrange(self.__cols):
new[i,j] = self[i,j] + other[i,j]
return new
else:
# assume other is scalar and add element-wise
new = self.ctx.matrix(self.__rows, self.__cols)
for i in xrange(self.__rows):
for j in xrange(self.__cols):
new[i,j] += self[i,j] + other
return new
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, self.ctx.matrix) and not (self.__rows == other.__rows
and self.__cols == other.__cols):
raise ValueError('incompatible dimensions for substraction')
return self.__add__(other * (-1))
def __neg__(self):
return (-1) * self
def __rsub__(self, other):
return -self + other
def __eq__(self, other):
return self.__rows == other.__rows and self.__cols == other.__cols \
and self.__data == other.__data
def __len__(self):
if self.rows == 1:
return self.cols
elif self.cols == 1:
return self.rows
else:
return self.rows # do it like numpy
def __getrows(self):
return self.__rows
def __setrows(self, value):
for key in self.__data.copy().iterkeys():
if key[0] >= value:
del self.__data[key]
self.__rows = value
rows = property(__getrows, __setrows, doc='number of rows')
def __getcols(self):
return self.__cols
def __setcols(self, value):
for key in self.__data.copy().iterkeys():
if key[1] >= value:
del self.__data[key]
self.__cols = value
cols = property(__getcols, __setcols, doc='number of columns')
def transpose(self):
new = self.ctx.matrix(self.__cols, self.__rows)
for i in xrange(self.__rows):
for j in xrange(self.__cols):
new[j,i] = self[i,j]
return new
T = property(transpose)
def conjugate(self):
return self.apply(self.ctx.conj)
def transpose_conj(self):
return self.conjugate().transpose()
H = property(transpose_conj)
def copy(self):
new = self.ctx.matrix(self.__rows, self.__cols)
new.__data = self.__data.copy()
return new
__copy__ = copy
def column(self, n):
m = self.ctx.matrix(self.rows, 1)
for i in range(self.rows):
m[i] = self[i,n]
return m
class MatrixMethods(object):
def __init__(ctx):
# XXX: subclass
ctx.matrix = type('matrix', (_matrix,), {})
ctx.matrix.ctx = ctx
ctx.matrix.convert = ctx.convert
def eye(ctx, n, **kwargs):
"""
Create square identity matrix n x n.
"""
A = ctx.matrix(n, **kwargs)
for i in xrange(n):
A[i,i] = 1
return A
def diag(ctx, diagonal, **kwargs):
"""
Create square diagonal matrix using given list.
Example:
>>> from mpmath import diag, mp
>>> mp.pretty = False
>>> diag([1, 2, 3])
matrix(
[['1.0', '0.0', '0.0'],
['0.0', '2.0', '0.0'],
['0.0', '0.0', '3.0']])
"""
A = ctx.matrix(len(diagonal), **kwargs)
for i in xrange(len(diagonal)):
A[i,i] = diagonal[i]
return A
def zeros(ctx, *args, **kwargs):
"""
Create matrix m x n filled with zeros.
One given dimension will create square matrix n x n.
Example:
>>> from mpmath import zeros, mp
>>> mp.pretty = False
>>> zeros(2)
matrix(
[['0.0', '0.0'],
['0.0', '0.0']])
"""
if len(args) == 1:
m = n = args[0]
elif len(args) == 2:
m = args[0]
n = args[1]
else:
raise TypeError('zeros expected at most 2 arguments, got %i' % len(args))
A = ctx.matrix(m, n, **kwargs)
for i in xrange(m):
for j in xrange(n):
A[i,j] = 0
return A
def ones(ctx, *args, **kwargs):
"""
Create matrix m x n filled with ones.
One given dimension will create square matrix n x n.
Example:
>>> from mpmath import ones, mp
>>> mp.pretty = False
>>> ones(2)
matrix(
[['1.0', '1.0'],
['1.0', '1.0']])
"""
if len(args) == 1:
m = n = args[0]
elif len(args) == 2:
m = args[0]
n = args[1]
else:
raise TypeError('ones expected at most 2 arguments, got %i' % len(args))
A = ctx.matrix(m, n, **kwargs)
for i in xrange(m):
for j in xrange(n):
A[i,j] = 1
return A
def hilbert(ctx, m, n=None):
"""
Create (pseudo) hilbert matrix m x n.
One given dimension will create hilbert matrix n x n.
The matrix is very ill-conditioned and symmetric, positive definite if
square.
"""
if n is None:
n = m
A = ctx.matrix(m, n)
for i in xrange(m):
for j in xrange(n):
A[i,j] = ctx.one / (i + j + 1)
return A
def randmatrix(ctx, m, n=None, min=0, max=1, **kwargs):
"""
Create a random m x n matrix.
All values are >= min and <max.
n defaults to m.
Example:
>>> from mpmath import randmatrix
>>> randmatrix(2) # doctest:+SKIP
matrix(
[['0.53491598236191806', '0.57195669543302752'],
['0.85589992269513615', '0.82444367501382143']])
"""
if not n:
n = m
A = ctx.matrix(m, n, **kwargs)
for i in xrange(m):
for j in xrange(n):
A[i,j] = ctx.rand() * (max - min) + min
return A
def swap_row(ctx, A, i, j):
"""
Swap row i with row j.
"""
if i == j:
return
if isinstance(A, ctx.matrix):
for k in xrange(A.cols):
A[i,k], A[j,k] = A[j,k], A[i,k]
elif isinstance(A, list):
A[i], A[j] = A[j], A[i]
else:
raise TypeError('could not interpret type')
def extend(ctx, A, b):
"""
Extend matrix A with column b and return result.
"""
assert isinstance(A, ctx.matrix)
assert A.rows == len(b)
A = A.copy()
A.cols += 1
for i in xrange(A.rows):
A[i, A.cols-1] = b[i]
return A
def norm(ctx, x, p=2):
r"""
Gives the entrywise `p`-norm of an iterable *x*, i.e. the vector norm
`\left(\sum_k |x_k|^p\right)^{1/p}`, for any given `1 \le p \le \infty`.
Special cases:
If *x* is not iterable, this just returns ``absmax(x)``.
``p=1`` gives the sum of absolute values.
``p=2`` is the standard Euclidean vector norm.
``p=inf`` gives the magnitude of the largest element.
For *x* a matrix, ``p=2`` is the Frobenius norm.
For operator matrix norms, use :func:`mnorm` instead.
You can use the string 'inf' as well as float('inf') or mpf('inf')
to specify the infinity norm.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> x = matrix([-10, 2, 100])
>>> norm(x, 1)
mpf('112.0')
>>> norm(x, 2)
mpf('100.5186549850325')
>>> norm(x, inf)
mpf('100.0')
"""
try:
iter(x)
except TypeError:
return ctx.absmax(x)
if type(p) is not int:
p = ctx.convert(p)
if p == ctx.inf:
return max(ctx.absmax(i) for i in x)
elif p == 1:
return ctx.fsum(x, absolute=1)
elif p == 2:
return ctx.sqrt(ctx.fsum(x, absolute=1, squared=1))
elif p > 1:
return ctx.nthroot(ctx.fsum(abs(i)**p for i in x), p)
else:
raise ValueError('p has to be >= 1')
def mnorm(ctx, A, p=1):
r"""
Gives the matrix (operator) `p`-norm of A. Currently ``p=1`` and ``p=inf``
are supported:
``p=1`` gives the 1-norm (maximal column sum)
``p=inf`` gives the `\infty`-norm (maximal row sum).
You can use the string 'inf' as well as float('inf') or mpf('inf')
``p=2`` (not implemented) for a square matrix is the usual spectral
matrix norm, i.e. the largest singular value.
``p='f'`` (or 'F', 'fro', 'Frobenius, 'frobenius') gives the
Frobenius norm, which is the elementwise 2-norm. The Frobenius norm is an
approximation of the spectral norm and satisfies
.. math ::
\frac{1}{\sqrt{\mathrm{rank}(A)}} \|A\|_F \le \|A\|_2 \le \|A\|_F
The Frobenius norm lacks some mathematical properties that might
be expected of a norm.
For general elementwise `p`-norms, use :func:`norm` instead.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> A = matrix([[1, -1000], [100, 50]])
>>> mnorm(A, 1)
mpf('1050.0')
>>> mnorm(A, inf)
mpf('1001.0')
>>> mnorm(A, 'F')
mpf('1006.2310867787777')
"""
A = ctx.matrix(A)
if type(p) is not int:
if type(p) is str and 'frobenius'.startswith(p.lower()):
return ctx.norm(A, 2)
p = ctx.convert(p)
m, n = A.rows, A.cols
if p == 1:
return max(ctx.fsum((A[i,j] for i in xrange(m)), absolute=1) for j in xrange(n))
elif p == ctx.inf:
return max(ctx.fsum((A[i,j] for j in xrange(n)), absolute=1) for i in xrange(m))
else:
raise NotImplementedError("matrix p-norm for arbitrary p")
if __name__ == '__main__':
import doctest
doctest.testmod()

View File

@ -1,106 +0,0 @@
# TODO: use gmpy.mpq when available?
class mpq(tuple):
"""
Rational number type, only intended for internal use.
"""
"""
def _mpmath_(self, prec, rounding):
# XXX
return mp.make_mpf(from_rational(self[0], self[1], prec, rounding))
#(mpf(self[0])/self[1])._mpf_
"""
def __int__(self):
a, b = self
return a // b
def __abs__(self):
a, b = self
return mpq((abs(a), b))
def __neg__(self):
a, b = self
return mpq((-a, b))
def __nonzero__(self):
return bool(self[0])
def __cmp__(self, other):
if type(other) is int and self[1] == 1:
return cmp(self[0], other)
return NotImplemented
def __add__(self, other):
if isinstance(other, mpq):
a, b = self
c, d = other
return mpq((a*d+b*c, b*d))
if isinstance(other, (int, long)):
a, b = self
return mpq((a+b*other, b))
return NotImplemented
__radd__ = __add__
def __sub__(self, other):
if isinstance(other, mpq):
a, b = self
c, d = other
return mpq((a*d-b*c, b*d))
if isinstance(other, (int, long)):
a, b = self
return mpq((a-b*other, b))
return NotImplemented
def __rsub__(self, other):
if isinstance(other, mpq):
a, b = self
c, d = other
return mpq((b*c-a*d, b*d))
if isinstance(other, (int, long)):
a, b = self
return mpq((b*other-a, b))
return NotImplemented
def __mul__(self, other):
if isinstance(other, mpq):
a, b = self
c, d = other
return mpq((a*c, b*d))
if isinstance(other, (int, long)):
a, b = self
return mpq((a*other, b))
return NotImplemented
def __div__(self, other):
if isinstance(other, (int, long)):
if other:
a, b = self
return mpq((a, b*other))
raise ZeroDivisionError
return NotImplemented
def __pow__(self, other):
if type(other) is int:
a, b = self
return mpq((a**other, b**other))
return NotImplemented
__rmul__ = __mul__
mpq_1 = mpq((1,1))
mpq_0 = mpq((0,1))
mpq_1_2 = mpq((1,2))
mpq_3_2 = mpq((3,2))
mpq_1_4 = mpq((1,4))
mpq_1_16 = mpq((1,16))
mpq_3_16 = mpq((3,16))
mpq_5_2 = mpq((5,2))
mpq_3_4 = mpq((3,4))
mpq_7_4 = mpq((7,4))
mpq_5_4 = mpq((5,4))

View File

@ -1,159 +0,0 @@
#!/usr/bin/env python
"""
python runtests.py -py
Use py.test to run tests (more useful for debugging)
python runtests.py -psyco
Enable psyco to make tests run about 50% faster
python runtests.py -coverage
Generate test coverage report. Statistics are written to /tmp
python runtests.py -profile
Generate profile stats (this is much slower)
python runtests.py -nogmpy
Run tests without using GMPY even if it exists
python runtests.py -strict
Enforce extra tests in normalize()
python runtests.py -local
Insert '../..' at the beginning of sys.path to use local mpmath
Additional arguments are used to filter the tests to run. Only files that have
one of the arguments in their name are executed.
"""
import sys, os, traceback
if "-psyco" in sys.argv:
sys.argv.remove('-psyco')
import psyco
psyco.full()
profile = False
if "-profile" in sys.argv:
sys.argv.remove('-profile')
profile = True
coverage = False
if "-coverage" in sys.argv:
sys.argv.remove('-coverage')
coverage = True
if "-nogmpy" in sys.argv:
sys.argv.remove('-nogmpy')
os.environ['MPMATH_NOGMPY'] = 'Y'
if "-strict" in sys.argv:
sys.argv.remove('-strict')
os.environ['MPMATH_STRICT'] = 'Y'
if "-local" in sys.argv:
sys.argv.remove('-local')
importdir = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]),
'../..'))
else:
importdir = ''
# TODO: add a flag for this
testdir = ''
def testit(importdir='', testdir=''):
"""Run all tests in testdir while importing from importdir."""
if importdir:
sys.path.insert(1, importdir)
if testdir:
sys.path.insert(1, testdir)
import os.path
import mpmath
print "mpmath imported from", os.path.dirname(mpmath.__file__)
print "mpmath backend:", mpmath.libmp.backend.BACKEND
print "mpmath mp class:", repr(mpmath.mp)
print "mpmath version:", mpmath.__version__
print "Python version:", sys.version
print
if "-py" in sys.argv:
sys.argv.remove('-py')
import py
py.test.cmdline.main()
else:
import glob
from timeit import default_timer as clock
modules = []
args = sys.argv[1:]
# search for tests in directory of this file if not otherwise specified
if not testdir:
pattern = os.path.dirname(sys.argv[0])
else:
pattern = testdir
if pattern:
pattern += '/'
pattern += 'test*.py'
# look for tests (respecting specified filter)
for f in glob.glob(pattern):
name = os.path.splitext(os.path.basename(f))[0]
# If run as a script, only run tests given as args, if any are given
if args and __name__ == "__main__":
ok = False
for arg in args:
if arg in name:
ok = True
break
if not ok:
continue
module = __import__(name)
priority = module.__dict__.get('priority', 100)
if priority == 666:
modules = [[priority, name, module]]
break
modules.append([priority, name, module])
# execute tests
modules.sort()
tstart = clock()
for priority, name, module in modules:
print name
for f in sorted(module.__dict__.keys()):
if f.startswith('test_'):
if coverage and ('numpy' in f):
continue
print " ", f[5:].ljust(25),
t1 = clock()
try:
module.__dict__[f]()
except:
etype, evalue, trb = sys.exc_info()
if etype in (KeyboardInterrupt, SystemExit):
raise
print
print "TEST FAILED!"
print
traceback.print_exc()
t2 = clock()
print "ok", " ", ("%.7f" % (t2-t1)), "s"
tend = clock()
print
print "finished tests in", ("%.2f" % (tend-tstart)), "seconds"
# clean sys.path
if importdir:
sys.path.remove(importdir)
if testdir:
sys.path.remove(testdir)
if __name__ == '__main__':
if profile:
import cProfile
cProfile.run("testit('%s', '%s')" % (importdir, testdir), sort=1)
elif coverage:
import trace
tracer = trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix],
trace=0, count=1)
tracer.run('testit(importdir, testdir)')
r = tracer.results()
r.write_results(show_missing=True, summary=True, coverdir="/tmp")
else:
testit(importdir, testdir)

View File

@ -1,161 +0,0 @@
from mpmath import *
def test_type_compare():
assert mpf(2) == mpc(2,0)
assert mpf(0) == mpc(0)
assert mpf(2) != mpc(2, 0.00001)
assert mpf(2) == 2.0
assert mpf(2) != 3.0
assert mpf(2) == 2
assert mpf(2) != '2.0'
assert mpc(2) != '2.0'
def test_add():
assert mpf(2.5) + mpf(3) == 5.5
assert mpf(2.5) + 3 == 5.5
assert mpf(2.5) + 3.0 == 5.5
assert 3 + mpf(2.5) == 5.5
assert 3.0 + mpf(2.5) == 5.5
assert (3+0j) + mpf(2.5) == 5.5
assert mpc(2.5) + mpf(3) == 5.5
assert mpc(2.5) + 3 == 5.5
assert mpc(2.5) + 3.0 == 5.5
assert mpc(2.5) + (3+0j) == 5.5
assert 3 + mpc(2.5) == 5.5
assert 3.0 + mpc(2.5) == 5.5
assert (3+0j) + mpc(2.5) == 5.5
def test_sub():
assert mpf(2.5) - mpf(3) == -0.5
assert mpf(2.5) - 3 == -0.5
assert mpf(2.5) - 3.0 == -0.5
assert 3 - mpf(2.5) == 0.5
assert 3.0 - mpf(2.5) == 0.5
assert (3+0j) - mpf(2.5) == 0.5
assert mpc(2.5) - mpf(3) == -0.5
assert mpc(2.5) - 3 == -0.5
assert mpc(2.5) - 3.0 == -0.5
assert mpc(2.5) - (3+0j) == -0.5
assert 3 - mpc(2.5) == 0.5
assert 3.0 - mpc(2.5) == 0.5
assert (3+0j) - mpc(2.5) == 0.5
def test_mul():
assert mpf(2.5) * mpf(3) == 7.5
assert mpf(2.5) * 3 == 7.5
assert mpf(2.5) * 3.0 == 7.5
assert 3 * mpf(2.5) == 7.5
assert 3.0 * mpf(2.5) == 7.5
assert (3+0j) * mpf(2.5) == 7.5
assert mpc(2.5) * mpf(3) == 7.5
assert mpc(2.5) * 3 == 7.5
assert mpc(2.5) * 3.0 == 7.5
assert mpc(2.5) * (3+0j) == 7.5
assert 3 * mpc(2.5) == 7.5
assert 3.0 * mpc(2.5) == 7.5
assert (3+0j) * mpc(2.5) == 7.5
def test_div():
assert mpf(6) / mpf(3) == 2.0
assert mpf(6) / 3 == 2.0
assert mpf(6) / 3.0 == 2.0
assert 6 / mpf(3) == 2.0
assert 6.0 / mpf(3) == 2.0
assert (6+0j) / mpf(3.0) == 2.0
assert mpc(6) / mpf(3) == 2.0
assert mpc(6) / 3 == 2.0
assert mpc(6) / 3.0 == 2.0
assert mpc(6) / (3+0j) == 2.0
assert 6 / mpc(3) == 2.0
assert 6.0 / mpc(3) == 2.0
assert (6+0j) / mpc(3) == 2.0
def test_pow():
assert mpf(6) ** mpf(3) == 216.0
assert mpf(6) ** 3 == 216.0
assert mpf(6) ** 3.0 == 216.0
assert 6 ** mpf(3) == 216.0
assert 6.0 ** mpf(3) == 216.0
assert (6+0j) ** mpf(3.0) == 216.0
assert mpc(6) ** mpf(3) == 216.0
assert mpc(6) ** 3 == 216.0
assert mpc(6) ** 3.0 == 216.0
assert mpc(6) ** (3+0j) == 216.0
assert 6 ** mpc(3) == 216.0
assert 6.0 ** mpc(3) == 216.0
assert (6+0j) ** mpc(3) == 216.0
def test_mixed_misc():
assert 1 + mpf(3) == mpf(3) + 1 == 4
assert 1 - mpf(3) == -(mpf(3) - 1) == -2
assert 3 * mpf(2) == mpf(2) * 3 == 6
assert 6 / mpf(2) == mpf(6) / 2 == 3
assert 1.0 + mpf(3) == mpf(3) + 1.0 == 4
assert 1.0 - mpf(3) == -(mpf(3) - 1.0) == -2
assert 3.0 * mpf(2) == mpf(2) * 3.0 == 6
assert 6.0 / mpf(2) == mpf(6) / 2.0 == 3
def test_add_misc():
mp.dps = 15
assert mpf(4) + mpf(-70) == -66
assert mpf(1) + mpf(1.1)/80 == 1 + 1.1/80
assert mpf((1, 10000000000)) + mpf(3) == mpf((1, 10000000000))
assert mpf(3) + mpf((1, 10000000000)) == mpf((1, 10000000000))
assert mpf((1, -10000000000)) + mpf(3) == mpf(3)
assert mpf(3) + mpf((1, -10000000000)) == mpf(3)
assert mpf(1) + 1e-15 != 1
assert mpf(1) + 1e-20 == 1
assert mpf(1.07e-22) + 0 == mpf(1.07e-22)
assert mpf(0) + mpf(1.07e-22) == mpf(1.07e-22)
def test_complex_misc():
# many more tests needed
assert 1 + mpc(2) == 3
assert not mpc(2).ae(2 + 1e-13)
assert mpc(2+1e-15j).ae(2)
def test_complex_zeros():
for a in [0,2]:
for b in [0,3]:
for c in [0,4]:
for d in [0,5]:
assert mpc(a,b)*mpc(c,d) == complex(a,b)*complex(c,d)
def test_hash():
for i in range(-256, 256):
assert hash(mpf(i)) == hash(i)
assert hash(mpf(0.5)) == hash(0.5)
assert hash(mpc(2,3)) == hash(2+3j)
# Check that this doesn't fail
assert hash(inf)
# Check that overflow doesn't assign equal hashes to large numbers
assert hash(mpf('1e1000')) != hash('1e10000')
assert hash(mpc(100,'1e1000')) != hash(mpc(200,'1e1000'))
def test_arithmetic_functions():
import operator
ops = [(operator.add, fadd), (operator.sub, fsub), (operator.mul, fmul),
(operator.div, fdiv)]
a = mpf(0.27)
b = mpf(1.13)
c = mpc(0.51+2.16j)
d = mpc(1.08-0.99j)
for x in [a,b,c,d]:
for y in [a,b,c,d]:
for op, fop in ops:
if fop is not fdiv:
mp.prec = 200
z0 = op(x,y)
mp.prec = 60
z1 = op(x,y)
mp.prec = 53
z2 = op(x,y)
assert fop(x, y, prec=60) == z1
assert fop(x, y) == z2
if fop is not fdiv:
assert fop(x, y, prec=inf) == z0
assert fop(x, y, dps=inf) == z0
assert fop(x, y, exact=True) == z0
assert fneg(fneg(z1, exact=True), prec=inf) == z1
assert fneg(z1) == -(+z1)
mp.dps = 15

View File

@ -1,172 +0,0 @@
"""
Test bit-level integer and mpf operations
"""
from mpmath import *
from mpmath.libmp import *
def test_bitcount():
assert bitcount(0) == 0
assert bitcount(1) == 1
assert bitcount(7) == 3
assert bitcount(8) == 4
assert bitcount(2**100) == 101
assert bitcount(2**100-1) == 100
def test_trailing():
assert trailing(0) == 0
assert trailing(1) == 0
assert trailing(2) == 1
assert trailing(7) == 0
assert trailing(8) == 3
assert trailing(2**100) == 100
assert trailing(2**100-1) == 0
def test_round_down():
assert from_man_exp(0, -4, 4, round_down)[:3] == (0, 0, 0)
assert from_man_exp(0xf0, -4, 4, round_down)[:3] == (0, 15, 0)
assert from_man_exp(0xf1, -4, 4, round_down)[:3] == (0, 15, 0)
assert from_man_exp(0xff, -4, 4, round_down)[:3] == (0, 15, 0)
assert from_man_exp(-0xf0, -4, 4, round_down)[:3] == (1, 15, 0)
assert from_man_exp(-0xf1, -4, 4, round_down)[:3] == (1, 15, 0)
assert from_man_exp(-0xff, -4, 4, round_down)[:3] == (1, 15, 0)
def test_round_up():
assert from_man_exp(0, -4, 4, round_up)[:3] == (0, 0, 0)
assert from_man_exp(0xf0, -4, 4, round_up)[:3] == (0, 15, 0)
assert from_man_exp(0xf1, -4, 4, round_up)[:3] == (0, 1, 4)
assert from_man_exp(0xff, -4, 4, round_up)[:3] == (0, 1, 4)
assert from_man_exp(-0xf0, -4, 4, round_up)[:3] == (1, 15, 0)
assert from_man_exp(-0xf1, -4, 4, round_up)[:3] == (1, 1, 4)
assert from_man_exp(-0xff, -4, 4, round_up)[:3] == (1, 1, 4)
def test_round_floor():
assert from_man_exp(0, -4, 4, round_floor)[:3] == (0, 0, 0)
assert from_man_exp(0xf0, -4, 4, round_floor)[:3] == (0, 15, 0)
assert from_man_exp(0xf1, -4, 4, round_floor)[:3] == (0, 15, 0)
assert from_man_exp(0xff, -4, 4, round_floor)[:3] == (0, 15, 0)
assert from_man_exp(-0xf0, -4, 4, round_floor)[:3] == (1, 15, 0)
assert from_man_exp(-0xf1, -4, 4, round_floor)[:3] == (1, 1, 4)
assert from_man_exp(-0xff, -4, 4, round_floor)[:3] == (1, 1, 4)
def test_round_ceiling():
assert from_man_exp(0, -4, 4, round_ceiling)[:3] == (0, 0, 0)
assert from_man_exp(0xf0, -4, 4, round_ceiling)[:3] == (0, 15, 0)
assert from_man_exp(0xf1, -4, 4, round_ceiling)[:3] == (0, 1, 4)
assert from_man_exp(0xff, -4, 4, round_ceiling)[:3] == (0, 1, 4)
assert from_man_exp(-0xf0, -4, 4, round_ceiling)[:3] == (1, 15, 0)
assert from_man_exp(-0xf1, -4, 4, round_ceiling)[:3] == (1, 15, 0)
assert from_man_exp(-0xff, -4, 4, round_ceiling)[:3] == (1, 15, 0)
def test_round_nearest():
assert from_man_exp(0, -4, 4, round_nearest)[:3] == (0, 0, 0)
assert from_man_exp(0xf0, -4, 4, round_nearest)[:3] == (0, 15, 0)
assert from_man_exp(0xf7, -4, 4, round_nearest)[:3] == (0, 15, 0)
assert from_man_exp(0xf8, -4, 4, round_nearest)[:3] == (0, 1, 4) # 1111.1000 -> 10000.0
assert from_man_exp(0xf9, -4, 4, round_nearest)[:3] == (0, 1, 4) # 1111.1001 -> 10000.0
assert from_man_exp(0xe8, -4, 4, round_nearest)[:3] == (0, 7, 1) # 1110.1000 -> 1110.0
assert from_man_exp(0xe9, -4, 4, round_nearest)[:3] == (0, 15, 0) # 1110.1001 -> 1111.0
assert from_man_exp(-0xf0, -4, 4, round_nearest)[:3] == (1, 15, 0)
assert from_man_exp(-0xf7, -4, 4, round_nearest)[:3] == (1, 15, 0)
assert from_man_exp(-0xf8, -4, 4, round_nearest)[:3] == (1, 1, 4)
assert from_man_exp(-0xf9, -4, 4, round_nearest)[:3] == (1, 1, 4)
assert from_man_exp(-0xe8, -4, 4, round_nearest)[:3] == (1, 7, 1)
assert from_man_exp(-0xe9, -4, 4, round_nearest)[:3] == (1, 15, 0)
def test_rounding_bugs():
# 1 less than power-of-two cases
assert from_man_exp(72057594037927935, -56, 53, round_up) == (0, 1, 0, 1)
assert from_man_exp(73786976294838205979l, -65, 53, round_nearest) == (0, 1, 1, 1)
assert from_man_exp(31, 0, 4, round_up) == (0, 1, 5, 1)
assert from_man_exp(-31, 0, 4, round_floor) == (1, 1, 5, 1)
assert from_man_exp(255, 0, 7, round_up) == (0, 1, 8, 1)
assert from_man_exp(-255, 0, 7, round_floor) == (1, 1, 8, 1)
def test_rounding_issue160():
a = from_man_exp(9867,-100)
b = from_man_exp(9867,-200)
c = from_man_exp(-1,0)
z = (1, 1023, -10, 10)
assert mpf_add(a, c, 10, 'd') == z
assert mpf_add(b, c, 10, 'd') == z
assert mpf_add(c, a, 10, 'd') == z
assert mpf_add(c, b, 10, 'd') == z
def test_perturb():
a = fone
b = from_float(0.99999999999999989)
c = from_float(1.0000000000000002)
assert mpf_perturb(a, 0, 53, round_nearest) == a
assert mpf_perturb(a, 1, 53, round_nearest) == a
assert mpf_perturb(a, 0, 53, round_up) == c
assert mpf_perturb(a, 0, 53, round_ceiling) == c
assert mpf_perturb(a, 0, 53, round_down) == a
assert mpf_perturb(a, 0, 53, round_floor) == a
assert mpf_perturb(a, 1, 53, round_up) == a
assert mpf_perturb(a, 1, 53, round_ceiling) == a
assert mpf_perturb(a, 1, 53, round_down) == b
assert mpf_perturb(a, 1, 53, round_floor) == b
a = mpf_neg(a)
b = mpf_neg(b)
c = mpf_neg(c)
assert mpf_perturb(a, 0, 53, round_nearest) == a
assert mpf_perturb(a, 1, 53, round_nearest) == a
assert mpf_perturb(a, 0, 53, round_up) == a
assert mpf_perturb(a, 0, 53, round_floor) == a
assert mpf_perturb(a, 0, 53, round_down) == b
assert mpf_perturb(a, 0, 53, round_ceiling) == b
assert mpf_perturb(a, 1, 53, round_up) == c
assert mpf_perturb(a, 1, 53, round_floor) == c
assert mpf_perturb(a, 1, 53, round_down) == a
assert mpf_perturb(a, 1, 53, round_ceiling) == a
def test_add_exact():
ff = from_float
assert mpf_add(ff(3.0), ff(2.5)) == ff(5.5)
assert mpf_add(ff(3.0), ff(-2.5)) == ff(0.5)
assert mpf_add(ff(-3.0), ff(2.5)) == ff(-0.5)
assert mpf_add(ff(-3.0), ff(-2.5)) == ff(-5.5)
assert mpf_sub(mpf_add(fone, ff(1e-100)), fone) == ff(1e-100)
assert mpf_sub(mpf_add(ff(1e-100), fone), fone) == ff(1e-100)
assert mpf_sub(mpf_add(fone, ff(-1e-100)), fone) == ff(-1e-100)
assert mpf_sub(mpf_add(ff(-1e-100), fone), fone) == ff(-1e-100)
assert mpf_add(fone, fzero) == fone
assert mpf_add(fzero, fone) == fone
assert mpf_add(fzero, fzero) == fzero
def test_long_exponent_shifts():
mp.dps = 15
# Check for possible bugs due to exponent arithmetic overflow
# in a C implementation
x = mpf(1)
for p in [32, 64]:
a = ldexp(1,2**(p-1))
b = ldexp(1,2**p)
c = ldexp(1,2**(p+1))
d = ldexp(1,-2**(p-1))
e = ldexp(1,-2**p)
f = ldexp(1,-2**(p+1))
assert (x+a) == a
assert (x+b) == b
assert (x+c) == c
assert (x+d) == x
assert (x+e) == x
assert (x+f) == x
assert (a+x) == a
assert (b+x) == b
assert (c+x) == c
assert (d+x) == x
assert (e+x) == x
assert (f+x) == x
assert (x-a) == -a
assert (x-b) == -b
assert (x-c) == -c
assert (x-d) == x
assert (x-e) == x
assert (x-f) == x
assert (a-x) == a
assert (b-x) == b
assert (c-x) == c
assert (d-x) == -x
assert (e-x) == -x
assert (f-x) == -x

View File

@ -1,69 +0,0 @@
from mpmath import *
def test_approximation():
mp.dps = 15
f = lambda x: cos(2-2*x)/x
p, err = chebyfit(f, [2, 4], 8, error=True)
assert err < 1e-5
for i in range(10):
x = 2 + i/5.
assert abs(polyval(p, x) - f(x)) < err
def test_limits():
mp.dps = 15
assert limit(lambda x: (x-sin(x))/x**3, 0).ae(mpf(1)/6)
assert limit(lambda n: (1+1/n)**n, inf).ae(e)
def test_polyval():
assert polyval([], 3) == 0
assert polyval([0], 3) == 0
assert polyval([5], 3) == 5
# 4x^3 - 2x + 5
p = [4, 0, -2, 5]
assert polyval(p,4) == 253
assert polyval(p,4,derivative=True) == (253, 190)
def test_polyroots():
p = polyroots([1,-4])
assert p[0].ae(4)
p, q = polyroots([1,2,3])
assert p.ae(-1 - sqrt(2)*j)
assert q.ae(-1 + sqrt(2)*j)
#this is not a real test, it only tests a specific case
assert polyroots([1]) == []
try:
polyroots([0])
assert False
except ValueError:
pass
def test_pade():
one = mpf(1)
mp.dps = 20
N = 10
a = [one]
k = 1
for i in range(1, N+1):
k *= i
a.append(one/k)
p, q = pade(a, N//2, N//2)
for x in arange(0, 1, 0.1):
r = polyval(p[::-1], x)/polyval(q[::-1], x)
assert(r.ae(exp(x), 1.0e-10))
mp.dps = 15
def test_fourier():
mp.dps = 15
c, s = fourier(lambda x: x+1, [-1, 2], 2)
#plot([lambda x: x+1, lambda x: fourierval((c, s), [-1, 2], x)], [-1, 2])
assert c[0].ae(1.5)
assert c[1].ae(-3*sqrt(3)/(2*pi))
assert c[2].ae(3*sqrt(3)/(4*pi))
assert s[0] == 0
assert s[1].ae(3/(2*pi))
assert s[2].ae(3/(4*pi))
assert fourierval((c, s), [-1, 2], 1).ae(1.9134966715663442)
def test_differint():
mp.dps = 15
assert differint(lambda t: t, 2, -0.5).ae(8*sqrt(2/pi)/3)

View File

@ -1,77 +0,0 @@
from mpmath import *
from random import seed, randint, random
import math
# Test compatibility with Python floats, which are
# IEEE doubles (53-bit)
N = 5000
seed(1)
# Choosing exponents between roughly -140, 140 ensures that
# the Python floats don't overflow or underflow
xs = [(random()-1) * 10**randint(-140, 140) for x in range(N)]
ys = [(random()-1) * 10**randint(-140, 140) for x in range(N)]
# include some equal values
ys[int(N*0.8):] = xs[int(N*0.8):]
# Detect whether Python is compiled to use 80-bit floating-point
# instructions, in which case the double compatibility test breaks
uses_x87 = -4.1974624032366689e+117 / -8.4657370748010221e-47 \
== 4.9581771393902231e+163
def test_double_compatibility():
mp.prec = 53
for x, y in zip(xs, ys):
mpx = mpf(x)
mpy = mpf(y)
assert mpf(x) == x
assert (mpx < mpy) == (x < y)
assert (mpx > mpy) == (x > y)
assert (mpx == mpy) == (x == y)
assert (mpx != mpy) == (x != y)
assert (mpx <= mpy) == (x <= y)
assert (mpx >= mpy) == (x >= y)
assert mpx == mpx
if uses_x87:
mp.prec = 64
a = mpx + mpy
b = mpx * mpy
c = mpx / mpy
d = mpx % mpy
mp.prec = 53
assert +a == x + y
assert +b == x * y
assert +c == x / y
assert +d == x % y
else:
assert mpx + mpy == x + y
assert mpx * mpy == x * y
assert mpx / mpy == x / y
assert mpx % mpy == x % y
assert abs(mpx) == abs(x)
assert mpf(repr(x)) == x
assert ceil(mpx) == math.ceil(x)
assert floor(mpx) == math.floor(x)
def test_sqrt():
# this fails quite often. it appers to be float
# that rounds the wrong way, not mpf
fail = 0
mp.prec = 53
for x in xs:
x = abs(x)
mp.prec = 100
mp_high = mpf(x)**0.5
mp.prec = 53
mp_low = mpf(x)**0.5
fp = x**0.5
assert abs(mp_low-mp_high) <= abs(fp-mp_high)
fail += mp_low != fp
assert fail < N/10
def test_bugs():
# particular bugs
assert mpf(4.4408920985006262E-16) < mpf(1.7763568394002505E-15)
assert mpf(-4.4408920985006262E-16) > mpf(-1.7763568394002505E-15)

View File

@ -1,186 +0,0 @@
import random
from mpmath import *
from mpmath.libmp import *
def test_basic_string():
"""
Test basic string conversion
"""
mp.dps = 15
assert mpf('3') == mpf('3.0') == mpf('0003.') == mpf('0.03e2') == mpf(3.0)
assert mpf('30') == mpf('30.0') == mpf('00030.') == mpf(30.0)
for i in range(10):
for j in range(10):
assert mpf('%ie%i' % (i,j)) == i * 10**j
assert str(mpf('25000.0')) == '25000.0'
assert str(mpf('2500.0')) == '2500.0'
assert str(mpf('250.0')) == '250.0'
assert str(mpf('25.0')) == '25.0'
assert str(mpf('2.5')) == '2.5'
assert str(mpf('0.25')) == '0.25'
assert str(mpf('0.025')) == '0.025'
assert str(mpf('0.0025')) == '0.0025'
assert str(mpf('0.00025')) == '0.00025'
assert str(mpf('0.000025')) == '2.5e-5'
assert str(mpf(0)) == '0.0'
assert str(mpf('2.5e1000000000000000000000')) == '2.5e+1000000000000000000000'
assert str(mpf('2.6e-1000000000000000000000')) == '2.6e-1000000000000000000000'
assert str(mpf(1.23402834e-15)) == '1.23402834e-15'
assert str(mpf(-1.23402834e-15)) == '-1.23402834e-15'
assert str(mpf(-1.2344e-15)) == '-1.2344e-15'
assert repr(mpf(-1.2344e-15)) == "mpf('-1.2343999999999999e-15')"
def test_pretty():
mp.pretty = True
assert repr(mpf(2.5)) == '2.5'
assert repr(mpc(2.5,3.5)) == '(2.5 + 3.5j)'
assert repr(mpi(2.5,3.5)) == '[2.5, 3.5]'
mp.pretty = False
def test_str_whitespace():
assert mpf('1.26 ') == 1.26
def test_unicode():
mp.dps = 15
assert mpf(u'2.76') == 2.76
assert mpf(u'inf') == inf
def test_str_format():
assert to_str(from_float(0.1),15,strip_zeros=False) == '0.100000000000000'
assert to_str(from_float(0.0),15,show_zero_exponent=True) == '0.0e+0'
assert to_str(from_float(0.0),0,show_zero_exponent=True) == '.0e+0'
assert to_str(from_float(0.0),0,show_zero_exponent=False) == '.0'
assert to_str(from_float(0.0),1,show_zero_exponent=True) == '0.0e+0'
assert to_str(from_float(0.0),1,show_zero_exponent=False) == '0.0'
assert to_str(from_float(1.23),3,show_zero_exponent=True) == '1.23e+0'
assert to_str(from_float(1.23456789000000e-2),15,strip_zeros=False,min_fixed=0,max_fixed=0) == '1.23456789000000e-2'
assert to_str(from_float(1.23456789000000e+2),15,strip_zeros=False,min_fixed=0,max_fixed=0) == '1.23456789000000e+2'
assert to_str(from_float(2.1287e14), 15, max_fixed=1000) == '212870000000000.0'
assert to_str(from_float(2.1287e15), 15, max_fixed=1000) == '2128700000000000.0'
assert to_str(from_float(2.1287e16), 15, max_fixed=1000) == '21287000000000000.0'
assert to_str(from_float(2.1287e30), 15, max_fixed=1000) == '2128700000000000000000000000000.0'
def test_tight_string_conversion():
mp.dps = 15
# In an old version, '0.5' wasn't recognized as representing
# an exact binary number and was erroneously rounded up or down
assert from_str('0.5', 10, round_floor) == fhalf
assert from_str('0.5', 10, round_ceiling) == fhalf
def test_eval_repr_invariant():
"""Test that eval(repr(x)) == x"""
random.seed(123)
for dps in [10, 15, 20, 50, 100]:
mp.dps = dps
for i in xrange(1000):
a = mpf(random.random())**0.5 * 10**random.randint(-100, 100)
assert eval(repr(a)) == a
mp.dps = 15
def test_str_bugs():
mp.dps = 15
# Decimal rounding used to give the wrong exponent in some cases
assert str(mpf('1e600')) == '1.0e+600'
assert str(mpf('1e10000')) == '1.0e+10000'
def test_str_prec0():
assert to_str(from_float(1.234), 0) == '.0e+0'
assert to_str(from_float(1e-15), 0) == '.0e-15'
assert to_str(from_float(1e+15), 0) == '.0e+15'
assert to_str(from_float(-1e-15), 0) == '-.0e-15'
assert to_str(from_float(-1e+15), 0) == '-.0e+15'
def test_convert_rational():
mp.dps = 15
assert from_rational(30, 5, 53, round_nearest) == (0, 3, 1, 2)
assert from_rational(-7, 4, 53, round_nearest) == (1, 7, -2, 3)
assert to_rational((0, 1, -1, 1)) == (1, 2)
def test_custom_class():
class mympf:
@property
def _mpf_(self):
return mpf(3.5)._mpf_
class mympc:
@property
def _mpc_(self):
return mpf(3.5)._mpf_, mpf(2.5)._mpf_
assert mpf(2) + mympf() == 5.5
assert mympf() + mpf(2) == 5.5
assert mpf(mympf()) == 3.5
assert mympc() + mpc(2) == mpc(5.5, 2.5)
assert mpc(2) + mympc() == mpc(5.5, 2.5)
assert mpc(mympc()) == (3.5+2.5j)
def test_conversion_methods():
class SomethingRandom:
pass
class SomethingReal:
def _mpmath_(self, prec, rounding):
return mp.make_mpf(from_str('1.3', prec, rounding))
class SomethingComplex:
def _mpmath_(self, prec, rounding):
return mp.make_mpc((from_str('1.3', prec, rounding), \
from_str('1.7', prec, rounding)))
x = mpf(3)
z = mpc(3)
a = SomethingRandom()
y = SomethingReal()
w = SomethingComplex()
for d in [15, 45]:
mp.dps = d
assert (x+y).ae(mpf('4.3'))
assert (y+x).ae(mpf('4.3'))
assert (x+w).ae(mpc('4.3', '1.7'))
assert (w+x).ae(mpc('4.3', '1.7'))
assert (z+y).ae(mpc('4.3'))
assert (y+z).ae(mpc('4.3'))
assert (z+w).ae(mpc('4.3', '1.7'))
assert (w+z).ae(mpc('4.3', '1.7'))
x-y; y-x; x-w; w-x; z-y; y-z; z-w; w-z
x*y; y*x; x*w; w*x; z*y; y*z; z*w; w*z
x/y; y/x; x/w; w/x; z/y; y/z; z/w; w/z
x**y; y**x; x**w; w**x; z**y; y**z; z**w; w**z
x==y; y==x; x==w; w==x; z==y; y==z; z==w; w==z
mp.dps = 15
assert x.__add__(a) is NotImplemented
assert x.__radd__(a) is NotImplemented
assert x.__lt__(a) is NotImplemented
assert x.__gt__(a) is NotImplemented
assert x.__le__(a) is NotImplemented
assert x.__ge__(a) is NotImplemented
assert x.__eq__(a) is NotImplemented
assert x.__ne__(a) is NotImplemented
# implementation detail
if hasattr(x, "__cmp__"):
assert x.__cmp__(a) is NotImplemented
assert x.__sub__(a) is NotImplemented
assert x.__rsub__(a) is NotImplemented
assert x.__mul__(a) is NotImplemented
assert x.__rmul__(a) is NotImplemented
assert x.__div__(a) is NotImplemented
assert x.__rdiv__(a) is NotImplemented
assert x.__mod__(a) is NotImplemented
assert x.__rmod__(a) is NotImplemented
assert x.__pow__(a) is NotImplemented
assert x.__rpow__(a) is NotImplemented
assert z.__add__(a) is NotImplemented
assert z.__radd__(a) is NotImplemented
assert z.__eq__(a) is NotImplemented
assert z.__ne__(a) is NotImplemented
assert z.__sub__(a) is NotImplemented
assert z.__rsub__(a) is NotImplemented
assert z.__mul__(a) is NotImplemented
assert z.__rmul__(a) is NotImplemented
assert z.__div__(a) is NotImplemented
assert z.__rdiv__(a) is NotImplemented
assert z.__pow__(a) is NotImplemented
assert z.__rpow__(a) is NotImplemented
def test_mpmathify():
assert mpmathify('1/2') == 0.5
assert mpmathify('(1.0+1.0j)') == mpc(1, 1)
assert mpmathify('(1.2e-10 - 3.4e5j)') == mpc('1.2e-10', '-3.4e5')
assert mpmathify('1j') == mpc(1j)

View File

@ -1,20 +0,0 @@
from mpmath import *
def test_diff():
assert diff(log, 2.0, n=0).ae(log(2))
assert diff(cos, 1.0).ae(-sin(1))
assert diff(abs, 0.0) == 0
assert diff(abs, 0.0, direction=1) == 1
assert diff(abs, 0.0, direction=-1) == -1
assert diff(exp, 1.0).ae(e)
assert diff(exp, 1.0, n=5).ae(e)
assert diff(exp, 2.0, n=5, direction=3*j).ae(e**2)
assert diff(lambda x: x**2, 3.0, method='quad').ae(6)
assert diff(lambda x: 3+x**5, 3.0, n=2, method='quad').ae(540)
assert diff(lambda x: 3+x**5, 3.0, n=2, method='step').ae(540)
assert diffun(sin)(2).ae(cos(2))
assert diffun(sin, n=2)(2).ae(-sin(2))
def test_taylor():
# Easy to test since the coefficients are exact in floating-point
assert taylor(sqrt, 1, 4) == [1, 0.5, -0.125, 0.0625, -0.0390625]

View File

@ -1,143 +0,0 @@
from mpmath.libmp import *
from mpmath import mpf, mp
from random import randint, choice, seed
all_modes = [round_floor, round_ceiling, round_down, round_up, round_nearest]
fb = from_bstr
fi = from_int
ff = from_float
def test_div_1_3():
a = fi(1)
b = fi(3)
c = fi(-1)
# floor rounds down, ceiling rounds up
assert mpf_div(a, b, 7, round_floor) == fb('0.01010101')
assert mpf_div(a, b, 7, round_ceiling) == fb('0.01010110')
assert mpf_div(a, b, 7, round_down) == fb('0.01010101')
assert mpf_div(a, b, 7, round_up) == fb('0.01010110')
assert mpf_div(a, b, 7, round_nearest) == fb('0.01010101')
# floor rounds up, ceiling rounds down
assert mpf_div(c, b, 7, round_floor) == fb('-0.01010110')
assert mpf_div(c, b, 7, round_ceiling) == fb('-0.01010101')
assert mpf_div(c, b, 7, round_down) == fb('-0.01010101')
assert mpf_div(c, b, 7, round_up) == fb('-0.01010110')
assert mpf_div(c, b, 7, round_nearest) == fb('-0.01010101')
def test_mpf_divi_1_3():
a = 1
b = fi(3)
c = -1
assert mpf_rdiv_int(a, b, 7, round_floor) == fb('0.01010101')
assert mpf_rdiv_int(a, b, 7, round_ceiling) == fb('0.01010110')
assert mpf_rdiv_int(a, b, 7, round_down) == fb('0.01010101')
assert mpf_rdiv_int(a, b, 7, round_up) == fb('0.01010110')
assert mpf_rdiv_int(a, b, 7, round_nearest) == fb('0.01010101')
assert mpf_rdiv_int(c, b, 7, round_floor) == fb('-0.01010110')
assert mpf_rdiv_int(c, b, 7, round_ceiling) == fb('-0.01010101')
assert mpf_rdiv_int(c, b, 7, round_down) == fb('-0.01010101')
assert mpf_rdiv_int(c, b, 7, round_up) == fb('-0.01010110')
assert mpf_rdiv_int(c, b, 7, round_nearest) == fb('-0.01010101')
def test_div_300():
q = fi(1000000)
a = fi(300499999) # a/q is a little less than a half-integer
b = fi(300500000) # b/q exactly a half-integer
c = fi(300500001) # c/q is a little more than a half-integer
# Check nearest integer rounding (prec=9 as 2**8 < 300 < 2**9)
assert mpf_div(a, q, 9, round_down) == fi(300)
assert mpf_div(b, q, 9, round_down) == fi(300)
assert mpf_div(c, q, 9, round_down) == fi(300)
assert mpf_div(a, q, 9, round_up) == fi(301)
assert mpf_div(b, q, 9, round_up) == fi(301)
assert mpf_div(c, q, 9, round_up) == fi(301)
# Nearest even integer is down
assert mpf_div(a, q, 9, round_nearest) == fi(300)
assert mpf_div(b, q, 9, round_nearest) == fi(300)
assert mpf_div(c, q, 9, round_nearest) == fi(301)
# Nearest even integer is up
a = fi(301499999)
b = fi(301500000)
c = fi(301500001)
assert mpf_div(a, q, 9, round_nearest) == fi(301)
assert mpf_div(b, q, 9, round_nearest) == fi(302)
assert mpf_div(c, q, 9, round_nearest) == fi(302)
def test_tight_integer_division():
# Test that integer division at tightest possible precision is exact
N = 100
seed(1)
for i in range(N):
a = choice([1, -1]) * randint(1, 1<<randint(10, 100))
b = choice([1, -1]) * randint(1, 1<<randint(10, 100))
p = a * b
width = bitcount(abs(b)) - trailing(b)
a = fi(a); b = fi(b); p = fi(p)
for mode in all_modes:
assert mpf_div(p, a, width, mode) == b
def test_epsilon_rounding():
# Verify that mpf_div uses infinite precision; this result will
# appear to be exactly 0.101 to a near-sighted algorithm
a = fb('0.101' + ('0'*200) + '1')
b = fb('1.10101')
c = mpf_mul(a, b, 250, round_floor) # exact
assert mpf_div(c, b, bitcount(a[1]), round_floor) == a # exact
assert mpf_div(c, b, 2, round_down) == fb('0.10')
assert mpf_div(c, b, 3, round_down) == fb('0.101')
assert mpf_div(c, b, 2, round_up) == fb('0.11')
assert mpf_div(c, b, 3, round_up) == fb('0.110')
assert mpf_div(c, b, 2, round_floor) == fb('0.10')
assert mpf_div(c, b, 3, round_floor) == fb('0.101')
assert mpf_div(c, b, 2, round_ceiling) == fb('0.11')
assert mpf_div(c, b, 3, round_ceiling) == fb('0.110')
# The same for negative numbers
a = fb('-0.101' + ('0'*200) + '1')
b = fb('1.10101')
c = mpf_mul(a, b, 250, round_floor)
assert mpf_div(c, b, bitcount(a[1]), round_floor) == a
assert mpf_div(c, b, 2, round_down) == fb('-0.10')
assert mpf_div(c, b, 3, round_up) == fb('-0.110')
# Floor goes up, ceiling goes down
assert mpf_div(c, b, 2, round_floor) == fb('-0.11')
assert mpf_div(c, b, 3, round_floor) == fb('-0.110')
assert mpf_div(c, b, 2, round_ceiling) == fb('-0.10')
assert mpf_div(c, b, 3, round_ceiling) == fb('-0.101')
def test_mod():
mp.dps = 15
assert mpf(234) % 1 == 0
assert mpf(-3) % 256 == 253
assert mpf(0.25) % 23490.5 == 0.25
assert mpf(0.25) % -23490.5 == -23490.25
assert mpf(-0.25) % 23490.5 == 23490.25
assert mpf(-0.25) % -23490.5 == -0.25
# Check that these cases are handled efficiently
assert mpf('1e10000000000') % 1 == 0
assert mpf('1.23e-1000000000') % 1 == mpf('1.23e-1000000000')
# test __rmod__
assert 3 % mpf('1.75') == 1.25
def test_div_negative_rnd_bug():
mp.dps = 15
assert (-3) / mpf('0.1531879017645047') == mpf('-19.583791966887116')
assert mpf('-2.6342475750861301') / mpf('0.35126216427941814') == mpf('-7.4993775104985909')

View File

@ -1,537 +0,0 @@
"""
Limited tests of the elliptic functions module. A full suite of
extensive testing can be found in elliptic_torture_tests.py
Author of the first version: M.T. Taschuk
References:
[1] Abramowitz & Stegun. 'Handbook of Mathematical Functions, 9th Ed.',
(Dover duplicate of 1972 edition)
[2] Whittaker 'A Course of Modern Analysis, 4th Ed.', 1946,
Cambridge University Press
"""
import mpmath
import random
from mpmath import *
def mpc_ae(a, b, eps=eps):
res = True
res = res and a.real.ae(b.real, eps)
res = res and a.imag.ae(b.imag, eps)
return res
zero = mpf(0)
one = mpf(1)
def test_calculate_nome():
mp.dps = 100
q = calculate_nome(zero)
assert(q == zero)
mp.dps = 25
# used Mathematica's EllipticNomeQ[m]
math1 = [(mpf(1)/10, mpf('0.006584651553858370274473060')),
(mpf(2)/10, mpf('0.01394285727531826872146409')),
(mpf(3)/10, mpf('0.02227743615715350822901627')),
(mpf(4)/10, mpf('0.03188334731336317755064299')),
(mpf(5)/10, mpf('0.04321391826377224977441774')),
(mpf(6)/10, mpf('0.05702025781460967637754953')),
(mpf(7)/10, mpf('0.07468994353717944761143751')),
(mpf(8)/10, mpf('0.09927369733882489703607378')),
(mpf(9)/10, mpf('0.1401731269542615524091055')),
(mpf(9)/10, mpf('0.1401731269542615524091055'))]
for i in math1:
m = i[0]
q = calculate_nome(sqrt(m))
assert q.ae(i[1])
mp.dps = 15
def test_jtheta():
mp.dps = 25
z = q = zero
for n in range(1,5):
value = jtheta(n, z, q)
assert(value == (n-1)//2)
for q in [one, mpf(2)]:
for n in range(1,5):
raised = True
try:
r = jtheta(n, z, q)
except:
pass
else:
raised = False
assert(raised)
z = one/10
q = one/11
# Mathematical N[EllipticTheta[1, 1/10, 1/11], 25]
res = mpf('0.1069552990104042681962096')
result = jtheta(1, z, q)
assert(result.ae(res))
# Mathematica N[EllipticTheta[2, 1/10, 1/11], 25]
res = mpf('1.101385760258855791140606')
result = jtheta(2, z, q)
assert(result.ae(res))
# Mathematica N[EllipticTheta[3, 1/10, 1/11], 25]
res = mpf('1.178319743354331061795905')
result = jtheta(3, z, q)
assert(result.ae(res))
# Mathematica N[EllipticTheta[4, 1/10, 1/11], 25]
res = mpf('0.8219318954665153577314573')
result = jtheta(4, z, q)
assert(result.ae(res))
# test for sin zeros for jtheta(1, z, q)
# test for cos zeros for jtheta(2, z, q)
z1 = pi
z2 = pi/2
for i in range(10):
qstring = str(random.random())
q = mpf(qstring)
result = jtheta(1, z1, q)
assert(result.ae(0))
result = jtheta(2, z2, q)
assert(result.ae(0))
mp.dps = 15
def test_jtheta_issue39():
# near the circle of covergence |q| = 1 the convergence slows
# down; for |q| > Q_LIM the theta functions raise ValueError
mp.dps = 30
mp.dps += 30
q = mpf(6)/10 - one/10**6 - mpf(8)/10 * j
mp.dps -= 30
# Mathematica run first
# N[EllipticTheta[3, 1, 6/10 - 10^-6 - 8/10*I], 2000]
# then it works:
# N[EllipticTheta[3, 1, 6/10 - 10^-6 - 8/10*I], 30]
res = mpf('32.0031009628901652627099524264') + \
mpf('16.6153027998236087899308935624') * j
result = jtheta(3, 1, q)
# check that for abs(q) > Q_LIM a ValueError exception is raised
mp.dps += 30
q = mpf(6)/10 - one/10**7 - mpf(8)/10 * j
mp.dps -= 30
try:
result = jtheta(3, 1, q)
except ValueError:
pass
else:
assert(False)
# bug reported in issue39
mp.dps = 100
z = (1+j)/3
q = mpf(368983957219251)/10**15 + mpf(636363636363636)/10**15 * j
# Mathematica N[EllipticTheta[1, z, q], 35]
res = mpf('2.4439389177990737589761828991467471') + \
mpf('0.5446453005688226915290954851851490') *j
mp.dps = 30
result = jtheta(1, z, q)
assert(result.ae(res))
mp.dps = 80
z = 3 + 4*j
q = 0.5 + 0.5*j
r1 = jtheta(1, z, q)
mp.dps = 15
r2 = jtheta(1, z, q)
assert r1.ae(r2)
mp.dps = 80
z = 3 + j
q1 = exp(j*3)
# longer test
# for n in range(1, 6)
for n in range(1, 2):
mp.dps = 80
q = q1*(1 - mpf(1)/10**n)
r1 = jtheta(1, z, q)
mp.dps = 15
r2 = jtheta(1, z, q)
assert r1.ae(r2)
mp.dps = 15
# issue 39 about high derivatives
assert jtheta(3, 4.5, 0.25, 9).ae(1359.04892680683)
assert jtheta(3, 4.5, 0.25, 50).ae(-6.14832772630905e+33)
mp.dps = 50
r = jtheta(3, 4.5, 0.25, 9)
assert r.ae('1359.048926806828939547859396600218966947753213803')
r = jtheta(3, 4.5, 0.25, 50)
assert r.ae('-6148327726309051673317975084654262.4119215720343656')
def test_jtheta_identities():
"""
Tests the some of the jacobi identidies found in Abramowitz,
Sec. 16.28, Pg. 576. The identities are tested to 1 part in 10^98.
"""
mp.dps = 110
eps1 = ldexp(eps, 30)
for i in range(10):
qstring = str(random.random())
q = mpf(qstring)
zstring = str(10*random.random())
z = mpf(zstring)
# Abramowitz 16.28.1
# v_1(z, q)**2 * v_4(0, q)**2 = v_3(z, q)**2 * v_2(0, q)**2
# - v_2(z, q)**2 * v_3(0, q)**2
term1 = (jtheta(1, z, q)**2) * (jtheta(4, zero, q)**2)
term2 = (jtheta(3, z, q)**2) * (jtheta(2, zero, q)**2)
term3 = (jtheta(2, z, q)**2) * (jtheta(3, zero, q)**2)
equality = term1 - term2 + term3
assert(equality.ae(0, eps1))
zstring = str(100*random.random())
z = mpf(zstring)
# Abramowitz 16.28.2
# v_2(z, q)**2 * v_4(0, q)**2 = v_4(z, q)**2 * v_2(0, q)**2
# - v_1(z, q)**2 * v_3(0, q)**2
term1 = (jtheta(2, z, q)**2) * (jtheta(4, zero, q)**2)
term2 = (jtheta(4, z, q)**2) * (jtheta(2, zero, q)**2)
term3 = (jtheta(1, z, q)**2) * (jtheta(3, zero, q)**2)
equality = term1 - term2 + term3
assert(equality.ae(0, eps1))
# Abramowitz 16.28.3
# v_3(z, q)**2 * v_4(0, q)**2 = v_4(z, q)**2 * v_3(0, q)**2
# - v_1(z, q)**2 * v_2(0, q)**2
term1 = (jtheta(3, z, q)**2) * (jtheta(4, zero, q)**2)
term2 = (jtheta(4, z, q)**2) * (jtheta(3, zero, q)**2)
term3 = (jtheta(1, z, q)**2) * (jtheta(2, zero, q)**2)
equality = term1 - term2 + term3
assert(equality.ae(0, eps1))
# Abramowitz 16.28.4
# v_4(z, q)**2 * v_4(0, q)**2 = v_3(z, q)**2 * v_3(0, q)**2
# - v_2(z, q)**2 * v_2(0, q)**2
term1 = (jtheta(4, z, q)**2) * (jtheta(4, zero, q)**2)
term2 = (jtheta(3, z, q)**2) * (jtheta(3, zero, q)**2)
term3 = (jtheta(2, z, q)**2) * (jtheta(2, zero, q)**2)
equality = term1 - term2 + term3
assert(equality.ae(0, eps1))
# Abramowitz 16.28.5
# v_2(0, q)**4 + v_4(0, q)**4 == v_3(0, q)**4
term1 = (jtheta(2, zero, q))**4
term2 = (jtheta(4, zero, q))**4
term3 = (jtheta(3, zero, q))**4
equality = term1 + term2 - term3
assert(equality.ae(0, eps1))
mp.dps = 15
def test_jtheta_complex():
mp.dps = 30
z = mpf(1)/4 + j/8
q = mpf(1)/3 + j/7
# Mathematica N[EllipticTheta[1, 1/4 + I/8, 1/3 + I/7], 35]
res = mpf('0.31618034835986160705729105731678285') + \
mpf('0.07542013825835103435142515194358975') * j
r = jtheta(1, z, q)
assert(mpc_ae(r, res))
# Mathematica N[EllipticTheta[2, 1/4 + I/8, 1/3 + I/7], 35]
res = mpf('1.6530986428239765928634711417951828') + \
mpf('0.2015344864707197230526742145361455') * j
r = jtheta(2, z, q)
assert(mpc_ae(r, res))
# Mathematica N[EllipticTheta[3, 1/4 + I/8, 1/3 + I/7], 35]
res = mpf('1.6520564411784228184326012700348340') + \
mpf('0.1998129119671271328684690067401823') * j
r = jtheta(3, z, q)
assert(mpc_ae(r, res))
# Mathematica N[EllipticTheta[4, 1/4 + I/8, 1/3 + I/7], 35]
res = mpf('0.37619082382228348252047624089973824') - \
mpf('0.15623022130983652972686227200681074') * j
r = jtheta(4, z, q)
assert(mpc_ae(r, res))
# check some theta function identities
mp.dos = 100
z = mpf(1)/4 + j/8
q = mpf(1)/3 + j/7
mp.dps += 10
a = [0,0, jtheta(2, 0, q), jtheta(3, 0, q), jtheta(4, 0, q)]
t = [0, jtheta(1, z, q), jtheta(2, z, q), jtheta(3, z, q), jtheta(4, z, q)]
r = [(t[2]*a[4])**2 - (t[4]*a[2])**2 + (t[1] *a[3])**2,
(t[3]*a[4])**2 - (t[4]*a[3])**2 + (t[1] *a[2])**2,
(t[1]*a[4])**2 - (t[3]*a[2])**2 + (t[2] *a[3])**2,
(t[4]*a[4])**2 - (t[3]*a[3])**2 + (t[2] *a[2])**2,
a[2]**4 + a[4]**4 - a[3]**4]
mp.dps -= 10
for x in r:
assert(mpc_ae(x, mpc(0)))
mp.dps = 15
def test_djtheta():
mp.dps = 30
z = one/7 + j/3
q = one/8 + j/5
# Mathematica N[EllipticThetaPrime[1, 1/7 + I/3, 1/8 + I/5], 35]
res = mpf('1.5555195883277196036090928995803201') - \
mpf('0.02439761276895463494054149673076275') * j
result = jtheta(1, z, q, 1)
assert(mpc_ae(result, res))
# Mathematica N[EllipticThetaPrime[2, 1/7 + I/3, 1/8 + I/5], 35]
res = mpf('0.19825296689470982332701283509685662') - \
mpf('0.46038135182282106983251742935250009') * j
result = jtheta(2, z, q, 1)
assert(mpc_ae(result, res))
# Mathematica N[EllipticThetaPrime[3, 1/7 + I/3, 1/8 + I/5], 35]
res = mpf('0.36492498415476212680896699407390026') - \
mpf('0.57743812698666990209897034525640369') * j
result = jtheta(3, z, q, 1)
assert(mpc_ae(result, res))
# Mathematica N[EllipticThetaPrime[4, 1/7 + I/3, 1/8 + I/5], 35]
res = mpf('-0.38936892528126996010818803742007352') + \
mpf('0.66549886179739128256269617407313625') * j
result = jtheta(4, z, q, 1)
assert(mpc_ae(result, res))
for i in range(10):
q = (one*random.random() + j*random.random())/2
# identity in Wittaker, Watson &21.41
a = jtheta(1, 0, q, 1)
b = jtheta(2, 0, q)*jtheta(3, 0, q)*jtheta(4, 0, q)
assert(a.ae(b))
# test higher derivatives
mp.dps = 20
for q,z in [(one/3, one/5), (one/3 + j/8, one/5),
(one/3, one/5 + j/8), (one/3 + j/7, one/5 + j/8)]:
for n in [1, 2, 3, 4]:
r = jtheta(n, z, q, 2)
r1 = diff(lambda zz: jtheta(n, zz, q), z, n=2)
assert r.ae(r1)
r = jtheta(n, z, q, 3)
r1 = diff(lambda zz: jtheta(n, zz, q), z, n=3)
assert r.ae(r1)
# identity in Wittaker, Watson &21.41
q = one/3
z = zero
a = [0]*5
a[1] = jtheta(1, z, q, 3)/jtheta(1, z, q, 1)
for n in [2,3,4]:
a[n] = jtheta(n, z, q, 2)/jtheta(n, z, q)
equality = a[2] + a[3] + a[4] - a[1]
assert(equality.ae(0))
mp.dps = 15
def test_jsn():
"""
Test some special cases of the sn(z, q) function.
"""
mp.dps = 100
# trival case
result = jsn(zero, zero)
assert(result == zero)
# Abramowitz Table 16.5
#
# sn(0, m) = 0
for i in range(10):
qstring = str(random.random())
q = mpf(qstring)
equality = jsn(zero, q)
assert(equality.ae(0))
# Abramowitz Table 16.6.1
#
# sn(z, 0) = sin(z), m == 0
#
# sn(z, 1) = tanh(z), m == 1
#
# It would be nice to test these, but I find that they run
# in to numerical trouble. I'm currently treating as a boundary
# case for sn function.
mp.dps = 25
arg = one/10
#N[JacobiSN[1/10, 2^-100], 25]
res = mpf('0.09983341664682815230681420')
m = ldexp(one, -100)
result = jsn(arg, m)
assert(result.ae(res))
# N[JacobiSN[1/10, 1/10], 25]
res = mpf('0.09981686718599080096451168')
result = jsn(arg, arg)
assert(result.ae(res))
mp.dps = 15
def test_jcn():
"""
Test some special cases of the cn(z, q) function.
"""
mp.dps = 100
# Abramowitz Table 16.5
# cn(0, q) = 1
qstring = str(random.random())
q = mpf(qstring)
cn = jcn(zero, q)
assert(cn.ae(one))
# Abramowitz Table 16.6.2
#
# cn(u, 0) = cos(u), m == 0
#
# cn(u, 1) = sech(z), m == 1
#
# It would be nice to test these, but I find that they run
# in to numerical trouble. I'm currently treating as a boundary
# case for cn function.
mp.dps = 25
arg = one/10
m = ldexp(one, -100)
#N[JacobiCN[1/10, 2^-100], 25]
res = mpf('0.9950041652780257660955620')
result = jcn(arg, m)
assert(result.ae(res))
# N[JacobiCN[1/10, 1/10], 25]
res = mpf('0.9950058256237368748520459')
result = jcn(arg, arg)
assert(result.ae(res))
mp.dps = 15
def test_jdn():
"""
Test some special cases of the dn(z, q) function.
"""
mp.dps = 100
# Abramowitz Table 16.5
# dn(0, q) = 1
mstring = str(random.random())
m = mpf(mstring)
dn = jdn(zero, m)
assert(dn.ae(one))
mp.dps = 25
# N[JacobiDN[1/10, 1/10], 25]
res = mpf('0.9995017055025556219713297')
arg = one/10
result = jdn(arg, arg)
assert(result.ae(res))
mp.dps = 15
def test_sn_cn_dn_identities():
"""
Tests the some of the jacobi elliptic function identities found
on Mathworld. Haven't found in Abramowitz.
"""
mp.dps = 100
N = 5
for i in range(N):
qstring = str(random.random())
q = mpf(qstring)
zstring = str(100*random.random())
z = mpf(zstring)
# MathWorld
# sn(z, q)**2 + cn(z, q)**2 == 1
term1 = jsn(z, q)**2
term2 = jcn(z, q)**2
equality = one - term1 - term2
assert(equality.ae(0))
# MathWorld
# k**2 * sn(z, m)**2 + dn(z, m)**2 == 1
for i in range(N):
mstring = str(random.random())
m = mpf(qstring)
k = m.sqrt()
zstring = str(10*random.random())
z = mpf(zstring)
term1 = k**2 * jsn(z, m)**2
term2 = jdn(z, m)**2
equality = one - term1 - term2
assert(equality.ae(0))
for i in range(N):
mstring = str(random.random())
m = mpf(mstring)
k = m.sqrt()
zstring = str(random.random())
z = mpf(zstring)
# MathWorld
# k**2 * cn(z, m)**2 + (1 - k**2) = dn(z, m)**2
term1 = k**2 * jcn(z, m)**2
term2 = 1 - k**2
term3 = jdn(z, m)**2
equality = term3 - term1 - term2
assert(equality.ae(0))
K = ellipk(k**2)
# Abramowitz Table 16.5
# sn(K, m) = 1; K is K(k), first complete elliptic integral
r = jsn(K, m)
assert(r.ae(one))
# Abramowitz Table 16.5
# cn(K, q) = 0; K is K(k), first complete elliptic integral
equality = jcn(K, m)
assert(equality.ae(0))
# Abramowitz Table 16.6.3
# dn(z, 0) = 1, m == 0
z = m
value = jdn(z, zero)
assert(value.ae(one))
mp.dps = 15
def test_sn_cn_dn_complex():
mp.dps = 30
# N[JacobiSN[1/4 + I/8, 1/3 + I/7], 35] in Mathematica
res = mpf('0.2495674401066275492326652143537') + \
mpf('0.12017344422863833381301051702823') * j
u = mpf(1)/4 + j/8
m = mpf(1)/3 + j/7
r = jsn(u, m)
assert(mpc_ae(r, res))
#N[JacobiCN[1/4 + I/8, 1/3 + I/7], 35]
res = mpf('0.9762691700944007312693721148331') - \
mpf('0.0307203994181623243583169154824')*j
r = jcn(u, m)
#assert r.real.ae(res.real)
#assert r.imag.ae(res.imag)
assert(mpc_ae(r, res))
#N[JacobiDN[1/4 + I/8, 1/3 + I/7], 35]
res = mpf('0.99639490163039577560547478589753039') - \
mpf('0.01346296520008176393432491077244994')*j
r = jdn(u, m)
assert(mpc_ae(r, res))
mp.dps = 15

File diff suppressed because it is too large Load Diff

View File

@ -1,882 +0,0 @@
from mpmath.libmp import *
from mpmath import *
import random
import time
import math
import cmath
def mpc_ae(a, b, eps=eps):
res = True
res = res and a.real.ae(b.real, eps)
res = res and a.imag.ae(b.imag, eps)
return res
#----------------------------------------------------------------------------
# Constants and functions
#
tpi = "3.1415926535897932384626433832795028841971693993751058209749445923078\
1640628620899862803482534211706798"
te = "2.71828182845904523536028747135266249775724709369995957496696762772407\
663035354759457138217852516642743"
tdegree = "0.017453292519943295769236907684886127134428718885417254560971914\
4017100911460344944368224156963450948221"
teuler = "0.5772156649015328606065120900824024310421593359399235988057672348\
84867726777664670936947063291746749516"
tln2 = "0.693147180559945309417232121458176568075500134360255254120680009493\
393621969694715605863326996418687542"
tln10 = "2.30258509299404568401799145468436420760110148862877297603332790096\
757260967735248023599720508959829834"
tcatalan = "0.91596559417721901505460351493238411077414937428167213426649811\
9621763019776254769479356512926115106249"
tkhinchin = "2.6854520010653064453097148354817956938203822939944629530511523\
4555721885953715200280114117493184769800"
tglaisher = "1.2824271291006226368753425688697917277676889273250011920637400\
2174040630885882646112973649195820237439420646"
tapery = "1.2020569031595942853997381615114499907649862923404988817922715553\
4183820578631309018645587360933525815"
tphi = "1.618033988749894848204586834365638117720309179805762862135448622705\
26046281890244970720720418939113748475"
tmertens = "0.26149721284764278375542683860869585905156664826119920619206421\
3924924510897368209714142631434246651052"
ttwinprime = "0.660161815846869573927812110014555778432623360284733413319448\
423335405642304495277143760031413839867912"
def test_constants():
for prec in [3, 7, 10, 15, 20, 37, 80, 100, 29]:
mp.dps = prec
assert pi == mpf(tpi)
assert e == mpf(te)
assert degree == mpf(tdegree)
assert euler == mpf(teuler)
assert ln2 == mpf(tln2)
assert ln10 == mpf(tln10)
assert catalan == mpf(tcatalan)
assert khinchin == mpf(tkhinchin)
assert glaisher == mpf(tglaisher)
assert phi == mpf(tphi)
if prec < 50:
assert mertens == mpf(tmertens)
assert twinprime == mpf(ttwinprime)
mp.dps = 15
assert pi >= -1
assert pi > 2
assert pi > 3
assert pi < 4
def test_exact_sqrts():
for i in range(20000):
assert sqrt(mpf(i*i)) == i
random.seed(1)
for prec in [100, 300, 1000, 10000]:
mp.dps = prec
for i in range(20):
A = random.randint(10**(prec//2-2), 10**(prec//2-1))
assert sqrt(mpf(A*A)) == A
mp.dps = 15
for i in range(100):
for a in [1, 8, 25, 112307]:
assert sqrt(mpf((a*a, 2*i))) == mpf((a, i))
assert sqrt(mpf((a*a, -2*i))) == mpf((a, -i))
def test_sqrt_rounding():
for i in [2, 3, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15]:
i = from_int(i)
for dps in [7, 15, 83, 106, 2000]:
mp.dps = dps
a = mpf_pow_int(mpf_sqrt(i, mp.prec, round_down), 2, mp.prec, round_down)
b = mpf_pow_int(mpf_sqrt(i, mp.prec, round_up), 2, mp.prec, round_up)
assert mpf_lt(a, i)
assert mpf_gt(b, i)
random.seed(1234)
prec = 100
for rnd in [round_down, round_nearest, round_ceiling]:
for i in range(100):
a = mpf_rand(prec)
b = mpf_mul(a, a)
assert mpf_sqrt(b, prec, rnd) == a
# Test some extreme cases
mp.dps = 100
a = mpf(9) + 1e-90
b = mpf(9) - 1e-90
mp.dps = 15
assert sqrt(a, rounding='d') == 3
assert sqrt(a, rounding='n') == 3
assert sqrt(a, rounding='u') > 3
assert sqrt(b, rounding='d') < 3
assert sqrt(b, rounding='n') == 3
assert sqrt(b, rounding='u') == 3
# A worst case, from the MPFR test suite
assert sqrt(mpf('7.0503726185518891')) == mpf('2.655253776675949')
def test_float_sqrt():
mp.dps = 15
# These should round identically
for x in [0, 1e-7, 0.1, 0.5, 1, 2, 3, 4, 5, 0.333, 76.19]:
assert sqrt(mpf(x)) == float(x)**0.5
assert sqrt(-1) == 1j
assert sqrt(-2).ae(cmath.sqrt(-2))
assert sqrt(-3).ae(cmath.sqrt(-3))
assert sqrt(-100).ae(cmath.sqrt(-100))
assert sqrt(1j).ae(cmath.sqrt(1j))
assert sqrt(-1j).ae(cmath.sqrt(-1j))
assert sqrt(math.pi + math.e*1j).ae(cmath.sqrt(math.pi + math.e*1j))
assert sqrt(math.pi - math.e*1j).ae(cmath.sqrt(math.pi - math.e*1j))
def test_hypot():
assert hypot(0, 0) == 0
assert hypot(0, 0.33) == mpf(0.33)
assert hypot(0.33, 0) == mpf(0.33)
assert hypot(-0.33, 0) == mpf(0.33)
assert hypot(3, 4) == mpf(5)
def test_exact_cbrt():
for i in range(0, 20000, 200):
assert cbrt(mpf(i*i*i)) == i
random.seed(1)
for prec in [100, 300, 1000, 10000]:
mp.dps = prec
A = random.randint(10**(prec//2-2), 10**(prec//2-1))
assert cbrt(mpf(A*A*A)) == A
mp.dps = 15
def test_exp():
assert exp(0) == 1
assert exp(10000).ae(mpf('8.8068182256629215873e4342'))
assert exp(-10000).ae(mpf('1.1354838653147360985e-4343'))
a = exp(mpf((1, 8198646019315405L, -53, 53)))
assert(a.bc == bitcount(a.man))
mp.prec = 67
a = exp(mpf((1, 1781864658064754565L, -60, 61)))
assert(a.bc == bitcount(a.man))
mp.prec = 53
assert exp(ln2 * 10).ae(1024)
assert exp(2+2j).ae(cmath.exp(2+2j))
def test_issue_33():
mp.dps = 512
a = exp(-1)
b = exp(1)
mp.dps = 15
assert (+a).ae(0.36787944117144233)
assert (+b).ae(2.7182818284590451)
def test_log():
mp.dps = 15
assert log(1) == 0
for x in [0.5, 1.5, 2.0, 3.0, 100, 10**50, 1e-50]:
assert log(x).ae(math.log(x))
assert log(x, x) == 1
assert log(1024, 2) == 10
assert log(10**1234, 10) == 1234
assert log(2+2j).ae(cmath.log(2+2j))
# Accuracy near 1
assert (log(0.6+0.8j).real*10**17).ae(2.2204460492503131)
assert (log(0.6-0.8j).real*10**17).ae(2.2204460492503131)
assert (log(0.8-0.6j).real*10**17).ae(2.2204460492503131)
assert (log(1+1e-8j).real*10**16).ae(0.5)
assert (log(1-1e-8j).real*10**16).ae(0.5)
assert (log(-1+1e-8j).real*10**16).ae(0.5)
assert (log(-1-1e-8j).real*10**16).ae(0.5)
assert (log(1j+1e-8).real*10**16).ae(0.5)
assert (log(1j-1e-8).real*10**16).ae(0.5)
assert (log(-1j+1e-8).real*10**16).ae(0.5)
assert (log(-1j-1e-8).real*10**16).ae(0.5)
assert (log(1+1e-40j).real*10**80).ae(0.5)
assert (log(1j+1e-40).real*10**80).ae(0.5)
# Huge
assert log(ldexp(1.234,10**20)).ae(log(2)*1e20)
assert log(ldexp(1.234,10**200)).ae(log(2)*1e200)
# Some special values
assert log(mpc(0,0)) == mpc(-inf,0)
assert isnan(log(mpc(nan,0)).real)
assert isnan(log(mpc(nan,0)).imag)
assert isnan(log(mpc(0,nan)).real)
assert isnan(log(mpc(0,nan)).imag)
assert isnan(log(mpc(nan,1)).real)
assert isnan(log(mpc(nan,1)).imag)
assert isnan(log(mpc(1,nan)).real)
assert isnan(log(mpc(1,nan)).imag)
def test_trig_hyperb_basic():
for x in (range(100) + range(-100,0)):
t = x / 4.1
assert cos(mpf(t)).ae(math.cos(t))
assert sin(mpf(t)).ae(math.sin(t))
assert tan(mpf(t)).ae(math.tan(t))
assert cosh(mpf(t)).ae(math.cosh(t))
assert sinh(mpf(t)).ae(math.sinh(t))
assert tanh(mpf(t)).ae(math.tanh(t))
assert sin(1+1j).ae(cmath.sin(1+1j))
assert sin(-4-3.6j).ae(cmath.sin(-4-3.6j))
assert cos(1+1j).ae(cmath.cos(1+1j))
assert cos(-4-3.6j).ae(cmath.cos(-4-3.6j))
def test_degrees():
assert cos(0*degree) == 1
assert cos(90*degree).ae(0)
assert cos(180*degree).ae(-1)
assert cos(270*degree).ae(0)
assert cos(360*degree).ae(1)
assert sin(0*degree) == 0
assert sin(90*degree).ae(1)
assert sin(180*degree).ae(0)
assert sin(270*degree).ae(-1)
assert sin(360*degree).ae(0)
def random_complexes(N):
random.seed(1)
a = []
for i in range(N):
x1 = random.uniform(-10, 10)
y1 = random.uniform(-10, 10)
x2 = random.uniform(-10, 10)
y2 = random.uniform(-10, 10)
z1 = complex(x1, y1)
z2 = complex(x2, y2)
a.append((z1, z2))
return a
def test_complex_powers():
for dps in [15, 30, 100]:
# Check accuracy for complex square root
mp.dps = dps
a = mpc(1j)**0.5
assert a.real == a.imag == mpf(2)**0.5 / 2
mp.dps = 15
random.seed(1)
for (z1, z2) in random_complexes(100):
assert (mpc(z1)**mpc(z2)).ae(z1**z2, 1e-12)
assert (e**(-pi*1j)).ae(-1)
mp.dps = 50
assert (e**(-pi*1j)).ae(-1)
mp.dps = 15
def test_complex_sqrt_accuracy():
def test_mpc_sqrt(lst):
for a, b in lst:
z = mpc(a + j*b)
assert mpc_ae(sqrt(z*z), z)
z = mpc(-a + j*b)
assert mpc_ae(sqrt(z*z), -z)
z = mpc(a - j*b)
assert mpc_ae(sqrt(z*z), z)
z = mpc(-a - j*b)
assert mpc_ae(sqrt(z*z), -z)
random.seed(2)
N = 10
mp.dps = 30
dps = mp.dps
test_mpc_sqrt([(random.uniform(0, 10),random.uniform(0, 10)) for i in range(N)])
test_mpc_sqrt([(i + 0.1, (i + 0.2)*10**i) for i in range(N)])
mp.dps = 15
def test_atan():
mp.dps = 15
assert atan(-2.3).ae(math.atan(-2.3))
assert atan(1e-50) == 1e-50
assert atan(1e50).ae(pi/2)
assert atan(-1e-50) == -1e-50
assert atan(-1e50).ae(-pi/2)
assert atan(10**1000).ae(pi/2)
for dps in [25, 70, 100, 300, 1000]:
mp.dps = dps
assert (4*atan(1)).ae(pi)
mp.dps = 15
pi2 = pi/2
assert atan(mpc(inf,-1)).ae(pi2)
assert atan(mpc(inf,0)).ae(pi2)
assert atan(mpc(inf,1)).ae(pi2)
assert atan(mpc(1,inf)).ae(pi2)
assert atan(mpc(0,inf)).ae(pi2)
assert atan(mpc(-1,inf)).ae(-pi2)
assert atan(mpc(-inf,1)).ae(-pi2)
assert atan(mpc(-inf,0)).ae(-pi2)
assert atan(mpc(-inf,-1)).ae(-pi2)
assert atan(mpc(-1,-inf)).ae(-pi2)
assert atan(mpc(0,-inf)).ae(-pi2)
assert atan(mpc(1,-inf)).ae(pi2)
def test_atan2():
mp.dps = 15
assert atan2(1,1).ae(pi/4)
assert atan2(1,-1).ae(3*pi/4)
assert atan2(-1,-1).ae(-3*pi/4)
assert atan2(-1,1).ae(-pi/4)
assert atan2(-1,0).ae(-pi/2)
assert atan2(1,0).ae(pi/2)
assert atan2(0,0) == 0
assert atan2(inf,0).ae(pi/2)
assert atan2(-inf,0).ae(-pi/2)
assert isnan(atan2(inf,inf))
assert isnan(atan2(-inf,inf))
assert isnan(atan2(inf,-inf))
assert isnan(atan2(3,nan))
assert isnan(atan2(nan,3))
assert isnan(atan2(0,nan))
assert isnan(atan2(nan,0))
assert atan2(0,inf) == 0
assert atan2(0,-inf).ae(pi)
assert atan2(10,inf) == 0
assert atan2(-10,inf) == 0
assert atan2(-10,-inf).ae(-pi)
assert atan2(10,-inf).ae(pi)
assert atan2(inf,10).ae(pi/2)
assert atan2(inf,-10).ae(pi/2)
assert atan2(-inf,10).ae(-pi/2)
assert atan2(-inf,-10).ae(-pi/2)
def test_areal_inverses():
assert asin(mpf(0)) == 0
assert asinh(mpf(0)) == 0
assert acosh(mpf(1)) == 0
assert isinstance(asin(mpf(0.5)), mpf)
assert isinstance(asin(mpf(2.0)), mpc)
assert isinstance(acos(mpf(0.5)), mpf)
assert isinstance(acos(mpf(2.0)), mpc)
assert isinstance(atanh(mpf(0.1)), mpf)
assert isinstance(atanh(mpf(1.1)), mpc)
random.seed(1)
for i in range(50):
x = random.uniform(0, 1)
assert asin(mpf(x)).ae(math.asin(x))
assert acos(mpf(x)).ae(math.acos(x))
x = random.uniform(-10, 10)
assert asinh(mpf(x)).ae(cmath.asinh(x).real)
assert isinstance(asinh(mpf(x)), mpf)
x = random.uniform(1, 10)
assert acosh(mpf(x)).ae(cmath.acosh(x).real)
assert isinstance(acosh(mpf(x)), mpf)
x = random.uniform(-10, 0.999)
assert isinstance(acosh(mpf(x)), mpc)
x = random.uniform(-1, 1)
assert atanh(mpf(x)).ae(cmath.atanh(x).real)
assert isinstance(atanh(mpf(x)), mpf)
dps = mp.dps
mp.dps = 300
assert isinstance(asin(0.5), mpf)
mp.dps = 1000
assert asin(1).ae(pi/2)
assert asin(-1).ae(-pi/2)
mp.dps = dps
def test_invhyperb_inaccuracy():
mp.dps = 15
assert (asinh(1e-5)*10**5).ae(0.99999999998333333)
assert (asinh(1e-10)*10**10).ae(1)
assert (asinh(1e-50)*10**50).ae(1)
assert (asinh(-1e-5)*10**5).ae(-0.99999999998333333)
assert (asinh(-1e-10)*10**10).ae(-1)
assert (asinh(-1e-50)*10**50).ae(-1)
assert asinh(10**20).ae(46.744849040440862)
assert asinh(-10**20).ae(-46.744849040440862)
assert (tanh(1e-10)*10**10).ae(1)
assert (tanh(-1e-10)*10**10).ae(-1)
assert (atanh(1e-10)*10**10).ae(1)
assert (atanh(-1e-10)*10**10).ae(-1)
def test_complex_functions():
for x in (range(10) + range(-10,0)):
for y in (range(10) + range(-10,0)):
z = complex(x, y)/4.3 + 0.01j
assert exp(mpc(z)).ae(cmath.exp(z))
assert log(mpc(z)).ae(cmath.log(z))
assert cos(mpc(z)).ae(cmath.cos(z))
assert sin(mpc(z)).ae(cmath.sin(z))
assert tan(mpc(z)).ae(cmath.tan(z))
assert sinh(mpc(z)).ae(cmath.sinh(z))
assert cosh(mpc(z)).ae(cmath.cosh(z))
assert tanh(mpc(z)).ae(cmath.tanh(z))
def test_complex_inverse_functions():
for (z1, z2) in random_complexes(30):
# apparently cmath uses a different branch, so we
# can't use it for comparison
assert sinh(asinh(z1)).ae(z1)
#
assert acosh(z1).ae(cmath.acosh(z1))
assert atanh(z1).ae(cmath.atanh(z1))
assert atan(z1).ae(cmath.atan(z1))
# the reason we set a big eps here is that the cmath
# functions are inaccurate
assert asin(z1).ae(cmath.asin(z1), rel_eps=1e-12)
assert acos(z1).ae(cmath.acos(z1), rel_eps=1e-12)
one = mpf(1)
for i in range(-9, 10, 3):
for k in range(-9, 10, 3):
a = 0.9*j*10**k + 0.8*one*10**i
b = cos(acos(a))
assert b.ae(a)
b = sin(asin(a))
assert b.ae(a)
one = mpf(1)
err = 2*10**-15
for i in range(-9, 9, 3):
for k in range(-9, 9, 3):
a = -0.9*10**k + j*0.8*one*10**i
b = cosh(acosh(a))
assert b.ae(a, err)
b = sinh(asinh(a))
assert b.ae(a, err)
def test_reciprocal_functions():
assert sec(3).ae(-1.01010866590799375)
assert csc(3).ae(7.08616739573718592)
assert cot(3).ae(-7.01525255143453347)
assert sech(3).ae(0.0993279274194332078)
assert csch(3).ae(0.0998215696688227329)
assert coth(3).ae(1.00496982331368917)
assert asec(3).ae(1.23095941734077468)
assert acsc(3).ae(0.339836909454121937)
assert acot(3).ae(0.321750554396642193)
assert asech(0.5).ae(1.31695789692481671)
assert acsch(3).ae(0.327450150237258443)
assert acoth(3).ae(0.346573590279972655)
def test_ldexp():
mp.dps = 15
assert ldexp(mpf(2.5), 0) == 2.5
assert ldexp(mpf(2.5), -1) == 1.25
assert ldexp(mpf(2.5), 2) == 10
assert ldexp(mpf('inf'), 3) == mpf('inf')
def test_frexp():
mp.dps = 15
assert frexp(0) == (0.0, 0)
assert frexp(9) == (0.5625, 4)
assert frexp(1) == (0.5, 1)
assert frexp(0.2) == (0.8, -2)
assert frexp(1000) == (0.9765625, 10)
def test_aliases():
assert ln(7) == log(7)
assert log10(3.75) == log(3.75,10)
assert degrees(5.6) == 5.6 / degree
assert radians(5.6) == 5.6 * degree
assert power(-1,0.5) == j
assert modf(25,7) == 4.0 and isinstance(modf(25,7), mpf)
def test_arg_sign():
assert arg(3) == 0
assert arg(-3).ae(pi)
assert arg(j).ae(pi/2)
assert arg(-j).ae(-pi/2)
assert arg(0) == 0
assert isnan(atan2(3,nan))
assert isnan(atan2(nan,3))
assert isnan(atan2(0,nan))
assert isnan(atan2(nan,0))
assert isnan(atan2(nan,nan))
assert arg(inf) == 0
assert arg(-inf).ae(pi)
assert isnan(arg(nan))
#assert arg(inf*j).ae(pi/2)
assert sign(0) == 0
assert sign(3) == 1
assert sign(-3) == -1
assert sign(inf) == 1
assert sign(-inf) == -1
assert isnan(sign(nan))
assert sign(j) == j
assert sign(-3*j) == -j
assert sign(1+j).ae((1+j)/sqrt(2))
def test_misc_bugs():
# test that this doesn't raise an exception
mp.dps = 1000
log(1302)
mp.dps = 15
def test_arange():
assert arange(10) == [mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0'),
mpf('4.0'), mpf('5.0'), mpf('6.0'), mpf('7.0'),
mpf('8.0'), mpf('9.0')]
assert arange(-5, 5) == [mpf('-5.0'), mpf('-4.0'), mpf('-3.0'),
mpf('-2.0'), mpf('-1.0'), mpf('0.0'),
mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
assert arange(0, 1, 0.1) == [mpf('0.0'), mpf('0.10000000000000001'),
mpf('0.20000000000000001'),
mpf('0.30000000000000004'),
mpf('0.40000000000000002'),
mpf('0.5'), mpf('0.60000000000000009'),
mpf('0.70000000000000007'),
mpf('0.80000000000000004'),
mpf('0.90000000000000002')]
assert arange(17, -9, -3) == [mpf('17.0'), mpf('14.0'), mpf('11.0'),
mpf('8.0'), mpf('5.0'), mpf('2.0'),
mpf('-1.0'), mpf('-4.0'), mpf('-7.0')]
assert arange(0.2, 0.1, -0.1) == [mpf('0.20000000000000001')]
assert arange(0) == []
assert arange(1000, -1) == []
assert arange(-1.23, 3.21, -0.0000001) == []
def test_linspace():
assert linspace(2, 9, 7) == [mpf('2.0'), mpf('3.166666666666667'),
mpf('4.3333333333333339'), mpf('5.5'), mpf('6.666666666666667'),
mpf('7.8333333333333339'), mpf('9.0')] == linspace(mpi(2, 9), 7)
assert linspace(2, 9, 7, endpoint=0) == [mpf('2.0'), mpf('3.0'), mpf('4.0'),
mpf('5.0'), mpf('6.0'), mpf('7.0'), mpf('8.0')]
assert linspace(2, 7, 1) == [mpf(2)]
def test_float_cbrt():
mp.dps = 30
for a in arange(0,10,0.1):
assert cbrt(a*a*a).ae(a, eps)
assert cbrt(-1).ae(0.5 + j*sqrt(3)/2)
one_third = mpf(1)/3
for a in arange(0,10,2.7) + [0.1 + 10**5]:
a = mpc(a + 1.1j)
r1 = cbrt(a)
mp.dps += 10
r2 = pow(a, one_third)
mp.dps -= 10
assert r1.ae(r2, eps)
mp.dps = 100
for n in range(100, 301, 100):
w = 10**n + j*10**-3
z = w*w*w
r = cbrt(z)
assert mpc_ae(r, w, eps)
mp.dps = 15
def test_root():
mp.dps = 30
random.seed(1)
a = random.randint(0, 10000)
p = a*a*a
r = nthroot(mpf(p), 3)
assert r == a
for n in range(4, 10):
p = p*a
assert nthroot(mpf(p), n) == a
mp.dps = 40
for n in range(10, 5000, 100):
for a in [random.random()*10000, random.random()*10**100]:
r = nthroot(a, n)
r1 = pow(a, mpf(1)/n)
assert r.ae(r1)
r = nthroot(a, -n)
r1 = pow(a, -mpf(1)/n)
assert r.ae(r1)
# XXX: this is broken right now
# tests for nthroot rounding
for rnd in ['nearest', 'up', 'down']:
mp.rounding = rnd
for n in [-5, -3, 3, 5]:
prec = 50
for i in xrange(10):
mp.prec = prec
a = rand()
mp.prec = 2*prec
b = a**n
mp.prec = prec
r = nthroot(b, n)
assert r == a
mp.dps = 30
for n in range(3, 21):
a = (random.random() + j*random.random())
assert nthroot(a, n).ae(pow(a, mpf(1)/n))
assert mpc_ae(nthroot(a, n), pow(a, mpf(1)/n))
a = (random.random()*10**100 + j*random.random())
r = nthroot(a, n)
mp.dps += 4
r1 = pow(a, mpf(1)/n)
mp.dps -= 4
assert r.ae(r1)
assert mpc_ae(r, r1, eps)
r = nthroot(a, -n)
mp.dps += 4
r1 = pow(a, -mpf(1)/n)
mp.dps -= 4
assert r.ae(r1)
assert mpc_ae(r, r1, eps)
mp.dps = 15
assert nthroot(4, 1) == 4
assert nthroot(4, 0) == 1
assert nthroot(4, -1) == 0.25
assert nthroot(inf, 1) == inf
assert nthroot(inf, 2) == inf
assert nthroot(inf, 3) == inf
assert nthroot(inf, -1) == 0
assert nthroot(inf, -2) == 0
assert nthroot(inf, -3) == 0
assert nthroot(j, 1) == j
assert nthroot(j, 0) == 1
assert nthroot(j, -1) == -j
assert isnan(nthroot(nan, 1))
assert isnan(nthroot(nan, 0))
assert isnan(nthroot(nan, -1))
assert isnan(nthroot(inf, 0))
assert root(2,3) == nthroot(2,3)
assert root(16,4,0) == 2
assert root(16,4,1) == 2j
assert root(16,4,2) == -2
assert root(16,4,3) == -2j
assert root(16,4,4) == 2
assert root(-125,3,1) == -5
def test_issue_96():
for dps in [20, 80]:
mp.dps = dps
r = nthroot(mpf('-1e-20'), 4)
assert r.ae(mpf(10)**(-5) * (1 + j) * mpf(2)**(-0.5))
mp.dps = 80
assert nthroot('-1e-3', 4).ae(mpf(10)**(-3./4) * (1 + j)/sqrt(2))
assert nthroot('-1e-6', 4).ae((1 + j)/(10 * sqrt(20)))
# Check that this doesn't take eternity to compute
mp.dps = 20
assert nthroot('-1e100000000', 4).ae((1+j)*mpf('1e25000000')/sqrt(2))
mp.dps = 15
def test_perturbation_rounding():
mp.dps = 100
a = pi/10**50
b = -pi/10**50
c = 1 + a
d = 1 + b
mp.dps = 15
assert exp(a) == 1
assert exp(a, rounding='c') > 1
assert exp(b, rounding='c') == 1
assert exp(a, rounding='f') == 1
assert exp(b, rounding='f') < 1
assert cos(a) == 1
assert cos(a, rounding='c') == 1
assert cos(b, rounding='c') == 1
assert cos(a, rounding='f') < 1
assert cos(b, rounding='f') < 1
for f in [sin, atan, asinh, tanh]:
assert f(a) == +a
assert f(a, rounding='c') > a
assert f(a, rounding='f') < a
assert f(b) == +b
assert f(b, rounding='c') > b
assert f(b, rounding='f') < b
for f in [asin, tan, sinh, atanh]:
assert f(a) == +a
assert f(b) == +b
assert f(a, rounding='c') > a
assert f(b, rounding='c') > b
assert f(a, rounding='f') < a
assert f(b, rounding='f') < b
assert ln(c) == +a
assert ln(d) == +b
assert ln(c, rounding='c') > a
assert ln(c, rounding='f') < a
assert ln(d, rounding='c') > b
assert ln(d, rounding='f') < b
assert cosh(a) == 1
assert cosh(b) == 1
assert cosh(a, rounding='c') > 1
assert cosh(b, rounding='c') > 1
assert cosh(a, rounding='f') == 1
assert cosh(b, rounding='f') == 1
def test_integer_parts():
assert floor(3.2) == 3
assert ceil(3.2) == 4
assert floor(3.2+5j) == 3+5j
assert ceil(3.2+5j) == 4+5j
def test_complex_parts():
assert fabs('3') == 3
assert fabs(3+4j) == 5
assert re(3) == 3
assert re(1+4j) == 1
assert im(3) == 0
assert im(1+4j) == 4
assert conj(3) == 3
assert conj(3+4j) == 3-4j
assert mpf(3).conjugate() == 3
def test_cospi_sinpi():
assert sinpi(0) == 0
assert sinpi(0.5) == 1
assert sinpi(1) == 0
assert sinpi(1.5) == -1
assert sinpi(2) == 0
assert sinpi(2.5) == 1
assert sinpi(-0.5) == -1
assert cospi(0) == 1
assert cospi(0.5) == 0
assert cospi(1) == -1
assert cospi(1.5) == 0
assert cospi(2) == 1
assert cospi(2.5) == 0
assert cospi(-0.5) == 0
assert cospi(100000000000.25).ae(sqrt(2)/2)
a = cospi(2+3j)
assert a.real.ae(cos((2+3j)*pi).real)
assert a.imag == 0
b = sinpi(2+3j)
assert b.imag.ae(sin((2+3j)*pi).imag)
assert b.real == 0
mp.dps = 35
x1 = mpf(10000) - mpf('1e-15')
x2 = mpf(10000) + mpf('1e-15')
x3 = mpf(10000.5) - mpf('1e-15')
x4 = mpf(10000.5) + mpf('1e-15')
x5 = mpf(10001) - mpf('1e-15')
x6 = mpf(10001) + mpf('1e-15')
x7 = mpf(10001.5) - mpf('1e-15')
x8 = mpf(10001.5) + mpf('1e-15')
mp.dps = 15
M = 10**15
assert (sinpi(x1)*M).ae(-pi)
assert (sinpi(x2)*M).ae(pi)
assert (cospi(x3)*M).ae(pi)
assert (cospi(x4)*M).ae(-pi)
assert (sinpi(x5)*M).ae(pi)
assert (sinpi(x6)*M).ae(-pi)
assert (cospi(x7)*M).ae(-pi)
assert (cospi(x8)*M).ae(pi)
assert 0.999 < cospi(x1, rounding='d') < 1
assert 0.999 < cospi(x2, rounding='d') < 1
assert 0.999 < sinpi(x3, rounding='d') < 1
assert 0.999 < sinpi(x4, rounding='d') < 1
assert -1 < cospi(x5, rounding='d') < -0.999
assert -1 < cospi(x6, rounding='d') < -0.999
assert -1 < sinpi(x7, rounding='d') < -0.999
assert -1 < sinpi(x8, rounding='d') < -0.999
assert (sinpi(1e-15)*M).ae(pi)
assert (sinpi(-1e-15)*M).ae(-pi)
assert cospi(1e-15) == 1
assert cospi(1e-15, rounding='d') < 1
def test_expj():
assert expj(0) == 1
assert expj(1).ae(exp(j))
assert expj(j).ae(exp(-1))
assert expj(1+j).ae(exp(j*(1+j)))
assert expjpi(0) == 1
assert expjpi(1).ae(exp(j*pi))
assert expjpi(j).ae(exp(-pi))
assert expjpi(1+j).ae(exp(j*pi*(1+j)))
assert expjpi(-10**15 * j).ae('2.22579818340535731e+1364376353841841')
def test_sinc():
assert sinc(0) == sincpi(0) == 1
assert sinc(inf) == sincpi(inf) == 0
assert sinc(-inf) == sincpi(-inf) == 0
assert sinc(2).ae(0.45464871341284084770)
assert sinc(2+3j).ae(0.4463290318402435457-2.7539470277436474940j)
assert sincpi(2) == 0
assert sincpi(1.5).ae(-0.212206590789193781)
def test_fibonacci():
mp.dps = 15
assert [fibonacci(n) for n in range(-5, 10)] == \
[5, -3, 2, -1, 1, 0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
assert fib(2.5).ae(1.4893065462657091)
assert fib(3+4j).ae(-5248.51130728372 - 14195.962288353j)
assert fib(1000).ae(4.3466557686937455e+208)
assert str(fib(10**100)) == '6.24499112864607e+2089876402499787337692720892375554168224592399182109535392875613974104853496745963277658556235103534'
mp.dps = 2100
a = fib(10000)
assert a % 10**10 == 9947366875
mp.dps = 15
assert fibonacci(inf) == inf
assert fib(3+0j) == 2
def test_call_with_dps():
mp.dps = 15
assert abs(exp(1, dps=30)-e(dps=35)) < 1e-29
def test_tanh():
mp.dps = 15
assert tanh(0) == 0
assert tanh(inf) == 1
assert tanh(-inf) == -1
assert isnan(tanh(nan))
assert tanh(mpc('inf', '0')) == 1
def test_atanh():
mp.dps = 15
assert atanh(0) == 0
assert atanh(0.5).ae(0.54930614433405484570)
assert atanh(-0.5).ae(-0.54930614433405484570)
assert atanh(1) == inf
assert atanh(-1) == -inf
assert isnan(atanh(nan))
assert isinstance(atanh(1), mpf)
assert isinstance(atanh(-1), mpf)
# Limits at infinity
jpi2 = j*pi/2
assert atanh(inf).ae(-jpi2)
assert atanh(-inf).ae(jpi2)
assert atanh(mpc(inf,-1)).ae(-jpi2)
assert atanh(mpc(inf,0)).ae(-jpi2)
assert atanh(mpc(inf,1)).ae(jpi2)
assert atanh(mpc(1,inf)).ae(jpi2)
assert atanh(mpc(0,inf)).ae(jpi2)
assert atanh(mpc(-1,inf)).ae(jpi2)
assert atanh(mpc(-inf,1)).ae(jpi2)
assert atanh(mpc(-inf,0)).ae(jpi2)
assert atanh(mpc(-inf,-1)).ae(-jpi2)
assert atanh(mpc(-1,-inf)).ae(-jpi2)
assert atanh(mpc(0,-inf)).ae(-jpi2)
assert atanh(mpc(1,-inf)).ae(-jpi2)
def test_expm1():
mp.dps = 15
assert expm1(0) == 0
assert expm1(3).ae(exp(3)-1)
assert expm1(inf) == inf
assert expm1(1e-10)*1e10
assert expm1(1e-50).ae(1e-50)
assert (expm1(1e-10)*1e10).ae(1.00000000005)
def test_powm1():
mp.dps = 15
assert powm1(2,3) == 7
assert powm1(-1,2) == 0
assert powm1(-1,0) == 0
assert powm1(-2,0) == 0
assert powm1(3+4j,0) == 0
assert powm1(0,1) == -1
assert powm1(0,0) == 0
assert powm1(1,0) == 0
assert powm1(1,2) == 0
assert powm1(1,3+4j) == 0
assert powm1(1,5) == 0
assert powm1(j,4) == 0
assert powm1(-j,4) == 0
assert (powm1(2,1e-100)*1e100).ae(ln2)
assert powm1(2,'1e-100000000000') != 0
assert (powm1(fadd(1,1e-100,exact=True), 5)*1e100).ae(5)
def test_unitroots():
assert unitroots(1) == [1]
assert unitroots(2) == [1, -1]
a, b, c = unitroots(3)
assert a == 1
assert b.ae(-0.5 + 0.86602540378443864676j)
assert c.ae(-0.5 - 0.86602540378443864676j)
assert unitroots(1, primitive=True) == [1]
assert unitroots(2, primitive=True) == [-1]
assert unitroots(3, primitive=True) == unitroots(3)[1:]
assert unitroots(4, primitive=True) == [j, -j]
assert len(unitroots(17, primitive=True)) == 16
assert len(unitroots(16, primitive=True)) == 8
def test_cyclotomic():
mp.dps = 15
assert [cyclotomic(n,1) for n in range(31)] == [1,0,2,3,2,5,1,7,2,3,1,11,1,13,1,1,2,17,1,19,1,1,1,23,1,5,1,3,1,29,1]
assert [cyclotomic(n,-1) for n in range(31)] == [1,-2,0,1,2,1,3,1,2,1,5,1,1,1,7,1,2,1,3,1,1,1,11,1,1,1,13,1,1,1,1]
assert [cyclotomic(n,j) for n in range(21)] == [1,-1+j,1+j,j,0,1,-j,j,2,-j,1,j,3,1,-j,1,2,1,j,j,5]
assert [cyclotomic(n,-j) for n in range(21)] == [1,-1-j,1-j,-j,0,1,j,-j,2,j,1,-j,3,1,j,1,2,1,-j,-j,5]
assert cyclotomic(1624,j) == 1
assert cyclotomic(33600,j) == 1
u = sqrt(j, prec=500)
assert cyclotomic(8, u).ae(0)
assert cyclotomic(30, u).ae(5.8284271247461900976)
assert cyclotomic(2040, u).ae(1)
assert cyclotomic(0,2.5) == 1
assert cyclotomic(1,2.5) == 2.5-1
assert cyclotomic(2,2.5) == 2.5+1
assert cyclotomic(3,2.5) == 2.5**2 + 2.5 + 1
assert cyclotomic(7,2.5) == 406.234375

File diff suppressed because it is too large Load Diff

View File

@ -1,658 +0,0 @@
from mpmath import *
from mpmath.libmp import round_up, from_float, mpf_zeta_int
def test_zeta_int_bug():
assert mpf_zeta_int(0, 10) == from_float(-0.5)
def test_bernoulli():
assert bernfrac(0) == (1,1)
assert bernfrac(1) == (-1,2)
assert bernfrac(2) == (1,6)
assert bernfrac(3) == (0,1)
assert bernfrac(4) == (-1,30)
assert bernfrac(5) == (0,1)
assert bernfrac(6) == (1,42)
assert bernfrac(8) == (-1,30)
assert bernfrac(10) == (5,66)
assert bernfrac(12) == (-691,2730)
assert bernfrac(18) == (43867,798)
p, q = bernfrac(228)
assert p % 10**10 == 164918161
assert q == 625170
p, q = bernfrac(1000)
assert p % 10**10 == 7950421099
assert q == 342999030
mp.dps = 15
assert bernoulli(0) == 1
assert bernoulli(1) == -0.5
assert bernoulli(2).ae(1./6)
assert bernoulli(3) == 0
assert bernoulli(4).ae(-1./30)
assert bernoulli(5) == 0
assert bernoulli(6).ae(1./42)
assert str(bernoulli(10)) == '0.0757575757575758'
assert str(bernoulli(234)) == '7.62772793964344e+267'
assert str(bernoulli(10**5)) == '-5.82229431461335e+376755'
assert str(bernoulli(10**8+2)) == '1.19570355039953e+676752584'
mp.dps = 50
assert str(bernoulli(10)) == '0.075757575757575757575757575757575757575757575757576'
assert str(bernoulli(234)) == '7.6277279396434392486994969020496121553385863373331e+267'
assert str(bernoulli(10**5)) == '-5.8222943146133508236497045360612887555320691004308e+376755'
assert str(bernoulli(10**8+2)) == '1.1957035503995297272263047884604346914602088317782e+676752584'
mp.dps = 1000
assert bernoulli(10).ae(mpf(5)/66)
mp.dps = 50000
assert bernoulli(10).ae(mpf(5)/66)
mp.dps = 15
def test_bernpoly_eulerpoly():
mp.dps = 15
assert bernpoly(0,-1).ae(1)
assert bernpoly(0,0).ae(1)
assert bernpoly(0,'1/2').ae(1)
assert bernpoly(0,'3/4').ae(1)
assert bernpoly(0,1).ae(1)
assert bernpoly(0,2).ae(1)
assert bernpoly(1,-1).ae('-3/2')
assert bernpoly(1,0).ae('-1/2')
assert bernpoly(1,'1/2').ae(0)
assert bernpoly(1,'3/4').ae('1/4')
assert bernpoly(1,1).ae('1/2')
assert bernpoly(1,2).ae('3/2')
assert bernpoly(2,-1).ae('13/6')
assert bernpoly(2,0).ae('1/6')
assert bernpoly(2,'1/2').ae('-1/12')
assert bernpoly(2,'3/4').ae('-1/48')
assert bernpoly(2,1).ae('1/6')
assert bernpoly(2,2).ae('13/6')
assert bernpoly(3,-1).ae(-3)
assert bernpoly(3,0).ae(0)
assert bernpoly(3,'1/2').ae(0)
assert bernpoly(3,'3/4').ae('-3/64')
assert bernpoly(3,1).ae(0)
assert bernpoly(3,2).ae(3)
assert bernpoly(4,-1).ae('119/30')
assert bernpoly(4,0).ae('-1/30')
assert bernpoly(4,'1/2').ae('7/240')
assert bernpoly(4,'3/4').ae('7/3840')
assert bernpoly(4,1).ae('-1/30')
assert bernpoly(4,2).ae('119/30')
assert bernpoly(5,-1).ae(-5)
assert bernpoly(5,0).ae(0)
assert bernpoly(5,'1/2').ae(0)
assert bernpoly(5,'3/4').ae('25/1024')
assert bernpoly(5,1).ae(0)
assert bernpoly(5,2).ae(5)
assert bernpoly(10,-1).ae('665/66')
assert bernpoly(10,0).ae('5/66')
assert bernpoly(10,'1/2').ae('-2555/33792')
assert bernpoly(10,'3/4').ae('-2555/34603008')
assert bernpoly(10,1).ae('5/66')
assert bernpoly(10,2).ae('665/66')
assert bernpoly(11,-1).ae(-11)
assert bernpoly(11,0).ae(0)
assert bernpoly(11,'1/2').ae(0)
assert bernpoly(11,'3/4').ae('-555731/4194304')
assert bernpoly(11,1).ae(0)
assert bernpoly(11,2).ae(11)
assert eulerpoly(0,-1).ae(1)
assert eulerpoly(0,0).ae(1)
assert eulerpoly(0,'1/2').ae(1)
assert eulerpoly(0,'3/4').ae(1)
assert eulerpoly(0,1).ae(1)
assert eulerpoly(0,2).ae(1)
assert eulerpoly(1,-1).ae('-3/2')
assert eulerpoly(1,0).ae('-1/2')
assert eulerpoly(1,'1/2').ae(0)
assert eulerpoly(1,'3/4').ae('1/4')
assert eulerpoly(1,1).ae('1/2')
assert eulerpoly(1,2).ae('3/2')
assert eulerpoly(2,-1).ae(2)
assert eulerpoly(2,0).ae(0)
assert eulerpoly(2,'1/2').ae('-1/4')
assert eulerpoly(2,'3/4').ae('-3/16')
assert eulerpoly(2,1).ae(0)
assert eulerpoly(2,2).ae(2)
assert eulerpoly(3,-1).ae('-9/4')
assert eulerpoly(3,0).ae('1/4')
assert eulerpoly(3,'1/2').ae(0)
assert eulerpoly(3,'3/4').ae('-11/64')
assert eulerpoly(3,1).ae('-1/4')
assert eulerpoly(3,2).ae('9/4')
assert eulerpoly(4,-1).ae(2)
assert eulerpoly(4,0).ae(0)
assert eulerpoly(4,'1/2').ae('5/16')
assert eulerpoly(4,'3/4').ae('57/256')
assert eulerpoly(4,1).ae(0)
assert eulerpoly(4,2).ae(2)
assert eulerpoly(5,-1).ae('-3/2')
assert eulerpoly(5,0).ae('-1/2')
assert eulerpoly(5,'1/2').ae(0)
assert eulerpoly(5,'3/4').ae('361/1024')
assert eulerpoly(5,1).ae('1/2')
assert eulerpoly(5,2).ae('3/2')
assert eulerpoly(10,-1).ae(2)
assert eulerpoly(10,0).ae(0)
assert eulerpoly(10,'1/2').ae('-50521/1024')
assert eulerpoly(10,'3/4').ae('-36581523/1048576')
assert eulerpoly(10,1).ae(0)
assert eulerpoly(10,2).ae(2)
assert eulerpoly(11,-1).ae('-699/4')
assert eulerpoly(11,0).ae('691/4')
assert eulerpoly(11,'1/2').ae(0)
assert eulerpoly(11,'3/4').ae('-512343611/4194304')
assert eulerpoly(11,1).ae('-691/4')
assert eulerpoly(11,2).ae('699/4')
# Potential accuracy issues
assert bernpoly(10000,10000).ae('5.8196915936323387117e+39999')
assert bernpoly(200,17.5).ae(3.8048418524583064909e244)
assert eulerpoly(200,17.5).ae(-3.7309911582655785929e275)
def test_gamma():
mp.dps = 15
assert gamma(0.25).ae(3.6256099082219083119)
assert gamma(0.0001).ae(9999.4228832316241908)
assert gamma(300).ae('1.0201917073881354535e612')
assert gamma(-0.5).ae(-3.5449077018110320546)
assert gamma(-7.43).ae(0.00026524416464197007186)
#assert gamma(Rational(1,2)) == gamma(0.5)
#assert gamma(Rational(-7,3)).ae(gamma(mpf(-7)/3))
assert gamma(1+1j).ae(0.49801566811835604271 - 0.15494982830181068512j)
assert gamma(-1+0.01j).ae(-0.422733904013474115 + 99.985883082635367436j)
assert gamma(20+30j).ae(-1453876687.5534810 + 1163777777.8031573j)
# Should always give exact factorials when they can
# be represented as mpfs under the current working precision
fact = 1
for i in range(1, 18):
assert gamma(i) == fact
fact *= i
for dps in [170, 600]:
fact = 1
mp.dps = dps
for i in range(1, 105):
assert gamma(i) == fact
fact *= i
mp.dps = 100
assert gamma(0.5).ae(sqrt(pi))
mp.dps = 15
assert factorial(0) == fac(0) == 1
assert factorial(3) == 6
assert isnan(gamma(nan))
assert gamma(1100).ae('4.8579168073569433667e2866')
def test_fac2():
mp.dps = 15
assert [fac2(n) for n in range(10)] == [1,1,2,3,8,15,48,105,384,945]
assert fac2(-5).ae(1./3)
assert fac2(-11).ae(-1./945)
assert fac2(50).ae(5.20469842636666623e32)
assert fac2(0.5+0.75j).ae(0.81546769394688069176-0.34901016085573266889j)
assert fac2(inf) == inf
assert isnan(fac2(-inf))
def test_gamma_quotients():
mp.dps = 15
h = 1e-8
ep = 1e-4
G = gamma
assert gammaprod([-1],[-3,-4]) == 0
assert gammaprod([-1,0],[-5]) == inf
assert abs(gammaprod([-1],[-2]) - G(-1+h)/G(-2+h)) < 1e-4
assert abs(gammaprod([-4,-3],[-2,0]) - G(-4+h)*G(-3+h)/G(-2+h)/G(0+h)) < 1e-4
assert rf(3,0) == 1
assert rf(2.5,1) == 2.5
assert rf(-5,2) == 20
assert rf(j,j).ae(gamma(2*j)/gamma(j))
assert ff(-2,0) == 1
assert ff(-2,1) == -2
assert ff(4,3) == 24
assert ff(3,4) == 0
assert binomial(0,0) == 1
assert binomial(1,0) == 1
assert binomial(0,-1) == 0
assert binomial(3,2) == 3
assert binomial(5,2) == 10
assert binomial(5,3) == 10
assert binomial(5,5) == 1
assert binomial(-1,0) == 1
assert binomial(-2,-4) == 3
assert binomial(4.5, 1.5) == 6.5625
assert binomial(1100,1) == 1100
assert binomial(1100,2) == 604450
assert beta(1,1) == 1
assert beta(0,0) == inf
assert beta(3,0) == inf
assert beta(-1,-1) == inf
assert beta(1.5,1).ae(2/3.)
assert beta(1.5,2.5).ae(pi/16)
assert (10**15*beta(10,100)).ae(2.3455339739604649879)
assert beta(inf,inf) == 0
assert isnan(beta(-inf,inf))
assert isnan(beta(-3,inf))
assert isnan(beta(0,inf))
assert beta(inf,0.5) == beta(0.5,inf) == 0
assert beta(inf,-1.5) == inf
assert beta(inf,-0.5) == -inf
assert beta(1+2j,-1-j/2).ae(1.16396542451069943086+0.08511695947832914640j)
assert beta(-0.5,0.5) == 0
assert beta(-3,3).ae(-1/3.)
def test_zeta():
mp.dps = 15
assert zeta(2).ae(pi**2 / 6)
assert zeta(2.0).ae(pi**2 / 6)
assert zeta(mpc(2)).ae(pi**2 / 6)
assert zeta(100).ae(1)
assert zeta(0).ae(-0.5)
assert zeta(0.5).ae(-1.46035450880958681)
assert zeta(-1).ae(-mpf(1)/12)
assert zeta(-2) == 0
assert zeta(-3).ae(mpf(1)/120)
assert zeta(-4) == 0
assert zeta(-100) == 0
assert isnan(zeta(nan))
# Zeros in the critical strip
assert zeta(mpc(0.5, 14.1347251417346937904)).ae(0)
assert zeta(mpc(0.5, 21.0220396387715549926)).ae(0)
assert zeta(mpc(0.5, 25.0108575801456887632)).ae(0)
mp.dps = 50
im = '236.5242296658162058024755079556629786895294952121891237'
assert zeta(mpc(0.5, im)).ae(0, 1e-46)
mp.dps = 15
# Complex reflection formula
assert (zeta(-60+3j) / 10**34).ae(8.6270183987866146+15.337398548226238j)
def test_altzeta():
mp.dps = 15
assert altzeta(-2) == 0
assert altzeta(-4) == 0
assert altzeta(-100) == 0
assert altzeta(0) == 0.5
assert altzeta(-1) == 0.25
assert altzeta(-3) == -0.125
assert altzeta(-5) == 0.25
assert altzeta(-21) == 1180529130.25
assert altzeta(1).ae(log(2))
assert altzeta(2).ae(pi**2/12)
assert altzeta(10).ae(73*pi**10/6842880)
assert altzeta(50) < 1
assert altzeta(60, rounding='d') < 1
assert altzeta(60, rounding='u') == 1
assert altzeta(10000, rounding='d') < 1
assert altzeta(10000, rounding='u') == 1
assert altzeta(3+0j) == altzeta(3)
s = 3+4j
assert altzeta(s).ae((1-2**(1-s))*zeta(s))
s = -3+4j
assert altzeta(s).ae((1-2**(1-s))*zeta(s))
assert altzeta(-100.5).ae(4.58595480083585913e+108)
assert altzeta(1.3).ae(0.73821404216623045)
def test_zeta_huge():
mp.dps = 15
assert zeta(inf) == 1
mp.dps = 50
assert zeta(100).ae('1.0000000000000000000000000000007888609052210118073522')
assert zeta(40*pi).ae('1.0000000000000000000000000000000000000148407238666182')
mp.dps = 10000
v = zeta(33000)
mp.dps = 15
assert str(v-1) == '1.02363019598118e-9934'
assert zeta(pi*1000, rounding=round_up) > 1
assert zeta(3000, rounding=round_up) > 1
assert zeta(pi*1000) == 1
assert zeta(3000) == 1
def test_zeta_negative():
mp.dps = 150
a = -pi*10**40
mp.dps = 15
assert str(zeta(a)) == '2.55880492708712e+1233536161668617575553892558646631323374078'
mp.dps = 50
assert str(zeta(a)) == '2.5588049270871154960875033337384432038436330847333e+1233536161668617575553892558646631323374078'
mp.dps = 15
def test_polygamma():
mp.dps = 15
psi0 = lambda z: psi(0,z)
psi1 = lambda z: psi(1,z)
assert psi0(3) == psi(0,3) == digamma(3)
#assert psi2(3) == psi(2,3) == tetragamma(3)
#assert psi3(3) == psi(3,3) == pentagamma(3)
assert psi0(pi).ae(0.97721330794200673)
assert psi0(-pi).ae(7.8859523853854902)
assert psi0(-pi+1).ae(7.5676424992016996)
assert psi0(pi+j).ae(1.04224048313859376 + 0.35853686544063749j)
assert psi0(-pi-j).ae(1.3404026194821986 - 2.8824392476809402j)
assert findroot(psi0, 1).ae(1.4616321449683622)
assert psi0(inf) == inf
assert psi1(inf) == 0
assert psi(2,inf) == 0
assert psi1(pi).ae(0.37424376965420049)
assert psi1(-pi).ae(53.030438740085385)
assert psi1(pi+j).ae(0.32935710377142464 - 0.12222163911221135j)
assert psi1(-pi-j).ae(-0.30065008356019703 + 0.01149892486928227j)
assert (10**6*psi(4,1+10*pi*j)).ae(-6.1491803479004446 - 0.3921316371664063j)
assert psi0(1+10*pi*j).ae(3.4473994217222650 + 1.5548808324857071j)
assert isnan(psi0(nan))
assert isnan(psi0(-inf))
assert psi0(-100.5).ae(4.615124601338064)
assert psi0(3+0j).ae(psi0(3))
assert psi0(-100+3j).ae(4.6106071768714086321+3.1117510556817394626j)
def test_polygamma_high_prec():
mp.dps = 100
assert str(psi(0,pi)) == "0.9772133079420067332920694864061823436408346099943256380095232865318105924777141317302075654362928734"
assert str(psi(10,pi)) == "-12.98876181434889529310283769414222588307175962213707170773803550518307617769657562747174101900659238"
def test_polygamma_identities():
mp.dps = 15
psi0 = lambda z: psi(0,z)
psi1 = lambda z: psi(1,z)
psi2 = lambda z: psi(2,z)
assert psi0(0.5).ae(-euler-2*log(2))
assert psi0(1).ae(-euler)
assert psi1(0.5).ae(0.5*pi**2)
assert psi1(1).ae(pi**2/6)
assert psi1(0.25).ae(pi**2 + 8*catalan)
assert psi2(1).ae(-2*apery)
mp.dps = 20
u = -182*apery+4*sqrt(3)*pi**3
mp.dps = 15
assert psi(2,5/6.).ae(u)
assert psi(3,0.5).ae(pi**4)
def test_foxtrot_identity():
# A test of the complex digamma function.
# See http://mathworld.wolfram.com/FoxTrotSeries.html and
# http://mathworld.wolfram.com/DigammaFunction.html
psi0 = lambda z: psi(0,z)
mp.dps = 50
a = (-1)**fraction(1,3)
b = (-1)**fraction(2,3)
x = -psi0(0.5*a) - psi0(-0.5*b) + psi0(0.5*(1+a)) + psi0(0.5*(1-b))
y = 2*pi*sech(0.5*sqrt(3)*pi)
assert x.ae(y)
mp.dps = 15
def test_polygamma_high_order():
mp.dps = 100
assert str(psi(50, pi)) == "-1344100348958402765749252447726432491812.641985273160531055707095989227897753035823152397679626136483"
assert str(psi(50, pi + 14*e)) == "-0.00000000000000000189793739550804321623512073101895801993019919886375952881053090844591920308111549337295143780341396"
assert str(psi(50, pi + 14*e*j)) == ("(-0.0000000000000000522516941152169248975225472155683565752375889510631513244785"
"9377385233700094871256507814151956624433 - 0.00000000000000001813157041407010184"
"702414110218205348527862196327980417757665282244728963891298080199341480881811613j)")
mp.dps = 15
assert str(psi(50, pi)) == "-1.34410034895841e+39"
assert str(psi(50, pi + 14*e)) == "-1.89793739550804e-18"
assert str(psi(50, pi + 14*e*j)) == "(-5.2251694115217e-17 - 1.81315704140701e-17j)"
def test_harmonic():
mp.dps = 15
assert harmonic(0) == 0
assert harmonic(1) == 1
assert harmonic(2) == 1.5
assert harmonic(3).ae(1. + 1./2 + 1./3)
assert harmonic(10**10).ae(23.603066594891989701)
assert harmonic(10**1000).ae(2303.162308658947)
assert harmonic(0.5).ae(2-2*log(2))
assert harmonic(inf) == inf
assert harmonic(2+0j) == 1.5+0j
assert harmonic(1+2j).ae(1.4918071802755104+0.92080728264223022j)
def test_gamma_huge_1():
mp.dps = 500
x = mpf(10**10) / 7
mp.dps = 15
assert str(gamma(x)) == "6.26075321389519e+12458010678"
mp.dps = 50
assert str(gamma(x)) == "6.2607532138951929201303779291707455874010420783933e+12458010678"
mp.dps = 15
def test_gamma_huge_2():
mp.dps = 500
x = mpf(10**100) / 19
mp.dps = 15
assert str(gamma(x)) == (\
"1.82341134776679e+5172997469323364168990133558175077136829182824042201886051511"
"9656908623426021308685461258226190190661")
mp.dps = 50
assert str(gamma(x)) == (\
"1.82341134776678875374414910350027596939980412984e+5172997469323364168990133558"
"1750771368291828240422018860515119656908623426021308685461258226190190661")
def test_gamma_huge_3():
mp.dps = 500
x = 10**80 // 3 + 10**70*j / 7
mp.dps = 15
y = gamma(x)
assert str(y.real) == (\
"-6.82925203918106e+2636286142112569524501781477865238132302397236429627932441916"
"056964386399485392600")
assert str(y.imag) == (\
"8.54647143678418e+26362861421125695245017814778652381323023972364296279324419160"
"56964386399485392600")
mp.dps = 50
y = gamma(x)
assert str(y.real) == (\
"-6.8292520391810548460682736226799637356016538421817e+26362861421125695245017814"
"77865238132302397236429627932441916056964386399485392600")
assert str(y.imag) == (\
"8.5464714367841748507479306948130687511711420234015e+263628614211256952450178147"
"7865238132302397236429627932441916056964386399485392600")
def test_gamma_huge_4():
x = 3200+11500j
mp.dps = 15
assert str(gamma(x)) == \
"(8.95783268539713e+5164 - 1.94678798329735e+5164j)"
mp.dps = 50
assert str(gamma(x)) == (\
"(8.9578326853971339570292952697675570822206567327092e+5164"
" - 1.9467879832973509568895402139429643650329524144794e+51"
"64j)")
mp.dps = 15
def test_gamma_huge_5():
mp.dps = 500
x = 10**60 * j / 3
mp.dps = 15
y = gamma(x)
assert str(y.real) == "-3.27753899634941e-227396058973640224580963937571892628368354580620654233316839"
assert str(y.imag) == "-7.1519888950416e-227396058973640224580963937571892628368354580620654233316841"
mp.dps = 50
y = gamma(x)
assert str(y.real) == (\
"-3.2775389963494132168950056995974690946983219123935e-22739605897364022458096393"
"7571892628368354580620654233316839")
assert str(y.imag) == (\
"-7.1519888950415979749736749222530209713136588885897e-22739605897364022458096393"
"7571892628368354580620654233316841")
mp.dps = 15
"""
XXX: fails
def test_gamma_huge_6():
return
mp.dps = 500
x = -10**10 + mpf(10)**(-175)*j
mp.dps = 15
assert str(gamma(x)) == \
"(1.86729378905343e-95657055178 - 4.29960285282433e-95657055002j)"
mp.dps = 50
assert str(gamma(x)) == (\
"(1.8672937890534298925763143275474177736153484820662e-9565705517"
"8 - 4.2996028528243336966001185406200082244961757496106e-9565705"
"5002j)")
mp.dps = 15
"""
def test_gamma_huge_7():
mp.dps = 100
a = 3 + j/mpf(10)**1000
mp.dps = 15
y = gamma(a)
assert str(y.real) == "2.0"
assert str(y.imag) == "2.16735365342606e-1000"
mp.dps = 50
y = gamma(a)
assert str(y.real) == "2.0"
assert str(y.imag) == "2.1673536534260596065418805612488708028522563689298e-1000"
def test_stieltjes():
mp.dps = 15
assert stieltjes(0).ae(+euler)
mp.dps = 25
assert stieltjes(1).ae('-0.07281584548367672486058637587')
assert stieltjes(2).ae('-0.009690363192872318484530386035')
assert stieltjes(3).ae('0.002053834420303345866160046543')
assert stieltjes(4).ae('0.002325370065467300057468170178')
mp.dps = 15
assert stieltjes(1).ae(-0.07281584548367672486058637587)
assert stieltjes(2).ae(-0.009690363192872318484530386035)
assert stieltjes(3).ae(0.002053834420303345866160046543)
assert stieltjes(4).ae(0.0023253700654673000574681701775)
def test_barnesg():
mp.dps = 15
assert barnesg(0) == barnesg(-1) == 0
assert [superfac(i) for i in range(8)] == [1, 1, 2, 12, 288, 34560, 24883200, 125411328000]
assert str(superfac(1000)) == '3.24570818422368e+1177245'
assert isnan(barnesg(nan))
assert isnan(superfac(nan))
assert isnan(hyperfac(nan))
assert barnesg(inf) == inf
assert superfac(inf) == inf
assert hyperfac(inf) == inf
assert isnan(superfac(-inf))
assert barnesg(0.7).ae(0.8068722730141471)
assert barnesg(2+3j).ae(-0.17810213864082169+0.04504542715447838j)
assert [hyperfac(n) for n in range(7)] == [1, 1, 4, 108, 27648, 86400000, 4031078400000]
assert [hyperfac(n) for n in range(0,-7,-1)] == [1,1,-1,-4,108,27648,-86400000]
a = barnesg(-3+0j)
assert a == 0 and isinstance(a, mpc)
a = hyperfac(-3+0j)
assert a == -4 and isinstance(a, mpc)
def test_polylog():
mp.dps = 15
zs = [mpmathify(z) for z in [0, 0.5, 0.99, 4, -0.5, -4, 1j, 3+4j]]
for z in zs: assert polylog(1, z).ae(-log(1-z))
for z in zs: assert polylog(0, z).ae(z/(1-z))
for z in zs: assert polylog(-1, z).ae(z/(1-z)**2)
for z in zs: assert polylog(-2, z).ae(z*(1+z)/(1-z)**3)
for z in zs: assert polylog(-3, z).ae(z*(1+4*z+z**2)/(1-z)**4)
assert polylog(3, 7).ae(5.3192579921456754382-5.9479244480803301023j)
assert polylog(3, -7).ae(-4.5693548977219423182)
assert polylog(2, 0.9).ae(1.2997147230049587252)
assert polylog(2, -0.9).ae(-0.75216317921726162037)
assert polylog(2, 0.9j).ae(-0.17177943786580149299+0.83598828572550503226j)
assert polylog(2, 1.1).ae(1.9619991013055685931-0.2994257606855892575j)
assert polylog(2, -1.1).ae(-0.89083809026228260587)
assert polylog(2, 1.1*sqrt(j)).ae(0.58841571107611387722+1.09962542118827026011j)
assert polylog(-2, 0.9).ae(1710)
assert polylog(-2, -0.9).ae(-90/6859.)
assert polylog(3, 0.9).ae(1.0496589501864398696)
assert polylog(-3, 0.9).ae(48690)
assert polylog(-3, -4).ae(-0.0064)
assert polylog(0.5+j/3, 0.5+j/2).ae(0.31739144796565650535 + 0.99255390416556261437j)
assert polylog(3+4j,1).ae(zeta(3+4j))
assert polylog(3+4j,-1).ae(-altzeta(3+4j))
def test_bell_polyexp():
mp.dps = 15
# TODO: more tests for polyexp
assert (polyexp(0,1e-10)*10**10).ae(1.00000000005)
assert (polyexp(1,1e-10)*10**10).ae(1.0000000001)
assert polyexp(5,3j).ae(-607.7044517476176454+519.962786482001476087j)
assert polyexp(-1,3.5).ae(12.09537536175543444)
# bell(0,x) = 1
assert bell(0,0) == 1
assert bell(0,1) == 1
assert bell(0,2) == 1
assert bell(0,inf) == 1
assert bell(0,-inf) == 1
assert isnan(bell(0,nan))
# bell(1,x) = x
assert bell(1,4) == 4
assert bell(1,0) == 0
assert bell(1,inf) == inf
assert bell(1,-inf) == -inf
assert isnan(bell(1,nan))
# bell(2,x) = x*(1+x)
assert bell(2,-1) == 0
assert bell(2,0) == 0
# large orders / arguments
assert bell(10) == 115975
assert bell(10,1) == 115975
assert bell(10, -8) == 11054008
assert bell(5,-50) == -253087550
assert bell(50,-50).ae('3.4746902914629720259e74')
mp.dps = 80
assert bell(50,-50) == 347469029146297202586097646631767227177164818163463279814268368579055777450
assert bell(40,50) == 5575520134721105844739265207408344706846955281965031698187656176321717550
assert bell(74) == 5006908024247925379707076470957722220463116781409659160159536981161298714301202
mp.dps = 15
assert bell(10,20j) == 7504528595600+15649605360020j
# continuity of the generalization
assert bell(0.5,0).ae(sinc(pi*0.5))
def test_primezeta():
mp.dps = 15
assert primezeta(0.9).ae(1.8388316154446882243 + 3.1415926535897932385j)
assert primezeta(4).ae(0.076993139764246844943)
assert primezeta(1) == inf
assert primezeta(inf) == 0
assert isnan(primezeta(nan))
def test_rs_zeta():
mp.dps = 15
assert zeta(0.5+100000j).ae(1.0730320148577531321 + 5.7808485443635039843j)
assert zeta(0.75+100000j).ae(1.837852337251873704 + 1.9988492668661145358j)
assert zeta(0.5+1000000j, derivative=3).ae(1647.7744105852674733 - 1423.1270943036622097j)
assert zeta(1+1000000j, derivative=3).ae(3.4085866124523582894 - 18.179184721525947301j)
assert zeta(1+1000000j, derivative=1).ae(-0.10423479366985452134 - 0.74728992803359056244j)
assert zeta(0.5-1000000j, derivative=1).ae(11.636804066002521459 + 17.127254072212996004j)
# Additional sanity tests using fp arithmetic.
# Some more high-precision tests are found in the docstrings
def ae(x, y, tol=1e-6):
return abs(x-y) < tol*abs(y)
assert ae(fp.zeta(0.5-100000j), 1.0730320148577531321 - 5.7808485443635039843j)
assert ae(fp.zeta(0.75-100000j), 1.837852337251873704 - 1.9988492668661145358j)
assert ae(fp.zeta(0.5+1e6j), 0.076089069738227100006 + 2.8051021010192989554j)
assert ae(fp.zeta(0.5+1e6j, derivative=1), 11.636804066002521459 - 17.127254072212996004j)
assert ae(fp.zeta(1+1e6j), 0.94738726251047891048 + 0.59421999312091832833j)
assert ae(fp.zeta(1+1e6j, derivative=1), -0.10423479366985452134 - 0.74728992803359056244j)
assert ae(fp.zeta(0.5+100000j, derivative=1), 10.766962036817482375 - 30.92705282105996714j)
assert ae(fp.zeta(0.5+100000j, derivative=2), -119.40515625740538429 + 217.14780631141830251j)
assert ae(fp.zeta(0.5+100000j, derivative=3), 1129.7550282628460881 - 1685.4736895169690346j)
assert ae(fp.zeta(0.5+100000j, derivative=4), -10407.160819314958615 + 13777.786698628045085j)
assert ae(fp.zeta(0.75+100000j, derivative=1), -0.41742276699594321475 - 6.4453816275049955949j)
assert ae(fp.zeta(0.75+100000j, derivative=2), -9.214314279161977266 + 35.07290795337967899j)
assert ae(fp.zeta(0.75+100000j, derivative=3), 110.61331857820103469 - 236.87847130518129926j)
assert ae(fp.zeta(0.75+100000j, derivative=4), -1054.334275898559401 + 1769.9177890161596383j)
def test_zeta_near_1():
# Test for a former bug in mpf_zeta and mpc_zeta
mp.dps = 15
s1 = fadd(1, '1e-10', exact=True)
s2 = fadd(1, '-1e-10', exact=True)
s3 = fadd(1, '1e-10j', exact=True)
assert zeta(s1).ae(1.000000000057721566490881444e10)
assert zeta(s2).ae(-9.99999999942278433510574872e9)
z = zeta(s3)
assert z.real.ae(0.57721566490153286060)
assert z.imag.ae(-9.9999999999999999999927184e9)
mp.dps = 30
s1 = fadd(1, '1e-50', exact=True)
s2 = fadd(1, '-1e-50', exact=True)
s3 = fadd(1, '1e-50j', exact=True)
assert zeta(s1).ae('1e50')
assert zeta(s2).ae('-1e50')
z = zeta(s3)
assert z.real.ae('0.57721566490153286060651209008240243104215933593992')
assert z.imag.ae('-1e50')

View File

@ -1,292 +0,0 @@
"""
Check that the output from irrational functions is accurate for
high-precision input, from 5 to 200 digits. The reference values were
verified with Mathematica.
"""
import time
from mpmath import *
precs = [5, 15, 28, 35, 57, 80, 100, 150, 200]
# sqrt(3) + pi/2
a = \
"3.302847134363773912758768033145623809041389953497933538543279275605"\
"841220051904536395163599428307109666700184672047856353516867399774243594"\
"67433521615861420725323528325327484262075464241255915238845599752675"
# e + 1/euler**2
b = \
"5.719681166601007617111261398629939965860873957353320734275716220045750"\
"31474116300529519620938123730851145473473708966080207482581266469342214"\
"824842256999042984813905047895479210702109260221361437411947323431"
# sqrt(a)
sqrt_a = \
"1.817373691447021556327498239690365674922395036495564333152483422755"\
"144321726165582817927383239308173567921345318453306994746434073691275094"\
"484777905906961689902608644112196725896908619756404253109722911487"
# sqrt(a+b*i).real
sqrt_abi_real = \
"2.225720098415113027729407777066107959851146508557282707197601407276"\
"89160998185797504198062911768240808839104987021515555650875977724230130"\
"3584116233925658621288393930286871862273400475179312570274423840384"
# sqrt(a+b*i).imag
sqrt_abi_imag = \
"1.2849057639084690902371581529110949983261182430040898147672052833653668"\
"0629534491275114877090834296831373498336559849050755848611854282001250"\
"1924311019152914021365263161630765255610885489295778894976075186"
# log(a)
log_a = \
"1.194784864491089550288313512105715261520511949410072046160598707069"\
"4336653155025770546309137440687056366757650909754708302115204338077595203"\
"83005773986664564927027147084436553262269459110211221152925732612"
# log(a+b*i).real
log_abi_real = \
"1.8877985921697018111624077550443297276844736840853590212962006811663"\
"04949387789489704203167470111267581371396245317618589339274243008242708"\
"014251531496104028712866224020066439049377679709216784954509456421"
# log(a+b*i).imag
log_abi_imag = \
"1.0471204952840802663567714297078763189256357109769672185219334169734948"\
"4265809854092437285294686651806426649541504240470168212723133326542181"\
"8300136462287639956713914482701017346851009323172531601894918640"
# exp(a)
exp_a = \
"27.18994224087168661137253262213293847994194869430518354305430976149"\
"382792035050358791398632888885200049857986258414049540376323785711941636"\
"100358982497583832083513086941635049329804685212200507288797531143"
# exp(a+b*i).real
exp_abi_real = \
"22.98606617170543596386921087657586890620262522816912505151109385026"\
"40160179326569526152851983847133513990281518417211964710397233157168852"\
"4963130831190142571659948419307628119985383887599493378056639916701"
# exp(a+b*i).imag
exp_abi_imag = \
"-14.523557450291489727214750571590272774669907424478129280902375851196283"\
"3377162379031724734050088565710975758824441845278120105728824497308303"\
"6065619788140201636218705414429933685889542661364184694108251449"
# a**b
pow_a_b = \
"928.7025342285568142947391505837660251004990092821305668257284426997"\
"361966028275685583421197860603126498884545336686124793155581311527995550"\
"580229264427202446131740932666832138634013168125809402143796691154"
# (a**(a+b*i)).real
pow_a_abi_real = \
"44.09156071394489511956058111704382592976814280267142206420038656267"\
"67707916510652790502399193109819563864568986234654864462095231138500505"\
"8197456514795059492120303477512711977915544927440682508821426093455"
# (a**(a+b*i)).imag
pow_a_abi_imag = \
"27.069371511573224750478105146737852141664955461266218367212527612279886"\
"9322304536553254659049205414427707675802193810711302947536332040474573"\
"8166261217563960235014674118610092944307893857862518964990092301"
# ((a+b*i)**(a+b*i)).real
pow_abi_abi_real = \
"-0.15171310677859590091001057734676423076527145052787388589334350524"\
"8084195882019497779202452975350579073716811284169068082670778986235179"\
"0813026562962084477640470612184016755250592698408112493759742219150452"\
# ((a+b*i)**(a+b*i)).imag
pow_abi_abi_imag = \
"1.2697592504953448936553147870155987153192995316950583150964099070426"\
"4736837932577176947632535475040521749162383347758827307504526525647759"\
"97547638617201824468382194146854367480471892602963428122896045019902"
# sin(a)
sin_a = \
"-0.16055653857469062740274792907968048154164433772938156243509084009"\
"38437090841460493108570147191289893388608611542655654723437248152535114"\
"528368009465836614227575701220612124204622383149391870684288862269631"
# sin(1000*a)
sin_1000a = \
"-0.85897040577443833776358106803777589664322997794126153477060795801"\
"09151695416961724733492511852267067419573754315098042850381158563024337"\
"216458577140500488715469780315833217177634490142748614625281171216863"
# sin(a+b*i)
sin_abi_real = \
"-24.4696999681556977743346798696005278716053366404081910969773939630"\
"7149215135459794473448465734589287491880563183624997435193637389884206"\
"02151395451271809790360963144464736839412254746645151672423256977064"
sin_abi_imag = \
"-150.42505378241784671801405965872972765595073690984080160750785565810981"\
"8314482499135443827055399655645954830931316357243750839088113122816583"\
"7169201254329464271121058839499197583056427233866320456505060735"
# cos
cos_a = \
"-0.98702664499035378399332439243967038895709261414476495730788864004"\
"05406821549361039745258003422386169330787395654908532996287293003581554"\
"257037193284199198069707141161341820684198547572456183525659969145501"
cos_1000a = \
"-0.51202523570982001856195696460663971099692261342827540426136215533"\
"52686662667660613179619804463250686852463876088694806607652218586060613"\
"951310588158830695735537073667299449753951774916401887657320950496820"
# tan
tan_a = \
"0.162666873675188117341401059858835168007137819495998960250142156848"\
"639654718809412181543343168174807985559916643549174530459883826451064966"\
"7996119428949951351938178809444268785629011625179962457123195557310"
tan_abi_real = \
"6.822696615947538488826586186310162599974827139564433912601918442911"\
"1026830824380070400102213741875804368044342309515353631134074491271890"\
"467615882710035471686578162073677173148647065131872116479947620E-6"
tan_abi_imag = \
"0.9999795833048243692245661011298447587046967777739649018690797625964167"\
"1446419978852235960862841608081413169601038230073129482874832053357571"\
"62702259309150715669026865777947502665936317953101462202542168429"
def test_hp():
for dps in precs:
mp.dps = dps + 8
aa = mpf(a)
bb = mpf(b)
a1000 = 1000*mpf(a)
abi = mpc(aa, bb)
mp.dps = dps
assert (sqrt(3) + pi/2).ae(aa)
assert (e + 1/euler**2).ae(bb)
assert sqrt(aa).ae(mpf(sqrt_a))
assert sqrt(abi).ae(mpc(sqrt_abi_real, sqrt_abi_imag))
assert log(aa).ae(mpf(log_a))
assert log(abi).ae(mpc(log_abi_real, log_abi_imag))
assert exp(aa).ae(mpf(exp_a))
assert exp(abi).ae(mpc(exp_abi_real, exp_abi_imag))
assert (aa**bb).ae(mpf(pow_a_b))
assert (aa**abi).ae(mpc(pow_a_abi_real, pow_a_abi_imag))
assert (abi**abi).ae(mpc(pow_abi_abi_real, pow_abi_abi_imag))
assert sin(a).ae(mpf(sin_a))
assert sin(a1000).ae(mpf(sin_1000a))
assert sin(abi).ae(mpc(sin_abi_real, sin_abi_imag))
assert cos(a).ae(mpf(cos_a))
assert cos(a1000).ae(mpf(cos_1000a))
assert tan(a).ae(mpf(tan_a))
assert tan(abi).ae(mpc(tan_abi_real, tan_abi_imag))
# check that complex cancellation is avoided so that both
# real and imaginary parts have high relative accuracy.
# abs_eps should be 0, but has to be set to 1e-205 to pass the
# 200-digit case, probably due to slight inaccuracy in the
# precomputed input
assert (tan(abi).real).ae(mpf(tan_abi_real), abs_eps=1e-205)
assert (tan(abi).imag).ae(mpf(tan_abi_imag), abs_eps=1e-205)
mp.dps = 460
assert str(log(3))[-20:] == '02166121184001409826'
mp.dps = 15
# Since str(a) can differ in the last digit from rounded a, and I want
# to compare the last digits of big numbers with the results in Mathematica,
# I made this hack to get the last 20 digits of rounded a
def last_digits(a):
r = repr(a)
s = str(a)
#dps = mp.dps
#mp.dps += 3
m = 10
r = r.replace(s[:-m],'')
r = r.replace("mpf('",'').replace("')",'')
num0 = 0
for c in r:
if c == '0':
num0 += 1
else:
break
b = float(int(r))/10**(len(r) - m)
if b >= 10**m - 0.5:
raise NotImplementedError
n = int(round(b))
sn = str(n)
s = s[:-m] + '0'*num0 + sn
return s[-20:]
# values checked with Mathematica
def test_log_hp():
mp.dps = 2000
a = mpf(10)**15000/3
r = log(a)
res = last_digits(r)
# Mathematica N[Log[10^15000/3], 2000]
# ...7443804441768333470331
assert res == '44380444176833347033'
# see issue 105
r = log(mpf(3)/2)
# Mathematica N[Log[3/2], 2000]
# ...69653749808140753263288
res = last_digits(r)
assert res == '53749808140753263288'
mp.dps = 10000
r = log(2)
res = last_digits(r)
# Mathematica N[Log[2], 10000]
# ...695615913401856601359655561
assert res == '91340185660135965556'
r = log(mpf(10)**10/3)
res = last_digits(r)
# Mathematica N[Log[10^10/3], 10000]
# ...587087654020631943060007154
assert res == '54020631943060007154', res
r = log(mpf(10)**100/3)
res = last_digits(r)
# Mathematica N[Log[10^100/3], 10000]
# ,,,59246336539088351652334666
assert res == '36539088351652334666', res
mp.dps += 10
a = 1 - mpf(1)/10**10
mp.dps -= 10
r = log(a)
res = last_digits(r)
# ...3310334360482956137216724048322957404
# 372167240483229574038733026370
# Mathematica N[Log[1 - 10^-10]*10^10, 10000]
# ...60482956137216724048322957404
assert res == '37216724048322957404', res
mp.dps = 10000
mp.dps += 100
a = 1 + mpf(1)/10**100
mp.dps -= 100
r = log(a)
res = last_digits(+r)
# Mathematica N[Log[1 + 10^-100]*10^10, 10030]
# ...3994733877377412241546890854692521568292338268273 10^-91
assert res == '39947338773774122415', res
mp.dps = 15
def test_exp_hp():
mp.dps = 4000
r = exp(mpf(1)/10)
# IntegerPart[N[Exp[1/10] * 10^4000, 4000]]
# ...92167105162069688129
assert int(r * 10**mp.dps) % 10**20 == 92167105162069688129

View File

@ -1,19 +0,0 @@
from mpmath import *
def test_pslq():
mp.dps = 15
assert pslq([3*pi+4*e/7, pi, e, log(2)]) == [7, -21, -4, 0]
assert pslq([4.9999999999999991, 1]) == [1, -5]
assert pslq([2,1]) == [1, -2]
def test_identify():
mp.dps = 20
assert identify(zeta(4), ['log(2)', 'pi**4']) == '((1/90)*pi**4)'
mp.dps = 15
assert identify(exp(5)) == 'exp(5)'
assert identify(exp(4)) == 'exp(4)'
assert identify(log(5)) == 'log(5)'
assert identify(exp(3*pi), ['pi']) == 'exp((3*pi))'
assert identify(3, full=True) == ['3', '3', '1/(1/3)', 'sqrt(9)',
'1/sqrt((1/9))', '(sqrt(12)/2)**2', '1/(sqrt(12)/6)**2']
assert identify(pi+1, {'a':+pi}) == '(1 + 1*a)'

View File

@ -1,264 +0,0 @@
from mpmath import *
mpi_to_str = mp.mpi_to_str
mpi_from_str = mp.mpi_from_str
def test_interval_identity():
mp.dps = 15
assert mpi(2) == mpi(2, 2)
assert mpi(2) != mpi(-2, 2)
assert not (mpi(2) != mpi(2, 2))
assert mpi(-1, 1) == mpi(-1, 1)
assert str(mpi('0.1')) == "[0.099999999999999991673, 0.10000000000000000555]"
assert repr(mpi('0.1')) == "mpi(mpf('0.099999999999999992'), mpf('0.10000000000000001'))"
u = mpi(-1, 3)
assert -1 in u
assert 2 in u
assert 3 in u
assert -1.1 not in u
assert 3.1 not in u
assert mpi(-1, 3) in u
assert mpi(0, 1) in u
assert mpi(-1.1, 2) not in u
assert mpi(2.5, 3.1) not in u
w = mpi(-inf, inf)
assert mpi(-5, 5) in w
assert mpi(2, inf) in w
assert mpi(0, 2) in mpi(0, 10)
assert not (3 in mpi(-inf, 0))
def test_interval_arithmetic():
mp.dps = 15
assert mpi(2) + mpi(3,4) == mpi(5,6)
assert mpi(1, 2)**2 == mpi(1, 4)
assert mpi(1) + mpi(0, 1e-50) == mpi(1, mpf('1.0000000000000002'))
x = 1 / (1 / mpi(3))
assert x.a < 3 < x.b
x = mpi(2) ** mpi(0.5)
mp.dps += 5
sq = sqrt(2)
mp.dps -= 5
assert x.a < sq < x.b
assert mpi(1) / mpi(1, inf)
assert mpi(2, 3) / inf == mpi(0, 0)
assert mpi(0) / inf == 0
assert mpi(0) / 0 == mpi(-inf, inf)
assert mpi(inf) / 0 == mpi(-inf, inf)
assert mpi(0) * inf == mpi(-inf, inf)
assert 1 / mpi(2, inf) == mpi(0, 0.5)
assert str((mpi(50, 50) * mpi(-10, -10)) / 3) == \
'[-166.66666666666668561, -166.66666666666665719]'
assert mpi(0, 4) ** 3 == mpi(0, 64)
assert mpi(2,4).mid == 3
mp.dps = 30
a = mpi(pi)
mp.dps = 15
b = +a
assert b.a < a.a
assert b.b > a.b
a = mpi(pi)
assert a == +a
assert abs(mpi(-1,2)) == mpi(0,2)
assert abs(mpi(0.5,2)) == mpi(0.5,2)
assert abs(mpi(-3,2)) == mpi(0,3)
assert abs(mpi(-3,-0.5)) == mpi(0.5,3)
assert mpi(0) * mpi(2,3) == mpi(0)
assert mpi(2,3) * mpi(0) == mpi(0)
assert mpi(1,3).delta == 2
assert mpi(1,2) - mpi(3,4) == mpi(-3,-1)
assert mpi(-inf,0) - mpi(0,inf) == mpi(-inf,0)
assert mpi(-inf,0) - mpi(-inf,inf) == mpi(-inf,inf)
assert mpi(0,inf) - mpi(-inf,1) == mpi(-1,inf)
def test_interval_mul():
assert mpi(-1, 0) * inf == mpi(-inf, 0)
assert mpi(-1, 0) * -inf == mpi(0, inf)
assert mpi(0, 1) * inf == mpi(0, inf)
assert mpi(0, 1) * mpi(0, inf) == mpi(0, inf)
assert mpi(-1, 1) * inf == mpi(-inf, inf)
assert mpi(-1, 1) * mpi(0, inf) == mpi(-inf, inf)
assert mpi(-1, 1) * mpi(-inf, inf) == mpi(-inf, inf)
assert mpi(-inf, 0) * mpi(0, 1) == mpi(-inf, 0)
assert mpi(-inf, 0) * mpi(0, 0) * mpi(-inf, 0)
assert mpi(-inf, 0) * mpi(-inf, inf) == mpi(-inf, inf)
assert mpi(-5,0)*mpi(-32,28) == mpi(-140,160)
assert mpi(2,3) * mpi(-1,2) == mpi(-3,6)
# Should be undefined?
assert mpi(inf, inf) * 0 == mpi(-inf, inf)
assert mpi(-inf, -inf) * 0 == mpi(-inf, inf)
assert mpi(0) * mpi(-inf,2) == mpi(-inf,inf)
assert mpi(0) * mpi(-2,inf) == mpi(-inf,inf)
assert mpi(-2,inf) * mpi(0) == mpi(-inf,inf)
assert mpi(-inf,2) * mpi(0) == mpi(-inf,inf)
def test_interval_pow():
assert mpi(3)**2 == mpi(9, 9)
assert mpi(-3)**2 == mpi(9, 9)
assert mpi(-3, 1)**2 == mpi(0, 9)
assert mpi(-3, -1)**2 == mpi(1, 9)
assert mpi(-3, -1)**3 == mpi(-27, -1)
assert mpi(-3, 1)**3 == mpi(-27, 1)
assert mpi(-2, 3)**2 == mpi(0, 9)
assert mpi(-3, 2)**2 == mpi(0, 9)
assert mpi(4) ** -1 == mpi(0.25, 0.25)
assert mpi(-4) ** -1 == mpi(-0.25, -0.25)
assert mpi(4) ** -2 == mpi(0.0625, 0.0625)
assert mpi(-4) ** -2 == mpi(0.0625, 0.0625)
assert mpi(0, 1) ** inf == mpi(0, 1)
assert mpi(0, 1) ** -inf == mpi(1, inf)
assert mpi(0, inf) ** inf == mpi(0, inf)
assert mpi(0, inf) ** -inf == mpi(0, inf)
assert mpi(1, inf) ** inf == mpi(1, inf)
assert mpi(1, inf) ** -inf == mpi(0, 1)
assert mpi(2, 3) ** 1 == mpi(2, 3)
assert mpi(2, 3) ** 0 == 1
assert mpi(1,3) ** mpi(2) == mpi(1,9)
def test_interval_sqrt():
assert mpi(4) ** 0.5 == mpi(2)
def test_interval_div():
assert mpi(0.5, 1) / mpi(-1, 0) == mpi(-inf, -0.5)
assert mpi(0, 1) / mpi(0, 1) == mpi(0, inf)
assert mpi(inf, inf) / mpi(inf, inf) == mpi(0, inf)
assert mpi(inf, inf) / mpi(2, inf) == mpi(0, inf)
assert mpi(inf, inf) / mpi(2, 2) == mpi(inf, inf)
assert mpi(0, inf) / mpi(2, inf) == mpi(0, inf)
assert mpi(0, inf) / mpi(2, 2) == mpi(0, inf)
assert mpi(2, inf) / mpi(2, 2) == mpi(1, inf)
assert mpi(2, inf) / mpi(2, inf) == mpi(0, inf)
assert mpi(-4, 8) / mpi(1, inf) == mpi(-4, 8)
assert mpi(-4, 8) / mpi(0.5, inf) == mpi(-8, 16)
assert mpi(-inf, 8) / mpi(0.5, inf) == mpi(-inf, 16)
assert mpi(-inf, inf) / mpi(0.5, inf) == mpi(-inf, inf)
assert mpi(8, inf) / mpi(0.5, inf) == mpi(0, inf)
assert mpi(-8, inf) / mpi(0.5, inf) == mpi(-16, inf)
assert mpi(-4, 8) / mpi(inf, inf) == mpi(0, 0)
assert mpi(0, 8) / mpi(inf, inf) == mpi(0, 0)
assert mpi(0, 0) / mpi(inf, inf) == mpi(0, 0)
assert mpi(-inf, 0) / mpi(inf, inf) == mpi(-inf, 0)
assert mpi(-inf, 8) / mpi(inf, inf) == mpi(-inf, 0)
assert mpi(-inf, inf) / mpi(inf, inf) == mpi(-inf, inf)
assert mpi(-8, inf) / mpi(inf, inf) == mpi(0, inf)
assert mpi(0, inf) / mpi(inf, inf) == mpi(0, inf)
assert mpi(8, inf) / mpi(inf, inf) == mpi(0, inf)
assert mpi(inf, inf) / mpi(inf, inf) == mpi(0, inf)
assert mpi(-1, 2) / mpi(0, 1) == mpi(-inf, +inf)
assert mpi(0, 1) / mpi(0, 1) == mpi(0.0, +inf)
assert mpi(-1, 0) / mpi(0, 1) == mpi(-inf, 0.0)
assert mpi(-0.5, -0.25) / mpi(0, 1) == mpi(-inf, -0.25)
assert mpi(0.5, 1) / mpi(0, 1) == mpi(0.5, +inf)
assert mpi(0.5, 4) / mpi(0, 1) == mpi(0.5, +inf)
assert mpi(-1, -0.5) / mpi(0, 1) == mpi(-inf, -0.5)
assert mpi(-4, -0.5) / mpi(0, 1) == mpi(-inf, -0.5)
assert mpi(-1, 2) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(0, 1) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(-1, 0) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(-0.5, -0.25) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(0.5, 1) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(0.5, 4) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(-1, -0.5) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(-4, -0.5) / mpi(-2, 0.5) == mpi(-inf, +inf)
assert mpi(-1, 2) / mpi(-1, 0) == mpi(-inf, +inf)
assert mpi(0, 1) / mpi(-1, 0) == mpi(-inf, 0.0)
assert mpi(-1, 0) / mpi(-1, 0) == mpi(0.0, +inf)
assert mpi(-0.5, -0.25) / mpi(-1, 0) == mpi(0.25, +inf)
assert mpi(0.5, 1) / mpi(-1, 0) == mpi(-inf, -0.5)
assert mpi(0.5, 4) / mpi(-1, 0) == mpi(-inf, -0.5)
assert mpi(-1, -0.5) / mpi(-1, 0) == mpi(0.5, +inf)
assert mpi(-4, -0.5) / mpi(-1, 0) == mpi(0.5, +inf)
assert mpi(-1, 2) / mpi(0.5, 1) == mpi(-2.0, 4.0)
assert mpi(0, 1) / mpi(0.5, 1) == mpi(0.0, 2.0)
assert mpi(-1, 0) / mpi(0.5, 1) == mpi(-2.0, 0.0)
assert mpi(-0.5, -0.25) / mpi(0.5, 1) == mpi(-1.0, -0.25)
assert mpi(0.5, 1) / mpi(0.5, 1) == mpi(0.5, 2.0)
assert mpi(0.5, 4) / mpi(0.5, 1) == mpi(0.5, 8.0)
assert mpi(-1, -0.5) / mpi(0.5, 1) == mpi(-2.0, -0.5)
assert mpi(-4, -0.5) / mpi(0.5, 1) == mpi(-8.0, -0.5)
assert mpi(-1, 2) / mpi(-2, -0.5) == mpi(-4.0, 2.0)
assert mpi(0, 1) / mpi(-2, -0.5) == mpi(-2.0, 0.0)
assert mpi(-1, 0) / mpi(-2, -0.5) == mpi(0.0, 2.0)
assert mpi(-0.5, -0.25) / mpi(-2, -0.5) == mpi(0.125, 1.0)
assert mpi(0.5, 1) / mpi(-2, -0.5) == mpi(-2.0, -0.25)
assert mpi(0.5, 4) / mpi(-2, -0.5) == mpi(-8.0, -0.25)
assert mpi(-1, -0.5) / mpi(-2, -0.5) == mpi(0.25, 2.0)
assert mpi(-4, -0.5) / mpi(-2, -0.5) == mpi(0.25, 8.0)
# Should be undefined?
assert mpi(0, 0) / mpi(0, 0) == mpi(-inf, inf)
assert mpi(0, 0) / mpi(0, 1) == mpi(-inf, inf)
def test_interval_cos_sin():
mp.dps = 15
# Around 0
assert cos(mpi(0)) == 1
assert sin(mpi(0)) == 0
assert cos(mpi(0,1)) == mpi(0.54030230586813965399, 1.0)
assert sin(mpi(0,1)) == mpi(0, 0.8414709848078966159)
assert cos(mpi(1,2)) == mpi(-0.4161468365471424069, 0.54030230586813976501)
assert sin(mpi(1,2)) == mpi(0.84147098480789650488, 1.0)
assert sin(mpi(1,2.5)) == mpi(0.59847214410395643824, 1.0)
assert cos(mpi(-1, 1)) == mpi(0.54030230586813965399, 1.0)
assert cos(mpi(-1, 0.5)) == mpi(0.54030230586813965399, 1.0)
assert cos(mpi(-1, 1.5)) == mpi(0.070737201667702906405, 1.0)
assert sin(mpi(-1,1)) == mpi(-0.8414709848078966159, 0.8414709848078966159)
assert sin(mpi(-1,0.5)) == mpi(-0.8414709848078966159, 0.47942553860420300538)
assert sin(mpi(-1,1e-100)) == mpi(-0.8414709848078966159, 1.00000000000000002e-100)
assert sin(mpi(-2e-100,1e-100)) == mpi(-2.00000000000000004e-100, 1.00000000000000002e-100)
# Same interval
assert cos(mpi(2, 2.5)) == mpi(-0.80114361554693380718, -0.41614683654714235139)
assert cos(mpi(3.5, 4)) == mpi(-0.93645668729079634129, -0.65364362086361182946)
assert cos(mpi(5, 5.5)) == mpi(0.28366218546322624627, 0.70866977429126010168)
assert sin(mpi(2, 2.5)) == mpi(0.59847214410395654927, 0.90929742682568170942)
assert sin(mpi(3.5, 4)) == mpi(-0.75680249530792831347, -0.35078322768961983646)
assert sin(mpi(5, 5.5)) == mpi(-0.95892427466313856499, -0.70554032557039181306)
# Higher roots
mp.dps = 55
w = 4*10**50 + mpf(0.5)
for p in [15, 40, 80]:
mp.dps = p
assert 0 in sin(4*mpi(pi))
assert 0 in sin(4*10**50*mpi(pi))
assert 0 in cos((4+0.5)*mpi(pi))
assert 0 in cos(w*mpi(pi))
assert 1 in cos(4*mpi(pi))
assert 1 in cos(4*10**50*mpi(pi))
mp.dps = 15
assert cos(mpi(2,inf)) == mpi(-1,1)
assert sin(mpi(2,inf)) == mpi(-1,1)
assert cos(mpi(-inf,2)) == mpi(-1,1)
assert sin(mpi(-inf,2)) == mpi(-1,1)
u = tan(mpi(0.5,1))
assert u.a.ae(tan(0.5))
assert u.b.ae(tan(1))
v = cot(mpi(0.5,1))
assert v.a.ae(cot(1))
assert v.b.ae(cot(0.5))
def test_mpi_to_str():
mp.dps = 30
x = mpi(1, 2)
# FIXME: error_dps should not be necessary
assert mpi_to_str(x, mode='plusminus', error_dps=6) == '1.5 +- 0.5'
assert mpi_to_str(x, mode='plusminus', use_spaces=False, error_dps=6
) == '1.5+-0.5'
assert mpi_to_str(x, mode='percent') == '1.5 (33.33%)'
assert mpi_to_str(x, mode='brackets', use_spaces=False) == '[1.0,2.0]'
assert mpi_to_str(x, mode='brackets' , brackets=('<', '>')) == '<1.0, 2.0>'
x = mpi('5.2582327113062393041', '5.2582327113062749951')
assert (mpi_to_str(x, mode='diff') ==
'5.2582327113062[393041, 749951]')
assert (mpi_to_str(cos(mpi(1)), mode='diff', use_spaces=False) ==
'0.54030230586813971740093660744[2955,3053]')
assert (mpi_to_str(mpi('1e123', '1e129'), mode='diff') ==
'[1.0e+123, 1.0e+129]')
assert (mpi_to_str(exp(mpi('5000.1')), mode='diff') ==
'3.2797365856787867069110487[0926, 1191]e+2171')
def test_mpi_from_str():
assert mpi_from_str('1.5 +- 0.5') == mpi(mpf('1.0'), mpf('2.0'))
assert (mpi_from_str('1.5 (33.33333333333333333333333333333%)') ==
mpi(mpf(1), mpf(2)))
assert mpi_from_str('[1, 2]') == mpi(1, 2)
assert mpi_from_str('1[2, 3]') == mpi(12, 13)
assert mpi_from_str('1.[23,46]e-8') == mpi('1.23e-8', '1.46e-8')
assert mpi_from_str('12[3.4,5.9]e4') == mpi('123.4e+4', '125.9e4')

View File

@ -1,243 +0,0 @@
# TODO: don't use round
from __future__ import division
from mpmath import *
# XXX: these shouldn't be visible(?)
LU_decomp = mp.LU_decomp
L_solve = mp.L_solve
U_solve = mp.U_solve
householder = mp.householder
improve_solution = mp.improve_solution
A1 = matrix([[3, 1, 6],
[2, 1, 3],
[1, 1, 1]])
b1 = [2, 7, 4]
A2 = matrix([[ 2, -1, -1, 2],
[ 6, -2, 3, -1],
[-4, 2, 3, -2],
[ 2, 0, 4, -3]])
b2 = [3, -3, -2, -1]
A3 = matrix([[ 1, 0, -1, -1, 0],
[ 0, 1, 1, 0, -1],
[ 4, -5, 2, 0, 0],
[ 0, 0, -2, 9,-12],
[ 0, 5, 0, 0, 12]])
b3 = [0, 0, 0, 0, 50]
A4 = matrix([[10.235, -4.56, 0., -0.035, 5.67],
[-2.463, 1.27, 3.97, -8.63, 1.08],
[-6.58, 0.86, -0.257, 9.32, -43.6 ],
[ 9.83, 7.39, -17.25, 0.036, 24.86],
[-9.31, 34.9, 78.56, 1.07, 65.8 ]])
b4 = [8.95, 20.54, 7.42, 5.60, 58.43]
A5 = matrix([[ 1, 2, -4],
[-2, -3, 5],
[ 3, 5, -8]])
A6 = matrix([[ 1.377360, 2.481400, 5.359190],
[ 2.679280, -1.229560, 25.560210],
[-1.225280+1.e6, 9.910180, -35.049900-1.e6]])
b6 = [23.500000, -15.760000, 2.340000]
A7 = matrix([[1, -0.5],
[2, 1],
[-2, 6]])
b7 = [3, 2, -4]
A8 = matrix([[1, 2, 3],
[-1, 0, 1],
[-1, -2, -1],
[1, 0, -1]])
b8 = [1, 2, 3, 4]
A9 = matrix([[ 4, 2, -2],
[ 2, 5, -4],
[-2, -4, 5.5]])
b9 = [10, 16, -15.5]
A10 = matrix([[1.0 + 1.0j, 2.0, 2.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
b10 = [1.0, 1.0 + 1.0j, 1.0]
def test_LU_decomp():
A = A3.copy()
b = b3
A, p = LU_decomp(A)
y = L_solve(A, b, p)
x = U_solve(A, y)
assert p == [2, 1, 2, 3]
assert [round(i, 14) for i in x] == [3.78953107960742, 2.9989094874591098,
-0.081788440567070006, 3.8713195201744801, 2.9171210468920399]
A = A4.copy()
b = b4
A, p = LU_decomp(A)
y = L_solve(A, b, p)
x = U_solve(A, y)
assert p == [0, 3, 4, 3]
assert [round(i, 14) for i in x] == [2.6383625899619201, 2.6643834462368399,
0.79208015947958998, -2.5088376454101899, -1.0567657691375001]
A = randmatrix(3)
bak = A.copy()
LU_decomp(A, overwrite=1)
assert A != bak
def test_inverse():
for A in [A1, A2, A5]:
inv = inverse(A)
assert mnorm(A*inv - eye(A.rows), 1) < 1.e-14
def test_householder():
mp.dps = 15
A, b = A8, b8
H, p, x, r = householder(extend(A, b))
assert H == matrix(
[[mpf('3.0'), mpf('-2.0'), mpf('-1.0'), 0],
[-1.0,mpf('3.333333333333333'),mpf('-2.9999999999999991'),mpf('2.0')],
[-1.0, mpf('-0.66666666666666674'),mpf('2.8142135623730948'),
mpf('-2.8284271247461898')],
[1.0, mpf('-1.3333333333333333'),mpf('-0.20000000000000018'),
mpf('4.2426406871192857')]])
assert p == [-2, -2, mpf('-1.4142135623730949')]
assert round(norm(r, 2), 10) == 4.2426406870999998
y = [102.102, 58.344, 36.463, 24.310, 17.017, 12.376, 9.282, 7.140, 5.610,
4.488, 3.6465, 3.003]
def coeff(n):
# similiar to Hilbert matrix
A = []
for i in xrange(1, 13):
A.append([1. / (i + j - 1) for j in xrange(1, n + 1)])
return matrix(A)
residuals = []
refres = []
for n in xrange(2, 7):
A = coeff(n)
H, p, x, r = householder(extend(A, y))
x = matrix(x)
y = matrix(y)
residuals.append(norm(r, 2))
refres.append(norm(residual(A, x, y), 2))
assert [round(res, 10) for res in residuals] == [15.1733888877,
0.82378073210000002, 0.302645887, 0.0260109244,
0.00058653999999999998]
assert norm(matrix(residuals) - matrix(refres), inf) < 1.e-13
def test_factorization():
A = randmatrix(5)
P, L, U = lu(A)
assert mnorm(P*A - L*U, 1) < 1.e-15
def test_solve():
assert norm(residual(A6, lu_solve(A6, b6), b6), inf) < 1.e-10
assert norm(residual(A7, lu_solve(A7, b7), b7), inf) < 1.5
assert norm(residual(A8, lu_solve(A8, b8), b8), inf) <= 3 + 1.e-10
assert norm(residual(A6, qr_solve(A6, b6)[0], b6), inf) < 1.e-10
assert norm(residual(A7, qr_solve(A7, b7)[0], b7), inf) < 1.5
assert norm(residual(A8, qr_solve(A8, b8)[0], b8), 2) <= 4.3
assert norm(residual(A10, lu_solve(A10, b10), b10), 2) < 1.e-10
assert norm(residual(A10, qr_solve(A10, b10)[0], b10), 2) < 1.e-10
def test_solve_overdet_complex():
A = matrix([[1, 2j], [3, 4j], [5, 6]])
b = matrix([1 + j, 2, -j])
assert norm(residual(A, lu_solve(A, b), b)) < 1.0208
def test_singular():
mp.dps = 15
A = [[5.6, 1.2], [7./15, .1]]
B = repr(zeros(2))
b = [1, 2]
def _assert_ZeroDivisionError(statement):
try:
eval(statement)
assert False
except (ZeroDivisionError, ValueError):
pass
for i in ['lu_solve(%s, %s)' % (A, b), 'lu_solve(%s, %s)' % (B, b),
'qr_solve(%s, %s)' % (A, b), 'qr_solve(%s, %s)' % (B, b)]:
_assert_ZeroDivisionError(i)
def test_cholesky():
assert fp.cholesky(fp.matrix(A9)) == fp.matrix([[2, 0, 0], [1, 2, 0], [-1, -3/2, 3/2]])
x = fp.cholesky_solve(A9, b9)
assert fp.norm(fp.residual(A9, x, b9), fp.inf) == 0
def test_det():
assert det(A1) == 1
assert round(det(A2), 14) == 8
assert round(det(A3)) == 1834
assert round(det(A4)) == 4443376
assert det(A5) == 1
assert round(det(A6)) == 78356463
assert det(zeros(3)) == 0
def test_cond():
mp.dps = 15
A = matrix([[1.2969, 0.8648], [0.2161, 0.1441]])
assert cond(A, lambda x: mnorm(x,1)) == mpf('327065209.73817754')
assert cond(A, lambda x: mnorm(x,inf)) == mpf('327065209.73817754')
assert cond(A, lambda x: mnorm(x,'F')) == mpf('249729266.80008656')
@extradps(50)
def test_precision():
A = randmatrix(10, 10)
assert mnorm(inverse(inverse(A)) - A, 1) < 1.e-45
def test_interval_matrix():
a = matrix([['0.1','0.3','1.0'],['7.1','5.5','4.8'],['3.2','4.4','5.6']],
force_type=mpi)
b = matrix(['4','0.6','0.5'], force_type=mpi)
c = lu_solve(a, b)
assert c[0].delta < 1e-13
assert c[1].delta < 1e-13
assert c[2].delta < 1e-13
assert 5.25823271130625686059275 in c[0]
assert -13.155049396267837541163 in c[1]
assert 7.42069154774972557628979 in c[2]
def test_LU_cache():
A = randmatrix(3)
LU = LU_decomp(A)
assert A._LU == LU_decomp(A)
A[0,0] = -1000
assert A._LU is None
def test_improve_solution():
A = randmatrix(5, min=1e-20, max=1e20)
b = randmatrix(5, 1, min=-1000, max=1000)
x1 = lu_solve(A, b) + randmatrix(5, 1, min=-1e-5, max=1.e-5)
x2 = improve_solution(A, x1, b)
assert norm(residual(A, x2, b), 2) < norm(residual(A, x1, b), 2)
def test_exp_pade():
for i in range(3):
dps = 15
extra = 5
mp.dps = dps + extra
dm = 0
while not dm:
m = randmatrix(3)
dm = det(m)
m = m/dm
a = diag([1,2,3])
a1 = m**-1 * a * m
mp.dps = dps
e1 = expm(a1, method='pade')
mp.dps = dps + extra
e2 = m * a1 * m**-1
d = e2 - a
#print d
mp.dps = dps
assert norm(d, inf).ae(0)
mp.dps = 15

View File

@ -1,144 +0,0 @@
from mpmath import *
def test_matrix_basic():
A1 = matrix(3)
for i in xrange(3):
A1[i,i] = 1
assert A1 == eye(3)
assert A1 == matrix(A1)
A2 = matrix(3, 2)
assert not A2._matrix__data
A3 = matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
assert list(A3) == range(1, 10)
A3[1,1] = 0
assert not (1, 1) in A3._matrix__data
A4 = matrix([[1, 2, 3], [4, 5, 6]])
A5 = matrix([[6, -1], [3, 2], [0, -3]])
assert A4 * A5 == matrix([[12, -6], [39, -12]])
assert A1 * A3 == A3 * A1 == A3
try:
A2 * A2
assert False
except ValueError:
pass
l = [[10, 20, 30], [40, 0, 60], [70, 80, 90]]
A6 = matrix(l)
assert A6.tolist() == l
assert A6 == eval(repr(A6))
A6 = matrix(A6, force_type=float)
assert A6 == eval(repr(A6))
assert A6*1j == eval(repr(A6*1j))
assert A3 * 10 == 10 * A3 == A6
assert A2.rows == 3
assert A2.cols == 2
A3.rows = 2
A3.cols = 2
assert len(A3._matrix__data) == 3
assert A4 + A4 == 2*A4
try:
A4 + A2
except ValueError:
pass
assert sum(A1 - A1) == 0
A7 = matrix([[1, 2], [3, 4], [5, 6], [7, 8]])
x = matrix([10, -10])
assert A7*x == matrix([-10, -10, -10, -10])
A8 = ones(5)
assert sum((A8 + 1) - (2 - zeros(5))) == 0
assert (1 + ones(4)) / 2 - 1 == zeros(4)
assert eye(3)**10 == eye(3)
try:
A7**2
assert False
except ValueError:
pass
A9 = randmatrix(3)
A10 = matrix(A9)
A9[0,0] = -100
assert A9 != A10
A11 = matrix(randmatrix(2, 3), force_type=mpi)
for a in A11:
assert isinstance(a, mpi)
assert nstr(A9)
def test_matrix_power():
A = matrix([[1, 2], [3, 4]])
assert A**2 == A*A
assert A**3 == A*A*A
assert A**-1 == inverse(A)
assert A**-2 == inverse(A*A)
def test_matrix_transform():
A = matrix([[1, 2], [3, 4], [5, 6]])
assert A.T == A.transpose() == matrix([[1, 3, 5], [2, 4, 6]])
swap_row(A, 1, 2)
assert A == matrix([[1, 2], [5, 6], [3, 4]])
l = [1, 2]
swap_row(l, 0, 1)
assert l == [2, 1]
assert extend(eye(3), [1,2,3]) == matrix([[1,0,0,1],[0,1,0,2],[0,0,1,3]])
def test_matrix_conjugate():
A = matrix([[1 + j, 0], [2, j]])
assert A.conjugate() == matrix([[mpc(1, -1), 0], [2, mpc(0, -1)]])
assert A.transpose_conj() == A.H == matrix([[mpc(1, -1), 2],
[0, mpc(0, -1)]])
def test_matrix_creation():
assert diag([1, 2, 3]) == matrix([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
A1 = ones(2, 3)
assert A1.rows == 2 and A1.cols == 3
for a in A1:
assert a == 1
A2 = zeros(3, 2)
assert A2.rows == 3 and A2.cols == 2
for a in A2:
assert a == 0
assert randmatrix(10) != randmatrix(10)
one = mpf(1)
assert hilbert(3) == matrix([[one, one/2, one/3],
[one/2, one/3, one/4],
[one/3, one/4, one/5]])
def test_norms():
# matrix norms
A = matrix([[1, -2], [-3, -1], [2, 1]])
assert mnorm(A,1) == 6
assert mnorm(A,inf) == 4
assert mnorm(A,'F') == sqrt(20)
# vector norms
assert norm(-3) == 3
x = [1, -2, 7, -12]
assert norm(x, 1) == 22
assert round(norm(x, 2), 10) == 14.0712472795
assert round(norm(x, 10), 10) == 12.0054633727
assert norm(x, inf) == 12
def test_vector():
x = matrix([0, 1, 2, 3, 4])
assert x == matrix([[0], [1], [2], [3], [4]])
assert x[3] == 3
assert len(x._matrix__data) == 4
assert list(x) == range(5)
x[0] = -10
x[4] = 0
assert x[0] == -10
assert len(x) == len(x.T) == 5
assert x.T*x == matrix([[114]])
def test_matrix_copy():
A = ones(6)
B = A.copy()
assert A == B
B[0,0] = 0
assert A != B
def test_matrix_numpy():
try:
import numpy
except ImportError:
return
l = [[1, 2], [3, 4], [5, 6]]
a = numpy.matrix(l)
assert matrix(l) == matrix(a)

View File

@ -1,98 +0,0 @@
from mpmath.libmp import *
from mpmath import *
import random
#----------------------------------------------------------------------------
# Low-level tests
#
# Advanced rounding test
def test_add_rounding():
mp.dps = 15
a = from_float(1e-50)
assert mpf_sub(mpf_add(fone, a, 53, round_up), fone, 53, round_up) == from_float(2.2204460492503131e-16)
assert mpf_sub(fone, a, 53, round_up) == fone
assert mpf_sub(fone, mpf_sub(fone, a, 53, round_down), 53, round_down) == from_float(1.1102230246251565e-16)
assert mpf_add(fone, a, 53, round_down) == fone
def test_almost_equal():
assert mpf(1.2).ae(mpf(1.20000001), 1e-7)
assert not mpf(1.2).ae(mpf(1.20000001), 1e-9)
assert not mpf(-0.7818314824680298).ae(mpf(-0.774695868667929))
#----------------------------------------------------------------------------
# Test basic arithmetic
#
# Test that integer arithmetic is exact
def test_aintegers():
# XXX: re-fix this so that all operations are tested with all rounding modes
random.seed(0)
for prec in [6, 10, 25, 40, 100, 250, 725]:
for rounding in ['d', 'u', 'f', 'c', 'n']:
mp.dps = prec
M = 10**(prec-2)
M2 = 10**(prec//2-2)
for i in range(10):
a = random.randint(-M, M)
b = random.randint(-M, M)
assert mpf(a, rounding=rounding) == a
assert int(mpf(a, rounding=rounding)) == a
assert int(mpf(str(a), rounding=rounding)) == a
assert mpf(a) + mpf(b) == a + b
assert mpf(a) - mpf(b) == a - b
assert -mpf(a) == -a
a = random.randint(-M2, M2)
b = random.randint(-M2, M2)
assert mpf(a) * mpf(b) == a*b
assert mpf_mul(from_int(a), from_int(b), mp.prec, rounding) == from_int(a*b)
mp.dps = 15
def test_odd_int_bug():
assert to_int(from_int(3), round_nearest) == 3
def test_str_1000_digits():
mp.dps = 1001
# last digit may be wrong
assert str(mpf(2)**0.5)[-10:-1] == '9518488472'[:9]
assert str(pi)[-10:-1] == '2164201989'[:9]
mp.dps = 15
def test_str_10000_digits():
mp.dps = 10001
# last digit may be wrong
assert str(mpf(2)**0.5)[-10:-1] == '5873258351'[:9]
assert str(pi)[-10:-1] == '5256375678'[:9]
mp.dps = 15
def test_monitor():
f = lambda x: x**2
a = []
b = []
g = monitor(f, a.append, b.append)
assert g(3) == 9
assert g(4) == 16
assert a[0] == ((3,), {})
assert b[0] == 9
def test_nint_distance():
nint_distance(mpf(-3)) == (-3, -inf)
nint_distance(mpc(-3)) == (-3, -inf)
nint_distance(mpf(-3.1)) == (-3, -3)
nint_distance(mpf(-3.01)) == (-3, -6)
nint_distance(mpf(-3.001)) == (-3, -9)
nint_distance(mpf(-3.0001)) == (-3, -13)
nint_distance(mpf(-2.9)) == (-3, -3)
nint_distance(mpf(-2.99)) == (-3, -6)
nint_distance(mpf(-2.999)) == (-3, -9)
nint_distance(mpf(-2.9999)) == (-3, -13)
nint_distance(mpc(-3+0.1j)) == (-3, -3)
nint_distance(mpc(-3+0.01j)) == (-3, -6)
nint_distance(mpc(-3.1+0.1j)) == (-3, -3)
nint_distance(mpc(-3.01+0.01j)) == (-3, -6)
nint_distance(mpc(-3.001+0.001j)) == (-3, -9)
nint_distance(mpf(0)) == (0, -inf)
nint_distance(mpf(0.01)) == (0, -6)
nint_distance(mpf('1e-100')) == (0, -332)

View File

@ -1,73 +0,0 @@
#from mpmath.calculus import ODE_step_euler, ODE_step_rk4, odeint, arange
from mpmath import odefun, cos, sin, mpf, sinc, mp
'''
solvers = [ODE_step_euler, ODE_step_rk4]
def test_ode1():
"""
Let's solve:
x'' + w**2 * x = 0
i.e. x1 = x, x2 = x1':
x1' = x2
x2' = -x1
"""
def derivs((x1, x2), t):
return x2, -x1
for solver in solvers:
t = arange(0, 3.1415926, 0.005)
sol = odeint(derivs, (0., 1.), t, solver)
x1 = [a[0] for a in sol]
x2 = [a[1] for a in sol]
# the result is x1 = sin(t), x2 = cos(t)
# let's just check the end points for t = pi
assert abs(x1[-1]) < 1e-2
assert abs(x2[-1] - (-1)) < 1e-2
def test_ode2():
"""
Let's solve:
x' - x = 0
i.e. x = exp(x)
"""
def derivs((x), t):
return x
for solver in solvers:
t = arange(0, 1, 1e-3)
sol = odeint(derivs, (1.,), t, solver)
x = [a[0] for a in sol]
# the result is x = exp(t)
# let's just check the end point for t = 1, i.e. x = e
assert abs(x[-1] - 2.718281828) < 1e-2
'''
def test_odefun_rational():
mp.dps = 15
# A rational function
f = lambda t: 1/(1+mpf(t)**2)
g = odefun(lambda x, y: [-2*x*y[0]**2], 0, [f(0)])
assert f(2).ae(g(2)[0])
def test_odefun_sinc_large():
mp.dps = 15
# Sinc function; test for large x
f = sinc
g = odefun(lambda x, y: [(cos(x)-y[0])/x], 1, [f(1)], tol=0.01, degree=5)
assert abs(f(100) - g(100)[0])/f(100) < 0.01
def test_odefun_harmonic():
mp.dps = 15
# Harmonic oscillator
f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
for x in [0, 1, 2.5, 8, 3.7]: # we go back to 3.7 to check caching
c, s = f(x)
assert c.ae(cos(x))
assert s.ae(sin(x))

View File

@ -1,27 +0,0 @@
import os
import tempfile
import pickle
from mpmath import *
def pickler(obj):
fn = tempfile.mktemp()
f = open(fn, 'wb')
pickle.dump(obj, f)
f.close()
f = open(fn, 'rb')
obj2 = pickle.load(f)
f.close()
os.remove(fn)
return obj2
def test_pickle():
obj = mpf('0.5')
assert obj == pickler(obj)
obj = mpc('0.5','0.2')
assert obj == pickler(obj)

View File

@ -1,155 +0,0 @@
from mpmath import *
from mpmath.libmp import *
import random
def test_fractional_pow():
assert mpf(16) ** 2.5 == 1024
assert mpf(64) ** 0.5 == 8
assert mpf(64) ** -0.5 == 0.125
assert mpf(16) ** -2.5 == 0.0009765625
assert (mpf(10) ** 0.5).ae(3.1622776601683791)
assert (mpf(10) ** 2.5).ae(316.2277660168379)
assert (mpf(10) ** -0.5).ae(0.31622776601683794)
assert (mpf(10) ** -2.5).ae(0.0031622776601683794)
assert (mpf(10) ** 0.3).ae(1.9952623149688795)
assert (mpf(10) ** -0.3).ae(0.50118723362727224)
def test_pow_integer_direction():
"""
Test that inexact integer powers are rounded in the right
direction.
"""
random.seed(1234)
for prec in [10, 53, 200]:
for i in range(50):
a = random.randint(1<<(prec-1), 1<<prec)
b = random.randint(2, 100)
ab = a**b
# note: could actually be exact, but that's very unlikely!
assert to_int(mpf_pow(from_int(a), from_int(b), prec, round_down)) < ab
assert to_int(mpf_pow(from_int(a), from_int(b), prec, round_up)) > ab
def test_pow_epsilon_rounding():
"""
Stress test directed rounding for powers with integer exponents.
Basically, we look at the following cases:
>>> 1.0001 ** -5
0.99950014996500702
>>> 0.9999 ** -5
1.000500150035007
>>> (-1.0001) ** -5
-0.99950014996500702
>>> (-0.9999) ** -5
-1.000500150035007
>>> 1.0001 ** -6
0.99940020994401269
>>> 0.9999 ** -6
1.0006002100560125
>>> (-1.0001) ** -6
0.99940020994401269
>>> (-0.9999) ** -6
1.0006002100560125
etc.
We run the tests with values a very small epsilon away from 1:
small enough that the result is indistinguishable from 1 when
rounded to nearest at the output precision. We check that the
result is not erroneously rounded to 1 in cases where the
rounding should be done strictly away from 1.
"""
def powr(x, n, r):
return make_mpf(mpf_pow_int(x._mpf_, n, mp.prec, r))
for (inprec, outprec) in [(100, 20), (5000, 3000)]:
mp.prec = inprec
pos10001 = mpf(1) + mpf(2)**(-inprec+5)
pos09999 = mpf(1) - mpf(2)**(-inprec+5)
neg10001 = -pos10001
neg09999 = -pos09999
mp.prec = outprec
r = round_up
assert powr(pos10001, 5, r) > 1
assert powr(pos09999, 5, r) == 1
assert powr(neg10001, 5, r) < -1
assert powr(neg09999, 5, r) == -1
assert powr(pos10001, 6, r) > 1
assert powr(pos09999, 6, r) == 1
assert powr(neg10001, 6, r) > 1
assert powr(neg09999, 6, r) == 1
assert powr(pos10001, -5, r) == 1
assert powr(pos09999, -5, r) > 1
assert powr(neg10001, -5, r) == -1
assert powr(neg09999, -5, r) < -1
assert powr(pos10001, -6, r) == 1
assert powr(pos09999, -6, r) > 1
assert powr(neg10001, -6, r) == 1
assert powr(neg09999, -6, r) > 1
r = round_down
assert powr(pos10001, 5, r) == 1
assert powr(pos09999, 5, r) < 1
assert powr(neg10001, 5, r) == -1
assert powr(neg09999, 5, r) > -1
assert powr(pos10001, 6, r) == 1
assert powr(pos09999, 6, r) < 1
assert powr(neg10001, 6, r) == 1
assert powr(neg09999, 6, r) < 1
assert powr(pos10001, -5, r) < 1
assert powr(pos09999, -5, r) == 1
assert powr(neg10001, -5, r) > -1
assert powr(neg09999, -5, r) == -1
assert powr(pos10001, -6, r) < 1
assert powr(pos09999, -6, r) == 1
assert powr(neg10001, -6, r) < 1
assert powr(neg09999, -6, r) == 1
r = round_ceiling
assert powr(pos10001, 5, r) > 1
assert powr(pos09999, 5, r) == 1
assert powr(neg10001, 5, r) == -1
assert powr(neg09999, 5, r) > -1
assert powr(pos10001, 6, r) > 1
assert powr(pos09999, 6, r) == 1
assert powr(neg10001, 6, r) > 1
assert powr(neg09999, 6, r) == 1
assert powr(pos10001, -5, r) == 1
assert powr(pos09999, -5, r) > 1
assert powr(neg10001, -5, r) > -1
assert powr(neg09999, -5, r) == -1
assert powr(pos10001, -6, r) == 1
assert powr(pos09999, -6, r) > 1
assert powr(neg10001, -6, r) == 1
assert powr(neg09999, -6, r) > 1
r = round_floor
assert powr(pos10001, 5, r) == 1
assert powr(pos09999, 5, r) < 1
assert powr(neg10001, 5, r) < -1
assert powr(neg09999, 5, r) == -1
assert powr(pos10001, 6, r) == 1
assert powr(pos09999, 6, r) < 1
assert powr(neg10001, 6, r) == 1
assert powr(neg09999, 6, r) < 1
assert powr(pos10001, -5, r) < 1
assert powr(pos09999, -5, r) == 1
assert powr(neg10001, -5, r) == -1
assert powr(neg09999, -5, r) < -1
assert powr(pos10001, -6, r) < 1
assert powr(pos09999, -6, r) == 1
assert powr(neg10001, -6, r) < 1
assert powr(neg09999, -6, r) == 1
mp.dps = 15

View File

@ -1,85 +0,0 @@
from mpmath import *
def ae(a, b):
return abs(a-b) < 10**(-mp.dps+5)
def test_basic_integrals():
for prec in [15, 30, 100]:
mp.dps = prec
assert ae(quadts(lambda x: x**3 - 3*x**2, [-2, 4]), -12)
assert ae(quadgl(lambda x: x**3 - 3*x**2, [-2, 4]), -12)
assert ae(quadts(sin, [0, pi]), 2)
assert ae(quadts(sin, [0, 2*pi]), 0)
assert ae(quadts(exp, [-inf, -1]), 1/e)
assert ae(quadts(lambda x: exp(-x), [0, inf]), 1)
assert ae(quadts(lambda x: exp(-x*x), [-inf, inf]), sqrt(pi))
assert ae(quadts(lambda x: 1/(1+x*x), [-1, 1]), pi/2)
assert ae(quadts(lambda x: 1/(1+x*x), [-inf, inf]), pi)
assert ae(quadts(lambda x: 2*sqrt(1-x*x), [-1, 1]), pi)
mp.dps = 15
def test_quad_symmetry():
assert quadts(sin, [-1, 1]) == 0
assert quadgl(sin, [-1, 1]) == 0
def test_quadgl_linear():
assert quadgl(lambda x: x, [0, 1], maxdegree=1).ae(0.5)
def test_complex_integration():
assert quadts(lambda x: x, [0, 1+j]).ae(j)
def test_quadosc():
mp.dps = 15
assert quadosc(lambda x: sin(x)/x, [0, inf], period=2*pi).ae(pi/2)
# Double integrals
def test_double_trivial():
assert ae(quadts(lambda x, y: x, [0, 1], [0, 1]), 0.5)
assert ae(quadts(lambda x, y: x, [-1, 1], [-1, 1]), 0.0)
def test_double_1():
assert ae(quadts(lambda x, y: cos(x+y/2), [-pi/2, pi/2], [0, pi]), 4)
def test_double_2():
assert ae(quadts(lambda x, y: (x-1)/((1-x*y)*log(x*y)), [0, 1], [0, 1]), euler)
def test_double_3():
assert ae(quadts(lambda x, y: 1/sqrt(1+x*x+y*y), [-1, 1], [-1, 1]), 4*log(2+sqrt(3))-2*pi/3)
def test_double_4():
assert ae(quadts(lambda x, y: 1/(1-x*x * y*y), [0, 1], [0, 1]), pi**2 / 8)
def test_double_5():
assert ae(quadts(lambda x, y: 1/(1-x*y), [0, 1], [0, 1]), pi**2 / 6)
def test_double_6():
assert ae(quadts(lambda x, y: exp(-(x+y)), [0, inf], [0, inf]), 1)
# fails
def xtest_double_7():
assert ae(quadts(lambda x, y: exp(-x*x-y*y), [-inf, inf], [-inf, inf]), pi)
# Test integrals from "Experimentation in Mathematics" by Borwein,
# Bailey & Girgensohn
def test_expmath_integrals():
for prec in [15, 30, 50]:
mp.dps = prec
assert ae(quadts(lambda x: x/sinh(x), [0, inf]), pi**2 / 4)
assert ae(quadts(lambda x: log(x)**2 / (1+x**2), [0, inf]), pi**3 / 8)
assert ae(quadts(lambda x: (1+x**2)/(1+x**4), [0, inf]), pi/sqrt(2))
assert ae(quadts(lambda x: log(x)/cosh(x)**2, [0, inf]), log(pi)-2*log(2)-euler)
assert ae(quadts(lambda x: log(1+x**3)/(1-x+x**2), [0, inf]), 2*pi*log(3)/sqrt(3))
assert ae(quadts(lambda x: log(x)**2 / (x**2+x+1), [0, 1]), 8*pi**3 / (81*sqrt(3)))
assert ae(quadts(lambda x: log(cos(x))**2, [0, pi/2]), pi/2 * (log(2)**2+pi**2/12))
assert ae(quadts(lambda x: x**2 / sin(x)**2, [0, pi/2]), pi*log(2))
assert ae(quadts(lambda x: x**2/sqrt(exp(x)-1), [0, inf]), 4*pi*(log(2)**2 + pi**2/12))
assert ae(quadts(lambda x: x*exp(-x)*sqrt(1-exp(-2*x)), [0, inf]), pi*(1+2*log(2))/8)
mp.dps = 15
# Do not reach full accuracy
def xtest_expmath_fail():
assert ae(quadts(lambda x: sqrt(tan(x)), [0, pi/2]), pi*sqrt(2)/2)
assert ae(quadts(lambda x: atan(x)/(x*sqrt(1-x**2)), [0, 1]), pi*log(1+sqrt(2))/2)
assert ae(quadts(lambda x: log(1+x**2)/x**2, [0, 1]), pi/2-log(2))
assert ae(quadts(lambda x: x**2/((1+x**4)*sqrt(1-x**4)), [0, 1]), pi/8)

View File

@ -1,75 +0,0 @@
from mpmath import *
from mpmath.calculus.optimization import Secant, Muller, Bisection, Illinois, \
Pegasus, Anderson, Ridder, ANewton, Newton, MNewton, MDNewton
def test_findroot():
# old tests, assuming secant
mp.dps = 15
assert findroot(lambda x: 4*x-3, mpf(5)).ae(0.75)
assert findroot(sin, mpf(3)).ae(pi)
assert findroot(sin, (mpf(3), mpf(3.14))).ae(pi)
assert findroot(lambda x: x*x+1, mpc(2+2j)).ae(1j)
# test all solvers with 1 starting point
f = lambda x: cos(x)
for solver in [Newton, Secant, MNewton, Muller, ANewton]:
x = findroot(f, 2., solver=solver)
assert abs(f(x)) < eps
# test all solvers with interval of 2 points
for solver in [Secant, Muller, Bisection, Illinois, Pegasus, Anderson,
Ridder]:
x = findroot(f, (1., 2.), solver=solver)
assert abs(f(x)) < eps
# test types
f = lambda x: (x - 2)**2
#assert isinstance(findroot(f, 1, force_type=mpf, tol=1e-10), mpf)
#assert isinstance(findroot(f, 1., force_type=None, tol=1e-10), float)
#assert isinstance(findroot(f, 1, force_type=complex, tol=1e-10), complex)
assert isinstance(fp.findroot(f, 1, tol=1e-10), float)
assert isinstance(fp.findroot(f, 1+0j, tol=1e-10), complex)
def test_mnewton():
f = lambda x: polyval([1,3,3,1],x)
x = findroot(f, -0.9, solver='mnewton')
assert abs(f(x)) < eps
def test_anewton():
f = lambda x: (x - 2)**100
x = findroot(f, 1., solver=ANewton)
assert abs(f(x)) < eps
def test_muller():
f = lambda x: (2 + x)**3 + 2
x = findroot(f, 1., solver=Muller)
assert abs(f(x)) < eps
def test_multiplicity():
for i in xrange(1, 5):
assert multiplicity(lambda x: (x - 1)**i, 1) == i
assert multiplicity(lambda x: x**2, 1) == 0
def test_multidimensional():
def f(*x):
return [3*x[0]**2-2*x[1]**2-1, x[0]**2-2*x[0]+x[1]**2+2*x[1]-8]
assert mnorm(jacobian(f, (1,-2)) - matrix([[6,8],[0,-2]]),1) < 1.e-7
for x, error in MDNewton(mp, f, (1,-2), verbose=0,
norm=lambda x: norm(x, inf)):
pass
assert norm(f(*x), 2) < 1e-14
# The Chinese mathematician Zhu Shijie was the very first to solve this
# nonlinear system 700 years ago
f1 = lambda x, y: -x + 2*y
f2 = lambda x, y: (x**2 + x*(y**2 - 2) - 4*y) / (x + 4)
f3 = lambda x, y: sqrt(x**2 + y**2)
def f(x, y):
f1x = f1(x, y)
return (f2(x, y) - f1x, f3(x, y) - f1x)
x = findroot(f, (10, 10))
assert [int(round(i)) for i in x] == [3, 4]
def test_trivial():
assert findroot(lambda x: 0, 1) == 1
assert findroot(lambda x: x, 0) == 0
#assert findroot(lambda x, y: x + y, (1, -1)) == (1, -1)

View File

@ -1,112 +0,0 @@
from mpmath import *
def test_special():
assert inf == inf
assert inf != -inf
assert -inf == -inf
assert inf != nan
assert nan != nan
assert isnan(nan)
assert --inf == inf
assert abs(inf) == inf
assert abs(-inf) == inf
assert abs(nan) != abs(nan)
assert isnan(inf - inf)
assert isnan(inf + (-inf))
assert isnan(-inf - (-inf))
assert isnan(inf + nan)
assert isnan(-inf + nan)
assert mpf(2) + inf == inf
assert 2 + inf == inf
assert mpf(2) - inf == -inf
assert 2 - inf == -inf
assert inf > 3
assert 3 < inf
assert 3 > -inf
assert -inf < 3
assert inf > mpf(3)
assert mpf(3) < inf
assert mpf(3) > -inf
assert -inf < mpf(3)
assert not (nan < 3)
assert not (nan > 3)
assert isnan(inf * 0)
assert isnan(-inf * 0)
assert inf * 3 == inf
assert inf * -3 == -inf
assert -inf * 3 == -inf
assert -inf * -3 == inf
assert inf * inf == inf
assert -inf * -inf == inf
assert isnan(nan / 3)
assert inf / -3 == -inf
assert inf / 3 == inf
assert 3 / inf == 0
assert -3 / inf == 0
assert 0 / inf == 0
assert isnan(inf / inf)
assert isnan(inf / -inf)
assert isnan(inf / nan)
assert mpf('inf') == mpf('+inf') == inf
assert mpf('-inf') == -inf
assert isnan(mpf('nan'))
assert isinf(inf)
assert isinf(-inf)
assert not isinf(mpf(0))
assert not isinf(nan)
def test_special_powers():
assert inf**3 == inf
assert isnan(inf**0)
assert inf**-3 == 0
assert (-inf)**2 == inf
assert (-inf)**3 == -inf
assert isnan((-inf)**0)
assert (-inf)**-2 == 0
assert (-inf)**-3 == 0
assert isnan(nan**5)
assert isnan(nan**0)
def test_functions_special():
assert exp(inf) == inf
assert exp(-inf) == 0
assert isnan(exp(nan))
assert log(inf) == inf
assert isnan(sin(inf))
assert isnan(sin(nan))
assert atan(inf).ae(pi/2)
assert atan(-inf).ae(-pi/2)
assert isnan(sqrt(nan))
assert sqrt(inf) == inf
def test_convert_special():
float_inf = 1e300 * 1e300
float_ninf = -float_inf
float_nan = float_inf/float_ninf
assert mpf(3) * float_inf == inf
assert mpf(3) * float_ninf == -inf
assert isnan(mpf(3) * float_nan)
assert not (mpf(3) < float_nan)
assert not (mpf(3) > float_nan)
assert not (mpf(3) <= float_nan)
assert not (mpf(3) >= float_nan)
assert float(mpf('1e1000')) == float_inf
assert float(mpf('-1e1000')) == float_ninf
assert float(mpf('1e100000000000000000')) == float_inf
assert float(mpf('-1e100000000000000000')) == float_ninf
assert float(mpf('1e-100000000000000000')) == 0.0
def test_div_bug():
assert isnan(nan/1)
assert isnan(nan/2)
assert inf/2 == inf
assert (-inf)/2 == -inf

View File

@ -1,15 +0,0 @@
from mpmath import nstr, matrix, inf
def test_nstr():
m = matrix([[0.75, 0.190940654, -0.0299195971],
[0.190940654, 0.65625, 0.205663228],
[-0.0299195971, 0.205663228, 0.64453125e-20]])
assert nstr(m, 4, min_fixed=-inf) == \
'''[ 0.75 0.1909 -0.02992]
[ 0.1909 0.6563 0.2057]
[-0.02992 0.2057 0.000000000000000000006445]'''
assert nstr(m, 4) == \
'''[ 0.75 0.1909 -0.02992]
[ 0.1909 0.6563 0.2057]
[-0.02992 0.2057 6.445e-21]'''

View File

@ -1,51 +0,0 @@
from mpmath import *
def test_sumem():
mp.dps = 15
assert sumem(lambda k: 1/k**2.5, [50, 100]).ae(0.0012524505324784962)
assert sumem(lambda k: k**4 + 3*k + 1, [10, 100]).ae(2050333103)
def test_nsum():
mp.dps = 15
assert nsum(lambda x: x**2, [1, 3]) == 14
assert nsum(lambda k: 1/factorial(k), [0, inf]).ae(e)
assert nsum(lambda k: (-1)**(k+1) / k, [1, inf]).ae(log(2))
assert nsum(lambda k: (-1)**(k+1) / k**2, [1, inf]).ae(pi**2 / 12)
assert nsum(lambda k: (-1)**k / log(k), [2, inf]).ae(0.9242998972229388)
assert nsum(lambda k: 1/k**2, [1, inf]).ae(pi**2 / 6)
assert nsum(lambda k: 2**k/fac(k), [0, inf]).ae(exp(2))
assert nsum(lambda k: 1/k**2, [4, inf], method='e').ae(0.2838229557371153)
def test_nprod():
mp.dps = 15
assert nprod(lambda k: exp(1/k**2), [1,inf], method='r').ae(exp(pi**2/6))
assert nprod(lambda x: x**2, [1, 3]) == 36
def test_fsum():
mp.dps = 15
assert fsum([]) == 0
assert fsum([-4]) == -4
assert fsum([2,3]) == 5
assert fsum([1e-100,1]) == 1
assert fsum([1,1e-100]) == 1
assert fsum([1e100,1]) == 1e100
assert fsum([1,1e100]) == 1e100
assert fsum([1e-100,0]) == 1e-100
assert fsum([1e-100,1e100,1e-100]) == 1e100
assert fsum([2,1+1j,1]) == 4+1j
assert fsum([1,mpi(2,3)]) == mpi(3,4)
assert fsum([2,inf,3]) == inf
assert fsum([2,-1], absolute=1) == 3
assert fsum([2,-1], squared=1) == 5
assert fsum([1,1+j], squared=1) == 1+2j
assert fsum([1,3+4j], absolute=1) == 6
assert fsum([1,2+3j], absolute=1, squared=1) == 14
assert isnan(fsum([inf,-inf]))
assert fsum([inf,-inf], absolute=1) == inf
assert fsum([inf,-inf], squared=1) == inf
assert fsum([inf,-inf], absolute=1, squared=1) == inf
def test_fprod():
mp.dps = 15
assert fprod([]) == 1
assert fprod([2,3]) == 6

View File

@ -1,142 +0,0 @@
from mpmath import *
from mpmath.libmp import *
def test_trig_misc_hard():
mp.prec = 53
# Worst-case input for an IEEE double, from a paper by Kahan
x = ldexp(6381956970095103,797)
assert cos(x) == mpf('-4.6871659242546277e-19')
assert sin(x) == 1
mp.prec = 150
a = mpf(10**50)
mp.prec = 53
assert sin(a).ae(-0.7896724934293100827)
assert cos(a).ae(-0.6135286082336635622)
# Check relative accuracy close to x = zero
assert sin(1e-100) == 1e-100 # when rounding to nearest
assert sin(1e-6).ae(9.999999999998333e-007, rel_eps=2e-15, abs_eps=0)
assert sin(1e-6j).ae(1.0000000000001666e-006j, rel_eps=2e-15, abs_eps=0)
assert sin(-1e-6j).ae(-1.0000000000001666e-006j, rel_eps=2e-15, abs_eps=0)
assert cos(1e-100) == 1
assert cos(1e-6).ae(0.9999999999995)
assert cos(-1e-6j).ae(1.0000000000005)
assert tan(1e-100) == 1e-100
assert tan(1e-6).ae(1.0000000000003335e-006, rel_eps=2e-15, abs_eps=0)
assert tan(1e-6j).ae(9.9999999999966644e-007j, rel_eps=2e-15, abs_eps=0)
assert tan(-1e-6j).ae(-9.9999999999966644e-007j, rel_eps=2e-15, abs_eps=0)
def test_trig_near_zero():
mp.dps = 15
for r in [round_nearest, round_down, round_up, round_floor, round_ceiling]:
assert sin(0, rounding=r) == 0
assert cos(0, rounding=r) == 1
a = mpf('1e-100')
b = mpf('-1e-100')
assert sin(a, rounding=round_nearest) == a
assert sin(a, rounding=round_down) < a
assert sin(a, rounding=round_floor) < a
assert sin(a, rounding=round_up) >= a
assert sin(a, rounding=round_ceiling) >= a
assert sin(b, rounding=round_nearest) == b
assert sin(b, rounding=round_down) > b
assert sin(b, rounding=round_floor) <= b
assert sin(b, rounding=round_up) <= b
assert sin(b, rounding=round_ceiling) > b
assert cos(a, rounding=round_nearest) == 1
assert cos(a, rounding=round_down) < 1
assert cos(a, rounding=round_floor) < 1
assert cos(a, rounding=round_up) == 1
assert cos(a, rounding=round_ceiling) == 1
assert cos(b, rounding=round_nearest) == 1
assert cos(b, rounding=round_down) < 1
assert cos(b, rounding=round_floor) < 1
assert cos(b, rounding=round_up) == 1
assert cos(b, rounding=round_ceiling) == 1
def test_trig_near_n_pi():
mp.dps = 15
a = [n*pi for n in [1, 2, 6, 11, 100, 1001, 10000, 100001]]
mp.dps = 135
a.append(10**100 * pi)
mp.dps = 15
assert sin(a[0]) == mpf('1.2246467991473531772e-16')
assert sin(a[1]) == mpf('-2.4492935982947063545e-16')
assert sin(a[2]) == mpf('-7.3478807948841190634e-16')
assert sin(a[3]) == mpf('4.8998251578625894243e-15')
assert sin(a[4]) == mpf('1.9643867237284719452e-15')
assert sin(a[5]) == mpf('-8.8632615209684813458e-15')
assert sin(a[6]) == mpf('-4.8568235395684898392e-13')
assert sin(a[7]) == mpf('3.9087342299491231029e-11')
assert sin(a[8]) == mpf('-1.369235466754566993528e-36')
r = round_nearest
assert cos(a[0], rounding=r) == -1
assert cos(a[1], rounding=r) == 1
assert cos(a[2], rounding=r) == 1
assert cos(a[3], rounding=r) == -1
assert cos(a[4], rounding=r) == 1
assert cos(a[5], rounding=r) == -1
assert cos(a[6], rounding=r) == 1
assert cos(a[7], rounding=r) == -1
assert cos(a[8], rounding=r) == 1
r = round_up
assert cos(a[0], rounding=r) == -1
assert cos(a[1], rounding=r) == 1
assert cos(a[2], rounding=r) == 1
assert cos(a[3], rounding=r) == -1
assert cos(a[4], rounding=r) == 1
assert cos(a[5], rounding=r) == -1
assert cos(a[6], rounding=r) == 1
assert cos(a[7], rounding=r) == -1
assert cos(a[8], rounding=r) == 1
r = round_down
assert cos(a[0], rounding=r) > -1
assert cos(a[1], rounding=r) < 1
assert cos(a[2], rounding=r) < 1
assert cos(a[3], rounding=r) > -1
assert cos(a[4], rounding=r) < 1
assert cos(a[5], rounding=r) > -1
assert cos(a[6], rounding=r) < 1
assert cos(a[7], rounding=r) > -1
assert cos(a[8], rounding=r) < 1
r = round_floor
assert cos(a[0], rounding=r) == -1
assert cos(a[1], rounding=r) < 1
assert cos(a[2], rounding=r) < 1
assert cos(a[3], rounding=r) == -1
assert cos(a[4], rounding=r) < 1
assert cos(a[5], rounding=r) == -1
assert cos(a[6], rounding=r) < 1
assert cos(a[7], rounding=r) == -1
assert cos(a[8], rounding=r) < 1
r = round_ceiling
assert cos(a[0], rounding=r) > -1
assert cos(a[1], rounding=r) == 1
assert cos(a[2], rounding=r) == 1
assert cos(a[3], rounding=r) > -1
assert cos(a[4], rounding=r) == 1
assert cos(a[5], rounding=r) > -1
assert cos(a[6], rounding=r) == 1
assert cos(a[7], rounding=r) > -1
assert cos(a[8], rounding=r) == 1
mp.dps = 15
if __name__ == '__main__':
for f in globals().keys():
if f.startswith("test_"):
print f
globals()[f]()

View File

@ -1,27 +0,0 @@
"""
Limited tests of the visualization module. Right now it just makes
sure that passing custom Axes works.
"""
from mpmath import mp, fp
def test_axes():
try:
import pylab
except ImportError:
print "\nSkipping test (pylab not available)\n"
return
fig = pylab.figure()
axes = fig.add_subplot(111)
for ctx in [mp, fp]:
ctx.plot(lambda x: x**2, [0, 3], axes=axes)
assert axes.get_xlabel() == 'x'
assert axes.get_ylabel() == 'f(x)'
fig = pylab.figure()
axes = fig.add_subplot(111)
for ctx in [mp, fp]:
ctx.cplot(lambda z: z, [-2, 2], [-10, 10], axes=axes)
assert axes.get_xlabel() == 'Re(z)'
assert axes.get_ylabel() == 'Im(z)'

View File

@ -1,229 +0,0 @@
"""
Torture tests for asymptotics and high precision evaluation of
special functions.
(Other torture tests may also be placed here.)
Running this file (gmpy and psyco recommended!) takes several CPU minutes.
With Python 2.6+, multiprocessing is used automatically to run tests
in parallel if many cores are available. (A single test may take between
a second and several minutes; possibly more.)
The idea:
* We evaluate functions at positive, negative, imaginary, 45- and 135-degree
complex values with magnitudes between 10^-20 to 10^20, at precisions between
5 and 150 digits (we can go even higher for fast functions).
* Comparing the result from two different precision levels provides
a strong consistency check (particularly for functions that use
different algorithms at different precision levels).
* That the computation finishes at all (without failure), within reasonable
time, provides a check that evaluation works at all: that the code runs,
that it doesn't get stuck in an infinite loop, and that it doesn't use
some extremely slowly algorithm where it could use a faster one.
TODO:
* Speed up those functions that take long to finish!
* Generalize to test more cases; more options.
* Implement a timeout mechanism.
* Some functions are notably absent, including the following:
* inverse trigonometric functions (some become inaccurate for complex arguments)
* ci, si (not implemented properly for large complex arguments)
* zeta functions (need to modify test not to try too large imaginary values)
* and others...
"""
import sys, os
from timeit import default_timer as clock
if "-psyco" in sys.argv:
sys.argv.remove('-psyco')
import psyco
psyco.full()
if "-nogmpy" in sys.argv:
sys.argv.remove('-nogmpy')
os.environ['MPMATH_NOGMPY'] = 'Y'
filt = ''
if not sys.argv[-1].endswith(".py"):
filt = sys.argv[-1]
from mpmath import *
def test_asymp(f, maxdps=150, verbose=False, huge_range=False):
dps = [5,15,25,50,90,150,500,1500,5000,10000]
dps = [p for p in dps if p <= maxdps]
def check(x,y,p,inpt):
if abs(x-y)/abs(y) < workprec(20)(power)(10, -p+1):
return
print
print "Error!"
print "Input:", inpt
print "dps =", p
print "Result 1:", x
print "Result 2:", y
print "Absolute error:", abs(x-y)
print "Relative error:", abs(x-y)/abs(y)
raise AssertionError
exponents = range(-20,20)
if huge_range:
exponents += [-1000, -100, -50, 50, 100, 1000]
for n in exponents:
if verbose:
print ".",
mp.dps = 25
xpos = mpf(10)**n / 1.1287
xneg = -xpos
ximag = xpos*j
xcomplex1 = xpos*(1+j)
xcomplex2 = xpos*(-1+j)
for i in range(len(dps)):
if verbose:
print "Testing dps = %s" % dps[i]
mp.dps = dps[i]
new = f(xpos), f(xneg), f(ximag), f(xcomplex1), f(xcomplex2)
if i != 0:
p = dps[i-1]
check(prev[0], new[0], p, xpos)
check(prev[1], new[1], p, xneg)
check(prev[2], new[2], p, ximag)
check(prev[3], new[3], p, xcomplex1)
check(prev[4], new[4], p, xcomplex2)
prev = new
if verbose:
print
a1, a2, a3, a4, a5 = 1.5, -2.25, 3.125, 4, 2
def test_bernoulli_huge():
p, q = bernfrac(9000)
assert p % 10**10 == 9636701091
assert q == 4091851784687571609141381951327092757255270
mp.dps = 15
assert str(bernoulli(10**100)) == '-2.58183325604736e+987675256497386331227838638980680030172857347883537824464410652557820800494271520411283004120790908623'
mp.dps = 50
assert str(bernoulli(10**100)) == '-2.5818332560473632073252488656039475548106223822913e+987675256497386331227838638980680030172857347883537824464410652557820800494271520411283004120790908623'
mp.dps = 15
cases = """\
test_bernoulli_huge()
test_asymp(lambda z: +pi, maxdps=10000)
test_asymp(lambda z: +e, maxdps=10000)
test_asymp(lambda z: +ln2, maxdps=10000)
test_asymp(lambda z: +ln10, maxdps=10000)
test_asymp(lambda z: +phi, maxdps=10000)
test_asymp(lambda z: +catalan, maxdps=5000)
test_asymp(lambda z: +euler, maxdps=5000)
test_asymp(lambda z: +glaisher, maxdps=1000)
test_asymp(lambda z: +khinchin, maxdps=1000)
test_asymp(lambda z: +twinprime, maxdps=150)
test_asymp(lambda z: stieltjes(2), maxdps=150)
test_asymp(lambda z: +mertens, maxdps=150)
test_asymp(lambda z: +apery, maxdps=5000)
test_asymp(sqrt, maxdps=10000, huge_range=True)
test_asymp(cbrt, maxdps=5000, huge_range=True)
test_asymp(lambda z: root(z,4), maxdps=5000, huge_range=True)
test_asymp(lambda z: root(z,-5), maxdps=5000, huge_range=True)
test_asymp(exp, maxdps=5000, huge_range=True)
test_asymp(expm1, maxdps=1500)
test_asymp(ln, maxdps=5000, huge_range=True)
test_asymp(cosh, maxdps=5000)
test_asymp(sinh, maxdps=5000)
test_asymp(tanh, maxdps=1500)
test_asymp(sin, maxdps=5000, huge_range=True)
test_asymp(cos, maxdps=5000, huge_range=True)
test_asymp(tan, maxdps=1500)
test_asymp(agm, maxdps=1500, huge_range=True)
test_asymp(ellipk, maxdps=1500)
test_asymp(ellipe, maxdps=1500)
test_asymp(lambertw, huge_range=True)
test_asymp(lambda z: lambertw(z,-1))
test_asymp(lambda z: lambertw(z,1))
test_asymp(lambda z: lambertw(z,4))
test_asymp(gamma)
test_asymp(loggamma) # huge_range=True ?
test_asymp(ei)
test_asymp(e1)
test_asymp(li, huge_range=True)
test_asymp(ci)
test_asymp(si)
test_asymp(chi)
test_asymp(shi)
test_asymp(erf)
test_asymp(erfc)
test_asymp(erfi)
test_asymp(lambda z: besselj(2, z))
test_asymp(lambda z: bessely(2, z))
test_asymp(lambda z: besseli(2, z))
test_asymp(lambda z: besselk(2, z))
test_asymp(lambda z: besselj(-2.25, z))
test_asymp(lambda z: bessely(-2.25, z))
test_asymp(lambda z: besseli(-2.25, z))
test_asymp(lambda z: besselk(-2.25, z))
test_asymp(airyai)
test_asymp(airybi)
test_asymp(lambda z: hyp0f1(a1, z))
test_asymp(lambda z: hyp1f1(a1, a2, z))
test_asymp(lambda z: hyp1f2(a1, a2, a3, z))
test_asymp(lambda z: hyp2f0(a1, a2, z))
test_asymp(lambda z: hyperu(a1, a2, z))
test_asymp(lambda z: hyp2f1(a1, a2, a3, z))
test_asymp(lambda z: hyp2f2(a1, a2, a3, a4, z))
test_asymp(lambda z: hyp2f3(a1, a2, a3, a4, a5, z))
test_asymp(lambda z: coulombf(a1, a2, z))
test_asymp(lambda z: coulombg(a1, a2, z))
test_asymp(lambda z: polylog(2,z))
test_asymp(lambda z: polylog(3,z))
test_asymp(lambda z: polylog(-2,z))
test_asymp(lambda z: expint(4, z))
test_asymp(lambda z: expint(-4, z))
test_asymp(lambda z: expint(2.25, z))
test_asymp(lambda z: gammainc(2.5, z, 5))
test_asymp(lambda z: gammainc(2.5, 5, z))
test_asymp(lambda z: hermite(3, z))
test_asymp(lambda z: hermite(2.5, z))
test_asymp(lambda z: legendre(3, z))
test_asymp(lambda z: legendre(4, z))
test_asymp(lambda z: legendre(2.5, z))
test_asymp(lambda z: legenp(a1, a2, z))
test_asymp(lambda z: legenq(a1, a2, z), maxdps=90) # abnormally slow
test_asymp(lambda z: jtheta(1, z, 0.5))
test_asymp(lambda z: jtheta(2, z, 0.5))
test_asymp(lambda z: jtheta(3, z, 0.5))
test_asymp(lambda z: jtheta(4, z, 0.5))
test_asymp(lambda z: jtheta(1, z, 0.5, 1))
test_asymp(lambda z: jtheta(2, z, 0.5, 1))
test_asymp(lambda z: jtheta(3, z, 0.5, 1))
test_asymp(lambda z: jtheta(4, z, 0.5, 1))
test_asymp(barnesg, maxdps=90)
"""
def testit(line):
if filt in line:
print line
t1 = clock()
exec line
t2 = clock()
elapsed = t2-t1
print "Time:", elapsed, "for", line, "(OK)"
if __name__ == '__main__':
try:
from multiprocessing import Pool
mapf = Pool(None).map
print "Running tests with multiprocessing"
except ImportError:
print "Not using multiprocessing"
mapf = map
t1 = clock()
tasks = cases.splitlines()
mapf(testit, tasks)
t2 = clock()
print "Cumulative wall time:", t2-t1

View File

@ -1,91 +0,0 @@
def monitor(f, input='print', output='print'):
"""
Returns a wrapped copy of *f* that monitors evaluation by calling
*input* with every input (*args*, *kwargs*) passed to *f* and
*output* with every value returned from *f*. The default action
(specify using the special string value ``'print'``) is to print
inputs and outputs to stdout, along with the total evaluation
count::
>>> from mpmath import *
>>> mp.dps = 5; mp.pretty = False
>>> diff(monitor(exp), 1) # diff will eval f(x-h) and f(x+h)
in 0 (mpf('0.99999999906867742538452148'),) {}
out 0 mpf('2.7182818259274480055282064')
in 1 (mpf('1.0000000009313225746154785'),) {}
out 1 mpf('2.7182818309906424675501024')
mpf('2.7182808')
To disable either the input or the output handler, you may
pass *None* as argument.
Custom input and output handlers may be used e.g. to store
results for later analysis::
>>> mp.dps = 15
>>> input = []
>>> output = []
>>> findroot(monitor(sin, input.append, output.append), 3.0)
mpf('3.1415926535897932')
>>> len(input) # Count number of evaluations
9
>>> print input[3], output[3]
((mpf('3.1415076583334066'),), {}) 8.49952562843408e-5
>>> print input[4], output[4]
((mpf('3.1415928201669122'),), {}) -1.66577118985331e-7
"""
if not input:
input = lambda v: None
elif input == 'print':
incount = [0]
def input(value):
args, kwargs = value
print "in %s %r %r" % (incount[0], args, kwargs)
incount[0] += 1
if not output:
output = lambda v: None
elif output == 'print':
outcount = [0]
def output(value):
print "out %s %r" % (outcount[0], value)
outcount[0] += 1
def f_monitored(*args, **kwargs):
input((args, kwargs))
v = f(*args, **kwargs)
output(v)
return v
return f_monitored
def timing(f, *args, **kwargs):
"""
Returns time elapsed for evaluating ``f()``. Optionally arguments
may be passed to time the execution of ``f(*args, **kwargs)``.
If the first call is very quick, ``f`` is called
repeatedly and the best time is returned.
"""
once = kwargs.get('once')
if 'once' in kwargs:
del kwargs['once']
if args or kwargs:
if len(args) == 1 and not kwargs:
arg = args[0]
g = lambda: f(arg)
else:
g = lambda: f(*args, **kwargs)
else:
g = f
from timeit import default_timer as clock
t1=clock(); v=g(); t2=clock(); t=t2-t1
if t > 0.05 or once:
return t
for i in range(3):
t1=clock();
# Evaluate multiple times because the timer function
# has a significant overhead
g();g();g();g();g();g();g();g();g();g()
t2=clock()
t=min(t,(t2-t1)/10)
return t

View File

@ -1,270 +0,0 @@
"""
Plotting (requires matplotlib)
"""
from colorsys import hsv_to_rgb, hls_to_rgb
class VisualizationMethods(object):
plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError)
def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,
singularities=[], axes=None):
r"""
Shows a simple 2D plot of a function `f(x)` or list of functions
`[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval
specified by *xlim*. Some examples::
plot(lambda x: exp(x)*li(x), [1, 4])
plot([cos, sin], [-4, 4])
plot([fresnels, fresnelc], [-4, 4])
plot([sqrt, cbrt], [-4, 4])
plot(lambda t: zeta(0.5+t*j), [-20, 20])
plot([floor, ceil, abs, sign], [-5, 5])
Points where the function raises a numerical exception or
returns an infinite value are removed from the graph.
Singularities can also be excluded explicitly
as follows (useful for removing erroneous vertical lines)::
plot(cot, ylim=[-5, 5]) # bad
plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi]) # good
For parts where the function assumes complex values, the
real part is plotted with dashes and the imaginary part
is plotted with dots.
NOTE: This function requires matplotlib (pylab).
"""
if file:
axes = None
fig = None
if not axes:
import pylab
fig = pylab.figure()
axes = fig.add_subplot(111)
if not isinstance(f, (tuple, list)):
f = [f]
a, b = xlim
colors = ['b', 'r', 'g', 'm', 'k']
for n, func in enumerate(f):
x = ctx.arange(a, b, (b-a)/float(points))
segments = []
segment = []
in_complex = False
for i in xrange(len(x)):
try:
if i != 0:
for sing in singularities:
if x[i-1] <= sing and x[i] >= sing:
raise ValueError
v = func(x[i])
if ctx.isnan(v) or abs(v) > 1e300:
raise ValueError
if hasattr(v, "imag") and v.imag:
re = float(v.real)
im = float(v.imag)
if not in_complex:
in_complex = True
segments.append(segment)
segment = []
segment.append((float(x[i]), re, im))
else:
if in_complex:
in_complex = False
segments.append(segment)
segment = []
segment.append((float(x[i]), v))
except ctx.plot_ignore:
if segment:
segments.append(segment)
segment = []
if segment:
segments.append(segment)
for segment in segments:
x = [s[0] for s in segment]
y = [s[1] for s in segment]
if not x:
continue
c = colors[n % len(colors)]
if len(segment[0]) == 3:
z = [s[2] for s in segment]
axes.plot(x, y, '--'+c, linewidth=3)
axes.plot(x, z, ':'+c, linewidth=3)
else:
axes.plot(x, y, c, linewidth=3)
axes.set_xlim(xlim)
if ylim:
axes.set_ylim(ylim)
axes.set_xlabel('x')
axes.set_ylabel('f(x)')
axes.grid(True)
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
def default_color_function(ctx, z):
if ctx.isinf(z):
return (1.0, 1.0, 1.0)
if ctx.isnan(z):
return (0.5, 0.5, 0.5)
pi = 3.1415926535898
a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)
a = (a + 0.5) % 1.0
b = 1.0 - float(1/(1.0+abs(z)**0.3))
return hls_to_rgb(a, b, 0.8)
def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,
verbose=False, file=None, dpi=None, axes=None):
"""
Plots the given complex-valued function *f* over a rectangular part
of the complex plane specified by the pairs of intervals *re* and *im*.
For example::
cplot(lambda z: z, [-2, 2], [-10, 10])
cplot(exp)
cplot(zeta, [0, 1], [0, 50])
By default, the complex argument (phase) is shown as color (hue) and
the magnitude is show as brightness. You can also supply a
custom color function (*color*). This function should take a
complex number as input and return an RGB 3-tuple containing
floats in the range 0.0-1.0.
To obtain a sharp image, the number of points may need to be
increased to 100,000 or thereabout. Since evaluating the
function that many times is likely to be slow, the 'verbose'
option is useful to display progress.
NOTE: This function requires matplotlib (pylab).
"""
if color is None:
color = ctx.default_color_function
import pylab
if file:
axes = None
fig = None
if not axes:
fig = pylab.figure()
axes = fig.add_subplot(111)
rea, reb = re
ima, imb = im
dre = reb - rea
dim = imb - ima
M = int(ctx.sqrt(points*dre/dim)+1)
N = int(ctx.sqrt(points*dim/dre)+1)
x = pylab.linspace(rea, reb, M)
y = pylab.linspace(ima, imb, N)
# Note: we have to be careful to get the right rotation.
# Test with these plots:
# cplot(lambda z: z if z.real < 0 else 0)
# cplot(lambda z: z if z.imag < 0 else 0)
w = pylab.zeros((N, M, 3))
for n in xrange(N):
for m in xrange(M):
z = ctx.mpc(x[m], y[n])
try:
v = color(f(z))
except ctx.plot_ignore:
v = (0.5, 0.5, 0.5)
w[n,m] = v
if verbose:
print n, "of", N
axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')
axes.set_xlabel('Re(z)')
axes.set_ylabel('Im(z)')
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \
wireframe=False, file=None, dpi=None, axes=None):
"""
Plots the surface defined by `f`.
If `f` returns a single component, then this plots the surface
defined by `z = f(x,y)` over the rectangular domain with
`x = u` and `y = v`.
If `f` returns three components, then this plots the parametric
surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.
For example, to plot a simple function::
>>> from mpmath import *
>>> f = lambda x, y: sin(x+y)*cos(y)
>>> splot(f, [-pi,pi], [-pi,pi]) # doctest: +SKIP
Plotting a donut::
>>> r, R = 1, 2.5
>>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]
>>> splot(f, [0, 2*pi], [0, 2*pi]) # doctest: +SKIP
NOTE: This function requires matplotlib (pylab) 0.98.5.3 or higher.
"""
import pylab
import mpl_toolkits.mplot3d as mplot3d
if file:
axes = None
fig = None
if not axes:
fig = pylab.figure()
axes = mplot3d.axes3d.Axes3D(fig)
ua, ub = u
va, vb = v
du = ub - ua
dv = vb - va
if not isinstance(points, (list, tuple)):
points = [points, points]
M, N = points
u = pylab.linspace(ua, ub, M)
v = pylab.linspace(va, vb, N)
x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]
xab, yab, zab = [[0, 0] for i in xrange(3)]
for n in xrange(N):
for m in xrange(M):
fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))
try:
x[m,n], y[m,n], z[m,n] = fdata
except TypeError:
x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata
for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:
if c < cab[0]:
cab[0] = c
if c > cab[1]:
cab[1] = c
if wireframe:
axes.plot_wireframe(x, y, z, rstride=4, cstride=4)
else:
axes.plot_surface(x, y, z, rstride=4, cstride=4)
axes.set_xlabel('x')
axes.set_ylabel('y')
axes.set_zlabel('z')
if keep_aspect:
dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]
maxd = max(dx, dy, dz)
if dx < maxd:
delta = maxd - dx
axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)
if dy < maxd:
delta = maxd - dy
axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)
if dz < maxd:
delta = maxd - dz
axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
VisualizationMethods.plot = plot
VisualizationMethods.default_color_function = default_color_function
VisualizationMethods.cplot = cplot
VisualizationMethods.splot = splot

View File

@ -1,8 +1,7 @@
# -*- coding: ISO-8859-1 -*-
#
#
# Copyright (C) 2002-2005 Jörg Lehmann <joergl@users.sourceforge.net>
# Copyright (C) 2002-2006 André Wobst <wobsta@users.sourceforge.net>
# Copyright (C) 2002-2005 Jorg Lehmann <joergl@users.sourceforge.net>
# Copyright (C) 2002-2006 Andre Wobst <wobsta@users.sourceforge.net>
#
# This file is part of PyX (http://pyx.sourceforge.net/).
#
@ -28,8 +27,8 @@ interface. Complex tasks like 2d and 3d plots in publication-ready quality are
built out of these primitives.
"""
import version
__version__ = version.version
from .version import version
__version__ = version
__all__ = ["attr", "box", "bitmap", "canvas", "color", "connector", "deco", "deformer", "document",
"epsfile", "graph", "mesh", "path", "pattern", "style", "trafo", "text", "unit"]

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python2.7
#!/usr/bin/env python
"""
This script will generate a stimulus file for a given period, load, and slew input
for the given dimension SRAM. It is useful for debugging after an SRAM has been

View File

@ -89,11 +89,11 @@ def print_banner():
def check_versions():
""" Run some checks of required software versions. """
# check that we are not using version 3 and at least 2.7
# Now require python >=3.6
major_python_version = sys.version_info.major
minor_python_version = sys.version_info.minor
if not (major_python_version == 2 and minor_python_version >= 7):
debug.error("Python 2.7 is required.",-1)
if not (major_python_version == 3 and minor_python_version >= 6):
debug.error("Python 3.6 or greater is required.",-1)
# FIXME: Check versions of other tools here??
# or, this could be done in each module (e.g. verify, characterizer, etc.)

View File

@ -26,6 +26,7 @@ class bank(design.design):
"bitcell_array", "sense_amp_array", "precharge_array",
"column_mux_array", "write_driver_array", "tri_gate_array",
"bank_select"]
from importlib import reload
for mod_name in mod_list:
config_mod_name = getattr(OPTS, mod_name)
class_file = reload(__import__(config_mod_name))
@ -130,8 +131,8 @@ class bank(design.design):
def compute_sizes(self):
""" Computes the required sizes to create the bank """
self.num_cols = self.words_per_row*self.word_size
self.num_rows = self.num_words / self.words_per_row
self.num_cols = int(self.words_per_row*self.word_size)
self.num_rows = int(self.num_words / self.words_per_row)
self.row_addr_size = int(log(self.num_rows, 2))
self.col_addr_size = int(log(self.words_per_row, 2))
@ -320,7 +321,7 @@ class bank(design.design):
y_offset = self.sense_amp_array.height+self.column_mux_height \
+ self.write_driver_array.height + self.m2_gap + self.tri_gate_array.height
self.tri_gate_array_inst=self.add_inst(name="tri_gate_array",
mod=self.tri_gate_array,
mod=self.tri_gate_array,
offset=vector(0,y_offset).scale(-1,-1))
temp = []
@ -852,9 +853,7 @@ class bank(design.design):
def analytical_delay(self, slew, load):
""" return analytical delay of the bank"""
msf_addr_delay = self.msf_address.analytical_delay(slew, self.row_decoder.input_load())
decoder_delay = self.row_decoder.analytical_delay(msf_addr_delay.slew, self.wordline_driver.input_load())
decoder_delay = self.row_decoder.analytical_delay(slew, self.wordline_driver.input_load())
word_driver_delay = self.wordline_driver.analytical_delay(decoder_delay.slew, self.bitcell_array.input_load())
@ -866,7 +865,6 @@ class bank(design.design):
data_t_DATA_delay = self.tri_gate_array.analytical_delay(bl_t_data_out_delay.slew, load)
result = msf_addr_delay + decoder_delay + word_driver_delay \
+ bitcell_array_delay + bl_t_data_out_delay + data_t_DATA_delay
result = decoder_delay + word_driver_delay + bitcell_array_delay + bl_t_data_out_delay + data_t_DATA_delay
return result

View File

@ -21,6 +21,7 @@ class bitcell_array(design.design):
self.column_size = cols
self.row_size = rows
from importlib import reload
c = reload(__import__(OPTS.bitcell))
self.mod_bitcell = getattr(c, OPTS.bitcell)
self.cell = self.mod_bitcell()

View File

@ -70,6 +70,7 @@ class control_logic(design.design):
self.inv8 = pinv(size=16, height=dff_height)
self.add_mod(self.inv8)
from importlib import reload
c = reload(__import__(OPTS.replica_bitline))
replica_bitline = getattr(c, OPTS.replica_bitline)
# FIXME: These should be tuned according to the size!

View File

@ -28,6 +28,7 @@ class delay_chain(design.design):
self.num_inverters = 1 + sum(fanout_list)
self.num_top_half = round(self.num_inverters / 2.0)
from importlib import reload
c = reload(__import__(OPTS.bitcell))
self.mod_bitcell = getattr(c, OPTS.bitcell)
self.bitcell = self.mod_bitcell()

View File

@ -20,6 +20,7 @@ class dff_array(design.design):
design.design.__init__(self, name)
debug.info(1, "Creating {}".format(self.name))
from importlib import reload
c = reload(__import__(OPTS.dff))
self.mod_dff = getattr(c, OPTS.dff)
self.dff = self.mod_dff("dff")

View File

@ -20,6 +20,7 @@ class dff_buf(design.design):
design.design.__init__(self, name)
debug.info(1, "Creating {}".format(self.name))
from importlib import reload
c = reload(__import__(OPTS.dff))
self.mod_dff = getattr(c, OPTS.dff)
self.dff = self.mod_dff("dff")

View File

@ -19,6 +19,7 @@ class dff_inv(design.design):
design.design.__init__(self, name)
debug.info(1, "Creating {}".format(self.name))
from importlib import reload
c = reload(__import__(OPTS.dff))
self.mod_dff = getattr(c, OPTS.dff)
self.dff = self.mod_dff("dff")

View File

@ -21,6 +21,7 @@ class hierarchical_decoder(design.design):
def __init__(self, rows):
design.design.__init__(self, "hierarchical_decoder_{0}rows".format(rows))
from importlib import reload
c = reload(__import__(OPTS.bitcell))
self.mod_bitcell = getattr(c, OPTS.bitcell)
self.bitcell_height = self.mod_bitcell.height

View File

@ -19,6 +19,7 @@ class hierarchical_predecode(design.design):
self.number_of_outputs = int(math.pow(2, self.number_of_inputs))
design.design.__init__(self, name="pre{0}x{1}".format(self.number_of_inputs,self.number_of_outputs))
from importlib import reload
c = reload(__import__(OPTS.bitcell))
self.mod_bitcell = getattr(c, OPTS.bitcell)

View File

@ -20,6 +20,7 @@ class ms_flop_array(design.design):
design.design.__init__(self, name)
debug.info(1, "Creating {}".format(self.name))
from importlib import reload
c = reload(__import__(OPTS.ms_flop))
self.mod_ms_flop = getattr(c, OPTS.ms_flop)
self.ms = self.mod_ms_flop("ms_flop")
@ -27,7 +28,7 @@ class ms_flop_array(design.design):
self.width = self.columns * self.ms.width
self.height = self.ms.height
self.words_per_row = self.columns / self.word_size
self.words_per_row = int(self.columns / self.word_size)
self.create_layout()
@ -57,13 +58,16 @@ class ms_flop_array(design.design):
else:
base = vector((i+1)*self.ms.width,0)
mirror = "MY"
self.ms_inst[i/self.words_per_row]=self.add_inst(name=name,
index = int(i/self.words_per_row)
self.ms_inst[index]=self.add_inst(name=name,
mod=self.ms,
offset=base,
mirror=mirror)
self.connect_inst(["din[{0}]".format(i/self.words_per_row),
"dout[{0}]".format(i/self.words_per_row),
"dout_bar[{0}]".format(i/self.words_per_row),
self.connect_inst(["din[{0}]".format(index),
"dout[{0}]".format(index),
"dout_bar[{0}]".format(index),
"clk",
"vdd", "gnd"])

View File

@ -18,6 +18,7 @@ class replica_bitline(design.design):
def __init__(self, delay_stages, delay_fanout, bitcell_loads, name="replica_bitline"):
design.design.__init__(self, name)
from importlib import reload
g = reload(__import__(OPTS.delay_chain))
self.mod_delay_chain = getattr(g, OPTS.delay_chain)
@ -132,11 +133,10 @@ class replica_bitline(design.design):
""" Connect all the signals together """
self.route_vdd()
self.route_gnd()
self.route_vdd_gnd()
self.route_access_tx()
def route_vdd_gnd(self):
""" Route all the vdd and gnd pins to the top level """
def route_vdd_gnd(self):
""" Propagate all vdd/gnd pins up to this level for all modules """
# These are the instances that every bank has

View File

@ -14,6 +14,7 @@ class sense_amp_array(design.design):
design.design.__init__(self, "sense_amp_array")
debug.info(1, "Creating {0}".format(self.name))
from importlib import reload
c = reload(__import__(OPTS.sense_amp))
self.mod_sense_amp = getattr(c, OPTS.sense_amp)
self.amp = self.mod_sense_amp("sense_amp")
@ -33,7 +34,8 @@ class sense_amp_array(design.design):
def add_pins(self):
for i in range(0,self.row_size,self.words_per_row):
self.add_pin("data[{0}]".format(i/self.words_per_row))
index = int(i/self.words_per_row)
self.add_pin("data[{0}]".format(index))
self.add_pin("bl[{0}]".format(i))
self.add_pin("br[{0}]".format(i))
@ -62,12 +64,14 @@ class sense_amp_array(design.design):
br_offset = amp_position + br_pin.ll().scale(1,0)
dout_offset = amp_position + dout_pin.ll()
index = int(i/self.words_per_row)
inst = self.add_inst(name=name,
mod=self.amp,
offset=amp_position)
self.connect_inst(["bl[{0}]".format(i),
"br[{0}]".format(i),
"data[{0}]".format(i/self.words_per_row),
"data[{0}]".format(index),
"en", "vdd", "gnd"])
@ -85,19 +89,18 @@ class sense_amp_array(design.design):
layer="metal3",
offset=vdd_pos)
self.add_layout_pin(text="bl[{0}]".format(i/self.words_per_row),
self.add_layout_pin(text="bl[{0}]".format(i),
layer="metal2",
offset=bl_offset,
width=bl_pin.width(),
height=bl_pin.height())
self.add_layout_pin(text="br[{0}]".format(i/self.words_per_row),
self.add_layout_pin(text="br[{0}]".format(i),
layer="metal2",
offset=br_offset,
width=br_pin.width(),
height=br_pin.height())
self.add_layout_pin(text="data[{0}]".format(i/self.words_per_row),
self.add_layout_pin(text="data[{0}]".format(index),
layer="metal2",
offset=dout_offset,
width=dout_pin.width(),

Some files were not shown because too many files have changed in this diff Show More